const_decoder/
decoder.rs

1//! `Decoder` and closely related types.
2
3use compile_fmt::{clip_ascii, compile_assert, compile_panic, fmt, Ascii};
4
5use crate::wrappers::{SkipWhitespace, Skipper};
6
7// Since `?` is not allowed in `const fn`s, we use its simplified version.
8macro_rules! const_try {
9    ($result:expr) => {
10        match $result {
11            Ok(value) => value,
12            Err(err) => return Err(err),
13        }
14    };
15}
16
17#[derive(Debug)]
18struct DecodeError {
19    invalid_char: u8,
20    // `None` for hex encoding
21    alphabet: Option<Ascii<'static>>,
22}
23
24impl DecodeError {
25    const fn invalid_char(invalid_char: u8, alphabet: Option<Ascii<'static>>) -> Self {
26        Self {
27            invalid_char,
28            alphabet,
29        }
30    }
31
32    const fn panic(self, input_pos: usize) -> ! {
33        if self.invalid_char.is_ascii() {
34            if let Some(alphabet) = self.alphabet {
35                compile_panic!(
36                    "Character '", self.invalid_char as char => fmt::<char>(), "' at position ",
37                    input_pos => fmt::<usize>(), " is not a part of \
38                    the decoder alphabet '", alphabet => clip_ascii(64, ""), "'"
39                );
40            } else {
41                compile_panic!(
42                    "Character '", self.invalid_char as char => fmt::<char>(), "' at position ",
43                    input_pos => fmt::<usize>(), " is not a hex digit"
44                );
45            }
46        } else {
47            compile_panic!(
48                "Non-ASCII character with decimal code ", self.invalid_char => fmt::<u8>(),
49                " encountered at position ", input_pos => fmt::<usize>()
50            );
51        }
52    }
53}
54
55/// Custom encoding scheme based on a certain alphabet (mapping between a subset of ASCII chars
56/// and digits in `0..P`, where `P` is a power of 2).
57///
58/// # Examples
59///
60/// ```
61/// # use const_decoder::Decoder;
62/// // Decoder for Bech32 encoding as specified in
63/// // https://github.com/bitcoin/bips/blob/master/bip-0173.mediawiki.
64/// const BECH32: Decoder = Decoder::custom("qpzry9x8gf2tvdw0s3jn54khce6mua7l");
65///
66/// // Sample address from the Bech32 spec excluding the `tb1q` prefix
67/// // and the checksum suffix.
68/// const SAMPLE_ADDR: [u8; 32] =
69///     BECH32.decode(b"rp33g0q5c5txsp9arysrx4k6zdkfs4nce4xj0gdcccefvpysxf3q");
70/// ```
71#[derive(Debug, Clone, Copy)]
72pub struct Encoding {
73    alphabet: Ascii<'static>,
74    table: [u8; 128],
75    bits_per_char: u8,
76}
77
78impl Encoding {
79    const NO_MAPPING: u8 = u8::MAX;
80
81    const BASE64: Self =
82        Self::new("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/");
83    const BASE64_URL: Self =
84        Self::new("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_");
85
86    /// Creates an encoding based on the provided `alphabet`: a sequence of ASCII chars
87    /// that correspond to digits 0, 1, 2, etc.
88    ///
89    /// # Panics
90    ///
91    /// - Panics if `alphabet` does not consist of distinct ASCII chars.
92    /// - Panics if `alphabet` length is not 2, 4, 8, 16, 32 or 64.
93    #[allow(clippy::cast_possible_truncation)]
94    pub const fn new(alphabet: &'static str) -> Self {
95        let bits_per_char = match alphabet.len() {
96            2 => 1,
97            4 => 2,
98            8 => 3,
99            16 => 4,
100            32 => 5,
101            64 => 6,
102            other => compile_panic!(
103                "Invalid alphabet length ", other => fmt::<usize>(),
104                "; must be one of 2, 4, 8, 16, 32, or 64"
105            ),
106        };
107
108        let mut table = [Self::NO_MAPPING; 128];
109        let alphabet_bytes = alphabet.as_bytes();
110        let alphabet = Ascii::new(alphabet); // will panic if `alphabet` contains non-ASCII chars
111        let mut index = 0;
112        while index < alphabet_bytes.len() {
113            let byte = alphabet_bytes[index];
114            let byte_idx = byte as usize;
115            compile_assert!(
116                table[byte_idx] == Self::NO_MAPPING,
117                "Alphabet character '", byte as char => fmt::<char>(), "' is mentioned several times"
118            );
119            table[byte_idx] = index as u8;
120            index += 1;
121        }
122
123        Self {
124            alphabet,
125            table,
126            bits_per_char,
127        }
128    }
129
130    const fn lookup(&self, ascii_char: u8) -> Result<u8, DecodeError> {
131        if !ascii_char.is_ascii() {
132            return Err(DecodeError::invalid_char(ascii_char, Some(self.alphabet)));
133        }
134        let mapping = self.table[ascii_char as usize];
135        if mapping == Self::NO_MAPPING {
136            Err(DecodeError::invalid_char(ascii_char, Some(self.alphabet)))
137        } else {
138            Ok(mapping)
139        }
140    }
141}
142
143/// Internal state of the hexadecimal decoder.
144#[derive(Debug, Clone, Copy)]
145struct HexDecoderState(Option<u8>);
146
147impl HexDecoderState {
148    const fn byte_value(val: u8) -> Result<u8, DecodeError> {
149        Ok(match val {
150            b'0'..=b'9' => val - b'0',
151            b'A'..=b'F' => val - b'A' + 10,
152            b'a'..=b'f' => val - b'a' + 10,
153            _ => return Err(DecodeError::invalid_char(val, None)),
154        })
155    }
156
157    const fn new() -> Self {
158        Self(None)
159    }
160
161    #[allow(clippy::option_if_let_else)] // `Option::map_or_else` cannot be used in const fns
162    const fn update(mut self, byte: u8) -> Result<(Self, Option<u8>), DecodeError> {
163        let byte = const_try!(Self::byte_value(byte));
164        let output = if let Some(b) = self.0 {
165            self.0 = None;
166            Some((b << 4) + byte)
167        } else {
168            self.0 = Some(byte);
169            None
170        };
171        Ok((self, output))
172    }
173
174    const fn is_final(self) -> bool {
175        self.0.is_none()
176    }
177}
178
179/// Internal state of a Base64 decoder.
180#[derive(Debug, Clone, Copy)]
181struct CustomDecoderState {
182    table: Encoding,
183    partial_byte: u8,
184    filled_bits: u8,
185}
186
187impl CustomDecoderState {
188    const fn new(table: Encoding) -> Self {
189        Self {
190            table,
191            partial_byte: 0,
192            filled_bits: 0,
193        }
194    }
195
196    #[allow(clippy::comparison_chain)] // not feasible in const context
197    const fn update(mut self, byte: u8) -> Result<(Self, Option<u8>), DecodeError> {
198        let byte = const_try!(self.table.lookup(byte));
199        let output = if self.filled_bits < 8 - self.table.bits_per_char {
200            self.partial_byte = (self.partial_byte << self.table.bits_per_char) + byte;
201            self.filled_bits += self.table.bits_per_char;
202            None
203        } else if self.filled_bits == 8 - self.table.bits_per_char {
204            let output = (self.partial_byte << self.table.bits_per_char) + byte;
205            self.partial_byte = 0;
206            self.filled_bits = 0;
207            Some(output)
208        } else {
209            let remaining_bits = 8 - self.filled_bits;
210            let new_filled_bits = self.table.bits_per_char - remaining_bits;
211            let output = (self.partial_byte << remaining_bits) + (byte >> new_filled_bits);
212            self.partial_byte = byte % (1 << new_filled_bits);
213            self.filled_bits = new_filled_bits;
214            Some(output)
215        };
216        Ok((self, output))
217    }
218
219    const fn is_final(&self) -> bool {
220        // We don't check `self.filled_bits` because padding may be implicit
221        self.partial_byte == 0
222    }
223}
224
225/// State of a decoder.
226#[derive(Debug, Clone, Copy)]
227enum DecoderState {
228    Hex(HexDecoderState),
229    Base64(CustomDecoderState),
230    Custom(CustomDecoderState),
231}
232
233impl DecoderState {
234    const fn update(self, byte: u8) -> Result<(Self, Option<u8>), DecodeError> {
235        Ok(match self {
236            Self::Hex(state) => {
237                let (updated_state, output) = const_try!(state.update(byte));
238                (Self::Hex(updated_state), output)
239            }
240            Self::Base64(state) => {
241                if byte == b'=' {
242                    (self, None)
243                } else {
244                    let (updated_state, output) = const_try!(state.update(byte));
245                    (Self::Base64(updated_state), output)
246                }
247            }
248            Self::Custom(state) => {
249                let (updated_state, output) = const_try!(state.update(byte));
250                (Self::Custom(updated_state), output)
251            }
252        })
253    }
254
255    const fn is_final(&self) -> bool {
256        match self {
257            Self::Hex(state) => state.is_final(),
258            Self::Base64(state) | Self::Custom(state) => state.is_final(),
259        }
260    }
261}
262
263/// Decoder of a human-friendly encoding, such as hex or base64, into bytes.
264///
265/// # Examples
266///
267/// See the [crate docs](index.html) for examples of usage.
268#[derive(Debug, Clone, Copy)]
269#[non_exhaustive]
270pub enum Decoder {
271    /// Hexadecimal decoder. Supports uppercase and lowercase digits.
272    Hex,
273    /// Base64 decoder accepting standard encoding as per [RFC 3548].
274    /// Does not require padding, but works fine with it.
275    ///
276    /// [RFC 3548]: https://datatracker.ietf.org/doc/html/rfc3548.html
277    Base64,
278    /// Base64 decoder accepting URL / filesystem-safe encoding as per [RFC 3548].
279    /// Does not require padding, but works fine with it.
280    ///
281    /// [RFC 3548]: https://datatracker.ietf.org/doc/html/rfc3548.html
282    Base64Url,
283    /// Decoder based on a custom [`Encoding`].
284    Custom(Encoding),
285}
286
287impl Decoder {
288    /// Creates a new decoder with a custom alphabet.
289    ///
290    /// # Panics
291    ///
292    /// Panics in the same situations as [`Encoding::new()`].
293    pub const fn custom(alphabet: &'static str) -> Self {
294        Self::Custom(Encoding::new(alphabet))
295    }
296
297    /// Makes this decoder skip whitespace chars rather than panicking on encountering them.
298    pub const fn skip_whitespace(self) -> SkipWhitespace {
299        SkipWhitespace(self)
300    }
301
302    const fn new_state(self) -> DecoderState {
303        match self {
304            Self::Hex => DecoderState::Hex(HexDecoderState::new()),
305            Self::Base64 => DecoderState::Base64(CustomDecoderState::new(Encoding::BASE64)),
306            Self::Base64Url => DecoderState::Base64(CustomDecoderState::new(Encoding::BASE64_URL)),
307            Self::Custom(encoding) => DecoderState::Custom(CustomDecoderState::new(encoding)),
308        }
309    }
310
311    /// Decodes `input` into a byte array.
312    ///
313    /// # Panics
314    ///
315    /// - Panics if the provided length is insufficient or too large for `input`.
316    /// - Panics if `input` contains invalid chars.
317    pub const fn decode<const N: usize>(self, input: &[u8]) -> [u8; N] {
318        self.do_decode(input, None)
319    }
320
321    pub(crate) const fn do_decode<const N: usize>(
322        self,
323        input: &[u8],
324        skipper: Option<Skipper>,
325    ) -> [u8; N] {
326        let mut bytes = [0_u8; N];
327        let mut in_index = 0;
328        let mut out_index = 0;
329        let mut state = self.new_state();
330
331        while in_index < input.len() {
332            if let Some(skipper) = skipper {
333                let new_in_index = skipper.skip(input, in_index);
334                if new_in_index != in_index {
335                    in_index = new_in_index;
336                    continue;
337                }
338            }
339
340            let update = match state.update(input[in_index]) {
341                Ok(update) => update,
342                Err(err) => err.panic(in_index),
343            };
344            state = update.0;
345            if let Some(byte) = update.1 {
346                if out_index < N {
347                    bytes[out_index] = byte;
348                }
349                out_index += 1;
350            }
351            in_index += 1;
352        }
353
354        compile_assert!(
355            out_index <= N,
356            "Output overflow: the input decodes to ", out_index => fmt::<usize>(),
357            " bytes, while type inference implies ",  N => fmt::<usize>(), ". \
358            Either fix the input or change the output buffer length correspondingly"
359        );
360        compile_assert!(
361            out_index == N,
362            "Output underflow: the input decodes to ", out_index => fmt::<usize>(),
363            " bytes, while type inference implies ", N => fmt::<usize>(), ". \
364            Either fix the input or change the output buffer length correspondingly"
365        );
366
367        assert!(
368            state.is_final(),
369            "Left-over state after processing input. This usually means that the input \
370             is incorrect (e.g., an odd number of hex digits)."
371        );
372        bytes
373    }
374
375    pub(crate) const fn do_decode_len(self, input: &[u8], skipper: Option<Skipper>) -> usize {
376        let mut in_index = 0;
377        let mut out_index = 0;
378        let mut state = self.new_state();
379
380        while in_index < input.len() {
381            if let Some(skipper) = skipper {
382                let new_in_index = skipper.skip(input, in_index);
383                if new_in_index != in_index {
384                    in_index = new_in_index;
385                    continue;
386                }
387            }
388
389            let update = match state.update(input[in_index]) {
390                Ok(update) => update,
391                Err(err) => err.panic(in_index),
392            };
393            state = update.0;
394            if update.1.is_some() {
395                out_index += 1;
396            }
397            in_index += 1;
398        }
399        out_index
400    }
401}