base_d/features/
detection.rs

1use crate::core::config::{DictionaryRegistry, EncodingMode};
2use crate::core::dictionary::Dictionary;
3use crate::decode;
4use std::collections::HashSet;
5
6/// A match result from dictionary detection.
7#[derive(Debug, Clone)]
8pub struct DictionaryMatch {
9    /// Name of the matched dictionary
10    pub name: String,
11    /// Confidence score (0.0 to 1.0)
12    pub confidence: f64,
13    /// The dictionary itself
14    pub dictionary: Dictionary,
15}
16
17/// Detector for automatically identifying which dictionary was used to encode data.
18pub struct DictionaryDetector {
19    dictionaries: Vec<(String, Dictionary)>,
20}
21
22impl DictionaryDetector {
23    /// Creates a new detector from a configuration.
24    pub fn new(config: &DictionaryRegistry) -> Result<Self, Box<dyn std::error::Error>> {
25        let mut dictionaries = Vec::new();
26
27        for (name, dict_config) in &config.dictionaries {
28            let effective_mode = dict_config.effective_mode();
29            let dictionary = match effective_mode {
30                EncodingMode::ByteRange => {
31                    let start = dict_config
32                        .start_codepoint
33                        .ok_or("ByteRange mode requires start_codepoint")?;
34                    Dictionary::builder()
35                        .mode(effective_mode)
36                        .start_codepoint(start)
37                        .build()?
38                }
39                _ => {
40                    let chars: Vec<char> = dict_config.effective_chars()?.chars().collect();
41                    let padding = dict_config.padding.as_ref().and_then(|s| s.chars().next());
42                    let mut builder = Dictionary::builder().chars(chars).mode(effective_mode);
43                    if let Some(p) = padding {
44                        builder = builder.padding(p);
45                    }
46                    builder.build()?
47                }
48            };
49            dictionaries.push((name.clone(), dictionary));
50        }
51
52        Ok(DictionaryDetector { dictionaries })
53    }
54
55    /// Detect which dictionary was likely used to encode the input.
56    /// Returns matches sorted by confidence (highest first).
57    pub fn detect(&self, input: &str) -> Vec<DictionaryMatch> {
58        let input = input.trim();
59        if input.is_empty() {
60            return Vec::new();
61        }
62
63        let mut matches = Vec::new();
64
65        for (name, dict) in &self.dictionaries {
66            if let Some(confidence) = self.score_dictionary(input, dict) {
67                matches.push(DictionaryMatch {
68                    name: name.clone(),
69                    confidence,
70                    dictionary: dict.clone(),
71                });
72            }
73        }
74
75        // Sort by confidence descending
76        matches.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
77
78        matches
79    }
80
81    /// Score how likely a dictionary matches the input.
82    /// Returns Some(confidence) if it's a plausible match, None otherwise.
83    fn score_dictionary(&self, input: &str, dict: &Dictionary) -> Option<f64> {
84        let mut score = 0.0;
85        let mut weight_sum = 0.0;
86
87        // Weight for each scoring component
88        const CHARSET_WEIGHT: f64 = 0.25;
89        const SPECIFICITY_WEIGHT: f64 = 0.20; // Increased
90        const PADDING_WEIGHT: f64 = 0.30; // Increased (very important for RFC standards)
91        const LENGTH_WEIGHT: f64 = 0.15;
92        const DECODE_WEIGHT: f64 = 0.10;
93
94        // 1. Character set matching
95        let charset_score = self.score_charset(input, dict);
96        score += charset_score * CHARSET_WEIGHT;
97        weight_sum += CHARSET_WEIGHT;
98
99        // If character set score is too low, skip this dictionary
100        if charset_score < 0.5 {
101            return None;
102        }
103
104        // 1.5. Specificity - does this dictionary use a focused character set?
105        let specificity_score = self.score_specificity(input, dict);
106        score += specificity_score * SPECIFICITY_WEIGHT;
107        weight_sum += SPECIFICITY_WEIGHT;
108
109        // 2. Padding detection (for chunked modes)
110        if let Some(padding_score) = self.score_padding(input, dict) {
111            score += padding_score * PADDING_WEIGHT;
112            weight_sum += PADDING_WEIGHT;
113        }
114
115        // 3. Length validation
116        let length_score = self.score_length(input, dict);
117        score += length_score * LENGTH_WEIGHT;
118        weight_sum += LENGTH_WEIGHT;
119
120        // 4. Decode validation (try to actually decode)
121        if let Some(decode_score) = self.score_decode(input, dict) {
122            score += decode_score * DECODE_WEIGHT;
123            weight_sum += DECODE_WEIGHT;
124        }
125
126        // Normalize score
127        if weight_sum > 0.0 {
128            Some(score / weight_sum)
129        } else {
130            None
131        }
132    }
133
134    /// Score based on character set matching.
135    fn score_charset(&self, input: &str, dict: &Dictionary) -> f64 {
136        // Get all unique characters in input (excluding whitespace and padding)
137        let input_chars: HashSet<char> = input
138            .chars()
139            .filter(|c| !c.is_whitespace() && Some(*c) != dict.padding())
140            .collect();
141
142        if input_chars.is_empty() {
143            return 0.0;
144        }
145
146        // For ByteRange mode, check if characters are in the expected range
147        if let Some(start) = dict.start_codepoint() {
148            let in_range = input_chars
149                .iter()
150                .filter(|&&c| {
151                    let code = c as u32;
152                    code >= start && code < start + 256
153                })
154                .count();
155            return in_range as f64 / input_chars.len() as f64;
156        }
157
158        // Check if all input characters are in the dictionary
159        let mut valid_count = 0;
160        for c in &input_chars {
161            if dict.decode_char(*c).is_some() {
162                valid_count += 1;
163            }
164        }
165
166        if valid_count < input_chars.len() {
167            // Not all characters are valid - reject this dictionary
168            return 0.0;
169        }
170
171        // All characters are valid. Now check how well the dictionary size matches
172        let dict_size = dict.base();
173        let input_unique = input_chars.len();
174
175        // Calculate what percentage of the dictionary is actually used
176        let usage_ratio = input_unique as f64 / dict_size as f64;
177
178        // Prefer dictionaries where we use most of the character set
179        // This helps distinguish base64 (64 chars) from base85 (85 chars)
180        if usage_ratio > 0.7 {
181            // We're using >70% of dictionary - excellent match
182            1.0
183        } else if usage_ratio > 0.5 {
184            // We're using >50% of dictionary - good match
185            0.85
186        } else if usage_ratio > 0.3 {
187            // We're using >30% of dictionary - okay match
188            0.7
189        } else {
190            // We're using <30% of dictionary - probably wrong
191            // (e.g., using 20 chars of a 85-char dictionary)
192            0.5
193        }
194    }
195
196    /// Score based on how specific/focused the dictionary character set is.
197    /// Smaller, more focused dictionaries score higher.
198    fn score_specificity(&self, _input: &str, dict: &Dictionary) -> f64 {
199        let dict_size = dict.base();
200
201        // Prefer smaller, more common dictionaries
202        // This helps distinguish base64 (64) from base85 (85) when both match
203        match dict_size {
204            16 => 1.0,   // hex
205            32 => 0.95,  // base32
206            58 => 0.90,  // base58
207            62 => 0.88,  // base62
208            64 => 0.92,  // base64 (very common)
209            85 => 0.70,  // base85 (less common)
210            256 => 0.60, // base256
211            _ if dict_size < 64 => 0.85,
212            _ if dict_size < 128 => 0.75,
213            _ => 0.65,
214        }
215    }
216
217    /// Score based on padding character presence and position.
218    fn score_padding(&self, input: &str, dict: &Dictionary) -> Option<f64> {
219        let padding = dict.padding()?;
220
221        // Chunked modes should have padding at the end (or no padding)
222        if *dict.mode() == EncodingMode::Chunked {
223            let has_padding = input.ends_with(padding);
224            let padding_count = input.chars().filter(|c| *c == padding).count();
225
226            if has_padding {
227                // Padding should only be at the end
228                let trimmed = input.trim_end_matches(padding);
229                let internal_padding = trimmed.chars().any(|c| c == padding);
230
231                if internal_padding {
232                    Some(0.5) // Suspicious padding in middle
233                } else if padding_count <= 3 {
234                    Some(1.0) // Valid padding
235                } else {
236                    Some(0.3) // Too much padding
237                }
238            } else {
239                // No padding is also valid for chunked mode
240                Some(0.8)
241            }
242        } else {
243            None
244        }
245    }
246
247    /// Score based on input length validation for the encoding mode.
248    fn score_length(&self, input: &str, dict: &Dictionary) -> f64 {
249        let length = input.trim().len();
250
251        match dict.mode() {
252            EncodingMode::Chunked => {
253                // Chunked mode should have specific alignment
254                let base = dict.base();
255
256                // Remove padding to check alignment
257                let trimmed = if let Some(pad) = dict.padding() {
258                    input.trim_end_matches(pad)
259                } else {
260                    input
261                };
262
263                // For base64 (6 bits per char), output should be multiple of 4
264                // For base32 (5 bits per char), output should be multiple of 8
265                // For base16 (4 bits per char), output should be multiple of 2
266                let expected_multiple = match base {
267                    64 => 4,
268                    32 => 8,
269                    16 => 2,
270                    _ => return 0.5, // Unknown chunked base
271                };
272
273                if trimmed.len() % expected_multiple == 0 {
274                    1.0
275                } else {
276                    0.3
277                }
278            }
279            EncodingMode::ByteRange => {
280                // ByteRange is 1:1 mapping, any length is valid
281                1.0
282            }
283            EncodingMode::Radix => {
284                // Radix conversion can produce any length
285                if length > 0 { 1.0 } else { 0.0 }
286            }
287        }
288    }
289
290    /// Score based on whether the input can be successfully decoded.
291    fn score_decode(&self, input: &str, dict: &Dictionary) -> Option<f64> {
292        match decode(input, dict) {
293            Ok(decoded) => {
294                if decoded.is_empty() {
295                    Some(0.5)
296                } else {
297                    // Successfully decoded!
298                    Some(1.0)
299                }
300            }
301            Err(_) => {
302                // Failed to decode
303                Some(0.0)
304            }
305        }
306    }
307}
308
309/// Convenience function to detect dictionary from input.
310pub fn detect_dictionary(input: &str) -> Result<Vec<DictionaryMatch>, Box<dyn std::error::Error>> {
311    let config = DictionaryRegistry::load_with_overrides()?;
312    let detector = DictionaryDetector::new(&config)?;
313    Ok(detector.detect(input))
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use crate::encode;
320
321    #[test]
322    fn test_detect_base64() {
323        let config = DictionaryRegistry::load_default().unwrap();
324        let detector = DictionaryDetector::new(&config).unwrap();
325
326        // Standard base64 with padding
327        let matches = detector.detect("SGVsbG8sIFdvcmxkIQ==");
328        assert!(!matches.is_empty());
329        // base64 and base64url are very similar, so either is acceptable
330        assert!(matches[0].name == "base64" || matches[0].name == "base64url");
331        assert!(matches[0].confidence > 0.7);
332    }
333
334    #[test]
335    fn test_detect_base32() {
336        let config = DictionaryRegistry::load_default().unwrap();
337        let detector = DictionaryDetector::new(&config).unwrap();
338
339        let matches = detector.detect("JBSWY3DPEBLW64TMMQ======");
340        assert!(!matches.is_empty());
341        // base32 should be in top 10 candidates (more dictionaries now)
342        let base32_found = matches
343            .iter()
344            .take(10)
345            .any(|m| m.name.starts_with("base32"));
346        assert!(base32_found, "base32 should be in top 10 candidates");
347    }
348
349    #[test]
350    fn test_detect_hex() {
351        let config = DictionaryRegistry::load_default().unwrap();
352        let detector = DictionaryDetector::new(&config).unwrap();
353
354        let matches = detector.detect("48656c6c6f");
355        assert!(!matches.is_empty());
356        // hex or hex_radix are both correct
357        assert!(matches[0].name == "hex" || matches[0].name == "hex_radix");
358        assert!(matches[0].confidence > 0.8);
359    }
360
361    #[test]
362    fn test_detect_from_encoded() {
363        let config = DictionaryRegistry::load_default().unwrap();
364
365        // Test with actual encoding
366        let dict_config = config.get_dictionary("base64").unwrap();
367        let chars: Vec<char> = dict_config.effective_chars().unwrap().chars().collect();
368        let padding = dict_config.padding.as_ref().and_then(|s| s.chars().next());
369        let mut builder = Dictionary::builder()
370            .chars(chars)
371            .mode(dict_config.effective_mode());
372        if let Some(p) = padding {
373            builder = builder.padding(p);
374        }
375        let dict = builder.build().unwrap();
376
377        let data = b"Hello, World!";
378        let encoded = encode(data, &dict);
379
380        let detector = DictionaryDetector::new(&config).unwrap();
381        let matches = detector.detect(&encoded);
382
383        assert!(!matches.is_empty());
384        // base64 and base64url only differ by 2 chars, so both are valid
385        assert!(matches[0].name == "base64" || matches[0].name == "base64url");
386    }
387
388    #[test]
389    fn test_detect_empty_input() {
390        let config = DictionaryRegistry::load_default().unwrap();
391        let detector = DictionaryDetector::new(&config).unwrap();
392
393        let matches = detector.detect("");
394        assert!(matches.is_empty());
395    }
396
397    #[test]
398    fn test_detect_invalid_input() {
399        let config = DictionaryRegistry::load_default().unwrap();
400        let detector = DictionaryDetector::new(&config).unwrap();
401
402        // Input with characters not in any dictionary
403        let matches = detector.detect("こんにちは世界");
404        // Should return few or no high-confidence matches
405        if !matches.is_empty() {
406            assert!(matches[0].confidence < 0.5);
407        }
408    }
409}