protobuf_json_mapping/
base64.rs

1//! Copy-pasted from the internet
2/// Available encoding character sets
3#[derive(Clone, Copy, Debug)]
4enum _CharacterSet {
5    /// The standard character set (uses `+` and `/`)
6    _Standard,
7    /// The URL safe character set (uses `-` and `_`)
8    _UrlSafe,
9}
10
11static STANDARD_CHARS: &'static [u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\
12                                        abcdefghijklmnopqrstuvwxyz\
13                                        0123456789+/";
14
15static _URLSAFE_CHARS: &'static [u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\
16                                       abcdefghijklmnopqrstuvwxyz\
17                                       0123456789-_";
18
19pub fn encode(input: &[u8]) -> String {
20    let bytes = STANDARD_CHARS;
21
22    let len = input.len();
23
24    // Preallocate memory.
25    let prealloc_len = (len + 2) / 3 * 4;
26    let mut out_bytes = vec![b'='; prealloc_len];
27
28    // Deal with padding bytes
29    let mod_len = len % 3;
30
31    // Use iterators to reduce branching
32    {
33        let mut s_in = input[..len - mod_len].iter().map(|&x| x as u32);
34        let mut s_out = out_bytes.iter_mut();
35
36        // Convenient shorthand
37        let enc = |val| bytes[val as usize];
38        let mut write = |val| *s_out.next().unwrap() = val;
39
40        // Iterate though blocks of 4
41        while let (Some(first), Some(second), Some(third)) = (s_in.next(), s_in.next(), s_in.next())
42        {
43            let n = first << 16 | second << 8 | third;
44
45            // This 24-bit number gets separated into four 6-bit numbers.
46            write(enc((n >> 18) & 63));
47            write(enc((n >> 12) & 63));
48            write(enc((n >> 6) & 63));
49            write(enc((n >> 0) & 63));
50        }
51
52        // Heh, would be cool if we knew this was exhaustive
53        // (the dream of bounded integer types)
54        match mod_len {
55            0 => (),
56            1 => {
57                let n = (input[len - 1] as u32) << 16;
58                write(enc((n >> 18) & 63));
59                write(enc((n >> 12) & 63));
60            }
61            2 => {
62                let n = (input[len - 2] as u32) << 16 | (input[len - 1] as u32) << 8;
63                write(enc((n >> 18) & 63));
64                write(enc((n >> 12) & 63));
65                write(enc((n >> 6) & 63));
66            }
67            _ => panic!("Algebra is broken, please alert the math police"),
68        }
69    }
70
71    // `out_bytes` vec is prepopulated with `=` symbols and then only updated
72    // with base64 chars, so this unsafe is safe.
73    unsafe { String::from_utf8_unchecked(out_bytes) }
74}
75
76/// Errors that can occur when decoding a base64 encoded string
77#[derive(Clone, Copy, Debug, thiserror::Error)]
78pub enum FromBase64Error {
79    /// The input contained a character not part of the base64 format
80    #[error("Invalid base64 byte")]
81    InvalidBase64Byte(u8, usize),
82    /// The input had an invalid length
83    #[error("Invalid base64 length")]
84    InvalidBase64Length,
85}
86
87pub fn decode(input: &str) -> Result<Vec<u8>, FromBase64Error> {
88    let mut r = Vec::with_capacity(input.len());
89    let mut buf: u32 = 0;
90    let mut modulus = 0;
91
92    let mut it = input.as_bytes().iter();
93    for byte in it.by_ref() {
94        let code = DECODE_TABLE[*byte as usize];
95        if code >= SPECIAL_CODES_START {
96            match code {
97                NEWLINE_CODE => continue,
98                EQUALS_CODE => break,
99                INVALID_CODE => {
100                    return Err(FromBase64Error::InvalidBase64Byte(
101                        *byte,
102                        (byte as *const _ as usize) - input.as_ptr() as usize,
103                    ))
104                }
105                _ => unreachable!(),
106            }
107        }
108        buf = (buf | code as u32) << 6;
109        modulus += 1;
110        if modulus == 4 {
111            modulus = 0;
112            r.push((buf >> 22) as u8);
113            r.push((buf >> 14) as u8);
114            r.push((buf >> 6) as u8);
115        }
116    }
117
118    for byte in it {
119        match *byte {
120            b'=' | b'\r' | b'\n' => continue,
121            _ => {
122                return Err(FromBase64Error::InvalidBase64Byte(
123                    *byte,
124                    (byte as *const _ as usize) - input.as_ptr() as usize,
125                ))
126            }
127        }
128    }
129
130    match modulus {
131        2 => {
132            r.push((buf >> 10) as u8);
133        }
134        3 => {
135            r.push((buf >> 16) as u8);
136            r.push((buf >> 8) as u8);
137        }
138        0 => (),
139        _ => return Err(FromBase64Error::InvalidBase64Length),
140    }
141
142    Ok(r)
143}
144
145const DECODE_TABLE: [u8; 256] = [
146    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFD, 0xFF, 0xFF, 0xFD, 0xFF, 0xFF,
147    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
148    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x3E, 0xFF, 0x3E, 0xFF, 0x3F,
149    0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0xFF, 0xFF, 0xFF, 0xFE, 0xFF, 0xFF,
150    0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E,
151    0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0xFF, 0xFF, 0xFF, 0xFF, 0x3F,
152    0xFF, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
153    0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
154    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
155    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
156    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
157    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
158    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
159    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
160    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
161    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
162];
163const INVALID_CODE: u8 = 0xFF;
164const EQUALS_CODE: u8 = 0xFE;
165const NEWLINE_CODE: u8 = 0xFD;
166const SPECIAL_CODES_START: u8 = NEWLINE_CODE;
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn test_encode_basic() {
174        assert_eq!(encode(b""), "");
175        assert_eq!(encode(b"f"), "Zg==");
176        assert_eq!(encode(b"fo"), "Zm8=");
177        assert_eq!(encode(b"foo"), "Zm9v");
178        assert_eq!(encode(b"foob"), "Zm9vYg==");
179        assert_eq!(encode(b"fooba"), "Zm9vYmE=");
180        assert_eq!(encode(b"foobar"), "Zm9vYmFy");
181    }
182
183    #[test]
184    fn test_encode_standard_safe() {
185        assert_eq!(encode(&[251, 255]), "+/8=");
186    }
187
188    #[test]
189    fn test_decode_basic() {
190        assert_eq!(decode("").unwrap(), b"");
191        assert_eq!(decode("Zg==").unwrap(), b"f");
192        assert_eq!(decode("Zm8=").unwrap(), b"fo");
193        assert_eq!(decode("Zm9v").unwrap(), b"foo");
194        assert_eq!(decode("Zm9vYg==").unwrap(), b"foob");
195        assert_eq!(decode("Zm9vYmE=").unwrap(), b"fooba");
196        assert_eq!(decode("Zm9vYmFy").unwrap(), b"foobar");
197    }
198
199    #[test]
200    fn test_decode() {
201        assert_eq!(decode("Zm9vYmFy").unwrap(), b"foobar");
202    }
203
204    #[test]
205    fn test_decode_newlines() {
206        assert_eq!(decode("Zm9v\r\nYmFy").unwrap(), b"foobar");
207        assert_eq!(decode("Zm9vYg==\r\n").unwrap(), b"foob");
208        assert_eq!(decode("Zm9v\nYmFy").unwrap(), b"foobar");
209        assert_eq!(decode("Zm9vYg==\n").unwrap(), b"foob");
210    }
211
212    #[test]
213    fn test_decode_urlsafe() {
214        assert_eq!(decode("-_8").unwrap(), decode("+/8=").unwrap());
215    }
216
217    #[test]
218    fn test_from_base64_invalid_char() {
219        assert!(decode("Zm$=").is_err());
220        assert!(decode("Zg==$").is_err());
221    }
222
223    #[test]
224    fn test_decode_invalid_padding() {
225        assert!(decode("Z===").is_err());
226    }
227}