Skip to main content

nidus_http/middleware/
security.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    task::{Context, Poll},
5    time::Duration,
6};
7
8use axum::{body::Body, extract::Request};
9use http::{HeaderValue, Response, StatusCode, header};
10use tower::{Layer, Service};
11use tower_http::limit::RequestBodyLimitLayer;
12
13/// Creates a layer that applies conservative API security headers.
14///
15/// Responses receive:
16/// - `x-content-type-options: nosniff`
17/// - `x-frame-options: DENY`
18/// - `referrer-policy: no-referrer`
19///
20/// Existing values for those headers are replaced.
21pub fn security_headers_layer() -> SecurityHeadersLayer {
22    SecurityHeadersLayer
23}
24
25/// Tower layer that adds conservative API security headers to responses.
26///
27/// This layer only mutates response headers after the inner service returns. It
28/// does not perform authentication, CORS, CSRF, or content-security-policy
29/// enforcement.
30#[derive(Clone, Copy, Debug, Default)]
31pub struct SecurityHeadersLayer;
32
33impl<S> Layer<S> for SecurityHeadersLayer {
34    type Service = SecurityHeadersService<S>;
35
36    fn layer(&self, inner: S) -> Self::Service {
37        SecurityHeadersService { inner }
38    }
39}
40
41/// Service produced by [`SecurityHeadersLayer`].
42#[derive(Clone, Debug)]
43pub struct SecurityHeadersService<S> {
44    inner: S,
45}
46
47impl<S> Service<Request> for SecurityHeadersService<S>
48where
49    S: Service<Request, Response = Response<Body>> + Send + 'static,
50    S::Future: Send + 'static,
51    S::Error: Send + 'static,
52{
53    type Response = Response<Body>;
54    type Error = S::Error;
55    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
56
57    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
58        self.inner.poll_ready(cx)
59    }
60
61    fn call(&mut self, request: Request) -> Self::Future {
62        let future = self.inner.call(request);
63        Box::pin(async move {
64            let mut response = future.await?;
65            response.headers_mut().insert(
66                "x-content-type-options",
67                HeaderValue::from_static("nosniff"),
68            );
69            response
70                .headers_mut()
71                .insert("x-frame-options", HeaderValue::from_static("DENY"));
72            response
73                .headers_mut()
74                .insert("referrer-policy", HeaderValue::from_static("no-referrer"));
75            Ok(response)
76        })
77    }
78}
79
80/// Creates a request body limit layer using the declared `Content-Length`.
81///
82/// The layer rejects requests when `Content-Length` parses as `u64` and is
83/// greater than `max_bytes`. The rejection is `413 Payload Too Large` with a
84/// plain-text `payload too large` body.
85///
86/// If `Content-Length` is absent, not UTF-8, or not a valid integer, the layer
87/// lets the request through. It does not count streamed bytes as the body is
88/// read; pair it with extractor/server limits when you need hard streaming
89/// enforcement.
90pub fn body_limit_layer(max_bytes: u64) -> BodyLimitLayer {
91    BodyLimitLayer {
92        max_bytes,
93        webhook_boundary: false,
94    }
95}
96
97/// Creates a streaming request body limit layer.
98///
99/// Unlike [`body_limit_layer`], this wraps the request body and enforces
100/// `max_bytes` as the downstream extractor or handler reads the stream. Requests
101/// with an oversized `Content-Length` are rejected before the inner service is
102/// called; requests without `Content-Length` fail with `413 Payload Too Large`
103/// when the body is read past the configured limit.
104///
105/// Use this when you need a hard read-time cap across streaming bodies. Keep
106/// [`body_limit_layer`] when you only want the lightweight declared
107/// `Content-Length` boundary used by [`crate::middleware::ApiDefaults`].
108pub fn streaming_body_limit_layer(max_bytes: usize) -> RequestBodyLimitLayer {
109    RequestBodyLimitLayer::new(max_bytes)
110}
111
112/// Creates a request body limit layer for webhook/raw-body routes.
113///
114/// This has the same declared `Content-Length` behavior as
115/// [`body_limit_layer`], but `413` responses include
116/// `x-nidus-body-limit: webhook-raw-body`. Use it at raw-body/webhook
117/// boundaries where callers or tests need to distinguish this limit from a
118/// generic API body limit.
119pub fn webhook_body_limit_layer(max_bytes: u64) -> BodyLimitLayer {
120    BodyLimitLayer {
121        max_bytes,
122        webhook_boundary: true,
123    }
124}
125
126/// Tower layer that rejects requests with a declared oversized body.
127///
128/// Enforcement is header-based: only a parseable `Content-Length` value above
129/// `max_bytes` is rejected. Missing or invalid `Content-Length` values are
130/// passed to the inner service unchanged.
131#[derive(Clone, Copy, Debug)]
132pub struct BodyLimitLayer {
133    max_bytes: u64,
134    webhook_boundary: bool,
135}
136
137impl<S> Layer<S> for BodyLimitLayer {
138    type Service = BodyLimitService<S>;
139
140    fn layer(&self, inner: S) -> Self::Service {
141        BodyLimitService {
142            inner,
143            max_bytes: self.max_bytes,
144            webhook_boundary: self.webhook_boundary,
145        }
146    }
147}
148
149/// Service produced by [`BodyLimitLayer`].
150#[derive(Clone, Debug)]
151pub struct BodyLimitService<S> {
152    inner: S,
153    max_bytes: u64,
154    webhook_boundary: bool,
155}
156
157impl<S> Service<Request> for BodyLimitService<S>
158where
159    S: Service<Request, Response = Response<Body>> + Send + 'static,
160    S::Future: Send + 'static,
161    S::Error: Send + 'static,
162{
163    type Response = Response<Body>;
164    type Error = S::Error;
165    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
166
167    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
168        self.inner.poll_ready(cx)
169    }
170
171    fn call(&mut self, request: Request) -> Self::Future {
172        let too_large = request
173            .headers()
174            .get(header::CONTENT_LENGTH)
175            .and_then(|value| value.to_str().ok())
176            .and_then(|value| value.parse::<u64>().ok())
177            .is_some_and(|length| length > self.max_bytes);
178        if too_large {
179            let webhook_boundary = self.webhook_boundary;
180            return Box::pin(async move { Ok(body_too_large_response(webhook_boundary)) });
181        }
182
183        let future = self.inner.call(request);
184        Box::pin(future)
185    }
186}
187
188fn body_too_large_response(webhook_boundary: bool) -> Response<Body> {
189    let mut response = Response::new(Body::from("payload too large"));
190    *response.status_mut() = StatusCode::PAYLOAD_TOO_LARGE;
191    if webhook_boundary {
192        response.headers_mut().insert(
193            "x-nidus-body-limit",
194            HeaderValue::from_static("webhook-raw-body"),
195        );
196    }
197    response
198}
199
200/// Creates a timeout layer that maps elapsed inner work to `408 Request Timeout`.
201///
202/// If the inner service completes before `timeout`, its response is returned
203/// unchanged. If the timeout elapses first, the response is `408 Request
204/// Timeout` with a plain-text `request timed out` body.
205pub fn timeout_response_layer(timeout: Duration) -> TimeoutResponseLayer {
206    TimeoutResponseLayer { timeout }
207}
208
209/// Tower layer that maps elapsed inner work to an HTTP timeout response.
210///
211/// This is an HTTP response-mapping layer, not Tower's error-returning timeout
212/// layer. It keeps the service error type unchanged and turns elapsed requests
213/// into a concrete `408` response.
214#[derive(Clone, Copy, Debug)]
215pub struct TimeoutResponseLayer {
216    timeout: Duration,
217}
218
219impl<S> Layer<S> for TimeoutResponseLayer {
220    type Service = TimeoutResponseService<S>;
221
222    fn layer(&self, inner: S) -> Self::Service {
223        TimeoutResponseService {
224            inner,
225            timeout: self.timeout,
226        }
227    }
228}
229
230/// Service produced by [`TimeoutResponseLayer`].
231#[derive(Clone, Debug)]
232pub struct TimeoutResponseService<S> {
233    inner: S,
234    timeout: Duration,
235}
236
237impl<S> Service<Request> for TimeoutResponseService<S>
238where
239    S: Service<Request, Response = Response<Body>> + Send + 'static,
240    S::Future: Send + 'static,
241    S::Error: Send + 'static,
242{
243    type Response = Response<Body>;
244    type Error = S::Error;
245    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
246
247    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
248        self.inner.poll_ready(cx)
249    }
250
251    fn call(&mut self, request: Request) -> Self::Future {
252        let timeout_duration = self.timeout;
253        let future = self.inner.call(request);
254        Box::pin(async move {
255            match tokio::time::timeout(timeout_duration, future).await {
256                Ok(response) => response,
257                Err(_) => Ok(timeout_response()),
258            }
259        })
260    }
261}
262
263fn timeout_response() -> Response<Body> {
264    let mut response = Response::new(Body::from("request timed out"));
265    *response.status_mut() = StatusCode::REQUEST_TIMEOUT;
266    response
267}