1#![warn(missing_docs)]
129pub mod allowed_hosts;
130pub mod auth;
132pub mod broken_link;
133#[cfg(feature = "compression")]
134pub mod brotli;
135pub mod cache;
136pub mod circuit_breaker;
137pub mod common;
138pub mod conditional;
139#[cfg(feature = "cors")]
140pub mod cors;
142pub mod csp;
143pub mod csp_helpers;
144pub mod csrf;
145pub mod etag;
146pub mod flatpages;
147#[cfg(feature = "compression")]
148pub mod gzip;
149pub mod honeypot;
150pub mod https_redirect;
151pub mod locale;
152pub mod logging;
154pub mod messages;
155pub mod metrics;
156#[cfg(feature = "rate-limit")]
157pub mod rate_limit;
158pub mod redirect_fallback;
159pub mod request_id;
160#[cfg(feature = "security")]
161pub mod security_middleware;
162pub mod session;
163pub mod site;
164pub mod timeout;
165pub mod tracing;
166pub mod xframe;
167pub mod xss;
168
169pub use reinhardt_http::{Handler, Middleware, MiddlewareChain};
171
172pub use allowed_hosts::{AllowedHostsConfig, AllowedHostsMiddleware};
173#[cfg(feature = "sessions")]
174pub use auth::AuthenticationMiddleware;
175pub use broken_link::{BrokenLinkConfig, BrokenLinkEmailsMiddleware};
176#[cfg(feature = "compression")]
177pub use brotli::{BrotliConfig, BrotliMiddleware, BrotliQuality};
178pub use cache::{CacheConfig, CacheKeyStrategy, CacheMiddleware, CacheStore};
179pub use circuit_breaker::{CircuitBreakerConfig, CircuitBreakerMiddleware, CircuitState};
180pub use common::{CommonConfig, CommonMiddleware};
181pub use conditional::ConditionalGetMiddleware;
182#[cfg(feature = "cors")]
183pub use cors::CorsMiddleware;
184pub use csp::{CspConfig, CspMiddleware, CspNonce};
185pub use csp_helpers::{csp_nonce_attr, get_csp_nonce};
186pub use csrf::{
187 CSRF_ALLOWED_CHARS, CSRF_SECRET_LENGTH, CSRF_SESSION_KEY, CSRF_TOKEN_LENGTH, CsrfConfig,
188 CsrfMeta, CsrfMiddleware, CsrfMiddlewareConfig, CsrfToken, InvalidTokenFormat,
189 REASON_BAD_ORIGIN, REASON_BAD_REFERER, REASON_CSRF_TOKEN_MISSING, REASON_INCORRECT_LENGTH,
190 REASON_INSECURE_REFERER, REASON_INVALID_CHARACTERS, REASON_MALFORMED_REFERER,
191 REASON_NO_CSRF_COOKIE, REASON_NO_REFERER, RejectRequest, SameSite, check_origin, check_referer,
192 check_token, get_secret, get_token, is_same_domain,
193};
194pub use etag::{ETagConfig, ETagMiddleware};
195pub use flatpages::{Flatpage, FlatpageStore, FlatpagesConfig, FlatpagesMiddleware};
196#[cfg(feature = "compression")]
197pub use gzip::{GZipConfig, GZipMiddleware};
198pub use honeypot::{HoneypotError, HoneypotField};
199pub use https_redirect::{HttpsRedirectConfig, HttpsRedirectMiddleware};
200pub use locale::{LocaleConfig, LocaleMiddleware};
201pub use logging::{LoggingConfig, LoggingMiddleware};
202pub use messages::{CookieStorage, Message, MessageLevel, MessageStorage, SessionStorage};
203pub use metrics::{MetricsConfig, MetricsMiddleware, MetricsStore};
204#[cfg(feature = "rate-limit")]
205pub use rate_limit::{RateLimitConfig, RateLimitMiddleware, RateLimitStore, RateLimitStrategy};
206pub use redirect_fallback::{RedirectFallbackMiddleware, RedirectResponseConfig};
207pub use request_id::{REQUEST_ID_HEADER, RequestIdConfig, RequestIdMiddleware};
208#[cfg(feature = "security")]
209pub use security_middleware::{SecurityConfig, SecurityMiddleware};
210pub use session::{SessionConfig, SessionData, SessionMiddleware, SessionStore};
211pub use site::{SITE_ID_HEADER, Site, SiteConfig, SiteMiddleware, SiteRegistry};
212pub use timeout::{TimeoutConfig, TimeoutMiddleware};
213pub use tracing::{
214 PARENT_SPAN_ID_HEADER, SPAN_ID_HEADER, Span, SpanStatus, TRACE_ID_HEADER, TraceStore,
215 TracingConfig, TracingMiddleware,
216};
217pub use xframe::{XFrameOptions, XFrameOptionsMiddleware};
218pub use xss::{XssConfig, XssError, XssProtector};
219
220#[cfg(all(test, feature = "cors"))]
221mod tests {
222 use super::*;
223 use bytes::Bytes;
224 use hyper::{HeaderMap, Method, StatusCode, Version};
225 use reinhardt_http::{Handler, Middleware, Request, Response};
226 use std::sync::Arc;
227
228 struct TestHandler;
229
230 #[async_trait::async_trait]
231 impl Handler for TestHandler {
232 async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
233 Ok(Response::ok().with_body("test response".as_bytes()))
234 }
235 }
236
237 #[tokio::test]
238 async fn test_cors_middleware_simple_request() {
239 use cors::CorsConfig;
240
241 let config = CorsConfig {
242 allow_origins: vec!["http://example.com".to_string()],
243 allow_methods: vec!["GET".to_string(), "POST".to_string()],
244 allow_headers: vec!["Content-Type".to_string()],
245 allow_credentials: false,
246 max_age: Some(3600),
247 };
248
249 let middleware = CorsMiddleware::new(config);
250 let handler = Arc::new(TestHandler);
251
252 let mut headers = HeaderMap::new();
253 headers.insert("origin", "http://example.com".parse().unwrap());
254
255 let request = Request::builder()
256 .method(Method::GET)
257 .uri("/test")
258 .version(Version::HTTP_11)
259 .headers(headers)
260 .body(Bytes::new())
261 .build()
262 .unwrap();
263
264 let response = middleware.process(request, handler).await.unwrap();
265
266 assert_eq!(
267 response.headers.get("Access-Control-Allow-Origin").unwrap(),
268 "http://example.com"
269 );
270 }
271
272 #[tokio::test]
273 async fn test_cors_middleware_preflight_request() {
274 use cors::CorsConfig;
275
276 let config = CorsConfig {
277 allow_origins: vec!["http://example.com".to_string()],
278 allow_methods: vec!["GET".to_string(), "POST".to_string()],
279 allow_headers: vec!["Content-Type".to_string()],
280 allow_credentials: false,
281 max_age: Some(3600),
282 };
283
284 let middleware = CorsMiddleware::new(config);
285 let handler = Arc::new(TestHandler);
286
287 let mut headers = HeaderMap::new();
288 headers.insert("origin", "http://example.com".parse().unwrap());
289
290 let request = Request::builder()
291 .method(Method::OPTIONS)
292 .uri("/test")
293 .version(Version::HTTP_11)
294 .headers(headers)
295 .body(Bytes::new())
296 .build()
297 .unwrap();
298
299 let response = middleware.process(request, handler).await.unwrap();
300
301 assert_eq!(response.status, StatusCode::NO_CONTENT);
302 assert!(response.headers.contains_key("Access-Control-Allow-Origin"));
303 assert!(
304 response
305 .headers
306 .contains_key("Access-Control-Allow-Methods")
307 );
308 assert!(
309 response
310 .headers
311 .contains_key("Access-Control-Allow-Headers")
312 );
313 }
314
315 #[tokio::test]
316 async fn test_cors_middleware_permissive() {
317 let middleware = CorsMiddleware::permissive();
318 let handler = Arc::new(TestHandler);
319 let request = Request::builder()
320 .method(Method::GET)
321 .uri("/test")
322 .version(Version::HTTP_11)
323 .headers(HeaderMap::new())
324 .body(Bytes::new())
325 .build()
326 .unwrap();
327
328 let response = middleware.process(request, handler).await.unwrap();
329
330 assert!(response.headers.contains_key("Access-Control-Allow-Origin"));
331 }
332
333 #[tokio::test]
334 async fn test_logging_middleware() {
335 let middleware = LoggingMiddleware::new();
336 let handler = Arc::new(TestHandler);
337 let request = Request::builder()
338 .method(Method::GET)
339 .uri("/test")
340 .version(Version::HTTP_11)
341 .headers(HeaderMap::new())
342 .body(Bytes::new())
343 .build()
344 .unwrap();
345
346 let response = middleware.process(request, handler).await.unwrap();
347
348 assert_eq!(response.status, StatusCode::OK);
349 }
350}