crabmole/encoding/
base64.rs

1/// Error
2#[derive(Debug, Clone, Eq, PartialEq)]
3pub enum Error {
4    /// Invalid encoder
5    InvalidEncoder,
6
7    /// Invalid padding character
8    InvalidPadding,
9}
10
11impl core::fmt::Display for Error {
12    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
13        match self {
14            Error::InvalidEncoder => write!(f, "Base64 alphabet must be 32 bytes long"),
15            Error::InvalidPadding => write!(f, "Invalid padding character"),
16        }
17    }
18}
19
20#[cfg(feature = "std")]
21impl std::error::Error for Error {}
22
23/// Decode error
24#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
25pub struct DecodeError(usize);
26
27impl DecodeError {
28    /// leak the inner input byte
29    #[inline]
30    pub const fn into_inner(self) -> usize {
31        self.0
32    }
33}
34
35impl core::fmt::Display for DecodeError {
36    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
37        write!(f, "illegal base64 data at input byte {}", self.0)
38    }
39}
40
41#[cfg(feature = "std")]
42impl std::error::Error for DecodeError {}
43
44/// `BASE = 64`
45pub const BASE: usize = 64;
46
47const ENCODE_STD: [u8; BASE] = *b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
48const ENCODE_URL: [u8; BASE] = *b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
49
50/// No padding
51pub const NO_PADDING: Option<char> = None;
52
53/// Standard padding
54pub const STD_PADDING: Option<char> = Some('=');
55
56/// The standard base64 encoding, as defined in
57/// RFC 4648.
58pub const STD_ENCODING: Base64 = Base64::new_unchecked(ENCODE_STD);
59
60/// The alternate base64 encoding defined in RFC 4648.
61/// It is typically used in URLs and file names.
62pub const URL_ENCODING: Base64 = Base64::new_unchecked(ENCODE_URL);
63
64/// The standard raw, unpadded base64 encoding,
65/// as defined in RFC 4648 section 3.2.
66/// This is the same as [`STD_ENCODING`] but omits padding characters.
67pub const RAW_STD_ENCODING: Base64 = Base64::new_unchecked(ENCODE_STD).with_padding_unchecked(None);
68
69/// The unpadded alternate base64 encoding defined in RFC 4648.
70/// It is typically used in URLs and file names.
71/// This is the same as [`URL_ENCODING`] but omits padding characters.
72pub const RAW_URL_ENCODING: Base64 = Base64::new_unchecked(ENCODE_URL).with_padding_unchecked(None);
73
74const DECODE_MAP_INITIALIZE: [u8; 256] = [255; 256];
75
76/// An Base64 is a radix 64 encoding/decoding scheme, defined by a
77/// 64-character alphabet. The most common encoding is the "base64"
78/// encoding defined in RFC 4648 and used in MIME (RFC 2045) and PEM
79/// (RFC 1421).  RFC 4648 also defines an alternate encoding, which is
80/// the standard encoding with - and _ substituted for + and /.
81#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
82pub struct Base64 {
83    encode: [u8; BASE],
84    decode_map: [u8; 256],
85    pad_char: Option<char>,
86    strict: bool,
87}
88
89impl Default for Base64 {
90    fn default() -> Self {
91        STD_ENCODING
92    }
93}
94
95impl Base64 {
96    /// Returns a new padded Base64 defined by the given alphabet,
97    /// which must be a 64-byte array that does not contain the padding character
98    /// or CR / LF ('\r', '\n').
99    /// The resulting Base64 uses the default padding character ('='),
100    /// which may be changed or disabled via [`Base64::with_padding`].
101    #[inline]
102    pub const fn new(encoder: [u8; BASE]) -> Result<Self, Error> {
103        const CH: char = '=';
104        let mut decode_map = DECODE_MAP_INITIALIZE;
105        let mut idx = 0;
106        while idx < BASE {
107            if encoder[idx] == b'\n' || encoder[idx] == b'\r' || encoder[idx] == CH as u8 {
108                return Err(Error::InvalidEncoder);
109            }
110            decode_map[encoder[idx] as usize] = idx as u8;
111            idx += 1;
112        }
113
114        Ok(Self {
115            encode: encoder,
116            decode_map,
117            pad_char: Some('='),
118            strict: false,
119        })
120    }
121
122    /// Returns a new padded Base64 defined by the given alphabet,
123    /// which must be a 64-byte array that does not contain the padding character
124    /// or CR / LF ('\r', '\n').
125    /// The resulting Base64 uses the default padding character ('='),
126    /// which may be changed or disabled via [`Base64::with_padding_unchecked`].
127    ///
128    /// # Panic
129    /// 64-byte array that contains the padding character ('=')
130    /// or CR / LF ('\r', '\n').
131    #[inline]
132    pub const fn new_unchecked(encoder: [u8; BASE]) -> Self {
133        const CH: char = '=';
134        let mut decode_map = DECODE_MAP_INITIALIZE;
135        let mut idx = 0;
136        while idx < BASE {
137            if encoder[idx] == b'\n' || encoder[idx] == b'\r' || encoder[idx] == CH as u8 {
138                panic!("encoding alphabet contains newline character or padding character");
139            }
140            decode_map[encoder[idx] as usize] = idx as u8;
141            idx += 1;
142        }
143
144        Self {
145            encode: encoder,
146            decode_map,
147            pad_char: Some('='),
148            strict: false,
149        }
150    }
151
152    /// Creates a new encoding identical to enc except
153    /// with a specified padding character, or NoPadding to disable padding.
154    /// The padding character must not be '\r' or '\n', must not
155    /// be contained in the encoding's alphabet and must be a rune equal or
156    /// below '\xff'.
157    #[inline]
158    pub const fn with_padding(self, pad: Option<char>) -> Result<Self, Error> {
159        let Self {
160            encode: encoder,
161            mut decode_map,
162            pad_char: _,
163            strict,
164        } = self;
165
166        match pad {
167            Some(ch) => {
168                let mut idx = 0;
169                while idx < BASE {
170                    if encoder[idx] == b'\n' || encoder[idx] == b'\r' || encoder[idx] == ch as u8 {
171                        return Err(Error::InvalidPadding);
172                    }
173                    decode_map[encoder[idx] as usize] = idx as u8;
174                    idx += 1;
175                }
176            }
177            None => {
178                let mut idx = 0;
179                while idx < BASE {
180                    if encoder[idx] == b'\n' || encoder[idx] == b'\r' {
181                        return Err(Error::InvalidPadding);
182                    }
183                    decode_map[encoder[idx] as usize] = idx as u8;
184                    idx += 1;
185                }
186            }
187        }
188
189        Ok(Self {
190            encode: encoder,
191            decode_map,
192            pad_char: pad,
193            strict,
194        })
195    }
196
197    /// Creates a new encoding identical to enc except
198    /// with a specified padding character, or [`NO_PADDING`] to disable padding.
199    /// The padding character must not be '\r' or '\n', must not
200    /// be contained in the encoding's alphabet and must be a rune equal or
201    /// below '\xff'.
202    ///
203    /// # Panic
204    /// 64-byte array that contains the padding character
205    /// or CR / LF ('\r', '\n').
206    #[inline]
207    pub const fn with_padding_unchecked(self, pad: Option<char>) -> Self {
208        let Self {
209            encode: encoder,
210            mut decode_map,
211            pad_char: _,
212            strict,
213        } = self;
214
215        match pad {
216            Some(ch) => {
217                let mut idx = 0;
218                while idx < BASE {
219                    if encoder[idx] == b'\n' || encoder[idx] == b'\r' || encoder[idx] == ch as u8 {
220                        panic!("encoding alphabet contains newline character or padding character");
221                    }
222                    decode_map[encoder[idx] as usize] = idx as u8;
223                    idx += 1;
224                }
225            }
226            None => {
227                let mut idx = 0;
228                while idx < BASE {
229                    if encoder[idx] == b'\n' || encoder[idx] == b'\r' {
230                        panic!("encoding alphabet contains newline character or padding character");
231                    }
232                    decode_map[encoder[idx] as usize] = idx as u8;
233                    idx += 1;
234                }
235            }
236        }
237
238        Self {
239            encode: encoder,
240            decode_map,
241            pad_char: pad,
242            strict,
243        }
244    }
245
246    /// Creates a new encoding identical to enc except with
247    /// strict decoding enabled. In this mode, the decoder requires that
248    /// trailing padding bits are zero, as described in RFC 4648 section 3.5.
249    ///
250    /// Note that the input is still malleable, as new line characters
251    /// (CR and LF) are still ignored.
252    #[inline]
253    pub const fn with_strict(mut self) -> Self {
254        self.strict = true;
255        self
256    }
257
258    /// Returns the length in bytes of the base64 encoding
259    /// of an input buffer of length n.
260    #[inline]
261    pub const fn encoded_len(&self, n: usize) -> usize {
262        if self.pad_char.is_none() {
263            return (n * 8 + 5) / 6;
264        }
265        (n + 2) / 3 * 4
266    }
267
268    /// Returns a base64 encoder.
269    #[inline]
270    pub const fn encoder<W: std::io::Write>(self, w: W) -> Encoder<W> {
271        Encoder::new(self, w)
272    }
273
274    /// Encodes src using the encoding enc, writing
275    /// EncodedLen(len(src)) bytes to dst.
276    ///
277    /// The encoding pads the output to a multiple of 4 bytes,
278    /// so Encode is not appropriate for use on individual blocks
279    /// of a large data stream. Use NewEncoder() instead.
280    pub fn encode(&self, src: &[u8], dst: &mut [u8]) {
281        if src.is_empty() {
282            return;
283        }
284
285        let (mut di, mut si) = (0, 0);
286        let n = (src.len() / 3) * 3;
287        while si < n {
288            // Convert 3x 8bit source bytes into 4 bytes
289            let val =
290                ((src[si] as usize) << 16) | ((src[si + 1] as usize) << 8) | (src[si + 2] as usize);
291
292            dst[di] = self.encode[(val >> 18) & 0x3f];
293            dst[di + 1] = self.encode[(val >> 12) & 0x3f];
294            dst[di + 2] = self.encode[(val >> 6) & 0x3f];
295            dst[di + 3] = self.encode[val & 0x3f];
296
297            si += 3;
298            di += 4;
299        }
300
301        let remain = src.len() - si;
302        if remain == 0 {
303            return;
304        }
305
306        // Add the remaining small block
307        let mut val = (src[si] as usize) << 16;
308        if remain == 2 {
309            val |= (src[si + 1] as usize) << 8;
310        }
311
312        dst[di] = self.encode[(val >> 18) & 0x3f];
313        dst[di + 1] = self.encode[(val >> 12) & 0x3f];
314
315        match remain {
316            2 => {
317                dst[di + 2] = self.encode[(val >> 6) & 0x3f];
318                if let Some(ch) = self.pad_char {
319                    dst[di + 3] = ch as u8;
320                }
321            }
322            1 => {
323                if let Some(ch) = self.pad_char {
324                    dst[di + 2] = ch as u8;
325                    dst[di + 3] = ch as u8;
326                }
327            }
328            _ => {}
329        }
330    }
331
332    /// Returns the base64 encoding of src.
333    #[cfg(feature = "alloc")]
334    pub fn encode_to_vec(&self, src: &[u8]) -> alloc::vec::Vec<u8> {
335        let mut buf = alloc::vec![0; self.encoded_len(src.len())];
336        self.encode(src, &mut buf);
337        buf
338    }
339
340    /// Returns the base64 decoder
341    #[cfg(feature = "std")]
342    pub const fn decoder<R: std::io::Read>(self, r: R) -> Decoder<R> {
343        Decoder::new(self, r)
344    }
345
346    /// Decodes src using the encoding enc. It writes at most
347    /// `self.decoded_len(src.len())` bytes to dst and returns the number of bytes
348    /// written. If src contains invalid base64 data, it will return the
349    /// number of bytes successfully written and DecodeError.
350    /// New line characters (\r and \n) are ignored.
351    pub fn decode(&self, src: &[u8], dst: &mut [u8]) -> Result<usize, DecodeError> {
352        if src.is_empty() {
353            return Ok(0);
354        }
355
356        let mut n = 0;
357        let mut si = 0;
358        while usize::BITS >= 64 && src.len() - si >= 8 && dst.len() - n >= 8 {
359            let src2 = &src[si..si + 8];
360            let (dn, ok) = assemble_64(
361                self.decode_map[src2[0] as usize],
362                self.decode_map[src2[1] as usize],
363                self.decode_map[src2[2] as usize],
364                self.decode_map[src2[3] as usize],
365                self.decode_map[src2[4] as usize],
366                self.decode_map[src2[5] as usize],
367                self.decode_map[src2[6] as usize],
368                self.decode_map[src2[7] as usize],
369            );
370
371            if ok {
372                dst[n..n + core::mem::size_of::<u64>()].copy_from_slice(&dn.to_be_bytes());
373                n += 6;
374                si += 8;
375            } else {
376                let (si1, ninc) = self.decode_quantum(src, &mut dst[n..], si)?;
377                si = si1;
378                n += ninc;
379            }
380        }
381
382        while src.len() - si >= 4 && dst.len() - n >= 4 {
383            let src2 = &src[si..si + 4];
384            let (dn, ok) = assemble_32(
385                self.decode_map[src2[0] as usize],
386                self.decode_map[src2[1] as usize],
387                self.decode_map[src2[2] as usize],
388                self.decode_map[src2[3] as usize],
389            );
390            if ok {
391                dst[n..n + core::mem::size_of::<u32>()].copy_from_slice(&dn.to_be_bytes());
392                n += 3;
393                si += 4;
394            } else {
395                let (si1, ninc) = self.decode_quantum(src, &mut dst[n..], si)?;
396                si = si1;
397                n += ninc;
398            }
399        }
400
401        while si < src.len() {
402            let (si1, ninc) = self.decode_quantum(src, &mut dst[n..], si)?;
403            si = si1;
404            n += ninc;
405        }
406        Ok(n)
407    }
408
409    /// Returns the bytes represented by the base64 vec s.
410    #[cfg(feature = "alloc")]
411    pub fn decode_to_vec(&self, src: &[u8]) -> Result<alloc::vec::Vec<u8>, DecodeError> {
412        let mut buf = alloc::vec![0; self.decoded_len(src.len())];
413        let n = self.decode(src, &mut buf)?;
414        buf.truncate(n);
415        Ok(buf)
416    }
417
418    /// Decodes up to 4 base64 bytes. The received parameters are
419    /// the destination buffer dst, the source buffer src and an index in the
420    /// source buffer si.
421    /// It returns the number of bytes read from src, the number of bytes written
422    /// to dst, and an error, if any.
423    #[inline]
424    fn decode_quantum(
425        self,
426        src: &[u8],
427        dst: &mut [u8],
428        mut si: usize,
429    ) -> Result<(usize, usize), DecodeError> {
430        let mut dbuf = [0; 4];
431        let mut dlen = 4;
432        let mut j = 0;
433        while j < dbuf.len() {
434            if src.len() == si {
435                match () {
436                    () if j == 0 => {
437                        return Ok((si, 0));
438                    }
439                    () if j == 1 || self.pad_char.is_some() => {
440                        return Err(DecodeError(si - j));
441                    }
442                    _ => {}
443                }
444                dlen = j;
445                break;
446            }
447            let in_ = src[si];
448            si += 1;
449
450            let out = self.decode_map[in_ as usize];
451            if out != 0xff {
452                dbuf[j] = out;
453                j += 1;
454                continue;
455            }
456
457            if in_ == b'\n' || in_ == b'\r' {
458                continue;
459            }
460
461            if let Some(ch) = self.pad_char {
462                if (in_ as char) != ch {
463                    return Err(DecodeError(si - 1));
464                }
465            }
466            // We've reached the end and there's padding
467            match j {
468                0 | 1 => {
469                    // incorrect padding
470                    return Err(DecodeError(si - 1));
471                }
472                2 => {
473                    // "==" is expected, the first "=" is already consumed.
474                    // skip over newlines
475                    while si < src.len() && (src[si] == b'\n' || src[si] == b'\r') {
476                        si += 1;
477                    }
478                    if si == src.len() {
479                        // not enough padding
480                        return Err(DecodeError(src.len()));
481                    }
482                    if let Some(ch) = self.pad_char {
483                        if (src[si] as char) != ch {
484                            return Err(DecodeError(si - 1));
485                        }
486                    }
487                    si += 1;
488                }
489                _ => {}
490            }
491            // skip over newlines
492            while si < src.len() && (src[si] == b'\n' || src[si] == b'\r') {
493                si += 1;
494            }
495            if si < src.len() {
496                // trailing garbage
497                return Err(DecodeError(si));
498            }
499            dlen = j;
500            break;
501        }
502
503        // Convert 4x 6bit source bytes into 3 bytes
504        let val = ((dbuf[0] as usize) << 18)
505            | ((dbuf[1] as usize) << 12)
506            | ((dbuf[2] as usize) << 6)
507            | (dbuf[3] as usize);
508        dbuf[2] = val as u8;
509        dbuf[1] = (val >> 8) as u8;
510        dbuf[0] = (val >> 16) as u8;
511
512        match dlen {
513            4 => {
514                dst[2] = dbuf[2];
515                dbuf[2] = 0;
516                dst[1] = dbuf[1];
517                if self.strict && dbuf[2] != 0 {
518                    return Err(DecodeError(si - 1));
519                }
520                dbuf[1] = 0;
521                dst[0] = dbuf[0];
522                if self.strict && (dbuf[1] != 0 || dbuf[2] != 0) {
523                    return Err(DecodeError(si - 2));
524                }
525            }
526            3 => {
527                dst[1] = dbuf[1];
528                if self.strict && dbuf[2] != 0 {
529                    return Err(DecodeError(si - 1));
530                }
531                dbuf[1] = 0;
532                dst[0] = dbuf[0];
533                if self.strict && (dbuf[1] != 0 || dbuf[2] != 0) {
534                    return Err(DecodeError(si - 2));
535                }
536            }
537            2 => {
538                dst[0] = dbuf[0];
539                if self.strict && (dbuf[1] != 0 || dbuf[2] != 0) {
540                    return Err(DecodeError(si - 2));
541                }
542            }
543            _ => {}
544        }
545        Ok((si, dlen - 1))
546    }
547
548    /// Returns the maximum length in bytes of the decoded data
549    /// corresponding to n bytes of base64-encoded
550    #[inline]
551    pub const fn decoded_len(&self, n: usize) -> usize {
552        if self.pad_char.is_none() {
553            return n * 6 / 8;
554        }
555        n / 4 * 3
556    }
557}
558
559/// Base64 encoder
560pub struct Encoder<W> {
561    enc: Base64,
562    w: W,
563    buf: [u8; 3],
564    nbuf: usize,
565    out: [u8; 1024],
566}
567
568impl<W> Encoder<W> {
569    /// Returns a new encoder based on the given encoding
570    #[inline]
571    pub const fn new(enc: Base64, w: W) -> Self {
572        Self {
573            enc,
574            w,
575            buf: [0; 3],
576            nbuf: 0,
577            out: [0; 1024],
578        }
579    }
580}
581
582#[cfg(feature = "std")]
583impl<W: std::io::Write> std::io::Write for Encoder<W> {
584    #[inline]
585    fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
586        let mut n = 0;
587        // Leading fringe.
588        if self.nbuf > 0 {
589            let mut i = 0;
590            while i < buf.len() && self.nbuf < 3 {
591                self.buf[self.nbuf] = buf[i];
592                self.nbuf += 1;
593                i += 1;
594            }
595            n += i;
596            buf = &buf[i..];
597            if self.nbuf < 3 {
598                return Ok(n);
599            }
600
601            self.enc.encode(&self.buf, &mut self.out);
602            self.w.write_all(&self.out[..4])?;
603            self.nbuf = 0;
604        }
605
606        // Large interior chunks.
607        while buf.len() >= 3 {
608            let mut nn = self.out.len() / 4 * 3;
609            if nn > buf.len() {
610                nn = buf.len();
611                nn -= nn % 3;
612            }
613            self.enc.encode(&buf[..nn], &mut self.out);
614            self.w.write_all(&self.out[..nn / 3 * 4])?;
615            n += nn;
616            buf = &buf[nn..];
617        }
618
619        // Trailing fringe.
620        crate::copy(buf, &mut self.buf);
621        self.nbuf = buf.len();
622        n += buf.len();
623        Ok(n)
624    }
625
626    #[inline]
627    fn flush(&mut self) -> std::io::Result<()> {
628        if self.nbuf > 0 {
629            self.enc.encode(&self.buf[..self.nbuf], &mut self.out);
630            self.w
631                .write_all(&self.out[..self.enc.encoded_len(self.nbuf)])?;
632            self.nbuf = 0;
633        }
634        Ok(())
635    }
636}
637
638#[cfg(all(feature = "std", feature = "io"))]
639impl<W: std::io::Write> crate::io::Closer for Encoder<W> {
640    fn close(&mut self) -> std::io::Result<()> {
641        use std::io::Write;
642        self.flush()
643    }
644}
645
646struct NewLineFilteringReader<R> {
647    wrapped: R,
648}
649
650#[cfg(feature = "std")]
651impl<R: std::io::Read> std::io::Read for NewLineFilteringReader<R> {
652    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
653        let mut n = self.wrapped.read(buf)?;
654        while n > 0 {
655            let mut offset = 0;
656            for i in 0..n {
657                if buf[i] != b'\r' && buf[i] != b'\n' {
658                    if i != offset {
659                        buf[offset] = buf[i];
660                    }
661                    offset += 1;
662                }
663            }
664            if offset > 0 {
665                return Ok(offset);
666            }
667            // Previous buffer entirely whitespace, read again
668            n = self.wrapped.read(buf)?;
669        }
670        Ok(n)
671    }
672}
673
674/// Base64 decoder
675#[cfg(feature = "alloc")]
676pub struct Decoder<R> {
677    eof: bool,
678    r: NewLineFilteringReader<R>,
679    enc: Base64,
680    buf: [u8; 1024],
681    nbuf: usize,
682    out: alloc::vec::Vec<u8>,
683    outbuf: [u8; 1024 / 4 * 3],
684}
685
686#[cfg(feature = "alloc")]
687impl<R> Decoder<R> {
688    /// Create a new decoder
689    #[inline]
690    pub const fn new(enc: Base64, r: R) -> Decoder<R> {
691        Decoder {
692            eof: false,
693            r: NewLineFilteringReader { wrapped: r },
694            enc,
695            buf: [0; 1024],
696            nbuf: 0,
697            out: alloc::vec::Vec::new(),
698            outbuf: [0; 1024 / 4 * 3],
699        }
700    }
701}
702
703#[cfg(feature = "std")]
704impl<R: std::io::Read> std::io::Read for Decoder<R> {
705    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
706        // Use leftover decoded output from last read.
707        if !self.out.is_empty() {
708            let n = crate::copy(&self.out, buf);
709            self.out.drain(..n);
710            return Ok(n);
711        }
712
713        // This code assumes that d.r strips supported whitespace ('\r' and '\n').
714        let mut n = 0;
715
716        // Refill buffer.
717        while self.nbuf < 4 && !self.eof {
718            let mut nn = buf.len() / 3 * 4;
719            if nn < 4 {
720                nn = 4;
721            }
722            if nn > self.buf.len() {
723                nn = self.buf.len();
724            }
725            nn = self.r.read(&mut self.buf[self.nbuf..nn])?;
726            if nn == 0 {
727                self.eof = true;
728                break;
729            }
730            self.nbuf += nn;
731        }
732
733        if self.nbuf < 4 {
734            if self.enc.pad_char.is_none() && self.nbuf > 0 {
735                // Decode final fragment, without padding.
736                let nw = self
737                    .enc
738                    .decode(&self.buf[..self.nbuf], &mut self.outbuf)
739                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
740                self.nbuf = 0;
741                self.out.resize(nw, 0);
742                self.out[..nw].copy_from_slice(&self.outbuf[..nw]);
743                n = crate::copy(&self.out, buf);
744                self.out.drain(..n);
745                if n > 0 || buf.is_empty() && !self.out.is_empty() {
746                    return Ok(n);
747                }
748            }
749
750            if n == 0 && self.nbuf > 0 {
751                return Err(std::io::Error::new(
752                    std::io::ErrorKind::UnexpectedEof,
753                    "base64 decoder: unexpected EOF",
754                ));
755            }
756        }
757
758        // Decode chunk into p, or d.out and then p if p is too small.
759        let (nr, mut nw) = (self.nbuf / 4 * 4, self.nbuf / 4 * 3);
760        if nw > buf.len() {
761            nw = self
762                .enc
763                .decode(&self.buf[..nr], &mut self.outbuf)
764                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
765            self.out.resize(nw, 0);
766            self.out[..nw].copy_from_slice(&self.outbuf[..nw]);
767            n = crate::copy(&self.out, buf);
768            self.out.drain(..n);
769        } else {
770            n = self
771                .enc
772                .decode(&self.buf[..nr], buf)
773                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
774        }
775        self.nbuf -= nr;
776        self.buf.copy_within(nr..nr + self.nbuf, 0);
777
778        Ok(n)
779    }
780}
781
782/// Assembles 4 base64 digits into 3 bytes.
783/// Each digit comes from the decode map, and will be 0xff
784/// if it came from an invalid character.
785#[inline]
786const fn assemble_32(n1: u8, n2: u8, n3: u8, n4: u8) -> (u32, bool) {
787    // Check that all the digits are valid. If any of them was 0xff, their
788    // bitwise OR will be 0xff.
789    if n1 | n2 | n3 | n4 == 0xff {
790        return (0, false);
791    }
792    (
793        ((n1 as u32) << 26) | ((n2 as u32) << 20) | ((n3 as u32) << 14) | ((n4 as u32) << 8),
794        true,
795    )
796}
797
798/// Assembles 8 base64 digits into 6 bytes.
799/// Each digit comes from the decode map, and will be 0xff
800/// if it came from an invalid character.
801#[inline]
802#[allow(clippy::too_many_arguments)]
803const fn assemble_64(
804    n1: u8,
805    n2: u8,
806    n3: u8,
807    n4: u8,
808    n5: u8,
809    n6: u8,
810    n7: u8,
811    n8: u8,
812) -> (u64, bool) {
813    // Check that all the digits are valid. If any of them was 0xff, their
814    // bitwise OR will be 0xff.
815    if n1 | n2 | n3 | n4 | n5 | n6 | n7 | n8 == 0xff {
816        return (0, false);
817    }
818    (
819        ((n1 as u64) << 58)
820            | ((n2 as u64) << 52)
821            | ((n3 as u64) << 46)
822            | ((n4 as u64) << 40)
823            | ((n5 as u64) << 34)
824            | ((n6 as u64) << 28)
825            | ((n7 as u64) << 22)
826            | ((n8 as u64) << 16),
827        true,
828    )
829}
830
831#[cfg(test)]
832mod tests {
833    use std::io::{Read, Write};
834
835    use crate::io::Closer;
836
837    use super::*;
838
839    fn big_test() -> TestPair {
840        TestPair {
841            decoded: b"Twas brillig, and the slithy toves".to_vec(),
842            encoded: b"VHdhcyBicmlsbGlnLCBhbmQgdGhlIHNsaXRoeSB0b3Zlcw==".to_vec(),
843        }
844    }
845
846    struct TestPair {
847        decoded: Vec<u8>,
848        encoded: Vec<u8>,
849    }
850
851    fn pairs() -> Vec<TestPair> {
852        vec![
853            // RFC 3548 examples
854            TestPair {
855                decoded: vec![20, 251, 156, 3, 217, 126],
856                encoded: b"FPucA9l+".to_vec(),
857            },
858            TestPair {
859                decoded: vec![20, 251, 156, 3, 217],
860                encoded: b"FPucA9k=".to_vec(),
861            },
862            TestPair {
863                decoded: vec![20, 251, 156, 3],
864                encoded: b"FPucAw==".to_vec(),
865            },
866            // RFC 4648 examples
867            TestPair {
868                decoded: b"".to_vec(),
869                encoded: b"".to_vec(),
870            },
871            TestPair {
872                decoded: b"f".to_vec(),
873                encoded: b"Zg==".to_vec(),
874            },
875            TestPair {
876                decoded: b"fo".to_vec(),
877                encoded: b"Zm8=".to_vec(),
878            },
879            TestPair {
880                decoded: b"foo".to_vec(),
881                encoded: b"Zm9v".to_vec(),
882            },
883            TestPair {
884                decoded: b"foob".to_vec(),
885                encoded: b"Zm9vYg==".to_vec(),
886            },
887            TestPair {
888                decoded: b"fooba".to_vec(),
889                encoded: b"Zm9vYmE=".to_vec(),
890            },
891            TestPair {
892                decoded: b"foobar".to_vec(),
893                encoded: b"Zm9vYmFy".to_vec(),
894            },
895            // Wikipedia examples
896            TestPair {
897                decoded: b"sure.".to_vec(),
898                encoded: b"c3VyZS4=".to_vec(),
899            },
900            TestPair {
901                decoded: b"sure".to_vec(),
902                encoded: b"c3VyZQ==".to_vec(),
903            },
904            TestPair {
905                decoded: b"sur".to_vec(),
906                encoded: b"c3Vy".to_vec(),
907            },
908            TestPair {
909                decoded: b"su".to_vec(),
910                encoded: b"c3U=".to_vec(),
911            },
912            TestPair {
913                decoded: b"leasure.".to_vec(),
914                encoded: b"bGVhc3VyZS4=".to_vec(),
915            },
916            TestPair {
917                decoded: b"easure.".to_vec(),
918                encoded: b"ZWFzdXJlLg==".to_vec(),
919            },
920            TestPair {
921                decoded: b"asure.".to_vec(),
922                encoded: b"YXN1cmUu".to_vec(),
923            },
924            TestPair {
925                decoded: b"sure.".to_vec(),
926                encoded: b"c3VyZS4=".to_vec(),
927            },
928        ]
929    }
930
931    struct EncodingTest {
932        enc: Base64,
933        conv: Box<dyn Fn(String) -> String>,
934    }
935
936    fn std_ref(r: String) -> String {
937        r
938    }
939
940    fn url_ref(r: String) -> String {
941        r.replace('+', "-").replace('/', "_")
942    }
943
944    fn raw_ref(r: String) -> String {
945        r.trim_end_matches('=').to_string()
946    }
947
948    fn raw_url_ref(r: String) -> String {
949        raw_ref(url_ref(r))
950    }
951
952    const FUNNY_ENCODING: Base64 = STD_ENCODING.with_padding_unchecked(Some('@'));
953
954    fn funny_ref(r: String) -> String {
955        r.replace('=', "@")
956    }
957
958    fn encoding_tests() -> Vec<EncodingTest> {
959        vec![
960            EncodingTest {
961                enc: STD_ENCODING,
962                conv: Box::new(std_ref),
963            },
964            EncodingTest {
965                enc: URL_ENCODING,
966                conv: Box::new(url_ref),
967            },
968            EncodingTest {
969                enc: RAW_STD_ENCODING,
970                conv: Box::new(raw_ref),
971            },
972            EncodingTest {
973                enc: RAW_URL_ENCODING,
974                conv: Box::new(raw_url_ref),
975            },
976            EncodingTest {
977                enc: FUNNY_ENCODING,
978                conv: Box::new(funny_ref),
979            },
980            EncodingTest {
981                enc: STD_ENCODING.with_strict(),
982                conv: Box::new(std_ref),
983            },
984            EncodingTest {
985                enc: URL_ENCODING.with_strict(),
986                conv: Box::new(url_ref),
987            },
988            EncodingTest {
989                enc: RAW_STD_ENCODING.with_strict(),
990                conv: Box::new(raw_ref),
991            },
992            EncodingTest {
993                enc: RAW_URL_ENCODING.with_strict(),
994                conv: Box::new(raw_url_ref),
995            },
996            EncodingTest {
997                enc: FUNNY_ENCODING.with_strict(),
998                conv: Box::new(funny_ref),
999            },
1000        ]
1001    }
1002
1003    #[test]
1004    fn test_encode() {
1005        for p in pairs() {
1006            for tt in encoding_tests() {
1007                let got = tt.enc.encode_to_vec(&p.decoded);
1008                assert_eq!(
1009                    got,
1010                    (tt.conv)(String::from_utf8_lossy(&p.encoded).to_string()).as_bytes()
1011                );
1012            }
1013        }
1014    }
1015
1016    #[test]
1017    fn test_encoder() {
1018        for p in pairs() {
1019            let mut bb = vec![];
1020            let mut encoder = STD_ENCODING.encoder(&mut bb);
1021            encoder.write_all(&p.decoded).unwrap();
1022            encoder.close().unwrap();
1023            assert_eq!(bb, p.encoded);
1024        }
1025    }
1026
1027    #[test]
1028    fn test_encoder_buffering() {
1029        let input = big_test().decoded;
1030        for bs in 1..=12 {
1031            let mut bb = vec![];
1032            let mut encoder = STD_ENCODING.encoder(&mut bb);
1033            let mut pos = 0;
1034            while pos < input.len() {
1035                let mut end = pos + bs;
1036                if end > input.len() {
1037                    end = input.len();
1038                }
1039
1040                let n = encoder.write(&input[pos..end]).unwrap();
1041                assert_eq!(n, end - pos);
1042                pos += bs;
1043            }
1044            encoder.close().unwrap();
1045            assert_eq!(bb, big_test().encoded);
1046        }
1047    }
1048
1049    #[test]
1050    fn test_decode() {
1051        for p in pairs() {
1052            for tt in encoding_tests() {
1053                let encoded = (tt.conv)(String::from_utf8_lossy(p.encoded.as_slice()).to_string());
1054                let mut dbuf = vec![0; tt.enc.decoded_len(encoded.len())];
1055                let count = tt.enc.decode(encoded.as_bytes(), &mut dbuf).unwrap();
1056                assert_eq!(count, p.decoded.len());
1057                assert_eq!(&dbuf[..count], &p.decoded);
1058
1059                let dbuf = tt.enc.decode_to_vec(encoded.as_bytes()).unwrap();
1060                assert_eq!(dbuf, p.decoded);
1061            }
1062        }
1063    }
1064
1065    #[test]
1066    fn test_decoder() {
1067        for p in pairs() {
1068            let mut dbuf = vec![0; STD_ENCODING.decoded_len(p.encoded.len())];
1069            let mut decoder = STD_ENCODING.decoder(std::io::Cursor::new(&p.encoded));
1070            let count = decoder.read(&mut dbuf).unwrap();
1071            assert_eq!(count, p.decoded.len());
1072            assert_eq!(&dbuf[..count], &p.decoded);
1073        }
1074    }
1075
1076    #[test]
1077    fn test_decoder_buffering() {
1078        let input = big_test();
1079        for bs in 1..=12 {
1080            let mut decoder = STD_ENCODING.decoder(std::io::Cursor::new(&input.encoded));
1081            let mut buf = vec![0; input.decoded.len() + 12];
1082            let mut total = 0;
1083            while total < input.decoded.len() {
1084                total += decoder.read(&mut buf[total..total + bs]).unwrap();
1085            }
1086            assert_eq!(&buf[..total], &input.decoded);
1087        }
1088    }
1089
1090    #[test]
1091    fn test_decode_corrupt() {
1092        struct TestCase {
1093            input: Vec<u8>,
1094            offset: isize,
1095        }
1096
1097        let test_cases = vec![
1098            TestCase {
1099                input: b"".to_vec(),
1100                offset: -1,
1101            },
1102            TestCase {
1103                input: b"\n".to_vec(),
1104                offset: -1,
1105            },
1106            TestCase {
1107                input: b"AAA=\n".to_vec(),
1108                offset: -1,
1109            },
1110            TestCase {
1111                input: b"AAAA\n".to_vec(),
1112                offset: -1,
1113            },
1114            TestCase {
1115                input: b"!!!!".to_vec(),
1116                offset: 0,
1117            },
1118            TestCase {
1119                input: b"====".to_vec(),
1120                offset: 0,
1121            },
1122            TestCase {
1123                input: b"x===".to_vec(),
1124                offset: 1,
1125            },
1126            TestCase {
1127                input: b"=AAA".to_vec(),
1128                offset: 0,
1129            },
1130            TestCase {
1131                input: b"A=AA".to_vec(),
1132                offset: 1,
1133            },
1134            TestCase {
1135                input: b"AA=A".to_vec(),
1136                offset: 2,
1137            },
1138            TestCase {
1139                input: b"AA==A".to_vec(),
1140                offset: 4,
1141            },
1142            TestCase {
1143                input: b"AAA=AAAA".to_vec(),
1144                offset: 4,
1145            },
1146            TestCase {
1147                input: b"AAAAA".to_vec(),
1148                offset: 4,
1149            },
1150            TestCase {
1151                input: b"AAAAAA".to_vec(),
1152                offset: 4,
1153            },
1154            TestCase {
1155                input: b"A=".to_vec(),
1156                offset: 1,
1157            },
1158            TestCase {
1159                input: b"A==".to_vec(),
1160                offset: 1,
1161            },
1162            TestCase {
1163                input: b"AA=".to_vec(),
1164                offset: 3,
1165            },
1166            TestCase {
1167                input: b"AA==".to_vec(),
1168                offset: -1,
1169            },
1170            TestCase {
1171                input: b"AAA=".to_vec(),
1172                offset: -1,
1173            },
1174            TestCase {
1175                input: b"AAAA".to_vec(),
1176                offset: -1,
1177            },
1178            TestCase {
1179                input: b"AAAAAA=".to_vec(),
1180                offset: 7,
1181            },
1182            TestCase {
1183                input: b"YWJjZA=====".to_vec(),
1184                offset: 8,
1185            },
1186            TestCase {
1187                input: b"A!\n".to_vec(),
1188                offset: 1,
1189            },
1190            TestCase {
1191                input: b"A=\n".to_vec(),
1192                offset: 1,
1193            },
1194        ];
1195
1196        for tc in test_cases {
1197            let mut dbuf = vec![0; STD_ENCODING.decoded_len(tc.input.len())];
1198            if tc.offset == -1 {
1199                let _ = STD_ENCODING.decode(&tc.input, &mut dbuf).unwrap();
1200                continue;
1201            }
1202
1203            let n = STD_ENCODING
1204                .decode(&tc.input, &mut dbuf)
1205                .unwrap_err()
1206                .into_inner();
1207            assert_eq!(n, tc.offset as usize);
1208        }
1209    }
1210
1211    #[test]
1212    fn test_decode_bounds() {
1213        let mut buf = [0; 32];
1214        let s = STD_ENCODING.encode_to_vec(&buf);
1215        let n = STD_ENCODING.decode(&s, &mut buf).unwrap();
1216        assert_eq!(n, buf.len());
1217    }
1218
1219    struct Test {
1220        enc: Base64,
1221        n: usize,
1222        want: usize,
1223    }
1224
1225    #[test]
1226    fn test_encoded_len() {
1227        for tt in vec![
1228            Test {
1229                enc: RAW_STD_ENCODING,
1230                n: 0,
1231                want: 0,
1232            },
1233            Test {
1234                enc: RAW_STD_ENCODING,
1235                n: 1,
1236                want: 2,
1237            },
1238            Test {
1239                enc: RAW_STD_ENCODING,
1240                n: 2,
1241                want: 3,
1242            },
1243            Test {
1244                enc: RAW_STD_ENCODING,
1245                n: 3,
1246                want: 4,
1247            },
1248            Test {
1249                enc: RAW_STD_ENCODING,
1250                n: 7,
1251                want: 10,
1252            },
1253            Test {
1254                enc: STD_ENCODING,
1255                n: 0,
1256                want: 0,
1257            },
1258            Test {
1259                enc: STD_ENCODING,
1260                n: 1,
1261                want: 4,
1262            },
1263            Test {
1264                enc: STD_ENCODING,
1265                n: 2,
1266                want: 4,
1267            },
1268            Test {
1269                enc: STD_ENCODING,
1270                n: 3,
1271                want: 4,
1272            },
1273            Test {
1274                enc: STD_ENCODING,
1275                n: 4,
1276                want: 8,
1277            },
1278            Test {
1279                enc: STD_ENCODING,
1280                n: 7,
1281                want: 12,
1282            },
1283        ] {
1284            assert_eq!(tt.enc.encoded_len(tt.n), tt.want, "encoded_len({})", tt.n);
1285        }
1286    }
1287
1288    #[test]
1289    fn test_decoded_len() {
1290        for tt in vec![
1291            Test {
1292                enc: RAW_STD_ENCODING,
1293                n: 0,
1294                want: 0,
1295            },
1296            Test {
1297                enc: RAW_STD_ENCODING,
1298                n: 2,
1299                want: 1,
1300            },
1301            Test {
1302                enc: RAW_STD_ENCODING,
1303                n: 3,
1304                want: 2,
1305            },
1306            Test {
1307                enc: RAW_STD_ENCODING,
1308                n: 4,
1309                want: 3,
1310            },
1311            Test {
1312                enc: RAW_STD_ENCODING,
1313                n: 10,
1314                want: 7,
1315            },
1316            Test {
1317                enc: STD_ENCODING,
1318                n: 0,
1319                want: 0,
1320            },
1321            Test {
1322                enc: STD_ENCODING,
1323                n: 4,
1324                want: 3,
1325            },
1326            Test {
1327                enc: STD_ENCODING,
1328                n: 8,
1329                want: 6,
1330            },
1331        ] {
1332            let got = tt.enc.decoded_len(tt.n);
1333            assert_eq!(got, tt.want);
1334        }
1335    }
1336
1337    #[test]
1338    fn test_big() {
1339        const N: usize = 3 * 1000 + 1;
1340        let mut raw = [0; N];
1341        const ALPHA: &[u8] = b"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
1342        for i in 0..N {
1343            raw[i] = ALPHA[i % ALPHA.len()];
1344        }
1345
1346        let mut encoded = vec![];
1347        let mut w = STD_ENCODING.encoder(&mut encoded);
1348        let nn = w.write(&raw).unwrap();
1349        assert_eq!(nn, N);
1350
1351        w.close().unwrap();
1352        let mut dbuf = vec![];
1353        let mut decoded = STD_ENCODING.decoder(std::io::Cursor::new(&encoded));
1354        decoded.read_to_end(&mut dbuf).unwrap();
1355        assert_eq!(dbuf, raw);
1356    }
1357
1358    #[test]
1359    fn test_new_line_characters() {
1360        const EXPECTED: &str = "sure";
1361        let examples = vec![
1362            "c3VyZQ==",
1363            "c3VyZQ==\r",
1364            "c3VyZQ==\n",
1365            "c3VyZQ==\r\n",
1366            "c3VyZ\r\nQ==",
1367            "c3V\ryZ\nQ==",
1368            "c3V\nyZ\rQ==",
1369            "c3VyZ\nQ==",
1370            "c3VyZQ\n==",
1371            "c3VyZQ=\n=",
1372            "c3VyZQ=\r\n\r\n=",
1373        ];
1374
1375        for e in examples {
1376            let buf = STD_ENCODING.decode_to_vec(e.as_bytes()).unwrap();
1377            assert_eq!(EXPECTED, std::str::from_utf8(&buf).unwrap());
1378        }
1379    }
1380
1381    #[test]
1382    fn test_decoder_issue_3577() {
1383        // TODO: implement this test case
1384    }
1385
1386    #[test]
1387    fn test_decoder_issue_4779() {
1388        let encoded = r#"CP/EAT8AAAEF
1389AQEBAQEBAAAAAAAAAAMAAQIEBQYHCAkKCwEAAQUBAQEBAQEAAAAAAAAAAQACAwQFBgcICQoLEAAB
1390BAEDAgQCBQcGCAUDDDMBAAIRAwQhEjEFQVFhEyJxgTIGFJGhsUIjJBVSwWIzNHKC0UMHJZJT8OHx
1391Y3M1FqKygyZEk1RkRcKjdDYX0lXiZfKzhMPTdePzRieUpIW0lcTU5PSltcXV5fVWZnaGlqa2xtbm
13929jdHV2d3h5ent8fX5/cRAAICAQIEBAMEBQYHBwYFNQEAAhEDITESBEFRYXEiEwUygZEUobFCI8FS
13930fAzJGLhcoKSQ1MVY3M08SUGFqKygwcmNcLSRJNUoxdkRVU2dGXi8rOEw9N14/NGlKSFtJXE1OT0
1394pbXF1eX1VmZ2hpamtsbW5vYnN0dXZ3eHl6e3x//aAAwDAQACEQMRAD8A9VSSSSUpJJJJSkkkJ+Tj
13951kiy1jCJJDnAcCTykpKkuQ6p/jN6FgmxlNduXawwAzaGH+V6jn/R/wCt71zdn+N/qL3kVYFNYB4N
1396ji6PDVjWpKp9TSXnvTf8bFNjg3qOEa2n6VlLpj/rT/pf567DpX1i6L1hs9Py67X8mqdtg/rUWbbf
1397+gkp0kkkklKSSSSUpJJJJT//0PVUkkklKVLq3WMDpGI7KzrNjADtYNXvI/Mqr/Pd/q9W3vaxjnvM
1398NaCXE9gNSvGPrf8AWS3qmba5jjsJhoB0DAf0NDf6sevf+/lf8Hj0JJATfWT6/dV6oXU1uOLQeKKn
1399EQP+Hubtfe/+R7Mf/g7f5xcocp++Z11JMCJPgFBxOg7/AOuqDx8I/ikpkXkmSdU8mJIJA/O8EMAy
1400j+mSARB/17pKVXYWHXjsj7yIex0PadzXMO1zT5KHoNA3HT8ietoGhgjsfA+CSnvvqh/jJtqsrwOv
14012b6NGNzXfTYexzJ+nU7/ALkf4P8Awv6P9KvTQQ4AgyDqCF85Pho3CTB7eHwXoH+LT65uZbX9X+o2
1402bqbPb06551Y4
1403"#;
1404
1405        let encoded_stort = encoded.replace('\n', "");
1406        let mut buf = vec![];
1407        let mut decoder = STD_ENCODING.decoder(std::io::Cursor::new(encoded));
1408        decoder.read_to_end(&mut buf).unwrap();
1409
1410        let mut buf1 = vec![];
1411        let mut decoder1 = STD_ENCODING.decoder(std::io::Cursor::new(encoded_stort));
1412        decoder1.read_to_end(&mut buf1).unwrap();
1413        assert_eq!(buf, buf1);
1414    }
1415
1416    #[test]
1417    fn test_decode_issue_7733() {
1418        let err = STD_ENCODING
1419            .decode_to_vec(b"YWJjZA=====")
1420            .unwrap_err()
1421            .into_inner();
1422        assert_eq!(err, 8);
1423    }
1424
1425    #[test]
1426    fn test_decode_issue_15656() {
1427        let err = STD_ENCODING
1428            .with_strict()
1429            .decode_to_vec(b"WvLTlMrX9NpYDQlEIFlnDB==")
1430            .unwrap_err()
1431            .into_inner();
1432        assert_eq!(err, 22);
1433        STD_ENCODING
1434            .with_strict()
1435            .decode_to_vec(b"WvLTlMrX9NpYDQlEIFlnDA==")
1436            .unwrap();
1437        STD_ENCODING
1438            .decode_to_vec(b"WvLTlMrX9NpYDQlEIFlnDB==")
1439            .unwrap();
1440    }
1441}