Skip to main content

a3s_code_core/security/
mod.rs

1//! Security Module
2//!
3//! Provides security features for A3S Code sessions:
4//! - **Output Sanitizer**: Redacts sensitive data from LLM responses
5//! - **Taint Tracking**: Tracks sensitive values and their encoded variants
6//! - **Tool Interceptor**: Blocks dangerous tool invocations
7//! - **Session Isolation**: Per-session security state with secure wipe
8//! - **Prompt Injection Defense**: Detects and blocks injection attempts
9
10pub mod audit;
11pub mod classifier;
12pub mod config;
13pub mod injection;
14pub mod interceptor;
15pub mod sanitizer;
16pub mod taint;
17
18pub use audit::{AuditAction, AuditEntry, AuditEventType, AuditLog};
19pub use classifier::PrivacyClassifier;
20pub use config::{RedactionStrategy, SecurityConfig, SensitivityLevel};
21pub use injection::InjectionDetector;
22pub use injection::ToolOutputInjectionScanner;
23pub use interceptor::ToolInterceptor;
24pub use sanitizer::OutputSanitizer;
25pub use taint::{TaintId, TaintRegistry};
26
27use crate::hooks::HookEventType;
28use crate::hooks::HookHandler;
29use crate::hooks::{Hook, HookConfig, HookEngine};
30use sanitizer::make_replacement;
31use std::sync::{Arc, OnceLock, RwLock};
32
33/// Hook ID prefix for security hooks
34const HOOK_PREFIX: &str = "security";
35
36/// Per-session security orchestrator
37///
38/// Subsystems (`TaintRegistry`, `PrivacyClassifier`, `AuditLog`) are lazily
39/// initialized on first access via `OnceLock`. This avoids allocating a
40/// `HashMap` (taint), compiling regex (classifier), and creating a 10,000-
41/// capacity `Vec` (audit) when the corresponding feature is disabled.
42pub struct SecurityGuard {
43    session_id: String,
44    config: SecurityConfig,
45    taint_registry: OnceLock<Arc<RwLock<TaintRegistry>>>,
46    classifier: OnceLock<Arc<PrivacyClassifier>>,
47    audit_log: OnceLock<Arc<AuditLog>>,
48    /// Hook IDs registered by this guard (for teardown)
49    hook_ids: std::sync::Mutex<Vec<String>>,
50}
51
52impl SecurityGuard {
53    /// Create a new SecurityGuard without registering hooks.
54    ///
55    /// Call [`register_hooks`] separately with a real `HookEngine` to
56    /// register security hooks. This avoids the previous bug where hooks
57    /// were registered to a temporary engine that was immediately dropped.
58    pub fn new(session_id: String, config: SecurityConfig) -> Self {
59        Self {
60            session_id,
61            config,
62            taint_registry: OnceLock::new(),
63            classifier: OnceLock::new(),
64            audit_log: OnceLock::new(),
65            hook_ids: std::sync::Mutex::new(Vec::new()),
66        }
67    }
68
69    /// Register security hooks with the given engine.
70    ///
71    /// Must be called with a long-lived `HookEngine` — not a temporary.
72    /// Safe to call multiple times (idempotent: skips if hooks already
73    /// registered).
74    pub fn register_hooks(&self, hook_engine: &HookEngine) {
75        let mut hook_ids = self.hook_ids.lock().unwrap_or_else(|p| p.into_inner());
76        if !hook_ids.is_empty() {
77            // Already registered
78            return;
79        }
80
81        // Register tool interceptor hook
82        if self.config.features.tool_interceptor {
83            let hook_id = format!("{}-interceptor-{}", HOOK_PREFIX, &self.session_id);
84            let interceptor = ToolInterceptor::new(
85                &self.config,
86                self.taint_registry().clone(),
87                self.audit_log().clone(),
88                self.session_id.clone(),
89            );
90            hook_engine.register(Hook::new(&hook_id, HookEventType::PreToolUse).with_config(
91                HookConfig {
92                    priority: 1,
93                    ..Default::default()
94                },
95            ));
96            hook_engine.register_handler(&hook_id, Arc::new(interceptor) as Arc<dyn HookHandler>);
97            hook_ids.push(hook_id);
98        }
99
100        // Register output sanitizer hook
101        if self.config.features.output_sanitizer {
102            let hook_id = format!("{}-sanitizer-{}", HOOK_PREFIX, &self.session_id);
103            let sanitizer = OutputSanitizer::new(
104                self.taint_registry().clone(),
105                self.classifier().clone(),
106                self.config.redaction_strategy,
107                self.audit_log().clone(),
108                self.session_id.clone(),
109            );
110            hook_engine.register(Hook::new(&hook_id, HookEventType::GenerateEnd).with_config(
111                HookConfig {
112                    priority: 1,
113                    ..Default::default()
114                },
115            ));
116            hook_engine.register_handler(&hook_id, Arc::new(sanitizer) as Arc<dyn HookHandler>);
117            hook_ids.push(hook_id);
118        }
119
120        // Register injection detector hook
121        if self.config.features.injection_defense {
122            let hook_id = format!("{}-injection-{}", HOOK_PREFIX, &self.session_id);
123            let detector = InjectionDetector::new(self.audit_log().clone(), self.session_id.clone());
124            hook_engine.register(
125                Hook::new(&hook_id, HookEventType::GenerateStart).with_config(HookConfig {
126                    priority: 1,
127                    ..Default::default()
128                }),
129            );
130            hook_engine.register_handler(&hook_id, Arc::new(detector) as Arc<dyn HookHandler>);
131            hook_ids.push(hook_id);
132
133            // Also register PostToolUse scanner for indirect injection via tool outputs
134            let scanner_id = format!("{}-injection-output-{}", HOOK_PREFIX, &self.session_id);
135            let scanner =
136                ToolOutputInjectionScanner::new(self.audit_log().clone(), self.session_id.clone());
137            hook_engine.register(
138                Hook::new(&scanner_id, HookEventType::PostToolUse).with_config(HookConfig {
139                    priority: 1,
140                    ..Default::default()
141                }),
142            );
143            hook_engine.register_handler(&scanner_id, Arc::new(scanner) as Arc<dyn HookHandler>);
144            hook_ids.push(scanner_id);
145        }
146    }
147
148    /// Lazily initialize and return the taint registry
149    fn taint_registry(&self) -> &Arc<RwLock<TaintRegistry>> {
150        self.taint_registry
151            .get_or_init(|| Arc::new(RwLock::new(TaintRegistry::new())))
152    }
153
154    /// Lazily initialize and return the privacy classifier
155    fn classifier(&self) -> &Arc<PrivacyClassifier> {
156        self.classifier
157            .get_or_init(|| Arc::new(PrivacyClassifier::new(&self.config.classification_rules)))
158    }
159
160    /// Lazily initialize and return the audit log
161    fn audit_log(&self) -> &Arc<AuditLog> {
162        self.audit_log.get_or_init(|| Arc::new(AuditLog::new(10_000)))
163    }
164
165    /// Classify input text and register any detected sensitive data as tainted
166    pub fn taint_input(&self, text: &str) {
167        if !self.config.features.taint_tracking {
168            return;
169        }
170
171        let result = self.classifier().classify(text);
172        if !result.matches.is_empty() {
173            let Ok(mut registry) = self.taint_registry().write() else {
174                tracing::error!("Taint registry lock poisoned — skipping taint registration");
175                return;
176            };
177            for m in &result.matches {
178                let id = registry.register(&m.matched_text, &m.rule_name, m.level);
179                self.audit_log().log(AuditEntry {
180                    timestamp: chrono::Utc::now(),
181                    session_id: self.session_id.clone(),
182                    event_type: AuditEventType::TaintRegistered,
183                    severity: m.level,
184                    details: format!(
185                        "Registered tainted value from rule '{}' (id: {})",
186                        m.rule_name, id
187                    ),
188                    tool_name: None,
189                    action_taken: AuditAction::Logged,
190                });
191            }
192        }
193    }
194
195    /// Sanitize output text by redacting tainted and classified sensitive data
196    pub fn sanitize_output(&self, text: &str) -> String {
197        if !self.config.features.output_sanitizer {
198            return text.to_string();
199        }
200
201        let mut result = text.to_string();
202
203        // Check taint registry
204        {
205            let Ok(registry) = self.taint_registry().read() else {
206                tracing::error!("Taint registry lock poisoned — returning unsanitized output");
207                return result;
208            };
209            for (_, entry) in registry.entries_iter() {
210                if result.contains(&entry.original_value) {
211                    let replacement =
212                        make_replacement(&entry.original_value, self.config.redaction_strategy);
213                    result = result.replace(&entry.original_value, &replacement);
214                }
215                for variant in &entry.variants {
216                    if result.contains(variant.as_str()) {
217                        result = result.replace(variant.as_str(), "[REDACTED]");
218                    }
219                }
220            }
221        }
222
223        // Run classifier
224        result = self
225            .classifier()
226            .redact(&result, self.config.redaction_strategy);
227
228        result
229    }
230
231    /// Securely wipe all session security state
232    pub fn wipe(&self) {
233        if let Some(registry) = self.taint_registry.get() {
234            if let Ok(mut r) = registry.write() {
235                r.wipe();
236            } else {
237                tracing::error!("Taint registry lock poisoned — cannot wipe");
238            }
239        }
240        if let Some(log) = self.audit_log.get() {
241            log.log(AuditEntry {
242                timestamp: chrono::Utc::now(),
243                session_id: self.session_id.clone(),
244                event_type: AuditEventType::SessionWiped,
245                severity: SensitivityLevel::Normal,
246                details: "Session security state wiped".to_string(),
247                tool_name: None,
248                action_taken: AuditAction::Logged,
249            });
250            log.clear();
251        }
252    }
253
254    /// Unregister all hooks from the engine
255    pub fn teardown(&self, hook_engine: &HookEngine) {
256        let hook_ids = self.hook_ids.lock().unwrap_or_else(|p| p.into_inner());
257        for hook_id in hook_ids.iter() {
258            hook_engine.unregister_handler(hook_id);
259            hook_engine.unregister(hook_id);
260        }
261    }
262
263    /// Get audit log entries
264    pub fn audit_entries(&self) -> Vec<AuditEntry> {
265        self.audit_log().entries()
266    }
267
268    /// Get the taint registry (read-only access)
269    pub fn get_taint_registry(&self) -> &Arc<RwLock<TaintRegistry>> {
270        self.taint_registry()
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[test]
279    fn test_guard_lifecycle() {
280        let engine = HookEngine::new();
281        let config = SecurityConfig::default();
282        let guard = SecurityGuard::new("test-session".to_string(), config);
283        guard.register_hooks(&engine);
284
285        // Should have registered 4 hooks (interceptor, sanitizer, injection, output scanner)
286        assert_eq!(engine.hook_count(), 4);
287
288        // Taint input with PII
289        guard.taint_input("My SSN is 123-45-6789");
290
291        // Verify taint was registered
292        {
293            let registry = guard.taint_registry().read().unwrap();
294            assert!(registry.entry_count() > 0);
295        }
296
297        // Sanitize output
298        let sanitized = guard.sanitize_output("The SSN 123-45-6789 was found");
299        assert!(!sanitized.contains("123-45-6789"));
300
301        // Wipe
302        guard.wipe();
303        {
304            let registry = guard.taint_registry().read().unwrap();
305            assert_eq!(registry.entry_count(), 0);
306        }
307
308        // Teardown
309        guard.teardown(&engine);
310        assert_eq!(engine.hook_count(), 0);
311    }
312
313    #[test]
314    fn test_guard_taint_input_registers_pii() {
315        let engine = HookEngine::new();
316        let config = SecurityConfig::default();
317        let guard = SecurityGuard::new("s1".to_string(), config);
318        guard.register_hooks(&engine);
319
320        guard.taint_input("Contact me at user@example.com or call 555-123-4567");
321
322        let registry = guard.taint_registry().read().unwrap();
323        assert!(registry.entry_count() > 0);
324
325        // Audit should have entries
326        let entries = guard.audit_entries();
327        assert!(!entries.is_empty());
328        assert!(entries
329            .iter()
330            .any(|e| e.event_type == AuditEventType::TaintRegistered));
331
332        guard.teardown(&engine);
333    }
334
335    #[test]
336    fn test_guard_sanitize_output() {
337        let engine = HookEngine::new();
338        let config = SecurityConfig::default();
339        let guard = SecurityGuard::new("s1".to_string(), config);
340        guard.register_hooks(&engine);
341
342        // Register taint
343        guard.taint_input("My SSN is 123-45-6789");
344
345        // Output containing the tainted value should be sanitized
346        let output = guard.sanitize_output("Found SSN: 123-45-6789 in the data");
347        assert!(!output.contains("123-45-6789"));
348
349        guard.teardown(&engine);
350    }
351
352    #[test]
353    fn test_guard_disabled_features() {
354        let engine = HookEngine::new();
355        let mut config = SecurityConfig::default();
356        config.features.tool_interceptor = false;
357        config.features.output_sanitizer = false;
358        config.features.injection_defense = false;
359        config.features.taint_tracking = false;
360
361        let guard = SecurityGuard::new("s1".to_string(), config);
362        guard.register_hooks(&engine);
363
364        // No hooks should be registered
365        assert_eq!(engine.hook_count(), 0);
366
367        // Taint input should be a no-op
368        guard.taint_input("SSN: 123-45-6789");
369        // taint_registry is never initialized when taint_tracking is disabled
370        assert!(guard.taint_registry.get().is_none());
371
372        // Sanitize should pass through (output_sanitizer disabled)
373        let output = guard.sanitize_output("SSN: 123-45-6789");
374        assert_eq!(output, "SSN: 123-45-6789");
375
376        guard.teardown(&engine);
377    }
378
379    #[test]
380    fn test_guard_wipe_and_teardown() {
381        let engine = HookEngine::new();
382        let config = SecurityConfig::default();
383        let guard = SecurityGuard::new("s1".to_string(), config);
384        guard.register_hooks(&engine);
385
386        guard.taint_input("SSN: 123-45-6789");
387        assert!(guard.taint_registry().read().unwrap().entry_count() > 0);
388
389        guard.wipe();
390        assert_eq!(guard.taint_registry().read().unwrap().entry_count(), 0);
391        assert!(guard.audit_entries().is_empty());
392
393        guard.teardown(&engine);
394        assert_eq!(engine.hook_count(), 0);
395    }
396
397    #[test]
398    fn test_guard_lazy_init() {
399        // Verify subsystems are not initialized until accessed
400        let config = SecurityConfig::default();
401        let guard = SecurityGuard::new("lazy-test".to_string(), config);
402
403        // Before any access, OnceLock should be empty
404        assert!(guard.taint_registry.get().is_none());
405        assert!(guard.classifier.get().is_none());
406        assert!(guard.audit_log.get().is_none());
407
408        // Access triggers initialization
409        let _ = guard.taint_registry();
410        assert!(guard.taint_registry.get().is_some());
411        assert!(guard.classifier.get().is_none()); // still lazy
412        assert!(guard.audit_log.get().is_none()); // still lazy
413    }
414
415    #[test]
416    fn test_register_hooks_idempotent() {
417        let engine = HookEngine::new();
418        let config = SecurityConfig::default();
419        let guard = SecurityGuard::new("s1".to_string(), config);
420
421        guard.register_hooks(&engine);
422        let count = engine.hook_count();
423
424        // Second call should be a no-op
425        guard.register_hooks(&engine);
426        assert_eq!(engine.hook_count(), count);
427    }
428}