#![warn(missing_docs)]
pub mod allowed_hosts;
pub mod auth;
pub mod broken_link;
#[cfg(feature = "compression")]
pub mod brotli;
pub mod cache;
pub mod circuit_breaker;
pub mod common;
pub mod conditional;
#[cfg(feature = "sessions")]
pub mod cookie_session_auth;
#[cfg(feature = "cors")]
pub mod cors;
pub mod csp;
pub mod csp_helpers;
pub mod csrf;
pub mod etag;
pub mod flatpages;
#[cfg(feature = "compression")]
pub mod gzip;
pub mod honeypot;
pub mod https_redirect;
#[cfg(feature = "auth-jwt")]
pub mod jwt_auth;
pub mod locale;
pub mod logging;
pub mod login_required;
pub mod messages;
pub mod metrics;
pub mod origin_guard;
#[cfg(feature = "rate-limit")]
pub mod rate_limit;
pub mod redirect_fallback;
#[cfg(feature = "session-redis")]
pub mod redis_session;
pub mod remote_user;
pub mod request_id;
#[cfg(feature = "security")]
pub mod security_middleware;
pub mod session;
pub mod site;
pub mod timeout;
pub mod tracing;
pub mod xframe;
pub mod xss;
pub use reinhardt_http::{Handler, Middleware, MiddlewareChain};
pub use allowed_hosts::{AllowedHostsConfig, AllowedHostsMiddleware};
#[cfg(feature = "sessions")]
pub use auth::AuthenticationMiddleware;
pub use broken_link::{BrokenLinkConfig, BrokenLinkEmailsMiddleware};
#[cfg(feature = "compression")]
pub use brotli::{BrotliConfig, BrotliMiddleware, BrotliQuality};
pub use cache::{CacheConfig, CacheKeyStrategy, CacheMiddleware, CacheStore};
pub use circuit_breaker::{CircuitBreakerConfig, CircuitBreakerMiddleware, CircuitState};
pub use common::{CommonConfig, CommonMiddleware};
pub use conditional::ConditionalGetMiddleware;
#[cfg(feature = "sessions")]
pub use cookie_session_auth::{CookieSessionAuthMiddleware, CookieSessionConfig};
#[cfg(feature = "cors")]
pub use cors::CorsMiddleware;
pub use csp::{CspConfig, CspMiddleware, CspNonce};
pub use csp_helpers::{csp_nonce_attr, get_csp_nonce};
pub use csrf::{
CSRF_ALLOWED_CHARS, CSRF_SECRET_LENGTH, CSRF_SESSION_KEY, CSRF_TOKEN_LENGTH, CsrfConfig,
CsrfMeta, CsrfMiddleware, CsrfMiddlewareConfig, CsrfToken, InvalidTokenFormat,
REASON_BAD_ORIGIN, REASON_BAD_REFERER, REASON_CSRF_TOKEN_MISSING, REASON_INCORRECT_LENGTH,
REASON_INSECURE_REFERER, REASON_INVALID_CHARACTERS, REASON_MALFORMED_REFERER,
REASON_NO_CSRF_COOKIE, REASON_NO_REFERER, RejectRequest, SameSite, check_origin, check_referer,
check_token, get_secret, get_token, is_same_domain,
};
pub use etag::{ETagConfig, ETagMiddleware};
pub use flatpages::{Flatpage, FlatpageStore, FlatpagesConfig, FlatpagesMiddleware};
#[cfg(feature = "compression")]
pub use gzip::{GZipConfig, GZipMiddleware};
pub use honeypot::{HoneypotError, HoneypotField};
pub use https_redirect::{HttpsRedirectConfig, HttpsRedirectMiddleware};
#[cfg(feature = "auth-jwt")]
pub use jwt_auth::JwtAuthMiddleware;
pub use locale::{LocaleConfig, LocaleMiddleware};
pub use logging::{LoggingConfig, LoggingMiddleware};
pub use login_required::{
DEFAULT_LOGIN_URL, DEFAULT_REDIRECT_FIELD_NAME, LoginRequiredConfig, LoginRequiredMiddleware,
};
pub use messages::{CookieStorage, Message, MessageLevel, MessageStorage, SessionStorage};
pub use metrics::{MetricsConfig, MetricsMiddleware, MetricsStore};
pub use origin_guard::OriginGuardMiddleware;
#[cfg(feature = "rate-limit")]
pub use rate_limit::{RateLimitConfig, RateLimitMiddleware, RateLimitStore, RateLimitStrategy};
pub use redirect_fallback::{RedirectFallbackMiddleware, RedirectResponseConfig};
#[cfg(feature = "session-redis")]
pub use redis_session::RedisSessionBackend;
#[cfg(feature = "sessions")]
pub use remote_user::{PersistentRemoteUserMiddleware, REMOTE_USER_HEADER, RemoteUserMiddleware};
pub use request_id::{REQUEST_ID_HEADER, RequestIdConfig, RequestIdMiddleware};
#[cfg(feature = "security")]
#[allow(deprecated)] pub use security_middleware::{SecurityConfig, SecurityMiddleware};
pub use session::{SessionConfig, SessionData, SessionMiddleware, SessionStore};
pub use site::{SITE_ID_HEADER, Site, SiteConfig, SiteMiddleware, SiteRegistry};
pub use timeout::{TimeoutConfig, TimeoutMiddleware};
pub use tracing::{
PARENT_SPAN_ID_HEADER, SPAN_ID_HEADER, Span, SpanStatus, TRACE_ID_HEADER, TraceStore,
TracingConfig, TracingMiddleware,
};
pub use xframe::{XFrameOptions, XFrameOptionsMiddleware};
pub use xss::{XssConfig, XssError, XssProtector};
#[cfg(all(test, feature = "cors"))]
mod tests {
use super::*;
use bytes::Bytes;
use hyper::{HeaderMap, Method, StatusCode, Version};
use reinhardt_http::{Handler, Middleware, Request, Response};
use std::sync::Arc;
struct TestHandler;
#[async_trait::async_trait]
impl Handler for TestHandler {
async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
Ok(Response::ok().with_body("test response".as_bytes()))
}
}
#[tokio::test]
async fn test_cors_middleware_simple_request() {
use cors::CorsConfig;
let config = CorsConfig {
allow_origins: vec!["http://example.com".to_string()],
allow_methods: vec!["GET".to_string(), "POST".to_string()],
allow_headers: vec!["Content-Type".to_string()],
allow_credentials: false,
max_age: Some(3600),
};
let middleware = CorsMiddleware::new(config);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert("origin", "http://example.com".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(
response.headers.get("Access-Control-Allow-Origin").unwrap(),
"http://example.com"
);
}
#[tokio::test]
async fn test_cors_middleware_preflight_request() {
use cors::CorsConfig;
let config = CorsConfig {
allow_origins: vec!["http://example.com".to_string()],
allow_methods: vec!["GET".to_string(), "POST".to_string()],
allow_headers: vec!["Content-Type".to_string()],
allow_credentials: false,
max_age: Some(3600),
};
let middleware = CorsMiddleware::new(config);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert("origin", "http://example.com".parse().unwrap());
let request = Request::builder()
.method(Method::OPTIONS)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::NO_CONTENT);
assert!(response.headers.contains_key("Access-Control-Allow-Origin"));
assert!(
response
.headers
.contains_key("Access-Control-Allow-Methods")
);
assert!(
response
.headers
.contains_key("Access-Control-Allow-Headers")
);
}
#[tokio::test]
async fn test_cors_middleware_permissive() {
let middleware = CorsMiddleware::permissive();
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(response.headers.contains_key("Access-Control-Allow-Origin"));
}
#[tokio::test]
async fn test_logging_middleware() {
let middleware = LoggingMiddleware::new();
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
}
}