1pub mod audit;
11pub mod classifier;
12pub mod config;
13pub mod injection;
14pub mod interceptor;
15pub mod sanitizer;
16pub mod taint;
17
18pub use audit::{AuditAction, AuditEntry, AuditEventType, AuditLog};
19pub use classifier::PrivacyClassifier;
20pub use config::{RedactionStrategy, SecurityConfig, SensitivityLevel};
21pub use injection::InjectionDetector;
22pub use injection::ToolOutputInjectionScanner;
23pub use interceptor::ToolInterceptor;
24pub use sanitizer::OutputSanitizer;
25pub use taint::{TaintId, TaintRegistry};
26
27use crate::hooks::HookEventType;
28use crate::hooks::HookHandler;
29use crate::hooks::{Hook, HookConfig, HookEngine};
30use sanitizer::make_replacement;
31use std::sync::{Arc, OnceLock, RwLock};
32
33const HOOK_PREFIX: &str = "security";
35
36pub struct SecurityGuard {
43 session_id: String,
44 config: SecurityConfig,
45 taint_registry: OnceLock<Arc<RwLock<TaintRegistry>>>,
46 classifier: OnceLock<Arc<PrivacyClassifier>>,
47 audit_log: OnceLock<Arc<AuditLog>>,
48 hook_ids: std::sync::Mutex<Vec<String>>,
50}
51
52impl SecurityGuard {
53 pub fn new(session_id: String, config: SecurityConfig) -> Self {
59 Self {
60 session_id,
61 config,
62 taint_registry: OnceLock::new(),
63 classifier: OnceLock::new(),
64 audit_log: OnceLock::new(),
65 hook_ids: std::sync::Mutex::new(Vec::new()),
66 }
67 }
68
69 pub fn register_hooks(&self, hook_engine: &HookEngine) {
75 let mut hook_ids = self.hook_ids.lock().unwrap_or_else(|p| p.into_inner());
76 if !hook_ids.is_empty() {
77 return;
79 }
80
81 if self.config.features.tool_interceptor {
83 let hook_id = format!("{}-interceptor-{}", HOOK_PREFIX, &self.session_id);
84 let interceptor = ToolInterceptor::new(
85 &self.config,
86 self.taint_registry().clone(),
87 self.audit_log().clone(),
88 self.session_id.clone(),
89 );
90 hook_engine.register(Hook::new(&hook_id, HookEventType::PreToolUse).with_config(
91 HookConfig {
92 priority: 1,
93 ..Default::default()
94 },
95 ));
96 hook_engine.register_handler(&hook_id, Arc::new(interceptor) as Arc<dyn HookHandler>);
97 hook_ids.push(hook_id);
98 }
99
100 if self.config.features.output_sanitizer {
102 let hook_id = format!("{}-sanitizer-{}", HOOK_PREFIX, &self.session_id);
103 let sanitizer = OutputSanitizer::new(
104 self.taint_registry().clone(),
105 self.classifier().clone(),
106 self.config.redaction_strategy,
107 self.audit_log().clone(),
108 self.session_id.clone(),
109 );
110 hook_engine.register(Hook::new(&hook_id, HookEventType::GenerateEnd).with_config(
111 HookConfig {
112 priority: 1,
113 ..Default::default()
114 },
115 ));
116 hook_engine.register_handler(&hook_id, Arc::new(sanitizer) as Arc<dyn HookHandler>);
117 hook_ids.push(hook_id);
118 }
119
120 if self.config.features.injection_defense {
122 let hook_id = format!("{}-injection-{}", HOOK_PREFIX, &self.session_id);
123 let detector = InjectionDetector::new(self.audit_log().clone(), self.session_id.clone());
124 hook_engine.register(
125 Hook::new(&hook_id, HookEventType::GenerateStart).with_config(HookConfig {
126 priority: 1,
127 ..Default::default()
128 }),
129 );
130 hook_engine.register_handler(&hook_id, Arc::new(detector) as Arc<dyn HookHandler>);
131 hook_ids.push(hook_id);
132
133 let scanner_id = format!("{}-injection-output-{}", HOOK_PREFIX, &self.session_id);
135 let scanner =
136 ToolOutputInjectionScanner::new(self.audit_log().clone(), self.session_id.clone());
137 hook_engine.register(
138 Hook::new(&scanner_id, HookEventType::PostToolUse).with_config(HookConfig {
139 priority: 1,
140 ..Default::default()
141 }),
142 );
143 hook_engine.register_handler(&scanner_id, Arc::new(scanner) as Arc<dyn HookHandler>);
144 hook_ids.push(scanner_id);
145 }
146 }
147
148 fn taint_registry(&self) -> &Arc<RwLock<TaintRegistry>> {
150 self.taint_registry
151 .get_or_init(|| Arc::new(RwLock::new(TaintRegistry::new())))
152 }
153
154 fn classifier(&self) -> &Arc<PrivacyClassifier> {
156 self.classifier
157 .get_or_init(|| Arc::new(PrivacyClassifier::new(&self.config.classification_rules)))
158 }
159
160 fn audit_log(&self) -> &Arc<AuditLog> {
162 self.audit_log.get_or_init(|| Arc::new(AuditLog::new(10_000)))
163 }
164
165 pub fn taint_input(&self, text: &str) {
167 if !self.config.features.taint_tracking {
168 return;
169 }
170
171 let result = self.classifier().classify(text);
172 if !result.matches.is_empty() {
173 let Ok(mut registry) = self.taint_registry().write() else {
174 tracing::error!("Taint registry lock poisoned — skipping taint registration");
175 return;
176 };
177 for m in &result.matches {
178 let id = registry.register(&m.matched_text, &m.rule_name, m.level);
179 self.audit_log().log(AuditEntry {
180 timestamp: chrono::Utc::now(),
181 session_id: self.session_id.clone(),
182 event_type: AuditEventType::TaintRegistered,
183 severity: m.level,
184 details: format!(
185 "Registered tainted value from rule '{}' (id: {})",
186 m.rule_name, id
187 ),
188 tool_name: None,
189 action_taken: AuditAction::Logged,
190 });
191 }
192 }
193 }
194
195 pub fn sanitize_output(&self, text: &str) -> String {
197 if !self.config.features.output_sanitizer {
198 return text.to_string();
199 }
200
201 let mut result = text.to_string();
202
203 {
205 let Ok(registry) = self.taint_registry().read() else {
206 tracing::error!("Taint registry lock poisoned — returning unsanitized output");
207 return result;
208 };
209 for (_, entry) in registry.entries_iter() {
210 if result.contains(&entry.original_value) {
211 let replacement =
212 make_replacement(&entry.original_value, self.config.redaction_strategy);
213 result = result.replace(&entry.original_value, &replacement);
214 }
215 for variant in &entry.variants {
216 if result.contains(variant.as_str()) {
217 result = result.replace(variant.as_str(), "[REDACTED]");
218 }
219 }
220 }
221 }
222
223 result = self
225 .classifier()
226 .redact(&result, self.config.redaction_strategy);
227
228 result
229 }
230
231 pub fn wipe(&self) {
233 if let Some(registry) = self.taint_registry.get() {
234 if let Ok(mut r) = registry.write() {
235 r.wipe();
236 } else {
237 tracing::error!("Taint registry lock poisoned — cannot wipe");
238 }
239 }
240 if let Some(log) = self.audit_log.get() {
241 log.log(AuditEntry {
242 timestamp: chrono::Utc::now(),
243 session_id: self.session_id.clone(),
244 event_type: AuditEventType::SessionWiped,
245 severity: SensitivityLevel::Normal,
246 details: "Session security state wiped".to_string(),
247 tool_name: None,
248 action_taken: AuditAction::Logged,
249 });
250 log.clear();
251 }
252 }
253
254 pub fn teardown(&self, hook_engine: &HookEngine) {
256 let hook_ids = self.hook_ids.lock().unwrap_or_else(|p| p.into_inner());
257 for hook_id in hook_ids.iter() {
258 hook_engine.unregister_handler(hook_id);
259 hook_engine.unregister(hook_id);
260 }
261 }
262
263 pub fn audit_entries(&self) -> Vec<AuditEntry> {
265 self.audit_log().entries()
266 }
267
268 pub fn get_taint_registry(&self) -> &Arc<RwLock<TaintRegistry>> {
270 self.taint_registry()
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn test_guard_lifecycle() {
280 let engine = HookEngine::new();
281 let config = SecurityConfig::default();
282 let guard = SecurityGuard::new("test-session".to_string(), config);
283 guard.register_hooks(&engine);
284
285 assert_eq!(engine.hook_count(), 4);
287
288 guard.taint_input("My SSN is 123-45-6789");
290
291 {
293 let registry = guard.taint_registry().read().unwrap();
294 assert!(registry.entry_count() > 0);
295 }
296
297 let sanitized = guard.sanitize_output("The SSN 123-45-6789 was found");
299 assert!(!sanitized.contains("123-45-6789"));
300
301 guard.wipe();
303 {
304 let registry = guard.taint_registry().read().unwrap();
305 assert_eq!(registry.entry_count(), 0);
306 }
307
308 guard.teardown(&engine);
310 assert_eq!(engine.hook_count(), 0);
311 }
312
313 #[test]
314 fn test_guard_taint_input_registers_pii() {
315 let engine = HookEngine::new();
316 let config = SecurityConfig::default();
317 let guard = SecurityGuard::new("s1".to_string(), config);
318 guard.register_hooks(&engine);
319
320 guard.taint_input("Contact me at user@example.com or call 555-123-4567");
321
322 let registry = guard.taint_registry().read().unwrap();
323 assert!(registry.entry_count() > 0);
324
325 let entries = guard.audit_entries();
327 assert!(!entries.is_empty());
328 assert!(entries
329 .iter()
330 .any(|e| e.event_type == AuditEventType::TaintRegistered));
331
332 guard.teardown(&engine);
333 }
334
335 #[test]
336 fn test_guard_sanitize_output() {
337 let engine = HookEngine::new();
338 let config = SecurityConfig::default();
339 let guard = SecurityGuard::new("s1".to_string(), config);
340 guard.register_hooks(&engine);
341
342 guard.taint_input("My SSN is 123-45-6789");
344
345 let output = guard.sanitize_output("Found SSN: 123-45-6789 in the data");
347 assert!(!output.contains("123-45-6789"));
348
349 guard.teardown(&engine);
350 }
351
352 #[test]
353 fn test_guard_disabled_features() {
354 let engine = HookEngine::new();
355 let mut config = SecurityConfig::default();
356 config.features.tool_interceptor = false;
357 config.features.output_sanitizer = false;
358 config.features.injection_defense = false;
359 config.features.taint_tracking = false;
360
361 let guard = SecurityGuard::new("s1".to_string(), config);
362 guard.register_hooks(&engine);
363
364 assert_eq!(engine.hook_count(), 0);
366
367 guard.taint_input("SSN: 123-45-6789");
369 assert!(guard.taint_registry.get().is_none());
371
372 let output = guard.sanitize_output("SSN: 123-45-6789");
374 assert_eq!(output, "SSN: 123-45-6789");
375
376 guard.teardown(&engine);
377 }
378
379 #[test]
380 fn test_guard_wipe_and_teardown() {
381 let engine = HookEngine::new();
382 let config = SecurityConfig::default();
383 let guard = SecurityGuard::new("s1".to_string(), config);
384 guard.register_hooks(&engine);
385
386 guard.taint_input("SSN: 123-45-6789");
387 assert!(guard.taint_registry().read().unwrap().entry_count() > 0);
388
389 guard.wipe();
390 assert_eq!(guard.taint_registry().read().unwrap().entry_count(), 0);
391 assert!(guard.audit_entries().is_empty());
392
393 guard.teardown(&engine);
394 assert_eq!(engine.hook_count(), 0);
395 }
396
397 #[test]
398 fn test_guard_lazy_init() {
399 let config = SecurityConfig::default();
401 let guard = SecurityGuard::new("lazy-test".to_string(), config);
402
403 assert!(guard.taint_registry.get().is_none());
405 assert!(guard.classifier.get().is_none());
406 assert!(guard.audit_log.get().is_none());
407
408 let _ = guard.taint_registry();
410 assert!(guard.taint_registry.get().is_some());
411 assert!(guard.classifier.get().is_none()); assert!(guard.audit_log.get().is_none()); }
414
415 #[test]
416 fn test_register_hooks_idempotent() {
417 let engine = HookEngine::new();
418 let config = SecurityConfig::default();
419 let guard = SecurityGuard::new("s1".to_string(), config);
420
421 guard.register_hooks(&engine);
422 let count = engine.hook_count();
423
424 guard.register_hooks(&engine);
426 assert_eq!(engine.hook_count(), count);
427 }
428}