Skip to main content

a3s_code_core/security/
mod.rs

1//! Security Module
2//!
3//! Provides security features for A3S Code sessions:
4//! - **Output Sanitizer**: Redacts sensitive data from LLM responses
5//! - **Taint Tracking**: Tracks sensitive values and their encoded variants
6//! - **Tool Interceptor**: Blocks dangerous tool invocations
7//! - **Session Isolation**: Per-session security state with secure wipe
8//! - **Prompt Injection Defense**: Detects and blocks injection attempts
9
10pub mod audit;
11pub mod classifier;
12pub mod config;
13pub mod injection;
14pub mod interceptor;
15pub mod sanitizer;
16pub mod taint;
17
18pub use audit::{AuditAction, AuditEntry, AuditEventType, AuditLog};
19pub use classifier::PrivacyClassifier;
20pub use config::{RedactionStrategy, SecurityConfig, SensitivityLevel};
21pub use injection::InjectionDetector;
22pub use injection::ToolOutputInjectionScanner;
23pub use interceptor::ToolInterceptor;
24pub use sanitizer::OutputSanitizer;
25pub use taint::{TaintId, TaintRegistry};
26
27use crate::hooks::HookEventType;
28use crate::hooks::HookHandler;
29use crate::hooks::{Hook, HookConfig, HookEngine};
30use sanitizer::make_replacement;
31use std::sync::{Arc, OnceLock, RwLock};
32
33/// Hook ID prefix for security hooks
34const HOOK_PREFIX: &str = "security";
35
36/// Per-session security orchestrator
37///
38/// Subsystems (`TaintRegistry`, `PrivacyClassifier`, `AuditLog`) are lazily
39/// initialized on first access via `OnceLock`. This avoids allocating a
40/// `HashMap` (taint), compiling regex (classifier), and creating a 10,000-
41/// capacity `Vec` (audit) when the corresponding feature is disabled.
42pub struct SecurityGuard {
43    session_id: String,
44    config: SecurityConfig,
45    taint_registry: OnceLock<Arc<RwLock<TaintRegistry>>>,
46    classifier: OnceLock<Arc<PrivacyClassifier>>,
47    audit_log: OnceLock<Arc<AuditLog>>,
48    /// Hook IDs registered by this guard (for teardown)
49    hook_ids: std::sync::Mutex<Vec<String>>,
50}
51
52impl SecurityGuard {
53    /// Create a new SecurityGuard without registering hooks.
54    ///
55    /// Call [`register_hooks`] separately with a real `HookEngine` to
56    /// register security hooks. This avoids the previous bug where hooks
57    /// were registered to a temporary engine that was immediately dropped.
58    pub fn new(session_id: String, config: SecurityConfig) -> Self {
59        Self {
60            session_id,
61            config,
62            taint_registry: OnceLock::new(),
63            classifier: OnceLock::new(),
64            audit_log: OnceLock::new(),
65            hook_ids: std::sync::Mutex::new(Vec::new()),
66        }
67    }
68
69    /// Register security hooks with the given engine.
70    ///
71    /// Must be called with a long-lived `HookEngine` — not a temporary.
72    /// Safe to call multiple times (idempotent: skips if hooks already
73    /// registered).
74    pub fn register_hooks(&self, hook_engine: &HookEngine) {
75        let mut hook_ids = self.hook_ids.lock().unwrap_or_else(|p| p.into_inner());
76        if !hook_ids.is_empty() {
77            // Already registered
78            return;
79        }
80
81        // Register tool interceptor hook
82        if self.config.features.tool_interceptor {
83            let hook_id = format!("{}-interceptor-{}", HOOK_PREFIX, &self.session_id);
84            let interceptor = ToolInterceptor::new(
85                &self.config,
86                self.taint_registry().clone(),
87                self.audit_log().clone(),
88                self.session_id.clone(),
89            );
90            hook_engine.register(Hook::new(&hook_id, HookEventType::PreToolUse).with_config(
91                HookConfig {
92                    priority: 1,
93                    ..Default::default()
94                },
95            ));
96            hook_engine.register_handler(&hook_id, Arc::new(interceptor) as Arc<dyn HookHandler>);
97            hook_ids.push(hook_id);
98        }
99
100        // Register output sanitizer hook
101        if self.config.features.output_sanitizer {
102            let hook_id = format!("{}-sanitizer-{}", HOOK_PREFIX, &self.session_id);
103            let sanitizer = OutputSanitizer::new(
104                self.taint_registry().clone(),
105                self.classifier().clone(),
106                self.config.redaction_strategy,
107                self.audit_log().clone(),
108                self.session_id.clone(),
109            );
110            hook_engine.register(Hook::new(&hook_id, HookEventType::GenerateEnd).with_config(
111                HookConfig {
112                    priority: 1,
113                    ..Default::default()
114                },
115            ));
116            hook_engine.register_handler(&hook_id, Arc::new(sanitizer) as Arc<dyn HookHandler>);
117            hook_ids.push(hook_id);
118        }
119
120        // Register injection detector hook
121        if self.config.features.injection_defense {
122            let hook_id = format!("{}-injection-{}", HOOK_PREFIX, &self.session_id);
123            let detector =
124                InjectionDetector::new(self.audit_log().clone(), self.session_id.clone());
125            hook_engine.register(
126                Hook::new(&hook_id, HookEventType::GenerateStart).with_config(HookConfig {
127                    priority: 1,
128                    ..Default::default()
129                }),
130            );
131            hook_engine.register_handler(&hook_id, Arc::new(detector) as Arc<dyn HookHandler>);
132            hook_ids.push(hook_id);
133
134            // Also register PostToolUse scanner for indirect injection via tool outputs
135            let scanner_id = format!("{}-injection-output-{}", HOOK_PREFIX, &self.session_id);
136            let scanner =
137                ToolOutputInjectionScanner::new(self.audit_log().clone(), self.session_id.clone());
138            hook_engine.register(
139                Hook::new(&scanner_id, HookEventType::PostToolUse).with_config(HookConfig {
140                    priority: 1,
141                    ..Default::default()
142                }),
143            );
144            hook_engine.register_handler(&scanner_id, Arc::new(scanner) as Arc<dyn HookHandler>);
145            hook_ids.push(scanner_id);
146        }
147    }
148
149    /// Lazily initialize and return the taint registry
150    fn taint_registry(&self) -> &Arc<RwLock<TaintRegistry>> {
151        self.taint_registry
152            .get_or_init(|| Arc::new(RwLock::new(TaintRegistry::new())))
153    }
154
155    /// Lazily initialize and return the privacy classifier
156    fn classifier(&self) -> &Arc<PrivacyClassifier> {
157        self.classifier
158            .get_or_init(|| Arc::new(PrivacyClassifier::new(&self.config.classification_rules)))
159    }
160
161    /// Lazily initialize and return the audit log
162    fn audit_log(&self) -> &Arc<AuditLog> {
163        self.audit_log
164            .get_or_init(|| Arc::new(AuditLog::new(10_000)))
165    }
166
167    /// Classify input text and register any detected sensitive data as tainted
168    pub fn taint_input(&self, text: &str) {
169        if !self.config.features.taint_tracking {
170            return;
171        }
172
173        let result = self.classifier().classify(text);
174        if !result.matches.is_empty() {
175            let Ok(mut registry) = self.taint_registry().write() else {
176                tracing::error!("Taint registry lock poisoned — skipping taint registration");
177                return;
178            };
179            for m in &result.matches {
180                let id = registry.register(&m.matched_text, &m.rule_name, m.level);
181                self.audit_log().log(AuditEntry {
182                    timestamp: chrono::Utc::now(),
183                    session_id: self.session_id.clone(),
184                    event_type: AuditEventType::TaintRegistered,
185                    severity: m.level,
186                    details: format!(
187                        "Registered tainted value from rule '{}' (id: {})",
188                        m.rule_name, id
189                    ),
190                    tool_name: None,
191                    action_taken: AuditAction::Logged,
192                });
193            }
194        }
195    }
196
197    /// Sanitize output text by redacting tainted and classified sensitive data
198    pub fn sanitize_output(&self, text: &str) -> String {
199        if !self.config.features.output_sanitizer {
200            return text.to_string();
201        }
202
203        let mut result = text.to_string();
204
205        // Check taint registry
206        {
207            let Ok(registry) = self.taint_registry().read() else {
208                tracing::error!("Taint registry lock poisoned — returning unsanitized output");
209                return result;
210            };
211            for (_, entry) in registry.entries_iter() {
212                if result.contains(&entry.original_value) {
213                    let replacement =
214                        make_replacement(&entry.original_value, self.config.redaction_strategy);
215                    result = result.replace(&entry.original_value, &replacement);
216                }
217                for variant in &entry.variants {
218                    if result.contains(variant.as_str()) {
219                        result = result.replace(variant.as_str(), "[REDACTED]");
220                    }
221                }
222            }
223        }
224
225        // Run classifier
226        result = self
227            .classifier()
228            .redact(&result, self.config.redaction_strategy);
229
230        result
231    }
232
233    /// Securely wipe all session security state
234    pub fn wipe(&self) {
235        if let Some(registry) = self.taint_registry.get() {
236            if let Ok(mut r) = registry.write() {
237                r.wipe();
238            } else {
239                tracing::error!("Taint registry lock poisoned — cannot wipe");
240            }
241        }
242        if let Some(log) = self.audit_log.get() {
243            log.log(AuditEntry {
244                timestamp: chrono::Utc::now(),
245                session_id: self.session_id.clone(),
246                event_type: AuditEventType::SessionWiped,
247                severity: SensitivityLevel::Normal,
248                details: "Session security state wiped".to_string(),
249                tool_name: None,
250                action_taken: AuditAction::Logged,
251            });
252            log.clear();
253        }
254    }
255
256    /// Unregister all hooks from the engine
257    pub fn teardown(&self, hook_engine: &HookEngine) {
258        let hook_ids = self.hook_ids.lock().unwrap_or_else(|p| p.into_inner());
259        for hook_id in hook_ids.iter() {
260            hook_engine.unregister_handler(hook_id);
261            hook_engine.unregister(hook_id);
262        }
263    }
264
265    /// Get audit log entries
266    pub fn audit_entries(&self) -> Vec<AuditEntry> {
267        self.audit_log().entries()
268    }
269
270    /// Get the taint registry (read-only access)
271    pub fn get_taint_registry(&self) -> &Arc<RwLock<TaintRegistry>> {
272        self.taint_registry()
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_guard_lifecycle() {
282        let engine = HookEngine::new();
283        let config = SecurityConfig::default();
284        let guard = SecurityGuard::new("test-session".to_string(), config);
285        guard.register_hooks(&engine);
286
287        // Should have registered 4 hooks (interceptor, sanitizer, injection, output scanner)
288        assert_eq!(engine.hook_count(), 4);
289
290        // Taint input with PII
291        guard.taint_input("My SSN is 123-45-6789");
292
293        // Verify taint was registered
294        {
295            let registry = guard.taint_registry().read().unwrap();
296            assert!(registry.entry_count() > 0);
297        }
298
299        // Sanitize output
300        let sanitized = guard.sanitize_output("The SSN 123-45-6789 was found");
301        assert!(!sanitized.contains("123-45-6789"));
302
303        // Wipe
304        guard.wipe();
305        {
306            let registry = guard.taint_registry().read().unwrap();
307            assert_eq!(registry.entry_count(), 0);
308        }
309
310        // Teardown
311        guard.teardown(&engine);
312        assert_eq!(engine.hook_count(), 0);
313    }
314
315    #[test]
316    fn test_guard_taint_input_registers_pii() {
317        let engine = HookEngine::new();
318        let config = SecurityConfig::default();
319        let guard = SecurityGuard::new("s1".to_string(), config);
320        guard.register_hooks(&engine);
321
322        guard.taint_input("Contact me at user@example.com or call 555-123-4567");
323
324        let registry = guard.taint_registry().read().unwrap();
325        assert!(registry.entry_count() > 0);
326
327        // Audit should have entries
328        let entries = guard.audit_entries();
329        assert!(!entries.is_empty());
330        assert!(entries
331            .iter()
332            .any(|e| e.event_type == AuditEventType::TaintRegistered));
333
334        guard.teardown(&engine);
335    }
336
337    #[test]
338    fn test_guard_sanitize_output() {
339        let engine = HookEngine::new();
340        let config = SecurityConfig::default();
341        let guard = SecurityGuard::new("s1".to_string(), config);
342        guard.register_hooks(&engine);
343
344        // Register taint
345        guard.taint_input("My SSN is 123-45-6789");
346
347        // Output containing the tainted value should be sanitized
348        let output = guard.sanitize_output("Found SSN: 123-45-6789 in the data");
349        assert!(!output.contains("123-45-6789"));
350
351        guard.teardown(&engine);
352    }
353
354    #[test]
355    fn test_guard_disabled_features() {
356        let engine = HookEngine::new();
357        let mut config = SecurityConfig::default();
358        config.features.tool_interceptor = false;
359        config.features.output_sanitizer = false;
360        config.features.injection_defense = false;
361        config.features.taint_tracking = false;
362
363        let guard = SecurityGuard::new("s1".to_string(), config);
364        guard.register_hooks(&engine);
365
366        // No hooks should be registered
367        assert_eq!(engine.hook_count(), 0);
368
369        // Taint input should be a no-op
370        guard.taint_input("SSN: 123-45-6789");
371        // taint_registry is never initialized when taint_tracking is disabled
372        assert!(guard.taint_registry.get().is_none());
373
374        // Sanitize should pass through (output_sanitizer disabled)
375        let output = guard.sanitize_output("SSN: 123-45-6789");
376        assert_eq!(output, "SSN: 123-45-6789");
377
378        guard.teardown(&engine);
379    }
380
381    #[test]
382    fn test_guard_wipe_and_teardown() {
383        let engine = HookEngine::new();
384        let config = SecurityConfig::default();
385        let guard = SecurityGuard::new("s1".to_string(), config);
386        guard.register_hooks(&engine);
387
388        guard.taint_input("SSN: 123-45-6789");
389        assert!(guard.taint_registry().read().unwrap().entry_count() > 0);
390
391        guard.wipe();
392        assert_eq!(guard.taint_registry().read().unwrap().entry_count(), 0);
393        assert!(guard.audit_entries().is_empty());
394
395        guard.teardown(&engine);
396        assert_eq!(engine.hook_count(), 0);
397    }
398
399    #[test]
400    fn test_guard_lazy_init() {
401        // Verify subsystems are not initialized until accessed
402        let config = SecurityConfig::default();
403        let guard = SecurityGuard::new("lazy-test".to_string(), config);
404
405        // Before any access, OnceLock should be empty
406        assert!(guard.taint_registry.get().is_none());
407        assert!(guard.classifier.get().is_none());
408        assert!(guard.audit_log.get().is_none());
409
410        // Access triggers initialization
411        let _ = guard.taint_registry();
412        assert!(guard.taint_registry.get().is_some());
413        assert!(guard.classifier.get().is_none()); // still lazy
414        assert!(guard.audit_log.get().is_none()); // still lazy
415    }
416
417    #[test]
418    fn test_register_hooks_idempotent() {
419        let engine = HookEngine::new();
420        let config = SecurityConfig::default();
421        let guard = SecurityGuard::new("s1".to_string(), config);
422
423        guard.register_hooks(&engine);
424        let count = engine.hook_count();
425
426        // Second call should be a no-op
427        guard.register_hooks(&engine);
428        assert_eq!(engine.hook_count(), count);
429    }
430}