Skip to main content

oxide_framework_core/
middleware.rs

1use axum::extract::Request;
2use axum::http::StatusCode;
3use axum::middleware::Next;
4use axum::response::{IntoResponse, Response};
5use std::convert::Infallible;
6use std::future::Future;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use std::time::{Duration, Instant};
10use tower::{Layer, Service};
11use tracing::{info, warn};
12
13use crate::state::AppState;
14
15// ---------------------------------------------------------------------------
16// Built-in: panic recovery (JSON 500 response)
17// ---------------------------------------------------------------------------
18
19/// Produces a JSON 500 response when a handler panics.
20/// Used with `tower_http::catch_panic::CatchPanicLayer`.
21pub(crate) fn panic_json_response(
22    _err: Box<dyn std::any::Any + Send + 'static>,
23) -> axum::http::Response<axum::body::Body> {
24    let body = serde_json::json!({
25        "status": 500,
26        "error": "internal server error"
27    });
28    let bytes = serde_json::to_vec(&body).unwrap_or_default();
29    axum::http::Response::builder()
30        .status(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
31        .header("content-type", "application/json")
32        .body(axum::body::Body::from(bytes))
33        .unwrap()
34}
35
36// ---------------------------------------------------------------------------
37// Built-in: request logging middleware
38// ---------------------------------------------------------------------------
39
40/// Logs every request's method, path, response status, and latency.
41///
42/// Applied automatically by `App::run()`. Disable with `App::disable_request_logging()`.
43pub async fn request_logger(request: Request, next: Next) -> Response {
44    let method = request.method().clone();
45    let path = request.uri().path().to_string();
46    let start = Instant::now();
47
48    let response = next.run(request).await;
49
50    let latency = start.elapsed();
51    let status = response.status().as_u16();
52
53    info!(
54        method = %method,
55        path = %path,
56        status = status,
57        latency_ms = latency.as_millis() as u64,
58        "request completed"
59    );
60
61    response
62}
63
64// ---------------------------------------------------------------------------
65// Built-in: request timeout layer
66// ---------------------------------------------------------------------------
67
68/// Tower `Layer` that enforces a maximum request processing duration.
69///
70/// Returns `408 Request Timeout` with a JSON error body if the handler
71/// does not complete within the configured duration.
72#[derive(Clone)]
73pub(crate) struct RequestTimeoutLayer {
74    duration: Duration,
75}
76
77impl RequestTimeoutLayer {
78    pub fn new(duration: Duration) -> Self {
79        Self { duration }
80    }
81}
82
83impl<S: Clone> Layer<S> for RequestTimeoutLayer {
84    type Service = RequestTimeoutService<S>;
85
86    fn layer(&self, inner: S) -> RequestTimeoutService<S> {
87        RequestTimeoutService {
88            inner,
89            duration: self.duration,
90        }
91    }
92}
93
94#[derive(Clone)]
95pub(crate) struct RequestTimeoutService<S> {
96    inner: S,
97    duration: Duration,
98}
99
100impl<S> Service<Request> for RequestTimeoutService<S>
101where
102    S: Service<Request> + Clone + Send + 'static,
103    S::Response: IntoResponse + 'static,
104    S::Error: Into<Infallible> + 'static,
105    S::Future: Send + 'static,
106{
107    type Response = Response;
108    type Error = S::Error;
109    type Future = Pin<Box<dyn Future<Output = Result<Response, S::Error>> + Send>>;
110
111    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
112        self.inner.poll_ready(cx)
113    }
114
115    fn call(&mut self, req: Request) -> Self::Future {
116        let duration = self.duration;
117        let path = req.uri().path().to_string();
118        let future = self.inner.call(req);
119
120        Box::pin(async move {
121            match tokio::time::timeout(duration, future).await {
122                Ok(result) => result.map(|r| r.into_response()),
123                Err(_) => {
124                    warn!(
125                        path = %path,
126                        timeout_ms = duration.as_millis() as u64,
127                        "request timed out"
128                    );
129                    let body = serde_json::json!({
130                        "status": 408,
131                        "error": "request timeout"
132                    });
133                    Ok((StatusCode::REQUEST_TIMEOUT, axum::Json(body)).into_response())
134                }
135            }
136        })
137    }
138}
139
140// ---------------------------------------------------------------------------
141// Internal: state injection layer
142// ---------------------------------------------------------------------------
143
144/// Tower `Layer` that injects `AppState` into every request's extensions.
145#[derive(Clone)]
146pub(crate) struct InjectStateLayer {
147    state: AppState,
148}
149
150impl InjectStateLayer {
151    pub fn new(state: AppState) -> Self {
152        Self { state }
153    }
154}
155
156impl<S: Clone> Layer<S> for InjectStateLayer {
157    type Service = InjectState<S>;
158
159    fn layer(&self, inner: S) -> InjectState<S> {
160        InjectState {
161            inner,
162            state: self.state.clone(),
163        }
164    }
165}
166
167#[derive(Clone)]
168pub(crate) struct InjectState<S> {
169    inner: S,
170    state: AppState,
171}
172
173impl<S> Service<Request> for InjectState<S>
174where
175    S: Service<Request> + Clone + Send + 'static,
176    S::Response: IntoResponse + 'static,
177    S::Error: Into<Infallible> + 'static,
178    S::Future: Send + 'static,
179{
180    type Response = S::Response;
181    type Error = S::Error;
182    type Future = Pin<Box<dyn Future<Output = Result<S::Response, S::Error>> + Send>>;
183
184    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
185        self.inner.poll_ready(cx)
186    }
187
188    fn call(&mut self, mut req: Request) -> Self::Future {
189        req.extensions_mut().insert(self.state.clone());
190        Box::pin(self.inner.call(req))
191    }
192}
193