base_d/features/
detection.rs

1use crate::core::config::{DictionaryRegistry, EncodingMode};
2use crate::core::dictionary::Dictionary;
3use crate::decode;
4use std::collections::HashSet;
5
6/// A match result from dictionary detection.
7#[derive(Debug, Clone)]
8pub struct DictionaryMatch {
9    /// Name of the matched dictionary
10    pub name: String,
11    /// Confidence score (0.0 to 1.0)
12    pub confidence: f64,
13    /// The dictionary itself
14    pub dictionary: Dictionary,
15}
16
17/// Detector for automatically identifying which dictionary was used to encode data.
18pub struct DictionaryDetector {
19    dictionaries: Vec<(String, Dictionary)>,
20}
21
22impl DictionaryDetector {
23    /// Creates a new detector from a configuration.
24    pub fn new(config: &DictionaryRegistry) -> Result<Self, Box<dyn std::error::Error>> {
25        let mut dictionaries = Vec::new();
26
27        for (name, dict_config) in &config.dictionaries {
28            let dictionary = match dict_config.mode {
29                EncodingMode::ByteRange => {
30                    let start = dict_config
31                        .start_codepoint
32                        .ok_or("ByteRange mode requires start_codepoint")?;
33                    Dictionary::builder()
34                        .mode(dict_config.mode.clone())
35                        .start_codepoint(start)
36                        .build()?
37                }
38                _ => {
39                    let chars: Vec<char> = dict_config.chars.chars().collect();
40                    let padding = dict_config.padding.as_ref().and_then(|s| s.chars().next());
41                    let mut builder = Dictionary::builder()
42                        .chars(chars)
43                        .mode(dict_config.mode.clone());
44                    if let Some(p) = padding {
45                        builder = builder.padding(p);
46                    }
47                    builder.build()?
48                }
49            };
50            dictionaries.push((name.clone(), dictionary));
51        }
52
53        Ok(DictionaryDetector { dictionaries })
54    }
55
56    /// Detect which dictionary was likely used to encode the input.
57    /// Returns matches sorted by confidence (highest first).
58    pub fn detect(&self, input: &str) -> Vec<DictionaryMatch> {
59        let input = input.trim();
60        if input.is_empty() {
61            return Vec::new();
62        }
63
64        let mut matches = Vec::new();
65
66        for (name, dict) in &self.dictionaries {
67            if let Some(confidence) = self.score_dictionary(input, dict) {
68                matches.push(DictionaryMatch {
69                    name: name.clone(),
70                    confidence,
71                    dictionary: dict.clone(),
72                });
73            }
74        }
75
76        // Sort by confidence descending
77        matches.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
78
79        matches
80    }
81
82    /// Score how likely a dictionary matches the input.
83    /// Returns Some(confidence) if it's a plausible match, None otherwise.
84    fn score_dictionary(&self, input: &str, dict: &Dictionary) -> Option<f64> {
85        let mut score = 0.0;
86        let mut weight_sum = 0.0;
87
88        // Weight for each scoring component
89        const CHARSET_WEIGHT: f64 = 0.25;
90        const SPECIFICITY_WEIGHT: f64 = 0.20; // Increased
91        const PADDING_WEIGHT: f64 = 0.30; // Increased (very important for RFC standards)
92        const LENGTH_WEIGHT: f64 = 0.15;
93        const DECODE_WEIGHT: f64 = 0.10;
94
95        // 1. Character set matching
96        let charset_score = self.score_charset(input, dict);
97        score += charset_score * CHARSET_WEIGHT;
98        weight_sum += CHARSET_WEIGHT;
99
100        // If character set score is too low, skip this dictionary
101        if charset_score < 0.5 {
102            return None;
103        }
104
105        // 1.5. Specificity - does this dictionary use a focused character set?
106        let specificity_score = self.score_specificity(input, dict);
107        score += specificity_score * SPECIFICITY_WEIGHT;
108        weight_sum += SPECIFICITY_WEIGHT;
109
110        // 2. Padding detection (for chunked modes)
111        if let Some(padding_score) = self.score_padding(input, dict) {
112            score += padding_score * PADDING_WEIGHT;
113            weight_sum += PADDING_WEIGHT;
114        }
115
116        // 3. Length validation
117        let length_score = self.score_length(input, dict);
118        score += length_score * LENGTH_WEIGHT;
119        weight_sum += LENGTH_WEIGHT;
120
121        // 4. Decode validation (try to actually decode)
122        if let Some(decode_score) = self.score_decode(input, dict) {
123            score += decode_score * DECODE_WEIGHT;
124            weight_sum += DECODE_WEIGHT;
125        }
126
127        // Normalize score
128        if weight_sum > 0.0 {
129            Some(score / weight_sum)
130        } else {
131            None
132        }
133    }
134
135    /// Score based on character set matching.
136    fn score_charset(&self, input: &str, dict: &Dictionary) -> f64 {
137        // Get all unique characters in input (excluding whitespace and padding)
138        let input_chars: HashSet<char> = input
139            .chars()
140            .filter(|c| !c.is_whitespace() && Some(*c) != dict.padding())
141            .collect();
142
143        if input_chars.is_empty() {
144            return 0.0;
145        }
146
147        // For ByteRange mode, check if characters are in the expected range
148        if let Some(start) = dict.start_codepoint() {
149            let in_range = input_chars
150                .iter()
151                .filter(|&&c| {
152                    let code = c as u32;
153                    code >= start && code < start + 256
154                })
155                .count();
156            return in_range as f64 / input_chars.len() as f64;
157        }
158
159        // Check if all input characters are in the dictionary
160        let mut valid_count = 0;
161        for c in &input_chars {
162            if dict.decode_char(*c).is_some() {
163                valid_count += 1;
164            }
165        }
166
167        if valid_count < input_chars.len() {
168            // Not all characters are valid - reject this dictionary
169            return 0.0;
170        }
171
172        // All characters are valid. Now check how well the dictionary size matches
173        let dict_size = dict.base();
174        let input_unique = input_chars.len();
175
176        // Calculate what percentage of the dictionary is actually used
177        let usage_ratio = input_unique as f64 / dict_size as f64;
178
179        // Prefer dictionaries where we use most of the character set
180        // This helps distinguish base64 (64 chars) from base85 (85 chars)
181        if usage_ratio > 0.7 {
182            // We're using >70% of dictionary - excellent match
183            1.0
184        } else if usage_ratio > 0.5 {
185            // We're using >50% of dictionary - good match
186            0.85
187        } else if usage_ratio > 0.3 {
188            // We're using >30% of dictionary - okay match
189            0.7
190        } else {
191            // We're using <30% of dictionary - probably wrong
192            // (e.g., using 20 chars of a 85-char dictionary)
193            0.5
194        }
195    }
196
197    /// Score based on how specific/focused the dictionary character set is.
198    /// Smaller, more focused dictionaries score higher.
199    fn score_specificity(&self, _input: &str, dict: &Dictionary) -> f64 {
200        let dict_size = dict.base();
201
202        // Prefer smaller, more common dictionaries
203        // This helps distinguish base64 (64) from base85 (85) when both match
204        match dict_size {
205            16 => 1.0,   // hex
206            32 => 0.95,  // base32
207            58 => 0.90,  // base58
208            62 => 0.88,  // base62
209            64 => 0.92,  // base64 (very common)
210            85 => 0.70,  // base85 (less common)
211            256 => 0.60, // base256
212            _ if dict_size < 64 => 0.85,
213            _ if dict_size < 128 => 0.75,
214            _ => 0.65,
215        }
216    }
217
218    /// Score based on padding character presence and position.
219    fn score_padding(&self, input: &str, dict: &Dictionary) -> Option<f64> {
220        let padding = dict.padding()?;
221
222        // Chunked modes should have padding at the end (or no padding)
223        if *dict.mode() == EncodingMode::Chunked {
224            let has_padding = input.ends_with(padding);
225            let padding_count = input.chars().filter(|c| *c == padding).count();
226
227            if has_padding {
228                // Padding should only be at the end
229                let trimmed = input.trim_end_matches(padding);
230                let internal_padding = trimmed.chars().any(|c| c == padding);
231
232                if internal_padding {
233                    Some(0.5) // Suspicious padding in middle
234                } else if padding_count <= 3 {
235                    Some(1.0) // Valid padding
236                } else {
237                    Some(0.3) // Too much padding
238                }
239            } else {
240                // No padding is also valid for chunked mode
241                Some(0.8)
242            }
243        } else {
244            None
245        }
246    }
247
248    /// Score based on input length validation for the encoding mode.
249    fn score_length(&self, input: &str, dict: &Dictionary) -> f64 {
250        let length = input.trim().len();
251
252        match dict.mode() {
253            EncodingMode::Chunked => {
254                // Chunked mode should have specific alignment
255                let base = dict.base();
256
257                // Remove padding to check alignment
258                let trimmed = if let Some(pad) = dict.padding() {
259                    input.trim_end_matches(pad)
260                } else {
261                    input
262                };
263
264                // For base64 (6 bits per char), output should be multiple of 4
265                // For base32 (5 bits per char), output should be multiple of 8
266                // For base16 (4 bits per char), output should be multiple of 2
267                let expected_multiple = match base {
268                    64 => 4,
269                    32 => 8,
270                    16 => 2,
271                    _ => return 0.5, // Unknown chunked base
272                };
273
274                if trimmed.len() % expected_multiple == 0 {
275                    1.0
276                } else {
277                    0.3
278                }
279            }
280            EncodingMode::ByteRange => {
281                // ByteRange is 1:1 mapping, any length is valid
282                1.0
283            }
284            EncodingMode::BaseConversion => {
285                // Mathematical conversion can produce any length
286                if length > 0 { 1.0 } else { 0.0 }
287            }
288        }
289    }
290
291    /// Score based on whether the input can be successfully decoded.
292    fn score_decode(&self, input: &str, dict: &Dictionary) -> Option<f64> {
293        match decode(input, dict) {
294            Ok(decoded) => {
295                if decoded.is_empty() {
296                    Some(0.5)
297                } else {
298                    // Successfully decoded!
299                    Some(1.0)
300                }
301            }
302            Err(_) => {
303                // Failed to decode
304                Some(0.0)
305            }
306        }
307    }
308}
309
310/// Convenience function to detect dictionary from input.
311pub fn detect_dictionary(input: &str) -> Result<Vec<DictionaryMatch>, Box<dyn std::error::Error>> {
312    let config = DictionaryRegistry::load_with_overrides()?;
313    let detector = DictionaryDetector::new(&config)?;
314    Ok(detector.detect(input))
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use crate::encode;
321
322    #[test]
323    fn test_detect_base64() {
324        let config = DictionaryRegistry::load_default().unwrap();
325        let detector = DictionaryDetector::new(&config).unwrap();
326
327        // Standard base64 with padding
328        let matches = detector.detect("SGVsbG8sIFdvcmxkIQ==");
329        assert!(!matches.is_empty());
330        // base64 and base64url are very similar, so either is acceptable
331        assert!(matches[0].name == "base64" || matches[0].name == "base64url");
332        assert!(matches[0].confidence > 0.7);
333    }
334
335    #[test]
336    fn test_detect_base32() {
337        let config = DictionaryRegistry::load_default().unwrap();
338        let detector = DictionaryDetector::new(&config).unwrap();
339
340        let matches = detector.detect("JBSWY3DPEBLW64TMMQ======");
341        assert!(!matches.is_empty());
342        // base32 should be in top 5 candidates
343        let base32_found = matches.iter().take(5).any(|m| m.name.starts_with("base32"));
344        assert!(base32_found, "base32 should be in top 5 candidates");
345    }
346
347    #[test]
348    fn test_detect_hex() {
349        let config = DictionaryRegistry::load_default().unwrap();
350        let detector = DictionaryDetector::new(&config).unwrap();
351
352        let matches = detector.detect("48656c6c6f");
353        assert!(!matches.is_empty());
354        // hex or hex_math are both correct
355        assert!(matches[0].name == "hex" || matches[0].name == "hex_math");
356        assert!(matches[0].confidence > 0.8);
357    }
358
359    #[test]
360    fn test_detect_from_encoded() {
361        let config = DictionaryRegistry::load_default().unwrap();
362
363        // Test with actual encoding
364        let dict_config = config.get_dictionary("base64").unwrap();
365        let chars: Vec<char> = dict_config.chars.chars().collect();
366        let padding = dict_config.padding.as_ref().and_then(|s| s.chars().next());
367        let mut builder = Dictionary::builder()
368            .chars(chars)
369            .mode(dict_config.mode.clone());
370        if let Some(p) = padding {
371            builder = builder.padding(p);
372        }
373        let dict = builder.build().unwrap();
374
375        let data = b"Hello, World!";
376        let encoded = encode(data, &dict);
377
378        let detector = DictionaryDetector::new(&config).unwrap();
379        let matches = detector.detect(&encoded);
380
381        assert!(!matches.is_empty());
382        // base64 and base64url only differ by 2 chars, so both are valid
383        assert!(matches[0].name == "base64" || matches[0].name == "base64url");
384    }
385
386    #[test]
387    fn test_detect_empty_input() {
388        let config = DictionaryRegistry::load_default().unwrap();
389        let detector = DictionaryDetector::new(&config).unwrap();
390
391        let matches = detector.detect("");
392        assert!(matches.is_empty());
393    }
394
395    #[test]
396    fn test_detect_invalid_input() {
397        let config = DictionaryRegistry::load_default().unwrap();
398        let detector = DictionaryDetector::new(&config).unwrap();
399
400        // Input with characters not in any dictionary
401        let matches = detector.detect("こんにちは世界");
402        // Should return few or no high-confidence matches
403        if !matches.is_empty() {
404            assert!(matches[0].confidence < 0.5);
405        }
406    }
407}