b64_ct/decode/
mod.rs

1/* Copyright (c) Fortanix, Inc.
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6
7#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
8mod avx2;
9mod lut_align64;
10
11use alloc::vec::Vec;
12use core::cmp;
13use core::fmt;
14
15#[must_use]
16struct BlockResult {
17    out_length: u8,
18    first_invalid: Option<u8>,
19}
20
21/// Errors that can occur when decoding a base64 encoded string
22#[derive(Debug, Clone, Copy)]
23pub enum Error {
24    /// The input had an invalid length.
25    InvalidLength,
26    /// A trailer was found, but it wasn't the right length.
27    InvalidTrailer,
28    /// The input contained a character (at the given index) not part of the
29    /// base64 format.
30    InvalidCharacter(usize),
31}
32
33impl fmt::Display for Error {
34    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
35        fmt::Debug::fmt(&self, f)
36    }
37}
38
39trait Decoder: Copy {
40    type Block: AsRef<[u8]> + AsMut<[u8]>;
41
42    fn decode_block(self, block: &mut Self::Block) -> BlockResult;
43    fn zero_block() -> Self::Block;
44}
45
46trait Packer: Copy {
47    type Input: AsRef<[u8]> + AsMut<[u8]> + Default;
48    const OUT_BUF_LEN: usize;
49
50    /// The caller should pass `output` as a slice with length `OUT_BUF_LEN`.
51    fn pack_block(self, input: &Self::Input, output: &mut [u8]);
52}
53
54#[derive(Copy, Clone)]
55struct Simple;
56
57impl Packer for Simple {
58    type Input = [u8; 4];
59    const OUT_BUF_LEN: usize = 3;
60
61    #[inline]
62    fn pack_block(self, input: &Self::Input, output: &mut [u8]) {
63        output[0] = (input[0] << 2) | (input[1] >> 4);
64        output[1] = (input[1] << 4) | (input[2] >> 2);
65        output[2] = (input[2] << 6) | (input[3] >> 0);
66    }
67}
68
69struct PackState<P: Packer> {
70    packer: P,
71    cache: P::Input,
72    pos: usize,
73}
74
75impl<P: Packer> PackState<P> {
76    fn extend(&mut self, mut input: &[u8], out: &mut Vec<u8>) {
77        while !input.is_empty() {
78            let (_, cache_end) = self.cache.as_mut().split_at_mut(self.pos);
79            let (input_start, input_rest) = input.split_at(cmp::min(input.len(), cache_end.len()));
80            input = input_rest;
81            cache_end[..input_start.len()].copy_from_slice(input_start);
82            if input_start.len() != cache_end.len() {
83                self.pos += input_start.len();
84            } else {
85                let out_start = out.len();
86                out.resize(out_start + P::OUT_BUF_LEN, 0);
87                self.packer.pack_block(&self.cache, &mut out[out_start..]);
88                out.truncate(out_start + (core::mem::size_of::<P::Input>() / 4 * 3));
89                self.pos = 0;
90            }
91        }
92    }
93
94    fn flush(&mut self, out: &mut Vec<u8>, trailer_length: Option<usize>) -> Result<(), Error> {
95        if self.pos % 4 == 1 {
96            return Err(Error::InvalidLength);
97        }
98
99        if let Some(trailer_length) = trailer_length {
100            if (self.pos + trailer_length) % 4 != 0 {
101                return Err(Error::InvalidTrailer);
102            }
103        }
104
105        self.cache.as_mut()[self.pos] = 0;
106        let out_start = out.len();
107        out.resize(out.len() + P::OUT_BUF_LEN, 0);
108        self.packer.pack_block(&self.cache, &mut out[out_start..]);
109        out.truncate(out_start + (self.pos * 3 / 4));
110        Ok(())
111    }
112}
113
114fn decode64<D: Decoder, P: Packer>(input: &[u8], decoder: D, packer: P) -> Result<Vec<u8>, Error> {
115    if input.is_empty() {
116        return Ok(Vec::new());
117    }
118
119    let p_in_len = core::mem::size_of::<P::Input>();
120    let p_out_len = p_in_len / 4 * 3;
121    let cap =
122        crate::misc::div_roundup(input.len(), p_in_len) * p_out_len - p_out_len + P::OUT_BUF_LEN;
123    let mut out = Vec::with_capacity(cap);
124
125    let mut packer = PackState::<P> {
126        packer,
127        cache: P::Input::default(),
128        pos: 0,
129    };
130
131    let mut trailer_length = None;
132    for (chunk, chunk_start) in input
133        .chunks(core::mem::size_of::<D::Block>())
134        .zip((0..).step_by(core::mem::size_of::<D::Block>()))
135    {
136        let mut block = D::zero_block();
137        block.as_mut()[..chunk.len()].copy_from_slice(chunk);
138        let result = decoder.decode_block(&mut block);
139
140        if let Some(idx) = result.first_invalid {
141            let idx = idx as usize;
142            if input[chunk_start + idx] == b'=' {
143                let rest_start = chunk_start + idx + 1;
144                let rest = &input[rest_start..];
145                let mut iter = rest
146                    .iter()
147                    .enumerate()
148                    .filter(|(_, c)| !c.is_ascii_whitespace());
149                trailer_length = match (iter.next(), iter.next()) {
150                    (None, _) => Some(1),
151                    (Some((_, b'=')), None) => Some(2),
152                    (Some((_, b'=')), Some((i, _))) | (Some((i, _)), _) => {
153                        return Err(Error::InvalidCharacter(rest_start + i))
154                    }
155                };
156            } else {
157                return Err(Error::InvalidCharacter(chunk_start + idx));
158            }
159        }
160
161        packer.extend(&block.as_ref()[..(result.out_length as _)], &mut out);
162
163        if trailer_length.is_some() {
164            break;
165        }
166    }
167
168    packer.flush(&mut out, trailer_length)?;
169
170    Ok(out)
171}
172
173pub(super) fn decode64_arch(input: &[u8]) -> Result<Vec<u8>, Error> {
174    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
175    unsafe {
176        if is_x86_feature_detected!("avx2")
177            && is_x86_feature_detected!("bmi1")
178            && is_x86_feature_detected!("sse4.2")
179            && is_x86_feature_detected!("popcnt")
180        {
181            let avx2 = avx2::Avx2::new();
182            return decode64(input, avx2, avx2);
183        }
184    }
185    decode64(input, lut_align64::LutAlign64, Simple)
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    use crate::test_support::rand_base64_size;
193    use crate::ToBase64;
194
195    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
196    pub(super) fn test_avx2() -> avx2::Avx2 {
197        unsafe { avx2::Avx2::new() }
198    }
199
200    generate_tests![
201        decoders<D>: {
202            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] avx2, test_avx2();
203            lut_align64, lut_align64::LutAlign64;
204        },
205        packers<P>: {
206            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] avx2, test_avx2();
207            simple, Simple;
208        },
209        tests: {
210            decode,
211            decode_equivalency,
212            decode_error,
213            cmp_rand_1kb,
214            whitespace_skipped,
215            all_bytes,
216            wrapping_base64,
217        },
218    ];
219
220    fn decode<D: Decoder, P: Packer>(decoder: D, packer: P) {
221        static DECODE_TESTS: &[(&[u8], &[u8])] = &[
222            // basic tests (from rustc-serialize)
223            (b"", b""),
224            (b"Zg==", b"f"),
225            (b"Zm8=", b"fo"),
226            (b"Zm9v", b"foo"),
227            (b"Zm9vYg==", b"foob"),
228            (b"Zm9vYmE=", b"fooba"),
229            (b"Zm9vYmFy", b"foobar"),
230            // with newlines (from rustc-serialize)
231            (b"Zm9v\r\nYmFy", b"foobar"),
232            (b"Zm9vYg==\r\n", b"foob"),
233            (b"Zm9v\nYmFy", b"foobar"),
234            (b"Zm9vYg==\n", b"foob"),
235            // white space in trailer
236            (b"Zm9vYg  =  =  ", b"foob"),
237        ];
238
239        for (input, expected) in DECODE_TESTS {
240            let output = decode64(input, decoder, packer).unwrap();
241            if &output != expected {
242                panic!(
243                    "Test failed. Expected specific output. \n\nInput: {}\nOutput: {:02x?}\nExpected output:{:02x?}\n\n",
244                    std::str::from_utf8(input).unwrap(),
245                    output,
246                    expected
247                );
248            }
249        }
250    }
251
252    fn decode_equivalency<D: Decoder, P: Packer>(decoder: D, packer: P) {
253        static DECODE_EQUIVALENCY_TESTS: &[(&[u8], &[u8])] = &[
254            // url safe test (from rustc-serialize)
255            (b"-_8", b"+/8="),
256        ];
257
258        for (input1, input2) in DECODE_EQUIVALENCY_TESTS {
259            let output1 = decode64(input1, decoder, packer).unwrap();
260            let output2 = decode64(input2, decoder, packer).unwrap();
261            if output1 != output2 {
262                panic!(
263                    "Test failed. Expected same output.\n\nInput 1: {}\nInput 2: {}\nOutput 1: {:02x?}\nOutput 2:{:02x?}\n\n",
264                    std::str::from_utf8(input1).unwrap(),
265                    std::str::from_utf8(input2).unwrap(),
266                    output1,
267                    output2
268                );
269            }
270        }
271    }
272
273    fn decode_error<D: Decoder, P: Packer>(decoder: D, packer: P) {
274        #[rustfmt::skip]
275        static DECODE_ERROR_TESTS: &[&[u8]] = &[
276            // invalid chars (from rustc-serialize)
277            b"Zm$=",
278            b"Zg==$",
279            // invalid padding (from rustc-serialize)
280            b"Z===",
281        ];
282
283        for input in DECODE_ERROR_TESTS {
284            if decode64(input, decoder, packer).is_ok() {
285                panic!(
286                    "Test failed. Expected error.\n\nInput: {}\n\n",
287                    std::str::from_utf8(input).unwrap(),
288                );
289            }
290        }
291    }
292
293    fn cmp_rand_1kb<D: Decoder, P: Packer>(decoder: D, packer: P) {
294        let input = rand_base64_size(1024);
295
296        let output1 = decode64(&input, decoder, packer).unwrap();
297        let output2 = decode64(&input, lut_align64::LutAlign64, Simple).unwrap();
298        if output1 != output2 {
299            panic!(
300                "Test failed. Expected same output.\n\nInput: {}\nOutput 1: {:02x?}\nOutput 2:{:02x?}\n\n",
301                std::str::from_utf8(&input).unwrap(),
302                output1,
303                output2
304            );
305        }
306    }
307
308    fn whitespace_skipped<D: Decoder, P: Packer>(decoder: D, packer: P) {
309        let input1 = rand_base64_size(32);
310        use core::iter::once;
311        let input2 = input1
312            .iter()
313            .flat_map(|&c| once(c).chain(once(b' ')))
314            .collect::<Vec<_>>();
315
316        let output1 = decode64(&input1, decoder, packer).unwrap();
317        let output2 = decode64(&input2, decoder, packer).unwrap();
318        if output1 != output2 {
319            panic!(
320                "Test failed. Expected same output.\n\nInput 1: {}\nInput 2: {}\nOutput 1: {:02x?}\nOutput 2:{:02x?}\n\n",
321                std::str::from_utf8(&input1).unwrap(),
322                std::str::from_utf8(&input2).unwrap(),
323                output1,
324                output2
325            );
326        }
327    }
328
329    fn all_bytes<D: Decoder, P: Packer>(decoder: D, packer: P) {
330        let mut set = std::vec![Err(()); 256];
331        for (i, &b) in crate::misc::LUT_STANDARD.iter().enumerate() {
332            set[b as usize] = Ok(Some(i as u8));
333        }
334        // add URL-safe set
335        set[b'-' as usize] = Ok(Some(62));
336        set[b'_' as usize] = Ok(Some(63));
337        // add whitespace
338        set[b' ' as usize] = Ok(None);
339        set[b'\n' as usize] = Ok(None);
340        set[b'\t' as usize] = Ok(None);
341        set[b'\r' as usize] = Ok(None);
342        set[0x0c] = Ok(None);
343
344        for (i, &expected) in set.iter().enumerate() {
345            let output = match decode64(&[i as u8, i as u8], decoder, packer)
346                .as_ref()
347                .map(|v| &v[..])
348            {
349                Ok(&[]) => Ok(None),
350                Ok(&[v]) => Ok(Some(v >> 2)),
351                Ok(_) => panic!("Result is more than 1 byte long"),
352                Err(_) => Err(()),
353            };
354            assert_eq!(output, expected);
355        }
356    }
357
358    fn wrapping_base64<D: Decoder, P: Packer>(decoder: D, packer: P) {
359        const BASE64_PEM_WRAP: usize = 64;
360
361        static BASE64_PEM: crate::Config = crate::Config {
362            char_set: crate::CharacterSet::Standard,
363            newline: crate::Newline::LF,
364            pad: true,
365            line_length: Some(BASE64_PEM_WRAP),
366        };
367
368        let mut v: Vec<u8> = vec![];
369        let bytes_per_line = BASE64_PEM_WRAP * 3 / 4;
370        for _i in 0..(2 * bytes_per_line) {
371            let encoded = v.to_base64(BASE64_PEM);
372            let decoded = decode64(encoded.as_bytes(), decoder, packer).unwrap();
373            assert_eq!(v, decoded);
374            v.push(0);
375        }
376
377        v = vec![];
378        for _i in 0..1000 {
379            let encoded = v.to_base64(BASE64_PEM);
380            let decoded = decode64(encoded.as_bytes(), decoder, packer).unwrap();
381            assert_eq!(v, decoded);
382            v.push(rand::random::<u8>());
383        }
384    }
385
386    #[test]
387    fn display_errors() {
388        println!("Invalid length is {}", Error::InvalidLength);
389        println!("Invalid trailer is {}", Error::InvalidTrailer);
390        println!("Invalid character is {}", Error::InvalidCharacter(0));
391    }
392}
393
394#[cfg(all(test, feature = "nightly"))]
395mod benches {
396    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
397    use super::tests::test_avx2;
398    use super::*;
399
400    use test::Bencher;
401
402    use crate::test_support::rand_base64_size;
403
404    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
405    #[bench]
406    fn avx2_1mb(b: &mut Bencher) {
407        let input = rand_base64_size(1024 * 1024);
408        b.iter(|| {
409            let ret = decode64(&input, test_avx2(), test_avx2()).unwrap();
410            std::hint::black_box(ret);
411        });
412    }
413
414    #[bench]
415    fn lut_align64_1mb(b: &mut Bencher) {
416        let input = rand_base64_size(1024 * 1024);
417        b.iter(|| {
418            let ret = decode64(&input, lut_align64::LutAlign64, Simple).unwrap();
419            std::hint::black_box(ret);
420        });
421    }
422
423    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
424    #[bench]
425    fn avx2_1kb(b: &mut Bencher) {
426        let input = rand_base64_size(1024);
427        b.iter(|| {
428            let ret = decode64(&input, test_avx2(), test_avx2()).unwrap();
429            std::hint::black_box(ret);
430        });
431    }
432
433    #[bench]
434    fn lut_align64_1kb(b: &mut Bencher) {
435        let input = rand_base64_size(1024);
436        b.iter(|| {
437            let ret = decode64(&input, lut_align64::LutAlign64, Simple).unwrap();
438            std::hint::black_box(ret);
439        });
440    }
441}