glassy 0.0.3

An easy and fast library for encoding and decoding binary data.
Documentation
//! A generic implementation of Base32.
//!
//! This module optimizes for simplicity and readability, rather than any degree
//! of performance.  It is meant to provide an overview of the algorithm, and to
//! be usable on any platform.

use core::convert::Infallible;
use core::mem::{self, MaybeUninit};
use core::iter;

use crate::{Decode, Encode};
use crate::util::{Carry, Group, Writer};
use super::{DecodeError, Format};

/// Encoder state.
#[derive(Clone, Debug)]
pub struct Encoder {
    /// The format to use for encoding.
    format: Format,

    /// The carry-over buffer.
    carry: Carry<u8, 5>,
}

impl Encoder {
    /// Construct a new [`Encoder`] starting from scratch.
    pub fn new(format: Format) -> Self {
        Self {
            format,
            carry: Carry::new(),
        }
    }
}
 
impl Encode for Encoder {
    type Error = Infallible;

    fn encode_buffer_size(&self, decoded_size: usize) -> usize {
        ((decoded_size + self.carry.len()) / 5) * 8
    }

    fn encode(
        &mut self,
        decoded: &[u8], 
        encoded: &mut [MaybeUninit<u8>],
    ) -> Result<usize, Self::Error> {
        let mut encoded = Writer::new(encoded);
        assert!(encoded.cap() >= self.encode_buffer_size(decoded.len()));

        decoded.iter()
            .filter_map(|&c| self.carry.push(c))
            .map(|b| Block(b.into()).encode(&self.format))
            .for_each(|b| encoded.extend(b));

        Ok(encoded.as_ref().len())
    }

    fn encode_end_buffer_size(&self) -> usize {
        Block(self.carry.clone().into()).encode_len(&self.format).1
    }

    fn encode_end(
        &mut self,
        encoded: &mut [MaybeUninit<u8>],
    ) -> Result<usize, Self::Error> {
        let mut encoded = Writer::new(encoded);
        assert!(encoded.cap() >= self.encode_end_buffer_size());

        let decoded = mem::take(&mut self.carry).into();
        encoded.extend(Block(decoded).encode(&self.format));

        Ok(encoded.as_ref().len())
    }
}

/// Decoder state.
#[derive(Clone, Debug)]
pub struct Decoder {
    /// The format to use for decoding.
    format: Format,

    /// The carry-over buffer of encoded data.
    en_carry: Carry<Index, 8>,

    /// The carry-over buffer of decoded data.
    de_carry: Carry<u8, 5>,
}

impl Decoder {
    /// Construct a new [`Decoder`] starting from scratch.
    pub fn new(format: Format) -> Self {
        Self {
            format,
            en_carry: Carry::new(),
            de_carry: Carry::new(),
        }
    }
}

impl Decode for Decoder {
    type Error = DecodeError;

    fn decode_buffer_size(&self, encoded_size: usize) -> usize {
        ((encoded_size + self.en_carry.len()) / 8) * 5
    }

    fn decode(
        &mut self,
        encoded: &[u8],
        decoded: &mut [MaybeUninit<u8>],
    ) -> Result<usize, Self::Error> {
        let mut decoded = Writer::new(decoded);
        assert!(decoded.cap() >= self.decode_buffer_size(encoded.len()));

        let g_err = (!self.format.decode_garbage()).then_some(Err(DecodeError));
        encoded.iter()
            .map(|&c| Index::decode(c, &self.format))
            .filter_map(|i| i.map(Ok).or(g_err.clone()))
            .filter_map(|i| i.map(|i| self.en_carry.push(i)).transpose())
            .map(|b| Ok(Block::decode(b?.into(), &self.format)?.0))
            .filter_map(|b| b.map(|b| self.de_carry.append_group(b)).transpose())
            .try_for_each(|b| Ok(decoded.extend(b?)))?;

        Ok(decoded.as_ref().len())
    }

    fn decode_end_buffer_size(&self) -> usize {
        Block::decode_len(&self.en_carry.clone().into(), &self.format)
            .map_or(0, |(_, dst)| dst) + self.de_carry.len()
    }

    fn decode_end(
        &mut self,
        decoded: &mut [MaybeUninit<u8>],
    ) -> Result<usize, Self::Error> {
        let mut decoded = Writer::new(decoded);
        assert!(decoded.cap() >= self.decode_end_buffer_size());

        let encoded = mem::take(&mut self.en_carry).into();
        let post_group = Block::decode(encoded, &self.format)?.0;
        decoded.extend(self.de_carry.drain().chain(post_group));

        Ok(decoded.as_ref().len())
    }
}

