allframe_core/auth/
axum.rs

1//! Axum integration for authentication.
2//!
3//! Provides extractors and middleware for using AllFrame auth with Axum.
4//!
5//! # Example
6//!
7//! ```rust,ignore
8//! use allframe_core::auth::{AuthenticatedUser, AuthLayer, JwtValidator, JwtConfig};
9//! use axum::{Router, routing::get, Extension};
10//!
11//! #[derive(Clone, serde::Deserialize)]
12//! struct Claims {
13//!     sub: String,
14//!     role: String,
15//! }
16//!
17//! async fn protected_handler(
18//!     AuthenticatedUser(claims): AuthenticatedUser<Claims>,
19//! ) -> String {
20//!     format!("Hello, {}!", claims.sub)
21//! }
22//!
23//! // Setup
24//! let validator = JwtValidator::<Claims>::new(JwtConfig::hs256("secret"));
25//!
26//! let app = Router::new()
27//!     .route("/protected", get(protected_handler))
28//!     .layer(AuthLayer::new(validator));
29//! ```
30
31use std::{
32    future::Future,
33    marker::PhantomData,
34    pin::Pin,
35    sync::Arc,
36    task::{Context, Poll},
37};
38
39use super::{extract_bearer_token, AuthContext, AuthError, Authenticator};
40
41/// Extractor for authenticated requests.
42///
43/// Extracts and validates the bearer token from the Authorization header,
44/// returning the claims on success.
45///
46/// # Example
47///
48/// ```rust,ignore
49/// use allframe_core::auth::AuthenticatedUser;
50///
51/// async fn handler(AuthenticatedUser(claims): AuthenticatedUser<MyClaims>) -> String {
52///     format!("User ID: {}", claims.sub)
53/// }
54/// ```
55///
56/// # Extracting Optional Auth
57///
58/// Wrap in `Option` to make auth optional:
59///
60/// ```rust,ignore
61/// async fn handler(auth: Option<AuthenticatedUser<MyClaims>>) -> String {
62///     match auth {
63///         Some(AuthenticatedUser(claims)) => format!("Hello, {}", claims.sub),
64///         None => "Hello, anonymous!".to_string(),
65///     }
66/// }
67/// ```
68#[derive(Debug, Clone)]
69pub struct AuthenticatedUser<C>(pub C);
70
71impl<C> AuthenticatedUser<C> {
72    /// Get the claims.
73    pub fn claims(&self) -> &C {
74        &self.0
75    }
76
77    /// Unwrap into the inner claims.
78    pub fn into_inner(self) -> C {
79        self.0
80    }
81}
82
83impl<C> std::ops::Deref for AuthenticatedUser<C> {
84    type Target = C;
85
86    fn deref(&self) -> &Self::Target {
87        &self.0
88    }
89}
90
91/// Layer for adding authentication to a router.
92///
93/// This layer validates the Authorization header on each request and
94/// stores the auth context in request extensions.
95///
96/// # Example
97///
98/// ```rust,ignore
99/// use allframe_core::auth::{AuthLayer, JwtValidator, JwtConfig};
100///
101/// let validator = JwtValidator::<MyClaims>::new(JwtConfig::hs256("secret"));
102///
103/// let app = Router::new()
104///     .route("/protected", get(handler))
105///     .layer(AuthLayer::new(validator));
106/// ```
107#[derive(Clone)]
108pub struct AuthLayer<A> {
109    authenticator: Arc<A>,
110}
111
112impl<A> AuthLayer<A> {
113    /// Create a new auth layer with the given authenticator.
114    pub fn new(authenticator: A) -> Self {
115        Self {
116            authenticator: Arc::new(authenticator),
117        }
118    }
119}
120
121impl<S, A> tower::Layer<S> for AuthLayer<A>
122where
123    A: Clone,
124{
125    type Service = AuthService<S, A>;
126
127    fn layer(&self, inner: S) -> Self::Service {
128        AuthService {
129            inner,
130            authenticator: self.authenticator.clone(),
131        }
132    }
133}
134
135/// Service that performs authentication.
136#[derive(Clone)]
137pub struct AuthService<S, A> {
138    inner: S,
139    authenticator: Arc<A>,
140}
141
142impl<S, A, ReqBody> tower::Service<hyper::Request<ReqBody>> for AuthService<S, A>
143where
144    S: tower::Service<hyper::Request<ReqBody>> + Clone + Send + 'static,
145    S::Future: Send,
146    A: Authenticator + 'static,
147    ReqBody: Send + 'static,
148{
149    type Response = S::Response;
150    type Error = S::Error;
151    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
152
153    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
154        self.inner.poll_ready(cx)
155    }
156
157    fn call(&mut self, mut req: hyper::Request<ReqBody>) -> Self::Future {
158        let authenticator = self.authenticator.clone();
159        let mut inner = self.inner.clone();
160
161        Box::pin(async move {
162            // Extract token from Authorization header
163            if let Some(auth_header) = req.headers().get(hyper::header::AUTHORIZATION) {
164                if let Ok(header_str) = auth_header.to_str() {
165                    if let Some(token) = extract_bearer_token(header_str) {
166                        // Validate token
167                        if let Ok(claims) = authenticator.authenticate(token).await {
168                            let ctx = AuthContext::new(claims, token);
169                            req.extensions_mut().insert(ctx);
170                        }
171                    }
172                }
173            }
174
175            inner.call(req).await
176        })
177    }
178}
179
180/// Optional auth layer that doesn't reject unauthenticated requests.
181///
182/// Use this when you want to allow both authenticated and unauthenticated
183/// access, but still make auth info available when present.
184#[derive(Clone)]
185pub struct OptionalAuthLayer<A> {
186    authenticator: Arc<A>,
187}
188
189impl<A> OptionalAuthLayer<A> {
190    /// Create a new optional auth layer.
191    pub fn new(authenticator: A) -> Self {
192        Self {
193            authenticator: Arc::new(authenticator),
194        }
195    }
196}
197
198impl<S, A> tower::Layer<S> for OptionalAuthLayer<A>
199where
200    A: Clone,
201{
202    type Service = AuthService<S, A>;
203
204    fn layer(&self, inner: S) -> Self::Service {
205        AuthService {
206            inner,
207            authenticator: self.authenticator.clone(),
208        }
209    }
210}
211
212/// Extension trait for extracting auth context from request extensions.
213pub trait AuthExt {
214    /// Get the auth context if present.
215    fn auth_context<C: Clone + Send + Sync + 'static>(&self) -> Option<&AuthContext<C>>;
216
217    /// Get the claims if authenticated.
218    fn claims<C: Clone + Send + Sync + 'static>(&self) -> Option<&C> {
219        self.auth_context::<C>().map(|ctx| &ctx.claims)
220    }
221}
222
223impl<B> AuthExt for hyper::Request<B> {
224    fn auth_context<C: Clone + Send + Sync + 'static>(&self) -> Option<&AuthContext<C>> {
225        self.extensions().get::<AuthContext<C>>()
226    }
227}
228
229/// Rejection type for authentication failures.
230#[derive(Debug)]
231pub struct AuthRejection {
232    /// The authentication error.
233    pub error: AuthError,
234}
235
236impl AuthRejection {
237    /// Create a new rejection from an auth error.
238    pub fn new(error: AuthError) -> Self {
239        Self { error }
240    }
241}
242
243impl std::fmt::Display for AuthRejection {
244    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245        write!(f, "{}", self.error)
246    }
247}
248
249impl std::error::Error for AuthRejection {}
250
251// Note: Full axum::response::IntoResponse implementation would require
252// axum as a dependency. Users can implement this in their own code:
253//
254// impl axum::response::IntoResponse for AuthRejection {
255//     fn into_response(self) -> axum::response::Response {
256//         let status = match self.error.status_code() {
257//             401 => axum::http::StatusCode::UNAUTHORIZED,
258//             403 => axum::http::StatusCode::FORBIDDEN,
259//             _ => axum::http::StatusCode::INTERNAL_SERVER_ERROR,
260//         };
261//         (status, self.error.to_string()).into_response()
262//     }
263// }
264
265/// Marker type for required authentication.
266///
267/// Used to create extractors that reject unauthenticated requests.
268#[derive(Debug, Clone, Copy)]
269pub struct Required;
270
271/// Marker type for optional authentication.
272///
273/// Used to create extractors that allow unauthenticated requests.
274#[derive(Debug, Clone, Copy)]
275pub struct Optional;
276
277/// Generic auth extractor with configurable requirement.
278#[derive(Debug, Clone)]
279pub struct Auth<C, R = Required> {
280    /// The auth context (None if optional and unauthenticated).
281    pub context: Option<AuthContext<C>>,
282    _requirement: PhantomData<R>,
283}
284
285impl<C: Clone> Auth<C, Required> {
286    /// Get the claims (always present for Required auth).
287    pub fn claims(&self) -> &C {
288        &self.context.as_ref().unwrap().claims
289    }
290
291    /// Get the original token.
292    pub fn token(&self) -> &str {
293        self.context.as_ref().unwrap().token()
294    }
295}
296
297impl<C> Auth<C, Optional> {
298    /// Get the claims if authenticated.
299    pub fn claims(&self) -> Option<&C> {
300        self.context.as_ref().map(|ctx| &ctx.claims)
301    }
302
303    /// Check if authenticated.
304    pub fn is_authenticated(&self) -> bool {
305        self.context.is_some()
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn test_authenticated_user() {
315        #[derive(Clone, Debug, PartialEq)]
316        struct Claims {
317            sub: String,
318        }
319
320        let user = AuthenticatedUser(Claims {
321            sub: "user123".to_string(),
322        });
323
324        assert_eq!(user.claims().sub, "user123");
325        assert_eq!(user.sub, "user123"); // via Deref
326
327        let claims = user.into_inner();
328        assert_eq!(claims.sub, "user123");
329    }
330
331    #[test]
332    fn test_auth_rejection() {
333        let rejection = AuthRejection::new(AuthError::MissingToken);
334        assert!(rejection.to_string().contains("missing"));
335    }
336
337    #[test]
338    fn test_auth_ext_trait() {
339        #[derive(Clone)]
340        struct Claims {
341            sub: String,
342        }
343
344        let mut req = hyper::Request::builder()
345            .body(())
346            .unwrap();
347
348        // No auth context initially
349        assert!(req.auth_context::<Claims>().is_none());
350        assert!(req.claims::<Claims>().is_none());
351
352        // Add auth context
353        req.extensions_mut().insert(AuthContext::new(
354            Claims {
355                sub: "user123".to_string(),
356            },
357            "token",
358        ));
359
360        // Now available
361        assert!(req.auth_context::<Claims>().is_some());
362        assert_eq!(req.claims::<Claims>().unwrap().sub, "user123");
363    }
364}