1use std::collections::{HashMap, HashSet};
35use std::sync::atomic::{AtomicU64, Ordering};
36use std::sync::Arc;
37use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
38
39use parking_lot::RwLock;
40
41#[derive(Debug, Clone)]
43pub struct Principal {
44 pub id: String,
46 pub tenant_id: String,
48 pub capabilities: HashSet<Capability>,
50 pub expires_at: Option<u64>,
52 pub auth_method: AuthMethod,
54}
55
56impl Principal {
57 pub fn has_capability(&self, cap: &Capability) -> bool {
59 self.capabilities.contains(cap) || self.capabilities.contains(&Capability::Admin)
61 }
62
63 pub fn is_expired(&self) -> bool {
65 if let Some(exp) = self.expires_at {
66 let now = SystemTime::now()
67 .duration_since(UNIX_EPOCH)
68 .unwrap_or_default()
69 .as_secs();
70 now >= exp
71 } else {
72 false
73 }
74 }
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub enum AuthMethod {
80 MtlsCertificate,
82 JwtBearer,
84 ApiKey,
86 Anonymous,
88}
89
90#[derive(Debug, Clone, PartialEq, Eq, Hash)]
92pub enum Capability {
93 Admin,
95 Read,
97 Write,
99 ManageCollections,
101 ManageIndexes,
103 ViewMetrics,
105 ManageBackups,
107 Custom(String),
109}
110
111impl Capability {
112 pub fn from_str(s: &str) -> Self {
114 match s.to_lowercase().as_str() {
115 "admin" => Capability::Admin,
116 "read" => Capability::Read,
117 "write" => Capability::Write,
118 "manage_collections" => Capability::ManageCollections,
119 "manage_indexes" => Capability::ManageIndexes,
120 "view_metrics" => Capability::ViewMetrics,
121 "manage_backups" => Capability::ManageBackups,
122 _ => Capability::Custom(s.to_string()),
123 }
124 }
125}
126
127pub struct RateLimiter {
129 buckets: RwLock<HashMap<String, TokenBucket>>,
131 default_rate: u64,
133 default_burst: u64,
135 tenant_limits: RwLock<HashMap<String, (u64, u64)>>,
137}
138
139struct TokenBucket {
140 tokens: f64,
141 last_update: Instant,
142 rate: f64, capacity: f64,
144}
145
146impl RateLimiter {
147 pub fn new(default_rate: u64, default_burst: u64) -> Self {
149 Self {
150 buckets: RwLock::new(HashMap::new()),
151 default_rate,
152 default_burst,
153 tenant_limits: RwLock::new(HashMap::new()),
154 }
155 }
156
157 pub fn set_tenant_limit(&self, tenant_id: &str, rate: u64, burst: u64) {
159 self.tenant_limits
160 .write()
161 .insert(tenant_id.to_string(), (rate, burst));
162 }
163
164 pub fn check(&self, principal_id: &str, tenant_id: &str) -> RateLimitResult {
166 let now = Instant::now();
167
168 let (rate, burst) = self
170 .tenant_limits
171 .read()
172 .get(tenant_id)
173 .copied()
174 .unwrap_or((self.default_rate, self.default_burst));
175
176 let mut buckets = self.buckets.write();
177 let bucket = buckets
178 .entry(principal_id.to_string())
179 .or_insert(TokenBucket {
180 tokens: burst as f64,
181 last_update: now,
182 rate: rate as f64,
183 capacity: burst as f64,
184 });
185
186 let elapsed = now.duration_since(bucket.last_update).as_secs_f64();
188 bucket.tokens = (bucket.tokens + elapsed * bucket.rate).min(bucket.capacity);
189 bucket.last_update = now;
190
191 if bucket.tokens >= 1.0 {
192 bucket.tokens -= 1.0;
193 RateLimitResult::Allowed {
194 remaining: bucket.tokens as u64,
195 }
196 } else {
197 let retry_after = (1.0 - bucket.tokens) / bucket.rate;
198 RateLimitResult::Limited {
199 retry_after_ms: (retry_after * 1000.0) as u64,
200 }
201 }
202 }
203
204 pub fn cleanup(&self, max_age: Duration) {
206 let now = Instant::now();
207 let mut buckets = self.buckets.write();
208 buckets.retain(|_, bucket| now.duration_since(bucket.last_update) < max_age);
209 }
210}
211
212#[derive(Debug)]
214pub enum RateLimitResult {
215 Allowed { remaining: u64 },
217 Limited { retry_after_ms: u64 },
219}
220
221#[derive(Debug, Clone)]
223pub struct AuditLogEntry {
224 pub timestamp: u64,
226 pub principal_id: String,
228 pub tenant_id: String,
230 pub action: String,
232 pub resource: String,
234 pub result: AuditResult,
236 pub context: Option<String>,
238 pub request_id: String,
240 pub client_ip: Option<String>,
242}
243
244#[derive(Debug, Clone, Copy)]
246pub enum AuditResult {
247 Success,
248 Failure,
249 Denied,
250}
251
252impl AuditLogEntry {
253 pub fn to_json(&self) -> String {
255 format!(
256 r#"{{"timestamp":{},"principal_id":"{}","tenant_id":"{}","action":"{}","resource":"{}","result":"{}","request_id":"{}","client_ip":{}}}"#,
257 self.timestamp,
258 self.principal_id.replace('"', "\\\""),
259 self.tenant_id.replace('"', "\\\""),
260 self.action.replace('"', "\\\""),
261 self.resource.replace('"', "\\\""),
262 match self.result {
263 AuditResult::Success => "success",
264 AuditResult::Failure => "failure",
265 AuditResult::Denied => "denied",
266 },
267 self.request_id,
268 self.client_ip
269 .as_ref()
270 .map(|ip| format!("\"{}\"", ip))
271 .unwrap_or_else(|| "null".to_string()),
272 )
273 }
274}
275
276pub struct AuditLogger {
278 buffer: RwLock<Vec<AuditLogEntry>>,
280 flush_threshold: usize,
282 total_entries: AtomicU64,
284}
285
286impl AuditLogger {
287 pub fn new(flush_threshold: usize) -> Self {
289 Self {
290 buffer: RwLock::new(Vec::with_capacity(flush_threshold)),
291 flush_threshold,
292 total_entries: AtomicU64::new(0),
293 }
294 }
295
296 pub fn log(&self, entry: AuditLogEntry) {
298 self.total_entries.fetch_add(1, Ordering::Relaxed);
299
300 let mut buffer = self.buffer.write();
301 buffer.push(entry);
302
303 if buffer.len() >= self.flush_threshold {
304 buffer.clear();
307 }
308 }
309
310 pub fn log_success(
312 &self,
313 principal: &Principal,
314 action: &str,
315 resource: &str,
316 request_id: &str,
317 ) {
318 self.log(AuditLogEntry {
319 timestamp: SystemTime::now()
320 .duration_since(UNIX_EPOCH)
321 .unwrap_or_default()
322 .as_secs(),
323 principal_id: principal.id.clone(),
324 tenant_id: principal.tenant_id.clone(),
325 action: action.to_string(),
326 resource: resource.to_string(),
327 result: AuditResult::Success,
328 context: None,
329 request_id: request_id.to_string(),
330 client_ip: None,
331 });
332 }
333
334 pub fn log_denied(
336 &self,
337 principal: &Principal,
338 action: &str,
339 resource: &str,
340 request_id: &str,
341 reason: &str,
342 ) {
343 self.log(AuditLogEntry {
344 timestamp: SystemTime::now()
345 .duration_since(UNIX_EPOCH)
346 .unwrap_or_default()
347 .as_secs(),
348 principal_id: principal.id.clone(),
349 tenant_id: principal.tenant_id.clone(),
350 action: action.to_string(),
351 resource: resource.to_string(),
352 result: AuditResult::Denied,
353 context: Some(format!(r#"{{"reason":"{}"}}"#, reason.replace('"', "\\\""))),
354 request_id: request_id.to_string(),
355 client_ip: None,
356 });
357 }
358
359 pub fn total_entries(&self) -> u64 {
361 self.total_entries.load(Ordering::Relaxed)
362 }
363}
364
365#[derive(Debug, Clone)]
367pub struct SecurityConfig {
368 pub mtls_enabled: bool,
370 pub cert_path: Option<String>,
372 pub key_path: Option<String>,
374 pub ca_cert_path: Option<String>,
376
377 pub jwt_enabled: bool,
379 pub jwks_url: Option<String>,
381 pub jwt_issuer: Option<String>,
383 pub jwt_audience: Option<String>,
385
386 pub api_key_enabled: bool,
388
389 pub rate_limit_default: u64,
391 pub rate_limit_burst: u64,
393
394 pub audit_enabled: bool,
396 pub audit_flush_threshold: usize,
398}
399
400impl Default for SecurityConfig {
401 fn default() -> Self {
402 Self {
403 mtls_enabled: false,
404 cert_path: None,
405 key_path: None,
406 ca_cert_path: None,
407 jwt_enabled: false,
408 jwks_url: None,
409 jwt_issuer: None,
410 jwt_audience: None,
411 api_key_enabled: false,
412 rate_limit_default: 1000,
413 rate_limit_burst: 100,
414 audit_enabled: true,
415 audit_flush_threshold: 100,
416 }
417 }
418}
419
420pub struct SecurityService {
422 config: SecurityConfig,
423 rate_limiter: RateLimiter,
424 audit_logger: AuditLogger,
425 api_keys: RwLock<HashMap<String, Principal>>,
427}
428
429impl SecurityService {
430 pub fn new(config: SecurityConfig) -> Self {
432 let rate_limiter = RateLimiter::new(config.rate_limit_default, config.rate_limit_burst);
433 let audit_logger = AuditLogger::new(config.audit_flush_threshold);
434
435 Self {
436 config,
437 rate_limiter,
438 audit_logger,
439 api_keys: RwLock::new(HashMap::new()),
440 }
441 }
442
443 pub fn register_api_key(&self, key: &str, principal: Principal) {
445 self.api_keys.write().insert(key.to_string(), principal);
446 }
447
448 pub fn authenticate(
450 &self,
451 auth_header: Option<&str>,
452 client_cert: Option<&str>,
453 ) -> Result<Principal, AuthError> {
454 if self.config.mtls_enabled {
456 if let Some(_cert) = client_cert {
457 return Ok(Principal {
459 id: "mtls-client".to_string(),
460 tenant_id: "default".to_string(),
461 capabilities: HashSet::from([Capability::Read, Capability::Write]),
462 expires_at: None,
463 auth_method: AuthMethod::MtlsCertificate,
464 });
465 }
466 }
467
468 if let Some(header) = auth_header {
470 if header.starts_with("Bearer ") {
471 let token = &header[7..];
472
473 if self.config.jwt_enabled {
475 return Ok(Principal {
477 id: "jwt-user".to_string(),
478 tenant_id: "default".to_string(),
479 capabilities: HashSet::from([Capability::Read]),
480 expires_at: Some(
481 SystemTime::now()
482 .duration_since(UNIX_EPOCH)
483 .unwrap_or_default()
484 .as_secs()
485 + 3600,
486 ),
487 auth_method: AuthMethod::JwtBearer,
488 });
489 }
490
491 if self.config.api_key_enabled {
493 if let Some(principal) = self.api_keys.read().get(token) {
494 return Ok(principal.clone());
495 }
496 }
497 }
498 }
499
500 Err(AuthError::Unauthenticated)
501 }
502
503 pub fn authorize(
505 &self,
506 principal: &Principal,
507 required_capability: &Capability,
508 ) -> Result<(), AuthError> {
509 if principal.is_expired() {
511 return Err(AuthError::TokenExpired);
512 }
513
514 if principal.has_capability(required_capability) {
516 Ok(())
517 } else {
518 Err(AuthError::Unauthorized {
519 required: format!("{:?}", required_capability),
520 })
521 }
522 }
523
524 pub fn check_rate_limit(&self, principal: &Principal) -> Result<(), AuthError> {
526 match self.rate_limiter.check(&principal.id, &principal.tenant_id) {
527 RateLimitResult::Allowed { .. } => Ok(()),
528 RateLimitResult::Limited { retry_after_ms } => {
529 Err(AuthError::RateLimited { retry_after_ms })
530 }
531 }
532 }
533
534 pub fn audit(&self) -> &AuditLogger {
536 &self.audit_logger
537 }
538
539 pub fn full_check(
541 &self,
542 auth_header: Option<&str>,
543 client_cert: Option<&str>,
544 required_capability: &Capability,
545 action: &str,
546 resource: &str,
547 request_id: &str,
548 ) -> Result<Principal, AuthError> {
549 let principal = self.authenticate(auth_header, client_cert)?;
551
552 self.check_rate_limit(&principal)?;
554
555 match self.authorize(&principal, required_capability) {
557 Ok(()) => {
558 if self.config.audit_enabled {
559 self.audit_logger
560 .log_success(&principal, action, resource, request_id);
561 }
562 Ok(principal)
563 }
564 Err(e) => {
565 if self.config.audit_enabled {
566 self.audit_logger.log_denied(
567 &principal,
568 action,
569 resource,
570 request_id,
571 &format!("{:?}", e),
572 );
573 }
574 Err(e)
575 }
576 }
577 }
578}
579
580#[derive(Debug)]
582pub enum AuthError {
583 Unauthenticated,
585 TokenExpired,
587 Unauthorized { required: String },
589 RateLimited { retry_after_ms: u64 },
591 Internal(String),
593}
594
595impl std::fmt::Display for AuthError {
596 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
597 match self {
598 AuthError::Unauthenticated => write!(f, "Authentication required"),
599 AuthError::TokenExpired => write!(f, "Token has expired"),
600 AuthError::Unauthorized { required } => {
601 write!(f, "Missing required capability: {}", required)
602 }
603 AuthError::RateLimited { retry_after_ms } => {
604 write!(f, "Rate limit exceeded, retry after {}ms", retry_after_ms)
605 }
606 AuthError::Internal(msg) => write!(f, "Internal error: {}", msg),
607 }
608 }
609}
610
611impl std::error::Error for AuthError {}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616
617 #[test]
618 fn test_capability_check() {
619 let principal = Principal {
620 id: "user1".to_string(),
621 tenant_id: "tenant1".to_string(),
622 capabilities: HashSet::from([Capability::Read, Capability::Write]),
623 expires_at: None,
624 auth_method: AuthMethod::ApiKey,
625 };
626
627 assert!(principal.has_capability(&Capability::Read));
628 assert!(principal.has_capability(&Capability::Write));
629 assert!(!principal.has_capability(&Capability::Admin));
630 }
631
632 #[test]
633 fn test_admin_has_all_capabilities() {
634 let admin = Principal {
635 id: "admin".to_string(),
636 tenant_id: "tenant1".to_string(),
637 capabilities: HashSet::from([Capability::Admin]),
638 expires_at: None,
639 auth_method: AuthMethod::ApiKey,
640 };
641
642 assert!(admin.has_capability(&Capability::Read));
643 assert!(admin.has_capability(&Capability::Write));
644 assert!(admin.has_capability(&Capability::ManageBackups));
645 }
646
647 #[test]
648 fn test_rate_limiter() {
649 let limiter = RateLimiter::new(10, 5); for _ in 0..5 {
653 assert!(matches!(
654 limiter.check("user1", "tenant1"),
655 RateLimitResult::Allowed { .. }
656 ));
657 }
658
659 assert!(matches!(
661 limiter.check("user1", "tenant1"),
662 RateLimitResult::Limited { .. }
663 ));
664 }
665
666 #[test]
667 fn test_security_service_api_key() {
668 let config = SecurityConfig {
669 api_key_enabled: true,
670 ..Default::default()
671 };
672 let service = SecurityService::new(config);
673
674 let principal = Principal {
676 id: "service1".to_string(),
677 tenant_id: "tenant1".to_string(),
678 capabilities: HashSet::from([Capability::Read]),
679 expires_at: None,
680 auth_method: AuthMethod::ApiKey,
681 };
682 service.register_api_key("secret-key-123", principal);
683
684 let result = service.authenticate(Some("Bearer secret-key-123"), None);
686 assert!(result.is_ok());
687 assert_eq!(result.unwrap().id, "service1");
688
689 let result = service.authenticate(Some("Bearer invalid-key"), None);
691 assert!(result.is_err());
692 }
693
694 #[test]
695 fn test_audit_logging() {
696 let logger = AuditLogger::new(10);
697
698 let principal = Principal {
699 id: "user1".to_string(),
700 tenant_id: "tenant1".to_string(),
701 capabilities: HashSet::new(),
702 expires_at: None,
703 auth_method: AuthMethod::ApiKey,
704 };
705
706 logger.log_success(&principal, "read", "/collections/test", "req-123");
707 assert_eq!(logger.total_entries(), 1);
708 }
709}