1use std::{
2 future::Future,
3 pin::Pin,
4 sync::{Arc, Mutex},
5};
6
7use axum::{
8 body::Body,
9 extract::Request,
10 http::{request::Parts, Method, Uri},
11 middleware::Next,
12 response::IntoResponse,
13 response::Response,
14};
15
16use crate::AuthIdentity;
17use crate::HttpException;
18
19#[derive(Clone, Debug)]
20pub struct RequestContext {
21 pub method: Method,
22 pub uri: Uri,
23 pub request_id: Option<String>,
24 pub auth_identity: Option<Arc<AuthIdentity>>,
25}
26
27impl RequestContext {
28 pub fn from_parts(parts: &Parts) -> Self {
29 Self {
30 method: parts.method.clone(),
31 uri: parts.uri.clone(),
32 request_id: crate::request::request_id_from_extensions(&parts.extensions),
33 auth_identity: parts.extensions.get::<Arc<AuthIdentity>>().cloned(),
34 }
35 }
36
37 pub fn from_request(req: &Request) -> Self {
38 Self {
39 method: req.method().clone(),
40 uri: req.uri().clone(),
41 request_id: crate::request::request_id_from_extensions(req.extensions()),
42 auth_identity: req.extensions().get::<Arc<AuthIdentity>>().cloned(),
43 }
44 }
45
46 pub fn is_authenticated(&self) -> bool {
47 self.auth_identity.is_some()
48 }
49
50 pub fn has_role(&self, role: &str) -> bool {
51 self.auth_identity
52 .as_ref()
53 .map(|identity| identity.has_role(role))
54 .unwrap_or(false)
55 }
56}
57
58pub trait ExceptionFilter: Send + Sync + 'static {
59 fn catch(&self, exception: HttpException, ctx: &RequestContext) -> HttpException;
60}
61
62pub trait Guard: Send + Sync + 'static {
63 fn can_activate(&self, ctx: &RequestContext) -> Result<(), HttpException>;
64}
65
66#[derive(Default)]
67pub struct RequireAuthenticationGuard;
68
69impl Guard for RequireAuthenticationGuard {
70 fn can_activate(&self, ctx: &RequestContext) -> Result<(), HttpException> {
71 if ctx.is_authenticated() {
72 Ok(())
73 } else {
74 Err(HttpException::unauthorized("Authentication required"))
75 }
76 }
77}
78
79pub struct RoleRequirementsGuard {
80 roles: Vec<String>,
81}
82
83impl RoleRequirementsGuard {
84 pub fn new<I, S>(roles: I) -> Self
85 where
86 I: IntoIterator<Item = S>,
87 S: Into<String>,
88 {
89 Self {
90 roles: roles.into_iter().map(Into::into).collect(),
91 }
92 }
93}
94
95impl Guard for RoleRequirementsGuard {
96 fn can_activate(&self, ctx: &RequestContext) -> Result<(), HttpException> {
97 if !ctx.is_authenticated() {
98 return Err(HttpException::unauthorized("Authentication required"));
99 }
100
101 if self.roles.iter().any(|role| ctx.has_role(role)) {
102 Ok(())
103 } else {
104 Err(HttpException::forbidden(format!(
105 "Missing required role. Expected one of: {}",
106 self.roles.join(", ")
107 )))
108 }
109 }
110}
111
112pub type NextFuture = Pin<Box<dyn Future<Output = Response> + Send>>;
113pub type NextFn = Arc<dyn Fn(Request<Body>) -> NextFuture + Send + Sync + 'static>;
114
115pub trait Interceptor: Send + Sync + 'static {
116 fn around(&self, ctx: RequestContext, req: Request<Body>, next: NextFn) -> NextFuture;
117}
118
119pub fn run_guards(guards: &[Arc<dyn Guard>], ctx: &RequestContext) -> Result<(), HttpException> {
120 for guard in guards {
121 guard.can_activate(ctx)?;
122 }
123 Ok(())
124}
125
126fn next_to_fn(next: Next) -> NextFn {
127 let next = Arc::new(Mutex::new(Some(next)));
128
129 Arc::new(move |req: Request<Body>| {
130 let next = Arc::clone(&next);
131 Box::pin(async move {
132 let next = {
133 let mut guard = match next.lock() {
134 Ok(guard) => guard,
135 Err(_) => {
136 return HttpException::internal_server_error("Pipeline lock poisoned")
137 .into_response();
138 }
139 };
140 guard.take()
141 };
142
143 match next {
144 Some(next) => next.run(req).await,
145 std::option::Option::None => crate::HttpException::internal_server_error(
146 "Pipeline next called multiple times",
147 )
148 .into_response(),
149 }
150 })
151 })
152}
153
154fn run_interceptor_chain(
155 interceptors: Arc<Vec<Arc<dyn Interceptor>>>,
156 index: usize,
157 ctx: RequestContext,
158 req: Request<Body>,
159 terminal: NextFn,
160) -> NextFuture {
161 if index >= interceptors.len() {
162 return terminal(req);
163 }
164
165 let current = Arc::clone(&interceptors[index]);
166 let interceptors_for_next = Arc::clone(&interceptors);
167 let ctx_for_next = ctx.clone();
168 let terminal_for_next = Arc::clone(&terminal);
169
170 let next_fn: NextFn = Arc::new(move |next_req: Request<Body>| {
171 run_interceptor_chain(
172 Arc::clone(&interceptors_for_next),
173 index + 1,
174 ctx_for_next.clone(),
175 next_req,
176 Arc::clone(&terminal_for_next),
177 )
178 });
179
180 current.around(ctx, req, next_fn)
181}
182
183pub async fn execute_pipeline(
184 req: Request<Body>,
185 next: Next,
186 guards: Arc<Vec<Arc<dyn Guard>>>,
187 interceptors: Arc<Vec<Arc<dyn Interceptor>>>,
188 filters: Arc<Vec<Arc<dyn ExceptionFilter>>>,
189) -> Response {
190 let ctx = RequestContext::from_request(&req);
191
192 if let Err(err) = run_guards(guards.as_slice(), &ctx) {
193 return apply_exception_filters(err, &ctx, filters.as_slice()).into_response();
194 }
195
196 let terminal = next_to_fn(next);
197 run_interceptor_chain(interceptors, 0, ctx, req, terminal).await
198}
199
200pub fn apply_exception_filters(
201 mut exception: HttpException,
202 ctx: &RequestContext,
203 filters: &[Arc<dyn ExceptionFilter>],
204) -> HttpException {
205 for filter in filters {
206 exception = filter.catch(exception, ctx);
207 }
208
209 exception
210}
211
212#[cfg(test)]
213mod tests {
214 use std::sync::Arc;
215
216 use axum::http::Method;
217
218 use crate::{AuthIdentity, Guard};
219
220 use super::{RequestContext, RequireAuthenticationGuard, RoleRequirementsGuard};
221
222 fn anonymous_context() -> RequestContext {
223 RequestContext {
224 method: Method::GET,
225 uri: "/".parse().expect("uri should parse"),
226 request_id: None,
227 auth_identity: None,
228 }
229 }
230
231 fn authenticated_context(roles: &[&str]) -> RequestContext {
232 RequestContext {
233 method: Method::GET,
234 uri: "/".parse().expect("uri should parse"),
235 request_id: None,
236 auth_identity: Some(Arc::new(
237 AuthIdentity::new("user-1").with_roles(roles.iter().copied()),
238 )),
239 }
240 }
241
242 #[test]
243 fn authentication_guard_rejects_anonymous_requests() {
244 let guard = RequireAuthenticationGuard;
245
246 assert!(guard.can_activate(&anonymous_context()).is_err());
247 assert!(guard.can_activate(&authenticated_context(&[])).is_ok());
248 }
249
250 #[test]
251 fn role_guard_accepts_any_matching_role() {
252 let guard = RoleRequirementsGuard::new(["admin", "support"]);
253
254 assert!(guard
255 .can_activate(&authenticated_context(&["support"]))
256 .is_ok());
257 assert!(guard
258 .can_activate(&authenticated_context(&["viewer"]))
259 .is_err());
260 assert!(guard.can_activate(&anonymous_context()).is_err());
261 }
262}