1pub mod audit;
11pub mod classifier;
12pub mod config;
13pub mod injection;
14pub mod interceptor;
15pub mod sanitizer;
16pub mod taint;
17
18pub use audit::{AuditAction, AuditEntry, AuditEventType, AuditLog};
19pub use classifier::PrivacyClassifier;
20pub use config::{RedactionStrategy, SecurityConfig, SensitivityLevel};
21pub use injection::InjectionDetector;
22pub use injection::ToolOutputInjectionScanner;
23pub use interceptor::ToolInterceptor;
24pub use sanitizer::OutputSanitizer;
25pub use taint::{TaintId, TaintRegistry};
26
27use crate::hooks::HookEventType;
28use crate::hooks::HookHandler;
29use crate::hooks::{Hook, HookConfig, HookEngine};
30use sanitizer::make_replacement;
31use std::sync::{Arc, OnceLock, RwLock};
32
33const HOOK_PREFIX: &str = "security";
35
36pub struct SecurityGuard {
43 session_id: String,
44 config: SecurityConfig,
45 taint_registry: OnceLock<Arc<RwLock<TaintRegistry>>>,
46 classifier: OnceLock<Arc<PrivacyClassifier>>,
47 audit_log: OnceLock<Arc<AuditLog>>,
48 hook_ids: std::sync::Mutex<Vec<String>>,
50}
51
52impl SecurityGuard {
53 pub fn new(session_id: String, config: SecurityConfig) -> Self {
59 Self {
60 session_id,
61 config,
62 taint_registry: OnceLock::new(),
63 classifier: OnceLock::new(),
64 audit_log: OnceLock::new(),
65 hook_ids: std::sync::Mutex::new(Vec::new()),
66 }
67 }
68
69 pub fn register_hooks(&self, hook_engine: &HookEngine) {
75 let mut hook_ids = self.hook_ids.lock().unwrap_or_else(|p| p.into_inner());
76 if !hook_ids.is_empty() {
77 return;
79 }
80
81 if self.config.features.tool_interceptor {
83 let hook_id = format!("{}-interceptor-{}", HOOK_PREFIX, &self.session_id);
84 let interceptor = ToolInterceptor::new(
85 &self.config,
86 self.taint_registry().clone(),
87 self.audit_log().clone(),
88 self.session_id.clone(),
89 );
90 hook_engine.register(Hook::new(&hook_id, HookEventType::PreToolUse).with_config(
91 HookConfig {
92 priority: 1,
93 ..Default::default()
94 },
95 ));
96 hook_engine.register_handler(&hook_id, Arc::new(interceptor) as Arc<dyn HookHandler>);
97 hook_ids.push(hook_id);
98 }
99
100 if self.config.features.output_sanitizer {
102 let hook_id = format!("{}-sanitizer-{}", HOOK_PREFIX, &self.session_id);
103 let sanitizer = OutputSanitizer::new(
104 self.taint_registry().clone(),
105 self.classifier().clone(),
106 self.config.redaction_strategy,
107 self.audit_log().clone(),
108 self.session_id.clone(),
109 );
110 hook_engine.register(Hook::new(&hook_id, HookEventType::GenerateEnd).with_config(
111 HookConfig {
112 priority: 1,
113 ..Default::default()
114 },
115 ));
116 hook_engine.register_handler(&hook_id, Arc::new(sanitizer) as Arc<dyn HookHandler>);
117 hook_ids.push(hook_id);
118 }
119
120 if self.config.features.injection_defense {
122 let hook_id = format!("{}-injection-{}", HOOK_PREFIX, &self.session_id);
123 let detector =
124 InjectionDetector::new(self.audit_log().clone(), self.session_id.clone());
125 hook_engine.register(
126 Hook::new(&hook_id, HookEventType::GenerateStart).with_config(HookConfig {
127 priority: 1,
128 ..Default::default()
129 }),
130 );
131 hook_engine.register_handler(&hook_id, Arc::new(detector) as Arc<dyn HookHandler>);
132 hook_ids.push(hook_id);
133
134 let scanner_id = format!("{}-injection-output-{}", HOOK_PREFIX, &self.session_id);
136 let scanner =
137 ToolOutputInjectionScanner::new(self.audit_log().clone(), self.session_id.clone());
138 hook_engine.register(
139 Hook::new(&scanner_id, HookEventType::PostToolUse).with_config(HookConfig {
140 priority: 1,
141 ..Default::default()
142 }),
143 );
144 hook_engine.register_handler(&scanner_id, Arc::new(scanner) as Arc<dyn HookHandler>);
145 hook_ids.push(scanner_id);
146 }
147 }
148
149 fn taint_registry(&self) -> &Arc<RwLock<TaintRegistry>> {
151 self.taint_registry
152 .get_or_init(|| Arc::new(RwLock::new(TaintRegistry::new())))
153 }
154
155 fn classifier(&self) -> &Arc<PrivacyClassifier> {
157 self.classifier
158 .get_or_init(|| Arc::new(PrivacyClassifier::new(&self.config.classification_rules)))
159 }
160
161 fn audit_log(&self) -> &Arc<AuditLog> {
163 self.audit_log
164 .get_or_init(|| Arc::new(AuditLog::new(10_000)))
165 }
166
167 pub fn taint_input(&self, text: &str) {
169 if !self.config.features.taint_tracking {
170 return;
171 }
172
173 let result = self.classifier().classify(text);
174 if !result.matches.is_empty() {
175 let Ok(mut registry) = self.taint_registry().write() else {
176 tracing::error!("Taint registry lock poisoned — skipping taint registration");
177 return;
178 };
179 for m in &result.matches {
180 let id = registry.register(&m.matched_text, &m.rule_name, m.level);
181 self.audit_log().log(AuditEntry {
182 timestamp: chrono::Utc::now(),
183 session_id: self.session_id.clone(),
184 event_type: AuditEventType::TaintRegistered,
185 severity: m.level,
186 details: format!(
187 "Registered tainted value from rule '{}' (id: {})",
188 m.rule_name, id
189 ),
190 tool_name: None,
191 action_taken: AuditAction::Logged,
192 });
193 }
194 }
195 }
196
197 pub fn sanitize_output(&self, text: &str) -> String {
199 if !self.config.features.output_sanitizer {
200 return text.to_string();
201 }
202
203 let mut result = text.to_string();
204
205 {
207 let Ok(registry) = self.taint_registry().read() else {
208 tracing::error!("Taint registry lock poisoned — returning unsanitized output");
209 return result;
210 };
211 for (_, entry) in registry.entries_iter() {
212 if result.contains(&entry.original_value) {
213 let replacement =
214 make_replacement(&entry.original_value, self.config.redaction_strategy);
215 result = result.replace(&entry.original_value, &replacement);
216 }
217 for variant in &entry.variants {
218 if result.contains(variant.as_str()) {
219 result = result.replace(variant.as_str(), "[REDACTED]");
220 }
221 }
222 }
223 }
224
225 result = self
227 .classifier()
228 .redact(&result, self.config.redaction_strategy);
229
230 result
231 }
232
233 pub fn wipe(&self) {
235 if let Some(registry) = self.taint_registry.get() {
236 if let Ok(mut r) = registry.write() {
237 r.wipe();
238 } else {
239 tracing::error!("Taint registry lock poisoned — cannot wipe");
240 }
241 }
242 if let Some(log) = self.audit_log.get() {
243 log.log(AuditEntry {
244 timestamp: chrono::Utc::now(),
245 session_id: self.session_id.clone(),
246 event_type: AuditEventType::SessionWiped,
247 severity: SensitivityLevel::Normal,
248 details: "Session security state wiped".to_string(),
249 tool_name: None,
250 action_taken: AuditAction::Logged,
251 });
252 log.clear();
253 }
254 }
255
256 pub fn teardown(&self, hook_engine: &HookEngine) {
258 let hook_ids = self.hook_ids.lock().unwrap_or_else(|p| p.into_inner());
259 for hook_id in hook_ids.iter() {
260 hook_engine.unregister_handler(hook_id);
261 hook_engine.unregister(hook_id);
262 }
263 }
264
265 pub fn audit_entries(&self) -> Vec<AuditEntry> {
267 self.audit_log().entries()
268 }
269
270 pub fn get_taint_registry(&self) -> &Arc<RwLock<TaintRegistry>> {
272 self.taint_registry()
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_guard_lifecycle() {
282 let engine = HookEngine::new();
283 let config = SecurityConfig::default();
284 let guard = SecurityGuard::new("test-session".to_string(), config);
285 guard.register_hooks(&engine);
286
287 assert_eq!(engine.hook_count(), 4);
289
290 guard.taint_input("My SSN is 123-45-6789");
292
293 {
295 let registry = guard.taint_registry().read().unwrap();
296 assert!(registry.entry_count() > 0);
297 }
298
299 let sanitized = guard.sanitize_output("The SSN 123-45-6789 was found");
301 assert!(!sanitized.contains("123-45-6789"));
302
303 guard.wipe();
305 {
306 let registry = guard.taint_registry().read().unwrap();
307 assert_eq!(registry.entry_count(), 0);
308 }
309
310 guard.teardown(&engine);
312 assert_eq!(engine.hook_count(), 0);
313 }
314
315 #[test]
316 fn test_guard_taint_input_registers_pii() {
317 let engine = HookEngine::new();
318 let config = SecurityConfig::default();
319 let guard = SecurityGuard::new("s1".to_string(), config);
320 guard.register_hooks(&engine);
321
322 guard.taint_input("Contact me at user@example.com or call 555-123-4567");
323
324 let registry = guard.taint_registry().read().unwrap();
325 assert!(registry.entry_count() > 0);
326
327 let entries = guard.audit_entries();
329 assert!(!entries.is_empty());
330 assert!(entries
331 .iter()
332 .any(|e| e.event_type == AuditEventType::TaintRegistered));
333
334 guard.teardown(&engine);
335 }
336
337 #[test]
338 fn test_guard_sanitize_output() {
339 let engine = HookEngine::new();
340 let config = SecurityConfig::default();
341 let guard = SecurityGuard::new("s1".to_string(), config);
342 guard.register_hooks(&engine);
343
344 guard.taint_input("My SSN is 123-45-6789");
346
347 let output = guard.sanitize_output("Found SSN: 123-45-6789 in the data");
349 assert!(!output.contains("123-45-6789"));
350
351 guard.teardown(&engine);
352 }
353
354 #[test]
355 fn test_guard_disabled_features() {
356 let engine = HookEngine::new();
357 let mut config = SecurityConfig::default();
358 config.features.tool_interceptor = false;
359 config.features.output_sanitizer = false;
360 config.features.injection_defense = false;
361 config.features.taint_tracking = false;
362
363 let guard = SecurityGuard::new("s1".to_string(), config);
364 guard.register_hooks(&engine);
365
366 assert_eq!(engine.hook_count(), 0);
368
369 guard.taint_input("SSN: 123-45-6789");
371 assert!(guard.taint_registry.get().is_none());
373
374 let output = guard.sanitize_output("SSN: 123-45-6789");
376 assert_eq!(output, "SSN: 123-45-6789");
377
378 guard.teardown(&engine);
379 }
380
381 #[test]
382 fn test_guard_wipe_and_teardown() {
383 let engine = HookEngine::new();
384 let config = SecurityConfig::default();
385 let guard = SecurityGuard::new("s1".to_string(), config);
386 guard.register_hooks(&engine);
387
388 guard.taint_input("SSN: 123-45-6789");
389 assert!(guard.taint_registry().read().unwrap().entry_count() > 0);
390
391 guard.wipe();
392 assert_eq!(guard.taint_registry().read().unwrap().entry_count(), 0);
393 assert!(guard.audit_entries().is_empty());
394
395 guard.teardown(&engine);
396 assert_eq!(engine.hook_count(), 0);
397 }
398
399 #[test]
400 fn test_guard_lazy_init() {
401 let config = SecurityConfig::default();
403 let guard = SecurityGuard::new("lazy-test".to_string(), config);
404
405 assert!(guard.taint_registry.get().is_none());
407 assert!(guard.classifier.get().is_none());
408 assert!(guard.audit_log.get().is_none());
409
410 let _ = guard.taint_registry();
412 assert!(guard.taint_registry.get().is_some());
413 assert!(guard.classifier.get().is_none()); assert!(guard.audit_log.get().is_none()); }
416
417 #[test]
418 fn test_register_hooks_idempotent() {
419 let engine = HookEngine::new();
420 let config = SecurityConfig::default();
421 let guard = SecurityGuard::new("s1".to_string(), config);
422
423 guard.register_hooks(&engine);
424 let count = engine.hook_count();
425
426 guard.register_hooks(&engine);
428 assert_eq!(engine.hook_count(), count);
429 }
430}