oxide_framework_core/
middleware.rs1use axum::extract::Request;
2use axum::http::StatusCode;
3use axum::middleware::Next;
4use axum::response::{IntoResponse, Response};
5use std::convert::Infallible;
6use std::future::Future;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use std::time::{Duration, Instant};
10use tower::{Layer, Service};
11use tracing::{info, warn};
12
13use crate::state::AppState;
14
15pub(crate) fn panic_json_response(
22 _err: Box<dyn std::any::Any + Send + 'static>,
23) -> axum::http::Response<axum::body::Body> {
24 let body = serde_json::json!({
25 "status": 500,
26 "error": "internal server error"
27 });
28 let bytes = serde_json::to_vec(&body).unwrap_or_default();
29 axum::http::Response::builder()
30 .status(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
31 .header("content-type", "application/json")
32 .body(axum::body::Body::from(bytes))
33 .unwrap()
34}
35
36pub async fn request_logger(request: Request, next: Next) -> Response {
44 let method = request.method().clone();
45 let path = request.uri().path().to_string();
46 let start = Instant::now();
47
48 let response = next.run(request).await;
49
50 let latency = start.elapsed();
51 let status = response.status().as_u16();
52
53 info!(
54 method = %method,
55 path = %path,
56 status = status,
57 latency_ms = latency.as_millis() as u64,
58 "request completed"
59 );
60
61 response
62}
63
64#[derive(Clone)]
73pub(crate) struct RequestTimeoutLayer {
74 duration: Duration,
75}
76
77impl RequestTimeoutLayer {
78 pub fn new(duration: Duration) -> Self {
79 Self { duration }
80 }
81}
82
83impl<S: Clone> Layer<S> for RequestTimeoutLayer {
84 type Service = RequestTimeoutService<S>;
85
86 fn layer(&self, inner: S) -> RequestTimeoutService<S> {
87 RequestTimeoutService {
88 inner,
89 duration: self.duration,
90 }
91 }
92}
93
94#[derive(Clone)]
95pub(crate) struct RequestTimeoutService<S> {
96 inner: S,
97 duration: Duration,
98}
99
100impl<S> Service<Request> for RequestTimeoutService<S>
101where
102 S: Service<Request> + Clone + Send + 'static,
103 S::Response: IntoResponse + 'static,
104 S::Error: Into<Infallible> + 'static,
105 S::Future: Send + 'static,
106{
107 type Response = Response;
108 type Error = S::Error;
109 type Future = Pin<Box<dyn Future<Output = Result<Response, S::Error>> + Send>>;
110
111 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
112 self.inner.poll_ready(cx)
113 }
114
115 fn call(&mut self, req: Request) -> Self::Future {
116 let duration = self.duration;
117 let path = req.uri().path().to_string();
118 let future = self.inner.call(req);
119
120 Box::pin(async move {
121 match tokio::time::timeout(duration, future).await {
122 Ok(result) => result.map(|r| r.into_response()),
123 Err(_) => {
124 warn!(
125 path = %path,
126 timeout_ms = duration.as_millis() as u64,
127 "request timed out"
128 );
129 let body = serde_json::json!({
130 "status": 408,
131 "error": "request timeout"
132 });
133 Ok((StatusCode::REQUEST_TIMEOUT, axum::Json(body)).into_response())
134 }
135 }
136 })
137 }
138}
139
140#[derive(Clone)]
146pub(crate) struct InjectStateLayer {
147 state: AppState,
148}
149
150impl InjectStateLayer {
151 pub fn new(state: AppState) -> Self {
152 Self { state }
153 }
154}
155
156impl<S: Clone> Layer<S> for InjectStateLayer {
157 type Service = InjectState<S>;
158
159 fn layer(&self, inner: S) -> InjectState<S> {
160 InjectState {
161 inner,
162 state: self.state.clone(),
163 }
164 }
165}
166
167#[derive(Clone)]
168pub(crate) struct InjectState<S> {
169 inner: S,
170 state: AppState,
171}
172
173impl<S> Service<Request> for InjectState<S>
174where
175 S: Service<Request> + Clone + Send + 'static,
176 S::Response: IntoResponse + 'static,
177 S::Error: Into<Infallible> + 'static,
178 S::Future: Send + 'static,
179{
180 type Response = S::Response;
181 type Error = S::Error;
182 type Future = Pin<Box<dyn Future<Output = Result<S::Response, S::Error>> + Send>>;
183
184 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
185 self.inner.poll_ready(cx)
186 }
187
188 fn call(&mut self, mut req: Request) -> Self::Future {
189 req.extensions_mut().insert(self.state.clone());
190 Box::pin(self.inner.call(req))
191 }
192}
193