1use std::{
2 future::Future,
3 pin::Pin,
4 sync::{Arc, Mutex},
5};
6
7use axum::{
8 body::Body,
9 extract::Request,
10 http::{request::Parts, Method, Uri},
11 middleware::Next,
12 response::IntoResponse,
13 response::Response,
14};
15
16use crate::AuthIdentity;
17use crate::HttpException;
18
19#[derive(Clone, Debug)]
24pub struct RequestContext {
25 pub method: Method,
26 pub uri: Uri,
27 pub request_id: Option<String>,
28 pub auth_identity: Option<Arc<AuthIdentity>>,
29}
30
31impl RequestContext {
32 pub fn from_parts(parts: &Parts) -> Self {
33 Self {
34 method: parts.method.clone(),
35 uri: parts.uri.clone(),
36 request_id: crate::request::request_id_from_extensions(&parts.extensions),
37 auth_identity: parts.extensions.get::<Arc<AuthIdentity>>().cloned(),
38 }
39 }
40
41 pub fn from_request(req: &Request) -> Self {
42 Self {
43 method: req.method().clone(),
44 uri: req.uri().clone(),
45 request_id: crate::request::request_id_from_extensions(req.extensions()),
46 auth_identity: req.extensions().get::<Arc<AuthIdentity>>().cloned(),
47 }
48 }
49
50 pub fn is_authenticated(&self) -> bool {
51 self.auth_identity.is_some()
52 }
53
54 pub fn has_role(&self, role: &str) -> bool {
55 self.auth_identity
56 .as_ref()
57 .map(|identity| identity.has_role(role))
58 .unwrap_or(false)
59 }
60}
61
62pub trait ExceptionFilter: Send + Sync + 'static {
77 fn catch(&self, exception: HttpException, ctx: &RequestContext) -> HttpException;
78}
79
80pub trait Guard: Send + Sync + 'static {
87 fn can_activate(&self, ctx: &RequestContext) -> Result<(), HttpException>;
88}
89
90#[derive(Default)]
94pub struct RequireAuthenticationGuard;
95
96impl Guard for RequireAuthenticationGuard {
97 fn can_activate(&self, ctx: &RequestContext) -> Result<(), HttpException> {
98 if ctx.is_authenticated() {
99 Ok(())
100 } else {
101 Err(HttpException::unauthorized("Authentication required"))
102 }
103 }
104}
105
106pub struct RoleRequirementsGuard {
110 roles: Vec<String>,
111}
112
113impl RoleRequirementsGuard {
114 pub fn new<I, S>(roles: I) -> Self
115 where
116 I: IntoIterator<Item = S>,
117 S: Into<String>,
118 {
119 Self {
120 roles: roles.into_iter().map(Into::into).collect(),
121 }
122 }
123}
124
125impl Guard for RoleRequirementsGuard {
126 fn can_activate(&self, ctx: &RequestContext) -> Result<(), HttpException> {
127 if !ctx.is_authenticated() {
128 return Err(HttpException::unauthorized("Authentication required"));
129 }
130
131 if self.roles.iter().any(|role| ctx.has_role(role)) {
132 Ok(())
133 } else {
134 Err(HttpException::forbidden(format!(
135 "Missing required role. Expected one of: {}",
136 self.roles.join(", ")
137 )))
138 }
139 }
140}
141
142pub type NextFuture = Pin<Box<dyn Future<Output = Response> + Send>>;
143pub type NextFn = Arc<dyn Fn(Request<Body>) -> NextFuture + Send + Sync + 'static>;
144
145pub trait Interceptor: Send + Sync + 'static {
165 fn around(&self, ctx: RequestContext, req: Request<Body>, next: NextFn) -> NextFuture;
166}
167
168pub fn run_guards(guards: &[Arc<dyn Guard>], ctx: &RequestContext) -> Result<(), HttpException> {
169 for guard in guards {
170 guard.can_activate(ctx)?;
171 }
172 Ok(())
173}
174
175fn next_to_fn(next: Next) -> NextFn {
176 let next = Arc::new(Mutex::new(Some(next)));
182
183 Arc::new(move |req: Request<Body>| {
184 let next = Arc::clone(&next);
185 Box::pin(async move {
186 let next = {
187 let mut guard = match next.lock() {
188 Ok(guard) => guard,
189 Err(_) => {
190 return HttpException::internal_server_error("Pipeline lock poisoned")
191 .into_response();
192 }
193 };
194 guard.take()
195 };
196
197 match next {
198 Some(next) => next.run(req).await,
199 std::option::Option::None => crate::HttpException::internal_server_error(
200 "Pipeline next called multiple times",
201 )
202 .into_response(),
203 }
204 })
205 })
206}
207
208fn run_interceptor_chain(
209 interceptors: Arc<Vec<Arc<dyn Interceptor>>>,
210 index: usize,
211 ctx: RequestContext,
212 req: Request<Body>,
213 terminal: NextFn,
214) -> NextFuture {
215 if index >= interceptors.len() {
216 return terminal(req);
217 }
218
219 let current = Arc::clone(&interceptors[index]);
220 let interceptors_for_next = Arc::clone(&interceptors);
221 let ctx_for_next = ctx.clone();
222 let terminal_for_next = Arc::clone(&terminal);
223
224 let next_fn: NextFn = Arc::new(move |next_req: Request<Body>| {
225 run_interceptor_chain(
226 Arc::clone(&interceptors_for_next),
227 index + 1,
228 ctx_for_next.clone(),
229 next_req,
230 Arc::clone(&terminal_for_next),
231 )
232 });
233
234 current.around(ctx, req, next_fn)
235}
236
237pub async fn execute_pipeline(
243 req: Request<Body>,
244 next: Next,
245 guards: Arc<Vec<Arc<dyn Guard>>>,
246 interceptors: Arc<Vec<Arc<dyn Interceptor>>>,
247 filters: Arc<Vec<Arc<dyn ExceptionFilter>>>,
248) -> Response {
249 let ctx = RequestContext::from_request(&req);
250
251 if let Err(err) = run_guards(guards.as_slice(), &ctx) {
252 return apply_exception_filters(err, &ctx, filters.as_slice()).into_response();
253 }
254
255 let terminal = next_to_fn(next);
256 run_interceptor_chain(interceptors, 0, ctx, req, terminal).await
257}
258
259pub fn apply_exception_filters(
260 mut exception: HttpException,
261 ctx: &RequestContext,
262 filters: &[Arc<dyn ExceptionFilter>],
263) -> HttpException {
264 for filter in filters {
265 exception = filter.catch(exception, ctx);
266 }
267
268 exception
269}
270
271#[cfg(test)]
272mod tests {
273 use std::sync::Arc;
274
275 use axum::http::Method;
276
277 use crate::{AuthIdentity, Guard};
278
279 use super::{RequestContext, RequireAuthenticationGuard, RoleRequirementsGuard};
280
281 fn anonymous_context() -> RequestContext {
282 RequestContext {
283 method: Method::GET,
284 uri: "/".parse().expect("uri should parse"),
285 request_id: None,
286 auth_identity: None,
287 }
288 }
289
290 fn authenticated_context(roles: &[&str]) -> RequestContext {
291 RequestContext {
292 method: Method::GET,
293 uri: "/".parse().expect("uri should parse"),
294 request_id: None,
295 auth_identity: Some(Arc::new(
296 AuthIdentity::new("user-1").with_roles(roles.iter().copied()),
297 )),
298 }
299 }
300
301 #[test]
302 fn authentication_guard_rejects_anonymous_requests() {
303 let guard = RequireAuthenticationGuard;
304
305 assert!(guard.can_activate(&anonymous_context()).is_err());
306 assert!(guard.can_activate(&authenticated_context(&[])).is_ok());
307 }
308
309 #[test]
310 fn role_guard_accepts_any_matching_role() {
311 let guard = RoleRequirementsGuard::new(["admin", "support"]);
312
313 assert!(guard
314 .can_activate(&authenticated_context(&["support"]))
315 .is_ok());
316 assert!(guard
317 .can_activate(&authenticated_context(&["viewer"]))
318 .is_err());
319 assert!(guard.can_activate(&anonymous_context()).is_err());
320 }
321}