1use std::{
2 future::Future,
3 pin::Pin,
4 sync::{Arc, Mutex},
5};
6
7use axum::{
8 body::Body,
9 extract::Request,
10 http::{request::Parts, Method, Uri},
11 middleware::Next,
12 response::IntoResponse,
13 response::Response,
14};
15
16use crate::HttpException;
17use crate::AuthIdentity;
18
19#[derive(Clone, Debug)]
20pub struct RequestContext {
21 pub method: Method,
22 pub uri: Uri,
23 pub request_id: Option<String>,
24 pub auth_identity: Option<Arc<AuthIdentity>>,
25}
26
27impl RequestContext {
28 pub fn from_parts(parts: &Parts) -> Self {
29 Self {
30 method: parts.method.clone(),
31 uri: parts.uri.clone(),
32 request_id: crate::request::request_id_from_extensions(&parts.extensions),
33 auth_identity: parts.extensions.get::<Arc<AuthIdentity>>().cloned(),
34 }
35 }
36
37 pub fn from_request(req: &Request) -> Self {
38 Self {
39 method: req.method().clone(),
40 uri: req.uri().clone(),
41 request_id: crate::request::request_id_from_extensions(req.extensions()),
42 auth_identity: req.extensions().get::<Arc<AuthIdentity>>().cloned(),
43 }
44 }
45
46 pub fn is_authenticated(&self) -> bool {
47 self.auth_identity.is_some()
48 }
49
50 pub fn has_role(&self, role: &str) -> bool {
51 self.auth_identity
52 .as_ref()
53 .map(|identity| identity.has_role(role))
54 .unwrap_or(false)
55 }
56}
57
58pub trait ExceptionFilter: Send + Sync + 'static {
59 fn catch(&self, exception: HttpException, ctx: &RequestContext) -> HttpException;
60}
61
62pub trait Guard: Send + Sync + 'static {
63 fn can_activate(&self, ctx: &RequestContext) -> Result<(), HttpException>;
64}
65
66#[derive(Default)]
67pub struct RequireAuthenticationGuard;
68
69impl Guard for RequireAuthenticationGuard {
70 fn can_activate(&self, ctx: &RequestContext) -> Result<(), HttpException> {
71 if ctx.is_authenticated() {
72 Ok(())
73 } else {
74 Err(HttpException::unauthorized("Authentication required"))
75 }
76 }
77}
78
79pub struct RoleRequirementsGuard {
80 roles: Vec<String>,
81}
82
83impl RoleRequirementsGuard {
84 pub fn new<I, S>(roles: I) -> Self
85 where
86 I: IntoIterator<Item = S>,
87 S: Into<String>,
88 {
89 Self {
90 roles: roles.into_iter().map(Into::into).collect(),
91 }
92 }
93}
94
95impl Guard for RoleRequirementsGuard {
96 fn can_activate(&self, ctx: &RequestContext) -> Result<(), HttpException> {
97 if !ctx.is_authenticated() {
98 return Err(HttpException::unauthorized("Authentication required"));
99 }
100
101 if self.roles.iter().any(|role| ctx.has_role(role)) {
102 Ok(())
103 } else {
104 Err(HttpException::forbidden(format!(
105 "Missing required role. Expected one of: {}",
106 self.roles.join(", ")
107 )))
108 }
109 }
110}
111
112pub type NextFuture = Pin<Box<dyn Future<Output = Response> + Send>>;
113pub type NextFn = Arc<dyn Fn(Request<Body>) -> NextFuture + Send + Sync + 'static>;
114
115pub trait Interceptor: Send + Sync + 'static {
116 fn around(&self, ctx: RequestContext, req: Request<Body>, next: NextFn) -> NextFuture;
117}
118
119pub fn run_guards(guards: &[Arc<dyn Guard>], ctx: &RequestContext) -> Result<(), HttpException> {
120 for guard in guards {
121 guard.can_activate(ctx)?;
122 }
123 Ok(())
124}
125
126fn next_to_fn(next: Next) -> NextFn {
127 let next = Arc::new(Mutex::new(Some(next)));
128
129 Arc::new(move |req: Request<Body>| {
130 let next = Arc::clone(&next);
131 Box::pin(async move {
132 let next = {
133 let mut guard = match next.lock() {
134 Ok(guard) => guard,
135 Err(_) => {
136 return HttpException::internal_server_error("Pipeline lock poisoned")
137 .into_response();
138 }
139 };
140 guard.take()
141 };
142
143 match next {
144 Some(next) => next.run(req).await,
145 None => HttpException::internal_server_error("Pipeline next called multiple times")
146 .into_response(),
147 }
148 })
149 })
150}
151
152fn run_interceptor_chain(
153 interceptors: Arc<Vec<Arc<dyn Interceptor>>>,
154 index: usize,
155 ctx: RequestContext,
156 req: Request<Body>,
157 terminal: NextFn,
158) -> NextFuture {
159 if index >= interceptors.len() {
160 return terminal(req);
161 }
162
163 let current = Arc::clone(&interceptors[index]);
164 let interceptors_for_next = Arc::clone(&interceptors);
165 let ctx_for_next = ctx.clone();
166 let terminal_for_next = Arc::clone(&terminal);
167
168 let next_fn: NextFn = Arc::new(move |next_req: Request<Body>| {
169 run_interceptor_chain(
170 Arc::clone(&interceptors_for_next),
171 index + 1,
172 ctx_for_next.clone(),
173 next_req,
174 Arc::clone(&terminal_for_next),
175 )
176 });
177
178 current.around(ctx, req, next_fn)
179}
180
181pub async fn execute_pipeline(
182 req: Request<Body>,
183 next: Next,
184 guards: Arc<Vec<Arc<dyn Guard>>>,
185 interceptors: Arc<Vec<Arc<dyn Interceptor>>>,
186 filters: Arc<Vec<Arc<dyn ExceptionFilter>>>,
187) -> Response {
188 let ctx = RequestContext::from_request(&req);
189
190 if let Err(err) = run_guards(guards.as_slice(), &ctx) {
191 return apply_exception_filters(err, &ctx, filters.as_slice()).into_response();
192 }
193
194 let terminal = next_to_fn(next);
195 run_interceptor_chain(interceptors, 0, ctx, req, terminal).await
196}
197
198pub fn apply_exception_filters(
199 mut exception: HttpException,
200 ctx: &RequestContext,
201 filters: &[Arc<dyn ExceptionFilter>],
202) -> HttpException {
203 for filter in filters {
204 exception = filter.catch(exception, ctx);
205 }
206
207 exception
208}
209
210#[cfg(test)]
211mod tests {
212 use std::sync::Arc;
213
214 use axum::http::Method;
215
216 use crate::{AuthIdentity, Guard};
217
218 use super::{RequestContext, RequireAuthenticationGuard, RoleRequirementsGuard};
219
220 fn anonymous_context() -> RequestContext {
221 RequestContext {
222 method: Method::GET,
223 uri: "/".parse().expect("uri should parse"),
224 request_id: None,
225 auth_identity: None,
226 }
227 }
228
229 fn authenticated_context(roles: &[&str]) -> RequestContext {
230 RequestContext {
231 method: Method::GET,
232 uri: "/".parse().expect("uri should parse"),
233 request_id: None,
234 auth_identity: Some(Arc::new(AuthIdentity::new("user-1").with_roles(roles))),
235 }
236 }
237
238 #[test]
239 fn authentication_guard_rejects_anonymous_requests() {
240 let guard = RequireAuthenticationGuard;
241
242 assert!(guard.can_activate(&anonymous_context()).is_err());
243 assert!(guard.can_activate(&authenticated_context(&[])).is_ok());
244 }
245
246 #[test]
247 fn role_guard_accepts_any_matching_role() {
248 let guard = RoleRequirementsGuard::new(["admin", "support"]);
249
250 assert!(guard.can_activate(&authenticated_context(&["support"])).is_ok());
251 assert!(guard.can_activate(&authenticated_context(&["viewer"])).is_err());
252 assert!(guard.can_activate(&anonymous_context()).is_err());
253 }
254}