Skip to main content

anyxml_base64/
lib.rs

1//! A simple implementation of Base64 encoding as defined in
2//! [RFC 2045 6.8. Base64 Content-Transfer-Encoding](https://datatracker.ietf.org/doc/html/rfc2045#section-6.8).
3
4/// ```text
5/// Table 1: The Base64 Alphabet
6///
7/// Value Encoding  Value Encoding  Value Encoding  Value Encoding
8///     0 A            17 R            34 i            51 z
9///     1 B            18 S            35 j            52 0
10///     2 C            19 T            36 k            53 1
11///     3 D            20 U            37 l            54 2
12///     4 E            21 V            38 m            55 3
13///     5 F            22 W            39 n            56 4
14///     6 G            23 X            40 o            57 5
15///     7 H            24 Y            41 p            58 6
16///     8 I            25 Z            42 q            59 7
17///     9 J            26 a            43 r            60 8
18///    10 K            27 b            44 s            61 9
19///    11 L            28 c            45 t            62 +
20///    12 M            29 d            46 u            63 /
21///    13 N            30 e            47 v
22///    14 O            31 f            48 w         (pad) =
23///    15 P            32 g            49 x
24///    16 Q            33 h            50 y
25/// ```
26const ENCODING_TABLE: &[u8; 1 << 6] =
27    b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
28const DECODING_TABLE: [u8; 1 << 8] = {
29    let mut table = [u8::MAX; 1 << 8];
30    let mut i = 0;
31    while i < ENCODING_TABLE.len() {
32        table[ENCODING_TABLE[i] as usize] = i as u8;
33        i += 1;
34    }
35    table
36};
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum Base64Error {
40    MalformedByte { byte: u8, position: usize },
41    InsufficientPadding,
42}
43
44/// Base64-encoded byte sequence.
45#[derive(Clone, PartialEq, Eq, Hash)]
46pub struct Base64Binary {
47    bin: Vec<u8>,
48}
49
50impl Base64Binary {
51    /// Encode a binary sequence to Base64 binary.
52    ///
53    /// # Example
54    /// ```rust
55    /// use anyxml_base64::Base64Binary;
56    ///
57    /// let encoded = Base64Binary::encode("Hello".bytes());
58    /// assert_eq!(encoded.to_string(), "SGVsbG8=");
59    /// ```
60    pub fn encode(iter: impl IntoIterator<Item = u8>) -> Self {
61        iter.into_iter().collect()
62    }
63
64    /// Decode the Base64 binary back into the original binary sequence.
65    ///
66    /// # Example
67    /// ```rust
68    /// use anyxml_base64::Base64Binary;
69    ///
70    /// let encoded = Base64Binary::from_encoded(*b"SGVsbG8=", false).unwrap();
71    /// let decoded = encoded.decode().map(|b| b as char).collect::<String>();
72    /// assert_eq!(decoded, "Hello");
73    /// ```
74    pub fn decode(&self) -> impl Iterator<Item = u8> + '_ {
75        assert!(self.bin.len() % 4 == 0);
76        self.bin.chunks_exact(4).flat_map(|chunk| {
77            let b0 = DECODING_TABLE[chunk[0] as usize];
78            let b1 = DECODING_TABLE[chunk[1] as usize];
79            let b2 = DECODING_TABLE[chunk[2] as usize];
80            let b3 = DECODING_TABLE[chunk[3] as usize];
81            let mut r0 = Some((b0 << 2) | (b1 >> 4));
82            let mut r1 = (b2 != u8::MAX).then_some(b1.wrapping_shl(4) | (b2 >> 2));
83            let mut r2 = (b3 != u8::MAX).then_some(b2.wrapping_shl(6) | b3);
84            std::iter::from_fn(move || r0.take().or_else(|| r1.take()).or_else(|| r2.take()))
85        })
86    }
87
88    /// Construct a [`Base64Binary`] from a Base64-encoded byte sequence.
89    ///
90    /// When `allow_whitespace` is true, bytes for which [`u8::is_ascii_whitespace`]
91    /// returns true are ignored during byte sequence validation.
92    ///
93    /// Returns an error if the `iter` is invalid as a Base64 byte sequence.
94    pub fn from_encoded(
95        iter: impl IntoIterator<Item = u8>,
96        allow_whitespace: bool,
97    ) -> Result<Self, Base64Error> {
98        let mut bin = vec![];
99        let mut pad = None;
100        for (position, byte) in iter.into_iter().enumerate() {
101            if allow_whitespace && byte.is_ascii_whitespace() {
102                continue;
103            }
104            if byte == b'=' {
105                pad.get_or_insert((bin.len(), position));
106            } else if DECODING_TABLE[byte as usize] == u8::MAX {
107                return Err(Base64Error::MalformedByte { byte, position });
108            }
109            bin.push(byte);
110        }
111
112        if bin.len() % 4 != 0 {
113            return Err(Base64Error::InsufficientPadding);
114        }
115
116        if let Some((pad, position)) = pad
117            && (bin.len() - pad > 2 || bin[pad..].iter().any(|&b| b != b'='))
118        {
119            return Err(Base64Error::MalformedByte {
120                byte: b'=',
121                position,
122            });
123        }
124
125        Ok(Base64Binary { bin })
126    }
127}
128
129impl FromIterator<u8> for Base64Binary {
130    fn from_iter<T: IntoIterator<Item = u8>>(iter: T) -> Self {
131        let mut iter = iter.into_iter();
132        let mut bin = vec![];
133        while let Some(b0) = iter.next() {
134            bin.push(ENCODING_TABLE[(b0 >> 2) as usize]);
135            match iter.next() {
136                Some(b1) => {
137                    bin.push(ENCODING_TABLE[(((b0 & 0x3) << 4) | (b1 >> 4)) as usize]);
138                    match iter.next() {
139                        Some(b2) => {
140                            bin.push(ENCODING_TABLE[(((b1 & 0xF) << 2) | (b2 >> 6)) as usize]);
141                            bin.push(ENCODING_TABLE[(b2 & 0x3F) as usize]);
142                        }
143                        None => {
144                            bin.push(ENCODING_TABLE[((b1 & 0xF) << 2) as usize]);
145                            bin.push(b'=');
146                        }
147                    }
148                }
149                None => {
150                    bin.push(ENCODING_TABLE[((b0 & 0x3) << 4) as usize]);
151                    bin.push(b'=');
152                    bin.push(b'=');
153                }
154            }
155        }
156        Base64Binary { bin }
157    }
158}
159
160impl From<&str> for Base64Binary {
161    fn from(value: &str) -> Self {
162        value.bytes().collect()
163    }
164}
165
166macro_rules! impl_from_str_for_base64_binary {
167    ( $( $t:ty ),* ) => {
168        $(
169            impl From<$t> for Base64Binary {
170                fn from(value: $t) -> Self {
171                    value.bytes().collect()
172                }
173            }
174            impl From<&$t> for Base64Binary {
175                fn from(value: &$t) -> Self {
176                    value.bytes().collect()
177                }
178            }
179        )*
180    };
181}
182impl_from_str_for_base64_binary!(
183    String,
184    Box<str>,
185    std::rc::Rc<str>,
186    std::sync::Arc<str>,
187    std::borrow::Cow<'_, str>
188);
189
190impl From<&[u8]> for Base64Binary {
191    fn from(value: &[u8]) -> Self {
192        value.iter().copied().collect()
193    }
194}
195macro_rules! impl_from_bytes_for_base64_binary {
196    ( $( $t:ty ),* ) => {
197        $(
198            impl From<$t> for Base64Binary {
199                fn from(value: $t) -> Self {
200                    value.iter().copied().collect()
201                }
202            }
203            impl From<&$t> for Base64Binary {
204                fn from(value: &$t) -> Self {
205                    value.iter().copied().collect()
206                }
207            }
208        )*
209    };
210}
211impl_from_bytes_for_base64_binary!(
212    Vec<u8>,
213    Box<[u8]>,
214    std::rc::Rc<[u8]>,
215    std::sync::Arc<[u8]>,
216    std::borrow::Cow<'_, [u8]>
217);
218
219impl std::fmt::Debug for Base64Binary {
220    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
221        write!(f, "{}", self)
222    }
223}
224
225impl std::fmt::Display for Base64Binary {
226    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227        unsafe {
228            // # Safety
229            // `self.bin` is a Base64 binary consisting entirely of ASCII characters,
230            // so UTF-8 validation will not fail.
231            write!(f, "{}", std::str::from_utf8_unchecked(&self.bin))
232        }
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use std::hash::{BuildHasher, Hasher, RandomState};
239
240    use super::*;
241
242    fn xor_shift32(seed: u64) -> impl Iterator<Item = u32> {
243        let mut random = seed as u32;
244
245        std::iter::repeat_with(move || {
246            random ^= random << 13;
247            random ^= random >> 17;
248            random ^= random << 5;
249            random
250        })
251    }
252
253    fn bytes(seed: u64) -> impl Iterator<Item = u8> {
254        let mut generator = xor_shift32(seed);
255        let mut counter = 0;
256        let mut buf = [0u8; 4];
257        std::iter::from_fn(move || {
258            if counter == 4 {
259                let val = generator.next().unwrap();
260                buf = val.to_le_bytes();
261                counter = 0;
262            }
263            let ret = buf[counter];
264            counter += 1;
265            Some(ret)
266        })
267    }
268
269    #[test]
270    fn regression_tests() {
271        let state = RandomState::new().build_hasher();
272        let seed = state.finish();
273        let mut bytes = bytes(seed);
274        for _ in 0..10000 {
275            let len = bytes.next().unwrap() as usize;
276            let bytes = (0..len).map(|_| bytes.next().unwrap()).collect::<Vec<_>>();
277
278            let encoded = bytes.iter().copied().collect::<Base64Binary>();
279            let decoded = encoded.decode().collect::<Vec<_>>();
280
281            assert_eq!(bytes, decoded, "len: {},{}", bytes.len(), decoded.len());
282            let pad = encoded.bin.iter().filter(|c| **c == b'=').count();
283            match encoded.bin.as_slice() {
284                [.., b'=', b'='] => assert_eq!(pad, 2),
285                [.., b'='] => assert_eq!(pad, 1),
286                [..] => assert_eq!(pad, 0),
287            }
288
289            let encoded2 = Base64Binary::from_encoded(encoded.to_string().bytes(), false).unwrap();
290            assert_eq!(encoded, encoded2);
291        }
292    }
293
294    #[test]
295    fn encoded_bytes_tests() {
296        assert!(Base64Binary::from_encoded(*b"", false).is_ok());
297        let state = RandomState::new().build_hasher();
298        let seed = state.finish();
299        let mut bytes = bytes(seed);
300        for _ in 0..10000 {
301            let len = bytes.next().unwrap() as usize;
302            let len = len.div_ceil(4) * 4;
303            let bytes = bytes
304                .by_ref()
305                .filter(|b| DECODING_TABLE[*b as usize] != u8::MAX)
306                .take(len)
307                .collect::<Vec<_>>();
308
309            let encoded = Base64Binary::from_encoded(bytes, false);
310            assert!(encoded.is_ok());
311        }
312    }
313
314    #[test]
315    fn erroneous_encoded_bytes_tests() {
316        assert!(Base64Binary::from_encoded(*b"a", false).is_err());
317        assert!(Base64Binary::from_encoded(*b"aa", false).is_err());
318        assert!(Base64Binary::from_encoded(*b"aaa", false).is_err());
319        assert!(Base64Binary::from_encoded(*b"aaaaa", false).is_err());
320        assert!(Base64Binary::from_encoded(*b"aaaaaa", false).is_err());
321        assert!(Base64Binary::from_encoded(*b"aaaaaaa", false).is_err());
322
323        assert!(Base64Binary::from_encoded(*b"=", false).is_err());
324        assert!(Base64Binary::from_encoded(*b"==", false).is_err());
325        assert!(Base64Binary::from_encoded(*b"===", false).is_err());
326        assert!(Base64Binary::from_encoded(*b"====", false).is_err());
327
328        assert!(Base64Binary::from_encoded(*b"a=", false).is_err());
329        assert!(Base64Binary::from_encoded(*b"a==", false).is_err());
330        assert!(Base64Binary::from_encoded(*b"a===", false).is_err());
331    }
332}