radix64/
decode.rs

1use crate::Config;
2use std::{error, fmt};
3
4pub(crate) mod block;
5pub(crate) mod io;
6
7pub(crate) const INVALID_VALUE: u8 = 255;
8
9/// Errors that can occur during decoding.
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum DecodeError {
12    /// An invalid byte was found in the input. The offending byte is provided.
13    InvalidByte(u8),
14    /// The length of the input is invalid.
15    InvalidLength,
16    /// The last non-padding byte of input has discarded bits and those bits are
17    /// not zero. While this could be decoded it likely represents a corrupted or
18    /// invalid encoding.
19    InvalidTrailingBits,
20}
21
22impl fmt::Display for DecodeError {
23    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
24        match *self {
25            DecodeError::InvalidByte(byte) => write!(f, "invalid byte {}", byte),
26            DecodeError::InvalidLength => write!(f, "encoded text cannot have a 6-bit remainder"),
27            DecodeError::InvalidTrailingBits => {
28                write!(f, "last byte has unnecessary trailing bits")
29            }
30        }
31    }
32}
33
34impl error::Error for DecodeError {
35    fn description(&self) -> &str {
36        match *self {
37            DecodeError::InvalidByte(_) => "invalid byte",
38            DecodeError::InvalidLength => "invalid length",
39            DecodeError::InvalidTrailingBits => "invalid trailing bits",
40        }
41    }
42
43    fn cause(&self) -> Option<&dyn error::Error> {
44        None
45    }
46}
47
48// decode_slice on success will return the number of decoded bytes written.
49pub(crate) fn decode_slice<C>(
50    config: C,
51    mut input: &[u8],
52    mut output: &mut [u8],
53) -> Result<usize, DecodeError>
54where
55    C: Config,
56{
57    input = remove_padding(config, input)?;
58    let (input_idx, output_idx) = decode_full_chunks_without_padding(config, input, output)?;
59    input = &input[input_idx..];
60    output = &mut output[output_idx..];
61
62    // Deal with the remaining partial chunk. The padding characters have already been removed.
63    Ok(output_idx + decode_partial_chunk(config, input, output)?)
64}
65
66#[inline]
67fn remove_padding<C>(config: C, input: &[u8]) -> Result<&[u8], DecodeError>
68where
69    C: Config,
70{
71    Ok(if let Some(padding) = config.padding_byte() {
72        if input.len() % 4 != 0 {
73            return Err(DecodeError::InvalidLength);
74        }
75        let num_padding_bytes = input
76            .iter()
77            .rev()
78            .cloned()
79            .take_while(|&b| b == padding)
80            .take(2)
81            .count();
82        match num_padding_bytes {
83            0 => input,
84            1 => &input[..input.len() - 1],
85            2 => &input[..input.len() - 2],
86            _ => unreachable!("impossible number of padding bytes"),
87        }
88    } else {
89        input
90    })
91}
92
93#[inline]
94fn decode_full_chunks_without_padding<C>(
95    config: C,
96    mut input: &[u8],
97    mut output: &mut [u8],
98) -> Result<(usize, usize), DecodeError>
99where
100    C: Config,
101{
102    use crate::decode::block::BlockDecoder;
103    let (input_idx, output_idx) = if input.len() < 32 {
104        (0, 0)
105    } else {
106        // If input is suitably large use an architecture optimized encoder.
107        // The magic value of 27 was chosen because the avx2 encoder works with
108        // 28 byte chunks of input at a time. Benchmarks show that bypassing
109        // creating the block encoder when the input is small is up to 33%
110        // faster (50% throughput improvement).
111        let block_encoder = config.into_block_decoder();
112        block_encoder.decode_blocks(input, output)?
113    };
114
115    input = &input[input_idx..];
116    output = &mut output[output_idx..];
117
118    let mut iter = DecodeIter::new(input, output);
119    while let Some((input, output)) = iter.next_chunk() {
120        decode_chunk(config, *input, output).map_err(DecodeError::InvalidByte)?;
121    }
122
123    let (input_idx2, output_idx2) = iter.remaining();
124    Ok((input_idx + input_idx2, output_idx + output_idx2))
125}
126
127#[inline]
128fn decode_partial_chunk<C>(config: C, input: &[u8], output: &mut [u8]) -> Result<usize, DecodeError>
129where
130    C: Config,
131{
132    // Deal with the remaining partial chunk. The padding characters have already been removed.
133    match input.len() {
134        0 => Ok(0),
135        1 => Err(DecodeError::InvalidLength),
136        2 => {
137            let first = config.decode_u8(input[0]);
138            if first == INVALID_VALUE {
139                return Err(DecodeError::InvalidByte(input[0]));
140            }
141            let second = config.decode_u8(input[1]);
142            if second == INVALID_VALUE {
143                return Err(DecodeError::InvalidByte(input[1]));
144            }
145            output[0] = (first << 2) | (second >> 4);
146            if second & 0b0000_1111 != 0 {
147                return Err(DecodeError::InvalidTrailingBits);
148            }
149            Ok(1)
150        }
151        3 => {
152            let first = config.decode_u8(input[0]);
153            if first == INVALID_VALUE {
154                return Err(DecodeError::InvalidByte(input[0]));
155            }
156            let second = config.decode_u8(input[1]);
157            if second == INVALID_VALUE {
158                return Err(DecodeError::InvalidByte(input[1]));
159            }
160            let third = config.decode_u8(input[2]);
161            if third == INVALID_VALUE {
162                return Err(DecodeError::InvalidByte(input[2]));
163            }
164            output[0] = (first << 2) | (second >> 4);
165            output[1] = (second << 4) | (third >> 2);
166            if third & 0b0000_0011 != 0 {
167                return Err(DecodeError::InvalidTrailingBits);
168            }
169            Ok(2)
170        }
171        x => unreachable!("impossible remainder: {}", x),
172    }
173}
174
175/// Decode a chunk. The chunk cannot contain any padding.
176#[inline]
177fn decode_chunk<C: Config>(config: C, input: [u8; 4], output: &mut [u8; 3]) -> Result<(), u8> {
178    let mut chunk_output: u32 = 0;
179    for (idx, input) in input.iter().cloned().enumerate() {
180        let decoded = config.decode_u8(input);
181        if decoded == INVALID_VALUE {
182            return Err(input);
183        }
184        let shift_amount = 32 - (idx as u32 + 1) * 6;
185        chunk_output |= u32::from(decoded) << shift_amount;
186    }
187    debug_assert!(chunk_output.trailing_zeros() >= 8);
188    write_be_u24(chunk_output, output);
189    Ok(())
190}
191
192/// Copy the 24 most significant bits into the provided buffer.
193#[inline]
194fn write_be_u24(n: u32, buf: &mut [u8; 3]) {
195    unsafe {
196        let n: [u8; 4] = *(&n.to_be() as *const _ as *const [u8; 4]);
197        std::ptr::copy_nonoverlapping(n.as_ptr(), buf.as_mut_ptr(), 3);
198    }
199}
200
201#[inline]
202pub(crate) fn decode_using_table(table: &[u8; 256], input: u8) -> u8 {
203    table[input as usize]
204}
205
206define_block_iter!(
207    name = DecodeIter,
208    input_chunk_size = 4,
209    input_stride = 4,
210    output_chunk_size = 3,
211    output_stride = 3
212);
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    #[test]
218
219    fn detect_trailing_bits() {
220        use crate::STD;
221        assert!(STD.decode("iYU=").is_ok());
222        assert_eq!(Err(DecodeError::InvalidTrailingBits), STD.decode("iYV="));
223        assert_eq!(Err(DecodeError::InvalidTrailingBits), STD.decode("iYW="));
224        assert_eq!(Err(DecodeError::InvalidTrailingBits), STD.decode("iYX="));
225        assert_eq!(
226            Err(DecodeError::InvalidTrailingBits),
227            STD.decode("AAAAiYX=")
228        );
229    }
230
231}