/// A 40-bit Base32 block.
#[derive(Clone, Debug)]
struct Block(Group<u8, 5>);

impl Block {
    /// Encode this [`Block`] into a group of ASCII characters.
    fn encode(self, format: &Format) -> Group<u8, 8> {
        let (mid, end) = self.encode_len(format);
        let val = self.0.into_iter().enumerate()
            .map(|(i, b)| (b as u64) << (64 - (i + 1) * 8))
            .fold(0u64, |v, b| v | b);
        let mut res = Group::new();
        res.extend((0 .. mid)
            .map(|i| ((val >> (64 - (i + 1) * 5)) & 31) as u8)
            .map(Index::Regular)
            .chain(iter::repeat(Index::Padding))
            .take(end)
            .map(|i| i.encode(format))
            .by_ref());
        res
    }

    /// The size of the output when encoding.
    fn encode_len(&self, format: &Format) -> (usize, usize) {
        let mid = match self.0.len() {
            0 => 0,
            1 => 2,
            2 => 4,
            3 => 5,
            4 => 7,
            5 => 8,
            _ => unreachable!(),
        };

        if format.encode_padding() {
            (mid, mid.div_ceil(8) * 8)
        } else {
            (mid, mid)
        }
    }

    /// Decode a [`Block`] from a group of ASCII characters.
    fn decode(
        data: Group<Index, 8>,
        format: &Format,
    ) -> Result<Self, DecodeError> {
        let (src, dst) = Self::decode_len(&data, format)?;
        let val = data[.. src].iter()
            .enumerate()
            .map(|(i, &e)| match e {
                Index::Regular(b) => (b as u64) << (64 - (i + 1) * 5),
                Index::Padding => unreachable!(),
            })
            .fold(0u64, |v, b| v | b);
        let mut res = Group::new();
        res.extend((0 .. dst)
            .map(|i| (val >> (64 - (i + 1) * 8)) as u8)
            .by_ref());
        Ok(Self(res))
    }

    /// The size of the output when decoding.
    fn decode_len(
        data: &Group<Index, 8>,
        format: &Format,
    ) -> Result<(usize, usize), DecodeError> {
        // Check for and filter out any padding bytes.
        let src = if format.decode_padding() {
            data.len() - data.iter().try_fold(0, |p, &e| match (p, e) {
                (p, Index::Padding) => Ok(p + 1),
                (0, Index::Regular(_)) => Ok(0),
                // A regular byte cannot appear after padding.
                (_, Index::Regular(_)) => Err(DecodeError),
            })?
        } else {
            data.iter().try_fold(0, |n, &e| match e {
                Index::Regular(_) => Ok(n + 1),
                Index::Padding => Err(DecodeError),
            })?
        };

        Ok((src, match src {
            0 => 0,
            2 => 1,
            4 => 2,
            5 => 3,
            7 => 4,
            8 => 5,
            1 | 3 | 6 => return Err(DecodeError),
            _ => unreachable!(),
        }))
    }
}

/// A 5-bit Base32 index.
#[derive(Copy, Clone, Debug)]
enum Index {
    Regular(u8),
    Padding,
}

impl Index {
    /// Encode this [`Index`] into an ASCII character.
    fn encode(self, format: &Format) -> u8 {
        match format {
            Format::Base32 { .. } => match self {
                Self::Regular(index @ ( 0 ..= 25)) => index -  0 + b'A',
                Self::Regular(index @ (26 ..= 31)) => index - 26 + b'2',
                Self::Padding if format.encode_padding() => b'=',
                _ => unreachable!(),
            },
            Format::Base32Hex { .. } => match self {
                Self::Regular(index @ ( 0 ..=  9)) => index -  0 + b'0',
                Self::Regular(index @ (10 ..= 31)) => index - 10 + b'A',
                Self::Padding if format.encode_padding() => b'=',
                _ => unreachable!(),
            },
        }
    }

    /// Decode an [`Index`] from an ASCII character.
    fn decode(data: u8, format: &Format) -> Option<Self> {
        Some(match format {
            Format::Base32 { .. } => match data {
                b'A' ..= b'Z' => Self::Regular(data - b'A' +  0),
                b'2' ..= b'7' => Self::Regular(data - b'2' + 26),
                b'=' if format.decode_padding() => Self::Padding,
                _ => return None,
            },
            Format::Base32Hex { .. } => match data {
                b'0' ..= b'9' => Self::Regular(data - b'0' +  0),
                b'A' ..= b'V' => Self::Regular(data - b'A' + 10),
                b'=' if format.decode_padding() => Self::Padding,
                _ => return None,
            },
        })
    }
}