use std::collections::HashMap;
use std::fmt;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum SafetySeverity {
Critical,
High,
Medium,
Low,
}
impl SafetySeverity {
pub(super) fn weight(self) -> f64 {
match self {
Self::Critical => 1.0,
Self::High => 0.7,
Self::Medium => 0.4,
Self::Low => 0.1,
}
}
}
impl fmt::Display for SafetySeverity {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Critical => write!(f, "CRITICAL"),
Self::High => write!(f, "HIGH"),
Self::Medium => write!(f, "MEDIUM"),
Self::Low => write!(f, "LOW"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum SafetyEnforcement {
Block,
Warn,
AuditOnly,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum SafetyRuleType {
ResourceLimit { resource: String, max_value: u64 },
ForbiddenAction { pattern: String },
RequireApproval { action_pattern: String },
RateLimit {
action_pattern: String,
max_per_minute: u32,
},
ContentFilter { forbidden_patterns: Vec<String> },
ScopeRestriction {
allowed_paths: Vec<String>,
denied_paths: Vec<String>,
},
EscalationRequired {
from_level: String,
to_level: String,
},
OutputValidation {
max_length: usize,
require_utf8: bool,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SafetyRule {
pub rule_id: String,
pub description: String,
pub rule_type: SafetyRuleType,
pub severity: SafetySeverity,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SafetyPolicy {
pub policy_id: String,
pub name: String,
pub rules: Vec<SafetyRule>,
pub enforcement: SafetyEnforcement,
pub priority: u8,
pub enabled: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum ActionType {
FileAccess,
ProcessSpawn,
NetworkRequest,
SystemCommand,
DataOutput,
PrivilegeEscalation,
}
impl fmt::Display for ActionType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::FileAccess => write!(f, "FileAccess"),
Self::ProcessSpawn => write!(f, "ProcessSpawn"),
Self::NetworkRequest => write!(f, "NetworkRequest"),
Self::SystemCommand => write!(f, "SystemCommand"),
Self::DataOutput => write!(f, "DataOutput"),
Self::PrivilegeEscalation => write!(f, "PrivilegeEscalation"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SafetyAction {
pub action_type: ActionType,
pub target: String,
pub parameters: HashMap<String, String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum SafetyVerdict {
Allowed,
Blocked { reason: String, rule_id: String },
RequiresApproval { reason: String, rule_id: String },
RateLimited { retry_after_secs: u32 },
Warning { message: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SafetyViolation {
pub violation_id: String,
pub agent_id: String,
pub timestamp: DateTime<Utc>,
pub rule_id: String,
pub action_attempted: String,
pub verdict: SafetyVerdict,
pub severity: SafetySeverity,
}
#[derive(Debug, Clone)]
pub(super) struct RateBucket {
pub(super) timestamps: Vec<std::time::Instant>,
}
impl RateBucket {
pub(super) fn new() -> Self {
Self {
timestamps: Vec::new(),
}
}
pub(super) fn record_and_count(&mut self) -> usize {
let now = std::time::Instant::now();
let cutoff = now - std::time::Duration::from_secs(60);
self.timestamps.retain(|t| *t >= cutoff);
self.timestamps.push(now);
self.timestamps.len()
}
}