amaters_server/
auth.rs

1//! Authentication module
2//!
3//! This module provides authentication services for the server:
4//! - mTLS (Mutual TLS) client certificate validation
5//! - JWT (JSON Web Token) authentication
6//! - API key authentication
7//!
8//! Security model:
9//! - Secure by default (deny unless explicitly allowed)
10//! - Multiple authentication methods can be enabled simultaneously
11//! - Authentication results in a validated identity (Principal)
12
13use crate::config::{ApiKeySettings, AuthSettings, JwtSettings, MtlsSettings};
14use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
15use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
16use serde::{Deserialize, Serialize};
17use sha2::{Digest, Sha256};
18use std::collections::HashMap;
19use std::fs;
20use std::path::Path;
21use std::sync::Arc;
22use thiserror::Error;
23use tracing::{debug, info, warn};
24use x509_parser::prelude::*;
25
26/// Authentication errors
27#[derive(Error, Debug)]
28pub enum AuthError {
29    #[error("Authentication failed: {0}")]
30    AuthenticationFailed(String),
31
32    #[error("Invalid credentials")]
33    InvalidCredentials,
34
35    #[error("Certificate validation failed: {0}")]
36    CertificateError(String),
37
38    #[error("JWT validation failed: {0}")]
39    JwtError(String),
40
41    #[error("API key validation failed: {0}")]
42    ApiKeyError(String),
43
44    #[error("Configuration error: {0}")]
45    ConfigError(String),
46
47    #[error("IO error: {0}")]
48    Io(#[from] std::io::Error),
49
50    #[error("JSON error: {0}")]
51    Json(#[from] serde_json::Error),
52
53    #[error("No authentication provided")]
54    NoAuthProvided,
55
56    #[error("Authentication method not enabled: {0}")]
57    MethodNotEnabled(String),
58}
59
60pub type AuthResult<T> = Result<T, AuthError>;
61
62/// Authenticated principal (user identity)
63#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
64pub struct Principal {
65    /// Unique identifier for the user
66    pub id: String,
67
68    /// Username or common name
69    pub name: String,
70
71    /// Authentication method used
72    pub auth_method: AuthMethod,
73
74    /// Additional attributes (roles, groups, etc.)
75    pub attributes: HashMap<String, String>,
76}
77
78impl Principal {
79    /// Create a new principal
80    pub fn new(id: String, name: String, auth_method: AuthMethod) -> Self {
81        Self {
82            id,
83            name,
84            auth_method,
85            attributes: HashMap::new(),
86        }
87    }
88
89    /// Add an attribute to the principal
90    pub fn with_attribute(mut self, key: String, value: String) -> Self {
91        self.attributes.insert(key, value);
92        self
93    }
94
95    /// Get an attribute value
96    pub fn get_attribute(&self, key: &str) -> Option<&String> {
97        self.attributes.get(key)
98    }
99
100    /// Check if principal has a specific role
101    pub fn has_role(&self, role: &str) -> bool {
102        self.get_attribute("role")
103            .map(|r| r == role)
104            .unwrap_or(false)
105    }
106}
107
108/// Authentication method used
109#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
110pub enum AuthMethod {
111    /// mTLS client certificate
112    MutualTls,
113    /// JWT token
114    Jwt,
115    /// API key
116    ApiKey,
117}
118
119impl std::fmt::Display for AuthMethod {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        match self {
122            AuthMethod::MutualTls => write!(f, "mTLS"),
123            AuthMethod::Jwt => write!(f, "JWT"),
124            AuthMethod::ApiKey => write!(f, "API Key"),
125        }
126    }
127}
128
129/// JWT claims structure
130#[derive(Debug, Serialize, Deserialize)]
131struct JwtClaims {
132    /// Subject (user ID)
133    sub: String,
134    /// Expiration time
135    exp: usize,
136    /// Issued at
137    iat: Option<usize>,
138    /// Issuer
139    iss: Option<String>,
140    /// Audience
141    aud: Option<String>,
142    /// User name
143    name: Option<String>,
144    /// Roles
145    roles: Option<Vec<String>>,
146    /// Custom attributes
147    #[serde(flatten)]
148    attributes: HashMap<String, serde_json::Value>,
149}
150
151/// API key entry
152#[derive(Debug, Clone, Serialize, Deserialize)]
153struct ApiKeyEntry {
154    /// Key ID
155    id: String,
156    /// Key name/description
157    name: String,
158    /// Hashed key value (if hashing enabled)
159    #[serde(skip_serializing_if = "Option::is_none")]
160    key_hash: Option<String>,
161    /// Plain key value (if hashing disabled)
162    #[serde(skip_serializing_if = "Option::is_none")]
163    key: Option<String>,
164    /// User ID
165    user_id: String,
166    /// Roles
167    #[serde(default)]
168    roles: Vec<String>,
169    /// Additional attributes
170    #[serde(default)]
171    attributes: HashMap<String, String>,
172}
173
174/// Authentication service
175pub struct Authenticator {
176    config: Arc<AuthSettings>,
177    mtls_validator: Option<MtlsValidator>,
178    jwt_validator: Option<JwtValidator>,
179    api_key_validator: Option<ApiKeyValidator>,
180}
181
182impl Authenticator {
183    /// Create a new authenticator
184    pub fn new(config: AuthSettings) -> AuthResult<Self> {
185        let config = Arc::new(config);
186
187        // Initialize mTLS validator
188        let mtls_validator = if config.mtls.enabled {
189            Some(MtlsValidator::new(config.mtls.clone())?)
190        } else {
191            None
192        };
193
194        // Initialize JWT validator
195        let jwt_validator = if config.jwt.enabled {
196            Some(JwtValidator::new(config.jwt.clone())?)
197        } else {
198            None
199        };
200
201        // Initialize API key validator
202        let api_key_validator = if config.api_key.enabled {
203            Some(ApiKeyValidator::new(config.api_key.clone())?)
204        } else {
205            None
206        };
207
208        Ok(Self {
209            config,
210            mtls_validator,
211            jwt_validator,
212            api_key_validator,
213        })
214    }
215
216    /// Authenticate using client certificate
217    pub fn authenticate_certificate(&self, cert_der: &[u8]) -> AuthResult<Principal> {
218        if !self.config.methods.contains(&"mtls".to_string()) {
219            return Err(AuthError::MethodNotEnabled("mTLS".to_string()));
220        }
221
222        let validator = self
223            .mtls_validator
224            .as_ref()
225            .ok_or_else(|| AuthError::MethodNotEnabled("mTLS".to_string()))?;
226
227        validator.validate_certificate(cert_der)
228    }
229
230    /// Authenticate using JWT token
231    pub fn authenticate_jwt(&self, token: &str) -> AuthResult<Principal> {
232        if !self.config.methods.contains(&"jwt".to_string()) {
233            return Err(AuthError::MethodNotEnabled("JWT".to_string()));
234        }
235
236        let validator = self
237            .jwt_validator
238            .as_ref()
239            .ok_or_else(|| AuthError::MethodNotEnabled("JWT".to_string()))?;
240
241        validator.validate_token(token)
242    }
243
244    /// Authenticate using API key
245    pub fn authenticate_api_key(&self, key: &str) -> AuthResult<Principal> {
246        if !self.config.methods.contains(&"api_key".to_string()) {
247            return Err(AuthError::MethodNotEnabled("API Key".to_string()));
248        }
249
250        let validator = self
251            .api_key_validator
252            .as_ref()
253            .ok_or_else(|| AuthError::MethodNotEnabled("API Key".to_string()))?;
254
255        validator.validate_key(key)
256    }
257
258    /// Check if authentication is enabled
259    pub fn is_enabled(&self) -> bool {
260        self.config.enabled
261    }
262
263    /// Check if a specific method is enabled
264    pub fn is_method_enabled(&self, method: &str) -> bool {
265        self.config.methods.contains(&method.to_string())
266    }
267}
268
269/// mTLS certificate validator
270struct MtlsValidator {
271    config: MtlsSettings,
272    ca_certs: Vec<Vec<u8>>,
273}
274
275impl MtlsValidator {
276    fn new(config: MtlsSettings) -> AuthResult<Self> {
277        let ca_certs = if let Some(ref ca_dir) = config.ca_certs_dir {
278            Self::load_ca_certificates(ca_dir)?
279        } else {
280            Vec::new()
281        };
282
283        Ok(Self { config, ca_certs })
284    }
285
286    fn load_ca_certificates(dir: &Path) -> AuthResult<Vec<Vec<u8>>> {
287        let mut certs = Vec::new();
288
289        if !dir.exists() {
290            return Err(AuthError::ConfigError(format!(
291                "CA certificates directory does not exist: {}",
292                dir.display()
293            )));
294        }
295
296        for entry_result in fs::read_dir(dir)? {
297            let entry = entry_result?;
298            let path = entry.path();
299
300            if path.is_file() {
301                if let Some(ext) = path.extension() {
302                    if ext == "crt" || ext == "pem" || ext == "der" {
303                        let cert_data = fs::read(&path)?;
304                        certs.push(cert_data);
305                        debug!("Loaded CA certificate: {}", path.display());
306                    }
307                }
308            }
309        }
310
311        info!("Loaded {} CA certificates", certs.len());
312        Ok(certs)
313    }
314
315    fn validate_certificate(&self, cert_der: &[u8]) -> AuthResult<Principal> {
316        // Parse the certificate
317        let (_, cert) = X509Certificate::from_der(cert_der).map_err(|e| {
318            AuthError::CertificateError(format!("Failed to parse certificate: {}", e))
319        })?;
320
321        // Verify certificate validity period
322        let now = std::time::SystemTime::now();
323        let not_before = cert.validity().not_before.to_datetime();
324        let not_after = cert.validity().not_after.to_datetime();
325
326        if now < not_before {
327            return Err(AuthError::CertificateError(
328                "Certificate not yet valid".to_string(),
329            ));
330        }
331
332        if now > not_after {
333            return Err(AuthError::CertificateError(
334                "Certificate has expired".to_string(),
335            ));
336        }
337
338        // Extract subject information
339        let subject = cert.subject();
340        let cn = subject
341            .iter_common_name()
342            .next()
343            .and_then(|cn| cn.as_str().ok())
344            .ok_or_else(|| AuthError::CertificateError("No CN in certificate".to_string()))?;
345
346        let organization = subject
347            .iter_organization()
348            .next()
349            .and_then(|o| o.as_str().ok());
350
351        // Verify organization if restrictions are configured
352        if !self.config.allowed_organizations.is_empty() {
353            let org = organization.ok_or_else(|| {
354                AuthError::CertificateError("Certificate has no organization".to_string())
355            })?;
356
357            if !self.config.allowed_organizations.contains(&org.to_string()) {
358                return Err(AuthError::CertificateError(format!(
359                    "Organization '{}' not allowed",
360                    org
361                )));
362            }
363        }
364
365        // Create principal
366        let mut principal = Principal::new(cn.to_string(), cn.to_string(), AuthMethod::MutualTls);
367
368        if let Some(org) = organization {
369            principal = principal.with_attribute("organization".to_string(), org.to_string());
370        }
371
372        debug!("Successfully authenticated certificate for user: {}", cn);
373        Ok(principal)
374    }
375}
376
377/// JWT token validator
378struct JwtValidator {
379    config: JwtSettings,
380    decoding_key: DecodingKey,
381    validation: Validation,
382}
383
384impl JwtValidator {
385    fn new(config: JwtSettings) -> AuthResult<Self> {
386        let algorithm = match config.algorithm.as_str() {
387            "HS256" => Algorithm::HS256,
388            "RS256" => Algorithm::RS256,
389            _ => {
390                return Err(AuthError::ConfigError(format!(
391                    "Unsupported JWT algorithm: {}",
392                    config.algorithm
393                )));
394            }
395        };
396
397        let decoding_key = match algorithm {
398            Algorithm::HS256 => {
399                let secret = config.secret.as_ref().ok_or_else(|| {
400                    AuthError::ConfigError("JWT secret not configured".to_string())
401                })?;
402                DecodingKey::from_secret(secret.as_bytes())
403            }
404            Algorithm::RS256 => {
405                let public_key_path = config.public_key_path.as_ref().ok_or_else(|| {
406                    AuthError::ConfigError("JWT public key path not configured".to_string())
407                })?;
408                let pem = fs::read_to_string(public_key_path)?;
409                DecodingKey::from_rsa_pem(pem.as_bytes()).map_err(|e| {
410                    AuthError::ConfigError(format!("Failed to load RSA public key: {}", e))
411                })?
412            }
413            _ => {
414                return Err(AuthError::ConfigError(
415                    "Algorithm not implemented".to_string(),
416                ));
417            }
418        };
419
420        let mut validation = Validation::new(algorithm);
421        validation.validate_exp = true;
422
423        if let Some(ref issuer) = config.issuer {
424            validation.set_issuer(&[issuer]);
425        }
426
427        if let Some(ref audience) = config.audience {
428            validation.set_audience(&[audience]);
429        }
430
431        Ok(Self {
432            config,
433            decoding_key,
434            validation,
435        })
436    }
437
438    fn validate_token(&self, token: &str) -> AuthResult<Principal> {
439        // Decode and validate the token
440        let token_data = decode::<JwtClaims>(token, &self.decoding_key, &self.validation)
441            .map_err(|e| AuthError::JwtError(format!("Token validation failed: {}", e)))?;
442
443        let claims = token_data.claims;
444
445        // Create principal
446        let name = claims.name.unwrap_or_else(|| claims.sub.clone());
447        let mut principal = Principal::new(claims.sub, name, AuthMethod::Jwt);
448
449        // Add roles
450        if let Some(roles) = claims.roles {
451            principal = principal.with_attribute("roles".to_string(), roles.join(","));
452        }
453
454        // Add custom attributes
455        for (key, value) in claims.attributes {
456            if let Some(s) = value.as_str() {
457                principal = principal.with_attribute(key, s.to_string());
458            }
459        }
460
461        debug!(
462            "Successfully authenticated JWT token for user: {}",
463            principal.name
464        );
465        Ok(principal)
466    }
467}
468
469/// API key validator
470struct ApiKeyValidator {
471    config: ApiKeySettings,
472    keys: HashMap<String, ApiKeyEntry>,
473}
474
475impl ApiKeyValidator {
476    fn new(config: ApiKeySettings) -> AuthResult<Self> {
477        let keys_file = config
478            .keys_file
479            .as_ref()
480            .ok_or_else(|| AuthError::ConfigError("API keys file not configured".to_string()))?;
481
482        let keys = Self::load_keys(keys_file, config.hash_keys)?;
483
484        info!("Loaded {} API keys", keys.len());
485
486        Ok(Self { config, keys })
487    }
488
489    fn load_keys(path: &Path, hash_keys: bool) -> AuthResult<HashMap<String, ApiKeyEntry>> {
490        if !path.exists() {
491            return Err(AuthError::ConfigError(format!(
492                "API keys file does not exist: {}",
493                path.display()
494            )));
495        }
496
497        let contents = fs::read_to_string(path)?;
498        let entries: Vec<ApiKeyEntry> = serde_json::from_str(&contents)?;
499
500        let mut keys = HashMap::new();
501        for entry in entries {
502            let key_value = if hash_keys {
503                entry
504                    .key_hash
505                    .clone()
506                    .ok_or_else(|| AuthError::ConfigError("Missing key_hash".to_string()))?
507            } else {
508                entry
509                    .key
510                    .clone()
511                    .ok_or_else(|| AuthError::ConfigError("Missing key".to_string()))?
512            };
513
514            keys.insert(key_value, entry);
515        }
516
517        Ok(keys)
518    }
519
520    fn validate_key(&self, key: &str) -> AuthResult<Principal> {
521        let lookup_key = if self.config.hash_keys {
522            Self::hash_key(key)
523        } else {
524            key.to_string()
525        };
526
527        let entry = self
528            .keys
529            .get(&lookup_key)
530            .ok_or(AuthError::InvalidCredentials)?;
531
532        // Create principal
533        let mut principal = Principal::new(
534            entry.user_id.clone(),
535            entry.name.clone(),
536            AuthMethod::ApiKey,
537        );
538
539        // Add roles
540        if !entry.roles.is_empty() {
541            principal = principal.with_attribute("roles".to_string(), entry.roles.join(","));
542        }
543
544        // Add custom attributes
545        for (key, value) in &entry.attributes {
546            principal = principal.with_attribute(key.clone(), value.clone());
547        }
548
549        debug!(
550            "Successfully authenticated API key for user: {}",
551            entry.user_id
552        );
553        Ok(principal)
554    }
555
556    fn hash_key(key: &str) -> String {
557        let mut hasher = Sha256::new();
558        hasher.update(key.as_bytes());
559        let result = hasher.finalize();
560        BASE64.encode(result)
561    }
562}
563
564#[cfg(test)]
565mod tests {
566    use super::*;
567    use std::env;
568
569    #[test]
570    fn test_principal_creation() {
571        let principal = Principal::new(
572            "user123".to_string(),
573            "John Doe".to_string(),
574            AuthMethod::Jwt,
575        );
576
577        assert_eq!(principal.id, "user123");
578        assert_eq!(principal.name, "John Doe");
579        assert_eq!(principal.auth_method, AuthMethod::Jwt);
580    }
581
582    #[test]
583    fn test_principal_attributes() {
584        let principal = Principal::new(
585            "user123".to_string(),
586            "John Doe".to_string(),
587            AuthMethod::Jwt,
588        )
589        .with_attribute("role".to_string(), "admin".to_string())
590        .with_attribute("department".to_string(), "Engineering".to_string());
591
592        assert_eq!(principal.get_attribute("role"), Some(&"admin".to_string()));
593        assert!(principal.has_role("admin"));
594        assert!(!principal.has_role("user"));
595    }
596
597    #[test]
598    fn test_api_key_hashing() {
599        let key = "test-api-key-12345";
600        let hash1 = ApiKeyValidator::hash_key(key);
601        let hash2 = ApiKeyValidator::hash_key(key);
602
603        assert_eq!(hash1, hash2); // Same key produces same hash
604        assert!(!hash1.is_empty());
605    }
606
607    #[test]
608    fn test_authenticator_creation() {
609        let config = AuthSettings {
610            enabled: true,
611            methods: vec!["mtls".to_string()],
612            mtls: MtlsSettings {
613                enabled: true,
614                ca_certs_dir: Some(env::temp_dir()),
615                crl_path: None,
616                verify_cn: true,
617                allowed_organizations: Vec::new(),
618            },
619            jwt: JwtSettings::default(),
620            api_key: ApiKeySettings::default(),
621            reject_unauthenticated: true,
622        };
623
624        let auth_result = Authenticator::new(config);
625        assert!(auth_result.is_ok());
626    }
627
628    #[test]
629    fn test_jwt_validator_creation_missing_secret() {
630        let config = JwtSettings {
631            enabled: true,
632            secret: None,
633            public_key_path: None,
634            algorithm: "HS256".to_string(),
635            expiration_secs: 3600,
636            issuer: None,
637            audience: None,
638        };
639
640        let result = JwtValidator::new(config);
641        assert!(result.is_err());
642    }
643
644    #[test]
645    fn test_auth_method_display() {
646        assert_eq!(format!("{}", AuthMethod::MutualTls), "mTLS");
647        assert_eq!(format!("{}", AuthMethod::Jwt), "JWT");
648        assert_eq!(format!("{}", AuthMethod::ApiKey), "API Key");
649    }
650}