ant_core/mcp/
security.rs

1//! MCP Security Module
2//!
3//! This module provides comprehensive security features for the MCP server including:
4//! - JWT-based authentication
5//! - Peer identity verification
6//! - Access control and permissions
7//! - Rate limiting and abuse prevention
8//! - Message integrity and encryption
9
10use crate::{PeerId, Result, P2PError};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::time::{Duration, SystemTime, UNIX_EPOCH};
14use sha2::{Digest, Sha256};
15use tokio::sync::RwLock;
16use std::sync::Arc;
17use base64::prelude::*;
18
19/// JWT-like token structure for MCP authentication
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct MCPToken {
22    /// Token header
23    pub header: TokenHeader,
24    /// Token payload
25    pub payload: TokenPayload,
26    /// Token signature
27    pub signature: String,
28}
29
30/// Token header information
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct TokenHeader {
33    /// Algorithm used for signing
34    pub alg: String,
35    /// Token type
36    pub typ: String,
37    /// Key ID
38    pub kid: Option<String>,
39}
40
41/// Token payload with claims
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct TokenPayload {
44    /// Issuer (peer ID)
45    pub iss: PeerId,
46    /// Subject (target peer ID or tool)
47    pub sub: String,
48    /// Audience (intended recipient)
49    pub aud: String,
50    /// Expiration time (Unix timestamp)
51    pub exp: u64,
52    /// Not before time (Unix timestamp)
53    pub nbf: u64,
54    /// Issued at time (Unix timestamp)
55    pub iat: u64,
56    /// JWT ID
57    pub jti: String,
58    /// Custom claims
59    pub claims: HashMap<String, serde_json::Value>,
60}
61
62/// Security level for MCP operations
63#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
64pub enum SecurityLevel {
65    /// Public access - no authentication required
66    Public,
67    /// Basic authentication required
68    Basic,
69    /// Strong authentication required
70    Strong,
71    /// Administrative access required
72    Admin,
73}
74
75/// Permission for MCP operations
76#[derive(Debug, Clone, PartialEq, Eq, Hash)]
77pub enum MCPPermission {
78    /// Read access to tools
79    ReadTools,
80    /// Execute tools
81    ExecuteTools,
82    /// Register new tools
83    RegisterTools,
84    /// Modify existing tools
85    ModifyTools,
86    /// Delete tools
87    DeleteTools,
88    /// Access prompts
89    AccessPrompts,
90    /// Access resources
91    AccessResources,
92    /// Administrative access
93    Admin,
94    /// Custom permission
95    Custom(String),
96}
97
98impl MCPPermission {
99    /// Get permission string representation
100    pub fn as_str(&self) -> &str {
101        match self {
102            MCPPermission::ReadTools => "read:tools",
103            MCPPermission::ExecuteTools => "execute:tools",
104            MCPPermission::RegisterTools => "register:tools",
105            MCPPermission::ModifyTools => "modify:tools",
106            MCPPermission::DeleteTools => "delete:tools",
107            MCPPermission::AccessPrompts => "access:prompts",
108            MCPPermission::AccessResources => "access:resources",
109            MCPPermission::Admin => "admin",
110            MCPPermission::Custom(s) => s,
111        }
112    }
113    
114    /// Parse permission from string
115    pub fn from_str(s: &str) -> Option<Self> {
116        match s {
117            "read:tools" => Some(MCPPermission::ReadTools),
118            "execute:tools" => Some(MCPPermission::ExecuteTools),
119            "register:tools" => Some(MCPPermission::RegisterTools),
120            "modify:tools" => Some(MCPPermission::ModifyTools),
121            "delete:tools" => Some(MCPPermission::DeleteTools),
122            "access:prompts" => Some(MCPPermission::AccessPrompts),
123            "access:resources" => Some(MCPPermission::AccessResources),
124            "admin" => Some(MCPPermission::Admin),
125            _ => Some(MCPPermission::Custom(s.to_string())),
126        }
127    }
128}
129
130/// Access control list for a peer
131#[derive(Debug, Clone)]
132pub struct PeerACL {
133    /// Peer ID
134    pub peer_id: PeerId,
135    /// Granted permissions
136    pub permissions: Vec<MCPPermission>,
137    /// Security level
138    pub security_level: SecurityLevel,
139    /// Reputation score (0.0 to 1.0)
140    pub reputation: f64,
141    /// Last access time
142    pub last_access: SystemTime,
143    /// Access count
144    pub access_count: u64,
145    /// Rate limit violations
146    pub rate_violations: u32,
147    /// Banned until (if applicable)
148    pub banned_until: Option<SystemTime>,
149}
150
151impl PeerACL {
152    /// Create new peer ACL with default permissions
153    pub fn new(peer_id: PeerId) -> Self {
154        Self {
155            peer_id,
156            permissions: vec![MCPPermission::ReadTools, MCPPermission::ExecuteTools],
157            security_level: SecurityLevel::Basic,
158            reputation: 0.5, // Start with neutral reputation
159            last_access: SystemTime::now(),
160            access_count: 0,
161            rate_violations: 0,
162            banned_until: None,
163        }
164    }
165    
166    /// Check if peer has specific permission
167    pub fn has_permission(&self, permission: &MCPPermission) -> bool {
168        if self.is_banned() {
169            return false;
170        }
171        
172        // Admin permission grants all access
173        if self.permissions.contains(&MCPPermission::Admin) {
174            return true;
175        }
176        
177        self.permissions.contains(permission)
178    }
179    
180    /// Check if peer is currently banned
181    pub fn is_banned(&self) -> bool {
182        if let Some(banned_until) = self.banned_until {
183            SystemTime::now() < banned_until
184        } else {
185            false
186        }
187    }
188    
189    /// Update access statistics
190    pub fn record_access(&mut self) {
191        self.last_access = SystemTime::now();
192        self.access_count += 1;
193    }
194    
195    /// Record rate limit violation
196    pub fn record_rate_violation(&mut self) {
197        self.rate_violations += 1;
198        
199        // Auto-ban after too many violations
200        if self.rate_violations >= 10 {
201            self.banned_until = Some(SystemTime::now() + Duration::from_secs(3600)); // 1 hour
202        }
203    }
204    
205    /// Grant permission to peer
206    pub fn grant_permission(&mut self, permission: MCPPermission) {
207        if !self.permissions.contains(&permission) {
208            self.permissions.push(permission);
209        }
210    }
211    
212    /// Revoke permission from peer
213    pub fn revoke_permission(&mut self, permission: &MCPPermission) {
214        self.permissions.retain(|p| p != permission);
215    }
216}
217
218/// Rate limiter for controlling request frequency
219#[derive(Debug, Clone)]
220pub struct RateLimiter {
221    /// Requests per minute limit
222    pub rpm_limit: u32,
223    /// Request timestamps for each peer
224    requests: Arc<RwLock<HashMap<PeerId, Vec<SystemTime>>>>,
225}
226
227impl RateLimiter {
228    /// Create new rate limiter
229    pub fn new(rpm_limit: u32) -> Self {
230        Self {
231            rpm_limit,
232            requests: Arc::new(RwLock::new(HashMap::new())),
233        }
234    }
235    
236    /// Check if request is allowed for peer
237    pub async fn is_allowed(&self, peer_id: &PeerId) -> bool {
238        let mut requests = self.requests.write().await;
239        let now = SystemTime::now();
240        let minute_ago = now - Duration::from_secs(60);
241        
242        // Get or create request history for peer
243        let peer_requests = requests.entry(peer_id.clone()).or_insert_with(Vec::new);
244        
245        // Remove old requests (older than 1 minute)
246        peer_requests.retain(|&req_time| req_time > minute_ago);
247        
248        // Check if under limit
249        if peer_requests.len() < self.rpm_limit as usize {
250            peer_requests.push(now);
251            true
252        } else {
253            false
254        }
255    }
256    
257    /// Reset rate limit for peer (admin function)
258    pub async fn reset_peer(&self, peer_id: &PeerId) {
259        let mut requests = self.requests.write().await;
260        requests.remove(peer_id);
261    }
262    
263    /// Clean up old entries periodically
264    pub async fn cleanup(&self) {
265        let mut requests = self.requests.write().await;
266        let minute_ago = SystemTime::now() - Duration::from_secs(60);
267        
268        for peer_requests in requests.values_mut() {
269            peer_requests.retain(|&req_time| req_time > minute_ago);
270        }
271        
272        // Remove empty entries
273        requests.retain(|_, reqs| !reqs.is_empty());
274    }
275}
276
277/// MCP Security Manager
278pub struct MCPSecurityManager {
279    /// Access control lists
280    acls: Arc<RwLock<HashMap<PeerId, PeerACL>>>,
281    /// Rate limiter
282    rate_limiter: RateLimiter,
283    /// Shared secret for token signing
284    secret_key: Vec<u8>,
285    /// Tool security policies
286    tool_policies: Arc<RwLock<HashMap<String, SecurityLevel>>>,
287    /// Trusted peer list
288    trusted_peers: Arc<RwLock<Vec<PeerId>>>,
289}
290
291impl MCPSecurityManager {
292    /// Create new security manager
293    pub fn new(secret_key: Vec<u8>, rpm_limit: u32) -> Self {
294        Self {
295            acls: Arc::new(RwLock::new(HashMap::new())),
296            rate_limiter: RateLimiter::new(rpm_limit),
297            secret_key,
298            tool_policies: Arc::new(RwLock::new(HashMap::new())),
299            trusted_peers: Arc::new(RwLock::new(Vec::new())),
300        }
301    }
302    
303    /// Generate authentication token for peer
304    pub async fn generate_token(&self, peer_id: &PeerId, permissions: Vec<MCPPermission>, ttl: Duration) -> Result<String> {
305        let now = SystemTime::now()
306            .duration_since(UNIX_EPOCH)
307            .map_err(|e| P2PError::MCP(format!("Time error: {}", e)))?;
308        
309        let payload = TokenPayload {
310            iss: peer_id.clone(),
311            sub: peer_id.clone(),
312            aud: "mcp-server".to_string(),
313            exp: (now + ttl).as_secs(),
314            nbf: now.as_secs(),
315            iat: now.as_secs(),
316            jti: uuid::Uuid::new_v4().to_string(),
317            claims: {
318                let mut claims = HashMap::new();
319                claims.insert("permissions".to_string(), 
320                    serde_json::to_value(permissions.iter().map(|p| p.as_str()).collect::<Vec<_>>()).unwrap());
321                claims
322            },
323        };
324        
325        let header = TokenHeader {
326            alg: "HS256".to_string(),
327            typ: "JWT".to_string(),
328            kid: None,
329        };
330        
331        // Create token without signature first
332        let header_b64 = base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(serde_json::to_vec(&header)
333            .map_err(|e| P2PError::Serialization(e))?);
334        let payload_b64 = base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(serde_json::to_vec(&payload)
335            .map_err(|e| P2PError::Serialization(e))?);
336        
337        // Sign the token
338        let signing_input = format!("{}.{}", header_b64, payload_b64);
339        let signature = self.sign_data(signing_input.as_bytes());
340        let signature_b64 = base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(signature);
341        
342        Ok(format!("{}.{}.{}", header_b64, payload_b64, signature_b64))
343    }
344    
345    /// Verify authentication token
346    pub async fn verify_token(&self, token: &str) -> Result<TokenPayload> {
347        let parts: Vec<&str> = token.split('.').collect();
348        if parts.len() != 3 {
349            return Err(P2PError::MCP("Invalid token format".to_string()));
350        }
351        
352        let _header_data = base64::prelude::BASE64_URL_SAFE_NO_PAD.decode(parts[0])
353            .map_err(|e| P2PError::MCP(format!("Invalid header encoding: {}", e)))?;
354        let payload_data = base64::prelude::BASE64_URL_SAFE_NO_PAD.decode(parts[1])
355            .map_err(|e| P2PError::MCP(format!("Invalid payload encoding: {}", e)))?;
356        let signature = base64::prelude::BASE64_URL_SAFE_NO_PAD.decode(parts[2])
357            .map_err(|e| P2PError::MCP(format!("Invalid signature encoding: {}", e)))?;
358        
359        // Verify signature
360        let signing_input = format!("{}.{}", parts[0], parts[1]);
361        let expected_signature = self.sign_data(signing_input.as_bytes());
362        
363        if signature != expected_signature {
364            return Err(P2PError::MCP("Invalid token signature".to_string()));
365        }
366        
367        // Parse payload
368        let payload: TokenPayload = serde_json::from_slice(&payload_data)
369            .map_err(|e| P2PError::MCP(format!("Invalid payload: {}", e)))?;
370        
371        // Check expiration
372        let now = SystemTime::now()
373            .duration_since(UNIX_EPOCH)
374            .map_err(|e| P2PError::MCP(format!("Time error: {}", e)))?
375            .as_secs();
376        
377        if payload.exp < now {
378            return Err(P2PError::MCP("Token expired".to_string()));
379        }
380        
381        if payload.nbf > now {
382            return Err(P2PError::MCP("Token not yet valid".to_string()));
383        }
384        
385        Ok(payload)
386    }
387    
388    /// Check if peer has permission for operation
389    pub async fn check_permission(&self, peer_id: &PeerId, permission: &MCPPermission) -> Result<bool> {
390        let acls = self.acls.read().await;
391        
392        if let Some(acl) = acls.get(peer_id) {
393            Ok(acl.has_permission(permission))
394        } else {
395            // Create default ACL for new peer
396            drop(acls);
397            let mut acls = self.acls.write().await;
398            acls.insert(peer_id.clone(), PeerACL::new(peer_id.clone()));
399            Ok(false) // New peers start with no permissions by default
400        }
401    }
402    
403    /// Check rate limit for peer
404    pub async fn check_rate_limit(&self, peer_id: &PeerId) -> Result<bool> {
405        if self.rate_limiter.is_allowed(peer_id).await {
406            Ok(true)
407        } else {
408            // Record violation
409            let mut acls = self.acls.write().await;
410            if let Some(acl) = acls.get_mut(peer_id) {
411                acl.record_rate_violation();
412            }
413            Ok(false)
414        }
415    }
416    
417    /// Grant permission to peer
418    pub async fn grant_permission(&self, peer_id: &PeerId, permission: MCPPermission) -> Result<()> {
419        let mut acls = self.acls.write().await;
420        let acl = acls.entry(peer_id.clone()).or_insert_with(|| PeerACL::new(peer_id.clone()));
421        acl.grant_permission(permission);
422        Ok(())
423    }
424    
425    /// Revoke permission from peer
426    pub async fn revoke_permission(&self, peer_id: &PeerId, permission: &MCPPermission) -> Result<()> {
427        let mut acls = self.acls.write().await;
428        if let Some(acl) = acls.get_mut(peer_id) {
429            acl.revoke_permission(permission);
430        }
431        Ok(())
432    }
433    
434    /// Add trusted peer
435    pub async fn add_trusted_peer(&self, peer_id: PeerId) -> Result<()> {
436        let mut trusted = self.trusted_peers.write().await;
437        if !trusted.contains(&peer_id) {
438            trusted.push(peer_id);
439        }
440        Ok(())
441    }
442    
443    /// Check if peer is trusted
444    pub async fn is_trusted_peer(&self, peer_id: &PeerId) -> bool {
445        let trusted = self.trusted_peers.read().await;
446        trusted.contains(peer_id)
447    }
448    
449    /// Set security policy for tool
450    pub async fn set_tool_policy(&self, tool_name: String, level: SecurityLevel) -> Result<()> {
451        let mut policies = self.tool_policies.write().await;
452        policies.insert(tool_name, level);
453        Ok(())
454    }
455    
456    /// Get security policy for tool
457    pub async fn get_tool_policy(&self, tool_name: &str) -> SecurityLevel {
458        let policies = self.tool_policies.read().await;
459        policies.get(tool_name).cloned().unwrap_or(SecurityLevel::Basic)
460    }
461    
462    /// Sign data with secret key
463    fn sign_data(&self, data: &[u8]) -> Vec<u8> {
464        let mut hasher = Sha256::new();
465        hasher.update(&self.secret_key);
466        hasher.update(data);
467        hasher.finalize().to_vec()
468    }
469    
470    /// Update peer reputation based on behavior
471    pub async fn update_reputation(&self, peer_id: &PeerId, delta: f64) -> Result<()> {
472        let mut acls = self.acls.write().await;
473        if let Some(acl) = acls.get_mut(peer_id) {
474            acl.reputation = (acl.reputation + delta).max(0.0).min(1.0);
475        }
476        Ok(())
477    }
478    
479    /// Get peer statistics
480    pub async fn get_peer_stats(&self, peer_id: &PeerId) -> Option<PeerACL> {
481        let acls = self.acls.read().await;
482        acls.get(peer_id).cloned()
483    }
484    
485    /// Clean up expired data
486    pub async fn cleanup(&self) -> Result<()> {
487        self.rate_limiter.cleanup().await;
488        
489        // Clean up old ACLs (remove entries not accessed in 24 hours)
490        let mut acls = self.acls.write().await;
491        let day_ago = SystemTime::now() - Duration::from_secs(24 * 3600);
492        acls.retain(|_, acl| acl.last_access > day_ago);
493        
494        Ok(())
495    }
496}
497
498/// Security audit log entry
499#[derive(Debug, Clone)]
500pub struct SecurityAuditEntry {
501    /// Timestamp
502    pub timestamp: SystemTime,
503    /// Event type
504    pub event_type: String,
505    /// Peer ID involved
506    pub peer_id: PeerId,
507    /// Event details
508    pub details: HashMap<String, String>,
509    /// Severity level
510    pub severity: AuditSeverity,
511}
512
513/// Audit severity levels
514#[derive(Debug, Clone, PartialEq)]
515pub enum AuditSeverity {
516    /// Informational
517    Info,
518    /// Warning
519    Warning,
520    /// Error
521    Error,
522    /// Critical security event
523    Critical,
524}
525
526/// Security audit logger
527pub struct SecurityAuditLogger {
528    /// Audit entries
529    entries: Arc<RwLock<Vec<SecurityAuditEntry>>>,
530    /// Maximum entries to keep
531    max_entries: usize,
532}
533
534impl SecurityAuditLogger {
535    /// Create new audit logger
536    pub fn new(max_entries: usize) -> Self {
537        Self {
538            entries: Arc::new(RwLock::new(Vec::new())),
539            max_entries,
540        }
541    }
542    
543    /// Log security event
544    pub async fn log_event(&self, event_type: String, peer_id: PeerId, details: HashMap<String, String>, severity: AuditSeverity) {
545        let entry = SecurityAuditEntry {
546            timestamp: SystemTime::now(),
547            event_type,
548            peer_id,
549            details,
550            severity,
551        };
552        
553        let mut entries = self.entries.write().await;
554        entries.push(entry);
555        
556        // Keep only recent entries
557        if entries.len() > self.max_entries {
558            let excess = entries.len() - self.max_entries;
559            entries.drain(0..excess);
560        }
561    }
562    
563    /// Get recent audit entries
564    pub async fn get_recent_entries(&self, limit: Option<usize>) -> Vec<SecurityAuditEntry> {
565        let entries = self.entries.read().await;
566        let limit = limit.unwrap_or(entries.len());
567        entries.iter().rev().take(limit).cloned().collect()
568    }
569    
570    /// Get entries by severity
571    pub async fn get_entries_by_severity(&self, severity: AuditSeverity) -> Vec<SecurityAuditEntry> {
572        let entries = self.entries.read().await;
573        entries.iter().filter(|e| e.severity == severity).cloned().collect()
574    }
575}
576
577#[cfg(test)]
578mod tests {
579    use super::*;
580    use std::time::Duration;
581
582    /// Helper function to create a test PeerId
583    fn create_test_peer() -> PeerId {
584        format!("test_peer_{}", rand::random::<u32>())
585    }
586
587    /// Helper function to create a test security manager
588    fn create_test_security_manager() -> MCPSecurityManager {
589        let secret_key = b"test_secret_key_1234567890123456".to_vec();
590        MCPSecurityManager::new(secret_key, 60) // 60 RPM limit
591    }
592
593    #[test]
594    fn test_mcp_permission_string_conversion() {
595        let permissions = vec![
596            (MCPPermission::ReadTools, "read:tools"),
597            (MCPPermission::ExecuteTools, "execute:tools"),
598            (MCPPermission::RegisterTools, "register:tools"),
599            (MCPPermission::ModifyTools, "modify:tools"),
600            (MCPPermission::DeleteTools, "delete:tools"),
601            (MCPPermission::AccessPrompts, "access:prompts"),
602            (MCPPermission::AccessResources, "access:resources"),
603            (MCPPermission::Admin, "admin"),
604        ];
605
606        for (permission, expected_str) in permissions {
607            assert_eq!(permission.as_str(), expected_str);
608            assert_eq!(MCPPermission::from_str(expected_str), Some(permission));
609        }
610
611        // Test custom permission
612        let custom = MCPPermission::Custom("custom:action".to_string());
613        assert_eq!(custom.as_str(), "custom:action");
614        assert_eq!(MCPPermission::from_str("custom:action"), Some(custom));
615
616        // Test unknown permission defaults to custom
617        let unknown = MCPPermission::from_str("unknown:permission");
618        match unknown {
619            Some(MCPPermission::Custom(s)) => assert_eq!(s, "unknown:permission"),
620            _ => panic!("Expected custom permission"),
621        }
622    }
623
624    #[test]
625    fn test_security_level_ordering() {
626        // Test security level ordering
627        assert!(SecurityLevel::Public < SecurityLevel::Basic);
628        assert!(SecurityLevel::Basic < SecurityLevel::Strong);
629        assert!(SecurityLevel::Strong < SecurityLevel::Admin);
630
631        // Test equality
632        assert_eq!(SecurityLevel::Public, SecurityLevel::Public);
633        assert_eq!(SecurityLevel::Basic, SecurityLevel::Basic);
634        assert_eq!(SecurityLevel::Strong, SecurityLevel::Strong);
635        assert_eq!(SecurityLevel::Admin, SecurityLevel::Admin);
636    }
637
638    #[test]
639    fn test_peer_acl_creation() {
640        let peer_id = create_test_peer();
641        let acl = PeerACL::new(peer_id.clone());
642
643        assert_eq!(acl.peer_id, peer_id);
644        assert_eq!(acl.permissions.len(), 2); // Default: ReadTools, ExecuteTools
645        assert!(acl.permissions.contains(&MCPPermission::ReadTools));
646        assert!(acl.permissions.contains(&MCPPermission::ExecuteTools));
647        assert_eq!(acl.security_level, SecurityLevel::Basic);
648        assert_eq!(acl.reputation, 0.5);
649        assert_eq!(acl.access_count, 0);
650        assert_eq!(acl.rate_violations, 0);
651        assert!(acl.banned_until.is_none());
652        assert!(!acl.is_banned());
653    }
654
655    #[test]
656    fn test_peer_acl_permissions() {
657        let peer_id = create_test_peer();
658        let mut acl = PeerACL::new(peer_id);
659
660        // Test default permissions
661        assert!(acl.has_permission(&MCPPermission::ReadTools));
662        assert!(acl.has_permission(&MCPPermission::ExecuteTools));
663        assert!(!acl.has_permission(&MCPPermission::RegisterTools));
664        assert!(!acl.has_permission(&MCPPermission::Admin));
665
666        // Grant admin permission
667        acl.grant_permission(MCPPermission::Admin);
668        // Admin permission grants all access
669        assert!(acl.has_permission(&MCPPermission::ReadTools));
670        assert!(acl.has_permission(&MCPPermission::ExecuteTools));
671        assert!(acl.has_permission(&MCPPermission::RegisterTools));
672        assert!(acl.has_permission(&MCPPermission::DeleteTools));
673        assert!(acl.has_permission(&MCPPermission::Admin));
674
675        // Revoke admin permission
676        acl.revoke_permission(&MCPPermission::Admin);
677        assert!(!acl.has_permission(&MCPPermission::RegisterTools));
678        assert!(!acl.has_permission(&MCPPermission::Admin));
679
680        // Grant specific permission
681        acl.grant_permission(MCPPermission::RegisterTools);
682        assert!(acl.has_permission(&MCPPermission::RegisterTools));
683
684        // Revoke specific permission
685        acl.revoke_permission(&MCPPermission::RegisterTools);
686        assert!(!acl.has_permission(&MCPPermission::RegisterTools));
687    }
688
689    #[test]
690    fn test_peer_acl_ban_functionality() {
691        let peer_id = create_test_peer();
692        let mut acl = PeerACL::new(peer_id);
693
694        // Initially not banned
695        assert!(!acl.is_banned());
696        assert!(acl.has_permission(&MCPPermission::ReadTools));
697
698        // Record violations (but not enough to trigger auto-ban)
699        for _ in 0..5 {
700            acl.record_rate_violation();
701        }
702        assert_eq!(acl.rate_violations, 5);
703        assert!(!acl.is_banned());
704
705        // Record enough violations to trigger auto-ban
706        for _ in 0..5 {
707            acl.record_rate_violation();
708        }
709        assert_eq!(acl.rate_violations, 10);
710        assert!(acl.is_banned());
711
712        // Banned peers have no permissions
713        assert!(!acl.has_permission(&MCPPermission::ReadTools));
714        assert!(!acl.has_permission(&MCPPermission::ExecuteTools));
715    }
716
717    #[test]
718    fn test_peer_acl_access_tracking() {
719        let peer_id = create_test_peer();
720        let mut acl = PeerACL::new(peer_id);
721
722        let initial_time = acl.last_access;
723        assert_eq!(acl.access_count, 0);
724
725        // Record access
726        std::thread::sleep(std::time::Duration::from_millis(10));
727        acl.record_access();
728
729        assert_eq!(acl.access_count, 1);
730        assert!(acl.last_access > initial_time);
731
732        // Record more access
733        acl.record_access();
734        assert_eq!(acl.access_count, 2);
735    }
736
737    #[tokio::test]
738    async fn test_rate_limiter_creation() {
739        let limiter = RateLimiter::new(60);
740        assert_eq!(limiter.rpm_limit, 60);
741    }
742
743    #[tokio::test]
744    async fn test_rate_limiter_basic_functionality() {
745        let limiter = RateLimiter::new(2); // 2 requests per minute
746        let peer_id = create_test_peer();
747
748        // First request should be allowed
749        assert!(limiter.is_allowed(&peer_id).await);
750
751        // Second request should be allowed
752        assert!(limiter.is_allowed(&peer_id).await);
753
754        // Third request should be denied (over limit)
755        assert!(!limiter.is_allowed(&peer_id).await);
756    }
757
758    #[tokio::test]
759    async fn test_rate_limiter_different_peers() {
760        let limiter = RateLimiter::new(1); // 1 request per minute
761        let peer1 = create_test_peer();
762        let peer2 = create_test_peer();
763
764        // Each peer should have their own limit
765        assert!(limiter.is_allowed(&peer1).await);
766        assert!(limiter.is_allowed(&peer2).await);
767
768        // Both should be over their individual limits now
769        assert!(!limiter.is_allowed(&peer1).await);
770        assert!(!limiter.is_allowed(&peer2).await);
771    }
772
773    #[tokio::test]
774    async fn test_rate_limiter_reset() {
775        let limiter = RateLimiter::new(1);
776        let peer_id = create_test_peer();
777
778        // Use up the limit
779        assert!(limiter.is_allowed(&peer_id).await);
780        assert!(!limiter.is_allowed(&peer_id).await);
781
782        // Reset the peer
783        limiter.reset_peer(&peer_id).await;
784
785        // Should be allowed again
786        assert!(limiter.is_allowed(&peer_id).await);
787    }
788
789    #[tokio::test]
790    async fn test_rate_limiter_cleanup() {
791        let limiter = RateLimiter::new(10);
792        let peer_id = create_test_peer();
793
794        // Make some requests
795        limiter.is_allowed(&peer_id).await;
796        limiter.is_allowed(&peer_id).await;
797
798        // Cleanup shouldn't affect recent requests
799        limiter.cleanup().await;
800
801        // Should still have request history
802        let requests = limiter.requests.read().await;
803        assert!(requests.contains_key(&peer_id));
804        let peer_requests = requests.get(&peer_id).unwrap();
805        assert_eq!(peer_requests.len(), 2);
806    }
807
808    #[tokio::test]
809    async fn test_security_manager_creation() {
810        let secret_key = b"test_secret_key".to_vec();
811        let manager = MCPSecurityManager::new(secret_key.clone(), 60);
812
813        // Verify configuration
814        assert_eq!(manager.secret_key, secret_key);
815        assert_eq!(manager.rate_limiter.rpm_limit, 60);
816    }
817
818    #[tokio::test]
819    async fn test_token_generation_and_verification() -> Result<()> {
820        let manager = create_test_security_manager();
821        let peer_id = create_test_peer();
822        let permissions = vec![MCPPermission::ReadTools, MCPPermission::ExecuteTools];
823        let ttl = Duration::from_secs(3600); // 1 hour
824
825        // Generate token
826        let token = manager.generate_token(&peer_id, permissions.clone(), ttl).await?;
827        assert!(!token.is_empty());
828
829        // Verify token
830        let payload = manager.verify_token(&token).await?;
831        assert_eq!(payload.iss, peer_id);
832        assert_eq!(payload.sub, peer_id);
833        assert_eq!(payload.aud, "mcp-server");
834
835        // Check permissions in claims
836        let permissions_claim = payload.claims.get("permissions").unwrap();
837        let permission_strings: Vec<String> = serde_json::from_value(permissions_claim.clone()).unwrap();
838        assert_eq!(permission_strings.len(), 2);
839        assert!(permission_strings.contains(&"read:tools".to_string()));
840        assert!(permission_strings.contains(&"execute:tools".to_string()));
841
842        Ok(())
843    }
844
845    #[tokio::test]
846    async fn test_token_verification_invalid() {
847        let manager = create_test_security_manager();
848
849        // Test invalid token format
850        let result = manager.verify_token("invalid.token").await;
851        assert!(result.is_err());
852
853        // Test malformed token
854        let result = manager.verify_token("invalid.token.format.extra").await;
855        assert!(result.is_err());
856
857        // Test empty token
858        let result = manager.verify_token("").await;
859        assert!(result.is_err());
860    }
861
862    #[tokio::test]
863    async fn test_token_signature_verification() -> Result<()> {
864        let manager1 = create_test_security_manager();
865        let manager2 = MCPSecurityManager::new(b"different_secret".to_vec(), 60);
866
867        let peer_id = create_test_peer();
868        let permissions = vec![MCPPermission::ReadTools];
869        let ttl = Duration::from_secs(3600);
870
871        // Generate token with manager1
872        let token = manager1.generate_token(&peer_id, permissions, ttl).await?;
873
874        // Verify with manager1 should succeed
875        assert!(manager1.verify_token(&token).await.is_ok());
876
877        // Verify with manager2 should fail (different secret)
878        assert!(manager2.verify_token(&token).await.is_err());
879
880        Ok(())
881    }
882
883    #[tokio::test]
884    async fn test_permission_management() -> Result<()> {
885        let manager = create_test_security_manager();
886        let peer_id = create_test_peer();
887
888        // Initially should have no permissions (new peer starts with false)
889        assert!(!manager.check_permission(&peer_id, &MCPPermission::ExecuteTools).await?);
890
891        // Grant permission
892        manager.grant_permission(&peer_id, MCPPermission::ExecuteTools).await?;
893        assert!(manager.check_permission(&peer_id, &MCPPermission::ExecuteTools).await?);
894
895        // Revoke permission
896        manager.revoke_permission(&peer_id, &MCPPermission::ExecuteTools).await?;
897        assert!(!manager.check_permission(&peer_id, &MCPPermission::ExecuteTools).await?);
898
899        Ok(())
900    }
901
902    #[tokio::test]
903    async fn test_rate_limit_checking() -> Result<()> {
904        let manager = MCPSecurityManager::new(b"test_key".to_vec(), 2); // 2 RPM limit
905        let peer_id = create_test_peer();
906
907        // Grant permission first to create ACL entry
908        manager.grant_permission(&peer_id, MCPPermission::ReadTools).await?;
909
910        // First two requests should pass
911        assert!(manager.check_rate_limit(&peer_id).await?);
912        assert!(manager.check_rate_limit(&peer_id).await?);
913
914        // Third request should fail
915        assert!(!manager.check_rate_limit(&peer_id).await?);
916
917        // Check that violation was recorded
918        let stats = manager.get_peer_stats(&peer_id).await;
919        assert!(stats.is_some());
920        let acl = stats.unwrap();
921        assert_eq!(acl.rate_violations, 1);
922
923        Ok(())
924    }
925
926    #[tokio::test]
927    async fn test_trusted_peer_management() -> Result<()> {
928        let manager = create_test_security_manager();
929        let peer_id = create_test_peer();
930
931        // Initially not trusted
932        assert!(!manager.is_trusted_peer(&peer_id).await);
933
934        // Add as trusted
935        manager.add_trusted_peer(peer_id.clone()).await?;
936        assert!(manager.is_trusted_peer(&peer_id).await);
937
938        // Adding same peer again should be idempotent
939        manager.add_trusted_peer(peer_id.clone()).await?;
940        assert!(manager.is_trusted_peer(&peer_id).await);
941
942        Ok(())
943    }
944
945    #[tokio::test]
946    async fn test_tool_security_policies() -> Result<()> {
947        let manager = create_test_security_manager();
948
949        // Default policy should be Basic
950        let policy = manager.get_tool_policy("test_tool").await;
951        assert_eq!(policy, SecurityLevel::Basic);
952
953        // Set custom policy
954        manager.set_tool_policy("test_tool".to_string(), SecurityLevel::Strong).await?;
955        let policy = manager.get_tool_policy("test_tool").await;
956        assert_eq!(policy, SecurityLevel::Strong);
957
958        // Set admin policy
959        manager.set_tool_policy("admin_tool".to_string(), SecurityLevel::Admin).await?;
960        let policy = manager.get_tool_policy("admin_tool").await;
961        assert_eq!(policy, SecurityLevel::Admin);
962
963        Ok(())
964    }
965
966    #[tokio::test]
967    async fn test_reputation_management() -> Result<()> {
968        let manager = create_test_security_manager();
969        let peer_id = create_test_peer();
970
971        // Grant permission to create ACL entry
972        manager.grant_permission(&peer_id, MCPPermission::ReadTools).await?;
973
974        let stats = manager.get_peer_stats(&peer_id).await.unwrap();
975        assert_eq!(stats.reputation, 0.5); // Default reputation
976
977        // Increase reputation
978        manager.update_reputation(&peer_id, 0.2).await?;
979        let stats = manager.get_peer_stats(&peer_id).await.unwrap();
980        assert_eq!(stats.reputation, 0.7);
981
982        // Decrease reputation
983        manager.update_reputation(&peer_id, -0.3).await?;
984        let stats = manager.get_peer_stats(&peer_id).await.unwrap();
985        assert!((stats.reputation - 0.4).abs() < 0.001); // Use epsilon for float comparison
986
987        // Test bounds (should clamp to 0.0-1.0)
988        manager.update_reputation(&peer_id, -1.0).await?;
989        let stats = manager.get_peer_stats(&peer_id).await.unwrap();
990        assert_eq!(stats.reputation, 0.0);
991
992        manager.update_reputation(&peer_id, 2.0).await?;
993        let stats = manager.get_peer_stats(&peer_id).await.unwrap();
994        assert_eq!(stats.reputation, 1.0);
995
996        Ok(())
997    }
998
999    #[tokio::test]
1000    async fn test_security_manager_cleanup() -> Result<()> {
1001        let manager = create_test_security_manager();
1002        let peer_id = create_test_peer();
1003
1004        // Create some data
1005        manager.grant_permission(&peer_id, MCPPermission::ReadTools).await?;
1006        manager.check_rate_limit(&peer_id).await?;
1007
1008        // Cleanup should work without errors
1009        manager.cleanup().await?;
1010
1011        Ok(())
1012    }
1013
1014    #[tokio::test]
1015    async fn test_audit_logger_creation() {
1016        let logger = SecurityAuditLogger::new(100);
1017        assert_eq!(logger.max_entries, 100);
1018
1019        let entries = logger.get_recent_entries(None).await;
1020        assert!(entries.is_empty());
1021    }
1022
1023    #[tokio::test]
1024    async fn test_audit_logger_logging() {
1025        let logger = SecurityAuditLogger::new(10);
1026        let peer_id = create_test_peer();
1027
1028        let mut details = HashMap::new();
1029        details.insert("action".to_string(), "test_action".to_string());
1030        details.insert("result".to_string(), "success".to_string());
1031
1032        // Log an event
1033        logger.log_event(
1034            "test_event".to_string(),
1035            peer_id.clone(),
1036            details.clone(),
1037            AuditSeverity::Info,
1038        ).await;
1039
1040        let entries = logger.get_recent_entries(None).await;
1041        assert_eq!(entries.len(), 1);
1042
1043        let entry = &entries[0];
1044        assert_eq!(entry.event_type, "test_event");
1045        assert_eq!(entry.peer_id, peer_id);
1046        assert_eq!(entry.severity, AuditSeverity::Info);
1047        assert_eq!(entry.details.get("action"), Some(&"test_action".to_string()));
1048    }
1049
1050    #[tokio::test]
1051    async fn test_audit_logger_severity_filtering() {
1052        let logger = SecurityAuditLogger::new(10);
1053        let peer_id = create_test_peer();
1054
1055        // Log events with different severities
1056        logger.log_event("info_event".to_string(), peer_id.clone(), HashMap::new(), AuditSeverity::Info).await;
1057        logger.log_event("warning_event".to_string(), peer_id.clone(), HashMap::new(), AuditSeverity::Warning).await;
1058        logger.log_event("error_event".to_string(), peer_id.clone(), HashMap::new(), AuditSeverity::Error).await;
1059        logger.log_event("critical_event".to_string(), peer_id.clone(), HashMap::new(), AuditSeverity::Critical).await;
1060
1061        // Test filtering by severity
1062        let info_entries = logger.get_entries_by_severity(AuditSeverity::Info).await;
1063        assert_eq!(info_entries.len(), 1);
1064        assert_eq!(info_entries[0].event_type, "info_event");
1065
1066        let warning_entries = logger.get_entries_by_severity(AuditSeverity::Warning).await;
1067        assert_eq!(warning_entries.len(), 1);
1068        assert_eq!(warning_entries[0].event_type, "warning_event");
1069
1070        let error_entries = logger.get_entries_by_severity(AuditSeverity::Error).await;
1071        assert_eq!(error_entries.len(), 1);
1072
1073        let critical_entries = logger.get_entries_by_severity(AuditSeverity::Critical).await;
1074        assert_eq!(critical_entries.len(), 1);
1075    }
1076
1077    #[tokio::test]
1078    async fn test_audit_logger_max_entries() {
1079        let logger = SecurityAuditLogger::new(3); // Limit to 3 entries
1080        let peer_id = create_test_peer();
1081
1082        // Log 5 events
1083        for i in 0..5 {
1084            logger.log_event(
1085                format!("event_{}", i),
1086                peer_id.clone(),
1087                HashMap::new(),
1088                AuditSeverity::Info,
1089            ).await;
1090        }
1091
1092        let entries = logger.get_recent_entries(None).await;
1093        assert_eq!(entries.len(), 3); // Should only keep 3 most recent
1094
1095        // Check that we have the most recent events (2, 3, 4)
1096        assert_eq!(entries[0].event_type, "event_4"); // Most recent first
1097        assert_eq!(entries[1].event_type, "event_3");
1098        assert_eq!(entries[2].event_type, "event_2");
1099    }
1100
1101    #[tokio::test]
1102    async fn test_audit_logger_recent_entries_limit() {
1103        let logger = SecurityAuditLogger::new(10);
1104        let peer_id = create_test_peer();
1105
1106        // Log 5 events
1107        for i in 0..5 {
1108            logger.log_event(
1109                format!("event_{}", i),
1110                peer_id.clone(),
1111                HashMap::new(),
1112                AuditSeverity::Info,
1113            ).await;
1114        }
1115
1116        // Get limited number of recent entries
1117        let entries = logger.get_recent_entries(Some(3)).await;
1118        assert_eq!(entries.len(), 3);
1119
1120        // Should be most recent first
1121        assert_eq!(entries[0].event_type, "event_4");
1122        assert_eq!(entries[1].event_type, "event_3");
1123        assert_eq!(entries[2].event_type, "event_2");
1124    }
1125
1126    #[test]
1127    fn test_audit_severity_equality() {
1128        assert_eq!(AuditSeverity::Info, AuditSeverity::Info);
1129        assert_eq!(AuditSeverity::Warning, AuditSeverity::Warning);
1130        assert_eq!(AuditSeverity::Error, AuditSeverity::Error);
1131        assert_eq!(AuditSeverity::Critical, AuditSeverity::Critical);
1132
1133        assert_ne!(AuditSeverity::Info, AuditSeverity::Warning);
1134        assert_ne!(AuditSeverity::Warning, AuditSeverity::Error);
1135        assert_ne!(AuditSeverity::Error, AuditSeverity::Critical);
1136    }
1137
1138    #[test]
1139    fn test_token_header_structure() {
1140        let header = TokenHeader {
1141            alg: "HS256".to_string(),
1142            typ: "JWT".to_string(),
1143            kid: Some("key123".to_string()),
1144        };
1145
1146        assert_eq!(header.alg, "HS256");
1147        assert_eq!(header.typ, "JWT");
1148        assert_eq!(header.kid, Some("key123".to_string()));
1149    }
1150
1151    #[test]
1152    fn test_token_payload_structure() {
1153        let peer_id = create_test_peer();
1154        let now = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs();
1155
1156        let mut claims = HashMap::new();
1157        claims.insert("custom".to_string(), serde_json::json!("value"));
1158
1159        let payload = TokenPayload {
1160            iss: peer_id.clone(),
1161            sub: peer_id.to_string(),
1162            aud: "test-audience".to_string(),
1163            exp: now + 3600,
1164            nbf: now,
1165            iat: now,
1166            jti: "unique-id".to_string(),
1167            claims,
1168        };
1169
1170        assert_eq!(payload.iss, peer_id);
1171        assert_eq!(payload.aud, "test-audience");
1172        assert_eq!(payload.jti, "unique-id");
1173        assert!(payload.exp > payload.iat);
1174        assert_eq!(payload.claims.get("custom"), Some(&serde_json::json!("value")));
1175    }
1176
1177    #[test]
1178    fn test_mcp_token_structure() {
1179        let peer_id = create_test_peer();
1180
1181        let header = TokenHeader {
1182            alg: "HS256".to_string(),
1183            typ: "JWT".to_string(),
1184            kid: None,
1185        };
1186
1187        let payload = TokenPayload {
1188            iss: peer_id.clone(),
1189            sub: peer_id.to_string(),
1190            aud: "test".to_string(),
1191            exp: 1234567890,
1192            nbf: 1234567800,
1193            iat: 1234567800,
1194            jti: "test-id".to_string(),
1195            claims: HashMap::new(),
1196        };
1197
1198        let token = MCPToken {
1199            header: header.clone(),
1200            payload: payload.clone(),
1201            signature: "test-signature".to_string(),
1202        };
1203
1204        assert_eq!(token.header.alg, header.alg);
1205        assert_eq!(token.payload.iss, payload.iss);
1206        assert_eq!(token.signature, "test-signature");
1207    }
1208}