nidus_http/middleware/
security.rs1use std::{
2 future::Future,
3 pin::Pin,
4 task::{Context, Poll},
5 time::Duration,
6};
7
8use axum::{body::Body, extract::Request};
9use http::{HeaderValue, Response, StatusCode, header};
10use tower::{Layer, Service};
11use tower_http::limit::RequestBodyLimitLayer;
12
13pub fn security_headers_layer() -> SecurityHeadersLayer {
22 SecurityHeadersLayer
23}
24
25#[derive(Clone, Copy, Debug, Default)]
31pub struct SecurityHeadersLayer;
32
33impl<S> Layer<S> for SecurityHeadersLayer {
34 type Service = SecurityHeadersService<S>;
35
36 fn layer(&self, inner: S) -> Self::Service {
37 SecurityHeadersService { inner }
38 }
39}
40
41#[derive(Clone, Debug)]
43pub struct SecurityHeadersService<S> {
44 inner: S,
45}
46
47impl<S> Service<Request> for SecurityHeadersService<S>
48where
49 S: Service<Request, Response = Response<Body>> + Send + 'static,
50 S::Future: Send + 'static,
51 S::Error: Send + 'static,
52{
53 type Response = Response<Body>;
54 type Error = S::Error;
55 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
56
57 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
58 self.inner.poll_ready(cx)
59 }
60
61 fn call(&mut self, request: Request) -> Self::Future {
62 let future = self.inner.call(request);
63 Box::pin(async move {
64 let mut response = future.await?;
65 response.headers_mut().insert(
66 "x-content-type-options",
67 HeaderValue::from_static("nosniff"),
68 );
69 response
70 .headers_mut()
71 .insert("x-frame-options", HeaderValue::from_static("DENY"));
72 response
73 .headers_mut()
74 .insert("referrer-policy", HeaderValue::from_static("no-referrer"));
75 Ok(response)
76 })
77 }
78}
79
80pub fn body_limit_layer(max_bytes: u64) -> BodyLimitLayer {
91 BodyLimitLayer {
92 max_bytes,
93 webhook_boundary: false,
94 }
95}
96
97pub fn streaming_body_limit_layer(max_bytes: usize) -> RequestBodyLimitLayer {
109 RequestBodyLimitLayer::new(max_bytes)
110}
111
112pub fn webhook_body_limit_layer(max_bytes: u64) -> BodyLimitLayer {
120 BodyLimitLayer {
121 max_bytes,
122 webhook_boundary: true,
123 }
124}
125
126#[derive(Clone, Copy, Debug)]
132pub struct BodyLimitLayer {
133 max_bytes: u64,
134 webhook_boundary: bool,
135}
136
137impl<S> Layer<S> for BodyLimitLayer {
138 type Service = BodyLimitService<S>;
139
140 fn layer(&self, inner: S) -> Self::Service {
141 BodyLimitService {
142 inner,
143 max_bytes: self.max_bytes,
144 webhook_boundary: self.webhook_boundary,
145 }
146 }
147}
148
149#[derive(Clone, Debug)]
151pub struct BodyLimitService<S> {
152 inner: S,
153 max_bytes: u64,
154 webhook_boundary: bool,
155}
156
157impl<S> Service<Request> for BodyLimitService<S>
158where
159 S: Service<Request, Response = Response<Body>> + Send + 'static,
160 S::Future: Send + 'static,
161 S::Error: Send + 'static,
162{
163 type Response = Response<Body>;
164 type Error = S::Error;
165 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
166
167 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
168 self.inner.poll_ready(cx)
169 }
170
171 fn call(&mut self, request: Request) -> Self::Future {
172 let too_large = request
173 .headers()
174 .get(header::CONTENT_LENGTH)
175 .and_then(|value| value.to_str().ok())
176 .and_then(|value| value.parse::<u64>().ok())
177 .is_some_and(|length| length > self.max_bytes);
178 if too_large {
179 let webhook_boundary = self.webhook_boundary;
180 return Box::pin(async move { Ok(body_too_large_response(webhook_boundary)) });
181 }
182
183 let future = self.inner.call(request);
184 Box::pin(future)
185 }
186}
187
188fn body_too_large_response(webhook_boundary: bool) -> Response<Body> {
189 let mut response = Response::new(Body::from("payload too large"));
190 *response.status_mut() = StatusCode::PAYLOAD_TOO_LARGE;
191 if webhook_boundary {
192 response.headers_mut().insert(
193 "x-nidus-body-limit",
194 HeaderValue::from_static("webhook-raw-body"),
195 );
196 }
197 response
198}
199
200pub fn timeout_response_layer(timeout: Duration) -> TimeoutResponseLayer {
206 TimeoutResponseLayer { timeout }
207}
208
209#[derive(Clone, Copy, Debug)]
215pub struct TimeoutResponseLayer {
216 timeout: Duration,
217}
218
219impl<S> Layer<S> for TimeoutResponseLayer {
220 type Service = TimeoutResponseService<S>;
221
222 fn layer(&self, inner: S) -> Self::Service {
223 TimeoutResponseService {
224 inner,
225 timeout: self.timeout,
226 }
227 }
228}
229
230#[derive(Clone, Debug)]
232pub struct TimeoutResponseService<S> {
233 inner: S,
234 timeout: Duration,
235}
236
237impl<S> Service<Request> for TimeoutResponseService<S>
238where
239 S: Service<Request, Response = Response<Body>> + Send + 'static,
240 S::Future: Send + 'static,
241 S::Error: Send + 'static,
242{
243 type Response = Response<Body>;
244 type Error = S::Error;
245 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
246
247 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
248 self.inner.poll_ready(cx)
249 }
250
251 fn call(&mut self, request: Request) -> Self::Future {
252 let timeout_duration = self.timeout;
253 let future = self.inner.call(request);
254 Box::pin(async move {
255 match tokio::time::timeout(timeout_duration, future).await {
256 Ok(response) => response,
257 Err(_) => Ok(timeout_response()),
258 }
259 })
260 }
261}
262
263fn timeout_response() -> Response<Body> {
264 let mut response = Response::new(Body::from("request timed out"));
265 *response.status_mut() = StatusCode::REQUEST_TIMEOUT;
266 response
267}