use crate::{PeerId, Result, P2PError};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use sha2::{Digest, Sha256};
use tokio::sync::RwLock;
use std::sync::Arc;
use base64::prelude::*;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPToken {
pub header: TokenHeader,
pub payload: TokenPayload,
pub signature: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenHeader {
pub alg: String,
pub typ: String,
pub kid: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenPayload {
pub iss: PeerId,
pub sub: String,
pub aud: String,
pub exp: u64,
pub nbf: u64,
pub iat: u64,
pub jti: String,
pub claims: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum SecurityLevel {
Public,
Basic,
Strong,
Admin,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum MCPPermission {
ReadTools,
ExecuteTools,
RegisterTools,
ModifyTools,
DeleteTools,
AccessPrompts,
AccessResources,
Admin,
Custom(String),
}
impl MCPPermission {
pub fn as_str(&self) -> &str {
match self {
MCPPermission::ReadTools => "read:tools",
MCPPermission::ExecuteTools => "execute:tools",
MCPPermission::RegisterTools => "register:tools",
MCPPermission::ModifyTools => "modify:tools",
MCPPermission::DeleteTools => "delete:tools",
MCPPermission::AccessPrompts => "access:prompts",
MCPPermission::AccessResources => "access:resources",
MCPPermission::Admin => "admin",
MCPPermission::Custom(s) => s,
}
}
pub fn from_str(s: &str) -> Option<Self> {
match s {
"read:tools" => Some(MCPPermission::ReadTools),
"execute:tools" => Some(MCPPermission::ExecuteTools),
"register:tools" => Some(MCPPermission::RegisterTools),
"modify:tools" => Some(MCPPermission::ModifyTools),
"delete:tools" => Some(MCPPermission::DeleteTools),
"access:prompts" => Some(MCPPermission::AccessPrompts),
"access:resources" => Some(MCPPermission::AccessResources),
"admin" => Some(MCPPermission::Admin),
_ => Some(MCPPermission::Custom(s.to_string())),
}
}
}
#[derive(Debug, Clone)]
pub struct PeerACL {
pub peer_id: PeerId,
pub permissions: Vec<MCPPermission>,
pub security_level: SecurityLevel,
pub reputation: f64,
pub last_access: SystemTime,
pub access_count: u64,
pub rate_violations: u32,
pub banned_until: Option<SystemTime>,
}
impl PeerACL {
pub fn new(peer_id: PeerId) -> Self {
Self {
peer_id,
permissions: vec![MCPPermission::ReadTools, MCPPermission::ExecuteTools],
security_level: SecurityLevel::Basic,
reputation: 0.5, last_access: SystemTime::now(),
access_count: 0,
rate_violations: 0,
banned_until: None,
}
}
pub fn has_permission(&self, permission: &MCPPermission) -> bool {
if self.is_banned() {
return false;
}
if self.permissions.contains(&MCPPermission::Admin) {
return true;
}
self.permissions.contains(permission)
}
pub fn is_banned(&self) -> bool {
if let Some(banned_until) = self.banned_until {
SystemTime::now() < banned_until
} else {
false
}
}
pub fn record_access(&mut self) {
self.last_access = SystemTime::now();
self.access_count += 1;
}
pub fn record_rate_violation(&mut self) {
self.rate_violations += 1;
if self.rate_violations >= 10 {
self.banned_until = Some(SystemTime::now() + Duration::from_secs(3600)); }
}
pub fn grant_permission(&mut self, permission: MCPPermission) {
if !self.permissions.contains(&permission) {
self.permissions.push(permission);
}
}
pub fn revoke_permission(&mut self, permission: &MCPPermission) {
self.permissions.retain(|p| p != permission);
}
}
#[derive(Debug, Clone)]
pub struct RateLimiter {
pub rpm_limit: u32,
requests: Arc<RwLock<HashMap<PeerId, Vec<SystemTime>>>>,
}
impl RateLimiter {
pub fn new(rpm_limit: u32) -> Self {
Self {
rpm_limit,
requests: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn is_allowed(&self, peer_id: &PeerId) -> bool {
let mut requests = self.requests.write().await;
let now = SystemTime::now();
let minute_ago = now - Duration::from_secs(60);
let peer_requests = requests.entry(peer_id.clone()).or_insert_with(Vec::new);
peer_requests.retain(|&req_time| req_time > minute_ago);
if peer_requests.len() < self.rpm_limit as usize {
peer_requests.push(now);
true
} else {
false
}
}
pub async fn reset_peer(&self, peer_id: &PeerId) {
let mut requests = self.requests.write().await;
requests.remove(peer_id);
}
pub async fn cleanup(&self) {
let mut requests = self.requests.write().await;
let minute_ago = SystemTime::now() - Duration::from_secs(60);
for peer_requests in requests.values_mut() {
peer_requests.retain(|&req_time| req_time > minute_ago);
}
requests.retain(|_, reqs| !reqs.is_empty());
}
}
pub struct MCPSecurityManager {
acls: Arc<RwLock<HashMap<PeerId, PeerACL>>>,
rate_limiter: RateLimiter,
secret_key: Vec<u8>,
tool_policies: Arc<RwLock<HashMap<String, SecurityLevel>>>,
trusted_peers: Arc<RwLock<Vec<PeerId>>>,
}
impl MCPSecurityManager {
pub fn new(secret_key: Vec<u8>, rpm_limit: u32) -> Self {
Self {
acls: Arc::new(RwLock::new(HashMap::new())),
rate_limiter: RateLimiter::new(rpm_limit),
secret_key,
tool_policies: Arc::new(RwLock::new(HashMap::new())),
trusted_peers: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn generate_token(&self, peer_id: &PeerId, permissions: Vec<MCPPermission>, ttl: Duration) -> Result<String> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| P2PError::MCP(format!("Time error: {}", e)))?;
let payload = TokenPayload {
iss: peer_id.clone(),
sub: peer_id.clone(),
aud: "mcp-server".to_string(),
exp: (now + ttl).as_secs(),
nbf: now.as_secs(),
iat: now.as_secs(),
jti: uuid::Uuid::new_v4().to_string(),
claims: {
let mut claims = HashMap::new();
claims.insert("permissions".to_string(),
serde_json::to_value(permissions.iter().map(|p| p.as_str()).collect::<Vec<_>>()).unwrap());
claims
},
};
let header = TokenHeader {
alg: "HS256".to_string(),
typ: "JWT".to_string(),
kid: None,
};
let header_b64 = base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(serde_json::to_vec(&header)
.map_err(|e| P2PError::Serialization(e))?);
let payload_b64 = base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(serde_json::to_vec(&payload)
.map_err(|e| P2PError::Serialization(e))?);
let signing_input = format!("{}.{}", header_b64, payload_b64);
let signature = self.sign_data(signing_input.as_bytes());
let signature_b64 = base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(signature);
Ok(format!("{}.{}.{}", header_b64, payload_b64, signature_b64))
}
pub async fn verify_token(&self, token: &str) -> Result<TokenPayload> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(P2PError::MCP("Invalid token format".to_string()));
}
let _header_data = base64::prelude::BASE64_URL_SAFE_NO_PAD.decode(parts[0])
.map_err(|e| P2PError::MCP(format!("Invalid header encoding: {}", e)))?;
let payload_data = base64::prelude::BASE64_URL_SAFE_NO_PAD.decode(parts[1])
.map_err(|e| P2PError::MCP(format!("Invalid payload encoding: {}", e)))?;
let signature = base64::prelude::BASE64_URL_SAFE_NO_PAD.decode(parts[2])
.map_err(|e| P2PError::MCP(format!("Invalid signature encoding: {}", e)))?;
let signing_input = format!("{}.{}", parts[0], parts[1]);
let expected_signature = self.sign_data(signing_input.as_bytes());
if signature != expected_signature {
return Err(P2PError::MCP("Invalid token signature".to_string()));
}
let payload: TokenPayload = serde_json::from_slice(&payload_data)
.map_err(|e| P2PError::MCP(format!("Invalid payload: {}", e)))?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| P2PError::MCP(format!("Time error: {}", e)))?
.as_secs();
if payload.exp < now {
return Err(P2PError::MCP("Token expired".to_string()));
}
if payload.nbf > now {
return Err(P2PError::MCP("Token not yet valid".to_string()));
}
Ok(payload)
}
pub async fn check_permission(&self, peer_id: &PeerId, permission: &MCPPermission) -> Result<bool> {
let acls = self.acls.read().await;
if let Some(acl) = acls.get(peer_id) {
Ok(acl.has_permission(permission))
} else {
drop(acls);
let mut acls = self.acls.write().await;
acls.insert(peer_id.clone(), PeerACL::new(peer_id.clone()));
Ok(false) }
}
pub async fn check_rate_limit(&self, peer_id: &PeerId) -> Result<bool> {
if self.rate_limiter.is_allowed(peer_id).await {
Ok(true)
} else {
let mut acls = self.acls.write().await;
if let Some(acl) = acls.get_mut(peer_id) {
acl.record_rate_violation();
}
Ok(false)
}
}
pub async fn grant_permission(&self, peer_id: &PeerId, permission: MCPPermission) -> Result<()> {
let mut acls = self.acls.write().await;
let acl = acls.entry(peer_id.clone()).or_insert_with(|| PeerACL::new(peer_id.clone()));
acl.grant_permission(permission);
Ok(())
}
pub async fn revoke_permission(&self, peer_id: &PeerId, permission: &MCPPermission) -> Result<()> {
let mut acls = self.acls.write().await;
if let Some(acl) = acls.get_mut(peer_id) {
acl.revoke_permission(permission);
}
Ok(())
}
pub async fn add_trusted_peer(&self, peer_id: PeerId) -> Result<()> {
let mut trusted = self.trusted_peers.write().await;
if !trusted.contains(&peer_id) {
trusted.push(peer_id);
}
Ok(())
}
pub async fn is_trusted_peer(&self, peer_id: &PeerId) -> bool {
let trusted = self.trusted_peers.read().await;
trusted.contains(peer_id)
}
pub async fn set_tool_policy(&self, tool_name: String, level: SecurityLevel) -> Result<()> {
let mut policies = self.tool_policies.write().await;
policies.insert(tool_name, level);
Ok(())
}
pub async fn get_tool_policy(&self, tool_name: &str) -> SecurityLevel {
let policies = self.tool_policies.read().await;
policies.get(tool_name).cloned().unwrap_or(SecurityLevel::Basic)
}
fn sign_data(&self, data: &[u8]) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(&self.secret_key);
hasher.update(data);
hasher.finalize().to_vec()
}
pub async fn update_reputation(&self, peer_id: &PeerId, delta: f64) -> Result<()> {
let mut acls = self.acls.write().await;
if let Some(acl) = acls.get_mut(peer_id) {
acl.reputation = (acl.reputation + delta).max(0.0).min(1.0);
}
Ok(())
}
pub async fn get_peer_stats(&self, peer_id: &PeerId) -> Option<PeerACL> {
let acls = self.acls.read().await;
acls.get(peer_id).cloned()
}
pub async fn cleanup(&self) -> Result<()> {
self.rate_limiter.cleanup().await;
let mut acls = self.acls.write().await;
let day_ago = SystemTime::now() - Duration::from_secs(24 * 3600);
acls.retain(|_, acl| acl.last_access > day_ago);
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SecurityAuditEntry {
pub timestamp: SystemTime,
pub event_type: String,
pub peer_id: PeerId,
pub details: HashMap<String, String>,
pub severity: AuditSeverity,
}
#[derive(Debug, Clone, PartialEq)]
pub enum AuditSeverity {
Info,
Warning,
Error,
Critical,
}
pub struct SecurityAuditLogger {
entries: Arc<RwLock<Vec<SecurityAuditEntry>>>,
max_entries: usize,
}
impl SecurityAuditLogger {
pub fn new(max_entries: usize) -> Self {
Self {
entries: Arc::new(RwLock::new(Vec::new())),
max_entries,
}
}
pub async fn log_event(&self, event_type: String, peer_id: PeerId, details: HashMap<String, String>, severity: AuditSeverity) {
let entry = SecurityAuditEntry {
timestamp: SystemTime::now(),
event_type,
peer_id,
details,
severity,
};
let mut entries = self.entries.write().await;
entries.push(entry);
if entries.len() > self.max_entries {
let excess = entries.len() - self.max_entries;
entries.drain(0..excess);
}
}
pub async fn get_recent_entries(&self, limit: Option<usize>) -> Vec<SecurityAuditEntry> {
let entries = self.entries.read().await;
let limit = limit.unwrap_or(entries.len());
entries.iter().rev().take(limit).cloned().collect()
}
pub async fn get_entries_by_severity(&self, severity: AuditSeverity) -> Vec<SecurityAuditEntry> {
let entries = self.entries.read().await;
entries.iter().filter(|e| e.severity == severity).cloned().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
fn create_test_peer() -> PeerId {
format!("test_peer_{}", rand::random::<u32>())
}
fn create_test_security_manager() -> MCPSecurityManager {
let secret_key = b"test_secret_key_1234567890123456".to_vec();
MCPSecurityManager::new(secret_key, 60) }
#[test]
fn test_mcp_permission_string_conversion() {
let permissions = vec![
(MCPPermission::ReadTools, "read:tools"),
(MCPPermission::ExecuteTools, "execute:tools"),
(MCPPermission::RegisterTools, "register:tools"),
(MCPPermission::ModifyTools, "modify:tools"),
(MCPPermission::DeleteTools, "delete:tools"),
(MCPPermission::AccessPrompts, "access:prompts"),
(MCPPermission::AccessResources, "access:resources"),
(MCPPermission::Admin, "admin"),
];
for (permission, expected_str) in permissions {
assert_eq!(permission.as_str(), expected_str);
assert_eq!(MCPPermission::from_str(expected_str), Some(permission));
}
let custom = MCPPermission::Custom("custom:action".to_string());
assert_eq!(custom.as_str(), "custom:action");
assert_eq!(MCPPermission::from_str("custom:action"), Some(custom));
let unknown = MCPPermission::from_str("unknown:permission");
match unknown {
Some(MCPPermission::Custom(s)) => assert_eq!(s, "unknown:permission"),
_ => panic!("Expected custom permission"),
}
}
#[test]
fn test_security_level_ordering() {
assert!(SecurityLevel::Public < SecurityLevel::Basic);
assert!(SecurityLevel::Basic < SecurityLevel::Strong);
assert!(SecurityLevel::Strong < SecurityLevel::Admin);
assert_eq!(SecurityLevel::Public, SecurityLevel::Public);
assert_eq!(SecurityLevel::Basic, SecurityLevel::Basic);
assert_eq!(SecurityLevel::Strong, SecurityLevel::Strong);
assert_eq!(SecurityLevel::Admin, SecurityLevel::Admin);
}
#[test]
fn test_peer_acl_creation() {
let peer_id = create_test_peer();
let acl = PeerACL::new(peer_id.clone());
assert_eq!(acl.peer_id, peer_id);
assert_eq!(acl.permissions.len(), 2); assert!(acl.permissions.contains(&MCPPermission::ReadTools));
assert!(acl.permissions.contains(&MCPPermission::ExecuteTools));
assert_eq!(acl.security_level, SecurityLevel::Basic);
assert_eq!(acl.reputation, 0.5);
assert_eq!(acl.access_count, 0);
assert_eq!(acl.rate_violations, 0);
assert!(acl.banned_until.is_none());
assert!(!acl.is_banned());
}
#[test]
fn test_peer_acl_permissions() {
let peer_id = create_test_peer();
let mut acl = PeerACL::new(peer_id);
assert!(acl.has_permission(&MCPPermission::ReadTools));
assert!(acl.has_permission(&MCPPermission::ExecuteTools));
assert!(!acl.has_permission(&MCPPermission::RegisterTools));
assert!(!acl.has_permission(&MCPPermission::Admin));
acl.grant_permission(MCPPermission::Admin);
assert!(acl.has_permission(&MCPPermission::ReadTools));
assert!(acl.has_permission(&MCPPermission::ExecuteTools));
assert!(acl.has_permission(&MCPPermission::RegisterTools));
assert!(acl.has_permission(&MCPPermission::DeleteTools));
assert!(acl.has_permission(&MCPPermission::Admin));
acl.revoke_permission(&MCPPermission::Admin);
assert!(!acl.has_permission(&MCPPermission::RegisterTools));
assert!(!acl.has_permission(&MCPPermission::Admin));
acl.grant_permission(MCPPermission::RegisterTools);
assert!(acl.has_permission(&MCPPermission::RegisterTools));
acl.revoke_permission(&MCPPermission::RegisterTools);
assert!(!acl.has_permission(&MCPPermission::RegisterTools));
}
#[test]
fn test_peer_acl_ban_functionality() {
let peer_id = create_test_peer();
let mut acl = PeerACL::new(peer_id);
assert!(!acl.is_banned());
assert!(acl.has_permission(&MCPPermission::ReadTools));
for _ in 0..5 {
acl.record_rate_violation();
}
assert_eq!(acl.rate_violations, 5);
assert!(!acl.is_banned());
for _ in 0..5 {
acl.record_rate_violation();
}
assert_eq!(acl.rate_violations, 10);
assert!(acl.is_banned());
assert!(!acl.has_permission(&MCPPermission::ReadTools));
assert!(!acl.has_permission(&MCPPermission::ExecuteTools));
}
#[test]
fn test_peer_acl_access_tracking() {
let peer_id = create_test_peer();
let mut acl = PeerACL::new(peer_id);
let initial_time = acl.last_access;
assert_eq!(acl.access_count, 0);
std::thread::sleep(std::time::Duration::from_millis(10));
acl.record_access();
assert_eq!(acl.access_count, 1);
assert!(acl.last_access > initial_time);
acl.record_access();
assert_eq!(acl.access_count, 2);
}
#[tokio::test]
async fn test_rate_limiter_creation() {
let limiter = RateLimiter::new(60);
assert_eq!(limiter.rpm_limit, 60);
}
#[tokio::test]
async fn test_rate_limiter_basic_functionality() {
let limiter = RateLimiter::new(2); let peer_id = create_test_peer();
assert!(limiter.is_allowed(&peer_id).await);
assert!(limiter.is_allowed(&peer_id).await);
assert!(!limiter.is_allowed(&peer_id).await);
}
#[tokio::test]
async fn test_rate_limiter_different_peers() {
let limiter = RateLimiter::new(1); let peer1 = create_test_peer();
let peer2 = create_test_peer();
assert!(limiter.is_allowed(&peer1).await);
assert!(limiter.is_allowed(&peer2).await);
assert!(!limiter.is_allowed(&peer1).await);
assert!(!limiter.is_allowed(&peer2).await);
}
#[tokio::test]
async fn test_rate_limiter_reset() {
let limiter = RateLimiter::new(1);
let peer_id = create_test_peer();
assert!(limiter.is_allowed(&peer_id).await);
assert!(!limiter.is_allowed(&peer_id).await);
limiter.reset_peer(&peer_id).await;
assert!(limiter.is_allowed(&peer_id).await);
}
#[tokio::test]
async fn test_rate_limiter_cleanup() {
let limiter = RateLimiter::new(10);
let peer_id = create_test_peer();
limiter.is_allowed(&peer_id).await;
limiter.is_allowed(&peer_id).await;
limiter.cleanup().await;
let requests = limiter.requests.read().await;
assert!(requests.contains_key(&peer_id));
let peer_requests = requests.get(&peer_id).unwrap();
assert_eq!(peer_requests.len(), 2);
}
#[tokio::test]
async fn test_security_manager_creation() {
let secret_key = b"test_secret_key".to_vec();
let manager = MCPSecurityManager::new(secret_key.clone(), 60);
assert_eq!(manager.secret_key, secret_key);
assert_eq!(manager.rate_limiter.rpm_limit, 60);
}
#[tokio::test]
async fn test_token_generation_and_verification() -> Result<()> {
let manager = create_test_security_manager();
let peer_id = create_test_peer();
let permissions = vec![MCPPermission::ReadTools, MCPPermission::ExecuteTools];
let ttl = Duration::from_secs(3600);
let token = manager.generate_token(&peer_id, permissions.clone(), ttl).await?;
assert!(!token.is_empty());
let payload = manager.verify_token(&token).await?;
assert_eq!(payload.iss, peer_id);
assert_eq!(payload.sub, peer_id);
assert_eq!(payload.aud, "mcp-server");
let permissions_claim = payload.claims.get("permissions").unwrap();
let permission_strings: Vec<String> = serde_json::from_value(permissions_claim.clone()).unwrap();
assert_eq!(permission_strings.len(), 2);
assert!(permission_strings.contains(&"read:tools".to_string()));
assert!(permission_strings.contains(&"execute:tools".to_string()));
Ok(())
}
#[tokio::test]
async fn test_token_verification_invalid() {
let manager = create_test_security_manager();
let result = manager.verify_token("invalid.token").await;
assert!(result.is_err());
let result = manager.verify_token("invalid.token.format.extra").await;
assert!(result.is_err());
let result = manager.verify_token("").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_token_signature_verification() -> Result<()> {
let manager1 = create_test_security_manager();
let manager2 = MCPSecurityManager::new(b"different_secret".to_vec(), 60);
let peer_id = create_test_peer();
let permissions = vec![MCPPermission::ReadTools];
let ttl = Duration::from_secs(3600);
let token = manager1.generate_token(&peer_id, permissions, ttl).await?;
assert!(manager1.verify_token(&token).await.is_ok());
assert!(manager2.verify_token(&token).await.is_err());
Ok(())
}
#[tokio::test]
async fn test_permission_management() -> Result<()> {
let manager = create_test_security_manager();
let peer_id = create_test_peer();
assert!(!manager.check_permission(&peer_id, &MCPPermission::ExecuteTools).await?);
manager.grant_permission(&peer_id, MCPPermission::ExecuteTools).await?;
assert!(manager.check_permission(&peer_id, &MCPPermission::ExecuteTools).await?);
manager.revoke_permission(&peer_id, &MCPPermission::ExecuteTools).await?;
assert!(!manager.check_permission(&peer_id, &MCPPermission::ExecuteTools).await?);
Ok(())
}
#[tokio::test]
async fn test_rate_limit_checking() -> Result<()> {
let manager = MCPSecurityManager::new(b"test_key".to_vec(), 2); let peer_id = create_test_peer();
manager.grant_permission(&peer_id, MCPPermission::ReadTools).await?;
assert!(manager.check_rate_limit(&peer_id).await?);
assert!(manager.check_rate_limit(&peer_id).await?);
assert!(!manager.check_rate_limit(&peer_id).await?);
let stats = manager.get_peer_stats(&peer_id).await;
assert!(stats.is_some());
let acl = stats.unwrap();
assert_eq!(acl.rate_violations, 1);
Ok(())
}
#[tokio::test]
async fn test_trusted_peer_management() -> Result<()> {
let manager = create_test_security_manager();
let peer_id = create_test_peer();
assert!(!manager.is_trusted_peer(&peer_id).await);
manager.add_trusted_peer(peer_id.clone()).await?;
assert!(manager.is_trusted_peer(&peer_id).await);
manager.add_trusted_peer(peer_id.clone()).await?;
assert!(manager.is_trusted_peer(&peer_id).await);
Ok(())
}
#[tokio::test]
async fn test_tool_security_policies() -> Result<()> {
let manager = create_test_security_manager();
let policy = manager.get_tool_policy("test_tool").await;
assert_eq!(policy, SecurityLevel::Basic);
manager.set_tool_policy("test_tool".to_string(), SecurityLevel::Strong).await?;
let policy = manager.get_tool_policy("test_tool").await;
assert_eq!(policy, SecurityLevel::Strong);
manager.set_tool_policy("admin_tool".to_string(), SecurityLevel::Admin).await?;
let policy = manager.get_tool_policy("admin_tool").await;
assert_eq!(policy, SecurityLevel::Admin);
Ok(())
}
#[tokio::test]
async fn test_reputation_management() -> Result<()> {
let manager = create_test_security_manager();
let peer_id = create_test_peer();
manager.grant_permission(&peer_id, MCPPermission::ReadTools).await?;
let stats = manager.get_peer_stats(&peer_id).await.unwrap();
assert_eq!(stats.reputation, 0.5);
manager.update_reputation(&peer_id, 0.2).await?;
let stats = manager.get_peer_stats(&peer_id).await.unwrap();
assert_eq!(stats.reputation, 0.7);
manager.update_reputation(&peer_id, -0.3).await?;
let stats = manager.get_peer_stats(&peer_id).await.unwrap();
assert!((stats.reputation - 0.4).abs() < 0.001);
manager.update_reputation(&peer_id, -1.0).await?;
let stats = manager.get_peer_stats(&peer_id).await.unwrap();
assert_eq!(stats.reputation, 0.0);
manager.update_reputation(&peer_id, 2.0).await?;
let stats = manager.get_peer_stats(&peer_id).await.unwrap();
assert_eq!(stats.reputation, 1.0);
Ok(())
}
#[tokio::test]
async fn test_security_manager_cleanup() -> Result<()> {
let manager = create_test_security_manager();
let peer_id = create_test_peer();
manager.grant_permission(&peer_id, MCPPermission::ReadTools).await?;
manager.check_rate_limit(&peer_id).await?;
manager.cleanup().await?;
Ok(())
}
#[tokio::test]
async fn test_audit_logger_creation() {
let logger = SecurityAuditLogger::new(100);
assert_eq!(logger.max_entries, 100);
let entries = logger.get_recent_entries(None).await;
assert!(entries.is_empty());
}
#[tokio::test]
async fn test_audit_logger_logging() {
let logger = SecurityAuditLogger::new(10);
let peer_id = create_test_peer();
let mut details = HashMap::new();
details.insert("action".to_string(), "test_action".to_string());
details.insert("result".to_string(), "success".to_string());
logger.log_event(
"test_event".to_string(),
peer_id.clone(),
details.clone(),
AuditSeverity::Info,
).await;
let entries = logger.get_recent_entries(None).await;
assert_eq!(entries.len(), 1);
let entry = &entries[0];
assert_eq!(entry.event_type, "test_event");
assert_eq!(entry.peer_id, peer_id);
assert_eq!(entry.severity, AuditSeverity::Info);
assert_eq!(entry.details.get("action"), Some(&"test_action".to_string()));
}
#[tokio::test]
async fn test_audit_logger_severity_filtering() {
let logger = SecurityAuditLogger::new(10);
let peer_id = create_test_peer();
logger.log_event("info_event".to_string(), peer_id.clone(), HashMap::new(), AuditSeverity::Info).await;
logger.log_event("warning_event".to_string(), peer_id.clone(), HashMap::new(), AuditSeverity::Warning).await;
logger.log_event("error_event".to_string(), peer_id.clone(), HashMap::new(), AuditSeverity::Error).await;
logger.log_event("critical_event".to_string(), peer_id.clone(), HashMap::new(), AuditSeverity::Critical).await;
let info_entries = logger.get_entries_by_severity(AuditSeverity::Info).await;
assert_eq!(info_entries.len(), 1);
assert_eq!(info_entries[0].event_type, "info_event");
let warning_entries = logger.get_entries_by_severity(AuditSeverity::Warning).await;
assert_eq!(warning_entries.len(), 1);
assert_eq!(warning_entries[0].event_type, "warning_event");
let error_entries = logger.get_entries_by_severity(AuditSeverity::Error).await;
assert_eq!(error_entries.len(), 1);
let critical_entries = logger.get_entries_by_severity(AuditSeverity::Critical).await;
assert_eq!(critical_entries.len(), 1);
}
#[tokio::test]
async fn test_audit_logger_max_entries() {
let logger = SecurityAuditLogger::new(3); let peer_id = create_test_peer();
for i in 0..5 {
logger.log_event(
format!("event_{}", i),
peer_id.clone(),
HashMap::new(),
AuditSeverity::Info,
).await;
}
let entries = logger.get_recent_entries(None).await;
assert_eq!(entries.len(), 3);
assert_eq!(entries[0].event_type, "event_4"); assert_eq!(entries[1].event_type, "event_3");
assert_eq!(entries[2].event_type, "event_2");
}
#[tokio::test]
async fn test_audit_logger_recent_entries_limit() {
let logger = SecurityAuditLogger::new(10);
let peer_id = create_test_peer();
for i in 0..5 {
logger.log_event(
format!("event_{}", i),
peer_id.clone(),
HashMap::new(),
AuditSeverity::Info,
).await;
}
let entries = logger.get_recent_entries(Some(3)).await;
assert_eq!(entries.len(), 3);
assert_eq!(entries[0].event_type, "event_4");
assert_eq!(entries[1].event_type, "event_3");
assert_eq!(entries[2].event_type, "event_2");
}
#[test]
fn test_audit_severity_equality() {
assert_eq!(AuditSeverity::Info, AuditSeverity::Info);
assert_eq!(AuditSeverity::Warning, AuditSeverity::Warning);
assert_eq!(AuditSeverity::Error, AuditSeverity::Error);
assert_eq!(AuditSeverity::Critical, AuditSeverity::Critical);
assert_ne!(AuditSeverity::Info, AuditSeverity::Warning);
assert_ne!(AuditSeverity::Warning, AuditSeverity::Error);
assert_ne!(AuditSeverity::Error, AuditSeverity::Critical);
}
#[test]
fn test_token_header_structure() {
let header = TokenHeader {
alg: "HS256".to_string(),
typ: "JWT".to_string(),
kid: Some("key123".to_string()),
};
assert_eq!(header.alg, "HS256");
assert_eq!(header.typ, "JWT");
assert_eq!(header.kid, Some("key123".to_string()));
}
#[test]
fn test_token_payload_structure() {
let peer_id = create_test_peer();
let now = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs();
let mut claims = HashMap::new();
claims.insert("custom".to_string(), serde_json::json!("value"));
let payload = TokenPayload {
iss: peer_id.clone(),
sub: peer_id.to_string(),
aud: "test-audience".to_string(),
exp: now + 3600,
nbf: now,
iat: now,
jti: "unique-id".to_string(),
claims,
};
assert_eq!(payload.iss, peer_id);
assert_eq!(payload.aud, "test-audience");
assert_eq!(payload.jti, "unique-id");
assert!(payload.exp > payload.iat);
assert_eq!(payload.claims.get("custom"), Some(&serde_json::json!("value")));
}
#[test]
fn test_mcp_token_structure() {
let peer_id = create_test_peer();
let header = TokenHeader {
alg: "HS256".to_string(),
typ: "JWT".to_string(),
kid: None,
};
let payload = TokenPayload {
iss: peer_id.clone(),
sub: peer_id.to_string(),
aud: "test".to_string(),
exp: 1234567890,
nbf: 1234567800,
iat: 1234567800,
jti: "test-id".to_string(),
claims: HashMap::new(),
};
let token = MCPToken {
header: header.clone(),
payload: payload.clone(),
signature: "test-signature".to_string(),
};
assert_eq!(token.header.alg, header.alg);
assert_eq!(token.payload.iss, payload.iss);
assert_eq!(token.signature, "test-signature");
}
}