Skip to main content

nestforge_core/
pipeline.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    sync::{Arc, Mutex},
5};
6
7use axum::{
8    body::Body,
9    extract::Request,
10    http::{request::Parts, Method, Uri},
11    middleware::Next,
12    response::IntoResponse,
13    response::Response,
14};
15
16use crate::AuthIdentity;
17use crate::HttpException;
18
19#[derive(Clone, Debug)]
20pub struct RequestContext {
21    pub method: Method,
22    pub uri: Uri,
23    pub request_id: Option<String>,
24    pub auth_identity: Option<Arc<AuthIdentity>>,
25}
26
27impl RequestContext {
28    pub fn from_parts(parts: &Parts) -> Self {
29        Self {
30            method: parts.method.clone(),
31            uri: parts.uri.clone(),
32            request_id: crate::request::request_id_from_extensions(&parts.extensions),
33            auth_identity: parts.extensions.get::<Arc<AuthIdentity>>().cloned(),
34        }
35    }
36
37    pub fn from_request(req: &Request) -> Self {
38        Self {
39            method: req.method().clone(),
40            uri: req.uri().clone(),
41            request_id: crate::request::request_id_from_extensions(req.extensions()),
42            auth_identity: req.extensions().get::<Arc<AuthIdentity>>().cloned(),
43        }
44    }
45
46    pub fn is_authenticated(&self) -> bool {
47        self.auth_identity.is_some()
48    }
49
50    pub fn has_role(&self, role: &str) -> bool {
51        self.auth_identity
52            .as_ref()
53            .map(|identity| identity.has_role(role))
54            .unwrap_or(false)
55    }
56}
57
58pub trait ExceptionFilter: Send + Sync + 'static {
59    fn catch(&self, exception: HttpException, ctx: &RequestContext) -> HttpException;
60}
61
62pub trait Guard: Send + Sync + 'static {
63    fn can_activate(&self, ctx: &RequestContext) -> Result<(), HttpException>;
64}
65
66#[derive(Default)]
67pub struct RequireAuthenticationGuard;
68
69impl Guard for RequireAuthenticationGuard {
70    fn can_activate(&self, ctx: &RequestContext) -> Result<(), HttpException> {
71        if ctx.is_authenticated() {
72            Ok(())
73        } else {
74            Err(HttpException::unauthorized("Authentication required"))
75        }
76    }
77}
78
79pub struct RoleRequirementsGuard {
80    roles: Vec<String>,
81}
82
83impl RoleRequirementsGuard {
84    pub fn new<I, S>(roles: I) -> Self
85    where
86        I: IntoIterator<Item = S>,
87        S: Into<String>,
88    {
89        Self {
90            roles: roles.into_iter().map(Into::into).collect(),
91        }
92    }
93}
94
95impl Guard for RoleRequirementsGuard {
96    fn can_activate(&self, ctx: &RequestContext) -> Result<(), HttpException> {
97        if !ctx.is_authenticated() {
98            return Err(HttpException::unauthorized("Authentication required"));
99        }
100
101        if self.roles.iter().any(|role| ctx.has_role(role)) {
102            Ok(())
103        } else {
104            Err(HttpException::forbidden(format!(
105                "Missing required role. Expected one of: {}",
106                self.roles.join(", ")
107            )))
108        }
109    }
110}
111
112pub type NextFuture = Pin<Box<dyn Future<Output = Response> + Send>>;
113pub type NextFn = Arc<dyn Fn(Request<Body>) -> NextFuture + Send + Sync + 'static>;
114
115pub trait Interceptor: Send + Sync + 'static {
116    fn around(&self, ctx: RequestContext, req: Request<Body>, next: NextFn) -> NextFuture;
117}
118
119pub fn run_guards(guards: &[Arc<dyn Guard>], ctx: &RequestContext) -> Result<(), HttpException> {
120    for guard in guards {
121        guard.can_activate(ctx)?;
122    }
123    Ok(())
124}
125
126fn next_to_fn(next: Next) -> NextFn {
127    let next = Arc::new(Mutex::new(Some(next)));
128
129    Arc::new(move |req: Request<Body>| {
130        let next = Arc::clone(&next);
131        Box::pin(async move {
132            let next = {
133                let mut guard = match next.lock() {
134                    Ok(guard) => guard,
135                    Err(_) => {
136                        return HttpException::internal_server_error("Pipeline lock poisoned")
137                            .into_response();
138                    }
139                };
140                guard.take()
141            };
142
143            match next {
144                Some(next) => next.run(req).await,
145                std::option::Option::None => crate::HttpException::internal_server_error(
146                    "Pipeline next called multiple times",
147                )
148                .into_response(),
149            }
150        })
151    })
152}
153
154fn run_interceptor_chain(
155    interceptors: Arc<Vec<Arc<dyn Interceptor>>>,
156    index: usize,
157    ctx: RequestContext,
158    req: Request<Body>,
159    terminal: NextFn,
160) -> NextFuture {
161    if index >= interceptors.len() {
162        return terminal(req);
163    }
164
165    let current = Arc::clone(&interceptors[index]);
166    let interceptors_for_next = Arc::clone(&interceptors);
167    let ctx_for_next = ctx.clone();
168    let terminal_for_next = Arc::clone(&terminal);
169
170    let next_fn: NextFn = Arc::new(move |next_req: Request<Body>| {
171        run_interceptor_chain(
172            Arc::clone(&interceptors_for_next),
173            index + 1,
174            ctx_for_next.clone(),
175            next_req,
176            Arc::clone(&terminal_for_next),
177        )
178    });
179
180    current.around(ctx, req, next_fn)
181}
182
183pub async fn execute_pipeline(
184    req: Request<Body>,
185    next: Next,
186    guards: Arc<Vec<Arc<dyn Guard>>>,
187    interceptors: Arc<Vec<Arc<dyn Interceptor>>>,
188    filters: Arc<Vec<Arc<dyn ExceptionFilter>>>,
189) -> Response {
190    let ctx = RequestContext::from_request(&req);
191
192    if let Err(err) = run_guards(guards.as_slice(), &ctx) {
193        return apply_exception_filters(err, &ctx, filters.as_slice()).into_response();
194    }
195
196    let terminal = next_to_fn(next);
197    run_interceptor_chain(interceptors, 0, ctx, req, terminal).await
198}
199
200pub fn apply_exception_filters(
201    mut exception: HttpException,
202    ctx: &RequestContext,
203    filters: &[Arc<dyn ExceptionFilter>],
204) -> HttpException {
205    for filter in filters {
206        exception = filter.catch(exception, ctx);
207    }
208
209    exception
210}
211
212#[cfg(test)]
213mod tests {
214    use std::sync::Arc;
215
216    use axum::http::Method;
217
218    use crate::{AuthIdentity, Guard};
219
220    use super::{RequestContext, RequireAuthenticationGuard, RoleRequirementsGuard};
221
222    fn anonymous_context() -> RequestContext {
223        RequestContext {
224            method: Method::GET,
225            uri: "/".parse().expect("uri should parse"),
226            request_id: None,
227            auth_identity: None,
228        }
229    }
230
231    fn authenticated_context(roles: &[&str]) -> RequestContext {
232        RequestContext {
233            method: Method::GET,
234            uri: "/".parse().expect("uri should parse"),
235            request_id: None,
236            auth_identity: Some(Arc::new(
237                AuthIdentity::new("user-1").with_roles(roles.iter().copied()),
238            )),
239        }
240    }
241
242    #[test]
243    fn authentication_guard_rejects_anonymous_requests() {
244        let guard = RequireAuthenticationGuard;
245
246        assert!(guard.can_activate(&anonymous_context()).is_err());
247        assert!(guard.can_activate(&authenticated_context(&[])).is_ok());
248    }
249
250    #[test]
251    fn role_guard_accepts_any_matching_role() {
252        let guard = RoleRequirementsGuard::new(["admin", "support"]);
253
254        assert!(guard
255            .can_activate(&authenticated_context(&["support"]))
256            .is_ok());
257        assert!(guard
258            .can_activate(&authenticated_context(&["viewer"]))
259            .is_err());
260        assert!(guard.can_activate(&anonymous_context()).is_err());
261    }
262}