1use super::config::SensitivityLevel;
7use base64::Engine;
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11use uuid::Uuid;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub struct TaintId(pub Uuid);
16
17impl TaintId {
18 pub fn new() -> Self {
20 Self(Uuid::new_v4())
21 }
22}
23
24impl Default for TaintId {
25 fn default() -> Self {
26 Self::new()
27 }
28}
29
30impl std::fmt::Display for TaintId {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 write!(f, "{}", self.0)
33 }
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct TaintEntry {
39 pub id: TaintId,
41 pub original_value: String,
43 pub rule_name: String,
45 pub level: SensitivityLevel,
47 pub variants: Vec<String>,
49 pub created_at: DateTime<Utc>,
51}
52
53pub struct TaintRegistry {
55 entries: HashMap<TaintId, TaintEntry>,
57 value_index: HashMap<String, TaintId>,
59 all_values: HashSet<String>,
61}
62
63impl TaintRegistry {
64 pub fn new() -> Self {
66 Self {
67 entries: HashMap::new(),
68 value_index: HashMap::new(),
69 all_values: HashSet::new(),
70 }
71 }
72
73 pub fn register(&mut self, value: &str, rule_name: &str, level: SensitivityLevel) -> TaintId {
75 if let Some(&id) = self.value_index.get(value) {
77 return id;
78 }
79
80 let id = TaintId::new();
81 let variants = generate_variants(value);
82
83 let entry = TaintEntry {
84 id,
85 original_value: value.to_string(),
86 rule_name: rule_name.to_string(),
87 level,
88 variants: variants.clone(),
89 created_at: Utc::now(),
90 };
91
92 self.value_index.insert(value.to_string(), id);
94 self.all_values.insert(value.to_string());
95
96 for variant in &variants {
98 self.value_index.insert(variant.clone(), id);
99 self.all_values.insert(variant.clone());
100 }
101
102 self.entries.insert(id, entry);
103 id
104 }
105
106 pub fn contains(&self, text: &str) -> bool {
108 for value in &self.all_values {
109 if text.contains(value.as_str()) {
110 return true;
111 }
112 }
113 false
114 }
115
116 pub fn check_encoded(&self, text: &str) -> bool {
118 for word in text.split_whitespace() {
120 if let Ok(decoded) = base64::engine::general_purpose::STANDARD.decode(word) {
122 if let Ok(decoded_str) = String::from_utf8(decoded) {
123 if self.contains_original(&decoded_str) {
124 return true;
125 }
126 }
127 }
128
129 if word.contains('%') {
131 if let Ok(decoded) = urldecode(word) {
132 if self.contains_original(&decoded) {
133 return true;
134 }
135 }
136 }
137
138 if word.len() >= 4 && word.len() % 2 == 0 && word.chars().all(|c| c.is_ascii_hexdigit())
140 {
141 if let Some(decoded) = hex_decode(word) {
142 if self.contains_original(&decoded) {
143 return true;
144 }
145 }
146 }
147 }
148 false
149 }
150
151 fn contains_original(&self, text: &str) -> bool {
153 for entry in self.entries.values() {
154 if text.contains(&entry.original_value) {
155 return true;
156 }
157 }
158 false
159 }
160
161 pub fn wipe(&mut self) {
163 self.entries.clear();
164 self.value_index.clear();
165 self.all_values.clear();
166 }
167
168 pub fn entry_count(&self) -> usize {
170 self.entries.len()
171 }
172
173 pub fn get(&self, id: &TaintId) -> Option<&TaintEntry> {
175 self.entries.get(id)
176 }
177
178 pub fn lookup(&self, value: &str) -> Option<TaintId> {
180 self.value_index.get(value).copied()
181 }
182
183 pub fn entries_iter(&self) -> impl Iterator<Item = (&TaintId, &TaintEntry)> {
185 self.entries.iter()
186 }
187}
188
189impl Default for TaintRegistry {
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195fn generate_variants(value: &str) -> Vec<String> {
197 let mut variants = Vec::new();
198
199 let b64 = base64::engine::general_purpose::STANDARD.encode(value.as_bytes());
201 variants.push(b64);
202
203 let hex: String = value.bytes().map(|b| format!("{:02x}", b)).collect();
205 variants.push(hex);
206
207 let url_encoded = urlencode(value);
209 if url_encoded != value {
210 variants.push(url_encoded);
211 }
212
213 variants
214}
215
216fn urlencode(s: &str) -> String {
218 let mut result = String::new();
219 for b in s.bytes() {
220 match b {
221 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
222 result.push(b as char);
223 }
224 _ => {
225 result.push_str(&format!("%{:02X}", b));
226 }
227 }
228 }
229 result
230}
231
232fn urldecode(s: &str) -> Result<String, ()> {
234 let mut result = Vec::new();
235 let bytes = s.as_bytes();
236 let mut i = 0;
237 while i < bytes.len() {
238 if bytes[i] == b'%' && i + 2 < bytes.len() {
239 let hex = &s[i + 1..i + 3];
240 if let Ok(byte) = u8::from_str_radix(hex, 16) {
241 result.push(byte);
242 i += 3;
243 continue;
244 }
245 }
246 result.push(bytes[i]);
247 i += 1;
248 }
249 String::from_utf8(result).map_err(|_| ())
250}
251
252fn hex_decode(s: &str) -> Option<String> {
254 let bytes: Result<Vec<u8>, _> = (0..s.len())
255 .step_by(2)
256 .map(|i| u8::from_str_radix(&s[i..i + 2], 16))
257 .collect();
258 bytes.ok().and_then(|b| String::from_utf8(b).ok())
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn test_register_and_contains() {
267 let mut registry = TaintRegistry::new();
268 registry.register(
269 "4111-1111-1111-1111",
270 "credit_card",
271 SensitivityLevel::HighlySensitive,
272 );
273
274 assert!(registry.contains("my card is 4111-1111-1111-1111 ok"));
275 assert!(!registry.contains("no sensitive data here"));
276 assert_eq!(registry.entry_count(), 1);
277 }
278
279 #[test]
280 fn test_variant_generation_base64() {
281 let mut registry = TaintRegistry::new();
282 registry.register("secret123", "api_key", SensitivityLevel::HighlySensitive);
283
284 let b64 = base64::engine::general_purpose::STANDARD.encode("secret123");
285 assert!(registry.contains(&b64));
286 }
287
288 #[test]
289 fn test_variant_generation_hex() {
290 let mut registry = TaintRegistry::new();
291 registry.register("abc", "test", SensitivityLevel::Sensitive);
292
293 assert!(registry.contains("616263"));
295 }
296
297 #[test]
298 fn test_check_encoded_base64() {
299 let mut registry = TaintRegistry::new();
300 registry.register("sensitive-data", "test", SensitivityLevel::Sensitive);
301
302 let b64 = base64::engine::general_purpose::STANDARD.encode("sensitive-data");
303 assert!(registry.check_encoded(&format!("here is {}", b64)));
304 }
305
306 #[test]
307 fn test_check_encoded_hex() {
308 let mut registry = TaintRegistry::new();
309 registry.register("abc", "test", SensitivityLevel::Sensitive);
310
311 assert!(registry.check_encoded("decoded hex: 616263"));
313 }
314
315 #[test]
316 fn test_duplicate_register_returns_same_id() {
317 let mut registry = TaintRegistry::new();
318 let id1 = registry.register("value1", "rule1", SensitivityLevel::Sensitive);
319 let id2 = registry.register("value1", "rule1", SensitivityLevel::Sensitive);
320 assert_eq!(id1, id2);
321 assert_eq!(registry.entry_count(), 1);
322 }
323
324 #[test]
325 fn test_wipe_clears_all() {
326 let mut registry = TaintRegistry::new();
327 registry.register("value1", "rule1", SensitivityLevel::Sensitive);
328 registry.register("value2", "rule2", SensitivityLevel::HighlySensitive);
329 assert_eq!(registry.entry_count(), 2);
330
331 registry.wipe();
332 assert_eq!(registry.entry_count(), 0);
333 assert!(!registry.contains("value1"));
334 assert!(!registry.contains("value2"));
335 }
336
337 #[test]
338 fn test_lookup() {
339 let mut registry = TaintRegistry::new();
340 let id = registry.register("test-value", "rule1", SensitivityLevel::Sensitive);
341
342 assert_eq!(registry.lookup("test-value"), Some(id));
343 assert!(registry.lookup("nonexistent").is_none());
344 }
345
346 #[test]
347 fn test_get_entry() {
348 let mut registry = TaintRegistry::new();
349 let id = registry.register("test-value", "rule1", SensitivityLevel::Sensitive);
350
351 let entry = registry.get(&id).unwrap();
352 assert_eq!(entry.original_value, "test-value");
353 assert_eq!(entry.rule_name, "rule1");
354 assert_eq!(entry.level, SensitivityLevel::Sensitive);
355 assert!(!entry.variants.is_empty());
356 }
357
358 #[test]
359 fn test_url_encoding_variant() {
360 let mut registry = TaintRegistry::new();
361 registry.register("hello world", "test", SensitivityLevel::Sensitive);
362
363 assert!(registry.contains("hello%20world"));
365 }
366
367 #[test]
368 fn test_taint_id_default() {
369 let id1 = TaintId::default();
370 let id2 = TaintId::default();
371 assert_ne!(id1.0, id2.0);
373 }
374
375 #[test]
376 fn test_taint_id_display() {
377 let id = TaintId::new();
378 let display = format!("{}", id);
379 assert!(!display.is_empty());
380 assert_eq!(display.len(), 36);
382 }
383
384 #[test]
385 fn test_taint_registry_default() {
386 let registry = TaintRegistry::default();
387 assert!(registry.lookup("anything").is_none());
388 }
389
390 #[test]
391 fn test_urldecode_valid() {
392 let decoded = urldecode("hello%20world");
393 assert!(decoded.is_ok());
394 assert_eq!(decoded.unwrap(), "hello world");
395 }
396
397 #[test]
398 fn test_urldecode_no_encoding() {
399 let decoded = urldecode("hello");
400 assert!(decoded.is_ok());
401 assert_eq!(decoded.unwrap(), "hello");
402 }
403
404 #[test]
405 fn test_urldecode_invalid_hex_passthrough() {
406 let decoded = urldecode("hello%ZZworld");
408 assert!(decoded.is_ok());
409 assert_eq!(decoded.unwrap(), "hello%ZZworld");
410 }
411
412 #[test]
413 fn test_urlencode_special_chars() {
414 let encoded = urlencode("a b@c");
415 assert_eq!(encoded, "a%20b%40c");
416 }
417
418 #[test]
419 fn test_urlencode_no_special() {
420 let encoded = urlencode("hello");
421 assert_eq!(encoded, "hello");
422 }
423
424 #[test]
425 fn test_generate_variants() {
426 let variants = generate_variants("test");
427 assert!(variants.len() >= 2);
429 assert!(variants.contains(&"dGVzdA==".to_string()));
431 assert!(variants.contains(&"74657374".to_string()));
433 }
434
435 #[test]
436 fn test_generate_variants_with_special_chars() {
437 let variants = generate_variants("hello world");
438 assert!(variants.len() >= 3);
440 assert!(variants.contains(&"hello%20world".to_string()));
441 }
442
443 #[test]
444 fn test_entries_iter() {
445 let mut registry = TaintRegistry::new();
446 registry.register("val1", "rule1", SensitivityLevel::Sensitive);
447 registry.register("val2", "rule2", SensitivityLevel::HighlySensitive);
448
449 let entries: Vec<_> = registry.entries_iter().collect();
450 assert_eq!(entries.len(), 2);
451 }
452}