Skip to main content

a3s_code_core/security/
taint.rs

1//! Security Taint Tracking
2//!
3//! Tracks sensitive data values and their encoded variants (base64, hex, URL-encoded)
4//! so they can be detected in tool arguments and LLM output.
5
6use super::config::SensitivityLevel;
7use base64::Engine;
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11use uuid::Uuid;
12
13/// Unique identifier for a taint entry
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub struct TaintId(pub Uuid);
16
17impl TaintId {
18    /// Generate a new random taint ID
19    pub fn new() -> Self {
20        Self(Uuid::new_v4())
21    }
22}
23
24impl Default for TaintId {
25    fn default() -> Self {
26        Self::new()
27    }
28}
29
30impl std::fmt::Display for TaintId {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        write!(f, "{}", self.0)
33    }
34}
35
36/// A tracked piece of sensitive data with its encoded variants
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct TaintEntry {
39    /// Unique identifier
40    pub id: TaintId,
41    /// The original sensitive value
42    pub original_value: String,
43    /// Name of the classification rule that matched
44    pub rule_name: String,
45    /// Sensitivity level
46    pub level: SensitivityLevel,
47    /// Encoded variants (base64, hex, url-encoded)
48    pub variants: Vec<String>,
49    /// When this entry was created
50    pub created_at: DateTime<Utc>,
51}
52
53/// Registry of tainted (sensitive) data values
54pub struct TaintRegistry {
55    /// Entries indexed by TaintId
56    entries: HashMap<TaintId, TaintEntry>,
57    /// Reverse index: value/variant -> TaintId for fast lookup
58    value_index: HashMap<String, TaintId>,
59    /// Set of all known values for quick contains check
60    all_values: HashSet<String>,
61}
62
63impl TaintRegistry {
64    /// Create a new empty registry
65    pub fn new() -> Self {
66        Self {
67            entries: HashMap::new(),
68            value_index: HashMap::new(),
69            all_values: HashSet::new(),
70        }
71    }
72
73    /// Register a sensitive value and auto-generate encoded variants
74    pub fn register(&mut self, value: &str, rule_name: &str, level: SensitivityLevel) -> TaintId {
75        // Check if already registered
76        if let Some(&id) = self.value_index.get(value) {
77            return id;
78        }
79
80        let id = TaintId::new();
81        let variants = generate_variants(value);
82
83        let entry = TaintEntry {
84            id,
85            original_value: value.to_string(),
86            rule_name: rule_name.to_string(),
87            level,
88            variants: variants.clone(),
89            created_at: Utc::now(),
90        };
91
92        // Index the original value
93        self.value_index.insert(value.to_string(), id);
94        self.all_values.insert(value.to_string());
95
96        // Index all variants
97        for variant in &variants {
98            self.value_index.insert(variant.clone(), id);
99            self.all_values.insert(variant.clone());
100        }
101
102        self.entries.insert(id, entry);
103        id
104    }
105
106    /// Check if a text contains any tainted value (exact match against all variants)
107    pub fn contains(&self, text: &str) -> bool {
108        for value in &self.all_values {
109            if text.contains(value.as_str()) {
110                return true;
111            }
112        }
113        false
114    }
115
116    /// Check for encoded variants in text by decoding base64/hex/url segments
117    pub fn check_encoded(&self, text: &str) -> bool {
118        // Try to decode base64 segments
119        for word in text.split_whitespace() {
120            // Try base64 decode
121            if let Ok(decoded) = base64::engine::general_purpose::STANDARD.decode(word) {
122                if let Ok(decoded_str) = String::from_utf8(decoded) {
123                    if self.contains_original(&decoded_str) {
124                        return true;
125                    }
126                }
127            }
128
129            // Try URL decode
130            if word.contains('%') {
131                if let Ok(decoded) = urldecode(word) {
132                    if self.contains_original(&decoded) {
133                        return true;
134                    }
135                }
136            }
137
138            // Try hex decode
139            if word.len() >= 4 && word.len() % 2 == 0 && word.chars().all(|c| c.is_ascii_hexdigit())
140            {
141                if let Some(decoded) = hex_decode(word) {
142                    if self.contains_original(&decoded) {
143                        return true;
144                    }
145                }
146            }
147        }
148        false
149    }
150
151    /// Check if text contains any original (non-variant) tainted value
152    fn contains_original(&self, text: &str) -> bool {
153        for entry in self.entries.values() {
154            if text.contains(&entry.original_value) {
155                return true;
156            }
157        }
158        false
159    }
160
161    /// Securely wipe all taint data
162    pub fn wipe(&mut self) {
163        self.entries.clear();
164        self.value_index.clear();
165        self.all_values.clear();
166    }
167
168    /// Get the number of tracked entries
169    pub fn entry_count(&self) -> usize {
170        self.entries.len()
171    }
172
173    /// Get a taint entry by ID
174    pub fn get(&self, id: &TaintId) -> Option<&TaintEntry> {
175        self.entries.get(id)
176    }
177
178    /// Find the taint ID for a given value or variant
179    pub fn lookup(&self, value: &str) -> Option<TaintId> {
180        self.value_index.get(value).copied()
181    }
182
183    /// Iterate over all entries
184    pub fn entries_iter(&self) -> impl Iterator<Item = (&TaintId, &TaintEntry)> {
185        self.entries.iter()
186    }
187}
188
189impl Default for TaintRegistry {
190    fn default() -> Self {
191        Self::new()
192    }
193}
194
195/// Generate encoded variants of a sensitive value
196fn generate_variants(value: &str) -> Vec<String> {
197    let mut variants = Vec::new();
198
199    // Base64 encoded
200    let b64 = base64::engine::general_purpose::STANDARD.encode(value.as_bytes());
201    variants.push(b64);
202
203    // Hex encoded
204    let hex: String = value.bytes().map(|b| format!("{:02x}", b)).collect();
205    variants.push(hex);
206
207    // URL encoded
208    let url_encoded = urlencode(value);
209    if url_encoded != value {
210        variants.push(url_encoded);
211    }
212
213    variants
214}
215
216/// Simple URL encoding
217fn urlencode(s: &str) -> String {
218    let mut result = String::new();
219    for b in s.bytes() {
220        match b {
221            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
222                result.push(b as char);
223            }
224            _ => {
225                result.push_str(&format!("%{:02X}", b));
226            }
227        }
228    }
229    result
230}
231
232/// Simple URL decoding
233fn urldecode(s: &str) -> Result<String, ()> {
234    let mut result = Vec::new();
235    let bytes = s.as_bytes();
236    let mut i = 0;
237    while i < bytes.len() {
238        if bytes[i] == b'%' && i + 2 < bytes.len() {
239            let hex = &s[i + 1..i + 3];
240            if let Ok(byte) = u8::from_str_radix(hex, 16) {
241                result.push(byte);
242                i += 3;
243                continue;
244            }
245        }
246        result.push(bytes[i]);
247        i += 1;
248    }
249    String::from_utf8(result).map_err(|_| ())
250}
251
252/// Decode hex string to UTF-8
253fn hex_decode(s: &str) -> Option<String> {
254    let bytes: Result<Vec<u8>, _> = (0..s.len())
255        .step_by(2)
256        .map(|i| u8::from_str_radix(&s[i..i + 2], 16))
257        .collect();
258    bytes.ok().and_then(|b| String::from_utf8(b).ok())
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn test_register_and_contains() {
267        let mut registry = TaintRegistry::new();
268        registry.register(
269            "4111-1111-1111-1111",
270            "credit_card",
271            SensitivityLevel::HighlySensitive,
272        );
273
274        assert!(registry.contains("my card is 4111-1111-1111-1111 ok"));
275        assert!(!registry.contains("no sensitive data here"));
276        assert_eq!(registry.entry_count(), 1);
277    }
278
279    #[test]
280    fn test_variant_generation_base64() {
281        let mut registry = TaintRegistry::new();
282        registry.register("secret123", "api_key", SensitivityLevel::HighlySensitive);
283
284        let b64 = base64::engine::general_purpose::STANDARD.encode("secret123");
285        assert!(registry.contains(&b64));
286    }
287
288    #[test]
289    fn test_variant_generation_hex() {
290        let mut registry = TaintRegistry::new();
291        registry.register("abc", "test", SensitivityLevel::Sensitive);
292
293        // "abc" in hex is "616263"
294        assert!(registry.contains("616263"));
295    }
296
297    #[test]
298    fn test_check_encoded_base64() {
299        let mut registry = TaintRegistry::new();
300        registry.register("sensitive-data", "test", SensitivityLevel::Sensitive);
301
302        let b64 = base64::engine::general_purpose::STANDARD.encode("sensitive-data");
303        assert!(registry.check_encoded(&format!("here is {}", b64)));
304    }
305
306    #[test]
307    fn test_check_encoded_hex() {
308        let mut registry = TaintRegistry::new();
309        registry.register("abc", "test", SensitivityLevel::Sensitive);
310
311        // "abc" in hex
312        assert!(registry.check_encoded("decoded hex: 616263"));
313    }
314
315    #[test]
316    fn test_duplicate_register_returns_same_id() {
317        let mut registry = TaintRegistry::new();
318        let id1 = registry.register("value1", "rule1", SensitivityLevel::Sensitive);
319        let id2 = registry.register("value1", "rule1", SensitivityLevel::Sensitive);
320        assert_eq!(id1, id2);
321        assert_eq!(registry.entry_count(), 1);
322    }
323
324    #[test]
325    fn test_wipe_clears_all() {
326        let mut registry = TaintRegistry::new();
327        registry.register("value1", "rule1", SensitivityLevel::Sensitive);
328        registry.register("value2", "rule2", SensitivityLevel::HighlySensitive);
329        assert_eq!(registry.entry_count(), 2);
330
331        registry.wipe();
332        assert_eq!(registry.entry_count(), 0);
333        assert!(!registry.contains("value1"));
334        assert!(!registry.contains("value2"));
335    }
336
337    #[test]
338    fn test_lookup() {
339        let mut registry = TaintRegistry::new();
340        let id = registry.register("test-value", "rule1", SensitivityLevel::Sensitive);
341
342        assert_eq!(registry.lookup("test-value"), Some(id));
343        assert!(registry.lookup("nonexistent").is_none());
344    }
345
346    #[test]
347    fn test_get_entry() {
348        let mut registry = TaintRegistry::new();
349        let id = registry.register("test-value", "rule1", SensitivityLevel::Sensitive);
350
351        let entry = registry.get(&id).unwrap();
352        assert_eq!(entry.original_value, "test-value");
353        assert_eq!(entry.rule_name, "rule1");
354        assert_eq!(entry.level, SensitivityLevel::Sensitive);
355        assert!(!entry.variants.is_empty());
356    }
357
358    #[test]
359    fn test_url_encoding_variant() {
360        let mut registry = TaintRegistry::new();
361        registry.register("hello world", "test", SensitivityLevel::Sensitive);
362
363        // URL-encoded "hello world" = "hello%20world"
364        assert!(registry.contains("hello%20world"));
365    }
366
367    #[test]
368    fn test_taint_id_default() {
369        let id1 = TaintId::default();
370        let id2 = TaintId::default();
371        // Each default creates a new unique ID
372        assert_ne!(id1.0, id2.0);
373    }
374
375    #[test]
376    fn test_taint_id_display() {
377        let id = TaintId::new();
378        let display = format!("{}", id);
379        assert!(!display.is_empty());
380        // UUID format
381        assert_eq!(display.len(), 36);
382    }
383
384    #[test]
385    fn test_taint_registry_default() {
386        let registry = TaintRegistry::default();
387        assert!(registry.lookup("anything").is_none());
388    }
389
390    #[test]
391    fn test_urldecode_valid() {
392        let decoded = urldecode("hello%20world");
393        assert!(decoded.is_ok());
394        assert_eq!(decoded.unwrap(), "hello world");
395    }
396
397    #[test]
398    fn test_urldecode_no_encoding() {
399        let decoded = urldecode("hello");
400        assert!(decoded.is_ok());
401        assert_eq!(decoded.unwrap(), "hello");
402    }
403
404    #[test]
405    fn test_urldecode_invalid_hex_passthrough() {
406        // Invalid hex after % just passes through as raw bytes
407        let decoded = urldecode("hello%ZZworld");
408        assert!(decoded.is_ok());
409        assert_eq!(decoded.unwrap(), "hello%ZZworld");
410    }
411
412    #[test]
413    fn test_urlencode_special_chars() {
414        let encoded = urlencode("a b@c");
415        assert_eq!(encoded, "a%20b%40c");
416    }
417
418    #[test]
419    fn test_urlencode_no_special() {
420        let encoded = urlencode("hello");
421        assert_eq!(encoded, "hello");
422    }
423
424    #[test]
425    fn test_generate_variants() {
426        let variants = generate_variants("test");
427        // Should have base64 and hex (no url-encoded since "test" has no special chars)
428        assert!(variants.len() >= 2);
429        // Base64 of "test" = "dGVzdA=="
430        assert!(variants.contains(&"dGVzdA==".to_string()));
431        // Hex of "test" = "74657374"
432        assert!(variants.contains(&"74657374".to_string()));
433    }
434
435    #[test]
436    fn test_generate_variants_with_special_chars() {
437        let variants = generate_variants("hello world");
438        // Should have base64, hex, and url-encoded
439        assert!(variants.len() >= 3);
440        assert!(variants.contains(&"hello%20world".to_string()));
441    }
442
443    #[test]
444    fn test_entries_iter() {
445        let mut registry = TaintRegistry::new();
446        registry.register("val1", "rule1", SensitivityLevel::Sensitive);
447        registry.register("val2", "rule2", SensitivityLevel::HighlySensitive);
448
449        let entries: Vec<_> = registry.entries_iter().collect();
450        assert_eq!(entries.len(), 2);
451    }
452}