Skip to main content

nestforge_core/
pipeline.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    sync::{Arc, Mutex},
5};
6
7use axum::{
8    body::Body,
9    extract::Request,
10    http::{request::Parts, Method, Uri},
11    middleware::Next,
12    response::IntoResponse,
13    response::Response,
14};
15
16use crate::AuthIdentity;
17use crate::HttpException;
18
19/// Context available to all guards, interceptors, and pipes in the request lifecycle.
20///
21/// It provides access to basic request metadata (method, URI) and authentication state.
22/// This is lightweight and cloned cheaply (mostly `Arc`s or small types).
23#[derive(Clone, Debug)]
24pub struct RequestContext {
25    pub method: Method,
26    pub uri: Uri,
27    pub request_id: Option<String>,
28    pub auth_identity: Option<Arc<AuthIdentity>>,
29}
30
31impl RequestContext {
32    pub fn from_parts(parts: &Parts) -> Self {
33        Self {
34            method: parts.method.clone(),
35            uri: parts.uri.clone(),
36            request_id: crate::request::request_id_from_extensions(&parts.extensions),
37            auth_identity: parts.extensions.get::<Arc<AuthIdentity>>().cloned(),
38        }
39    }
40
41    pub fn from_request(req: &Request) -> Self {
42        Self {
43            method: req.method().clone(),
44            uri: req.uri().clone(),
45            request_id: crate::request::request_id_from_extensions(req.extensions()),
46            auth_identity: req.extensions().get::<Arc<AuthIdentity>>().cloned(),
47        }
48    }
49
50    pub fn is_authenticated(&self) -> bool {
51        self.auth_identity.is_some()
52    }
53
54    pub fn has_role(&self, role: &str) -> bool {
55        self.auth_identity
56            .as_ref()
57            .map(|identity| identity.has_role(role))
58            .unwrap_or(false)
59    }
60}
61
62/// A filter that catches exceptions thrown during request processing.
63///
64/// Use this to transform `HttpException`s into different responses or to log errors.
65///
66/// # Example
67/// ```rust
68/// struct LoggingFilter;
69/// impl ExceptionFilter for LoggingFilter {
70///     fn catch(&self, exception: HttpException, ctx: &RequestContext) -> HttpException {
71///         println!("Error: {:?}", exception);
72///         exception
73///     }
74/// }
75/// ```
76pub trait ExceptionFilter: Send + Sync + 'static {
77    fn catch(&self, exception: HttpException, ctx: &RequestContext) -> HttpException;
78}
79
80/// A guard determines whether a request should be allowed to proceed.
81///
82/// Guards run before any interceptors or the route handler. If a guard returns `Err`,
83/// the request is rejected immediately.
84///
85/// Common uses: Authentication, Authorization, Rate Limiting.
86pub trait Guard: Send + Sync + 'static {
87    fn can_activate(&self, ctx: &RequestContext) -> Result<(), HttpException>;
88}
89
90/// A built-in guard that ensures the user is authenticated.
91///
92/// It checks `ctx.is_authenticated()`. If false, it throws `401 Unauthorized`.
93#[derive(Default)]
94pub struct RequireAuthenticationGuard;
95
96impl Guard for RequireAuthenticationGuard {
97    fn can_activate(&self, ctx: &RequestContext) -> Result<(), HttpException> {
98        if ctx.is_authenticated() {
99            Ok(())
100        } else {
101            Err(HttpException::unauthorized("Authentication required"))
102        }
103    }
104}
105
106/// A built-in guard that checks if the authenticated user has specific roles.
107///
108/// It requires authentication first, then checks `ctx.has_role(role)`.
109pub struct RoleRequirementsGuard {
110    roles: Vec<String>,
111}
112
113impl RoleRequirementsGuard {
114    pub fn new<I, S>(roles: I) -> Self
115    where
116        I: IntoIterator<Item = S>,
117        S: Into<String>,
118    {
119        Self {
120            roles: roles.into_iter().map(Into::into).collect(),
121        }
122    }
123}
124
125impl Guard for RoleRequirementsGuard {
126    fn can_activate(&self, ctx: &RequestContext) -> Result<(), HttpException> {
127        if !ctx.is_authenticated() {
128            return Err(HttpException::unauthorized("Authentication required"));
129        }
130
131        if self.roles.iter().any(|role| ctx.has_role(role)) {
132            Ok(())
133        } else {
134            Err(HttpException::forbidden(format!(
135                "Missing required role. Expected one of: {}",
136                self.roles.join(", ")
137            )))
138        }
139    }
140}
141
142pub type NextFuture = Pin<Box<dyn Future<Output = Response> + Send>>;
143pub type NextFn = Arc<dyn Fn(Request<Body>) -> NextFuture + Send + Sync + 'static>;
144
145/// An interceptor wraps the request/response cycle.
146///
147/// It can execute logic *before* the handler runs (by inspecting the request)
148/// and *after* the handler returns (by transforming the response).
149///
150/// # Example
151/// ```rust
152/// struct LoggingInterceptor;
153/// impl Interceptor for LoggingInterceptor {
154///     fn around(&self, ctx: RequestContext, req: Request<Body>, next: NextFn) -> NextFuture {
155///         Box::pin(async move {
156///             println!("Before...");
157///             let res = next(req).await;
158///             println!("After...");
159///             res
160///         })
161///     }
162/// }
163/// ```
164pub trait Interceptor: Send + Sync + 'static {
165    fn around(&self, ctx: RequestContext, req: Request<Body>, next: NextFn) -> NextFuture;
166}
167
168pub fn run_guards(guards: &[Arc<dyn Guard>], ctx: &RequestContext) -> Result<(), HttpException> {
169    for guard in guards {
170        guard.can_activate(ctx)?;
171    }
172    Ok(())
173}
174
175fn next_to_fn(next: Next) -> NextFn {
176    /*
177    We wrap `Next` in an Arc<Mutex> because axum's `Next` is one-shot, but
178    our interceptor chain needs to be clonable/re-entrant in terms of API structure
179    (though it's still fundamentally linear).
180    */
181    let next = Arc::new(Mutex::new(Some(next)));
182
183    Arc::new(move |req: Request<Body>| {
184        let next = Arc::clone(&next);
185        Box::pin(async move {
186            let next = {
187                let mut guard = match next.lock() {
188                    Ok(guard) => guard,
189                    Err(_) => {
190                        return HttpException::internal_server_error("Pipeline lock poisoned")
191                            .into_response();
192                    }
193                };
194                guard.take()
195            };
196
197            match next {
198                Some(next) => next.run(req).await,
199                std::option::Option::None => crate::HttpException::internal_server_error(
200                    "Pipeline next called multiple times",
201                )
202                .into_response(),
203            }
204        })
205    })
206}
207
208fn run_interceptor_chain(
209    interceptors: Arc<Vec<Arc<dyn Interceptor>>>,
210    index: usize,
211    ctx: RequestContext,
212    req: Request<Body>,
213    terminal: NextFn,
214) -> NextFuture {
215    if index >= interceptors.len() {
216        return terminal(req);
217    }
218
219    let current = Arc::clone(&interceptors[index]);
220    let interceptors_for_next = Arc::clone(&interceptors);
221    let ctx_for_next = ctx.clone();
222    let terminal_for_next = Arc::clone(&terminal);
223
224    let next_fn: NextFn = Arc::new(move |next_req: Request<Body>| {
225        run_interceptor_chain(
226            Arc::clone(&interceptors_for_next),
227            index + 1,
228            ctx_for_next.clone(),
229            next_req,
230            Arc::clone(&terminal_for_next),
231        )
232    });
233
234    current.around(ctx, req, next_fn)
235}
236
237/// Executes the full NestForge request pipeline.
238///
239/// 1. Runs all **Guards**.
240/// 2. If successful, runs the **Interceptor Chain**.
241/// 3. If any step fails (guards or handler), runs **Exception Filters**.
242pub async fn execute_pipeline(
243    req: Request<Body>,
244    next: Next,
245    guards: Arc<Vec<Arc<dyn Guard>>>,
246    interceptors: Arc<Vec<Arc<dyn Interceptor>>>,
247    filters: Arc<Vec<Arc<dyn ExceptionFilter>>>,
248) -> Response {
249    let ctx = RequestContext::from_request(&req);
250
251    if let Err(err) = run_guards(guards.as_slice(), &ctx) {
252        return apply_exception_filters(err, &ctx, filters.as_slice()).into_response();
253    }
254
255    let terminal = next_to_fn(next);
256    run_interceptor_chain(interceptors, 0, ctx, req, terminal).await
257}
258
259pub fn apply_exception_filters(
260    mut exception: HttpException,
261    ctx: &RequestContext,
262    filters: &[Arc<dyn ExceptionFilter>],
263) -> HttpException {
264    for filter in filters {
265        exception = filter.catch(exception, ctx);
266    }
267
268    exception
269}
270
271#[cfg(test)]
272mod tests {
273    use std::sync::Arc;
274
275    use axum::http::Method;
276
277    use crate::{AuthIdentity, Guard};
278
279    use super::{RequestContext, RequireAuthenticationGuard, RoleRequirementsGuard};
280
281    fn anonymous_context() -> RequestContext {
282        RequestContext {
283            method: Method::GET,
284            uri: "/".parse().expect("uri should parse"),
285            request_id: None,
286            auth_identity: None,
287        }
288    }
289
290    fn authenticated_context(roles: &[&str]) -> RequestContext {
291        RequestContext {
292            method: Method::GET,
293            uri: "/".parse().expect("uri should parse"),
294            request_id: None,
295            auth_identity: Some(Arc::new(
296                AuthIdentity::new("user-1").with_roles(roles.iter().copied()),
297            )),
298        }
299    }
300
301    #[test]
302    fn authentication_guard_rejects_anonymous_requests() {
303        let guard = RequireAuthenticationGuard;
304
305        assert!(guard.can_activate(&anonymous_context()).is_err());
306        assert!(guard.can_activate(&authenticated_context(&[])).is_ok());
307    }
308
309    #[test]
310    fn role_guard_accepts_any_matching_role() {
311        let guard = RoleRequirementsGuard::new(["admin", "support"]);
312
313        assert!(guard
314            .can_activate(&authenticated_context(&["support"]))
315            .is_ok());
316        assert!(guard
317            .can_activate(&authenticated_context(&["viewer"]))
318            .is_err());
319        assert!(guard.can_activate(&anonymous_context()).is_err());
320    }
321}