use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet, VecDeque};
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::sync::{broadcast, Mutex as AsyncMutex, RwLock as AsyncRwLock};
use tracing::{debug, info, warn};
use crate::ast::{Document, OperationType};
#[derive(Debug, Clone)]
pub struct ZeroTrustConfig {
pub enable_continuous_auth: bool,
pub enable_behavioral_analysis: bool,
pub enable_threat_detection: bool,
pub enable_data_loss_prevention: bool,
pub enable_encryption_at_rest: bool,
pub enable_network_segmentation: bool,
pub enable_device_trust: bool,
pub auth_token_lifetime: Duration,
pub session_timeout: Duration,
pub max_failed_attempts: usize,
pub rate_limiting: RateLimitConfig,
pub encryption_config: EncryptionConfig,
pub audit_config: AuditConfig,
pub threat_detection_config: ThreatDetectionConfig,
}
impl Default for ZeroTrustConfig {
fn default() -> Self {
Self {
enable_continuous_auth: true,
enable_behavioral_analysis: true,
enable_threat_detection: true,
enable_data_loss_prevention: true,
enable_encryption_at_rest: true,
enable_network_segmentation: true,
enable_device_trust: true,
auth_token_lifetime: Duration::from_secs(3600),
session_timeout: Duration::from_secs(1800),
max_failed_attempts: 5,
rate_limiting: RateLimitConfig::default(),
encryption_config: EncryptionConfig::default(),
audit_config: AuditConfig::default(),
threat_detection_config: ThreatDetectionConfig::default(),
}
}
}
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub requests_per_minute: usize,
pub burst_capacity: usize,
pub sliding_window_duration: Duration,
pub enable_adaptive_limits: bool,
pub enable_ip_based_limiting: bool,
pub enable_user_based_limiting: bool,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
requests_per_minute: 1000,
burst_capacity: 100,
sliding_window_duration: Duration::from_secs(60),
enable_adaptive_limits: true,
enable_ip_based_limiting: true,
enable_user_based_limiting: true,
}
}
}
#[derive(Debug, Clone)]
pub struct EncryptionConfig {
pub algorithm: EncryptionAlgorithm,
pub key_rotation_interval: Duration,
pub enable_field_level_encryption: bool,
pub enable_query_encryption: bool,
pub enable_result_encryption: bool,
pub kms_integration: bool,
}
impl Default for EncryptionConfig {
fn default() -> Self {
Self {
algorithm: EncryptionAlgorithm::AES256GCM,
key_rotation_interval: Duration::from_secs(86400), enable_field_level_encryption: true,
enable_query_encryption: true,
enable_result_encryption: true,
kms_integration: true,
}
}
}
#[derive(Debug, Clone)]
pub enum EncryptionAlgorithm {
AES256GCM,
ChaCha20Poly1305,
AES256CTR,
}
#[derive(Debug, Clone)]
pub struct AuditConfig {
pub enable_audit_logging: bool,
pub log_all_queries: bool,
pub log_authentication_events: bool,
pub log_authorization_events: bool,
pub log_data_access: bool,
pub log_admin_actions: bool,
pub retention_period: Duration,
pub enable_real_time_alerts: bool,
}
impl Default for AuditConfig {
fn default() -> Self {
Self {
enable_audit_logging: true,
log_all_queries: true,
log_authentication_events: true,
log_authorization_events: true,
log_data_access: true,
log_admin_actions: true,
retention_period: Duration::from_secs(86400 * 365), enable_real_time_alerts: true,
}
}
}
#[derive(Debug, Clone)]
pub struct ThreatDetectionConfig {
pub enable_anomaly_detection: bool,
pub enable_pattern_matching: bool,
pub enable_ml_threat_detection: bool,
pub enable_network_analysis: bool,
pub threat_score_threshold: f64,
pub response_actions: Vec<ThreatResponseAction>,
}
impl Default for ThreatDetectionConfig {
fn default() -> Self {
Self {
enable_anomaly_detection: true,
enable_pattern_matching: true,
enable_ml_threat_detection: true,
enable_network_analysis: true,
threat_score_threshold: 0.7,
response_actions: vec![
ThreatResponseAction::LogEvent,
ThreatResponseAction::NotifyAdmin,
ThreatResponseAction::IncreaseMonitoring,
],
}
}
}
#[derive(Debug, Clone)]
pub enum ThreatResponseAction {
LogEvent,
NotifyAdmin,
BlockUser,
BlockIP,
RequireReauth,
IncreaseMonitoring,
RevokeTokens,
TriggerCircuitBreaker,
}
#[derive(Debug, Clone)]
pub struct SecurityContext {
pub user_id: Option<String>,
pub session_id: String,
pub device_id: Option<String>,
pub ip_address: IpAddr,
pub user_agent: Option<String>,
pub auth_level: AuthenticationLevel,
pub permissions: HashSet<Permission>,
pub trust_score: f64,
pub session_start: SystemTime,
pub last_activity: SystemTime,
pub authentication_factors: Vec<AuthenticationFactor>,
pub risk_factors: Vec<RiskFactor>,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum AuthenticationLevel {
Anonymous,
Basic,
MultiFactorAuthenticated,
CertificateBased,
BiometricVerified,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum Permission {
ReadData,
WriteData,
DeleteData,
AdminAccess,
QueryIntrospection,
SchemaAccess,
MetricsAccess,
AuditLogAccess,
UserManagement,
SystemConfiguration,
}
#[derive(Debug, Clone)]
pub enum AuthenticationFactor {
Password {
verified_at: SystemTime,
},
TwoFactorCode {
verified_at: SystemTime,
},
BiometricScan {
scan_type: BiometricType,
verified_at: SystemTime,
},
HardwareToken {
token_id: String,
verified_at: SystemTime,
},
CertificateAuth {
cert_fingerprint: String,
verified_at: SystemTime,
},
}
#[derive(Debug, Clone)]
pub enum BiometricType {
Fingerprint,
FaceRecognition,
VoicePrint,
IrisScanning,
}
#[derive(Debug, Clone)]
pub struct RiskFactor {
pub factor_type: RiskFactorType,
pub severity: RiskSeverity,
pub description: String,
pub detected_at: SystemTime,
pub auto_mitigated: bool,
}
#[derive(Debug, Clone)]
pub enum RiskFactorType {
UnusualLocation,
NewDevice,
SuspiciousQuery,
HighFailureRate,
DataExfiltrationAttempt,
PrivilegeEscalation,
AnomalousAccess,
NetworkAnomaly,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum RiskSeverity {
Low,
Medium,
High,
Critical,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityEvent {
pub event_id: String,
pub event_type: SecurityEventType,
pub severity: RiskSeverity,
pub user_id: Option<String>,
pub session_id: String,
pub ip_address: IpAddr,
pub timestamp: SystemTime,
pub description: String,
pub metadata: HashMap<String, String>,
pub threat_score: Option<f64>,
pub mitigated: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SecurityEventType {
AuthenticationAttempt,
AuthenticationFailure,
AuthenticationSuccess,
AuthorizationDenied,
SuspiciousQuery,
DataAccess,
AdminAction,
ThreatDetected,
SecurityViolation,
PolicyEnforcement,
}
#[derive(Debug, Clone)]
pub struct BehavioralProfile {
pub user_id: String,
pub typical_access_patterns: Vec<AccessPattern>,
pub common_query_types: HashSet<String>,
pub usual_access_times: Vec<TimeRange>,
pub typical_locations: HashSet<IpAddr>,
pub baseline_metrics: BaselineMetrics,
pub last_updated: SystemTime,
}
#[derive(Debug, Clone)]
pub struct AccessPattern {
pub pattern_type: String,
pub frequency: f64,
pub typical_duration: Duration,
pub confidence: f64,
}
#[derive(Debug, Clone)]
pub struct TimeRange {
pub start_hour: u8,
pub end_hour: u8,
pub days_of_week: HashSet<u8>,
}
#[derive(Debug, Clone)]
pub struct BaselineMetrics {
pub avg_queries_per_session: f64,
pub avg_session_duration: Duration,
pub typical_query_complexity: f64,
pub common_field_access: HashSet<String>,
}
pub struct ZeroTrustSecurityManager {
config: ZeroTrustConfig,
active_sessions: Arc<AsyncRwLock<HashMap<String, SecurityContext>>>,
rate_limiters: Arc<AsyncRwLock<HashMap<String, RateLimiter>>>,
behavioral_profiles: Arc<AsyncRwLock<HashMap<String, BehavioralProfile>>>,
threat_detector: Arc<AsyncMutex<ThreatDetector>>,
audit_logger: Arc<AsyncMutex<AuditLogger>>,
encryption_manager: Arc<AsyncMutex<EncryptionManager>>,
security_event_sender: broadcast::Sender<SecurityEvent>,
blocked_ips: Arc<AsyncRwLock<HashSet<IpAddr>>>,
blocked_users: Arc<AsyncRwLock<HashSet<String>>>,
}
impl ZeroTrustSecurityManager {
pub fn new(config: ZeroTrustConfig) -> (Self, broadcast::Receiver<SecurityEvent>) {
let (security_event_sender, security_event_receiver) = broadcast::channel(1000);
let manager = Self {
config: config.clone(),
active_sessions: Arc::new(AsyncRwLock::new(HashMap::new())),
rate_limiters: Arc::new(AsyncRwLock::new(HashMap::new())),
behavioral_profiles: Arc::new(AsyncRwLock::new(HashMap::new())),
threat_detector: Arc::new(AsyncMutex::new(ThreatDetector::new(
&config.threat_detection_config,
))),
audit_logger: Arc::new(AsyncMutex::new(AuditLogger::new(&config.audit_config))),
encryption_manager: Arc::new(AsyncMutex::new(EncryptionManager::new(
&config.encryption_config,
))),
security_event_sender,
blocked_ips: Arc::new(AsyncRwLock::new(HashSet::new())),
blocked_users: Arc::new(AsyncRwLock::new(HashSet::new())),
};
(manager, security_event_receiver)
}
pub async fn authenticate_request(&self, request: &SecurityRequest) -> Result<SecurityContext> {
info!("Authenticating request from {}", request.ip_address);
if self.is_ip_blocked(request.ip_address).await? {
return Err(anyhow!("IP address blocked"));
}
if let Some(ref user_id) = request.user_id {
if self.is_user_blocked(user_id).await? {
return Err(anyhow!("User account blocked"));
}
}
self.check_rate_limits(request).await?;
let auth_info = self
.validate_authentication_token(&request.auth_token)
.await?;
let mut context = SecurityContext {
user_id: auth_info.user_id,
session_id: auth_info.session_id,
device_id: request.device_id.clone(),
ip_address: request.ip_address,
user_agent: request.user_agent.clone(),
auth_level: auth_info.auth_level,
permissions: auth_info.permissions,
trust_score: 1.0, session_start: auth_info.session_start,
last_activity: SystemTime::now(),
authentication_factors: auth_info.authentication_factors,
risk_factors: Vec::new(),
};
context.trust_score = self.calculate_trust_score(&context, request).await?;
if self.config.enable_behavioral_analysis {
self.analyze_behavior(&mut context, request).await?;
}
if self.config.enable_threat_detection {
self.detect_threats(&mut context, request).await?;
}
self.active_sessions
.write()
.await
.insert(context.session_id.clone(), context.clone());
self.log_security_event(SecurityEvent {
event_id: uuid::Uuid::new_v4().to_string(),
event_type: SecurityEventType::AuthenticationSuccess,
severity: RiskSeverity::Low,
user_id: context.user_id.clone(),
session_id: context.session_id.clone(),
ip_address: context.ip_address,
timestamp: SystemTime::now(),
description: "User authenticated successfully".to_string(),
metadata: HashMap::new(),
threat_score: Some(context.trust_score),
mitigated: false,
})
.await?;
Ok(context)
}
pub async fn authorize_query(
&self,
context: &SecurityContext,
query: &Document,
) -> Result<AuthorizationResult> {
debug!("Authorizing query for user {:?}", context.user_id);
self.validate_session(context).await?;
let required_permissions = self.extract_query_permissions(query).await?;
let mut missing_permissions = Vec::new();
for permission in &required_permissions {
if !context.permissions.contains(permission) {
missing_permissions.push(permission.clone());
}
}
if !missing_permissions.is_empty() {
self.log_security_event(SecurityEvent {
event_id: uuid::Uuid::new_v4().to_string(),
event_type: SecurityEventType::AuthorizationDenied,
severity: RiskSeverity::Medium,
user_id: context.user_id.clone(),
session_id: context.session_id.clone(),
ip_address: context.ip_address,
timestamp: SystemTime::now(),
description: format!(
"Authorization denied - missing permissions: {missing_permissions:?}"
),
metadata: HashMap::new(),
threat_score: None,
mitigated: false,
})
.await?;
return Err(anyhow!(
"Insufficient permissions: {:?}",
missing_permissions
));
}
let suspicious_score = self.analyze_query_suspiciousness(query).await?;
if self.config.enable_data_loss_prevention {
self.apply_data_loss_prevention(context, query).await?;
}
Ok(AuthorizationResult {
authorized: true,
required_permissions,
applied_filters: Vec::new(),
suspicious_score,
additional_monitoring: suspicious_score > 0.5,
})
}
pub async fn encrypt_data(&self, data: &str, context: &SecurityContext) -> Result<String> {
let encryption_manager = self.encryption_manager.lock().await;
encryption_manager.encrypt(data, &context.session_id).await
}
pub async fn decrypt_data(
&self,
encrypted_data: &str,
context: &SecurityContext,
) -> Result<String> {
let encryption_manager = self.encryption_manager.lock().await;
encryption_manager
.decrypt(encrypted_data, &context.session_id)
.await
}
async fn is_ip_blocked(&self, ip: IpAddr) -> Result<bool> {
let blocked_ips = self.blocked_ips.read().await;
Ok(blocked_ips.contains(&ip))
}
async fn is_user_blocked(&self, user_id: &str) -> Result<bool> {
let blocked_users = self.blocked_users.read().await;
Ok(blocked_users.contains(user_id))
}
async fn check_rate_limits(&self, request: &SecurityRequest) -> Result<()> {
let key = format!("{}:{:?}", request.ip_address, request.user_id);
let mut rate_limiters = self.rate_limiters.write().await;
let rate_limiter = rate_limiters
.entry(key)
.or_insert_with(|| RateLimiter::new(&self.config.rate_limiting));
if !rate_limiter.allow_request().await? {
return Err(anyhow!("Rate limit exceeded"));
}
Ok(())
}
async fn validate_authentication_token(&self, _token: &str) -> Result<AuthenticationInfo> {
Ok(AuthenticationInfo {
user_id: Some("test_user".to_string()),
session_id: uuid::Uuid::new_v4().to_string(),
auth_level: AuthenticationLevel::Basic,
permissions: HashSet::from([Permission::ReadData, Permission::WriteData]),
session_start: SystemTime::now(),
authentication_factors: Vec::new(),
})
}
async fn calculate_trust_score(
&self,
context: &SecurityContext,
_request: &SecurityRequest,
) -> Result<f64> {
let mut score: f64 = 1.0;
for risk_factor in &context.risk_factors {
match risk_factor.severity {
RiskSeverity::Low => score -= 0.1,
RiskSeverity::Medium => score -= 0.3,
RiskSeverity::High => score -= 0.5,
RiskSeverity::Critical => score -= 0.8,
}
}
match context.auth_level {
AuthenticationLevel::Anonymous => score -= 0.5,
AuthenticationLevel::Basic => score -= 0.2,
AuthenticationLevel::MultiFactorAuthenticated => score += 0.1,
AuthenticationLevel::CertificateBased => score += 0.2,
AuthenticationLevel::BiometricVerified => score += 0.3,
}
Ok(score.clamp(0.0, 1.0))
}
async fn analyze_behavior(
&self,
context: &mut SecurityContext,
_request: &SecurityRequest,
) -> Result<()> {
if let Some(ref user_id) = context.user_id {
let profiles = self.behavioral_profiles.read().await;
if let Some(profile) = profiles.get(user_id) {
if self
.is_unusual_access_time(profile, SystemTime::now())
.await?
{
context.risk_factors.push(RiskFactor {
factor_type: RiskFactorType::AnomalousAccess,
severity: RiskSeverity::Medium,
description: "Access outside normal hours".to_string(),
detected_at: SystemTime::now(),
auto_mitigated: false,
});
}
if self
.is_unusual_location(profile, context.ip_address)
.await?
{
context.risk_factors.push(RiskFactor {
factor_type: RiskFactorType::UnusualLocation,
severity: RiskSeverity::High,
description: "Access from unusual location".to_string(),
detected_at: SystemTime::now(),
auto_mitigated: false,
});
}
}
}
Ok(())
}
async fn detect_threats(
&self,
context: &mut SecurityContext,
request: &SecurityRequest,
) -> Result<()> {
let mut threat_detector = self.threat_detector.lock().await;
let threats = threat_detector.analyze_request(context, request).await?;
for threat in threats {
if threat.score > self.config.threat_detection_config.threat_score_threshold {
context.risk_factors.push(RiskFactor {
factor_type: threat.threat_type.clone(),
severity: threat.severity.clone(),
description: threat.description.clone(),
detected_at: SystemTime::now(),
auto_mitigated: false,
});
for action in &self.config.threat_detection_config.response_actions {
self.apply_threat_response(action, context, &threat).await?;
}
}
}
Ok(())
}
async fn validate_session(&self, context: &SecurityContext) -> Result<()> {
let sessions = self.active_sessions.read().await;
if let Some(session) = sessions.get(&context.session_id) {
if session
.last_activity
.elapsed()
.unwrap_or(Duration::from_secs(0))
> self.config.session_timeout
{
return Err(anyhow!("Session expired"));
}
} else {
return Err(anyhow!("Invalid session"));
}
Ok(())
}
async fn extract_query_permissions(&self, query: &Document) -> Result<Vec<Permission>> {
let mut permissions = Vec::new();
for definition in &query.definitions {
if let crate::ast::Definition::Operation(op) = definition {
match op.operation_type {
OperationType::Query => permissions.push(Permission::ReadData),
OperationType::Mutation => permissions.push(Permission::WriteData),
OperationType::Subscription => permissions.push(Permission::ReadData),
}
}
}
Ok(permissions)
}
async fn analyze_query_suspiciousness(&self, _query: &Document) -> Result<f64> {
Ok(0.0)
}
async fn apply_data_loss_prevention(
&self,
_context: &SecurityContext,
_query: &Document,
) -> Result<()> {
Ok(())
}
async fn is_unusual_access_time(
&self,
_profile: &BehavioralProfile,
_access_time: SystemTime,
) -> Result<bool> {
Ok(false)
}
async fn is_unusual_location(&self, _profile: &BehavioralProfile, _ip: IpAddr) -> Result<bool> {
Ok(false)
}
async fn apply_threat_response(
&self,
action: &ThreatResponseAction,
context: &SecurityContext,
threat: &ThreatDetection,
) -> Result<()> {
match action {
ThreatResponseAction::LogEvent => {
self.log_security_event(SecurityEvent {
event_id: uuid::Uuid::new_v4().to_string(),
event_type: SecurityEventType::ThreatDetected,
severity: threat.severity.clone(),
user_id: context.user_id.clone(),
session_id: context.session_id.clone(),
ip_address: context.ip_address,
timestamp: SystemTime::now(),
description: threat.description.clone(),
metadata: HashMap::new(),
threat_score: Some(threat.score),
mitigated: false,
})
.await?;
}
ThreatResponseAction::BlockIP => {
self.blocked_ips.write().await.insert(context.ip_address);
}
ThreatResponseAction::BlockUser => {
if let Some(ref user_id) = context.user_id {
self.blocked_users.write().await.insert(user_id.clone());
}
}
_ => {
}
}
Ok(())
}
async fn log_security_event(&self, event: SecurityEvent) -> Result<()> {
let mut audit_logger = self.audit_logger.lock().await;
audit_logger.log_event(&event).await?;
if self.security_event_sender.send(event).is_err() {
warn!("No security event subscribers");
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SecurityRequest {
pub auth_token: String,
pub ip_address: IpAddr,
pub user_agent: Option<String>,
pub device_id: Option<String>,
pub user_id: Option<String>,
}
#[derive(Debug, Clone)]
pub struct AuthenticationInfo {
pub user_id: Option<String>,
pub session_id: String,
pub auth_level: AuthenticationLevel,
pub permissions: HashSet<Permission>,
pub session_start: SystemTime,
pub authentication_factors: Vec<AuthenticationFactor>,
}
#[derive(Debug, Clone)]
pub struct AuthorizationResult {
pub authorized: bool,
pub required_permissions: Vec<Permission>,
pub applied_filters: Vec<String>,
pub suspicious_score: f64,
pub additional_monitoring: bool,
}
#[derive(Debug, Clone)]
pub struct ThreatDetection {
pub threat_type: RiskFactorType,
pub severity: RiskSeverity,
pub score: f64,
pub description: String,
pub confidence: f64,
}
#[derive(Debug)]
pub struct RateLimiter {
config: RateLimitConfig,
requests: VecDeque<SystemTime>,
}
impl RateLimiter {
pub fn new(config: &RateLimitConfig) -> Self {
Self {
config: config.clone(),
requests: VecDeque::new(),
}
}
pub async fn allow_request(&mut self) -> Result<bool> {
let now = SystemTime::now();
while let Some(&front) = self.requests.front() {
if now.duration_since(front).unwrap_or(Duration::from_secs(0))
> self.config.sliding_window_duration
{
self.requests.pop_front();
} else {
break;
}
}
if self.requests.len() >= self.config.requests_per_minute {
return Ok(false);
}
self.requests.push_back(now);
Ok(true)
}
}
#[derive(Debug)]
pub struct ThreatDetector {
#[allow(dead_code)]
config: ThreatDetectionConfig,
#[allow(dead_code)]
ml_models: HashMap<String, ThreatModel>,
}
impl ThreatDetector {
pub fn new(config: &ThreatDetectionConfig) -> Self {
Self {
config: config.clone(),
ml_models: HashMap::new(),
}
}
pub async fn analyze_request(
&mut self,
_context: &SecurityContext,
_request: &SecurityRequest,
) -> Result<Vec<ThreatDetection>> {
Ok(Vec::new())
}
}
#[derive(Debug)]
pub struct ThreatModel {
pub model_type: String,
pub accuracy: f64,
pub last_trained: SystemTime,
}
#[derive(Debug)]
pub struct AuditLogger {
config: AuditConfig,
events: VecDeque<SecurityEvent>,
}
impl AuditLogger {
pub fn new(config: &AuditConfig) -> Self {
Self {
config: config.clone(),
events: VecDeque::new(),
}
}
pub async fn log_event(&mut self, event: &SecurityEvent) -> Result<()> {
if self.config.enable_audit_logging {
self.events.push_back(event.clone());
info!("Security event logged: {:?}", event);
}
Ok(())
}
}
#[derive(Debug)]
pub struct EncryptionManager {
#[allow(dead_code)]
config: EncryptionConfig,
#[allow(dead_code)]
active_keys: HashMap<String, EncryptionKey>,
}
impl EncryptionManager {
pub fn new(config: &EncryptionConfig) -> Self {
Self {
config: config.clone(),
active_keys: HashMap::new(),
}
}
pub async fn encrypt(&self, data: &str, _key_id: &str) -> Result<String> {
Ok(format!("encrypted:{data}"))
}
pub async fn decrypt(&self, encrypted_data: &str, _key_id: &str) -> Result<String> {
if let Some(stripped) = encrypted_data.strip_prefix("encrypted:") {
Ok(stripped.to_string())
} else {
Err(anyhow!("Invalid encrypted data format"))
}
}
}
#[derive(Debug, Clone)]
pub struct EncryptionKey {
pub key_id: String,
pub algorithm: EncryptionAlgorithm,
pub created_at: SystemTime,
pub expires_at: Option<SystemTime>,
}