base_d/encoders/algorithms/
errors.rs

1use std::fmt;
2
3/// Errors that can occur during decoding.
4#[derive(Debug, PartialEq, Eq)]
5pub enum DecodeError {
6    /// The input contains a character not in the dictionary
7    InvalidCharacter {
8        char: char,
9        position: usize,
10        input: String,
11        valid_chars: String,
12    },
13    /// The input string is empty
14    EmptyInput,
15    /// The padding is malformed or incorrect
16    InvalidPadding,
17    /// Invalid length for the encoding format
18    InvalidLength {
19        actual: usize,
20        expected: String,
21        hint: String,
22    },
23}
24
25impl DecodeError {
26    /// Create an InvalidCharacter error with context
27    pub fn invalid_character(c: char, position: usize, input: &str, valid_chars: &str) -> Self {
28        // Truncate long inputs
29        let display_input = if input.len() > 60 {
30            format!("{}...", &input[..60])
31        } else {
32            input.to_string()
33        };
34
35        DecodeError::InvalidCharacter {
36            char: c,
37            position,
38            input: display_input,
39            valid_chars: valid_chars.to_string(),
40        }
41    }
42
43    /// Create an InvalidLength error
44    pub fn invalid_length(
45        actual: usize,
46        expected: impl Into<String>,
47        hint: impl Into<String>,
48    ) -> Self {
49        DecodeError::InvalidLength {
50            actual,
51            expected: expected.into(),
52            hint: hint.into(),
53        }
54    }
55}
56
57impl fmt::Display for DecodeError {
58    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59        let use_color = should_use_color();
60
61        match self {
62            DecodeError::InvalidCharacter {
63                char: c,
64                position,
65                input,
66                valid_chars,
67            } => {
68                // Error header
69                if use_color {
70                    writeln!(
71                        f,
72                        "\x1b[1;31merror:\x1b[0m invalid character '{}' at position {}",
73                        c, position
74                    )?;
75                } else {
76                    writeln!(
77                        f,
78                        "error: invalid character '{}' at position {}",
79                        c, position
80                    )?;
81                }
82                writeln!(f)?;
83
84                // Show input with caret pointing at error position
85                // Need to account for multi-byte UTF-8 characters
86                let char_position = input.chars().take(*position).count();
87                writeln!(f, "  {}", input)?;
88                write!(f, "  {}", " ".repeat(char_position))?;
89                if use_color {
90                    writeln!(f, "\x1b[1;31m^\x1b[0m")?;
91                } else {
92                    writeln!(f, "^")?;
93                }
94                writeln!(f)?;
95
96                // Hint with valid characters (truncate if too long)
97                let hint_chars = if valid_chars.len() > 80 {
98                    format!("{}...", &valid_chars[..80])
99                } else {
100                    valid_chars.clone()
101                };
102
103                if use_color {
104                    write!(f, "\x1b[1;36mhint:\x1b[0m valid characters: {}", hint_chars)?;
105                } else {
106                    write!(f, "hint: valid characters: {}", hint_chars)?;
107                }
108                Ok(())
109            }
110            DecodeError::EmptyInput => {
111                if use_color {
112                    write!(f, "\x1b[1;31merror:\x1b[0m cannot decode empty input")?;
113                } else {
114                    write!(f, "error: cannot decode empty input")?;
115                }
116                Ok(())
117            }
118            DecodeError::InvalidPadding => {
119                if use_color {
120                    writeln!(f, "\x1b[1;31merror:\x1b[0m invalid padding")?;
121                    write!(
122                        f,
123                        "\n\x1b[1;36mhint:\x1b[0m check for missing or incorrect '=' characters at end of input"
124                    )?;
125                } else {
126                    writeln!(f, "error: invalid padding")?;
127                    write!(
128                        f,
129                        "\nhint: check for missing or incorrect '=' characters at end of input"
130                    )?;
131                }
132                Ok(())
133            }
134            DecodeError::InvalidLength {
135                actual,
136                expected,
137                hint,
138            } => {
139                if use_color {
140                    writeln!(f, "\x1b[1;31merror:\x1b[0m invalid length for decode",)?;
141                } else {
142                    writeln!(f, "error: invalid length for decode")?;
143                }
144                writeln!(f)?;
145                writeln!(f, "  input is {} characters, expected {}", actual, expected)?;
146                writeln!(f)?;
147                if use_color {
148                    write!(f, "\x1b[1;36mhint:\x1b[0m {}", hint)?;
149                } else {
150                    write!(f, "hint: {}", hint)?;
151                }
152                Ok(())
153            }
154        }
155    }
156}
157
158impl std::error::Error for DecodeError {}
159
160/// Check if colored output should be used
161fn should_use_color() -> bool {
162    // Respect NO_COLOR environment variable
163    if std::env::var("NO_COLOR").is_ok() {
164        return false;
165    }
166
167    // Check if stderr is a terminal
168    use std::io::IsTerminal;
169    std::io::stderr().is_terminal()
170}
171
172/// Error when a dictionary is not found
173#[derive(Debug)]
174pub struct DictionaryNotFoundError {
175    pub name: String,
176    pub suggestion: Option<String>,
177}
178
179impl DictionaryNotFoundError {
180    pub fn new(name: impl Into<String>, suggestion: Option<String>) -> Self {
181        Self {
182            name: name.into(),
183            suggestion,
184        }
185    }
186}
187
188impl fmt::Display for DictionaryNotFoundError {
189    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190        let use_color = should_use_color();
191
192        if use_color {
193            writeln!(
194                f,
195                "\x1b[1;31merror:\x1b[0m dictionary '{}' not found",
196                self.name
197            )?;
198        } else {
199            writeln!(f, "error: dictionary '{}' not found", self.name)?;
200        }
201
202        writeln!(f)?;
203
204        if let Some(suggestion) = &self.suggestion {
205            if use_color {
206                writeln!(f, "\x1b[1;36mhint:\x1b[0m did you mean '{}'?", suggestion)?;
207            } else {
208                writeln!(f, "hint: did you mean '{}'?", suggestion)?;
209            }
210        }
211
212        if use_color {
213            write!(
214                f,
215                "      run \x1b[1m`base-d config --dictionaries`\x1b[0m to see all dictionaries"
216            )?;
217        } else {
218            write!(
219                f,
220                "      run `base-d config --dictionaries` to see all dictionaries"
221            )?;
222        }
223
224        Ok(())
225    }
226}
227
228impl std::error::Error for DictionaryNotFoundError {}
229
230/// Calculate Levenshtein distance between two strings
231fn levenshtein_distance(s1: &str, s2: &str) -> usize {
232    let len1 = s1.chars().count();
233    let len2 = s2.chars().count();
234
235    if len1 == 0 {
236        return len2;
237    }
238    if len2 == 0 {
239        return len1;
240    }
241
242    let mut prev_row: Vec<usize> = (0..=len2).collect();
243    let mut curr_row = vec![0; len2 + 1];
244
245    for (i, c1) in s1.chars().enumerate() {
246        curr_row[0] = i + 1;
247
248        for (j, c2) in s2.chars().enumerate() {
249            let cost = if c1 == c2 { 0 } else { 1 };
250            curr_row[j + 1] = (curr_row[j] + 1)
251                .min(prev_row[j + 1] + 1)
252                .min(prev_row[j] + cost);
253        }
254
255        std::mem::swap(&mut prev_row, &mut curr_row);
256    }
257
258    prev_row[len2]
259}
260
261/// Find the closest matching dictionary name
262pub fn find_closest_dictionary(name: &str, available: &[String]) -> Option<String> {
263    if available.is_empty() {
264        return None;
265    }
266
267    let mut best_match = None;
268    let mut best_distance = usize::MAX;
269
270    for dict_name in available {
271        let distance = levenshtein_distance(name, dict_name);
272
273        // Only suggest if distance is reasonably small
274        // (e.g., 1-2 character typos for short names, up to 3 for longer names)
275        let threshold = if name.len() < 5 { 2 } else { 3 };
276
277        if distance < best_distance && distance <= threshold {
278            best_distance = distance;
279            best_match = Some(dict_name.clone());
280        }
281    }
282
283    best_match
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    #[test]
291    fn test_levenshtein_distance() {
292        assert_eq!(levenshtein_distance("base64", "base64"), 0);
293        assert_eq!(levenshtein_distance("base64", "base32"), 2);
294        assert_eq!(levenshtein_distance("bas64", "base64"), 1);
295        assert_eq!(levenshtein_distance("", "base64"), 6);
296    }
297
298    #[test]
299    fn test_find_closest_dictionary() {
300        let dicts = vec![
301            "base64".to_string(),
302            "base32".to_string(),
303            "base16".to_string(),
304            "hex".to_string(),
305        ];
306
307        assert_eq!(
308            find_closest_dictionary("bas64", &dicts),
309            Some("base64".to_string())
310        );
311        assert_eq!(
312            find_closest_dictionary("base63", &dicts),
313            Some("base64".to_string())
314        );
315        assert_eq!(
316            find_closest_dictionary("hex_radix", &dicts),
317            None // too different
318        );
319    }
320
321    #[test]
322    fn test_error_display_no_color() {
323        // Unsafe: environment variable access (not thread-safe)
324        // TODO: Audit that the environment access only happens in single-threaded code.
325        unsafe {
326            std::env::set_var("NO_COLOR", "1");
327        }
328
329        let err = DecodeError::invalid_character('_', 12, "SGVsbG9faW52YWxpZA==", "A-Za-z0-9+/=");
330        let display = format!("{}", err);
331
332        assert!(display.contains("invalid character '_' at position 12"));
333        assert!(display.contains("SGVsbG9faW52YWxpZA=="));
334        assert!(display.contains("^"));
335        assert!(display.contains("hint:"));
336
337        // Unsafe: environment variable access (not thread-safe)
338        // TODO: Audit that the environment access only happens in single-threaded code.
339        unsafe {
340            std::env::remove_var("NO_COLOR");
341        }
342    }
343
344    #[test]
345    fn test_invalid_length_error() {
346        // Unsafe: environment variable access (not thread-safe)
347        // TODO: Audit that the environment access only happens in single-threaded code.
348        unsafe {
349            std::env::set_var("NO_COLOR", "1");
350        }
351
352        let err = DecodeError::invalid_length(
353            13,
354            "multiple of 4",
355            "add padding (=) or check for missing characters",
356        );
357        let display = format!("{}", err);
358
359        assert!(display.contains("invalid length"));
360        assert!(display.contains("13 characters"));
361        assert!(display.contains("multiple of 4"));
362        assert!(display.contains("add padding"));
363
364        // Unsafe: environment variable access (not thread-safe)
365        // TODO: Audit that the environment access only happens in single-threaded code.
366        unsafe {
367            std::env::remove_var("NO_COLOR");
368        }
369    }
370
371    #[test]
372    fn test_dictionary_not_found_error() {
373        // Unsafe: environment variable access (not thread-safe)
374        // TODO: Audit that the environment access only happens in single-threaded code.
375        unsafe {
376            std::env::set_var("NO_COLOR", "1");
377        }
378
379        let err = DictionaryNotFoundError::new("bas64", Some("base64".to_string()));
380        let display = format!("{}", err);
381
382        assert!(display.contains("dictionary 'bas64' not found"));
383        assert!(display.contains("did you mean 'base64'?"));
384        assert!(display.contains("base-d config --dictionaries"));
385
386        // Unsafe: environment variable access (not thread-safe)
387        // TODO: Audit that the environment access only happens in single-threaded code.
388        unsafe {
389            std::env::remove_var("NO_COLOR");
390        }
391    }
392}