domain_core/utils/
base64.rs

1//! Decoding and encoding of Base64.
2
3use std::fmt;
4use bytes::{BufMut, Bytes, BytesMut};
5
6
7//------------ Convenience Functions -----------------------------------------
8
9pub fn decode(s: &str) -> Result<Bytes, DecodeError> {
10    let mut decoder = Decoder::new();
11    for ch in s.chars() {
12        decoder.push(ch)?;
13    }
14    decoder.finalize()
15}
16
17
18pub fn display<B, W>(bytes: &B, f: &mut W) -> fmt::Result
19where B: AsRef<[u8]> + ?Sized, W: fmt::Write { 
20    fn ch(i: u8) -> char {
21        ENCODE_ALPHABET[i as usize]
22    }
23
24    for chunk in bytes.as_ref().chunks(3) {
25        match chunk.len() {
26            1 => {
27                f.write_char(ch(chunk[0] >> 2))?;
28                f.write_char(ch((chunk[0] & 0x03) << 4))?;
29                f.write_char('=')?;
30                f.write_char('=')?;
31            }
32            2 => {
33                f.write_char(ch(chunk[0] >> 2))?;
34                f.write_char(ch((chunk[0] & 0x03) << 4 | chunk[1] >> 4))?;
35                f.write_char(ch((chunk[1] & 0x0F) << 2))?;
36                f.write_char('=')?;
37            }
38            3 => {
39                f.write_char(ch(chunk[0] >> 2))?;
40                f.write_char(ch((chunk[0] & 0x03) << 4 | chunk[1] >> 4))?;
41                f.write_char(ch((chunk[1] & 0x0F) << 2 | chunk[2] >> 6))?;
42                f.write_char(ch(chunk[2] & 0x3F))?;
43            }
44            _ => unreachable!()
45        }
46    }
47    Ok(())
48}
49
50
51//------------ Decoder -------------------------------------------------------
52
53/// A Base64 decoder.
54pub struct Decoder {
55    /// A buffer for up to four characters.
56    ///
57    /// We only keep `u8`s here because only ASCII characters are used by
58    /// Base64.
59    buf: [u8; 4],
60
61    /// The index in `buf` where we place the next character.
62    ///
63    /// We also abuse this to mark when we are done (because there was
64    /// padding, in which case we set it to 0xF0).
65    next: usize,
66
67    /// The target or an error if something went wrong.
68    target: Result<BytesMut, DecodeError>,
69}
70
71impl Decoder {
72    pub fn new() -> Self {
73        Decoder {
74            buf: [0; 4],
75            next: 0,
76            target: Ok(BytesMut::new()),
77        }
78    }
79
80    pub fn finalize(self) -> Result<Bytes, DecodeError> {
81        let (target, next) = (self.target, self.next);
82        target.and_then(|bytes| {
83            // next is either 0 or 0xF0 for a completed group.
84            if next & 0x0F != 0 {
85                Err(DecodeError::IncompleteInput)
86            }
87            else {
88                Ok(bytes.freeze())
89            }
90        })
91    }
92
93    pub fn push(&mut self, ch: char) -> Result<(), DecodeError> {
94        if self.next == 0xF0 {
95            self.target = Err(DecodeError::TrailingInput);
96            return Err(DecodeError::TrailingInput)
97        }
98
99        let val = if ch == PAD {
100            // Only up to two padding characters possible.
101            if self.next < 2 {
102                return Err(DecodeError::IllegalChar(ch))
103            }
104            0x80 // Acts as a marker later on.
105        }
106        else {
107            if ch > (127 as char) {
108                return Err(DecodeError::IllegalChar(ch))
109            }
110            let val = DECODE_ALPHABET[ch as usize];
111            if val == 0xFF {
112                return Err(DecodeError::IllegalChar(ch))
113            }
114            val
115        };
116        self.buf[self.next] = val;
117        self.next += 1;
118
119        if self.next == 4 {
120            let target = self.target.as_mut().unwrap(); // Err covered above.
121            target.reserve(3);
122            target.put_u8(self.buf[0] << 2 | self.buf[1] >> 4);
123            if self.buf[2] != 0x80 {
124                target.put_u8(self.buf[1] << 4 | self.buf[2] >> 2)
125            }
126            if self.buf[3] != 0x80 {
127                if self.buf[2] == 0x80 {
128                    return Err(DecodeError::TrailingInput)
129                }
130                target.put_u8((self.buf[2] << 6) | self.buf[3]);
131                self.next = 0
132            }
133            else {
134                self.next = 0xF0
135            }
136        }
137
138        Ok(())
139    }
140}
141
142
143//--- Default
144
145impl Default for Decoder {
146    fn default() -> Self {
147        Self::new()
148    }
149}
150
151//------------ DecodeError ---------------------------------------------------
152
153/// An error happened while decoding a Base64 string.
154#[derive(Debug, Fail, Eq, PartialEq)]
155pub enum DecodeError {
156    #[fail(display="incomplete input")]
157    IncompleteInput,
158
159    #[fail(display="trailing input")]
160    TrailingInput,
161
162    #[fail(display="illegal character '{}'", _0)]
163    IllegalChar(char),
164}
165
166
167//------------ Constants -----------------------------------------------------
168
169/// The alphabet used by the decoder.
170///
171/// This maps encoding characters into their values. A value of 0xFF stands in
172/// for illegal characters. We only provide the first 128 characters since the
173/// alphabet will only use ASCII characters.
174const DECODE_ALPHABET: [u8; 128] = [
175    0xFF, 0xFF, 0xFF, 0xFF,   0xFF, 0xFF, 0xFF, 0xFF,  // 0x00 .. 0x07
176    0xFF, 0xFF, 0xFF, 0xFF,   0xFF, 0xFF, 0xFF, 0xFF,  // 0x08 .. 0x0F
177
178    0xFF, 0xFF, 0xFF, 0xFF,   0xFF, 0xFF, 0xFF, 0xFF,  // 0x10 .. 0x17
179    0xFF, 0xFF, 0xFF, 0xFF,   0xFF, 0xFF, 0xFF, 0xFF,  // 0x18 .. 0x1F
180
181    0xFF, 0xFF, 0xFF, 0xFF,   0xFF, 0xFF, 0xFF, 0xFF,  // 0x20 .. 0x27
182    0xFF, 0xFF, 0xFF, 0x3E,   0xFF, 0xFF, 0xFF, 0x3F,  // 0x28 .. 0x2F
183
184    0x34, 0x35, 0x36, 0x37,   0x38, 0x39, 0x3A, 0x3B,  // 0x30 .. 0x37
185    0x3C, 0x3D, 0xFF, 0xFF,   0xFF, 0xFF, 0xFF, 0xFF,  // 0x38 .. 0x3F
186
187    0xFF, 0x00, 0x01, 0x02,   0x03, 0x04, 0x05, 0x06,  // 0x40 .. 0x47
188    0x07, 0x08, 0x09, 0x0A,   0x0B, 0x0C, 0x0D, 0x0E,  // 0x48 .. 0x4F
189
190    0x0F, 0x10, 0x11, 0x12,   0x13, 0x14, 0x15, 0x16,  // 0x50 .. 0x57
191    0x17, 0x18, 0x19, 0xFF,   0xFF, 0xFF, 0xFF, 0xFF,  // 0x58 .. 0x5F
192
193    0xFF, 0x1A, 0x1B, 0x1C,   0x1D, 0x1E, 0x1F, 0x20,  // 0x60 .. 0x67
194    0x21, 0x22, 0x23, 0x24,   0x25, 0x26, 0x27, 0x28,  // 0x68 .. 0x6F
195
196    0x29, 0x2A, 0x2B, 0x2C,   0x2D, 0x2E, 0x2F, 0x30,  // 0x70 .. 0x77
197    0x31, 0x32, 0x33, 0xFF,   0xFF, 0xFF, 0xFF, 0xFF,  // 0x78 .. 0x7F
198];
199
200const ENCODE_ALPHABET: [char; 64] = [
201    'A', 'B', 'C', 'D',   'E', 'F', 'G', 'H',   // 0x00 .. 0x07
202    'I', 'J', 'K', 'L',   'M', 'N', 'O', 'P',   // 0x08 .. 0x0F
203
204    'Q', 'R', 'S', 'T',   'U', 'V', 'W', 'X',   // 0x10 .. 0x17
205    'Y', 'Z', 'a', 'b',   'c', 'd', 'e', 'f',   // 0x18 .. 0x1F
206
207    'g', 'h', 'i', 'j',   'k', 'l', 'm', 'n',   // 0x20 .. 0x27
208    'o', 'p', 'q', 'r',   's', 't', 'u', 'v',   // 0x28 .. 0x2F
209
210    'w', 'x', 'y', 'z',   '0', '1', '2', '3',   // 0x30 .. 0x37
211    '4', '5', '6', '7',   '8', '9', '+', '/',   // 0x38 .. 0x3F
212];
213
214/// The padding character
215const PAD: char = '=';
216
217
218//============ Test ==========================================================
219
220#[cfg(test)]
221mod test {
222    use super::*;
223
224    #[test]
225    fn decode_str() {
226        assert_eq!(decode("").unwrap().as_ref(), b"");
227        assert_eq!(decode("Zg==").unwrap().as_ref(), b"f");
228        assert_eq!(decode("Zm8=").unwrap().as_ref(), b"fo");
229        assert_eq!(decode("Zm9v").unwrap().as_ref(), b"foo");
230        assert_eq!(decode("Zm9vYg==").unwrap().as_ref(), b"foob");
231        assert_eq!(decode("Zm9vYmE=").unwrap().as_ref(), b"fooba");
232        assert_eq!(decode("Zm9vYmFy").unwrap().as_ref(), b"foobar");
233
234        assert_eq!(decode("FPucA").unwrap_err(),
235                   DecodeError::IncompleteInput);
236        assert_eq!(decode("FPucA=").unwrap_err(),
237                   DecodeError::IllegalChar('='));
238        assert_eq!(decode("FPucAw=").unwrap_err(),
239                   DecodeError::IncompleteInput);
240        assert_eq!(decode("FPucAw=a").unwrap_err(),
241                   DecodeError::TrailingInput);
242        assert_eq!(decode("FPucAw==a").unwrap_err(),
243                   DecodeError::TrailingInput);
244    }
245
246    #[test]
247    fn display_bytes() {
248        fn fmt(s: &[u8]) -> String {
249            let mut out = String::new();
250            display(s, &mut out).unwrap();
251            out
252        }
253
254        assert_eq!(fmt(b""), "");
255        assert_eq!(fmt(b"f"), "Zg==");
256        assert_eq!(fmt(b"fo"), "Zm8=");
257        assert_eq!(fmt(b"foo"), "Zm9v");
258        assert_eq!(fmt(b"foob"), "Zm9vYg==");
259        assert_eq!(fmt(b"fooba"), "Zm9vYmE=");
260        assert_eq!(fmt(b"foobar"), "Zm9vYmFy");
261    }
262}
263