base_d/features/
detection.rs

1use crate::core::config::{DictionaryRegistry, DictionaryType, EncodingMode};
2use crate::core::dictionary::Dictionary;
3use crate::decode;
4use std::collections::HashSet;
5
6/// A match result from dictionary detection.
7#[derive(Debug, Clone)]
8pub struct DictionaryMatch {
9    /// Name of the matched dictionary
10    pub name: String,
11    /// Confidence score (0.0 to 1.0)
12    pub confidence: f64,
13    /// The dictionary itself
14    pub dictionary: Dictionary,
15}
16
17/// Detector for automatically identifying which dictionary was used to encode data.
18pub struct DictionaryDetector {
19    dictionaries: Vec<(String, Dictionary)>,
20}
21
22impl DictionaryDetector {
23    /// Creates a new detector from a configuration.
24    ///
25    /// Note: Word-based dictionaries are skipped as they require different detection logic.
26    pub fn new(config: &DictionaryRegistry) -> Result<Self, Box<dyn std::error::Error>> {
27        let mut dictionaries = Vec::new();
28
29        for (name, dict_config) in &config.dictionaries {
30            // Skip word-based dictionaries - they use different encoding
31            if dict_config.dictionary_type == DictionaryType::Word {
32                continue;
33            }
34
35            let effective_mode = dict_config.effective_mode();
36            let dictionary = match effective_mode {
37                EncodingMode::ByteRange => {
38                    let start = dict_config
39                        .start_codepoint
40                        .ok_or("ByteRange mode requires start_codepoint")?;
41                    Dictionary::builder()
42                        .mode(effective_mode)
43                        .start_codepoint(start)
44                        .build()?
45                }
46                _ => {
47                    let chars: Vec<char> = dict_config.effective_chars()?.chars().collect();
48                    let padding = dict_config.padding.as_ref().and_then(|s| s.chars().next());
49                    let mut builder = Dictionary::builder().chars(chars).mode(effective_mode);
50                    if let Some(p) = padding {
51                        builder = builder.padding(p);
52                    }
53                    builder.build()?
54                }
55            };
56            dictionaries.push((name.clone(), dictionary));
57        }
58
59        Ok(DictionaryDetector { dictionaries })
60    }
61
62    /// Detect which dictionary was likely used to encode the input.
63    /// Returns matches sorted by confidence (highest first).
64    pub fn detect(&self, input: &str) -> Vec<DictionaryMatch> {
65        let input = input.trim();
66        if input.is_empty() {
67            return Vec::new();
68        }
69
70        let mut matches = Vec::new();
71
72        for (name, dict) in &self.dictionaries {
73            if let Some(confidence) = self.score_dictionary(input, dict) {
74                matches.push(DictionaryMatch {
75                    name: name.clone(),
76                    confidence,
77                    dictionary: dict.clone(),
78                });
79            }
80        }
81
82        // Sort by confidence descending
83        matches.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
84
85        matches
86    }
87
88    /// Score how likely a dictionary matches the input.
89    /// Returns Some(confidence) if it's a plausible match, None otherwise.
90    fn score_dictionary(&self, input: &str, dict: &Dictionary) -> Option<f64> {
91        let mut score = 0.0;
92        let mut weight_sum = 0.0;
93
94        // Weight for each scoring component
95        const CHARSET_WEIGHT: f64 = 0.25;
96        const SPECIFICITY_WEIGHT: f64 = 0.20; // Increased
97        const PADDING_WEIGHT: f64 = 0.30; // Increased (very important for RFC standards)
98        const LENGTH_WEIGHT: f64 = 0.15;
99        const DECODE_WEIGHT: f64 = 0.10;
100
101        // 1. Character set matching
102        let charset_score = self.score_charset(input, dict);
103        score += charset_score * CHARSET_WEIGHT;
104        weight_sum += CHARSET_WEIGHT;
105
106        // If character set score is too low, skip this dictionary
107        if charset_score < 0.5 {
108            return None;
109        }
110
111        // 1.5. Specificity - does this dictionary use a focused character set?
112        let specificity_score = self.score_specificity(input, dict);
113        score += specificity_score * SPECIFICITY_WEIGHT;
114        weight_sum += SPECIFICITY_WEIGHT;
115
116        // 2. Padding detection (for chunked modes)
117        if let Some(padding_score) = self.score_padding(input, dict) {
118            score += padding_score * PADDING_WEIGHT;
119            weight_sum += PADDING_WEIGHT;
120        }
121
122        // 3. Length validation
123        let length_score = self.score_length(input, dict);
124        score += length_score * LENGTH_WEIGHT;
125        weight_sum += LENGTH_WEIGHT;
126
127        // 4. Decode validation (try to actually decode)
128        if let Some(decode_score) = self.score_decode(input, dict) {
129            score += decode_score * DECODE_WEIGHT;
130            weight_sum += DECODE_WEIGHT;
131        }
132
133        // Normalize score
134        if weight_sum > 0.0 {
135            Some(score / weight_sum)
136        } else {
137            None
138        }
139    }
140
141    /// Score based on character set matching.
142    fn score_charset(&self, input: &str, dict: &Dictionary) -> f64 {
143        // Get all unique characters in input (excluding whitespace and padding)
144        let input_chars: HashSet<char> = input
145            .chars()
146            .filter(|c| !c.is_whitespace() && Some(*c) != dict.padding())
147            .collect();
148
149        if input_chars.is_empty() {
150            return 0.0;
151        }
152
153        // For ByteRange mode, check if characters are in the expected range
154        if let Some(start) = dict.start_codepoint() {
155            let in_range = input_chars
156                .iter()
157                .filter(|&&c| {
158                    let code = c as u32;
159                    code >= start && code < start + 256
160                })
161                .count();
162            return in_range as f64 / input_chars.len() as f64;
163        }
164
165        // Check if all input characters are in the dictionary
166        let mut valid_count = 0;
167        for c in &input_chars {
168            if dict.decode_char(*c).is_some() {
169                valid_count += 1;
170            }
171        }
172
173        if valid_count < input_chars.len() {
174            // Not all characters are valid - reject this dictionary
175            return 0.0;
176        }
177
178        // All characters are valid. Now check how well the dictionary size matches
179        let dict_size = dict.base();
180        let input_unique = input_chars.len();
181
182        // Calculate what percentage of the dictionary is actually used
183        let usage_ratio = input_unique as f64 / dict_size as f64;
184
185        // Prefer dictionaries where we use most of the character set
186        // This helps distinguish base64 (64 chars) from base85 (85 chars)
187        if usage_ratio > 0.7 {
188            // We're using >70% of dictionary - excellent match
189            1.0
190        } else if usage_ratio > 0.5 {
191            // We're using >50% of dictionary - good match
192            0.85
193        } else if usage_ratio > 0.3 {
194            // We're using >30% of dictionary - okay match
195            0.7
196        } else {
197            // We're using <30% of dictionary - probably wrong
198            // (e.g., using 20 chars of a 85-char dictionary)
199            0.5
200        }
201    }
202
203    /// Score based on how specific/focused the dictionary character set is.
204    /// Smaller, more focused dictionaries score higher.
205    fn score_specificity(&self, _input: &str, dict: &Dictionary) -> f64 {
206        let dict_size = dict.base();
207
208        // Prefer smaller, more common dictionaries
209        // This helps distinguish base64 (64) from base85 (85) when both match
210        match dict_size {
211            16 => 1.0,   // hex
212            32 => 0.95,  // base32
213            58 => 0.90,  // base58
214            62 => 0.88,  // base62
215            64 => 0.92,  // base64 (very common)
216            85 => 0.70,  // base85 (less common)
217            256 => 0.60, // base256
218            _ if dict_size < 64 => 0.85,
219            _ if dict_size < 128 => 0.75,
220            _ => 0.65,
221        }
222    }
223
224    /// Score based on padding character presence and position.
225    fn score_padding(&self, input: &str, dict: &Dictionary) -> Option<f64> {
226        let padding = dict.padding()?;
227
228        // Chunked modes should have padding at the end (or no padding)
229        if *dict.mode() == EncodingMode::Chunked {
230            let has_padding = input.ends_with(padding);
231            let padding_count = input.chars().filter(|c| *c == padding).count();
232
233            if has_padding {
234                // Padding should only be at the end
235                let trimmed = input.trim_end_matches(padding);
236                let internal_padding = trimmed.chars().any(|c| c == padding);
237
238                if internal_padding {
239                    Some(0.5) // Suspicious padding in middle
240                } else if padding_count <= 3 {
241                    Some(1.0) // Valid padding
242                } else {
243                    Some(0.3) // Too much padding
244                }
245            } else {
246                // No padding is also valid for chunked mode
247                Some(0.8)
248            }
249        } else {
250            None
251        }
252    }
253
254    /// Score based on input length validation for the encoding mode.
255    fn score_length(&self, input: &str, dict: &Dictionary) -> f64 {
256        let length = input.trim().len();
257
258        match dict.mode() {
259            EncodingMode::Chunked => {
260                // Chunked mode should have specific alignment
261                let base = dict.base();
262
263                // Remove padding to check alignment
264                let trimmed = if let Some(pad) = dict.padding() {
265                    input.trim_end_matches(pad)
266                } else {
267                    input
268                };
269
270                // For base64 (6 bits per char), output should be multiple of 4
271                // For base32 (5 bits per char), output should be multiple of 8
272                // For base16 (4 bits per char), output should be multiple of 2
273                let expected_multiple = match base {
274                    64 => 4,
275                    32 => 8,
276                    16 => 2,
277                    _ => return 0.5, // Unknown chunked base
278                };
279
280                if trimmed.len() % expected_multiple == 0 {
281                    1.0
282                } else {
283                    0.3
284                }
285            }
286            EncodingMode::ByteRange => {
287                // ByteRange is 1:1 mapping, any length is valid
288                1.0
289            }
290            EncodingMode::Radix => {
291                // Radix conversion can produce any length
292                if length > 0 { 1.0 } else { 0.0 }
293            }
294        }
295    }
296
297    /// Score based on whether the input can be successfully decoded.
298    fn score_decode(&self, input: &str, dict: &Dictionary) -> Option<f64> {
299        match decode(input, dict) {
300            Ok(decoded) => {
301                if decoded.is_empty() {
302                    Some(0.5)
303                } else {
304                    // Successfully decoded!
305                    Some(1.0)
306                }
307            }
308            Err(_) => {
309                // Failed to decode
310                Some(0.0)
311            }
312        }
313    }
314}
315
316/// Convenience function to detect dictionary from input.
317pub fn detect_dictionary(input: &str) -> Result<Vec<DictionaryMatch>, Box<dyn std::error::Error>> {
318    let config = DictionaryRegistry::load_with_overrides()?;
319    let detector = DictionaryDetector::new(&config)?;
320    Ok(detector.detect(input))
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use crate::encode;
327
328    #[test]
329    fn test_detect_base64() {
330        let config = DictionaryRegistry::load_default().unwrap();
331        let detector = DictionaryDetector::new(&config).unwrap();
332
333        // Standard base64 with padding
334        let matches = detector.detect("SGVsbG8sIFdvcmxkIQ==");
335        assert!(!matches.is_empty());
336        // base64 and base64url are very similar, so either is acceptable
337        assert!(matches[0].name == "base64" || matches[0].name == "base64url");
338        assert!(matches[0].confidence > 0.7);
339    }
340
341    #[test]
342    fn test_detect_base32() {
343        let config = DictionaryRegistry::load_default().unwrap();
344        let detector = DictionaryDetector::new(&config).unwrap();
345
346        let matches = detector.detect("JBSWY3DPEBLW64TMMQ======");
347        assert!(!matches.is_empty());
348        // base32 should be in top 10 candidates (more dictionaries now)
349        let base32_found = matches
350            .iter()
351            .take(10)
352            .any(|m| m.name.starts_with("base32"));
353        assert!(base32_found, "base32 should be in top 10 candidates");
354    }
355
356    #[test]
357    fn test_detect_hex() {
358        let config = DictionaryRegistry::load_default().unwrap();
359        let detector = DictionaryDetector::new(&config).unwrap();
360
361        let matches = detector.detect("48656c6c6f");
362        assert!(!matches.is_empty());
363        // hex or hex_radix are both correct
364        assert!(matches[0].name == "hex" || matches[0].name == "hex_radix");
365        assert!(matches[0].confidence > 0.8);
366    }
367
368    #[test]
369    fn test_detect_from_encoded() {
370        let config = DictionaryRegistry::load_default().unwrap();
371
372        // Test with actual encoding
373        let dict_config = config.get_dictionary("base64").unwrap();
374        let chars: Vec<char> = dict_config.effective_chars().unwrap().chars().collect();
375        let padding = dict_config.padding.as_ref().and_then(|s| s.chars().next());
376        let mut builder = Dictionary::builder()
377            .chars(chars)
378            .mode(dict_config.effective_mode());
379        if let Some(p) = padding {
380            builder = builder.padding(p);
381        }
382        let dict = builder.build().unwrap();
383
384        let data = b"Hello, World!";
385        let encoded = encode(data, &dict);
386
387        let detector = DictionaryDetector::new(&config).unwrap();
388        let matches = detector.detect(&encoded);
389
390        assert!(!matches.is_empty());
391        // base64 and base64url only differ by 2 chars, so both are valid
392        assert!(matches[0].name == "base64" || matches[0].name == "base64url");
393    }
394
395    #[test]
396    fn test_detect_empty_input() {
397        let config = DictionaryRegistry::load_default().unwrap();
398        let detector = DictionaryDetector::new(&config).unwrap();
399
400        let matches = detector.detect("");
401        assert!(matches.is_empty());
402    }
403
404    #[test]
405    fn test_detect_invalid_input() {
406        let config = DictionaryRegistry::load_default().unwrap();
407        let detector = DictionaryDetector::new(&config).unwrap();
408
409        // Input with characters not in any dictionary
410        let matches = detector.detect("こんにちは世界");
411        // Should return few or no high-confidence matches
412        if !matches.is_empty() {
413            assert!(matches[0].confidence < 0.5);
414        }
415    }
416}