Skip to main content

axum_acl/
middleware.rs

1//! ACL middleware implementation for axum.
2//!
3//! This module provides the [`AclLayer`] and [`AclMiddleware`] types that
4//! integrate with axum's middleware system, as well as the generic
5//! [`GenericAclLayer`] and [`GenericAclMiddleware`] for custom auth types.
6
7use crate::error::{AccessDenied, AccessDeniedHandler, DefaultDeniedHandler};
8use crate::extractor::{
9    AuthExtractor, AuthResult, HeaderIdExtractor, HeaderRoleExtractor, IdExtractor, RoleExtractor,
10};
11use crate::rule::{AclAction, BitmaskAuth, RequestMeta};
12use crate::table::AclTable;
13
14use axum::extract::ConnectInfo;
15use axum::response::Response;
16use futures_util::future::BoxFuture;
17use http::{Request, StatusCode};
18use http_body::Body;
19use std::collections::HashMap;
20use std::net::{IpAddr, SocketAddr};
21use std::sync::Arc;
22use std::task::{Context, Poll};
23use tower::{Layer, Service};
24
25// ============================================================================
26// Legacy Middleware (BitmaskAuth via RoleExtractor + IdExtractor)
27// ============================================================================
28
29/// Configuration for the ACL middleware.
30pub struct AclConfig<E, I> {
31    /// The ACL table containing the rules.
32    pub table: Arc<AclTable>,
33    /// The role extractor.
34    pub role_extractor: Arc<E>,
35    /// The ID extractor.
36    pub id_extractor: Arc<I>,
37    /// The handler for access denied responses.
38    pub denied_handler: Arc<dyn AccessDeniedHandler>,
39    /// The roles bitmask to use for anonymous users.
40    pub anonymous_roles: u32,
41    /// Header to check for forwarded IP (e.g., X-Forwarded-For).
42    pub forwarded_ip_header: Option<String>,
43    /// Default ID when ID extractor returns anonymous.
44    pub default_id: String,
45}
46
47// Manual Clone impl to avoid requiring E/I: Clone (since they're behind Arc)
48impl<E, I> Clone for AclConfig<E, I> {
49    fn clone(&self) -> Self {
50        Self {
51            table: self.table.clone(),
52            role_extractor: self.role_extractor.clone(),
53            id_extractor: self.id_extractor.clone(),
54            denied_handler: self.denied_handler.clone(),
55            anonymous_roles: self.anonymous_roles,
56            forwarded_ip_header: self.forwarded_ip_header.clone(),
57            default_id: self.default_id.clone(),
58        }
59    }
60}
61
62/// A Tower layer that adds ACL middleware to a service.
63///
64/// # Example
65/// ```no_run
66/// use axum::{Router, routing::get};
67/// use axum_acl::{AclLayer, AclTable, AclRuleFilter, AclAction};
68/// use std::net::SocketAddr;
69///
70/// async fn handler() -> &'static str {
71///     "Hello, World!"
72/// }
73///
74/// #[tokio::main]
75/// async fn main() {
76///     let acl_table = AclTable::builder()
77///         .default_action(AclAction::Deny)
78///         .add_any(AclRuleFilter::new()
79///             .role_mask(0b1)  // admin role
80///             .action(AclAction::Allow))
81///         .build();
82///
83///     let app = Router::new()
84///         .route("/", get(handler))
85///         .layer(AclLayer::new(acl_table));
86///
87///     let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
88///     axum::serve(
89///         listener,
90///         app.into_make_service_with_connect_info::<SocketAddr>()
91///     ).await.unwrap();
92/// }
93/// ```
94#[derive(Clone)]
95pub struct AclLayer<E, I> {
96    config: AclConfig<E, I>,
97}
98
99impl AclLayer<HeaderRoleExtractor, HeaderIdExtractor> {
100    /// Create a new ACL layer with the given table.
101    ///
102    /// Uses the default header role extractor (`X-Roles` header),
103    /// default header ID extractor (`X-User-Id` header), and
104    /// default denied handler (plain text 403).
105    pub fn new(table: AclTable) -> Self {
106        Self {
107            config: AclConfig {
108                table: Arc::new(table),
109                role_extractor: Arc::new(HeaderRoleExtractor::new("X-Roles")),
110                id_extractor: Arc::new(HeaderIdExtractor::new("X-User-Id")),
111                denied_handler: Arc::new(DefaultDeniedHandler),
112                anonymous_roles: 0,
113                forwarded_ip_header: None,
114                default_id: "*".to_string(),
115            },
116        }
117    }
118}
119
120impl<E, I> AclLayer<E, I> {
121    /// Create a new ACL layer with a custom role extractor.
122    ///
123    /// # Example
124    /// ```
125    /// use axum_acl::{AclLayer, AclTable, HeaderRoleExtractor};
126    ///
127    /// let table = AclTable::new();
128    /// let layer = AclLayer::new(table)
129    ///     .with_role_extractor(HeaderRoleExtractor::new("X-User-Roles"));
130    /// ```
131    pub fn with_role_extractor<E2>(self, extractor: E2) -> AclLayer<E2, I> {
132        AclLayer {
133            config: AclConfig {
134                table: self.config.table,
135                role_extractor: Arc::new(extractor),
136                id_extractor: self.config.id_extractor,
137                denied_handler: self.config.denied_handler,
138                anonymous_roles: self.config.anonymous_roles,
139                forwarded_ip_header: self.config.forwarded_ip_header,
140                default_id: self.config.default_id,
141            },
142        }
143    }
144
145    /// Create a new ACL layer with a custom ID extractor.
146    ///
147    /// # Example
148    /// ```
149    /// use axum_acl::{AclLayer, AclTable, HeaderIdExtractor};
150    ///
151    /// let table = AclTable::new();
152    /// let layer = AclLayer::new(table)
153    ///     .with_id_extractor(HeaderIdExtractor::new("X-User-Id"));
154    /// ```
155    pub fn with_id_extractor<I2>(self, extractor: I2) -> AclLayer<E, I2> {
156        AclLayer {
157            config: AclConfig {
158                table: self.config.table,
159                role_extractor: self.config.role_extractor,
160                id_extractor: Arc::new(extractor),
161                denied_handler: self.config.denied_handler,
162                anonymous_roles: self.config.anonymous_roles,
163                forwarded_ip_header: self.config.forwarded_ip_header,
164                default_id: self.config.default_id,
165            },
166        }
167    }
168
169    /// Create a new ACL layer with a custom role extractor.
170    #[deprecated(since = "0.2.0", note = "Use with_role_extractor instead")]
171    pub fn with_extractor<E2>(self, extractor: E2) -> AclLayer<E2, I> {
172        self.with_role_extractor(extractor)
173    }
174
175    /// Set a custom access denied handler.
176    pub fn with_denied_handler(mut self, handler: impl AccessDeniedHandler + 'static) -> Self {
177        self.config.denied_handler = Arc::new(handler);
178        self
179    }
180
181    /// Set the roles bitmask to use for anonymous/unauthenticated users.
182    pub fn with_anonymous_roles(mut self, roles: u32) -> Self {
183        self.config.anonymous_roles = roles;
184        self
185    }
186
187    /// Set a header to extract the client IP from (e.g., X-Forwarded-For).
188    ///
189    /// When behind a reverse proxy, the client IP may be in a header.
190    /// This setting tells the middleware which header to check.
191    pub fn with_forwarded_ip_header(mut self, header: impl Into<String>) -> Self {
192        self.config.forwarded_ip_header = Some(header.into());
193        self
194    }
195
196    /// Set the default ID to use when the ID extractor returns anonymous.
197    pub fn with_default_id(mut self, id: impl Into<String>) -> Self {
198        self.config.default_id = id.into();
199        self
200    }
201
202    /// Get a reference to the ACL table.
203    pub fn table(&self) -> &AclTable {
204        &self.config.table
205    }
206}
207
208impl<S, E: Clone, I: Clone> Layer<S> for AclLayer<E, I> {
209    type Service = AclMiddleware<S, E, I>;
210
211    fn layer(&self, inner: S) -> Self::Service {
212        AclMiddleware {
213            inner,
214            config: self.config.clone(),
215        }
216    }
217}
218
219/// The ACL middleware service.
220#[derive(Clone)]
221pub struct AclMiddleware<S, E, I> {
222    inner: S,
223    config: AclConfig<E, I>,
224}
225
226impl<S, E, I, ReqBody, ResBody> Service<Request<ReqBody>> for AclMiddleware<S, E, I>
227where
228    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
229    S::Future: Send,
230    E: RoleExtractor<ReqBody> + 'static,
231    I: IdExtractor<ReqBody> + 'static,
232    ReqBody: Body + Send + 'static,
233    ResBody: Body + Default + Send + 'static,
234{
235    type Response = Response<ResBody>;
236    type Error = S::Error;
237    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
238
239    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
240        self.inner.poll_ready(cx)
241    }
242
243    fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
244        let config = self.config.clone();
245        let mut inner = self.inner.clone();
246
247        let role_result = config.role_extractor.extract_roles(&request);
248        let roles = role_result.roles_or(config.anonymous_roles);
249
250        let client_ip = extract_client_ip(&request, config.forwarded_ip_header.as_deref());
251
252        let id_result = config.id_extractor.extract_id(&request);
253        let id = id_result.id_or(&config.default_id);
254
255        let method = request.method().clone();
256        let path = request.uri().path().to_string();
257
258        Box::pin(async move {
259            let Some(client_ip) = client_ip else {
260                tracing::warn!("Failed to extract client IP address");
261                let response = Response::builder()
262                    .status(StatusCode::INTERNAL_SERVER_ERROR)
263                    .body(ResBody::default())
264                    .unwrap();
265                return Ok(response);
266            };
267
268            let auth = BitmaskAuth {
269                roles,
270                id: id.clone(),
271            };
272            let meta = RequestMeta {
273                method,
274                path: path.clone(),
275                path_params: HashMap::new(),
276                ip: client_ip,
277            };
278
279            let action = config.table.evaluate_request(&auth, &meta);
280
281            handle_action(action, &path, &id, roles, client_ip, &config.denied_handler, request, &mut inner).await
282        })
283    }
284}
285
286// ============================================================================
287// Generic Middleware (custom auth type via AuthExtractor)
288// ============================================================================
289
290/// Configuration for the generic ACL middleware.
291pub struct GenericAclConfig<A, X> {
292    /// The ACL table containing the rules.
293    pub table: Arc<AclTable<A>>,
294    /// The auth extractor.
295    pub auth_extractor: Arc<X>,
296    /// The handler for access denied responses.
297    pub denied_handler: Arc<dyn AccessDeniedHandler>,
298    /// Header to check for forwarded IP (e.g., X-Forwarded-For).
299    pub forwarded_ip_header: Option<String>,
300}
301
302impl<A, X> Clone for GenericAclConfig<A, X> {
303    fn clone(&self) -> Self {
304        Self {
305            table: self.table.clone(),
306            auth_extractor: self.auth_extractor.clone(),
307            denied_handler: self.denied_handler.clone(),
308            forwarded_ip_header: self.forwarded_ip_header.clone(),
309        }
310    }
311}
312
313/// A Tower layer for the generic ACL middleware.
314#[derive(Clone)]
315pub struct GenericAclLayer<A, X> {
316    config: GenericAclConfig<A, X>,
317}
318
319impl<A, X> GenericAclLayer<A, X> {
320    /// Create a new generic ACL layer.
321    pub fn with_auth(table: AclTable<A>, extractor: X) -> Self {
322        Self {
323            config: GenericAclConfig {
324                table: Arc::new(table),
325                auth_extractor: Arc::new(extractor),
326                denied_handler: Arc::new(DefaultDeniedHandler),
327                forwarded_ip_header: None,
328            },
329        }
330    }
331
332    /// Set a custom access denied handler.
333    pub fn with_denied_handler(
334        mut self,
335        handler: impl AccessDeniedHandler + 'static,
336    ) -> Self {
337        self.config.denied_handler = Arc::new(handler);
338        self
339    }
340
341    /// Set a header to extract the client IP from.
342    pub fn with_forwarded_ip_header(mut self, header: impl Into<String>) -> Self {
343        self.config.forwarded_ip_header = Some(header.into());
344        self
345    }
346}
347
348impl<S, A: Clone, X: Clone> Layer<S> for GenericAclLayer<A, X> {
349    type Service = GenericAclMiddleware<S, A, X>;
350
351    fn layer(&self, inner: S) -> Self::Service {
352        GenericAclMiddleware {
353            inner,
354            config: self.config.clone(),
355        }
356    }
357}
358
359/// The generic ACL middleware service.
360#[derive(Clone)]
361pub struct GenericAclMiddleware<S, A, X> {
362    inner: S,
363    config: GenericAclConfig<A, X>,
364}
365
366impl<S, A, X, ReqBody, ResBody> Service<Request<ReqBody>> for GenericAclMiddleware<S, A, X>
367where
368    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
369    S::Future: Send,
370    A: Send + Sync + 'static,
371    X: AuthExtractor<A, ReqBody> + 'static,
372    ReqBody: Body + Send + 'static,
373    ResBody: Body + Default + Send + 'static,
374{
375    type Response = Response<ResBody>;
376    type Error = S::Error;
377    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
378
379    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
380        self.inner.poll_ready(cx)
381    }
382
383    fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
384        let config = self.config.clone();
385        let mut inner = self.inner.clone();
386
387        let auth_result = config.auth_extractor.extract_auth(&request);
388        let client_ip = extract_client_ip(&request, config.forwarded_ip_header.as_deref());
389        let method = request.method().clone();
390        let path = request.uri().path().to_string();
391
392        Box::pin(async move {
393            let Some(client_ip) = client_ip else {
394                tracing::warn!("Failed to extract client IP address");
395                let response = Response::builder()
396                    .status(StatusCode::INTERNAL_SERVER_ERROR)
397                    .body(ResBody::default())
398                    .unwrap();
399                return Ok(response);
400            };
401
402            let meta = RequestMeta {
403                method,
404                path: path.clone(),
405                path_params: HashMap::new(),
406                ip: client_ip,
407            };
408
409            let action = match auth_result {
410                AuthResult::Auth(auth) => config.table.evaluate_request(&auth, &meta),
411                AuthResult::Anonymous => config.table.default_action(),
412                AuthResult::Error(e) => {
413                    tracing::warn!(error = %e, "Auth extraction failed");
414                    AclAction::Deny
415                }
416            };
417
418            handle_action(action, &path, "*", 0, client_ip, &config.denied_handler, request, &mut inner).await
419        })
420    }
421}
422
423// ============================================================================
424// Shared Helpers
425// ============================================================================
426
427/// Handle an ACL action, producing the appropriate response.
428async fn handle_action<S, ReqBody, ResBody>(
429    action: AclAction,
430    path: &str,
431    id: &str,
432    roles: u32,
433    client_ip: IpAddr,
434    denied_handler: &Arc<dyn AccessDeniedHandler>,
435    request: Request<ReqBody>,
436    inner: &mut S,
437) -> Result<Response<ResBody>, S::Error>
438where
439    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
440    S::Future: Send,
441    ResBody: Body + Default + Send + 'static,
442{
443    match action {
444        AclAction::Allow => {
445            tracing::trace!(
446                path = %path,
447                ip = %client_ip,
448                "ACL allowed request"
449            );
450            inner.call(request).await
451        }
452        AclAction::Deny => {
453            tracing::info!(
454                path = %path,
455                ip = %client_ip,
456                "ACL denied request"
457            );
458
459            let denied = AccessDenied::new_with_roles(roles, path, id);
460            let response = denied_handler.handle(&denied);
461            let (parts, _body) = response.into_parts();
462            let response = Response::from_parts(parts, ResBody::default());
463            Ok(response)
464        }
465        AclAction::Error { code, ref message } => {
466            tracing::info!(
467                path = %path,
468                ip = %client_ip,
469                code = code,
470                message = ?message,
471                "ACL returned error"
472            );
473
474            let status = StatusCode::from_u16(code).unwrap_or(StatusCode::FORBIDDEN);
475            let response = Response::builder()
476                .status(status)
477                .header("content-type", "text/plain")
478                .body(ResBody::default())
479                .unwrap();
480            Ok(response)
481        }
482        AclAction::Reroute {
483            ref target,
484            preserve_path,
485        } => {
486            tracing::info!(
487                path = %path,
488                ip = %client_ip,
489                target = %target,
490                "ACL rerouting request"
491            );
492
493            let mut response = Response::builder()
494                .status(StatusCode::TEMPORARY_REDIRECT)
495                .header("location", target.as_str())
496                .body(ResBody::default())
497                .unwrap();
498
499            if preserve_path {
500                response.headers_mut().insert(
501                    "x-original-path",
502                    path.parse().unwrap_or_else(|_| "/".parse().unwrap()),
503                );
504            }
505
506            Ok(response)
507        }
508        AclAction::RateLimit {
509            max_requests,
510            window_secs,
511        } => {
512            tracing::warn!(
513                path = %path,
514                ip = %client_ip,
515                max_requests = max_requests,
516                window_secs = window_secs,
517                "ACL rate limit action - not implemented, allowing request"
518            );
519            inner.call(request).await
520        }
521        AclAction::Log {
522            ref level,
523            ref message,
524        } => {
525            let msg = message.clone().unwrap_or_else(|| {
526                format!("ACL log: path={}, ip={}", path, client_ip)
527            });
528
529            match level.as_str() {
530                "trace" => tracing::trace!("{}", msg),
531                "debug" => tracing::debug!("{}", msg),
532                "warn" => tracing::warn!("{}", msg),
533                "error" => tracing::error!("{}", msg),
534                _ => tracing::info!("{}", msg),
535            }
536
537            inner.call(request).await
538        }
539    }
540}
541
542/// Extract the client IP address from the request.
543fn extract_client_ip<B>(request: &Request<B>, forwarded_header: Option<&str>) -> Option<IpAddr> {
544    if let Some(header_name) = forwarded_header {
545        if let Some(value) = request.headers().get(header_name) {
546            if let Ok(s) = value.to_str() {
547                if let Some(first_ip) = s.split(',').next() {
548                    if let Ok(ip) = first_ip.trim().parse::<IpAddr>() {
549                        return Some(ip);
550                    }
551                }
552            }
553        }
554    }
555
556    request
557        .extensions()
558        .get::<ConnectInfo<SocketAddr>>()
559        .map(|ci| ci.0.ip())
560}
561
562#[cfg(test)]
563mod tests {
564    // Tests for middleware are integration tests in examples/
565    // Unit tests would require mocking axum's Body type
566}