ser_write_json/
base64.rs

1//! Base-64 codec.
2use core::cell::Cell;
3use crate::SerWrite;
4
5static ALPHABET: &[u8;64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
6
7/// Encode an array of bytes as Base-64 ASCII armour codes into a [`SerWrite`] implementing object.
8///
9/// This function does not append Base-64 `'='` padding characters by itself
10/// and instead returns the number of padding characters required: 0-2.
11pub fn encode<W: SerWrite>(ser: &mut W, bytes: &[u8]) -> Result<u8, W::Error> {
12    let mut chunks = bytes.chunks_exact(3);
13    for slice in chunks.by_ref() {
14        let [a,b,c] = slice.try_into().unwrap();
15        let output = [
16            a >> 2,
17            ((a & 0x03) << 4) | ((b & 0xF0) >> 4),
18            ((b & 0x0F) << 2) | ((c & 0xC0) >> 6),
19            c & 0x3F
20        ].map(|n| ALPHABET[(n & 0x3F) as usize]);
21        ser.write(&output)?;
22    }
23    match chunks.remainder() {
24        [a, b] => {
25            let output = [
26                a >> 2,
27                ((a & 0x03) << 4) | ((b & 0xF0) >> 4),
28                ((b & 0x0F) << 2)
29            ].map(|n| ALPHABET[(n & 0x3F) as usize]);
30            ser.write(&output)?;
31            Ok(1)
32        }
33        [a] => {
34            let output = [
35                a >> 2,
36                ((a & 0x03) << 4),
37            ].map(|n| ALPHABET[(n & 0x3F) as usize]);
38            ser.write(&output)?;
39            Ok(2)
40        }
41        _ => Ok(0)
42    }
43}
44
45#[inline]
46fn get_code(c: u8) -> Option<u8> {
47    match c {
48        b'A'..=b'Z' => Some(c - b'A'),
49        b'a'..=b'z' => Some(c - b'a' + 26),
50        b'0'..=b'9' => Some(c - b'0' + 52),
51        b'/' => Some(63),
52        b'+' => Some(62),
53        _ => None
54    }
55}
56
57// static DIGITS: [u8;80] = [
58// /*   0   1   2   3   4   5   6   7   8   9   A   B   C   D   E   F */
59//                                                 62, 80, 80, 80, 63, /* 0x2B..=0x2F */
60//     52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 80, 80, 80, 64, 80, 80, /* 0x30..=0x3F */
61//     80,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, /* 0x40..=0x4F */
62//     15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 80, 80, 80, 80, 80, /* 0x50..=0x5F */
63//     80, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, /* 0x60..=0x6F */
64//     41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51                      /* 0x70..=0x7A */
65// ];
66
67// #[inline]
68// fn get_code(c: u8) -> Option<u8> {
69//     match c {
70//         0x2B..=0x7A => {
71//             let n = DIGITS[(c - 0x2B) as usize];
72//             (n <= 63).then_some(n)
73//         }
74//         _ => None
75//     }
76// }
77
78//   010100    110111    010101    101110
79//                       010100 << 18
80//                       110111 << 12
81//                       010101 << 6
82//                       101110
83//   01010011 01110101 01101110
84//
85//                            1 (0) (31)
86//                      1010100 (1) (25)
87//                   1 01010000 (1) (25)(<<2)
88//               10101 00110111 (2) (19)
89//          1 01010011 01110000 (2) (19)(<<4)
90//        101 01001101 11010101 (3) (13)
91// 1 01010011 01110101 01000000 (3) (13)(<<6)
92// 1 01010011 01110101 01101110 (4) (7)
93#[inline(always)]
94fn decode_cell(acc: u32, cell: &Cell<u8>) -> core::result::Result<u32, u32> {
95    match get_code(cell.get()) {
96        Some(code) => Ok((acc << 6) | u32::from(code)),
97        None => Err(acc)
98    }
99}
100/// Decode a Base-64 encoded slice of byte characters in-place until a first
101/// invalid character is found or until the end of the slice.
102///
103/// Return a tuple of: `(decoded_len, encoded_len)`.
104///
105/// `decoded_len <= encoded_len <= slice.len()`
106pub fn decode(slice: &mut[u8]) -> (usize, usize) {
107    let cells = Cell::from_mut(slice).as_slice_of_cells();
108    let mut chunks = cells.chunks_exact(4);
109    let mut dest = cells.iter();
110    let mut dcount: usize = 0;
111    for slice in chunks.by_ref() {
112        match slice.iter().try_fold(1, decode_cell) {
113            Ok(packed) => {
114                // SAFETY: dest and chunks iterate over the same cells slice,
115                // while for every 4 byte chunk only 3 dest bytes are consumed,
116                // there's no way dest.next() can be None at any point
117                unsafe {
118                    dest.next().unwrap_unchecked().set((packed >> 16) as u8);
119                    dest.next().unwrap_unchecked().set((packed >> 8) as u8);
120                    dest.next().unwrap_unchecked().set(packed as u8);
121                }
122                dcount += 3;
123            }
124            Err(packed) => return handle_tail(dcount, packed, dest)
125        }
126    }
127    match chunks.remainder().iter().try_fold(1, decode_cell) {
128        /* no tail */
129        Ok(1) => (dcount, dcount * 4 / 3),
130        /* some tail */
131        Ok(packed)|Err(packed) => handle_tail(dcount, packed, dest)
132    }
133}
134
135fn handle_tail<'a, I>(mut dcount: usize, mut packed: u32, mut dest: I) -> (usize, usize)
136    where I: Iterator<Item=&'a Cell<u8>>
137{
138    // 31->(+0, +0), 25->(+0, +1), 19->(+1, +2), 13->(+2, +3)
139    let leftovers = (31 - packed.leading_zeros()) / 6;
140    packed <<= leftovers*2;
141    let mut tail_dcount = leftovers.saturating_sub(1);
142    let ecount = dcount * 4 / 3 + leftovers as usize;
143    dcount += tail_dcount as usize;
144    while tail_dcount != 0 {
145        dest.next().unwrap().set((packed >> (tail_dcount * 8)) as u8);
146        tail_dcount -= 1;
147    }
148    (dcount, ecount)
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use crate::ser_write::{SerError, SliceWriter};
155
156    #[test]
157    fn test_base64_encode() {
158        let mut buf = [0u8;6];
159        let writer = &mut SliceWriter::new(&mut buf);
160        encode(writer, &[]).unwrap();
161        assert_eq!(writer.as_ref(), b"");
162        encode(writer, &[0]).unwrap();
163        assert_eq!(writer.as_ref(), b"AA");
164        writer.clear();
165        encode(writer, &[1]).unwrap();
166        assert_eq!(writer.as_ref(), b"AQ");
167        writer.clear();
168        encode(writer, &[0,0]).unwrap();
169        assert_eq!(writer.as_ref(), b"AAA");
170        writer.clear();
171        encode(writer, &[0,0,0]).unwrap();
172        assert_eq!(writer.as_ref(), b"AAAA");
173        writer.clear();
174        encode(writer, &[0,0,0,0]).unwrap();
175        assert_eq!(writer.as_ref(), b"AAAAAA");
176        writer.clear();
177        encode(writer, &[1,2]).unwrap();
178        assert_eq!(writer.as_ref(), b"AQI");
179        writer.clear();
180        encode(writer, &[1,2,3]).unwrap();
181        assert_eq!(writer.as_ref(), b"AQID");
182        writer.clear();
183        encode(writer, &[1,2,3,4]).unwrap();
184        assert_eq!(writer.as_ref(), b"AQIDBA");
185        writer.clear();
186        encode(writer, &[0x80]).unwrap();
187        assert_eq!(writer.as_ref(), b"gA");
188        writer.clear();
189        encode(writer, &[0x80,0x81]).unwrap();
190        assert_eq!(writer.as_ref(), b"gIE");
191        writer.clear();
192        encode(writer, &[0x80,0x81,0x82]).unwrap();
193        assert_eq!(writer.as_ref(), b"gIGC");
194        writer.clear();
195        encode(writer, &[0xFF]).unwrap();
196        assert_eq!(writer.as_ref(), b"/w");
197        writer.clear();
198        encode(writer, &[0xFF,0xFF]).unwrap();
199        assert_eq!(writer.as_ref(), b"//8");
200        writer.clear();
201        encode(writer, &[0xFF,0xFF,0xFE]).unwrap();
202        assert_eq!(writer.as_ref(), b"///+");
203        writer.clear();
204        encode(writer, &[0xFF,0xFF,0xFF]).unwrap();
205        assert_eq!(writer.as_ref(), b"////");
206        assert_eq!(encode(writer, b"12345"), Err(SerError::BufferFull));
207        let mut buf = [0u8;1];
208        let writer = &mut SliceWriter::new(&mut buf);
209        assert_eq!(encode(writer, b"1"), Err(SerError::BufferFull));
210        let mut buf = [0u8;1];
211        let writer = &mut SliceWriter::new(&mut buf);
212        assert_eq!(encode(writer, b"12"), Err(SerError::BufferFull));
213    }
214
215    fn test_decode(buf: &mut[u8], encoded: &[u8], expected: (usize, usize), decoded: &[u8]) {
216        for i in 0..=4 {
217            let mut vec = SliceWriter::new(buf);
218            vec.write(encoded).unwrap();
219            for _ in 0..i {
220                vec.write_byte(b'=').unwrap();
221            }
222            let output = vec.split().0;
223            assert_eq!(decode(output), expected);
224            assert_eq!(&output[..expected.0], decoded);
225            if i == 0 {
226                assert_eq!(output.len(), expected.1);
227            }
228            else {
229                assert_eq!(output[expected.1], b'=');
230            }
231        }
232    }
233
234    #[test]
235    fn test_base64_decode() {
236        let buf = &mut [0u8;12];
237        test_decode(buf, b"", (0, 0), &[]);
238        test_decode(buf, b"A", (0, 1), &[]);
239        test_decode(buf, br"/", (0, 1), &[]);
240        test_decode(buf, br"AA", (1,2), &[0]);
241        test_decode(buf, br"AAA", (2,3), &[0,0]);
242        test_decode(buf, br"AAAA", (3,4), &[0,0,0]);
243        test_decode(buf, br"AAAAA", (3,5), &[0,0,0]);
244        test_decode(buf, br"AAAAAA", (4,6), &[0,0,0,0]);
245        test_decode(buf, br"AQ", (1,2), &[1]);
246        test_decode(buf, br"AQI", (2,3), &[1,2]);
247        test_decode(buf, br"AQID", (3,4), &[1,2,3]);
248        test_decode(buf, br"AQIDB", (3,5), &[1,2,3]);
249        test_decode(buf, br"AQIDBA", (4,6), &[1,2,3,4]);
250        test_decode(buf, br"gA", (1,2), &[0x80]);
251        test_decode(buf, br"gIE", (2,3), &[0x80,0x81]);
252        test_decode(buf, br"gIGC", (3,4), &[0x80,0x81,0x82]);
253        test_decode(buf, br"/w", (1,2), &[0xFF]);
254        test_decode(buf, br"//8", (2,3), &[0xFF,0xFF]);
255        test_decode(buf, br"////", (3,4), &[0xFF,0xFF,0xFF]);
256        test_decode(buf, br"/////w", (4,6), &[0xFF,0xFF,0xFF,0xFF]);
257        test_decode(buf, br"//////8", (5,7), &[0xFF,0xFF,0xFF,0xFF,0xFF]);
258        test_decode(buf, br"////////", (6,8), &[0xFF,0xFF,0xFF,0xFF,0xFF,0xFF]);
259        test_decode(buf, br"/v", (1,2), &[0xFE]);
260        test_decode(buf, br"//7", (2,3), &[0xFF,0xFE]);
261        test_decode(buf, br"///+", (3,4), &[0xFF,0xFF,0xFE]);
262        test_decode(buf, br"/////v", (4,6), &[0xFF,0xFF,0xFF,0xFE]);
263        test_decode(buf, br"///+//7", (5,7), &[0xFF,0xFF,0xFE,0xFF,0xFE]);
264        test_decode(buf, br"///+///+", (6,8), &[0xFF,0xFF,0xFE,0xFF,0xFF,0xFE]);
265  }
266}