1use std::collections::HashMap;
5
6use serde::{Deserialize, Serialize};
7
8use hirn_core::id::MemoryId;
9use hirn_core::timestamp::Timestamp;
10use hirn_core::types::AgentId;
11use hirn_core::{GeneratedCognitionReview, QuarantinedRecordKind};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub enum QuarantineStatus {
16 Pending,
18 Approved,
20 Rejected,
22 RolledBack,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct QuarantineEntry {
29 pub memory_id: MemoryId,
31 pub record_kind: QuarantinedRecordKind,
33 pub record: Vec<u8>,
35 pub anomaly_score: f32,
37 pub reason: String,
39 pub status: QuarantineStatus,
41 pub created_at: Timestamp,
43 pub reviewed_by: Option<AgentId>,
45 pub reviewed_at: Option<Timestamp>,
47 pub generated_review: Option<GeneratedCognitionReview>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct QuarantineApprovalOutcome {
54 pub approved_entry_id: MemoryId,
56 pub applied_memory_ids: Vec<MemoryId>,
58 pub change_summary: String,
60 pub generated_review: Option<GeneratedCognitionReview>,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct QuarantineRollbackOutcome {
67 pub rolled_back_entry_id: MemoryId,
68 pub removed_memory_ids: Vec<MemoryId>,
69 pub restored_memory_ids: Vec<MemoryId>,
70 pub reason: String,
71 pub generated_review: Option<GeneratedCognitionReview>,
72}
73
74#[derive(Debug, Clone)]
78pub struct CorruptionDefenseConfig {
79 pub max_quarantines_per_window: usize,
82 pub window_seconds: u64,
84}
85
86impl Default for CorruptionDefenseConfig {
87 fn default() -> Self {
88 Self {
89 max_quarantines_per_window: 5,
90 window_seconds: 300, }
92 }
93}
94
95#[derive(Debug, Default)]
97pub struct CorruptionDefense {
98 history: HashMap<AgentId, Vec<Timestamp>>,
100 config: CorruptionDefenseConfig,
101}
102
103impl CorruptionDefense {
104 pub fn new(config: CorruptionDefenseConfig) -> Self {
106 Self {
107 history: HashMap::new(),
108 config,
109 }
110 }
111
112 pub fn record_quarantine(&mut self, agent_id: &AgentId) -> bool {
115 let now = Timestamp::now();
116 let cutoff = now
117 .as_datetime()
118 .checked_sub_signed(chrono::Duration::seconds(self.config.window_seconds as i64));
119
120 let timestamps = self.history.entry(agent_id.clone()).or_default();
121
122 if let Some(cutoff_dt) = cutoff {
124 timestamps.retain(|ts| ts.as_datetime() >= cutoff_dt);
125 }
126
127 timestamps.push(now);
128
129 timestamps.len() > self.config.max_quarantines_per_window
130 }
131
132 pub fn is_rate_limited(&self, agent_id: &AgentId) -> bool {
134 let Some(timestamps) = self.history.get(agent_id) else {
135 return false;
136 };
137
138 let now = Timestamp::now();
139 let cutoff = now
140 .as_datetime()
141 .checked_sub_signed(chrono::Duration::seconds(self.config.window_seconds as i64));
142
143 let recent_count = match cutoff {
144 Some(cutoff_dt) => timestamps
145 .iter()
146 .filter(|ts| ts.as_datetime() >= cutoff_dt)
147 .count(),
148 None => timestamps.len(),
149 };
150
151 recent_count > self.config.max_quarantines_per_window
152 }
153
154 pub fn clear_agent(&mut self, agent_id: &AgentId) {
156 self.history.remove(agent_id);
157 }
158
159 pub fn config(&self) -> &CorruptionDefenseConfig {
161 &self.config
162 }
163
164 pub fn snapshot(&self) -> Vec<(String, Vec<u64>)> {
166 self.history
167 .iter()
168 .map(|(agent_id, timestamps)| {
169 let ms: Vec<u64> = timestamps.iter().map(|ts| ts.millis()).collect();
170 (agent_id.to_string(), ms)
171 })
172 .collect()
173 }
174
175 pub fn restore(&mut self, entries: &[(String, Vec<u64>)]) {
177 for (agent_str, timestamps_ms) in entries {
178 if let Ok(agent_id) = AgentId::new(agent_str) {
179 let timestamps: Vec<Timestamp> = timestamps_ms
180 .iter()
181 .map(|&ms| Timestamp::from_millis(ms))
182 .collect();
183 self.history.insert(agent_id, timestamps);
184 }
185 }
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 #[test]
194 fn quarantine_entry_serde_round_trip() {
195 let entry = QuarantineEntry {
196 memory_id: MemoryId::new(),
197 record_kind: QuarantinedRecordKind::Episodic,
198 record: vec![1, 2, 3],
199 anomaly_score: 0.85,
200 reason: "outlier embedding".to_string(),
201 status: QuarantineStatus::Pending,
202 created_at: Timestamp::now(),
203 reviewed_by: None,
204 reviewed_at: None,
205 generated_review: None,
206 };
207 let bytes = bincode::serialize(&entry).unwrap();
208 let back: QuarantineEntry = bincode::deserialize(&bytes).unwrap();
209 assert_eq!(back.memory_id, entry.memory_id);
210 assert_eq!(back.status, QuarantineStatus::Pending);
211 }
212
213 #[test]
214 fn corruption_defense_rate_limits_after_burst() {
215 let config = CorruptionDefenseConfig {
216 max_quarantines_per_window: 3,
217 window_seconds: 300,
218 };
219 let mut defense = CorruptionDefense::new(config);
220 let agent = AgentId::new("bad-agent").unwrap();
221
222 assert!(!defense.record_quarantine(&agent));
223 assert!(!defense.record_quarantine(&agent));
224 assert!(!defense.record_quarantine(&agent));
225 assert!(defense.record_quarantine(&agent));
227 assert!(defense.is_rate_limited(&agent));
228
229 let good_agent = AgentId::new("good-agent").unwrap();
231 assert!(!defense.is_rate_limited(&good_agent));
232 }
233
234 #[test]
235 fn corruption_defense_clear_resets() {
236 let config = CorruptionDefenseConfig {
237 max_quarantines_per_window: 1,
238 window_seconds: 300,
239 };
240 let mut defense = CorruptionDefense::new(config);
241 let agent = AgentId::new("agent").unwrap();
242
243 defense.record_quarantine(&agent);
244 defense.record_quarantine(&agent);
245 assert!(defense.is_rate_limited(&agent));
246
247 defense.clear_agent(&agent);
248 assert!(!defense.is_rate_limited(&agent));
249 }
250}