ringkernel_core/
auth.rs

1//! Authentication framework for RingKernel.
2//!
3//! This module provides a pluggable authentication system with support for
4//! multiple methods including API keys, JWT tokens, and OAuth2.
5//!
6//! # Feature Flags
7//!
8//! - `auth` - Enables JWT token validation (requires `jsonwebtoken` crate)
9//!
10//! # Example
11//!
12//! ```rust,ignore
13//! use ringkernel_core::auth::{AuthProvider, ApiKeyAuth, AuthContext};
14//!
15//! // Simple API key authentication
16//! let auth = ApiKeyAuth::new()
17//!     .add_key("admin", "secret-key-123", &["admin", "read", "write"])
18//!     .add_key("readonly", "readonly-key-456", &["read"]);
19//!
20//! let ctx = auth.authenticate(&Credentials::ApiKey("secret-key-123".to_string())).await?;
21//! assert!(ctx.has_permission("write"));
22//!
23//! // JWT authentication
24//! let jwt_auth = JwtAuth::new(JwtConfig {
25//!     secret: "your-256-bit-secret".to_string(),
26//!     issuer: Some("ringkernel".to_string()),
27//!     audience: Some("api".to_string()),
28//!     ..Default::default()
29//! });
30//! let ctx = jwt_auth.authenticate(&Credentials::Bearer(token)).await?;
31//! ```
32
33use async_trait::async_trait;
34use parking_lot::RwLock;
35use std::collections::{HashMap, HashSet};
36use std::fmt;
37use std::sync::Arc;
38use std::time::{Duration, Instant};
39#[cfg(feature = "auth")]
40use std::time::{SystemTime, UNIX_EPOCH};
41
42#[cfg(feature = "auth")]
43use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
44
45// ============================================================================
46// CREDENTIALS
47// ============================================================================
48
49/// Credentials provided by a client for authentication.
50#[derive(Debug, Clone)]
51pub enum Credentials {
52    /// API key authentication.
53    ApiKey(String),
54    /// Bearer token (JWT) authentication.
55    Bearer(String),
56    /// Basic authentication (username:password).
57    Basic {
58        /// Username for basic auth.
59        username: String,
60        /// Password for basic auth.
61        password: String,
62    },
63    /// Custom credential type.
64    Custom {
65        /// Authentication scheme name.
66        scheme: String,
67        /// Credential value.
68        value: String,
69    },
70}
71
72impl Credentials {
73    /// Create API key credentials.
74    pub fn api_key(key: impl Into<String>) -> Self {
75        Self::ApiKey(key.into())
76    }
77
78    /// Create bearer token credentials.
79    pub fn bearer(token: impl Into<String>) -> Self {
80        Self::Bearer(token.into())
81    }
82
83    /// Create basic auth credentials.
84    pub fn basic(username: impl Into<String>, password: impl Into<String>) -> Self {
85        Self::Basic {
86            username: username.into(),
87            password: password.into(),
88        }
89    }
90
91    /// Parse from Authorization header value.
92    pub fn from_header(header: &str) -> Option<Self> {
93        let parts: Vec<&str> = header.splitn(2, ' ').collect();
94        if parts.len() != 2 {
95            return None;
96        }
97
98        match parts[0].to_lowercase().as_str() {
99            "bearer" => Some(Self::Bearer(parts[1].to_string())),
100            "basic" => {
101                #[cfg(feature = "auth")]
102                {
103                    use base64::Engine;
104                    let decoded = base64::engine::general_purpose::STANDARD
105                        .decode(parts[1])
106                        .ok()?;
107                    let decoded_str = String::from_utf8(decoded).ok()?;
108                    let creds: Vec<&str> = decoded_str.splitn(2, ':').collect();
109                    if creds.len() == 2 {
110                        Some(Self::Basic {
111                            username: creds[0].to_string(),
112                            password: creds[1].to_string(),
113                        })
114                    } else {
115                        None
116                    }
117                }
118                #[cfg(not(feature = "auth"))]
119                {
120                    None
121                }
122            }
123            "apikey" | "api-key" | "x-api-key" => Some(Self::ApiKey(parts[1].to_string())),
124            scheme => Some(Self::Custom {
125                scheme: scheme.to_string(),
126                value: parts[1].to_string(),
127            }),
128        }
129    }
130}
131
132// ============================================================================
133// AUTH CONTEXT
134// ============================================================================
135
136/// Identity of an authenticated principal.
137#[derive(Debug, Clone)]
138pub struct Identity {
139    /// Unique identifier (user ID, service account, etc.).
140    pub id: String,
141    /// Display name.
142    pub name: Option<String>,
143    /// Email address.
144    pub email: Option<String>,
145    /// Tenant/organization ID (for multi-tenancy).
146    pub tenant_id: Option<String>,
147    /// Additional claims/attributes.
148    pub claims: HashMap<String, String>,
149}
150
151impl Identity {
152    /// Create a new identity.
153    pub fn new(id: impl Into<String>) -> Self {
154        Self {
155            id: id.into(),
156            name: None,
157            email: None,
158            tenant_id: None,
159            claims: HashMap::new(),
160        }
161    }
162
163    /// Set the name.
164    pub fn with_name(mut self, name: impl Into<String>) -> Self {
165        self.name = Some(name.into());
166        self
167    }
168
169    /// Set the email.
170    pub fn with_email(mut self, email: impl Into<String>) -> Self {
171        self.email = Some(email.into());
172        self
173    }
174
175    /// Set the tenant ID.
176    pub fn with_tenant(mut self, tenant_id: impl Into<String>) -> Self {
177        self.tenant_id = Some(tenant_id.into());
178        self
179    }
180
181    /// Add a claim.
182    pub fn with_claim(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
183        self.claims.insert(key.into(), value.into());
184        self
185    }
186}
187
188/// Authentication context for an authenticated request.
189#[derive(Debug, Clone)]
190pub struct AuthContext {
191    /// The authenticated identity.
192    pub identity: Identity,
193    /// Roles assigned to this identity.
194    pub roles: HashSet<String>,
195    /// Permissions granted.
196    pub permissions: HashSet<String>,
197    /// When authentication occurred.
198    pub authenticated_at: Instant,
199    /// When the authentication expires.
200    pub expires_at: Option<Instant>,
201    /// The authentication method used.
202    pub auth_method: String,
203}
204
205impl AuthContext {
206    /// Create a new auth context.
207    pub fn new(identity: Identity, auth_method: impl Into<String>) -> Self {
208        Self {
209            identity,
210            roles: HashSet::new(),
211            permissions: HashSet::new(),
212            authenticated_at: Instant::now(),
213            expires_at: None,
214            auth_method: auth_method.into(),
215        }
216    }
217
218    /// Add a role.
219    pub fn with_role(mut self, role: impl Into<String>) -> Self {
220        self.roles.insert(role.into());
221        self
222    }
223
224    /// Add roles.
225    pub fn with_roles<I, S>(mut self, roles: I) -> Self
226    where
227        I: IntoIterator<Item = S>,
228        S: Into<String>,
229    {
230        self.roles.extend(roles.into_iter().map(Into::into));
231        self
232    }
233
234    /// Add a permission.
235    pub fn with_permission(mut self, permission: impl Into<String>) -> Self {
236        self.permissions.insert(permission.into());
237        self
238    }
239
240    /// Add permissions.
241    pub fn with_permissions<I, S>(mut self, permissions: I) -> Self
242    where
243        I: IntoIterator<Item = S>,
244        S: Into<String>,
245    {
246        self.permissions
247            .extend(permissions.into_iter().map(Into::into));
248        self
249    }
250
251    /// Set expiration.
252    pub fn with_expiry(mut self, duration: Duration) -> Self {
253        self.expires_at = Some(Instant::now() + duration);
254        self
255    }
256
257    /// Check if the context has expired.
258    pub fn is_expired(&self) -> bool {
259        self.expires_at
260            .map(|exp| Instant::now() > exp)
261            .unwrap_or(false)
262    }
263
264    /// Check if the identity has a role.
265    pub fn has_role(&self, role: &str) -> bool {
266        self.roles.contains(role)
267    }
268
269    /// Check if the identity has any of the specified roles.
270    pub fn has_any_role(&self, roles: &[&str]) -> bool {
271        roles.iter().any(|r| self.roles.contains(*r))
272    }
273
274    /// Check if the identity has all of the specified roles.
275    pub fn has_all_roles(&self, roles: &[&str]) -> bool {
276        roles.iter().all(|r| self.roles.contains(*r))
277    }
278
279    /// Check if the identity has a permission.
280    pub fn has_permission(&self, permission: &str) -> bool {
281        self.permissions.contains(permission)
282    }
283
284    /// Check if the identity has any of the specified permissions.
285    pub fn has_any_permission(&self, permissions: &[&str]) -> bool {
286        permissions.iter().any(|p| self.permissions.contains(*p))
287    }
288
289    /// Get the tenant ID (for multi-tenant operations).
290    pub fn tenant_id(&self) -> Option<&str> {
291        self.identity.tenant_id.as_deref()
292    }
293}
294
295// ============================================================================
296// AUTH ERROR
297// ============================================================================
298
299/// Error type for authentication operations.
300#[derive(Debug, Clone)]
301pub enum AuthError {
302    /// Invalid credentials provided.
303    InvalidCredentials(String),
304    /// Credentials have expired.
305    Expired(String),
306    /// Missing required credentials.
307    MissingCredentials(String),
308    /// Access denied (authenticated but not authorized).
309    AccessDenied(String),
310    /// Token validation failed.
311    TokenInvalid(String),
312    /// Authentication service unavailable.
313    ServiceUnavailable(String),
314    /// Rate limited.
315    RateLimited(String),
316    /// Other error.
317    Other(String),
318}
319
320impl fmt::Display for AuthError {
321    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
322        match self {
323            Self::InvalidCredentials(msg) => write!(f, "Invalid credentials: {}", msg),
324            Self::Expired(msg) => write!(f, "Credentials expired: {}", msg),
325            Self::MissingCredentials(msg) => write!(f, "Missing credentials: {}", msg),
326            Self::AccessDenied(msg) => write!(f, "Access denied: {}", msg),
327            Self::TokenInvalid(msg) => write!(f, "Token invalid: {}", msg),
328            Self::ServiceUnavailable(msg) => write!(f, "Auth service unavailable: {}", msg),
329            Self::RateLimited(msg) => write!(f, "Rate limited: {}", msg),
330            Self::Other(msg) => write!(f, "Auth error: {}", msg),
331        }
332    }
333}
334
335impl std::error::Error for AuthError {}
336
337/// Result type for authentication operations.
338pub type AuthResult<T> = Result<T, AuthError>;
339
340// ============================================================================
341// AUTH PROVIDER TRAIT
342// ============================================================================
343
344/// Trait for pluggable authentication providers.
345#[async_trait]
346pub trait AuthProvider: Send + Sync {
347    /// Authenticate credentials and return an auth context.
348    async fn authenticate(&self, credentials: &Credentials) -> AuthResult<AuthContext>;
349
350    /// Validate an existing auth context (e.g., check if still valid).
351    async fn validate(&self, context: &AuthContext) -> AuthResult<()>;
352
353    /// Revoke authentication (e.g., invalidate a token).
354    async fn revoke(&self, context: &AuthContext) -> AuthResult<()>;
355
356    /// Get the provider name.
357    fn provider_name(&self) -> &str;
358}
359
360// ============================================================================
361// API KEY AUTHENTICATION
362// ============================================================================
363
364/// API key entry in the store.
365#[derive(Debug, Clone)]
366struct ApiKeyEntry {
367    /// The API key hash (we store hash, not plaintext).
368    _key_hash: u64,
369    /// Identity associated with this key.
370    identity: Identity,
371    /// Permissions granted.
372    permissions: HashSet<String>,
373    /// Roles assigned.
374    roles: HashSet<String>,
375    /// When the key was created.
376    _created_at: Instant,
377    /// Optional expiration.
378    expires_at: Option<Instant>,
379    /// Whether the key is active.
380    active: bool,
381}
382
383/// Simple hash function for API keys (in production, use a proper hash).
384fn hash_api_key(key: &str) -> u64 {
385    use std::collections::hash_map::DefaultHasher;
386    use std::hash::{Hash, Hasher};
387    let mut hasher = DefaultHasher::new();
388    key.hash(&mut hasher);
389    hasher.finish()
390}
391
392/// API key authentication provider.
393pub struct ApiKeyAuth {
394    /// Registered API keys (hash -> entry).
395    keys: RwLock<HashMap<u64, ApiKeyEntry>>,
396    /// Default expiration for new keys.
397    default_expiry: Option<Duration>,
398}
399
400impl ApiKeyAuth {
401    /// Create a new API key auth provider.
402    pub fn new() -> Self {
403        Self {
404            keys: RwLock::new(HashMap::new()),
405            default_expiry: None,
406        }
407    }
408
409    /// Set default expiration for new keys.
410    pub fn with_default_expiry(mut self, expiry: Duration) -> Self {
411        self.default_expiry = Some(expiry);
412        self
413    }
414
415    /// Add an API key.
416    pub fn add_key(
417        self,
418        identity_id: impl Into<String>,
419        api_key: &str,
420        permissions: &[&str],
421    ) -> Self {
422        let identity_id = identity_id.into();
423        let key_hash = hash_api_key(api_key);
424
425        let entry = ApiKeyEntry {
426            _key_hash: key_hash,
427            identity: Identity::new(&identity_id),
428            permissions: permissions.iter().map(|s| s.to_string()).collect(),
429            roles: HashSet::new(),
430            _created_at: Instant::now(),
431            expires_at: self.default_expiry.map(|d| Instant::now() + d),
432            active: true,
433        };
434
435        self.keys.write().insert(key_hash, entry);
436        self
437    }
438
439    /// Add an API key with roles.
440    pub fn add_key_with_roles(
441        self,
442        identity_id: impl Into<String>,
443        api_key: &str,
444        permissions: &[&str],
445        roles: &[&str],
446    ) -> Self {
447        let identity_id = identity_id.into();
448        let key_hash = hash_api_key(api_key);
449
450        let entry = ApiKeyEntry {
451            _key_hash: key_hash,
452            identity: Identity::new(&identity_id),
453            permissions: permissions.iter().map(|s| s.to_string()).collect(),
454            roles: roles.iter().map(|s| s.to_string()).collect(),
455            _created_at: Instant::now(),
456            expires_at: self.default_expiry.map(|d| Instant::now() + d),
457            active: true,
458        };
459
460        self.keys.write().insert(key_hash, entry);
461        self
462    }
463
464    /// Revoke an API key.
465    pub fn revoke_key(&self, api_key: &str) -> bool {
466        let key_hash = hash_api_key(api_key);
467        let mut keys = self.keys.write();
468        if let Some(entry) = keys.get_mut(&key_hash) {
469            entry.active = false;
470            true
471        } else {
472            false
473        }
474    }
475
476    /// Get the number of registered keys.
477    pub fn key_count(&self) -> usize {
478        self.keys.read().len()
479    }
480}
481
482impl Default for ApiKeyAuth {
483    fn default() -> Self {
484        Self::new()
485    }
486}
487
488#[async_trait]
489impl AuthProvider for ApiKeyAuth {
490    async fn authenticate(&self, credentials: &Credentials) -> AuthResult<AuthContext> {
491        let api_key = match credentials {
492            Credentials::ApiKey(key) => key,
493            _ => {
494                return Err(AuthError::InvalidCredentials(
495                    "Expected API key".to_string(),
496                ))
497            }
498        };
499
500        let key_hash = hash_api_key(api_key);
501        let keys = self.keys.read();
502
503        let entry = keys
504            .get(&key_hash)
505            .ok_or_else(|| AuthError::InvalidCredentials("Unknown API key".to_string()))?;
506
507        if !entry.active {
508            return Err(AuthError::InvalidCredentials(
509                "API key has been revoked".to_string(),
510            ));
511        }
512
513        if let Some(expires) = entry.expires_at {
514            if Instant::now() > expires {
515                return Err(AuthError::Expired("API key has expired".to_string()));
516            }
517        }
518
519        let mut ctx = AuthContext::new(entry.identity.clone(), "api_key")
520            .with_permissions(entry.permissions.iter().cloned())
521            .with_roles(entry.roles.iter().cloned());
522
523        if let Some(expires) = entry.expires_at {
524            let remaining = expires.saturating_duration_since(Instant::now());
525            ctx = ctx.with_expiry(remaining);
526        }
527
528        Ok(ctx)
529    }
530
531    async fn validate(&self, context: &AuthContext) -> AuthResult<()> {
532        if context.is_expired() {
533            return Err(AuthError::Expired("Auth context has expired".to_string()));
534        }
535        Ok(())
536    }
537
538    async fn revoke(&self, _context: &AuthContext) -> AuthResult<()> {
539        // API keys are revoked by key, not by context
540        Ok(())
541    }
542
543    fn provider_name(&self) -> &str {
544        "ApiKeyAuth"
545    }
546}
547
548// ============================================================================
549// JWT AUTHENTICATION (requires auth feature)
550// ============================================================================
551
552/// JWT claims structure.
553#[cfg(feature = "auth")]
554#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
555pub struct JwtClaims {
556    /// Subject (user ID).
557    pub sub: String,
558    /// Issued at timestamp.
559    pub iat: u64,
560    /// Expiration timestamp.
561    pub exp: u64,
562    /// Issuer.
563    #[serde(skip_serializing_if = "Option::is_none")]
564    pub iss: Option<String>,
565    /// Audience.
566    #[serde(skip_serializing_if = "Option::is_none")]
567    pub aud: Option<String>,
568    /// Roles.
569    #[serde(default)]
570    pub roles: Vec<String>,
571    /// Permissions.
572    #[serde(default)]
573    pub permissions: Vec<String>,
574    /// Tenant ID.
575    #[serde(skip_serializing_if = "Option::is_none")]
576    pub tenant_id: Option<String>,
577    /// Additional custom claims.
578    #[serde(flatten)]
579    pub custom: HashMap<String, serde_json::Value>,
580}
581
582/// JWT authentication configuration.
583#[cfg(feature = "auth")]
584#[derive(Debug, Clone)]
585pub struct JwtConfig {
586    /// Secret key for HS256/HS384/HS512 algorithms.
587    pub secret: Option<String>,
588    /// Public key for RS256/ES256 algorithms (PEM format).
589    pub public_key: Option<String>,
590    /// Expected issuer.
591    pub issuer: Option<String>,
592    /// Expected audience.
593    pub audience: Option<String>,
594    /// Algorithm to use.
595    pub algorithm: Algorithm,
596    /// Leeway for time validation (seconds).
597    pub leeway_seconds: u64,
598}
599
600#[cfg(feature = "auth")]
601impl Default for JwtConfig {
602    fn default() -> Self {
603        Self {
604            secret: None,
605            public_key: None,
606            issuer: None,
607            audience: None,
608            algorithm: Algorithm::HS256,
609            leeway_seconds: 60,
610        }
611    }
612}
613
614/// JWT authentication provider.
615#[cfg(feature = "auth")]
616pub struct JwtAuth {
617    config: JwtConfig,
618    /// Revoked token IDs (jti claims).
619    revoked_tokens: RwLock<HashSet<String>>,
620}
621
622#[cfg(feature = "auth")]
623impl JwtAuth {
624    /// Create a new JWT auth provider.
625    pub fn new(config: JwtConfig) -> Self {
626        Self {
627            config,
628            revoked_tokens: RwLock::new(HashSet::new()),
629        }
630    }
631
632    /// Create with a simple secret (HS256).
633    pub fn with_secret(secret: impl Into<String>) -> Self {
634        Self::new(JwtConfig {
635            secret: Some(secret.into()),
636            algorithm: Algorithm::HS256,
637            ..Default::default()
638        })
639    }
640
641    /// Generate a JWT token for the given claims.
642    pub fn generate_token(&self, claims: &JwtClaims) -> AuthResult<String> {
643        let secret = self.config.secret.as_ref().ok_or_else(|| {
644            AuthError::Other("No secret configured for token generation".to_string())
645        })?;
646
647        let token = encode(
648            &Header::new(self.config.algorithm),
649            claims,
650            &EncodingKey::from_secret(secret.as_bytes()),
651        )
652        .map_err(|e| AuthError::Other(format!("Token generation failed: {}", e)))?;
653
654        Ok(token)
655    }
656
657    /// Decode and validate a JWT token.
658    fn decode_token(&self, token: &str) -> AuthResult<JwtClaims> {
659        let mut validation = Validation::new(self.config.algorithm);
660        validation.leeway = self.config.leeway_seconds;
661
662        if let Some(ref issuer) = self.config.issuer {
663            validation.set_issuer(&[issuer]);
664        }
665
666        if let Some(ref audience) = self.config.audience {
667            validation.set_audience(&[audience]);
668        }
669
670        let decoding_key = if let Some(ref secret) = self.config.secret {
671            DecodingKey::from_secret(secret.as_bytes())
672        } else if let Some(ref _public_key) = self.config.public_key {
673            // For RSA/EC keys, you'd use from_rsa_pem or from_ec_pem
674            return Err(AuthError::Other(
675                "Public key decoding not implemented".to_string(),
676            ));
677        } else {
678            return Err(AuthError::Other(
679                "No secret or public key configured".to_string(),
680            ));
681        };
682
683        let token_data = decode::<JwtClaims>(token, &decoding_key, &validation)
684            .map_err(|e| AuthError::TokenInvalid(format!("Token validation failed: {}", e)))?;
685
686        Ok(token_data.claims)
687    }
688
689    /// Revoke a token by its jti claim.
690    pub fn revoke_token(&self, jti: impl Into<String>) {
691        self.revoked_tokens.write().insert(jti.into());
692    }
693
694    /// Check if a token is revoked.
695    pub fn is_revoked(&self, jti: &str) -> bool {
696        self.revoked_tokens.read().contains(jti)
697    }
698}
699
700#[cfg(feature = "auth")]
701#[async_trait]
702impl AuthProvider for JwtAuth {
703    async fn authenticate(&self, credentials: &Credentials) -> AuthResult<AuthContext> {
704        let token = match credentials {
705            Credentials::Bearer(t) => t,
706            _ => {
707                return Err(AuthError::InvalidCredentials(
708                    "Expected Bearer token".to_string(),
709                ))
710            }
711        };
712
713        let claims = self.decode_token(token)?;
714
715        // Check if token is revoked (if jti is present in custom claims)
716        if let Some(serde_json::Value::String(jti)) = claims.custom.get("jti") {
717            if self.is_revoked(jti) {
718                return Err(AuthError::TokenInvalid(
719                    "Token has been revoked".to_string(),
720                ));
721            }
722        }
723
724        let mut identity = Identity::new(&claims.sub);
725        if let Some(tenant) = &claims.tenant_id {
726            identity = identity.with_tenant(tenant);
727        }
728
729        // Add custom claims to identity
730        for (key, value) in &claims.custom {
731            if let serde_json::Value::String(s) = value {
732                identity = identity.with_claim(key, s);
733            }
734        }
735
736        let now = SystemTime::now()
737            .duration_since(UNIX_EPOCH)
738            .unwrap_or_default()
739            .as_secs();
740        let remaining = claims.exp.saturating_sub(now);
741
742        let ctx = AuthContext::new(identity, "jwt")
743            .with_roles(claims.roles)
744            .with_permissions(claims.permissions)
745            .with_expiry(Duration::from_secs(remaining));
746
747        Ok(ctx)
748    }
749
750    async fn validate(&self, context: &AuthContext) -> AuthResult<()> {
751        if context.is_expired() {
752            return Err(AuthError::Expired("Token has expired".to_string()));
753        }
754        Ok(())
755    }
756
757    async fn revoke(&self, context: &AuthContext) -> AuthResult<()> {
758        // If there's a jti claim, add it to revoked set
759        if let Some(jti) = context.identity.claims.get("jti") {
760            self.revoke_token(jti);
761        }
762        Ok(())
763    }
764
765    fn provider_name(&self) -> &str {
766        "JwtAuth"
767    }
768}
769
770// ============================================================================
771// CHAINED AUTH PROVIDER
772// ============================================================================
773
774/// Authentication provider that tries multiple providers in order.
775pub struct ChainedAuthProvider {
776    providers: Vec<Arc<dyn AuthProvider>>,
777}
778
779impl ChainedAuthProvider {
780    /// Create a new chained auth provider.
781    pub fn new() -> Self {
782        Self {
783            providers: Vec::new(),
784        }
785    }
786
787    /// Add a provider to the chain.
788    pub fn with_provider(mut self, provider: Arc<dyn AuthProvider>) -> Self {
789        self.providers.push(provider);
790        self
791    }
792}
793
794impl Default for ChainedAuthProvider {
795    fn default() -> Self {
796        Self::new()
797    }
798}
799
800#[async_trait]
801impl AuthProvider for ChainedAuthProvider {
802    async fn authenticate(&self, credentials: &Credentials) -> AuthResult<AuthContext> {
803        let mut last_error = AuthError::MissingCredentials("No providers configured".to_string());
804
805        for provider in &self.providers {
806            match provider.authenticate(credentials).await {
807                Ok(ctx) => return Ok(ctx),
808                Err(e) => {
809                    last_error = e;
810                    continue;
811                }
812            }
813        }
814
815        Err(last_error)
816    }
817
818    async fn validate(&self, context: &AuthContext) -> AuthResult<()> {
819        // Validate using the provider that authenticated
820        for provider in &self.providers {
821            if provider.provider_name() == context.auth_method {
822                return provider.validate(context).await;
823            }
824        }
825        Ok(())
826    }
827
828    async fn revoke(&self, context: &AuthContext) -> AuthResult<()> {
829        for provider in &self.providers {
830            if provider.provider_name() == context.auth_method {
831                return provider.revoke(context).await;
832            }
833        }
834        Ok(())
835    }
836
837    fn provider_name(&self) -> &str {
838        "ChainedAuthProvider"
839    }
840}
841
842// ============================================================================
843// TESTS
844// ============================================================================
845
846#[cfg(test)]
847mod tests {
848    use super::*;
849
850    #[test]
851    fn test_credentials_from_header() {
852        let bearer = Credentials::from_header("Bearer token123");
853        assert!(matches!(bearer, Some(Credentials::Bearer(_))));
854
855        let api_key = Credentials::from_header("ApiKey secret123");
856        assert!(matches!(api_key, Some(Credentials::ApiKey(_))));
857
858        let invalid = Credentials::from_header("invalid");
859        assert!(invalid.is_none());
860    }
861
862    #[test]
863    fn test_identity() {
864        let identity = Identity::new("user123")
865            .with_name("John Doe")
866            .with_email("john@example.com")
867            .with_tenant("tenant1")
868            .with_claim("department", "engineering");
869
870        assert_eq!(identity.id, "user123");
871        assert_eq!(identity.name, Some("John Doe".to_string()));
872        assert_eq!(identity.tenant_id, Some("tenant1".to_string()));
873    }
874
875    #[test]
876    fn test_auth_context() {
877        let identity = Identity::new("user1");
878        let ctx = AuthContext::new(identity, "test")
879            .with_role("admin")
880            .with_role("user")
881            .with_permission("read")
882            .with_permission("write");
883
884        assert!(ctx.has_role("admin"));
885        assert!(ctx.has_role("user"));
886        assert!(!ctx.has_role("superadmin"));
887
888        assert!(ctx.has_permission("read"));
889        assert!(ctx.has_permission("write"));
890        assert!(!ctx.has_permission("delete"));
891
892        assert!(ctx.has_any_role(&["admin", "guest"]));
893        assert!(ctx.has_all_roles(&["admin", "user"]));
894        assert!(!ctx.has_all_roles(&["admin", "superadmin"]));
895    }
896
897    #[test]
898    fn test_auth_context_expiry() {
899        let identity = Identity::new("user1");
900        let ctx = AuthContext::new(identity, "test").with_expiry(Duration::from_nanos(1));
901
902        std::thread::sleep(Duration::from_millis(1));
903        assert!(ctx.is_expired());
904    }
905
906    #[tokio::test]
907    async fn test_api_key_auth() {
908        let auth = ApiKeyAuth::new()
909            .add_key("admin", "secret-key-123", &["admin", "read", "write"])
910            .add_key("readonly", "readonly-key-456", &["read"]);
911
912        // Valid admin key
913        let ctx = auth
914            .authenticate(&Credentials::ApiKey("secret-key-123".to_string()))
915            .await
916            .unwrap();
917        assert_eq!(ctx.identity.id, "admin");
918        assert!(ctx.has_permission("write"));
919
920        // Valid readonly key
921        let ctx2 = auth
922            .authenticate(&Credentials::ApiKey("readonly-key-456".to_string()))
923            .await
924            .unwrap();
925        assert_eq!(ctx2.identity.id, "readonly");
926        assert!(ctx2.has_permission("read"));
927        assert!(!ctx2.has_permission("write"));
928
929        // Invalid key
930        let result = auth
931            .authenticate(&Credentials::ApiKey("invalid-key".to_string()))
932            .await;
933        assert!(result.is_err());
934    }
935
936    #[tokio::test]
937    async fn test_api_key_revocation() {
938        let auth = ApiKeyAuth::new().add_key("user1", "key-to-revoke", &["read"]);
939
940        // Works initially
941        let result = auth
942            .authenticate(&Credentials::ApiKey("key-to-revoke".to_string()))
943            .await;
944        assert!(result.is_ok());
945
946        // Revoke
947        auth.revoke_key("key-to-revoke");
948
949        // Now fails
950        let result = auth
951            .authenticate(&Credentials::ApiKey("key-to-revoke".to_string()))
952            .await;
953        assert!(result.is_err());
954    }
955
956    #[cfg(feature = "auth")]
957    #[tokio::test]
958    async fn test_jwt_auth() {
959        let auth = JwtAuth::with_secret("test-secret-key-256-bits-long!");
960
961        // Generate a token
962        let claims = JwtClaims {
963            sub: "user123".to_string(),
964            iat: SystemTime::now()
965                .duration_since(UNIX_EPOCH)
966                .unwrap()
967                .as_secs(),
968            exp: SystemTime::now()
969                .duration_since(UNIX_EPOCH)
970                .unwrap()
971                .as_secs()
972                + 3600,
973            iss: None,
974            aud: None,
975            roles: vec!["admin".to_string()],
976            permissions: vec!["read".to_string(), "write".to_string()],
977            tenant_id: Some("tenant1".to_string()),
978            custom: HashMap::new(),
979        };
980
981        let token = auth.generate_token(&claims).unwrap();
982
983        // Authenticate with the token
984        let ctx = auth
985            .authenticate(&Credentials::Bearer(token))
986            .await
987            .unwrap();
988
989        assert_eq!(ctx.identity.id, "user123");
990        assert!(ctx.has_role("admin"));
991        assert!(ctx.has_permission("read"));
992        assert_eq!(ctx.tenant_id(), Some("tenant1"));
993    }
994
995    #[tokio::test]
996    async fn test_chained_auth() {
997        let api_auth = Arc::new(ApiKeyAuth::new().add_key("api_user", "api-key-123", &["api"]));
998
999        let chain = ChainedAuthProvider::new().with_provider(api_auth);
1000
1001        // API key works
1002        let ctx = chain
1003            .authenticate(&Credentials::ApiKey("api-key-123".to_string()))
1004            .await
1005            .unwrap();
1006        assert_eq!(ctx.identity.id, "api_user");
1007
1008        // Unknown key fails
1009        let result = chain
1010            .authenticate(&Credentials::ApiKey("unknown".to_string()))
1011            .await;
1012        assert!(result.is_err());
1013    }
1014}