Skip to main content

allframe_core/auth/
mod.rs

1//! Authentication primitives for AllFrame.
2//!
3//! This module provides a layered authentication system:
4//!
5//! - **`auth`** (this module): Core traits with zero dependencies
6//! - **`auth-jwt`**: JWT validation using `jsonwebtoken`
7//! - **`auth-axum`**: Axum extractors and middleware
8//! - **`auth-tonic`**: gRPC interceptors
9//!
10//! # Core Concepts
11//!
12//! The authentication system is built around a few key traits:
13//!
14//! - [`Authenticator`]: Validates tokens and returns claims
15//! - [`Claims`]: Marker trait for claim types
16//! - [`AuthContext`]: Holds authenticated user information
17//!
18//! # Example: Using Core Traits
19//!
20//! ```rust
21//! use allframe_core::auth::{Authenticator, AuthError, AuthContext};
22//!
23//! // Define your claims type
24//! #[derive(Clone, Debug)]
25//! struct MyClaims {
26//!     sub: String,
27//!     email: Option<String>,
28//! }
29//!
30//! // Implement your authenticator
31//! struct MyAuthenticator;
32//!
33//! #[async_trait::async_trait]
34//! impl Authenticator for MyAuthenticator {
35//!     type Claims = MyClaims;
36//!
37//!     async fn authenticate(&self, token: &str) -> Result<Self::Claims, AuthError> {
38//!         // Your validation logic here
39//!         Ok(MyClaims {
40//!             sub: "user123".to_string(),
41//!             email: Some("user@example.com".to_string()),
42//!         })
43//!     }
44//! }
45//! ```
46//!
47//! # Feature Flags
48//!
49//! | Feature | Description |
50//! |---------|-------------|
51//! | `auth` | Core traits (this module) |
52//! | `auth-jwt` | JWT validation with HS256/RS256 support |
53//! | `auth-axum` | Axum extractors and middleware |
54//! | `auth-tonic` | gRPC interceptors |
55
56use std::fmt;
57
58#[cfg(feature = "auth-jwt")]
59pub mod jwt;
60
61#[cfg(feature = "auth-axum")]
62pub mod axum;
63
64#[cfg(feature = "auth-tonic")]
65pub mod tonic;
66
67// Re-exports
68#[cfg(feature = "auth-jwt")]
69pub use jwt::{JwtAlgorithm, JwtConfig, JwtValidator};
70
71#[cfg(feature = "auth-axum")]
72pub use self::axum::{AuthLayer, AuthenticatedUser};
73#[cfg(feature = "auth-tonic")]
74pub use self::tonic::AuthInterceptor;
75
76/// Error type for authentication failures.
77#[derive(Debug, Clone)]
78pub enum AuthError {
79    /// No token was provided.
80    MissingToken,
81    /// Token format is invalid.
82    InvalidToken(String),
83    /// Token has expired.
84    TokenExpired,
85    /// Token signature is invalid.
86    InvalidSignature,
87    /// Token issuer doesn't match.
88    InvalidIssuer,
89    /// Token audience doesn't match.
90    InvalidAudience,
91    /// Custom validation error.
92    ValidationFailed(String),
93    /// Internal error during authentication.
94    Internal(String),
95}
96
97impl fmt::Display for AuthError {
98    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99        match self {
100            AuthError::MissingToken => write!(f, "missing authentication token"),
101            AuthError::InvalidToken(msg) => write!(f, "invalid token: {}", msg),
102            AuthError::TokenExpired => write!(f, "token has expired"),
103            AuthError::InvalidSignature => write!(f, "invalid token signature"),
104            AuthError::InvalidIssuer => write!(f, "invalid token issuer"),
105            AuthError::InvalidAudience => write!(f, "invalid token audience"),
106            AuthError::ValidationFailed(msg) => write!(f, "validation failed: {}", msg),
107            AuthError::Internal(msg) => write!(f, "internal auth error: {}", msg),
108        }
109    }
110}
111
112impl std::error::Error for AuthError {}
113
114impl AuthError {
115    /// Check if this is a "missing token" error (vs invalid token).
116    pub fn is_missing(&self) -> bool {
117        matches!(self, AuthError::MissingToken)
118    }
119
120    /// Check if this is an expiration error.
121    pub fn is_expired(&self) -> bool {
122        matches!(self, AuthError::TokenExpired)
123    }
124
125    /// Get an appropriate HTTP status code for this error.
126    pub fn status_code(&self) -> u16 {
127        match self {
128            AuthError::MissingToken => 401,
129            AuthError::InvalidToken(_) => 401,
130            AuthError::TokenExpired => 401,
131            AuthError::InvalidSignature => 401,
132            AuthError::InvalidIssuer => 401,
133            AuthError::InvalidAudience => 401,
134            AuthError::ValidationFailed(_) => 403,
135            AuthError::Internal(_) => 500,
136        }
137    }
138}
139
140/// Trait for types that can validate authentication tokens.
141///
142/// Implement this trait to create custom authenticators for different
143/// token types (JWT, API keys, session tokens, etc.).
144///
145/// # Example
146///
147/// ```rust
148/// use allframe_core::auth::{Authenticator, AuthError};
149///
150/// struct ApiKeyAuthenticator {
151///     valid_keys: Vec<String>,
152/// }
153///
154/// #[async_trait::async_trait]
155/// impl Authenticator for ApiKeyAuthenticator {
156///     type Claims = String; // Just the key itself
157///
158///     async fn authenticate(&self, token: &str) -> Result<Self::Claims, AuthError> {
159///         if self.valid_keys.contains(&token.to_string()) {
160///             Ok(token.to_string())
161///         } else {
162///             Err(AuthError::InvalidToken("unknown API key".into()))
163///         }
164///     }
165/// }
166/// ```
167#[async_trait::async_trait]
168pub trait Authenticator: Send + Sync {
169    /// The claims type returned on successful authentication.
170    type Claims: Clone + Send + Sync + 'static;
171
172    /// Validate a token and extract claims.
173    ///
174    /// # Arguments
175    /// * `token` - The raw token string (without "Bearer " prefix)
176    ///
177    /// # Returns
178    /// * `Ok(Claims)` - Authentication successful
179    /// * `Err(AuthError)` - Authentication failed
180    async fn authenticate(&self, token: &str) -> Result<Self::Claims, AuthError>;
181}
182
183/// Context holding authenticated user information.
184///
185/// This is the result of successful authentication and contains
186/// the validated claims.
187#[derive(Clone, Debug)]
188pub struct AuthContext<C> {
189    /// The validated claims.
190    pub claims: C,
191    /// The original token (for forwarding to downstream services).
192    pub token: String,
193}
194
195impl<C: Clone> AuthContext<C> {
196    /// Create a new auth context.
197    pub fn new(claims: C, token: impl Into<String>) -> Self {
198        Self {
199            claims,
200            token: token.into(),
201        }
202    }
203
204    /// Get the claims.
205    pub fn claims(&self) -> &C {
206        &self.claims
207    }
208
209    /// Get the original token.
210    pub fn token(&self) -> &str {
211        &self.token
212    }
213
214    /// Extract a value from claims using a closure.
215    pub fn get<T>(&self, f: impl FnOnce(&C) -> T) -> T {
216        f(&self.claims)
217    }
218}
219
220/// Extract bearer token from an authorization header value.
221///
222/// # Example
223///
224/// ```rust
225/// use allframe_core::auth::extract_bearer_token;
226///
227/// assert_eq!(extract_bearer_token("Bearer abc123"), Some("abc123"));
228/// assert_eq!(extract_bearer_token("bearer ABC"), Some("ABC"));
229/// assert_eq!(extract_bearer_token("Basic xyz"), None);
230/// assert_eq!(extract_bearer_token("abc123"), None);
231/// ```
232pub fn extract_bearer_token(header_value: &str) -> Option<&str> {
233    let header = header_value.trim();
234    if header.len() > 7 && header[..7].eq_ignore_ascii_case("bearer ") {
235        Some(header[7..].trim())
236    } else {
237        None
238    }
239}
240
241/// Trait for claims that have a subject (user ID).
242pub trait HasSubject {
243    /// Get the subject (user ID) from the claims.
244    fn subject(&self) -> &str;
245}
246
247/// Trait for claims that have an expiration time.
248pub trait HasExpiration {
249    /// Get the expiration timestamp (Unix seconds).
250    fn expiration(&self) -> Option<i64>;
251
252    /// Check if the claims have expired.
253    fn is_expired(&self) -> bool {
254        if let Some(exp) = self.expiration() {
255            let now = std::time::SystemTime::now()
256                .duration_since(std::time::UNIX_EPOCH)
257                .unwrap()
258                .as_secs() as i64;
259            exp < now
260        } else {
261            false
262        }
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn test_extract_bearer_token() {
272        assert_eq!(extract_bearer_token("Bearer abc123"), Some("abc123"));
273        assert_eq!(extract_bearer_token("bearer ABC"), Some("ABC"));
274        assert_eq!(extract_bearer_token("BEARER token"), Some("token"));
275        assert_eq!(extract_bearer_token("Bearer  spaced"), Some("spaced"));
276        assert_eq!(extract_bearer_token("Basic xyz"), None);
277        assert_eq!(extract_bearer_token("abc123"), None);
278        assert_eq!(extract_bearer_token(""), None);
279        assert_eq!(extract_bearer_token("Bearer"), None);
280        // "Bearer " with no token after is treated as invalid
281        assert_eq!(extract_bearer_token("Bearer "), None);
282    }
283
284    #[test]
285    fn test_auth_error_display() {
286        assert_eq!(
287            AuthError::MissingToken.to_string(),
288            "missing authentication token"
289        );
290        assert_eq!(AuthError::TokenExpired.to_string(), "token has expired");
291        assert_eq!(
292            AuthError::InvalidToken("bad".into()).to_string(),
293            "invalid token: bad"
294        );
295    }
296
297    #[test]
298    fn test_auth_error_status_codes() {
299        assert_eq!(AuthError::MissingToken.status_code(), 401);
300        assert_eq!(AuthError::TokenExpired.status_code(), 401);
301        assert_eq!(AuthError::ValidationFailed("".into()).status_code(), 403);
302        assert_eq!(AuthError::Internal("".into()).status_code(), 500);
303    }
304
305    #[test]
306    fn test_auth_context() {
307        #[derive(Clone, Debug)]
308        struct TestClaims {
309            sub: String,
310            role: String,
311        }
312
313        let ctx = AuthContext::new(
314            TestClaims {
315                sub: "user123".into(),
316                role: "admin".into(),
317            },
318            "token123",
319        );
320
321        assert_eq!(ctx.claims().sub, "user123");
322        assert_eq!(ctx.token(), "token123");
323        assert_eq!(ctx.get(|c| c.role.clone()), "admin");
324    }
325
326    #[test]
327    fn test_auth_error_predicates() {
328        assert!(AuthError::MissingToken.is_missing());
329        assert!(!AuthError::TokenExpired.is_missing());
330        assert!(AuthError::TokenExpired.is_expired());
331        assert!(!AuthError::MissingToken.is_expired());
332    }
333
334    #[derive(Clone)]
335    struct MockClaims {
336        exp: Option<i64>,
337    }
338
339    impl HasExpiration for MockClaims {
340        fn expiration(&self) -> Option<i64> {
341            self.exp
342        }
343    }
344
345    #[test]
346    fn test_has_expiration() {
347        let past = MockClaims { exp: Some(0) };
348        assert!(past.is_expired());
349
350        let future = MockClaims {
351            exp: Some(i64::MAX),
352        };
353        assert!(!future.is_expired());
354
355        let no_exp = MockClaims { exp: None };
356        assert!(!no_exp.is_expired());
357    }
358}