use crate::wrappers::{SkipWhitespace, Skipper};
#[derive(Debug, Clone, Copy)]
pub struct Encoding {
table: [u8; 128],
bits_per_char: u8,
}
impl Encoding {
const NO_MAPPING: u8 = u8::MAX;
const BASE64: Self =
Self::new("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/");
const BASE64_URL: Self =
Self::new("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_");
#[allow(clippy::cast_possible_truncation)]
pub const fn new(alphabet: &str) -> Self {
let bits_per_char = match alphabet.len() {
2 => 1,
4 => 2,
8 => 3,
16 => 4,
32 => 5,
64 => 6,
_ => panic!("Invalid alphabet length; must be one of 2, 4, 8, 16, 32, or 64"),
};
let mut table = [Self::NO_MAPPING; 128];
let alphabet_bytes = alphabet.as_bytes();
let mut index = 0;
while index < alphabet_bytes.len() {
let byte = alphabet_bytes[index];
assert!(byte < 0x80, "Non-ASCII alphabet character");
let byte_idx = byte as usize;
assert!(
table[byte_idx] == Self::NO_MAPPING,
"Alphabet character is mentioned several times"
);
table[byte_idx] = index as u8;
index += 1;
}
Self {
table,
bits_per_char,
}
}
const fn lookup(&self, ascii_char: u8) -> u8 {
let mapping = self.table[ascii_char as usize];
assert!(
mapping != Self::NO_MAPPING,
"Character is not present in the alphabet"
);
mapping
}
}
#[derive(Debug, Clone, Copy)]
struct HexDecoderState(Option<u8>);
impl HexDecoderState {
const fn byte_value(val: u8) -> u8 {
match val {
b'0'..=b'9' => val - b'0',
b'A'..=b'F' => val - b'A' + 10,
b'a'..=b'f' => val - b'a' + 10,
_ => panic!("Invalid character in input; expected a hex digit"),
}
}
const fn new() -> Self {
Self(None)
}
#[allow(clippy::option_if_let_else)] const fn update(mut self, byte: u8) -> (Self, Option<u8>) {
let byte = Self::byte_value(byte);
let output = if let Some(b) = self.0 {
self.0 = None;
Some((b << 4) + byte)
} else {
self.0 = Some(byte);
None
};
(self, output)
}
const fn is_final(self) -> bool {
self.0.is_none()
}
}
#[derive(Debug, Clone, Copy)]
struct CustomDecoderState {
table: Encoding,
partial_byte: u8,
filled_bits: u8,
}
impl CustomDecoderState {
const fn new(table: Encoding) -> Self {
Self {
table,
partial_byte: 0,
filled_bits: 0,
}
}
#[allow(clippy::comparison_chain)] const fn update(mut self, byte: u8) -> (Self, Option<u8>) {
let byte = self.table.lookup(byte);
let output = if self.filled_bits < 8 - self.table.bits_per_char {
self.partial_byte = (self.partial_byte << self.table.bits_per_char) + byte;
self.filled_bits += self.table.bits_per_char;
None
} else if self.filled_bits == 8 - self.table.bits_per_char {
let output = (self.partial_byte << self.table.bits_per_char) + byte;
self.partial_byte = 0;
self.filled_bits = 0;
Some(output)
} else {
let remaining_bits = 8 - self.filled_bits;
let new_filled_bits = self.table.bits_per_char - remaining_bits;
let output = (self.partial_byte << remaining_bits) + (byte >> new_filled_bits);
self.partial_byte = byte % (1 << new_filled_bits);
self.filled_bits = new_filled_bits;
Some(output)
};
(self, output)
}
const fn is_final(&self) -> bool {
self.partial_byte == 0
}
}
#[derive(Debug, Clone, Copy)]
enum DecoderState {
Hex(HexDecoderState),
Base64(CustomDecoderState),
Custom(CustomDecoderState),
}
impl DecoderState {
const fn update(self, byte: u8) -> (Self, Option<u8>) {
match self {
Self::Hex(state) => {
let (updated_state, output) = state.update(byte);
(Self::Hex(updated_state), output)
}
Self::Base64(state) => {
if byte == b'=' {
(self, None)
} else {
let (updated_state, output) = state.update(byte);
(Self::Base64(updated_state), output)
}
}
Self::Custom(state) => {
let (updated_state, output) = state.update(byte);
(Self::Custom(updated_state), output)
}
}
}
const fn is_final(&self) -> bool {
match self {
Self::Hex(state) => state.is_final(),
Self::Base64(state) | Self::Custom(state) => state.is_final(),
}
}
}
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub enum Decoder {
Hex,
Base64,
Base64Url,
Custom(Encoding),
}
impl Decoder {
pub const fn custom(alphabet: &str) -> Self {
Self::Custom(Encoding::new(alphabet))
}
pub const fn skip_whitespace(self) -> SkipWhitespace {
SkipWhitespace(self)
}
const fn new_state(self) -> DecoderState {
match self {
Self::Hex => DecoderState::Hex(HexDecoderState::new()),
Self::Base64 => DecoderState::Base64(CustomDecoderState::new(Encoding::BASE64)),
Self::Base64Url => DecoderState::Base64(CustomDecoderState::new(Encoding::BASE64_URL)),
Self::Custom(encoding) => DecoderState::Custom(CustomDecoderState::new(encoding)),
}
}
pub const fn decode<const N: usize>(self, input: &[u8]) -> [u8; N] {
self.do_decode(input, None)
}
pub(crate) const fn do_decode<const N: usize>(
self,
input: &[u8],
skipper: Option<Skipper>,
) -> [u8; N] {
let mut bytes = [0_u8; N];
let mut in_index = 0;
let mut out_index = 0;
let mut state = self.new_state();
while in_index < input.len() {
if let Some(skipper) = skipper {
let new_in_index = skipper.skip(input, in_index);
if new_in_index != in_index {
in_index = new_in_index;
continue;
}
}
let update = state.update(input[in_index]);
state = update.0;
if let Some(byte) = update.1 {
assert!(
out_index < N,
"Output overflow: the input decodes to more bytes than specified \
as the output length"
);
bytes[out_index] = byte;
out_index += 1;
}
in_index += 1;
}
assert!(
out_index == N,
"Output underflow: the input was decoded into less bytes than specified \
as the output length"
);
assert!(
state.is_final(),
"Left-over state after processing input. This usually means that the input \
is incorrect (e.g., an odd number of hex digits)."
);
bytes
}
}