Skip to main content

base64_ng/engine/
decode_const.rs

1use crate::{Alphabet, DecodeError, Engine};
2
3impl<A, const PAD: bool> Engine<A, PAD>
4where
5    A: Alphabet,
6{
7    /// Decodes a fixed-size Base64 input into a fixed-size output array in
8    /// const contexts.
9    ///
10    /// The returned tuple contains the output array and the number of decoded
11    /// bytes written into that array. Bytes after the decoded prefix are zero.
12    ///
13    /// Unlike [`Engine::encode_array`](crate::Engine::encode_array), this
14    /// function does not use panics for sizing mistakes. If `OUTPUT_CAP` is too
15    /// small or the input is malformed, it returns [`DecodeError`]. This keeps
16    /// the same function suitable for compile-time constants and runtime calls.
17    ///
18    /// # Security
19    ///
20    /// This is the normal strict decoder in const form. It is not a
21    /// constant-time-oriented secret decoder, and strict errors may reveal
22    /// input-derived indexes and bytes. Use [`crate::ct`] for sensitive decode
23    /// timing posture.
24    ///
25    /// # Examples
26    ///
27    /// ```
28    /// use base64_ng::STANDARD;
29    ///
30    /// const DECODED: ([u8; 5], usize) = match STANDARD.decode_array(b"aGVsbG8=") {
31    ///     Ok(decoded) => decoded,
32    ///     Err(_) => panic!("static base64 literal should decode"),
33    /// };
34    ///
35    /// assert_eq!(&DECODED.0[..DECODED.1], b"hello");
36    /// ```
37    pub const fn decode_array<const INPUT_LEN: usize, const OUTPUT_CAP: usize>(
38        &self,
39        input: &[u8; INPUT_LEN],
40    ) -> Result<([u8; OUTPUT_CAP], usize), DecodeError> {
41        let required = match const_decoded_len::<PAD>(input) {
42            Ok(required) => required,
43            Err(error) => return Err(error),
44        };
45        if OUTPUT_CAP < required {
46            return Err(DecodeError::OutputTooSmall {
47                required,
48                available: OUTPUT_CAP,
49            });
50        }
51
52        let mut output = [0u8; OUTPUT_CAP];
53        let written = if PAD {
54            match const_decode_padded::<A, INPUT_LEN, OUTPUT_CAP>(input, &mut output) {
55                Ok(written) => written,
56                Err(error) => return Err(error),
57            }
58        } else {
59            match const_decode_unpadded::<A, INPUT_LEN, OUTPUT_CAP>(input, &mut output) {
60                Ok(written) => written,
61                Err(error) => return Err(error),
62            }
63        };
64
65        Ok((output, written))
66    }
67}
68
69const fn const_decoded_len<const PAD: bool>(input: &[u8]) -> Result<usize, DecodeError> {
70    if PAD {
71        const_decoded_len_padded(input)
72    } else {
73        const_decoded_len_unpadded(input)
74    }
75}
76
77const fn const_decoded_len_padded(input: &[u8]) -> Result<usize, DecodeError> {
78    let len = input.len();
79    if len == 0 {
80        return Ok(0);
81    }
82    if len & 3 != 0 {
83        return Err(DecodeError::InvalidLength);
84    }
85
86    let mut padding = 0;
87    if input[len - 1] == b'=' {
88        padding += 1;
89    }
90    if input[len - 2] == b'=' {
91        padding += 1;
92    }
93
94    let first_pad = len - padding;
95    let mut index = 0;
96    while index < first_pad {
97        if input[index] == b'=' {
98            return Err(DecodeError::InvalidPadding { index });
99        }
100        index += 1;
101    }
102
103    Ok(len / 4 * 3 - padding)
104}
105
106const fn const_decoded_len_unpadded(input: &[u8]) -> Result<usize, DecodeError> {
107    let len = input.len();
108    let remainder = len & 3;
109    if remainder == 1 {
110        return Err(DecodeError::InvalidLength);
111    }
112
113    let mut index = 0;
114    while index < len {
115        if input[index] == b'=' {
116            return Err(DecodeError::InvalidPadding { index });
117        }
118        index += 1;
119    }
120
121    Ok(len / 4 * 3
122        + if remainder == 2 {
123            1
124        } else if remainder == 3 {
125            2
126        } else {
127            0
128        })
129}
130
131const fn const_decode_padded<A: Alphabet, const INPUT_LEN: usize, const OUTPUT_CAP: usize>(
132    input: &[u8; INPUT_LEN],
133    output: &mut [u8; OUTPUT_CAP],
134) -> Result<usize, DecodeError> {
135    let mut read = 0;
136    let mut write = 0;
137
138    while read < INPUT_LEN {
139        let written = match const_decode_quantum::<A, true, OUTPUT_CAP>(
140            input[read],
141            input[read + 1],
142            input[read + 2],
143            input[read + 3],
144            read,
145            output,
146            write,
147        ) {
148            Ok(written) => written,
149            Err(error) => return Err(error),
150        };
151        read += 4;
152        write += written;
153        if written < 3 && read != INPUT_LEN {
154            return Err(DecodeError::InvalidPadding { index: read - 4 });
155        }
156    }
157
158    Ok(write)
159}
160
161const fn const_decode_unpadded<A: Alphabet, const INPUT_LEN: usize, const OUTPUT_CAP: usize>(
162    input: &[u8; INPUT_LEN],
163    output: &mut [u8; OUTPUT_CAP],
164) -> Result<usize, DecodeError> {
165    let mut read = 0;
166    let mut write = 0;
167
168    while read + 4 <= INPUT_LEN {
169        let written = match const_decode_quantum::<A, false, OUTPUT_CAP>(
170            input[read],
171            input[read + 1],
172            input[read + 2],
173            input[read + 3],
174            read,
175            output,
176            write,
177        ) {
178            Ok(written) => written,
179            Err(error) => return Err(error),
180        };
181        read += 4;
182        write += written;
183    }
184
185    match INPUT_LEN - read {
186        0 => Ok(write),
187        2 => {
188            let v0 = match const_decode_byte::<A>(input[read], read) {
189                Ok(value) => value,
190                Err(error) => return Err(error),
191            };
192            let v1 = match const_decode_byte::<A>(input[read + 1], read + 1) {
193                Ok(value) => value,
194                Err(error) => return Err(error),
195            };
196            if v1 & 0b0000_1111 != 0 {
197                return Err(DecodeError::InvalidPadding { index: read + 1 });
198            }
199            if let Err(error) = const_ensure_output::<OUTPUT_CAP>(write, 1) {
200                return Err(error);
201            }
202            output[write] = (v0 << 2) | (v1 >> 4);
203            Ok(write + 1)
204        }
205        3 => {
206            let v0 = match const_decode_byte::<A>(input[read], read) {
207                Ok(value) => value,
208                Err(error) => return Err(error),
209            };
210            let v1 = match const_decode_byte::<A>(input[read + 1], read + 1) {
211                Ok(value) => value,
212                Err(error) => return Err(error),
213            };
214            let v2 = match const_decode_byte::<A>(input[read + 2], read + 2) {
215                Ok(value) => value,
216                Err(error) => return Err(error),
217            };
218            if v2 & 0b0000_0011 != 0 {
219                return Err(DecodeError::InvalidPadding { index: read + 2 });
220            }
221            if let Err(error) = const_ensure_output::<OUTPUT_CAP>(write, 2) {
222                return Err(error);
223            }
224            output[write] = (v0 << 2) | (v1 >> 4);
225            output[write + 1] = (v1 << 4) | (v2 >> 2);
226            Ok(write + 2)
227        }
228        _ => Err(DecodeError::InvalidLength),
229    }
230}
231
232const fn const_decode_quantum<A: Alphabet, const PAD: bool, const OUTPUT_CAP: usize>(
233    b0: u8,
234    b1: u8,
235    b2: u8,
236    b3: u8,
237    input_offset: usize,
238    output: &mut [u8; OUTPUT_CAP],
239    write: usize,
240) -> Result<usize, DecodeError> {
241    let v0 = match const_decode_byte::<A>(b0, input_offset) {
242        Ok(value) => value,
243        Err(error) => return Err(error),
244    };
245    let v1 = match const_decode_byte::<A>(b1, input_offset + 1) {
246        Ok(value) => value,
247        Err(error) => return Err(error),
248    };
249
250    match (b2, b3) {
251        (b'=', b'=') if PAD => {
252            if v1 & 0b0000_1111 != 0 {
253                return Err(DecodeError::InvalidPadding {
254                    index: input_offset + 1,
255                });
256            }
257            if let Err(error) = const_ensure_output::<OUTPUT_CAP>(write, 1) {
258                return Err(error);
259            }
260            output[write] = (v0 << 2) | (v1 >> 4);
261            Ok(1)
262        }
263        (b'=', _) if PAD => Err(DecodeError::InvalidPadding {
264            index: input_offset + 2,
265        }),
266        (_, b'=') if PAD => {
267            let v2 = match const_decode_byte::<A>(b2, input_offset + 2) {
268                Ok(value) => value,
269                Err(error) => return Err(error),
270            };
271            if v2 & 0b0000_0011 != 0 {
272                return Err(DecodeError::InvalidPadding {
273                    index: input_offset + 2,
274                });
275            }
276            if let Err(error) = const_ensure_output::<OUTPUT_CAP>(write, 2) {
277                return Err(error);
278            }
279            output[write] = (v0 << 2) | (v1 >> 4);
280            output[write + 1] = (v1 << 4) | (v2 >> 2);
281            Ok(2)
282        }
283        (b'=', _) => Err(DecodeError::InvalidPadding {
284            index: input_offset + 2,
285        }),
286        (_, b'=') => Err(DecodeError::InvalidPadding {
287            index: input_offset + 3,
288        }),
289        _ => {
290            let v2 = match const_decode_byte::<A>(b2, input_offset + 2) {
291                Ok(value) => value,
292                Err(error) => return Err(error),
293            };
294            let v3 = match const_decode_byte::<A>(b3, input_offset + 3) {
295                Ok(value) => value,
296                Err(error) => return Err(error),
297            };
298            if let Err(error) = const_ensure_output::<OUTPUT_CAP>(write, 3) {
299                return Err(error);
300            }
301            output[write] = (v0 << 2) | (v1 >> 4);
302            output[write + 1] = (v1 << 4) | (v2 >> 2);
303            output[write + 2] = (v2 << 6) | v3;
304            Ok(3)
305        }
306    }
307}
308
309const fn const_ensure_output<const OUTPUT_CAP: usize>(
310    write: usize,
311    needed: usize,
312) -> Result<(), DecodeError> {
313    if write > OUTPUT_CAP || OUTPUT_CAP - write < needed {
314        let required = if write > usize::MAX - needed {
315            usize::MAX
316        } else {
317            write + needed
318        };
319        return Err(DecodeError::OutputTooSmall {
320            required,
321            available: OUTPUT_CAP,
322        });
323    }
324
325    Ok(())
326}
327
328const fn const_decode_byte<A: Alphabet>(byte: u8, index: usize) -> Result<u8, DecodeError> {
329    let mut candidate = 0u8;
330    while candidate < 64 {
331        if byte == A::ENCODE[candidate as usize] {
332            return Ok(candidate);
333        }
334        candidate += 1;
335    }
336
337    Err(DecodeError::InvalidByte { index, byte })
338}