use core::convert::Infallible;
use core::mem::{self, MaybeUninit};
use core::iter;
use crate::{Decode, Encode};
use crate::util::{Carry, Group, Writer};
use super::{DecodeError, Format};
#[derive(Clone, Debug)]
pub struct Encoder {
format: Format,
carry: Carry<u8, 3>,
}
impl Encoder {
pub fn new(format: Format) -> Self {
Self {
format,
carry: Carry::new(),
}
}
}
impl Encode for Encoder {
type Error = Infallible;
fn encode_buffer_size(&self, decoded_size: usize) -> usize {
((decoded_size + self.carry.len()) / 3) * 4
}
fn encode(
&mut self,
decoded: &[u8],
encoded: &mut [MaybeUninit<u8>],
) -> Result<usize, Self::Error> {
let mut encoded = Writer::new(encoded);
assert!(encoded.cap() >= self.encode_buffer_size(decoded.len()));
decoded.iter()
.filter_map(|&c| self.carry.push(c))
.map(|b| Block(b.into()).encode(&self.format))
.for_each(|b| encoded.extend(b));
Ok(encoded.as_ref().len())
}
fn encode_end_buffer_size(&self) -> usize {
Block(self.carry.clone().into()).encode_len(&self.format).1
}
fn encode_end(
&mut self,
encoded: &mut [MaybeUninit<u8>],
) -> Result<usize, Self::Error> {
let mut encoded = Writer::new(encoded);
assert!(encoded.cap() >= self.encode_end_buffer_size());
let decoded = mem::take(&mut self.carry).into();
encoded.extend(Block(decoded).encode(&self.format));
Ok(encoded.as_ref().len())
}
}
#[derive(Clone, Debug)]
pub struct Decoder {
format: Format,
en_carry: Carry<Index, 4>,
de_carry: Carry<u8, 3>,
}
impl Decoder {
pub fn new(format: Format) -> Self {
Self {
format,
en_carry: Carry::new(),
de_carry: Carry::new(),
}
}
}
impl Decode for Decoder {
type Error = DecodeError;
fn decode_buffer_size(&self, encoded_size: usize) -> usize {
((encoded_size + self.en_carry.len()) / 4) * 3
}
fn decode(
&mut self,
encoded: &[u8],
decoded: &mut [MaybeUninit<u8>],
) -> Result<usize, Self::Error> {
let mut decoded = Writer::new(decoded);
assert!(decoded.cap() >= self.decode_buffer_size(encoded.len()));
let g_err = (!self.format.decode_garbage()).then_some(Err(DecodeError));
encoded.iter()
.map(|&c| Index::decode(c, &self.format))
.filter_map(|i| i.map(Ok).or(g_err.clone()))
.filter_map(|i| i.map(|i| self.en_carry.push(i)).transpose())
.map(|b| Ok(Block::decode(b?.into(), &self.format)?.0))
.filter_map(|b| b.map(|b| self.de_carry.append_group(b)).transpose())
.try_for_each(|b| Ok(decoded.extend(b?)))?;
Ok(decoded.as_ref().len())
}
fn decode_end_buffer_size(&self) -> usize {
Block::decode_len(&self.en_carry.clone().into(), &self.format)
.map_or(0, |(_, dst)| dst) + self.de_carry.len()
}
fn decode_end(
&mut self,
decoded: &mut [MaybeUninit<u8>],
) -> Result<usize, Self::Error> {
let mut decoded = Writer::new(decoded);
assert!(decoded.cap() >= self.decode_end_buffer_size());
let encoded = mem::take(&mut self.en_carry).into();
let post_group = Block::decode(encoded, &self.format)?.0;
decoded.extend(self.de_carry.drain().chain(post_group));
Ok(decoded.as_ref().len())
}
}
#[derive(Clone, Debug)]
struct Block(Group<u8, 3>);
impl Block {
fn encode(self, format: &Format) -> Group<u8, 4> {
let (mid, end) = self.encode_len(format);
let val = self.0.into_iter().enumerate()
.map(|(i, b)| (b as u64) << (64 - (i + 1) * 8))
.fold(0u64, |v, b| v | b);
let mut res = Group::new();
res.extend((0 .. mid)
.map(|i| ((val >> (64 - (i + 1) * 6)) & 63) as u8)
.map(Index::Regular)
.chain(iter::repeat(Index::Padding))
.take(end)
.map(|i| i.encode(format))
.by_ref());
res
}
fn encode_len(&self, format: &Format) -> (usize, usize) {
let mid = match self.0.len() {
0 => 0,
1 => 2,
2 => 3,
3 => 4,
_ => unreachable!(),
};
if format.encode_padding() {
(mid, mid.div_ceil(4) * 4)
} else {
(mid, mid)
}
}
fn decode(
data: Group<Index, 4>,
format: &Format,
) -> Result<Self, DecodeError> {
let (src, dst) = Self::decode_len(&data, format)?;
let val = data[.. src].iter()
.enumerate()
.map(|(i, &e)| match e {
Index::Regular(b) => (b as u64) << (64 - (i + 1) * 6),
Index::Padding => unreachable!(),
})
.fold(0u64, |v, b| v | b);
let mut res = Group::new();
res.extend((0 .. dst)
.map(|i| (val >> (64 - (i + 1) * 8)) as u8)
.by_ref());
Ok(Self(res))
}
fn decode_len(
data: &Group<Index, 4>,
format: &Format,
) -> Result<(usize, usize), DecodeError> {
let src = if format.decode_padding() {
data.len() - data.iter().try_fold(0, |p, &e| match (p, e) {
(p, Index::Padding) => Ok(p + 1),
(0, Index::Regular(_)) => Ok(0),
(_, Index::Regular(_)) => Err(DecodeError),
})?
} else {
data.iter().try_fold(0, |n, &e| match e {
Index::Regular(_) => Ok(n + 1),
Index::Padding => Err(DecodeError),
})?
};
Ok((src, match src {
0 => 0,
2 => 1,
3 => 2,
4 => 3,
1 => return Err(DecodeError),
_ => unreachable!(),
}))
}
}
#[derive(Copy, Clone, Debug)]
enum Index {
Regular(u8),
Padding,
}
impl Index {
fn encode(self, format: &Format) -> u8 {
match format {
Format::Base64 { .. } => match self {
Self::Regular(index @ ( 0 ..= 25)) => index - 0 + b'A',
Self::Regular(index @ (26 ..= 51)) => index - 26 + b'a',
Self::Regular(index @ (52 ..= 61)) => index - 52 + b'0',
Self::Regular(62) => b'+',
Self::Regular(63) => b'/',
Self::Padding if format.encode_padding() => b'=',
_ => unreachable!(),
},
Format::Base64URL { .. } => match self {
Self::Regular(index @ ( 0 ..= 25)) => index - 0 + b'A',
Self::Regular(index @ (26 ..= 51)) => index - 26 + b'a',
Self::Regular(index @ (52 ..= 61)) => index - 52 + b'0',
Self::Regular(62) => b'-',
Self::Regular(63) => b'_',
Self::Padding if format.encode_padding() => b'=',
_ => unreachable!(),
},
}
}
fn decode(data: u8, format: &Format) -> Option<Self> {
Some(match format {
Format::Base64 { .. } => match data {
b'A' ..= b'Z' => Self::Regular(data - b'A' + 0),
b'a' ..= b'z' => Self::Regular(data - b'a' + 26),
b'0' ..= b'9' => Self::Regular(data - b'0' + 52),
b'+' => Self::Regular(62),
b'/' => Self::Regular(63),
b'=' if format.decode_padding() => Self::Padding,
_ => return None,
},
Format::Base64URL { .. } => match data {
b'A' ..= b'Z' => Self::Regular(data - b'A' + 0),
b'a' ..= b'z' => Self::Regular(data - b'a' + 26),
b'0' ..= b'9' => Self::Regular(data - b'0' + 52),
b'-' => Self::Regular(62),
b'_' => Self::Regular(63),
b'=' if format.decode_padding() => Self::Padding,
_ => return None,
},
})
}
}