use anyhow::Result;
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum TrustLevel {
Untrusted = 0,
#[default]
Low = 1,
Medium = 2,
High = 3,
System = 4,
}
impl TrustLevel {
pub fn from_u8(level: u8) -> Self {
match level {
0 => TrustLevel::Untrusted,
1 => TrustLevel::Low,
2 => TrustLevel::Medium,
3 => TrustLevel::High,
4 => TrustLevel::System,
_ => TrustLevel::Low,
}
}
pub fn as_u8(&self) -> u8 {
*self as u8
}
pub fn from_score(score: f32) -> Self {
if score >= 0.9 {
TrustLevel::High
} else if score >= 0.7 {
TrustLevel::Medium
} else if score >= 0.4 {
TrustLevel::Low
} else {
TrustLevel::Untrusted
}
}
}
impl std::fmt::Display for TrustLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TrustLevel::Untrusted => write!(f, "Untrusted"),
TrustLevel::Low => write!(f, "Low"),
TrustLevel::Medium => write!(f, "Medium"),
TrustLevel::High => write!(f, "High"),
TrustLevel::System => write!(f, "System"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ViolationSeverity {
Minor,
Major,
Critical,
}
impl ViolationSeverity {
pub fn penalty(&self) -> f32 {
match self {
ViolationSeverity::Minor => 0.02,
ViolationSeverity::Major => 0.08,
ViolationSeverity::Critical => 0.15,
}
}
pub fn recent_penalty(&self) -> f32 {
match self {
ViolationSeverity::Minor => 0.04,
ViolationSeverity::Major => 0.15,
ViolationSeverity::Critical => 0.30,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ViolationCounts {
pub minor: u32,
pub major: u32,
pub critical: u32,
pub recent_minor: u32,
pub recent_major: u32,
pub recent_critical: u32,
}
impl ViolationCounts {
pub fn total_penalty(&self) -> f32 {
let base_penalty = (self.minor as f32 * ViolationSeverity::Minor.penalty())
+ (self.major as f32 * ViolationSeverity::Major.penalty())
+ (self.critical as f32 * ViolationSeverity::Critical.penalty());
let recent_penalty = (self.recent_minor as f32 * ViolationSeverity::Minor.recent_penalty())
+ (self.recent_major as f32 * ViolationSeverity::Major.recent_penalty())
+ (self.recent_critical as f32 * ViolationSeverity::Critical.recent_penalty());
base_penalty + recent_penalty
}
pub fn record(&mut self, severity: ViolationSeverity) {
match severity {
ViolationSeverity::Minor => {
self.minor += 1;
self.recent_minor += 1;
}
ViolationSeverity::Major => {
self.major += 1;
self.recent_major += 1;
}
ViolationSeverity::Critical => {
self.critical += 1;
self.recent_critical += 1;
}
}
}
pub fn decay_recent(&mut self) {
self.recent_minor = 0;
self.recent_major = 0;
self.recent_critical = 0;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrustFactor {
pub agent_id: String,
pub score: f32,
pub level: TrustLevel,
pub violations: ViolationCounts,
pub successful_ops: u64,
pub total_ops: u64,
pub last_updated: DateTime<Utc>,
pub violations_decay_at: DateTime<Utc>,
pub is_system: bool,
}
impl TrustFactor {
pub fn new(agent_id: &str) -> Self {
Self {
agent_id: agent_id.to_string(),
score: 0.5, level: TrustLevel::Low,
violations: ViolationCounts::default(),
successful_ops: 0,
total_ops: 0,
last_updated: Utc::now(),
violations_decay_at: Utc::now() + Duration::hours(24),
is_system: false,
}
}
pub fn system(agent_id: &str) -> Self {
Self {
agent_id: agent_id.to_string(),
score: 1.0,
level: TrustLevel::System,
violations: ViolationCounts::default(),
successful_ops: 0,
total_ops: 0,
last_updated: Utc::now(),
violations_decay_at: Utc::now(),
is_system: true,
}
}
pub fn record_success(&mut self) {
self.successful_ops += 1;
self.total_ops += 1;
self.recalculate();
}
pub fn record_failure(&mut self) {
self.total_ops += 1;
self.recalculate();
}
pub fn record_violation(&mut self, severity: ViolationSeverity) {
self.violations.record(severity);
self.total_ops += 1;
self.recalculate();
}
fn recalculate(&mut self) {
if self.is_system {
return; }
if Utc::now() > self.violations_decay_at {
self.violations.decay_recent();
self.violations_decay_at = Utc::now() + Duration::hours(24);
}
let base_score = if self.total_ops > 0 {
self.successful_ops as f32 / self.total_ops as f32
} else {
0.5 };
let penalty = self.violations.total_penalty();
self.score = (base_score - penalty).clamp(0.0, 1.0);
self.level = TrustLevel::from_score(self.score);
self.last_updated = Utc::now();
}
pub fn set_level(&mut self, level: TrustLevel) {
self.level = level;
self.score = match level {
TrustLevel::Untrusted => 0.2,
TrustLevel::Low => 0.5,
TrustLevel::Medium => 0.75,
TrustLevel::High => 0.95,
TrustLevel::System => 1.0,
};
self.last_updated = Utc::now();
}
pub fn reset(&mut self) {
self.score = 0.5;
self.level = TrustLevel::Low;
self.violations = ViolationCounts::default();
self.successful_ops = 0;
self.total_ops = 0;
self.last_updated = Utc::now();
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
struct TrustStore {
factors: HashMap<String, TrustFactor>,
last_saved: DateTime<Utc>,
}
#[derive(Debug)]
pub struct TrustManager {
factors: HashMap<String, TrustFactor>,
store_path: PathBuf,
persist: bool,
}
impl TrustManager {
#[cfg(feature = "native")]
pub fn new() -> Result<Self> {
let store_path = dirs::home_dir()
.ok_or_else(|| anyhow::anyhow!("Failed to get home directory"))?
.join(".brainwires")
.join("trust_store.json");
let mut manager = Self {
factors: HashMap::new(),
store_path,
persist: true,
};
manager.load()?;
Ok(manager)
}
pub fn with_path(path: PathBuf) -> Result<Self> {
let mut manager = Self {
factors: HashMap::new(),
store_path: path,
persist: true,
};
manager.load()?;
Ok(manager)
}
pub fn in_memory() -> Self {
Self {
factors: HashMap::new(),
store_path: PathBuf::new(),
persist: false,
}
}
fn load(&mut self) -> Result<()> {
if !self.store_path.exists() {
return Ok(());
}
let content = fs::read_to_string(&self.store_path)?;
let store: TrustStore = serde_json::from_str(&content)?;
self.factors = store.factors;
for factor in self.factors.values_mut() {
if Utc::now() > factor.violations_decay_at {
factor.violations.decay_recent();
factor.violations_decay_at = Utc::now() + Duration::hours(24);
}
}
Ok(())
}
pub fn save(&self) -> Result<()> {
if !self.persist {
return Ok(());
}
if let Some(parent) = self.store_path.parent() {
fs::create_dir_all(parent)?;
}
let store = TrustStore {
factors: self.factors.clone(),
last_saved: Utc::now(),
};
let content = serde_json::to_string_pretty(&store)?;
fs::write(&self.store_path, content)?;
Ok(())
}
pub fn get_or_create(&mut self, agent_id: &str) -> &mut TrustFactor {
self.factors
.entry(agent_id.to_string())
.or_insert_with(|| TrustFactor::new(agent_id))
}
pub fn get(&self, agent_id: &str) -> Option<&TrustFactor> {
self.factors.get(agent_id)
}
pub fn get_trust_level(&self, agent_id: &str) -> TrustLevel {
self.factors
.get(agent_id)
.map(|f| f.level)
.unwrap_or(TrustLevel::Low)
}
pub fn record_success(&mut self, agent_id: &str) {
let factor = self.get_or_create(agent_id);
factor.record_success();
let _ = self.save();
}
pub fn record_failure(&mut self, agent_id: &str) {
let factor = self.get_or_create(agent_id);
factor.record_failure();
let _ = self.save();
}
pub fn record_violation(&mut self, agent_id: &str, severity: ViolationSeverity) {
let factor = self.get_or_create(agent_id);
factor.record_violation(severity);
let _ = self.save();
}
pub fn set_trust_level(&mut self, agent_id: &str, level: TrustLevel) {
let factor = self.get_or_create(agent_id);
factor.set_level(level);
let _ = self.save();
}
pub fn reset(&mut self, agent_id: &str) {
if let Some(factor) = self.factors.get_mut(agent_id) {
factor.reset();
let _ = self.save();
}
}
pub fn remove(&mut self, agent_id: &str) -> Option<TrustFactor> {
let removed = self.factors.remove(agent_id);
let _ = self.save();
removed
}
pub fn register_system_agent(&mut self, agent_id: &str) {
self.factors
.insert(agent_id.to_string(), TrustFactor::system(agent_id));
let _ = self.save();
}
pub fn agents(&self) -> Vec<&str> {
self.factors.keys().map(|s| s.as_str()).collect()
}
pub fn statistics(&self) -> TrustStatistics {
let mut stats = TrustStatistics {
total_agents: self.factors.len(),
..Default::default()
};
for factor in self.factors.values() {
match factor.level {
TrustLevel::Untrusted => stats.untrusted += 1,
TrustLevel::Low => stats.low_trust += 1,
TrustLevel::Medium => stats.medium_trust += 1,
TrustLevel::High => stats.high_trust += 1,
TrustLevel::System => stats.system += 1,
}
stats.total_violations += factor.violations.minor as usize
+ factor.violations.major as usize
+ factor.violations.critical as usize;
stats.total_operations += factor.total_ops as usize;
}
if stats.total_operations > 0 {
let total_success: u64 = self.factors.values().map(|f| f.successful_ops).sum();
stats.average_score = total_success as f32 / stats.total_operations as f32;
}
stats
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TrustStatistics {
pub total_agents: usize,
pub untrusted: usize,
pub low_trust: usize,
pub medium_trust: usize,
pub high_trust: usize,
pub system: usize,
pub total_violations: usize,
pub total_operations: usize,
pub average_score: f32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trust_level_from_score() {
assert_eq!(TrustLevel::from_score(0.95), TrustLevel::High);
assert_eq!(TrustLevel::from_score(0.9), TrustLevel::High);
assert_eq!(TrustLevel::from_score(0.85), TrustLevel::Medium);
assert_eq!(TrustLevel::from_score(0.7), TrustLevel::Medium);
assert_eq!(TrustLevel::from_score(0.5), TrustLevel::Low);
assert_eq!(TrustLevel::from_score(0.4), TrustLevel::Low);
assert_eq!(TrustLevel::from_score(0.3), TrustLevel::Untrusted);
assert_eq!(TrustLevel::from_score(0.0), TrustLevel::Untrusted);
}
#[test]
fn test_trust_factor_success_increases_score() {
let mut factor = TrustFactor::new("test-agent");
for _ in 0..10 {
factor.record_success();
}
assert!(factor.score > 0.5);
assert_eq!(factor.successful_ops, 10);
assert_eq!(factor.total_ops, 10);
}
#[test]
fn test_trust_factor_violations_decrease_score() {
let mut factor = TrustFactor::new("test-agent");
for _ in 0..10 {
factor.record_success();
}
let initial_score = factor.score;
factor.record_violation(ViolationSeverity::Major);
assert!(factor.score < initial_score);
}
#[test]
fn test_trust_factor_critical_violation() {
let mut factor = TrustFactor::new("test-agent");
factor.record_violation(ViolationSeverity::Critical);
assert!(factor.score < 0.4);
assert_eq!(factor.level, TrustLevel::Untrusted);
}
#[test]
fn test_system_agent_always_trusted() {
let mut factor = TrustFactor::system("system-agent");
factor.record_violation(ViolationSeverity::Critical);
assert_eq!(factor.level, TrustLevel::System);
assert_eq!(factor.score, 1.0);
}
#[test]
fn test_trust_manager() {
let mut manager = TrustManager::in_memory();
manager.record_success("agent-1");
manager.record_success("agent-1");
manager.record_violation("agent-2", ViolationSeverity::Minor);
assert!(manager.get_trust_level("agent-1") >= TrustLevel::Low);
let stats = manager.statistics();
assert_eq!(stats.total_agents, 2);
}
#[test]
fn test_violation_counts() {
let mut counts = ViolationCounts::default();
counts.record(ViolationSeverity::Minor);
counts.record(ViolationSeverity::Major);
counts.record(ViolationSeverity::Critical);
assert_eq!(counts.minor, 1);
assert_eq!(counts.major, 1);
assert_eq!(counts.critical, 1);
let penalty = counts.total_penalty();
assert!(penalty > 0.2);
}
#[test]
fn test_trust_level_ordering() {
assert!(TrustLevel::System > TrustLevel::High);
assert!(TrustLevel::High > TrustLevel::Medium);
assert!(TrustLevel::Medium > TrustLevel::Low);
assert!(TrustLevel::Low > TrustLevel::Untrusted);
}
#[test]
fn test_reset_trust() {
let mut manager = TrustManager::in_memory();
for _ in 0..20 {
manager.record_success("agent-1");
}
manager.record_violation("agent-1", ViolationSeverity::Critical);
manager.reset("agent-1");
let factor = manager.get("agent-1").unwrap();
assert_eq!(factor.score, 0.5);
assert_eq!(factor.level, TrustLevel::Low);
assert_eq!(factor.successful_ops, 0);
}
}