1use crate::error::*;
2use crate::{Decoder, Encoder};
3
4struct Base64Impl;
5
6#[derive(Copy, Clone, Debug, Eq, PartialEq)]
7enum Base64Variant {
8    Original = 1,
9    OriginalNoPadding = 3,
10    UrlSafe = 5,
11    UrlSafeNoPadding = 7,
12}
13
14enum VariantMask {
15    NoPadding = 2,
16    UrlSafe = 4,
17}
18
19impl Base64Impl {
20    #[inline]
21    fn _eq(x: u8, y: u8) -> u8 {
22        !(((0u16.wrapping_sub((x as u16) ^ (y as u16))) >> 8) as u8)
23    }
24
25    #[inline]
26    fn _gt(x: u8, y: u8) -> u8 {
27        (((y as u16).wrapping_sub(x as u16)) >> 8) as u8
28    }
29
30    #[inline]
31    fn _ge(x: u8, y: u8) -> u8 {
32        !Self::_gt(y, x)
33    }
34
35    #[inline]
36    fn _lt(x: u8, y: u8) -> u8 {
37        Self::_gt(y, x)
38    }
39
40    #[inline]
41    fn _le(x: u8, y: u8) -> u8 {
42        Self::_ge(y, x)
43    }
44
45    #[inline]
46    fn b64_byte_to_char(x: u8) -> u8 {
47        (Self::_lt(x, 26) & (x.wrapping_add(b'A')))
48            | (Self::_ge(x, 26) & Self::_lt(x, 52) & (x.wrapping_add(b'a'.wrapping_sub(26))))
49            | (Self::_ge(x, 52) & Self::_lt(x, 62) & (x.wrapping_add(b'0'.wrapping_sub(52))))
50            | (Self::_eq(x, 62) & b'+')
51            | (Self::_eq(x, 63) & b'/')
52    }
53
54    #[inline]
55    fn b64_char_to_byte(c: u8) -> u8 {
56        let x = (Self::_ge(c, b'A') & Self::_le(c, b'Z') & (c.wrapping_sub(b'A')))
57            | (Self::_ge(c, b'a') & Self::_le(c, b'z') & (c.wrapping_sub(b'a'.wrapping_sub(26))))
58            | (Self::_ge(c, b'0') & Self::_le(c, b'9') & (c.wrapping_sub(b'0'.wrapping_sub(52))))
59            | (Self::_eq(c, b'+') & 62)
60            | (Self::_eq(c, b'/') & 63);
61        x | (Self::_eq(x, 0) & (Self::_eq(c, b'A') ^ 0xff))
62    }
63
64    #[inline]
65    fn b64_byte_to_urlsafe_char(x: u8) -> u8 {
66        (Self::_lt(x, 26) & (x.wrapping_add(b'A')))
67            | (Self::_ge(x, 26) & Self::_lt(x, 52) & (x.wrapping_add(b'a'.wrapping_sub(26))))
68            | (Self::_ge(x, 52) & Self::_lt(x, 62) & (x.wrapping_add(b'0'.wrapping_sub(52))))
69            | (Self::_eq(x, 62) & b'-')
70            | (Self::_eq(x, 63) & b'_')
71    }
72
73    #[inline]
74    fn b64_urlsafe_char_to_byte(c: u8) -> u8 {
75        let x = (Self::_ge(c, b'A') & Self::_le(c, b'Z') & (c.wrapping_sub(b'A')))
76            | (Self::_ge(c, b'a') & Self::_le(c, b'z') & (c.wrapping_sub(b'a'.wrapping_sub(26))))
77            | (Self::_ge(c, b'0') & Self::_le(c, b'9') & (c.wrapping_sub(b'0'.wrapping_sub(52))))
78            | (Self::_eq(c, b'-') & 62)
79            | (Self::_eq(c, b'_') & 63);
80        x | (Self::_eq(x, 0) & (Self::_eq(c, b'A') ^ 0xff))
81    }
82
83    #[inline]
84    fn encoded_len(bin_len: usize, variant: Base64Variant) -> Result<usize, Error> {
85        let nibbles = bin_len / 3;
86        let rounded = nibbles * 3;
87        let pad = bin_len - rounded;
88        Ok(nibbles.checked_mul(4).ok_or(Error::Overflow)?
89            + ((pad | (pad >> 1)) & 1)
90                * (4 - (!((((variant as usize) & 2) >> 1).wrapping_sub(1)) & (3 - pad)))
91            + 1)
92    }
93
94    pub fn encode<'t>(
95        b64: &'t mut [u8],
96        bin: &[u8],
97        variant: Base64Variant,
98    ) -> Result<&'t [u8], Error> {
99        let bin_len = bin.len();
100        let b64_maxlen = b64.len();
101        let mut acc_len = 0usize;
102        let mut b64_pos = 0usize;
103        let mut acc = 0u16;
104
105        let nibbles = bin_len / 3;
106        let remainder = bin_len - 3 * nibbles;
107        let mut b64_len = nibbles * 4;
108        if remainder != 0 {
109            if (variant as u16 & VariantMask::NoPadding as u16) == 0 {
110                b64_len += 4;
111            } else {
112                b64_len += 2 + (remainder >> 1);
113            }
114        }
115        if b64_maxlen < b64_len {
116            return Err(Error::Overflow);
117        }
118        if (variant as u16 & VariantMask::UrlSafe as u16) != 0 {
119            for &v in bin {
120                acc = (acc << 8) + v as u16;
121                acc_len += 8;
122                while acc_len >= 6 {
123                    acc_len -= 6;
124                    b64[b64_pos] = Self::b64_byte_to_urlsafe_char(((acc >> acc_len) & 0x3f) as u8);
125                    b64_pos += 1;
126                }
127            }
128            if acc_len > 0 {
129                b64[b64_pos] =
130                    Self::b64_byte_to_urlsafe_char(((acc << (6 - acc_len)) & 0x3f) as u8);
131                b64_pos += 1;
132            }
133        } else {
134            for &v in bin {
135                acc = (acc << 8) + v as u16;
136                acc_len += 8;
137                while acc_len >= 6 {
138                    acc_len -= 6;
139                    b64[b64_pos] = Self::b64_byte_to_char(((acc >> acc_len) & 0x3f) as u8);
140                    b64_pos += 1;
141                }
142            }
143            if acc_len > 0 {
144                b64[b64_pos] = Self::b64_byte_to_char(((acc << (6 - acc_len)) & 0x3f) as u8);
145                b64_pos += 1;
146            }
147        }
148        while b64_pos < b64_len {
149            b64[b64_pos] = b'=';
150            b64_pos += 1
151        }
152        Ok(&b64[..b64_pos])
153    }
154
155    fn skip_padding<'t>(
156        b64: &'t [u8],
157        mut padding_len: usize,
158        ignore: Option<&[u8]>,
159    ) -> Result<&'t [u8], Error> {
160        let b64_len = b64.len();
161        let mut b64_pos = 0usize;
162        while padding_len > 0 {
163            if b64_pos >= b64_len {
164                return Err(Error::InvalidInput);
165            }
166            let c = b64[b64_pos];
167            if c == b'=' {
168                padding_len -= 1
169            } else {
170                match ignore {
171                    Some(ignore) if ignore.contains(&c) => {}
172                    _ => return Err(Error::InvalidInput),
173                }
174            }
175            b64_pos += 1
176        }
177        Ok(&b64[b64_pos..])
178    }
179
180    pub fn decode<'t>(
181        bin: &'t mut [u8],
182        b64: &[u8],
183        ignore: Option<&[u8]>,
184        variant: Base64Variant,
185    ) -> Result<&'t [u8], Error> {
186        let bin_maxlen = bin.len();
187        let is_urlsafe = (variant as u16 & VariantMask::UrlSafe as u16) != 0;
188        let mut acc = 0u16;
189        let mut acc_len = 0usize;
190        let mut bin_pos = 0usize;
191        let mut premature_end = None;
192        for (b64_pos, &c) in b64.iter().enumerate() {
193            let d = if is_urlsafe {
194                Self::b64_urlsafe_char_to_byte(c)
195            } else {
196                Self::b64_char_to_byte(c)
197            };
198            if d == 0xff {
199                match ignore {
200                    Some(ignore) if ignore.contains(&c) => continue,
201                    _ => {
202                        premature_end = Some(b64_pos);
203                        break;
204                    }
205                }
206            }
207            acc = (acc << 6) + d as u16;
208            acc_len += 6;
209            if acc_len >= 8 {
210                acc_len -= 8;
211                if bin_pos >= bin_maxlen {
212                    return Err(Error::Overflow);
213                }
214                bin[bin_pos] = (acc >> acc_len) as u8;
215                bin_pos += 1;
216            }
217        }
218        if acc_len > 4 || (acc & ((1u16 << acc_len).wrapping_sub(1))) != 0 {
219            return Err(Error::InvalidInput);
220        }
221        let padding_len = acc_len / 2;
222        if let Some(premature_end) = premature_end {
223            let remaining = if variant as u16 & VariantMask::NoPadding as u16 == 0 {
224                Self::skip_padding(&b64[premature_end..], padding_len, ignore)?
225            } else {
226                &b64[premature_end..]
227            };
228            match ignore {
229                None => {
230                    if !remaining.is_empty() {
231                        return Err(Error::InvalidInput);
232                    }
233                }
234                Some(ignore) => {
235                    for &c in remaining {
236                        if !ignore.contains(&c) {
237                            return Err(Error::InvalidInput);
238                        }
239                    }
240                }
241            }
242        } else if variant as u16 & VariantMask::NoPadding as u16 == 0 && padding_len != 0 {
243            return Err(Error::InvalidInput);
244        }
245        Ok(&bin[..bin_pos])
246    }
247}
248
249pub struct Base64;
250pub struct Base64NoPadding;
251pub struct Base64UrlSafe;
252pub struct Base64UrlSafeNoPadding;
253
254impl Encoder for Base64 {
255    #[inline]
256    fn encoded_len(bin_len: usize) -> Result<usize, Error> {
257        Base64Impl::encoded_len(bin_len, Base64Variant::Original)
258    }
259
260    #[inline]
261    fn encode<IN: AsRef<[u8]>>(b64: &mut [u8], bin: IN) -> Result<&[u8], Error> {
262        Base64Impl::encode(b64, bin.as_ref(), Base64Variant::Original)
263    }
264}
265
266impl Decoder for Base64 {
267    #[inline]
268    fn decode<'t, IN: AsRef<[u8]>>(
269        bin: &'t mut [u8],
270        b64: IN,
271        ignore: Option<&[u8]>,
272    ) -> Result<&'t [u8], Error> {
273        Base64Impl::decode(bin, b64.as_ref(), ignore, Base64Variant::Original)
274    }
275}
276
277impl Encoder for Base64NoPadding {
278    #[inline]
279    fn encoded_len(bin_len: usize) -> Result<usize, Error> {
280        Base64Impl::encoded_len(bin_len, Base64Variant::OriginalNoPadding)
281    }
282
283    #[inline]
284    fn encode<IN: AsRef<[u8]>>(b64: &mut [u8], bin: IN) -> Result<&[u8], Error> {
285        Base64Impl::encode(b64, bin.as_ref(), Base64Variant::OriginalNoPadding)
286    }
287}
288
289impl Decoder for Base64NoPadding {
290    #[inline]
291    fn decode<'t, IN: AsRef<[u8]>>(
292        bin: &'t mut [u8],
293        b64: IN,
294        ignore: Option<&[u8]>,
295    ) -> Result<&'t [u8], Error> {
296        Base64Impl::decode(bin, b64.as_ref(), ignore, Base64Variant::OriginalNoPadding)
297    }
298}
299
300impl Encoder for Base64UrlSafe {
301    #[inline]
302    fn encoded_len(bin_len: usize) -> Result<usize, Error> {
303        Base64Impl::encoded_len(bin_len, Base64Variant::UrlSafe)
304    }
305
306    #[inline]
307    fn encode<IN: AsRef<[u8]>>(b64: &mut [u8], bin: IN) -> Result<&[u8], Error> {
308        Base64Impl::encode(b64, bin.as_ref(), Base64Variant::UrlSafe)
309    }
310}
311
312impl Decoder for Base64UrlSafe {
313    #[inline]
314    fn decode<'t, IN: AsRef<[u8]>>(
315        bin: &'t mut [u8],
316        b64: IN,
317        ignore: Option<&[u8]>,
318    ) -> Result<&'t [u8], Error> {
319        Base64Impl::decode(bin, b64.as_ref(), ignore, Base64Variant::UrlSafe)
320    }
321}
322
323impl Encoder for Base64UrlSafeNoPadding {
324    #[inline]
325    fn encoded_len(bin_len: usize) -> Result<usize, Error> {
326        Base64Impl::encoded_len(bin_len, Base64Variant::UrlSafeNoPadding)
327    }
328
329    #[inline]
330    fn encode<IN: AsRef<[u8]>>(b64: &mut [u8], bin: IN) -> Result<&[u8], Error> {
331        Base64Impl::encode(b64, bin.as_ref(), Base64Variant::UrlSafeNoPadding)
332    }
333}
334
335impl Decoder for Base64UrlSafeNoPadding {
336    #[inline]
337    fn decode<'t, IN: AsRef<[u8]>>(
338        bin: &'t mut [u8],
339        b64: IN,
340        ignore: Option<&[u8]>,
341    ) -> Result<&'t [u8], Error> {
342        Base64Impl::decode(bin, b64.as_ref(), ignore, Base64Variant::UrlSafeNoPadding)
343    }
344}
345
346#[cfg(feature = "std")]
347#[test]
348fn test_base64() {
349    let bin = [1u8, 5, 11, 15, 19, 131, 122];
350    let expected = "AQULDxODeg==";
351    let b64 = Base64::encode_to_string(bin).unwrap();
352    assert_eq!(b64, expected);
353    let bin2 = Base64::decode_to_vec(&b64, None).unwrap();
354    assert_eq!(bin, &bin2[..]);
355}
356
357#[cfg(feature = "std")]
358#[test]
359fn test_base64_mising_padding() {
360    let missing_padding = "AA";
361    assert!(Base64::decode_to_vec(missing_padding, None).is_err());
362    assert!(Base64NoPadding::decode_to_vec(missing_padding, None).is_ok());
363    let missing_padding = "AAA";
364    assert!(Base64::decode_to_vec(missing_padding, None).is_err());
365    assert!(Base64NoPadding::decode_to_vec(missing_padding, None).is_ok());
366}
367
368#[test]
369fn test_base64_no_std() {
370    let bin = [1u8, 5, 11, 15, 19, 131, 122];
371    let expected = [65, 81, 85, 76, 68, 120, 79, 68, 101, 103, 61, 61];
372    let mut b64 = [0u8; 12];
373    let b64 = Base64::encode(&mut b64, bin).unwrap();
374    assert_eq!(b64, expected);
375    let mut bin2 = [0u8; 7];
376    let bin2 = Base64::decode(&mut bin2, b64, None).unwrap();
377    assert_eq!(bin, bin2);
378}
379
380#[test]
381fn test_base64_invalid_padding() {
382    let valid_padding = "AA==";
383    assert_eq!(Base64::decode_to_vec(valid_padding, None), Ok(vec![0u8; 1]));
384    let invalid_padding = "AA=";
385    assert_eq!(
386        Base64::decode_to_vec(invalid_padding, None),
387        Err(Error::InvalidInput)
388    );
389}