Skip to main content

nestforge_core/
pipeline.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    sync::{Arc, Mutex},
5};
6
7use axum::{
8    body::Body,
9    extract::Request,
10    http::{request::Parts, Method, Uri},
11    middleware::Next,
12    response::IntoResponse,
13    response::Response,
14};
15
16use crate::HttpException;
17use crate::AuthIdentity;
18
19#[derive(Clone, Debug)]
20pub struct RequestContext {
21    pub method: Method,
22    pub uri: Uri,
23    pub request_id: Option<String>,
24    pub auth_identity: Option<Arc<AuthIdentity>>,
25}
26
27impl RequestContext {
28    pub fn from_parts(parts: &Parts) -> Self {
29        Self {
30            method: parts.method.clone(),
31            uri: parts.uri.clone(),
32            request_id: crate::request::request_id_from_extensions(&parts.extensions),
33            auth_identity: parts.extensions.get::<Arc<AuthIdentity>>().cloned(),
34        }
35    }
36
37    pub fn from_request(req: &Request) -> Self {
38        Self {
39            method: req.method().clone(),
40            uri: req.uri().clone(),
41            request_id: crate::request::request_id_from_extensions(req.extensions()),
42            auth_identity: req.extensions().get::<Arc<AuthIdentity>>().cloned(),
43        }
44    }
45
46    pub fn is_authenticated(&self) -> bool {
47        self.auth_identity.is_some()
48    }
49
50    pub fn has_role(&self, role: &str) -> bool {
51        self.auth_identity
52            .as_ref()
53            .map(|identity| identity.has_role(role))
54            .unwrap_or(false)
55    }
56}
57
58pub trait ExceptionFilter: Send + Sync + 'static {
59    fn catch(&self, exception: HttpException, ctx: &RequestContext) -> HttpException;
60}
61
62pub trait Guard: Send + Sync + 'static {
63    fn can_activate(&self, ctx: &RequestContext) -> Result<(), HttpException>;
64}
65
66#[derive(Default)]
67pub struct RequireAuthenticationGuard;
68
69impl Guard for RequireAuthenticationGuard {
70    fn can_activate(&self, ctx: &RequestContext) -> Result<(), HttpException> {
71        if ctx.is_authenticated() {
72            Ok(())
73        } else {
74            Err(HttpException::unauthorized("Authentication required"))
75        }
76    }
77}
78
79pub struct RoleRequirementsGuard {
80    roles: Vec<String>,
81}
82
83impl RoleRequirementsGuard {
84    pub fn new<I, S>(roles: I) -> Self
85    where
86        I: IntoIterator<Item = S>,
87        S: Into<String>,
88    {
89        Self {
90            roles: roles.into_iter().map(Into::into).collect(),
91        }
92    }
93}
94
95impl Guard for RoleRequirementsGuard {
96    fn can_activate(&self, ctx: &RequestContext) -> Result<(), HttpException> {
97        if !ctx.is_authenticated() {
98            return Err(HttpException::unauthorized("Authentication required"));
99        }
100
101        if self.roles.iter().any(|role| ctx.has_role(role)) {
102            Ok(())
103        } else {
104            Err(HttpException::forbidden(format!(
105                "Missing required role. Expected one of: {}",
106                self.roles.join(", ")
107            )))
108        }
109    }
110}
111
112pub type NextFuture = Pin<Box<dyn Future<Output = Response> + Send>>;
113pub type NextFn = Arc<dyn Fn(Request<Body>) -> NextFuture + Send + Sync + 'static>;
114
115pub trait Interceptor: Send + Sync + 'static {
116    fn around(&self, ctx: RequestContext, req: Request<Body>, next: NextFn) -> NextFuture;
117}
118
119pub fn run_guards(guards: &[Arc<dyn Guard>], ctx: &RequestContext) -> Result<(), HttpException> {
120    for guard in guards {
121        guard.can_activate(ctx)?;
122    }
123    Ok(())
124}
125
126fn next_to_fn(next: Next) -> NextFn {
127    let next = Arc::new(Mutex::new(Some(next)));
128
129    Arc::new(move |req: Request<Body>| {
130        let next = Arc::clone(&next);
131        Box::pin(async move {
132            let next = {
133                let mut guard = match next.lock() {
134                    Ok(guard) => guard,
135                    Err(_) => {
136                        return HttpException::internal_server_error("Pipeline lock poisoned")
137                            .into_response();
138                    }
139                };
140                guard.take()
141            };
142
143            match next {
144                Some(next) => next.run(req).await,
145                None => HttpException::internal_server_error("Pipeline next called multiple times")
146                    .into_response(),
147            }
148        })
149    })
150}
151
152fn run_interceptor_chain(
153    interceptors: Arc<Vec<Arc<dyn Interceptor>>>,
154    index: usize,
155    ctx: RequestContext,
156    req: Request<Body>,
157    terminal: NextFn,
158) -> NextFuture {
159    if index >= interceptors.len() {
160        return terminal(req);
161    }
162
163    let current = Arc::clone(&interceptors[index]);
164    let interceptors_for_next = Arc::clone(&interceptors);
165    let ctx_for_next = ctx.clone();
166    let terminal_for_next = Arc::clone(&terminal);
167
168    let next_fn: NextFn = Arc::new(move |next_req: Request<Body>| {
169        run_interceptor_chain(
170            Arc::clone(&interceptors_for_next),
171            index + 1,
172            ctx_for_next.clone(),
173            next_req,
174            Arc::clone(&terminal_for_next),
175        )
176    });
177
178    current.around(ctx, req, next_fn)
179}
180
181pub async fn execute_pipeline(
182    req: Request<Body>,
183    next: Next,
184    guards: Arc<Vec<Arc<dyn Guard>>>,
185    interceptors: Arc<Vec<Arc<dyn Interceptor>>>,
186    filters: Arc<Vec<Arc<dyn ExceptionFilter>>>,
187) -> Response {
188    let ctx = RequestContext::from_request(&req);
189
190    if let Err(err) = run_guards(guards.as_slice(), &ctx) {
191        return apply_exception_filters(err, &ctx, filters.as_slice()).into_response();
192    }
193
194    let terminal = next_to_fn(next);
195    run_interceptor_chain(interceptors, 0, ctx, req, terminal).await
196}
197
198pub fn apply_exception_filters(
199    mut exception: HttpException,
200    ctx: &RequestContext,
201    filters: &[Arc<dyn ExceptionFilter>],
202) -> HttpException {
203    for filter in filters {
204        exception = filter.catch(exception, ctx);
205    }
206
207    exception
208}
209
210#[cfg(test)]
211mod tests {
212    use std::sync::Arc;
213
214    use axum::http::Method;
215
216    use crate::{AuthIdentity, Guard};
217
218    use super::{RequestContext, RequireAuthenticationGuard, RoleRequirementsGuard};
219
220    fn anonymous_context() -> RequestContext {
221        RequestContext {
222            method: Method::GET,
223            uri: "/".parse().expect("uri should parse"),
224            request_id: None,
225            auth_identity: None,
226        }
227    }
228
229    fn authenticated_context(roles: &[&str]) -> RequestContext {
230        RequestContext {
231            method: Method::GET,
232            uri: "/".parse().expect("uri should parse"),
233            request_id: None,
234            auth_identity: Some(Arc::new(AuthIdentity::new("user-1").with_roles(roles))),
235        }
236    }
237
238    #[test]
239    fn authentication_guard_rejects_anonymous_requests() {
240        let guard = RequireAuthenticationGuard;
241
242        assert!(guard.can_activate(&anonymous_context()).is_err());
243        assert!(guard.can_activate(&authenticated_context(&[])).is_ok());
244    }
245
246    #[test]
247    fn role_guard_accepts_any_matching_role() {
248        let guard = RoleRequirementsGuard::new(["admin", "support"]);
249
250        assert!(guard.can_activate(&authenticated_context(&["support"])).is_ok());
251        assert!(guard.can_activate(&authenticated_context(&["viewer"])).is_err());
252        assert!(guard.can_activate(&anonymous_context()).is_err());
253    }
254}