1use compile_fmt::{clip_ascii, compile_assert, compile_panic, fmt, Ascii};
4
5use crate::wrappers::{SkipWhitespace, Skipper};
6
7macro_rules! const_try {
9 ($result:expr) => {
10 match $result {
11 Ok(value) => value,
12 Err(err) => return Err(err),
13 }
14 };
15}
16
17#[derive(Debug)]
18struct DecodeError {
19 invalid_char: u8,
20 alphabet: Option<Ascii<'static>>,
22}
23
24impl DecodeError {
25 const fn invalid_char(invalid_char: u8, alphabet: Option<Ascii<'static>>) -> Self {
26 Self {
27 invalid_char,
28 alphabet,
29 }
30 }
31
32 const fn panic(self, input_pos: usize) -> ! {
33 if self.invalid_char.is_ascii() {
34 if let Some(alphabet) = self.alphabet {
35 compile_panic!(
36 "Character '", self.invalid_char as char => fmt::<char>(), "' at position ",
37 input_pos => fmt::<usize>(), " is not a part of \
38 the decoder alphabet '", alphabet => clip_ascii(64, ""), "'"
39 );
40 } else {
41 compile_panic!(
42 "Character '", self.invalid_char as char => fmt::<char>(), "' at position ",
43 input_pos => fmt::<usize>(), " is not a hex digit"
44 );
45 }
46 } else {
47 compile_panic!(
48 "Non-ASCII character with decimal code ", self.invalid_char => fmt::<u8>(),
49 " encountered at position ", input_pos => fmt::<usize>()
50 );
51 }
52 }
53}
54
55#[derive(Debug, Clone, Copy)]
72pub struct Encoding {
73 alphabet: Ascii<'static>,
74 table: [u8; 128],
75 bits_per_char: u8,
76}
77
78impl Encoding {
79 const NO_MAPPING: u8 = u8::MAX;
80
81 const BASE64: Self =
82 Self::new("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/");
83 const BASE64_URL: Self =
84 Self::new("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_");
85
86 #[allow(clippy::cast_possible_truncation)]
94 pub const fn new(alphabet: &'static str) -> Self {
95 let bits_per_char = match alphabet.len() {
96 2 => 1,
97 4 => 2,
98 8 => 3,
99 16 => 4,
100 32 => 5,
101 64 => 6,
102 other => compile_panic!(
103 "Invalid alphabet length ", other => fmt::<usize>(),
104 "; must be one of 2, 4, 8, 16, 32, or 64"
105 ),
106 };
107
108 let mut table = [Self::NO_MAPPING; 128];
109 let alphabet_bytes = alphabet.as_bytes();
110 let alphabet = Ascii::new(alphabet); let mut index = 0;
112 while index < alphabet_bytes.len() {
113 let byte = alphabet_bytes[index];
114 let byte_idx = byte as usize;
115 compile_assert!(
116 table[byte_idx] == Self::NO_MAPPING,
117 "Alphabet character '", byte as char => fmt::<char>(), "' is mentioned several times"
118 );
119 table[byte_idx] = index as u8;
120 index += 1;
121 }
122
123 Self {
124 alphabet,
125 table,
126 bits_per_char,
127 }
128 }
129
130 const fn lookup(&self, ascii_char: u8) -> Result<u8, DecodeError> {
131 if !ascii_char.is_ascii() {
132 return Err(DecodeError::invalid_char(ascii_char, Some(self.alphabet)));
133 }
134 let mapping = self.table[ascii_char as usize];
135 if mapping == Self::NO_MAPPING {
136 Err(DecodeError::invalid_char(ascii_char, Some(self.alphabet)))
137 } else {
138 Ok(mapping)
139 }
140 }
141}
142
143#[derive(Debug, Clone, Copy)]
145struct HexDecoderState(Option<u8>);
146
147impl HexDecoderState {
148 const fn byte_value(val: u8) -> Result<u8, DecodeError> {
149 Ok(match val {
150 b'0'..=b'9' => val - b'0',
151 b'A'..=b'F' => val - b'A' + 10,
152 b'a'..=b'f' => val - b'a' + 10,
153 _ => return Err(DecodeError::invalid_char(val, None)),
154 })
155 }
156
157 const fn new() -> Self {
158 Self(None)
159 }
160
161 #[allow(clippy::option_if_let_else)] const fn update(mut self, byte: u8) -> Result<(Self, Option<u8>), DecodeError> {
163 let byte = const_try!(Self::byte_value(byte));
164 let output = if let Some(b) = self.0 {
165 self.0 = None;
166 Some((b << 4) + byte)
167 } else {
168 self.0 = Some(byte);
169 None
170 };
171 Ok((self, output))
172 }
173
174 const fn is_final(self) -> bool {
175 self.0.is_none()
176 }
177}
178
179#[derive(Debug, Clone, Copy)]
181struct CustomDecoderState {
182 table: Encoding,
183 partial_byte: u8,
184 filled_bits: u8,
185}
186
187impl CustomDecoderState {
188 const fn new(table: Encoding) -> Self {
189 Self {
190 table,
191 partial_byte: 0,
192 filled_bits: 0,
193 }
194 }
195
196 #[allow(clippy::comparison_chain)] const fn update(mut self, byte: u8) -> Result<(Self, Option<u8>), DecodeError> {
198 let byte = const_try!(self.table.lookup(byte));
199 let output = if self.filled_bits < 8 - self.table.bits_per_char {
200 self.partial_byte = (self.partial_byte << self.table.bits_per_char) + byte;
201 self.filled_bits += self.table.bits_per_char;
202 None
203 } else if self.filled_bits == 8 - self.table.bits_per_char {
204 let output = (self.partial_byte << self.table.bits_per_char) + byte;
205 self.partial_byte = 0;
206 self.filled_bits = 0;
207 Some(output)
208 } else {
209 let remaining_bits = 8 - self.filled_bits;
210 let new_filled_bits = self.table.bits_per_char - remaining_bits;
211 let output = (self.partial_byte << remaining_bits) + (byte >> new_filled_bits);
212 self.partial_byte = byte % (1 << new_filled_bits);
213 self.filled_bits = new_filled_bits;
214 Some(output)
215 };
216 Ok((self, output))
217 }
218
219 const fn is_final(&self) -> bool {
220 self.partial_byte == 0
222 }
223}
224
225#[derive(Debug, Clone, Copy)]
227enum DecoderState {
228 Hex(HexDecoderState),
229 Base64(CustomDecoderState),
230 Custom(CustomDecoderState),
231}
232
233impl DecoderState {
234 const fn update(self, byte: u8) -> Result<(Self, Option<u8>), DecodeError> {
235 Ok(match self {
236 Self::Hex(state) => {
237 let (updated_state, output) = const_try!(state.update(byte));
238 (Self::Hex(updated_state), output)
239 }
240 Self::Base64(state) => {
241 if byte == b'=' {
242 (self, None)
243 } else {
244 let (updated_state, output) = const_try!(state.update(byte));
245 (Self::Base64(updated_state), output)
246 }
247 }
248 Self::Custom(state) => {
249 let (updated_state, output) = const_try!(state.update(byte));
250 (Self::Custom(updated_state), output)
251 }
252 })
253 }
254
255 const fn is_final(&self) -> bool {
256 match self {
257 Self::Hex(state) => state.is_final(),
258 Self::Base64(state) | Self::Custom(state) => state.is_final(),
259 }
260 }
261}
262
263#[derive(Debug, Clone, Copy)]
269#[non_exhaustive]
270pub enum Decoder {
271 Hex,
273 Base64,
278 Base64Url,
283 Custom(Encoding),
285}
286
287impl Decoder {
288 pub const fn custom(alphabet: &'static str) -> Self {
294 Self::Custom(Encoding::new(alphabet))
295 }
296
297 pub const fn skip_whitespace(self) -> SkipWhitespace {
299 SkipWhitespace(self)
300 }
301
302 const fn new_state(self) -> DecoderState {
303 match self {
304 Self::Hex => DecoderState::Hex(HexDecoderState::new()),
305 Self::Base64 => DecoderState::Base64(CustomDecoderState::new(Encoding::BASE64)),
306 Self::Base64Url => DecoderState::Base64(CustomDecoderState::new(Encoding::BASE64_URL)),
307 Self::Custom(encoding) => DecoderState::Custom(CustomDecoderState::new(encoding)),
308 }
309 }
310
311 pub const fn decode<const N: usize>(self, input: &[u8]) -> [u8; N] {
318 self.do_decode(input, None)
319 }
320
321 pub(crate) const fn do_decode<const N: usize>(
322 self,
323 input: &[u8],
324 skipper: Option<Skipper>,
325 ) -> [u8; N] {
326 let mut bytes = [0_u8; N];
327 let mut in_index = 0;
328 let mut out_index = 0;
329 let mut state = self.new_state();
330
331 while in_index < input.len() {
332 if let Some(skipper) = skipper {
333 let new_in_index = skipper.skip(input, in_index);
334 if new_in_index != in_index {
335 in_index = new_in_index;
336 continue;
337 }
338 }
339
340 let update = match state.update(input[in_index]) {
341 Ok(update) => update,
342 Err(err) => err.panic(in_index),
343 };
344 state = update.0;
345 if let Some(byte) = update.1 {
346 if out_index < N {
347 bytes[out_index] = byte;
348 }
349 out_index += 1;
350 }
351 in_index += 1;
352 }
353
354 compile_assert!(
355 out_index <= N,
356 "Output overflow: the input decodes to ", out_index => fmt::<usize>(),
357 " bytes, while type inference implies ", N => fmt::<usize>(), ". \
358 Either fix the input or change the output buffer length correspondingly"
359 );
360 compile_assert!(
361 out_index == N,
362 "Output underflow: the input decodes to ", out_index => fmt::<usize>(),
363 " bytes, while type inference implies ", N => fmt::<usize>(), ". \
364 Either fix the input or change the output buffer length correspondingly"
365 );
366
367 assert!(
368 state.is_final(),
369 "Left-over state after processing input. This usually means that the input \
370 is incorrect (e.g., an odd number of hex digits)."
371 );
372 bytes
373 }
374
375 pub(crate) const fn do_decode_len(self, input: &[u8], skipper: Option<Skipper>) -> usize {
376 let mut in_index = 0;
377 let mut out_index = 0;
378 let mut state = self.new_state();
379
380 while in_index < input.len() {
381 if let Some(skipper) = skipper {
382 let new_in_index = skipper.skip(input, in_index);
383 if new_in_index != in_index {
384 in_index = new_in_index;
385 continue;
386 }
387 }
388
389 let update = match state.update(input[in_index]) {
390 Ok(update) => update,
391 Err(err) => err.panic(in_index),
392 };
393 state = update.0;
394 if update.1.is_some() {
395 out_index += 1;
396 }
397 in_index += 1;
398 }
399 out_index
400 }
401}