1use crate::Config;
2use std::{error, fmt};
3
4pub(crate) mod block;
5pub(crate) mod io;
6
7pub(crate) const INVALID_VALUE: u8 = 255;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum DecodeError {
12 InvalidByte(u8),
14 InvalidLength,
16 InvalidTrailingBits,
20}
21
22impl fmt::Display for DecodeError {
23 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
24 match *self {
25 DecodeError::InvalidByte(byte) => write!(f, "invalid byte {}", byte),
26 DecodeError::InvalidLength => write!(f, "encoded text cannot have a 6-bit remainder"),
27 DecodeError::InvalidTrailingBits => {
28 write!(f, "last byte has unnecessary trailing bits")
29 }
30 }
31 }
32}
33
34impl error::Error for DecodeError {
35 fn description(&self) -> &str {
36 match *self {
37 DecodeError::InvalidByte(_) => "invalid byte",
38 DecodeError::InvalidLength => "invalid length",
39 DecodeError::InvalidTrailingBits => "invalid trailing bits",
40 }
41 }
42
43 fn cause(&self) -> Option<&dyn error::Error> {
44 None
45 }
46}
47
48pub(crate) fn decode_slice<C>(
50 config: C,
51 mut input: &[u8],
52 mut output: &mut [u8],
53) -> Result<usize, DecodeError>
54where
55 C: Config,
56{
57 input = remove_padding(config, input)?;
58 let (input_idx, output_idx) = decode_full_chunks_without_padding(config, input, output)?;
59 input = &input[input_idx..];
60 output = &mut output[output_idx..];
61
62 Ok(output_idx + decode_partial_chunk(config, input, output)?)
64}
65
66#[inline]
67fn remove_padding<C>(config: C, input: &[u8]) -> Result<&[u8], DecodeError>
68where
69 C: Config,
70{
71 Ok(if let Some(padding) = config.padding_byte() {
72 if input.len() % 4 != 0 {
73 return Err(DecodeError::InvalidLength);
74 }
75 let num_padding_bytes = input
76 .iter()
77 .rev()
78 .cloned()
79 .take_while(|&b| b == padding)
80 .take(2)
81 .count();
82 match num_padding_bytes {
83 0 => input,
84 1 => &input[..input.len() - 1],
85 2 => &input[..input.len() - 2],
86 _ => unreachable!("impossible number of padding bytes"),
87 }
88 } else {
89 input
90 })
91}
92
93#[inline]
94fn decode_full_chunks_without_padding<C>(
95 config: C,
96 mut input: &[u8],
97 mut output: &mut [u8],
98) -> Result<(usize, usize), DecodeError>
99where
100 C: Config,
101{
102 use crate::decode::block::BlockDecoder;
103 let (input_idx, output_idx) = if input.len() < 32 {
104 (0, 0)
105 } else {
106 let block_encoder = config.into_block_decoder();
112 block_encoder.decode_blocks(input, output)?
113 };
114
115 input = &input[input_idx..];
116 output = &mut output[output_idx..];
117
118 let mut iter = DecodeIter::new(input, output);
119 while let Some((input, output)) = iter.next_chunk() {
120 decode_chunk(config, *input, output).map_err(DecodeError::InvalidByte)?;
121 }
122
123 let (input_idx2, output_idx2) = iter.remaining();
124 Ok((input_idx + input_idx2, output_idx + output_idx2))
125}
126
127#[inline]
128fn decode_partial_chunk<C>(config: C, input: &[u8], output: &mut [u8]) -> Result<usize, DecodeError>
129where
130 C: Config,
131{
132 match input.len() {
134 0 => Ok(0),
135 1 => Err(DecodeError::InvalidLength),
136 2 => {
137 let first = config.decode_u8(input[0]);
138 if first == INVALID_VALUE {
139 return Err(DecodeError::InvalidByte(input[0]));
140 }
141 let second = config.decode_u8(input[1]);
142 if second == INVALID_VALUE {
143 return Err(DecodeError::InvalidByte(input[1]));
144 }
145 output[0] = (first << 2) | (second >> 4);
146 if second & 0b0000_1111 != 0 {
147 return Err(DecodeError::InvalidTrailingBits);
148 }
149 Ok(1)
150 }
151 3 => {
152 let first = config.decode_u8(input[0]);
153 if first == INVALID_VALUE {
154 return Err(DecodeError::InvalidByte(input[0]));
155 }
156 let second = config.decode_u8(input[1]);
157 if second == INVALID_VALUE {
158 return Err(DecodeError::InvalidByte(input[1]));
159 }
160 let third = config.decode_u8(input[2]);
161 if third == INVALID_VALUE {
162 return Err(DecodeError::InvalidByte(input[2]));
163 }
164 output[0] = (first << 2) | (second >> 4);
165 output[1] = (second << 4) | (third >> 2);
166 if third & 0b0000_0011 != 0 {
167 return Err(DecodeError::InvalidTrailingBits);
168 }
169 Ok(2)
170 }
171 x => unreachable!("impossible remainder: {}", x),
172 }
173}
174
175#[inline]
177fn decode_chunk<C: Config>(config: C, input: [u8; 4], output: &mut [u8; 3]) -> Result<(), u8> {
178 let mut chunk_output: u32 = 0;
179 for (idx, input) in input.iter().cloned().enumerate() {
180 let decoded = config.decode_u8(input);
181 if decoded == INVALID_VALUE {
182 return Err(input);
183 }
184 let shift_amount = 32 - (idx as u32 + 1) * 6;
185 chunk_output |= u32::from(decoded) << shift_amount;
186 }
187 debug_assert!(chunk_output.trailing_zeros() >= 8);
188 write_be_u24(chunk_output, output);
189 Ok(())
190}
191
192#[inline]
194fn write_be_u24(n: u32, buf: &mut [u8; 3]) {
195 unsafe {
196 let n: [u8; 4] = *(&n.to_be() as *const _ as *const [u8; 4]);
197 std::ptr::copy_nonoverlapping(n.as_ptr(), buf.as_mut_ptr(), 3);
198 }
199}
200
201#[inline]
202pub(crate) fn decode_using_table(table: &[u8; 256], input: u8) -> u8 {
203 table[input as usize]
204}
205
206define_block_iter!(
207 name = DecodeIter,
208 input_chunk_size = 4,
209 input_stride = 4,
210 output_chunk_size = 3,
211 output_stride = 3
212);
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 #[test]
218
219 fn detect_trailing_bits() {
220 use crate::STD;
221 assert!(STD.decode("iYU=").is_ok());
222 assert_eq!(Err(DecodeError::InvalidTrailingBits), STD.decode("iYV="));
223 assert_eq!(Err(DecodeError::InvalidTrailingBits), STD.decode("iYW="));
224 assert_eq!(Err(DecodeError::InvalidTrailingBits), STD.decode("iYX="));
225 assert_eq!(
226 Err(DecodeError::InvalidTrailingBits),
227 STD.decode("AAAAiYX=")
228 );
229 }
230
231}