1use crate::config::{ApiKeySettings, AuthSettings, JwtSettings, MtlsSettings};
14use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
15use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
16use serde::{Deserialize, Serialize};
17use sha2::{Digest, Sha256};
18use std::collections::HashMap;
19use std::fs;
20use std::path::Path;
21use std::sync::Arc;
22use thiserror::Error;
23use tracing::{debug, info, warn};
24use x509_parser::prelude::*;
25
26#[derive(Error, Debug)]
28pub enum AuthError {
29 #[error("Authentication failed: {0}")]
30 AuthenticationFailed(String),
31
32 #[error("Invalid credentials")]
33 InvalidCredentials,
34
35 #[error("Certificate validation failed: {0}")]
36 CertificateError(String),
37
38 #[error("JWT validation failed: {0}")]
39 JwtError(String),
40
41 #[error("API key validation failed: {0}")]
42 ApiKeyError(String),
43
44 #[error("Configuration error: {0}")]
45 ConfigError(String),
46
47 #[error("IO error: {0}")]
48 Io(#[from] std::io::Error),
49
50 #[error("JSON error: {0}")]
51 Json(#[from] serde_json::Error),
52
53 #[error("No authentication provided")]
54 NoAuthProvided,
55
56 #[error("Authentication method not enabled: {0}")]
57 MethodNotEnabled(String),
58}
59
60pub type AuthResult<T> = Result<T, AuthError>;
61
62#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
64pub struct Principal {
65 pub id: String,
67
68 pub name: String,
70
71 pub auth_method: AuthMethod,
73
74 pub attributes: HashMap<String, String>,
76}
77
78impl Principal {
79 pub fn new(id: String, name: String, auth_method: AuthMethod) -> Self {
81 Self {
82 id,
83 name,
84 auth_method,
85 attributes: HashMap::new(),
86 }
87 }
88
89 pub fn with_attribute(mut self, key: String, value: String) -> Self {
91 self.attributes.insert(key, value);
92 self
93 }
94
95 pub fn get_attribute(&self, key: &str) -> Option<&String> {
97 self.attributes.get(key)
98 }
99
100 pub fn has_role(&self, role: &str) -> bool {
102 self.get_attribute("role")
103 .map(|r| r == role)
104 .unwrap_or(false)
105 }
106}
107
108#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
110pub enum AuthMethod {
111 MutualTls,
113 Jwt,
115 ApiKey,
117}
118
119impl std::fmt::Display for AuthMethod {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 match self {
122 AuthMethod::MutualTls => write!(f, "mTLS"),
123 AuthMethod::Jwt => write!(f, "JWT"),
124 AuthMethod::ApiKey => write!(f, "API Key"),
125 }
126 }
127}
128
129#[derive(Debug, Serialize, Deserialize)]
131struct JwtClaims {
132 sub: String,
134 exp: usize,
136 iat: Option<usize>,
138 iss: Option<String>,
140 aud: Option<String>,
142 name: Option<String>,
144 roles: Option<Vec<String>>,
146 #[serde(flatten)]
148 attributes: HashMap<String, serde_json::Value>,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
153struct ApiKeyEntry {
154 id: String,
156 name: String,
158 #[serde(skip_serializing_if = "Option::is_none")]
160 key_hash: Option<String>,
161 #[serde(skip_serializing_if = "Option::is_none")]
163 key: Option<String>,
164 user_id: String,
166 #[serde(default)]
168 roles: Vec<String>,
169 #[serde(default)]
171 attributes: HashMap<String, String>,
172}
173
174pub struct Authenticator {
176 config: Arc<AuthSettings>,
177 mtls_validator: Option<MtlsValidator>,
178 jwt_validator: Option<JwtValidator>,
179 api_key_validator: Option<ApiKeyValidator>,
180}
181
182impl Authenticator {
183 pub fn new(config: AuthSettings) -> AuthResult<Self> {
185 let config = Arc::new(config);
186
187 let mtls_validator = if config.mtls.enabled {
189 Some(MtlsValidator::new(config.mtls.clone())?)
190 } else {
191 None
192 };
193
194 let jwt_validator = if config.jwt.enabled {
196 Some(JwtValidator::new(config.jwt.clone())?)
197 } else {
198 None
199 };
200
201 let api_key_validator = if config.api_key.enabled {
203 Some(ApiKeyValidator::new(config.api_key.clone())?)
204 } else {
205 None
206 };
207
208 Ok(Self {
209 config,
210 mtls_validator,
211 jwt_validator,
212 api_key_validator,
213 })
214 }
215
216 pub fn authenticate_certificate(&self, cert_der: &[u8]) -> AuthResult<Principal> {
218 if !self.config.methods.contains(&"mtls".to_string()) {
219 return Err(AuthError::MethodNotEnabled("mTLS".to_string()));
220 }
221
222 let validator = self
223 .mtls_validator
224 .as_ref()
225 .ok_or_else(|| AuthError::MethodNotEnabled("mTLS".to_string()))?;
226
227 validator.validate_certificate(cert_der)
228 }
229
230 pub fn authenticate_jwt(&self, token: &str) -> AuthResult<Principal> {
232 if !self.config.methods.contains(&"jwt".to_string()) {
233 return Err(AuthError::MethodNotEnabled("JWT".to_string()));
234 }
235
236 let validator = self
237 .jwt_validator
238 .as_ref()
239 .ok_or_else(|| AuthError::MethodNotEnabled("JWT".to_string()))?;
240
241 validator.validate_token(token)
242 }
243
244 pub fn authenticate_api_key(&self, key: &str) -> AuthResult<Principal> {
246 if !self.config.methods.contains(&"api_key".to_string()) {
247 return Err(AuthError::MethodNotEnabled("API Key".to_string()));
248 }
249
250 let validator = self
251 .api_key_validator
252 .as_ref()
253 .ok_or_else(|| AuthError::MethodNotEnabled("API Key".to_string()))?;
254
255 validator.validate_key(key)
256 }
257
258 pub fn is_enabled(&self) -> bool {
260 self.config.enabled
261 }
262
263 pub fn is_method_enabled(&self, method: &str) -> bool {
265 self.config.methods.contains(&method.to_string())
266 }
267}
268
269struct MtlsValidator {
271 config: MtlsSettings,
272 ca_certs: Vec<Vec<u8>>,
273}
274
275impl MtlsValidator {
276 fn new(config: MtlsSettings) -> AuthResult<Self> {
277 let ca_certs = if let Some(ref ca_dir) = config.ca_certs_dir {
278 Self::load_ca_certificates(ca_dir)?
279 } else {
280 Vec::new()
281 };
282
283 Ok(Self { config, ca_certs })
284 }
285
286 fn load_ca_certificates(dir: &Path) -> AuthResult<Vec<Vec<u8>>> {
287 let mut certs = Vec::new();
288
289 if !dir.exists() {
290 return Err(AuthError::ConfigError(format!(
291 "CA certificates directory does not exist: {}",
292 dir.display()
293 )));
294 }
295
296 for entry_result in fs::read_dir(dir)? {
297 let entry = entry_result?;
298 let path = entry.path();
299
300 if path.is_file() {
301 if let Some(ext) = path.extension() {
302 if ext == "crt" || ext == "pem" || ext == "der" {
303 let cert_data = fs::read(&path)?;
304 certs.push(cert_data);
305 debug!("Loaded CA certificate: {}", path.display());
306 }
307 }
308 }
309 }
310
311 info!("Loaded {} CA certificates", certs.len());
312 Ok(certs)
313 }
314
315 fn validate_certificate(&self, cert_der: &[u8]) -> AuthResult<Principal> {
316 let (_, cert) = X509Certificate::from_der(cert_der).map_err(|e| {
318 AuthError::CertificateError(format!("Failed to parse certificate: {}", e))
319 })?;
320
321 let now = std::time::SystemTime::now();
323 let not_before = cert.validity().not_before.to_datetime();
324 let not_after = cert.validity().not_after.to_datetime();
325
326 if now < not_before {
327 return Err(AuthError::CertificateError(
328 "Certificate not yet valid".to_string(),
329 ));
330 }
331
332 if now > not_after {
333 return Err(AuthError::CertificateError(
334 "Certificate has expired".to_string(),
335 ));
336 }
337
338 let subject = cert.subject();
340 let cn = subject
341 .iter_common_name()
342 .next()
343 .and_then(|cn| cn.as_str().ok())
344 .ok_or_else(|| AuthError::CertificateError("No CN in certificate".to_string()))?;
345
346 let organization = subject
347 .iter_organization()
348 .next()
349 .and_then(|o| o.as_str().ok());
350
351 if !self.config.allowed_organizations.is_empty() {
353 let org = organization.ok_or_else(|| {
354 AuthError::CertificateError("Certificate has no organization".to_string())
355 })?;
356
357 if !self.config.allowed_organizations.contains(&org.to_string()) {
358 return Err(AuthError::CertificateError(format!(
359 "Organization '{}' not allowed",
360 org
361 )));
362 }
363 }
364
365 let mut principal = Principal::new(cn.to_string(), cn.to_string(), AuthMethod::MutualTls);
367
368 if let Some(org) = organization {
369 principal = principal.with_attribute("organization".to_string(), org.to_string());
370 }
371
372 debug!("Successfully authenticated certificate for user: {}", cn);
373 Ok(principal)
374 }
375}
376
377struct JwtValidator {
379 config: JwtSettings,
380 decoding_key: DecodingKey,
381 validation: Validation,
382}
383
384impl JwtValidator {
385 fn new(config: JwtSettings) -> AuthResult<Self> {
386 let algorithm = match config.algorithm.as_str() {
387 "HS256" => Algorithm::HS256,
388 "RS256" => Algorithm::RS256,
389 _ => {
390 return Err(AuthError::ConfigError(format!(
391 "Unsupported JWT algorithm: {}",
392 config.algorithm
393 )));
394 }
395 };
396
397 let decoding_key = match algorithm {
398 Algorithm::HS256 => {
399 let secret = config.secret.as_ref().ok_or_else(|| {
400 AuthError::ConfigError("JWT secret not configured".to_string())
401 })?;
402 DecodingKey::from_secret(secret.as_bytes())
403 }
404 Algorithm::RS256 => {
405 let public_key_path = config.public_key_path.as_ref().ok_or_else(|| {
406 AuthError::ConfigError("JWT public key path not configured".to_string())
407 })?;
408 let pem = fs::read_to_string(public_key_path)?;
409 DecodingKey::from_rsa_pem(pem.as_bytes()).map_err(|e| {
410 AuthError::ConfigError(format!("Failed to load RSA public key: {}", e))
411 })?
412 }
413 _ => {
414 return Err(AuthError::ConfigError(
415 "Algorithm not implemented".to_string(),
416 ));
417 }
418 };
419
420 let mut validation = Validation::new(algorithm);
421 validation.validate_exp = true;
422
423 if let Some(ref issuer) = config.issuer {
424 validation.set_issuer(&[issuer]);
425 }
426
427 if let Some(ref audience) = config.audience {
428 validation.set_audience(&[audience]);
429 }
430
431 Ok(Self {
432 config,
433 decoding_key,
434 validation,
435 })
436 }
437
438 fn validate_token(&self, token: &str) -> AuthResult<Principal> {
439 let token_data = decode::<JwtClaims>(token, &self.decoding_key, &self.validation)
441 .map_err(|e| AuthError::JwtError(format!("Token validation failed: {}", e)))?;
442
443 let claims = token_data.claims;
444
445 let name = claims.name.unwrap_or_else(|| claims.sub.clone());
447 let mut principal = Principal::new(claims.sub, name, AuthMethod::Jwt);
448
449 if let Some(roles) = claims.roles {
451 principal = principal.with_attribute("roles".to_string(), roles.join(","));
452 }
453
454 for (key, value) in claims.attributes {
456 if let Some(s) = value.as_str() {
457 principal = principal.with_attribute(key, s.to_string());
458 }
459 }
460
461 debug!(
462 "Successfully authenticated JWT token for user: {}",
463 principal.name
464 );
465 Ok(principal)
466 }
467}
468
469struct ApiKeyValidator {
471 config: ApiKeySettings,
472 keys: HashMap<String, ApiKeyEntry>,
473}
474
475impl ApiKeyValidator {
476 fn new(config: ApiKeySettings) -> AuthResult<Self> {
477 let keys_file = config
478 .keys_file
479 .as_ref()
480 .ok_or_else(|| AuthError::ConfigError("API keys file not configured".to_string()))?;
481
482 let keys = Self::load_keys(keys_file, config.hash_keys)?;
483
484 info!("Loaded {} API keys", keys.len());
485
486 Ok(Self { config, keys })
487 }
488
489 fn load_keys(path: &Path, hash_keys: bool) -> AuthResult<HashMap<String, ApiKeyEntry>> {
490 if !path.exists() {
491 return Err(AuthError::ConfigError(format!(
492 "API keys file does not exist: {}",
493 path.display()
494 )));
495 }
496
497 let contents = fs::read_to_string(path)?;
498 let entries: Vec<ApiKeyEntry> = serde_json::from_str(&contents)?;
499
500 let mut keys = HashMap::new();
501 for entry in entries {
502 let key_value = if hash_keys {
503 entry
504 .key_hash
505 .clone()
506 .ok_or_else(|| AuthError::ConfigError("Missing key_hash".to_string()))?
507 } else {
508 entry
509 .key
510 .clone()
511 .ok_or_else(|| AuthError::ConfigError("Missing key".to_string()))?
512 };
513
514 keys.insert(key_value, entry);
515 }
516
517 Ok(keys)
518 }
519
520 fn validate_key(&self, key: &str) -> AuthResult<Principal> {
521 let lookup_key = if self.config.hash_keys {
522 Self::hash_key(key)
523 } else {
524 key.to_string()
525 };
526
527 let entry = self
528 .keys
529 .get(&lookup_key)
530 .ok_or(AuthError::InvalidCredentials)?;
531
532 let mut principal = Principal::new(
534 entry.user_id.clone(),
535 entry.name.clone(),
536 AuthMethod::ApiKey,
537 );
538
539 if !entry.roles.is_empty() {
541 principal = principal.with_attribute("roles".to_string(), entry.roles.join(","));
542 }
543
544 for (key, value) in &entry.attributes {
546 principal = principal.with_attribute(key.clone(), value.clone());
547 }
548
549 debug!(
550 "Successfully authenticated API key for user: {}",
551 entry.user_id
552 );
553 Ok(principal)
554 }
555
556 fn hash_key(key: &str) -> String {
557 let mut hasher = Sha256::new();
558 hasher.update(key.as_bytes());
559 let result = hasher.finalize();
560 BASE64.encode(result)
561 }
562}
563
564#[cfg(test)]
565mod tests {
566 use super::*;
567 use std::env;
568
569 #[test]
570 fn test_principal_creation() {
571 let principal = Principal::new(
572 "user123".to_string(),
573 "John Doe".to_string(),
574 AuthMethod::Jwt,
575 );
576
577 assert_eq!(principal.id, "user123");
578 assert_eq!(principal.name, "John Doe");
579 assert_eq!(principal.auth_method, AuthMethod::Jwt);
580 }
581
582 #[test]
583 fn test_principal_attributes() {
584 let principal = Principal::new(
585 "user123".to_string(),
586 "John Doe".to_string(),
587 AuthMethod::Jwt,
588 )
589 .with_attribute("role".to_string(), "admin".to_string())
590 .with_attribute("department".to_string(), "Engineering".to_string());
591
592 assert_eq!(principal.get_attribute("role"), Some(&"admin".to_string()));
593 assert!(principal.has_role("admin"));
594 assert!(!principal.has_role("user"));
595 }
596
597 #[test]
598 fn test_api_key_hashing() {
599 let key = "test-api-key-12345";
600 let hash1 = ApiKeyValidator::hash_key(key);
601 let hash2 = ApiKeyValidator::hash_key(key);
602
603 assert_eq!(hash1, hash2); assert!(!hash1.is_empty());
605 }
606
607 #[test]
608 fn test_authenticator_creation() {
609 let config = AuthSettings {
610 enabled: true,
611 methods: vec!["mtls".to_string()],
612 mtls: MtlsSettings {
613 enabled: true,
614 ca_certs_dir: Some(env::temp_dir()),
615 crl_path: None,
616 verify_cn: true,
617 allowed_organizations: Vec::new(),
618 },
619 jwt: JwtSettings::default(),
620 api_key: ApiKeySettings::default(),
621 reject_unauthenticated: true,
622 };
623
624 let auth_result = Authenticator::new(config);
625 assert!(auth_result.is_ok());
626 }
627
628 #[test]
629 fn test_jwt_validator_creation_missing_secret() {
630 let config = JwtSettings {
631 enabled: true,
632 secret: None,
633 public_key_path: None,
634 algorithm: "HS256".to_string(),
635 expiration_secs: 3600,
636 issuer: None,
637 audience: None,
638 };
639
640 let result = JwtValidator::new(config);
641 assert!(result.is_err());
642 }
643
644 #[test]
645 fn test_auth_method_display() {
646 assert_eq!(format!("{}", AuthMethod::MutualTls), "mTLS");
647 assert_eq!(format!("{}", AuthMethod::Jwt), "JWT");
648 assert_eq!(format!("{}", AuthMethod::ApiKey), "API Key");
649 }
650}