base64easy/
lib.rs

1#![doc = include_str!("../README.md")]
2
3/// The base64 encoding engine kind to use when encoding/decoding data.
4#[derive(Debug, Clone, Copy)]
5pub enum EngineKind {
6    /// Base64 Standard
7    Standard,
8    /// Base64 StandardNoPad
9    StandardNoPad,
10    /// Base64 UrlSafe
11    UrlSafe,
12    /// Base64 UrlSafeNoPad
13    UrlSafeNoPad,
14}
15
16impl EngineKind {
17    pub fn padding(&self) -> bool {
18        match self {
19            EngineKind::Standard => true,
20            EngineKind::StandardNoPad => false,
21            EngineKind::UrlSafe => true,
22            EngineKind::UrlSafeNoPad => false,
23        }
24    }
25}
26
27/// Encode bytes to base64 string.
28#[cfg(feature = "alloc")]
29pub fn encode<T: AsRef<[u8]>>(bytes: T, engine: EngineKind) -> String {
30    match engine {
31        EngineKind::Standard => STANDARD.encode(bytes),
32        EngineKind::StandardNoPad => STANDARD_NO_PAD.encode(bytes),
33        EngineKind::UrlSafe => URL_SAFE.encode(bytes),
34        EngineKind::UrlSafeNoPad => URL_SAFE_NO_PAD.encode(bytes),
35    }
36}
37
38/// Encode bytes to base64 string into a pre-allocated buffer.
39pub fn encode_slice<T: AsRef<[u8]>>(bytes: T, output_buf: &mut [u8], engine: EngineKind) -> Result<usize, Error> {
40    match engine {
41        EngineKind::Standard => STANDARD.encode_slice(bytes, output_buf),
42        EngineKind::StandardNoPad => STANDARD_NO_PAD.encode_slice(bytes, output_buf),
43        EngineKind::UrlSafe => URL_SAFE.encode_slice(bytes, output_buf),
44        EngineKind::UrlSafeNoPad => URL_SAFE_NO_PAD.encode_slice(bytes, output_buf),
45    }
46}
47
48/// Decode base64 string to bytes.
49#[cfg(feature = "alloc")]
50pub fn decode<T: AsRef<[u8]>>(b64str: T, engine: EngineKind) -> Result<Vec<u8>, Error> {
51    match engine {
52        EngineKind::Standard => STANDARD.decode(b64str),
53        EngineKind::StandardNoPad => STANDARD_NO_PAD.decode(b64str),
54        EngineKind::UrlSafe => URL_SAFE.decode(b64str),
55        EngineKind::UrlSafeNoPad => URL_SAFE_NO_PAD.decode(b64str),
56    }
57}
58
59/// Decode base64 string to bytes into a pre-allocated buffer.
60pub fn decode_slice<T: AsRef<[u8]>>(b64str: T, output: &mut [u8], engine: EngineKind) -> Result<usize, Error> {
61    match engine {
62        EngineKind::Standard => STANDARD.decode_slice(b64str, output),
63        EngineKind::StandardNoPad => STANDARD_NO_PAD.decode_slice(b64str, output),
64        EngineKind::UrlSafe => URL_SAFE.decode_slice(b64str, output),
65        EngineKind::UrlSafeNoPad => URL_SAFE_NO_PAD.decode_slice(b64str, output),
66    }
67}
68
69pub(crate) trait Engine: Send + Sync {
70    type Config: Config;
71    type DecodeEstimate: DecodeEstimate;
72
73    fn config(&self) -> &Self::Config;
74
75    fn internal_encode(&self, input: &[u8], output: &mut [u8]) -> usize;
76    fn internal_decoded_len_estimate(&self, input_len: usize) -> Self::DecodeEstimate;
77    fn internal_decode(&self, input: &[u8], output: &mut [u8], decode_estimate: Self::DecodeEstimate) -> Result<DecodeMetadata, Error>;
78
79    #[cfg(feature = "alloc")]
80    #[inline]
81    fn encode<T: AsRef<[u8]>>(&self, input: T) -> String {
82        fn inner<E: Engine + ?Sized>(engine: &E, input_bytes: &[u8]) -> String {
83            let encoded_size =
84                encoded_len(input_bytes.len(), engine.config().encode_padding()).expect("integer overflow when calculating buffer size");
85
86            let mut buf = vec![0; encoded_size];
87
88            encode_with_padding(input_bytes, &mut buf[..], engine, encoded_size);
89
90            String::from_utf8(buf).expect("Invalid UTF8")
91        }
92
93        inner(self, input.as_ref())
94    }
95
96    #[inline]
97    fn encode_slice<T: AsRef<[u8]>>(&self, input: T, output_buf: &mut [u8]) -> Result<usize, Error> {
98        fn inner<E>(engine: &E, input_bytes: &[u8], output_buf: &mut [u8]) -> Result<usize, Error>
99        where
100            E: Engine + ?Sized,
101        {
102            let encoded_size =
103                encoded_len(input_bytes.len(), engine.config().encode_padding()).expect("usize overflow when calculating buffer size");
104
105            if output_buf.len() < encoded_size {
106                return Err(Error::OutputSliceTooSmall);
107            }
108
109            let b64_output = &mut output_buf[0..encoded_size];
110
111            encode_with_padding(input_bytes, b64_output, engine, encoded_size);
112
113            Ok(encoded_size)
114        }
115
116        inner(self, input.as_ref(), output_buf)
117    }
118
119    #[cfg(feature = "alloc")]
120    #[inline]
121    fn decode<T: AsRef<[u8]>>(&self, input: T) -> Result<Vec<u8>, Error> {
122        fn inner<E: Engine + ?Sized>(engine: &E, input_bytes: &[u8]) -> Result<Vec<u8>, Error> {
123            let estimate = engine.internal_decoded_len_estimate(input_bytes.len());
124            let mut buffer = vec![0; estimate.decoded_len_estimate()];
125
126            let bytes_written = engine.internal_decode(input_bytes, &mut buffer, estimate)?.decoded_len;
127
128            buffer.truncate(bytes_written);
129
130            Ok(buffer)
131        }
132
133        inner(self, input.as_ref())
134    }
135
136    #[inline]
137    fn decode_slice<T: AsRef<[u8]>>(&self, input: T, output: &mut [u8]) -> Result<usize, Error> {
138        fn inner<E: Engine + ?Sized>(eng: &E, input: &[u8], output: &mut [u8]) -> Result<usize, Error> {
139            let decode_estimate = eng.internal_decoded_len_estimate(input.len());
140            eng.internal_decode(input, output, decode_estimate).map(|dm| dm.decoded_len)
141        }
142
143        inner(self, input.as_ref(), output)
144    }
145}
146
147#[allow(dead_code)]
148pub(crate) trait DecodeEstimate {
149    fn decoded_len_estimate(&self) -> usize;
150}
151
152#[derive(PartialEq, Eq, Debug)]
153pub(crate) struct DecodeMetadata {
154    /// Number of decoded bytes output
155    pub(crate) decoded_len: usize,
156    /// Offset of the first padding byte in the input, if any
157    pub(crate) padding_offset: Option<usize>,
158}
159
160impl DecodeMetadata {
161    pub(crate) fn new(decoded_bytes: usize, padding_index: Option<usize>) -> Self {
162        Self {
163            decoded_len: decoded_bytes,
164            padding_offset: padding_index,
165        }
166    }
167}
168
169/// Calculate the base64 encoded length for a given input length, optionally including any
170/// appropriate padding bytes.
171///
172/// Returns `None` if the encoded length can't be represented in `usize`. This will happen for
173/// input lengths in approximately the top quarter of the range of `usize`.
174pub const fn encoded_len(bytes_len: usize, padding: bool) -> Option<usize> {
175    let rem = bytes_len % 3;
176
177    let complete_input_chunks = bytes_len / 3;
178    // `?` is disallowed in const, and `let Some(_) = _ else` requires 1.65.0, whereas this
179    // messier syntax works on 1.48
180    let complete_chunk_output = if let Some(complete_chunk_output) = complete_input_chunks.checked_mul(4) {
181        complete_chunk_output
182    } else {
183        return None;
184    };
185
186    if rem > 0 {
187        if padding {
188            complete_chunk_output.checked_add(4)
189        } else {
190            let encoded_rem = match rem {
191                1 => 2,
192                // only other possible remainder is 2
193                // can't use a separate _ => unreachable!() in const fns in ancient rust versions
194                _ => 3,
195            };
196            complete_chunk_output.checked_add(encoded_rem)
197        }
198    } else {
199        Some(complete_chunk_output)
200    }
201}
202
203/// Returns a conservative estimate of the decoded size of `encoded_len` base64 symbols (rounded up
204/// to the next group of 3 decoded bytes).
205///
206/// The resulting length will be a safe choice for the size of a decode buffer, but may have up to
207/// 2 trailing bytes that won't end up being needed.
208///
209/// # Examples
210///
211/// ```
212/// use base64easy::decoded_len_estimate;
213///
214/// assert_eq!(3, decoded_len_estimate(1));
215/// assert_eq!(3, decoded_len_estimate(2));
216/// assert_eq!(3, decoded_len_estimate(3));
217/// assert_eq!(3, decoded_len_estimate(4));
218/// // start of the next quad of encoded symbols
219/// assert_eq!(6, decoded_len_estimate(5));
220/// ```
221pub fn decoded_len_estimate(encoded_len: usize) -> usize {
222    STANDARD.internal_decoded_len_estimate(encoded_len).decoded_len_estimate()
223}
224
225pub(crate) fn encode_with_padding<E: Engine + ?Sized>(input: &[u8], output: &mut [u8], engine: &E, expected_encoded_size: usize) {
226    debug_assert_eq!(expected_encoded_size, output.len());
227
228    let b64_bytes_written = engine.internal_encode(input, output);
229
230    let padding_bytes = if engine.config().encode_padding() {
231        add_padding(b64_bytes_written, &mut output[b64_bytes_written..])
232    } else {
233        0
234    };
235
236    let encoded_bytes = b64_bytes_written
237        .checked_add(padding_bytes)
238        .expect("usize overflow when calculating b64 length");
239
240    debug_assert_eq!(expected_encoded_size, encoded_bytes);
241}
242
243pub(crate) const PAD_BYTE: u8 = b'=';
244
245pub(crate) fn add_padding(unpadded_output_len: usize, output: &mut [u8]) -> usize {
246    let pad_bytes = (4 - (unpadded_output_len % 4)) % 4;
247    // for just a couple bytes, this has better performance than using
248    // .fill(), or iterating over mutable refs, which call memset()
249    #[allow(clippy::needless_range_loop)]
250    for i in 0..pad_bytes {
251        output[i] = PAD_BYTE;
252    }
253
254    pad_bytes
255}
256
257pub(crate) const STANDARD: GeneralPurpose = GeneralPurpose::new(&alphabet::STANDARD, PAD);
258pub(crate) const STANDARD_NO_PAD: GeneralPurpose = GeneralPurpose::new(&alphabet::STANDARD, NO_PAD);
259pub(crate) const URL_SAFE: GeneralPurpose = GeneralPurpose::new(&alphabet::URL_SAFE, PAD);
260pub(crate) const URL_SAFE_NO_PAD: GeneralPurpose = GeneralPurpose::new(&alphabet::URL_SAFE, NO_PAD);
261
262pub(crate) const PAD: GeneralPurposeConfig = GeneralPurposeConfig::new();
263pub(crate) const NO_PAD: GeneralPurposeConfig = GeneralPurposeConfig::new()
264    .with_encode_padding(false)
265    .with_decode_padding_mode(DecodePaddingMode::RequireNone);
266
267#[derive(Debug, Clone)]
268pub(crate) struct GeneralPurpose {
269    encode_table: [u8; 64],
270    decode_table: [u8; 256],
271    config: GeneralPurposeConfig,
272}
273
274impl Engine for GeneralPurpose {
275    type Config = GeneralPurposeConfig;
276    type DecodeEstimate = GeneralPurposeEstimate;
277
278    fn internal_encode(&self, input: &[u8], output: &mut [u8]) -> usize {
279        let mut input_index: usize = 0;
280
281        const BLOCKS_PER_FAST_LOOP: usize = 4;
282        const LOW_SIX_BITS: u64 = 0x3F;
283
284        // we read 8 bytes at a time (u64) but only actually consume 6 of those bytes. Thus, we need
285        // 2 trailing bytes to be available to read..
286        let last_fast_index = input.len().saturating_sub(BLOCKS_PER_FAST_LOOP * 6 + 2);
287        let mut output_index = 0;
288
289        if last_fast_index > 0 {
290            while input_index <= last_fast_index {
291                // Major performance wins from letting the optimizer do the bounds check once, mostly
292                // on the output side
293                let input_chunk = &input[input_index..(input_index + (BLOCKS_PER_FAST_LOOP * 6 + 2))];
294                let output_chunk = &mut output[output_index..(output_index + BLOCKS_PER_FAST_LOOP * 8)];
295
296                // Hand-unrolling for 32 vs 16 or 8 bytes produces yields performance about equivalent
297                // to unsafe pointer code on a Xeon E5-1650v3. 64 byte unrolling was slightly better for
298                // large inputs but significantly worse for 50-byte input, unsurprisingly. I suspect
299                // that it's a not uncommon use case to encode smallish chunks of data (e.g. a 64-byte
300                // SHA-512 digest), so it would be nice if that fit in the unrolled loop at least once.
301                // Plus, single-digit percentage performance differences might well be quite different
302                // on different hardware.
303
304                let input_u64 = read_u64(&input_chunk[0..]);
305
306                output_chunk[0] = self.encode_table[((input_u64 >> 58) & LOW_SIX_BITS) as usize];
307                output_chunk[1] = self.encode_table[((input_u64 >> 52) & LOW_SIX_BITS) as usize];
308                output_chunk[2] = self.encode_table[((input_u64 >> 46) & LOW_SIX_BITS) as usize];
309                output_chunk[3] = self.encode_table[((input_u64 >> 40) & LOW_SIX_BITS) as usize];
310                output_chunk[4] = self.encode_table[((input_u64 >> 34) & LOW_SIX_BITS) as usize];
311                output_chunk[5] = self.encode_table[((input_u64 >> 28) & LOW_SIX_BITS) as usize];
312                output_chunk[6] = self.encode_table[((input_u64 >> 22) & LOW_SIX_BITS) as usize];
313                output_chunk[7] = self.encode_table[((input_u64 >> 16) & LOW_SIX_BITS) as usize];
314
315                let input_u64 = read_u64(&input_chunk[6..]);
316
317                output_chunk[8] = self.encode_table[((input_u64 >> 58) & LOW_SIX_BITS) as usize];
318                output_chunk[9] = self.encode_table[((input_u64 >> 52) & LOW_SIX_BITS) as usize];
319                output_chunk[10] = self.encode_table[((input_u64 >> 46) & LOW_SIX_BITS) as usize];
320                output_chunk[11] = self.encode_table[((input_u64 >> 40) & LOW_SIX_BITS) as usize];
321                output_chunk[12] = self.encode_table[((input_u64 >> 34) & LOW_SIX_BITS) as usize];
322                output_chunk[13] = self.encode_table[((input_u64 >> 28) & LOW_SIX_BITS) as usize];
323                output_chunk[14] = self.encode_table[((input_u64 >> 22) & LOW_SIX_BITS) as usize];
324                output_chunk[15] = self.encode_table[((input_u64 >> 16) & LOW_SIX_BITS) as usize];
325
326                let input_u64 = read_u64(&input_chunk[12..]);
327
328                output_chunk[16] = self.encode_table[((input_u64 >> 58) & LOW_SIX_BITS) as usize];
329                output_chunk[17] = self.encode_table[((input_u64 >> 52) & LOW_SIX_BITS) as usize];
330                output_chunk[18] = self.encode_table[((input_u64 >> 46) & LOW_SIX_BITS) as usize];
331                output_chunk[19] = self.encode_table[((input_u64 >> 40) & LOW_SIX_BITS) as usize];
332                output_chunk[20] = self.encode_table[((input_u64 >> 34) & LOW_SIX_BITS) as usize];
333                output_chunk[21] = self.encode_table[((input_u64 >> 28) & LOW_SIX_BITS) as usize];
334                output_chunk[22] = self.encode_table[((input_u64 >> 22) & LOW_SIX_BITS) as usize];
335                output_chunk[23] = self.encode_table[((input_u64 >> 16) & LOW_SIX_BITS) as usize];
336
337                let input_u64 = read_u64(&input_chunk[18..]);
338
339                output_chunk[24] = self.encode_table[((input_u64 >> 58) & LOW_SIX_BITS) as usize];
340                output_chunk[25] = self.encode_table[((input_u64 >> 52) & LOW_SIX_BITS) as usize];
341                output_chunk[26] = self.encode_table[((input_u64 >> 46) & LOW_SIX_BITS) as usize];
342                output_chunk[27] = self.encode_table[((input_u64 >> 40) & LOW_SIX_BITS) as usize];
343                output_chunk[28] = self.encode_table[((input_u64 >> 34) & LOW_SIX_BITS) as usize];
344                output_chunk[29] = self.encode_table[((input_u64 >> 28) & LOW_SIX_BITS) as usize];
345                output_chunk[30] = self.encode_table[((input_u64 >> 22) & LOW_SIX_BITS) as usize];
346                output_chunk[31] = self.encode_table[((input_u64 >> 16) & LOW_SIX_BITS) as usize];
347
348                output_index += BLOCKS_PER_FAST_LOOP * 8;
349                input_index += BLOCKS_PER_FAST_LOOP * 6;
350            }
351        }
352
353        // Encode what's left after the fast loop.
354
355        const LOW_SIX_BITS_U8: u8 = 0x3F;
356
357        let rem = input.len() % 3;
358        let start_of_rem = input.len() - rem;
359
360        // start at the first index not handled by fast loop, which may be 0.
361
362        while input_index < start_of_rem {
363            let input_chunk = &input[input_index..(input_index + 3)];
364            let output_chunk = &mut output[output_index..(output_index + 4)];
365
366            output_chunk[0] = self.encode_table[(input_chunk[0] >> 2) as usize];
367            output_chunk[1] = self.encode_table[((input_chunk[0] << 4 | input_chunk[1] >> 4) & LOW_SIX_BITS_U8) as usize];
368            output_chunk[2] = self.encode_table[((input_chunk[1] << 2 | input_chunk[2] >> 6) & LOW_SIX_BITS_U8) as usize];
369            output_chunk[3] = self.encode_table[(input_chunk[2] & LOW_SIX_BITS_U8) as usize];
370
371            input_index += 3;
372            output_index += 4;
373        }
374
375        if rem == 2 {
376            output[output_index] = self.encode_table[(input[start_of_rem] >> 2) as usize];
377            output[output_index + 1] =
378                self.encode_table[((input[start_of_rem] << 4 | input[start_of_rem + 1] >> 4) & LOW_SIX_BITS_U8) as usize];
379            output[output_index + 2] = self.encode_table[((input[start_of_rem + 1] << 2) & LOW_SIX_BITS_U8) as usize];
380            output_index += 3;
381        } else if rem == 1 {
382            output[output_index] = self.encode_table[(input[start_of_rem] >> 2) as usize];
383            output[output_index + 1] = self.encode_table[((input[start_of_rem] << 4) & LOW_SIX_BITS_U8) as usize];
384            output_index += 2;
385        }
386
387        output_index
388    }
389
390    fn internal_decoded_len_estimate(&self, input_len: usize) -> Self::DecodeEstimate {
391        GeneralPurposeEstimate::new(input_len)
392    }
393
394    fn internal_decode(&self, input: &[u8], output: &mut [u8], estimate: Self::DecodeEstimate) -> Result<DecodeMetadata, Error> {
395        decode::decode_helper(
396            input,
397            estimate,
398            output,
399            &self.decode_table,
400            self.config.decode_allow_trailing_bits,
401            self.config.decode_padding_mode,
402        )
403    }
404
405    fn config(&self) -> &Self::Config {
406        &self.config
407    }
408}
409
410#[inline]
411pub(crate) fn read_u64(s: &[u8]) -> u64 {
412    u64::from_be_bytes(s[..8].try_into().unwrap())
413}
414
415impl GeneralPurpose {
416    pub(crate) const fn new(alphabet: &alphabet::Alphabet, config: GeneralPurposeConfig) -> Self {
417        Self {
418            encode_table: encode_table(alphabet),
419            decode_table: decode_table(alphabet),
420            config,
421        }
422    }
423}
424
425pub(crate) const fn encode_table(alphabet: &alphabet::Alphabet) -> [u8; 64] {
426    // the encode table is just the alphabet:
427    // 6-bit index lookup -> printable byte
428    let mut encode_table = [0_u8; 64];
429    {
430        let mut index = 0;
431        while index < 64 {
432            encode_table[index] = alphabet.symbols[index];
433            index += 1;
434        }
435    }
436
437    encode_table
438}
439
440pub(crate) const INVALID_VALUE: u8 = 255;
441
442pub(crate) const fn decode_table(alphabet: &alphabet::Alphabet) -> [u8; 256] {
443    let mut decode_table = [INVALID_VALUE; 256];
444
445    // Since the table is full of `INVALID_VALUE` already, we only need to overwrite
446    // the parts that are valid.
447    let mut index = 0;
448    while index < 64 {
449        // The index in the alphabet is the 6-bit value we care about.
450        // Since the index is in 0-63, it is safe to cast to u8.
451        decode_table[alphabet.symbols[index] as usize] = index as u8;
452        index += 1;
453    }
454
455    decode_table
456}
457
458pub(crate) trait Config {
459    fn encode_padding(&self) -> bool;
460}
461
462#[derive(Clone, Copy, Debug)]
463pub(crate) struct GeneralPurposeConfig {
464    encode_padding: bool,
465    decode_allow_trailing_bits: bool,
466    decode_padding_mode: DecodePaddingMode,
467}
468
469impl GeneralPurposeConfig {
470    pub(crate) const fn new() -> Self {
471        Self {
472            encode_padding: true,
473            decode_allow_trailing_bits: false,
474            decode_padding_mode: DecodePaddingMode::RequireCanonical,
475        }
476    }
477
478    pub(crate) const fn with_encode_padding(self, padding: bool) -> Self {
479        Self {
480            encode_padding: padding,
481            ..self
482        }
483    }
484
485    pub(crate) const fn with_decode_padding_mode(self, mode: DecodePaddingMode) -> Self {
486        Self {
487            decode_padding_mode: mode,
488            ..self
489        }
490    }
491}
492
493impl Default for GeneralPurposeConfig {
494    fn default() -> Self {
495        Self::new()
496    }
497}
498
499impl Config for GeneralPurposeConfig {
500    fn encode_padding(&self) -> bool {
501        self.encode_padding
502    }
503}
504
505#[derive(Clone, Copy, Debug, PartialEq, Eq)]
506pub(crate) enum DecodePaddingMode {
507    #[allow(dead_code)]
508    Indifferent,
509    RequireCanonical,
510    RequireNone,
511}
512
513pub(crate) struct GeneralPurposeEstimate {
514    /// input len % 4
515    rem: usize,
516    #[allow(dead_code)]
517    conservative_decoded_len: usize,
518}
519
520impl GeneralPurposeEstimate {
521    pub(crate) fn new(encoded_len: usize) -> Self {
522        let rem = encoded_len % 4;
523        Self {
524            rem,
525            conservative_decoded_len: (encoded_len / 4 + (rem > 0) as usize) * 3,
526        }
527    }
528}
529
530impl DecodeEstimate for GeneralPurposeEstimate {
531    fn decoded_len_estimate(&self) -> usize {
532        self.conservative_decoded_len
533    }
534}
535
536pub(crate) mod alphabet {
537
538    #[derive(Clone, Debug, Eq, PartialEq)]
539    pub(crate) struct Alphabet {
540        pub(crate) symbols: [u8; ALPHABET_SIZE],
541    }
542    impl Alphabet {
543        const fn from_str_unchecked(alphabet: &str) -> Self {
544            let mut symbols = [0_u8; ALPHABET_SIZE];
545            let source_bytes = alphabet.as_bytes();
546
547            let mut index = 0;
548            while index < ALPHABET_SIZE {
549                symbols[index] = source_bytes[index];
550                index += 1;
551            }
552
553            Self { symbols }
554        }
555    }
556
557    pub(crate) const ALPHABET_SIZE: usize = 64;
558
559    pub(crate) const STANDARD: Alphabet = Alphabet::from_str_unchecked("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/");
560    pub(crate) const URL_SAFE: Alphabet = Alphabet::from_str_unchecked("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_");
561}
562
563pub(crate) mod decode {
564    use super::*;
565
566    #[inline]
567    pub(crate) fn decode_helper(
568        input: &[u8],
569        estimate: GeneralPurposeEstimate,
570        output: &mut [u8],
571        decode_table: &[u8; 256],
572        decode_allow_trailing_bits: bool,
573        padding_mode: DecodePaddingMode,
574    ) -> Result<DecodeMetadata, Error> {
575        let input_complete_nonterminal_quads_len = complete_quads_len(input, estimate.rem, output.len(), decode_table)?;
576
577        const UNROLLED_INPUT_CHUNK_SIZE: usize = 32;
578        const UNROLLED_OUTPUT_CHUNK_SIZE: usize = UNROLLED_INPUT_CHUNK_SIZE / 4 * 3;
579
580        let input_complete_quads_after_unrolled_chunks_len = input_complete_nonterminal_quads_len % UNROLLED_INPUT_CHUNK_SIZE;
581
582        let input_unrolled_loop_len = input_complete_nonterminal_quads_len - input_complete_quads_after_unrolled_chunks_len;
583
584        // chunks of 32 bytes
585        for (chunk_index, chunk) in input[..input_unrolled_loop_len].chunks_exact(UNROLLED_INPUT_CHUNK_SIZE).enumerate() {
586            let input_index = chunk_index * UNROLLED_INPUT_CHUNK_SIZE;
587            let chunk_output = &mut output[chunk_index * UNROLLED_OUTPUT_CHUNK_SIZE..(chunk_index + 1) * UNROLLED_OUTPUT_CHUNK_SIZE];
588
589            decode_chunk_8(&chunk[0..8], input_index, decode_table, &mut chunk_output[0..6])?;
590            decode_chunk_8(&chunk[8..16], input_index + 8, decode_table, &mut chunk_output[6..12])?;
591            decode_chunk_8(&chunk[16..24], input_index + 16, decode_table, &mut chunk_output[12..18])?;
592            decode_chunk_8(&chunk[24..32], input_index + 24, decode_table, &mut chunk_output[18..24])?;
593        }
594
595        // remaining quads, except for the last possibly partial one, as it may have padding
596        let output_unrolled_loop_len = input_unrolled_loop_len / 4 * 3;
597        let output_complete_quad_len = input_complete_nonterminal_quads_len / 4 * 3;
598        {
599            let output_after_unroll = &mut output[output_unrolled_loop_len..output_complete_quad_len];
600
601            for (chunk_index, chunk) in input[input_unrolled_loop_len..input_complete_nonterminal_quads_len]
602                .chunks_exact(4)
603                .enumerate()
604            {
605                let chunk_output = &mut output_after_unroll[chunk_index * 3..chunk_index * 3 + 3];
606
607                decode_chunk_4(chunk, input_unrolled_loop_len + chunk_index * 4, decode_table, chunk_output)?;
608            }
609        }
610
611        decode_suffix(
612            input,
613            input_complete_nonterminal_quads_len,
614            output,
615            output_complete_quad_len,
616            decode_table,
617            decode_allow_trailing_bits,
618            padding_mode,
619        )
620    }
621
622    pub(crate) fn complete_quads_len(
623        input: &[u8],
624        input_len_rem: usize,
625        output_len: usize,
626        decode_table: &[u8; 256],
627    ) -> Result<usize, Error> {
628        debug_assert!(input.len() % 4 == input_len_rem);
629
630        // detect a trailing invalid byte, like a newline, as a user convenience
631        if input_len_rem == 1 {
632            let last_byte = input[input.len() - 1];
633            // exclude pad bytes; might be part of padding that extends from earlier in the input
634            if last_byte != PAD_BYTE && decode_table[usize::from(last_byte)] == INVALID_VALUE {
635                return Err(Error::InvalidByte(input.len() - 1, last_byte));
636            }
637        };
638
639        // skip last quad, even if it's complete, as it may have padding
640        let input_complete_nonterminal_quads_len = input
641            .len()
642            .saturating_sub(input_len_rem)
643            // if rem was 0, subtract 4 to avoid padding
644            .saturating_sub((input_len_rem == 0) as usize * 4);
645        debug_assert!(input.is_empty() || (1..=4).contains(&(input.len() - input_complete_nonterminal_quads_len)));
646
647        // check that everything except the last quad handled by decode_suffix will fit
648        if output_len < input_complete_nonterminal_quads_len / 4 * 3 {
649            return Err(Error::OutputSliceTooSmall);
650        };
651        Ok(input_complete_nonterminal_quads_len)
652    }
653
654    #[inline(always)]
655    pub(crate) fn decode_chunk_8(
656        input: &[u8],
657        index_at_start_of_input: usize,
658        decode_table: &[u8; 256],
659        output: &mut [u8],
660    ) -> Result<(), Error> {
661        let morsel = decode_table[usize::from(input[0])];
662        if morsel == INVALID_VALUE {
663            return Err(Error::InvalidByte(index_at_start_of_input, input[0]));
664        }
665        let mut accum = u64::from(morsel) << 58;
666
667        let morsel = decode_table[usize::from(input[1])];
668        if morsel == INVALID_VALUE {
669            return Err(Error::InvalidByte(index_at_start_of_input + 1, input[1]));
670        }
671        accum |= u64::from(morsel) << 52;
672
673        let morsel = decode_table[usize::from(input[2])];
674        if morsel == INVALID_VALUE {
675            return Err(Error::InvalidByte(index_at_start_of_input + 2, input[2]));
676        }
677        accum |= u64::from(morsel) << 46;
678
679        let morsel = decode_table[usize::from(input[3])];
680        if morsel == INVALID_VALUE {
681            return Err(Error::InvalidByte(index_at_start_of_input + 3, input[3]));
682        }
683        accum |= u64::from(morsel) << 40;
684
685        let morsel = decode_table[usize::from(input[4])];
686        if morsel == INVALID_VALUE {
687            return Err(Error::InvalidByte(index_at_start_of_input + 4, input[4]));
688        }
689        accum |= u64::from(morsel) << 34;
690
691        let morsel = decode_table[usize::from(input[5])];
692        if morsel == INVALID_VALUE {
693            return Err(Error::InvalidByte(index_at_start_of_input + 5, input[5]));
694        }
695        accum |= u64::from(morsel) << 28;
696
697        let morsel = decode_table[usize::from(input[6])];
698        if morsel == INVALID_VALUE {
699            return Err(Error::InvalidByte(index_at_start_of_input + 6, input[6]));
700        }
701        accum |= u64::from(morsel) << 22;
702
703        let morsel = decode_table[usize::from(input[7])];
704        if morsel == INVALID_VALUE {
705            return Err(Error::InvalidByte(index_at_start_of_input + 7, input[7]));
706        }
707        accum |= u64::from(morsel) << 16;
708
709        output[..6].copy_from_slice(&accum.to_be_bytes()[..6]);
710
711        Ok(())
712    }
713
714    #[inline(always)]
715    pub(crate) fn decode_chunk_4(
716        input: &[u8],
717        index_at_start_of_input: usize,
718        decode_table: &[u8; 256],
719        output: &mut [u8],
720    ) -> Result<(), Error> {
721        let morsel = decode_table[usize::from(input[0])];
722        if morsel == INVALID_VALUE {
723            return Err(Error::InvalidByte(index_at_start_of_input, input[0]));
724        }
725        let mut accum = u32::from(morsel) << 26;
726
727        let morsel = decode_table[usize::from(input[1])];
728        if morsel == INVALID_VALUE {
729            return Err(Error::InvalidByte(index_at_start_of_input + 1, input[1]));
730        }
731        accum |= u32::from(morsel) << 20;
732
733        let morsel = decode_table[usize::from(input[2])];
734        if morsel == INVALID_VALUE {
735            return Err(Error::InvalidByte(index_at_start_of_input + 2, input[2]));
736        }
737        accum |= u32::from(morsel) << 14;
738
739        let morsel = decode_table[usize::from(input[3])];
740        if morsel == INVALID_VALUE {
741            return Err(Error::InvalidByte(index_at_start_of_input + 3, input[3]));
742        }
743        accum |= u32::from(morsel) << 8;
744
745        output[..3].copy_from_slice(&accum.to_be_bytes()[..3]);
746
747        Ok(())
748    }
749
750    pub(crate) fn decode_suffix(
751        input: &[u8],
752        input_index: usize,
753        output: &mut [u8],
754        mut output_index: usize,
755        decode_table: &[u8; 256],
756        decode_allow_trailing_bits: bool,
757        padding_mode: DecodePaddingMode,
758    ) -> Result<DecodeMetadata, Error> {
759        debug_assert!((input.len() - input_index) <= 4);
760
761        // Decode any leftovers that might not be a complete input chunk of 4 bytes.
762        // Use a u32 as a stack-resident 4 byte buffer.
763        let mut morsels_in_leftover = 0;
764        let mut padding_bytes_count = 0;
765        // offset from input_index
766        let mut first_padding_offset: usize = 0;
767        let mut last_symbol = 0_u8;
768        let mut morsels = [0_u8; 4];
769
770        for (leftover_index, &b) in input[input_index..].iter().enumerate() {
771            // '=' padding
772            if b == PAD_BYTE {
773                // There can be bad padding bytes in a few ways:
774                // 1 - Padding with non-padding characters after it
775                // 2 - Padding after zero or one characters in the current quad (should only
776                //     be after 2 or 3 chars)
777                // 3 - More than two characters of padding. If 3 or 4 padding chars
778                //     are in the same quad, that implies it will be caught by #2.
779                //     If it spreads from one quad to another, it will be an invalid byte
780                //     in the first quad.
781                // 4 - Non-canonical padding -- 1 byte when it should be 2, etc.
782                //     Per config, non-canonical but still functional non- or partially-padded base64
783                //     may be treated as an error condition.
784
785                if leftover_index < 2 {
786                    // Check for error #2.
787                    // Either the previous byte was padding, in which case we would have already hit
788                    // this case, or it wasn't, in which case this is the first such error.
789                    debug_assert!(leftover_index == 0 || (leftover_index == 1 && padding_bytes_count == 0));
790                    let bad_padding_index = input_index + leftover_index;
791                    return Err(Error::InvalidByte(bad_padding_index, b));
792                }
793
794                if padding_bytes_count == 0 {
795                    first_padding_offset = leftover_index;
796                }
797
798                padding_bytes_count += 1;
799                continue;
800            }
801
802            // Check for case #1.
803            // To make '=' handling consistent with the main loop, don't allow
804            // non-suffix '=' in trailing chunk either. Report error as first
805            // erroneous padding.
806            if padding_bytes_count > 0 {
807                return Err(Error::InvalidByte(input_index + first_padding_offset, PAD_BYTE));
808            }
809
810            last_symbol = b;
811
812            // can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding.
813            // Pack the leftovers from left to right.
814            let morsel = decode_table[b as usize];
815            if morsel == INVALID_VALUE {
816                return Err(Error::InvalidByte(input_index + leftover_index, b));
817            }
818
819            morsels[morsels_in_leftover] = morsel;
820            morsels_in_leftover += 1;
821        }
822
823        // If there was 1 trailing byte, and it was valid, and we got to this point without hitting
824        // an invalid byte, now we can report invalid length
825        if !input.is_empty() && morsels_in_leftover < 2 {
826            return Err(Error::InvalidLength(input_index + morsels_in_leftover));
827        }
828
829        match padding_mode {
830            DecodePaddingMode::Indifferent => { /* everything we care about was already checked */ }
831            DecodePaddingMode::RequireCanonical => {
832                // allow empty input
833                if (padding_bytes_count + morsels_in_leftover) % 4 != 0 {
834                    return Err(Error::InvalidPadding);
835                }
836            }
837            DecodePaddingMode::RequireNone => {
838                if padding_bytes_count > 0 {
839                    // check at the end to make sure we let the cases of padding that should be InvalidByte get hit
840                    return Err(Error::InvalidPadding);
841                }
842            }
843        }
844
845        // When encoding 1 trailing byte (e.g. 0xFF), 2 base64 bytes ("/w") are needed.
846        // / is the symbol for 63 (0x3F, bottom 6 bits all set) and w is 48 (0x30, top 2 bits
847        // of bottom 6 bits set).
848        // When decoding two symbols back to one trailing byte, any final symbol higher than
849        // w would still decode to the original byte because we only care about the top two
850        // bits in the bottom 6, but would be a non-canonical encoding. So, we calculate a
851        // mask based on how many bits are used for just the canonical encoding, and optionally
852        // error if any other bits are set. In the example of one encoded byte -> 2 symbols,
853        // 2 symbols can technically encode 12 bits, but the last 4 are non-canonical, and
854        // useless since there are no more symbols to provide the necessary 4 additional bits
855        // to finish the second original byte.
856
857        let leftover_bytes_to_append = morsels_in_leftover * 6 / 8;
858        // Put the up to 6 complete bytes as the high bytes.
859        // Gain a couple percent speedup from nudging these ORs to use more ILP with a two-way split.
860        let mut leftover_num =
861            (u32::from(morsels[0]) << 26) | (u32::from(morsels[1]) << 20) | (u32::from(morsels[2]) << 14) | (u32::from(morsels[3]) << 8);
862
863        // if there are bits set outside the bits we care about, last symbol encodes trailing bits that
864        // will not be included in the output
865        let mask = !0_u32 >> (leftover_bytes_to_append * 8);
866        if !decode_allow_trailing_bits && (leftover_num & mask) != 0 {
867            // last morsel is at `morsels_in_leftover` - 1
868            return Err(Error::InvalidLastSymbol(input_index + morsels_in_leftover - 1, last_symbol));
869        }
870
871        // Strangely, this approach benchmarks better than writing bytes one at a time,
872        // or copy_from_slice into output.
873        for _ in 0..leftover_bytes_to_append {
874            let hi_byte = (leftover_num >> 24) as u8;
875            leftover_num <<= 8;
876            *output.get_mut(output_index).ok_or(Error::OutputSliceTooSmall)? = hi_byte;
877            output_index += 1;
878        }
879
880        let padding_index = if padding_bytes_count > 0 {
881            Some(input_index + first_padding_offset)
882        } else {
883            None
884        };
885        Ok(DecodeMetadata::new(output_index, padding_index))
886    }
887}
888
889pub use error::Error;
890
891pub(crate) mod error {
892    /// Errors that can occur while encoding/decoding base64.
893    #[derive(Clone, Debug, PartialEq, Eq)]
894    pub enum Error {
895        /// An invalid byte was found in the input. The offset and offending byte are provided.
896        ///
897        /// Padding characters (`=`) interspersed in the encoded form are invalid, as they may only
898        /// be present as the last 0-2 bytes of input.
899        ///
900        /// This error may also indicate that extraneous trailing input bytes are present, causing
901        /// otherwise valid padding to no longer be the last bytes of input.
902        InvalidByte(usize, u8),
903
904        /// The length of the input, as measured in valid base64 symbols, is invalid.
905        /// There must be 2-4 symbols in the last input quad.
906        InvalidLength(usize),
907
908        /// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded.
909        /// This is indicative of corrupted or truncated Base64.
910        /// Unlike [Error::InvalidByte], which reports symbols that aren't in the alphabet,
911        /// this error is for symbols that are in the alphabet but represent nonsensical encodings.
912        InvalidLastSymbol(usize, u8),
913
914        /// The nature of the padding was not as configured: absent or incorrect when it must be
915        /// canonical, or present when it must be absent, etc.
916        InvalidPadding,
917
918        /// The provided slice is too small.
919        OutputSliceTooSmall,
920    }
921
922    impl core::fmt::Display for Error {
923        fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
924            match self {
925                Self::InvalidByte(index, byte) => {
926                    write!(f, "Invalid symbol {byte}, offset {index}.")
927                }
928                Self::InvalidLength(len) => write!(f, "Invalid input length: {len}"),
929                Self::InvalidLastSymbol(index, byte) => {
930                    write!(f, "Invalid last symbol {byte}, offset {index}.")
931                }
932                Self::InvalidPadding => write!(f, "Invalid padding"),
933                Self::OutputSliceTooSmall => write!(f, "Output slice too small"),
934            }
935        }
936    }
937
938    #[cfg(feature = "std")]
939    impl std::error::Error for Error {}
940}