Skip to main content

anyxml_base64/
lib.rs

1//! A simple implementation of Base64 encoding as defined in
2//! [RFC 2045 6.8. Base64 Content-Transfer-Encoding](https://datatracker.ietf.org/doc/html/rfc2045#section-6.8).
3
4/// ```text
5/// Table 1: The Base64 Alphabet
6///
7/// Value Encoding  Value Encoding  Value Encoding  Value Encoding
8///     0 A            17 R            34 i            51 z
9///     1 B            18 S            35 j            52 0
10///     2 C            19 T            36 k            53 1
11///     3 D            20 U            37 l            54 2
12///     4 E            21 V            38 m            55 3
13///     5 F            22 W            39 n            56 4
14///     6 G            23 X            40 o            57 5
15///     7 H            24 Y            41 p            58 6
16///     8 I            25 Z            42 q            59 7
17///     9 J            26 a            43 r            60 8
18///    10 K            27 b            44 s            61 9
19///    11 L            28 c            45 t            62 +
20///    12 M            29 d            46 u            63 /
21///    13 N            30 e            47 v
22///    14 O            31 f            48 w         (pad) =
23///    15 P            32 g            49 x
24///    16 Q            33 h            50 y
25/// ```
26const ENCODING_TABLE: &[u8; 1 << 6] =
27    b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
28const DECODING_TABLE: [u8; 1 << 8] = {
29    let mut table = [u8::MAX; 1 << 8];
30    let mut i = 0;
31    while i < ENCODING_TABLE.len() {
32        table[ENCODING_TABLE[i] as usize] = i as u8;
33        i += 1;
34    }
35    table
36};
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum Base64Error {
40    MalformedByte { byte: u8, position: usize },
41    InsufficientPadding,
42}
43
44impl std::fmt::Display for Base64Error {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        write!(f, "{self:?}")
47    }
48}
49
50impl std::error::Error for Base64Error {}
51
52/// Base64-encoded byte sequence.
53#[derive(Clone, PartialEq, Eq, Hash)]
54pub struct Base64Binary {
55    bin: Vec<u8>,
56}
57
58impl Base64Binary {
59    /// Encode a binary sequence to Base64 binary.
60    ///
61    /// # Example
62    /// ```rust
63    /// use anyxml_base64::Base64Binary;
64    ///
65    /// let encoded = Base64Binary::encode("Hello".bytes());
66    /// assert_eq!(encoded.to_string(), "SGVsbG8=");
67    /// ```
68    pub fn encode(iter: impl IntoIterator<Item = u8>) -> Self {
69        iter.into_iter().collect()
70    }
71
72    /// Decode the Base64 binary back into the original binary sequence.
73    ///
74    /// # Example
75    /// ```rust
76    /// use anyxml_base64::Base64Binary;
77    ///
78    /// let encoded = Base64Binary::from_encoded(*b"SGVsbG8=", false).unwrap();
79    /// let decoded = encoded.decode().map(|b| b as char).collect::<String>();
80    /// assert_eq!(decoded, "Hello");
81    /// ```
82    pub fn decode(&self) -> impl Iterator<Item = u8> + '_ {
83        assert!(self.bin.len() % 4 == 0);
84        self.bin.chunks_exact(4).flat_map(|chunk| {
85            let b0 = DECODING_TABLE[chunk[0] as usize];
86            let b1 = DECODING_TABLE[chunk[1] as usize];
87            let b2 = DECODING_TABLE[chunk[2] as usize];
88            let b3 = DECODING_TABLE[chunk[3] as usize];
89            let mut r0 = Some((b0 << 2) | (b1 >> 4));
90            let mut r1 = (b2 != u8::MAX).then_some(b1.wrapping_shl(4) | (b2 >> 2));
91            let mut r2 = (b3 != u8::MAX).then_some(b2.wrapping_shl(6) | b3);
92            std::iter::from_fn(move || r0.take().or_else(|| r1.take()).or_else(|| r2.take()))
93        })
94    }
95
96    /// Construct a [`Base64Binary`] from a Base64-encoded byte sequence.
97    ///
98    /// When `allow_whitespace` is true, bytes for which [`u8::is_ascii_whitespace`]
99    /// returns true are ignored during byte sequence validation.
100    ///
101    /// Returns an error if the `iter` is invalid as a Base64 byte sequence.
102    pub fn from_encoded(
103        iter: impl IntoIterator<Item = u8>,
104        allow_whitespace: bool,
105    ) -> Result<Self, Base64Error> {
106        let mut bin = vec![];
107        let mut pad = None;
108        for (position, byte) in iter.into_iter().enumerate() {
109            if allow_whitespace && byte.is_ascii_whitespace() {
110                continue;
111            }
112            if byte == b'=' {
113                pad.get_or_insert((bin.len(), position));
114            } else if DECODING_TABLE[byte as usize] == u8::MAX {
115                return Err(Base64Error::MalformedByte { byte, position });
116            }
117            bin.push(byte);
118        }
119
120        if bin.len() % 4 != 0 {
121            return Err(Base64Error::InsufficientPadding);
122        }
123
124        if let Some((pad, position)) = pad
125            && (bin.len() - pad > 2 || bin[pad..].iter().any(|&b| b != b'='))
126        {
127            return Err(Base64Error::MalformedByte {
128                byte: b'=',
129                position,
130            });
131        }
132
133        Ok(Base64Binary { bin })
134    }
135}
136
137impl FromIterator<u8> for Base64Binary {
138    fn from_iter<T: IntoIterator<Item = u8>>(iter: T) -> Self {
139        let mut iter = iter.into_iter();
140        let mut bin = vec![];
141        while let Some(b0) = iter.next() {
142            bin.push(ENCODING_TABLE[(b0 >> 2) as usize]);
143            match iter.next() {
144                Some(b1) => {
145                    bin.push(ENCODING_TABLE[(((b0 & 0x3) << 4) | (b1 >> 4)) as usize]);
146                    match iter.next() {
147                        Some(b2) => {
148                            bin.push(ENCODING_TABLE[(((b1 & 0xF) << 2) | (b2 >> 6)) as usize]);
149                            bin.push(ENCODING_TABLE[(b2 & 0x3F) as usize]);
150                        }
151                        None => {
152                            bin.push(ENCODING_TABLE[((b1 & 0xF) << 2) as usize]);
153                            bin.push(b'=');
154                        }
155                    }
156                }
157                None => {
158                    bin.push(ENCODING_TABLE[((b0 & 0x3) << 4) as usize]);
159                    bin.push(b'=');
160                    bin.push(b'=');
161                }
162            }
163        }
164        Base64Binary { bin }
165    }
166}
167
168impl From<&str> for Base64Binary {
169    fn from(value: &str) -> Self {
170        value.bytes().collect()
171    }
172}
173
174macro_rules! impl_from_str_for_base64_binary {
175    ( $( $t:ty ),* ) => {
176        $(
177            impl From<$t> for Base64Binary {
178                fn from(value: $t) -> Self {
179                    value.bytes().collect()
180                }
181            }
182            impl From<&$t> for Base64Binary {
183                fn from(value: &$t) -> Self {
184                    value.bytes().collect()
185                }
186            }
187        )*
188    };
189}
190impl_from_str_for_base64_binary!(
191    String,
192    Box<str>,
193    std::rc::Rc<str>,
194    std::sync::Arc<str>,
195    std::borrow::Cow<'_, str>
196);
197
198impl From<&[u8]> for Base64Binary {
199    fn from(value: &[u8]) -> Self {
200        value.iter().copied().collect()
201    }
202}
203macro_rules! impl_from_bytes_for_base64_binary {
204    ( $( $t:ty ),* ) => {
205        $(
206            impl From<$t> for Base64Binary {
207                fn from(value: $t) -> Self {
208                    value.iter().copied().collect()
209                }
210            }
211            impl From<&$t> for Base64Binary {
212                fn from(value: &$t) -> Self {
213                    value.iter().copied().collect()
214                }
215            }
216        )*
217    };
218}
219impl_from_bytes_for_base64_binary!(
220    Vec<u8>,
221    Box<[u8]>,
222    std::rc::Rc<[u8]>,
223    std::sync::Arc<[u8]>,
224    std::borrow::Cow<'_, [u8]>
225);
226
227impl std::fmt::Debug for Base64Binary {
228    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229        write!(f, "{}", self)
230    }
231}
232
233impl std::fmt::Display for Base64Binary {
234    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235        unsafe {
236            // # Safety
237            // `self.bin` is a Base64 binary consisting entirely of ASCII characters,
238            // so UTF-8 validation will not fail.
239            write!(f, "{}", std::str::from_utf8_unchecked(&self.bin))
240        }
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use std::hash::{BuildHasher, Hasher, RandomState};
247
248    use super::*;
249
250    fn xor_shift32(seed: u64) -> impl Iterator<Item = u32> {
251        let mut random = seed as u32;
252
253        std::iter::repeat_with(move || {
254            random ^= random << 13;
255            random ^= random >> 17;
256            random ^= random << 5;
257            random
258        })
259    }
260
261    fn bytes(seed: u64) -> impl Iterator<Item = u8> {
262        let mut generator = xor_shift32(seed);
263        let mut counter = 0;
264        let mut buf = [0u8; 4];
265        std::iter::from_fn(move || {
266            if counter == 4 {
267                let val = generator.next().unwrap();
268                buf = val.to_le_bytes();
269                counter = 0;
270            }
271            let ret = buf[counter];
272            counter += 1;
273            Some(ret)
274        })
275    }
276
277    #[test]
278    fn regression_tests() {
279        let state = RandomState::new().build_hasher();
280        let seed = state.finish();
281        let mut bytes = bytes(seed);
282        for _ in 0..10000 {
283            let len = bytes.next().unwrap() as usize;
284            let bytes = (0..len).map(|_| bytes.next().unwrap()).collect::<Vec<_>>();
285
286            let encoded = bytes.iter().copied().collect::<Base64Binary>();
287            let decoded = encoded.decode().collect::<Vec<_>>();
288
289            assert_eq!(bytes, decoded, "len: {},{}", bytes.len(), decoded.len());
290            let pad = encoded.bin.iter().filter(|c| **c == b'=').count();
291            match encoded.bin.as_slice() {
292                [.., b'=', b'='] => assert_eq!(pad, 2),
293                [.., b'='] => assert_eq!(pad, 1),
294                [..] => assert_eq!(pad, 0),
295            }
296
297            let encoded2 = Base64Binary::from_encoded(encoded.to_string().bytes(), false).unwrap();
298            assert_eq!(encoded, encoded2);
299        }
300    }
301
302    #[test]
303    fn encoded_bytes_tests() {
304        assert!(Base64Binary::from_encoded(*b"", false).is_ok());
305        let state = RandomState::new().build_hasher();
306        let seed = state.finish();
307        let mut bytes = bytes(seed);
308        for _ in 0..10000 {
309            let len = bytes.next().unwrap() as usize;
310            let len = len.div_ceil(4) * 4;
311            let bytes = bytes
312                .by_ref()
313                .filter(|b| DECODING_TABLE[*b as usize] != u8::MAX)
314                .take(len)
315                .collect::<Vec<_>>();
316
317            let encoded = Base64Binary::from_encoded(bytes, false);
318            assert!(encoded.is_ok());
319        }
320    }
321
322    #[test]
323    fn erroneous_encoded_bytes_tests() {
324        assert!(Base64Binary::from_encoded(*b"a", false).is_err());
325        assert!(Base64Binary::from_encoded(*b"aa", false).is_err());
326        assert!(Base64Binary::from_encoded(*b"aaa", false).is_err());
327        assert!(Base64Binary::from_encoded(*b"aaaaa", false).is_err());
328        assert!(Base64Binary::from_encoded(*b"aaaaaa", false).is_err());
329        assert!(Base64Binary::from_encoded(*b"aaaaaaa", false).is_err());
330
331        assert!(Base64Binary::from_encoded(*b"=", false).is_err());
332        assert!(Base64Binary::from_encoded(*b"==", false).is_err());
333        assert!(Base64Binary::from_encoded(*b"===", false).is_err());
334        assert!(Base64Binary::from_encoded(*b"====", false).is_err());
335
336        assert!(Base64Binary::from_encoded(*b"a=", false).is_err());
337        assert!(Base64Binary::from_encoded(*b"a==", false).is_err());
338        assert!(Base64Binary::from_encoded(*b"a===", false).is_err());
339    }
340}