use crate::error::{
input_buffer_doesnt_match_bits, invalid_character, output_buffer_doesnt_match_bits,
};
use crate::stateful_decoder::{
quintet_has_valid_trailing_bits, HaveOctets, NeedQuintets, NextOctetResult,
ProvideQuintetResult,
};
use crate::tables::{CHARACTER_MIN_VALUE, CHARACTER_TO_QUINTET};
use crate::util::{required_octets_buffer_len, required_quintets_buffer_len};
use crate::ZBase32Error;
use core::iter::Peekable;
enum QuintetsToOctetsIterState {
Initial(NeedQuintets),
HaveOctets(HaveOctets),
}
struct QuintetsToOctetsIter<I>
where
I: Iterator<Item = Result<u8, ZBase32Error>>,
{
quintet_iter: Peekable<I>,
state: Option<QuintetsToOctetsIterState>,
}
impl<I> QuintetsToOctetsIter<I>
where
I: Iterator<Item = Result<u8, ZBase32Error>>,
{
fn new(quintet_iter: I, need_quintets: NeedQuintets) -> QuintetsToOctetsIter<I> {
QuintetsToOctetsIter {
quintet_iter: quintet_iter.peekable(),
state: Some(QuintetsToOctetsIterState::Initial(need_quintets)),
}
}
}
fn refill<I>(
quintet_iter: &mut Peekable<I>,
mut need_quintets: NeedQuintets,
) -> Result<Option<QuintetsToOctetsIterState>, ZBase32Error>
where
I: Iterator<Item = Result<u8, ZBase32Error>>,
{
loop {
let quintet = quintet_iter.next().unwrap()?;
let last_quintet = quintet_iter.peek().is_none();
match need_quintets.provide_quintet(quintet, last_quintet)? {
ProvideQuintetResult::NeedQuintets(need_more) => need_quintets = need_more,
ProvideQuintetResult::HaveOctets(have_octets) => {
return Ok(Some(QuintetsToOctetsIterState::HaveOctets(have_octets)))
}
}
}
}
impl<I> Iterator for QuintetsToOctetsIter<I>
where
I: Iterator<Item = Result<u8, ZBase32Error>>,
{
type Item = Result<u8, ZBase32Error>;
fn next(&mut self) -> Option<<Self as Iterator>::Item> {
loop {
if self.state.is_none() {
return None;
}
match self.state.take().unwrap() {
QuintetsToOctetsIterState::Initial(need_quintets) => {
if self.quintet_iter.peek().is_some() {
match refill(&mut self.quintet_iter, need_quintets) {
Ok(new_state) => self.state = new_state,
Err(err) => return Some(Err(err)),
}
} else {
}
}
QuintetsToOctetsIterState::HaveOctets(have_octets) => {
match have_octets.next_octet() {
NextOctetResult::Octet(octet, have_octets) => {
self.state = Some(QuintetsToOctetsIterState::HaveOctets(have_octets));
return Some(Ok(octet));
}
NextOctetResult::NeedQuintets(need_quintets) => {
match refill(&mut self.quintet_iter, need_quintets) {
Ok(new_state) => self.state = new_state,
Err(err) => return Some(Err(err)),
}
}
NextOctetResult::Complete => {}
}
}
}
}
}
}
pub fn character_to_quintet(character: u8) -> Result<u8, ZBase32Error> {
if character < CHARACTER_MIN_VALUE {
return Err(invalid_character());
} else if (character - CHARACTER_MIN_VALUE) as usize >= CHARACTER_TO_QUINTET.len() {
return Err(invalid_character());
}
let val = CHARACTER_TO_QUINTET[(character - CHARACTER_MIN_VALUE) as usize];
if val == 255 {
return Err(invalid_character());
}
Ok(val)
}
fn calc_last_quintet_bits(bits: u64) -> Option<u8> {
if bits == 0 {
None
} else {
match bits % 5 {
0 => Some(5),
x => Some(x as u8),
}
}
}
pub fn is_last_quintet_valid(bits: u64, quintet: u8) -> bool {
if let Some(last_quintet_bits) = calc_last_quintet_bits(bits) {
quintet <= 31 && quintet_has_valid_trailing_bits(last_quintet_bits, quintet)
} else {
false
}
}
pub fn quintets_to_octets(
in_quintets: &[u8],
out_octets: &mut [u8],
bits: u64,
) -> Result<(), ZBase32Error> {
if in_quintets.len() != required_quintets_buffer_len(bits)? {
return Err(input_buffer_doesnt_match_bits().into());
}
if out_octets.len() != required_octets_buffer_len(bits)? {
return Err(output_buffer_doesnt_match_bits().into());
}
let last_quintet_bits = if let Some(x) = calc_last_quintet_bits(bits) {
x
} else {
return Ok(());
};
let octet_iter = QuintetsToOctetsIter::new(
in_quintets.iter().map(|&x| Ok(x)),
NeedQuintets::new(last_quintet_bits),
);
for (out, next_octet) in out_octets.iter_mut().zip(octet_iter) {
*out = next_octet?;
}
Ok(())
}
pub fn decode_slices(
in_characters: &[u8],
out_octets: &mut [u8],
bits: u64,
) -> Result<(), ZBase32Error> {
if in_characters.len() != required_quintets_buffer_len(bits)? {
return Err(input_buffer_doesnt_match_bits().into());
}
if out_octets.len() != required_octets_buffer_len(bits)? {
return Err(output_buffer_doesnt_match_bits().into());
}
let last_quintet_bits = if let Some(x) = calc_last_quintet_bits(bits) {
x
} else {
return Ok(());
};
let octet_iter = QuintetsToOctetsIter::new(
in_characters.iter().map(|&x| character_to_quintet(x)),
NeedQuintets::new(last_quintet_bits),
);
for (out, next_octet) in out_octets.iter_mut().zip(octet_iter) {
*out = next_octet?;
}
Ok(())
}
#[cfg(feature = "std")]
pub fn decode(input: &str, output: &mut Vec<u8>, bits: u64) -> Result<(), ZBase32Error> {
if input.len() != required_quintets_buffer_len(bits)? {
return Err(input_buffer_doesnt_match_bits().into());
}
let last_quintet_bits = if let Some(x) = calc_last_quintet_bits(bits) {
x
} else {
return Ok(());
};
let needed_octets = required_octets_buffer_len(bits)?;
let start = output.len();
output.extend(std::iter::repeat(0).take(needed_octets));
let output_buff = &mut output[start..];
let octet_iter = QuintetsToOctetsIter::new(
input.as_bytes().iter().map(|&x| character_to_quintet(x)),
NeedQuintets::new(last_quintet_bits),
);
for (out, next_octet) in output_buff.iter_mut().zip(octet_iter) {
*out = next_octet?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::decode;
use crate::test_data::{TestCase, RANDOM_TEST_DATA, STANDARD_TEST_DATA};
fn run_tests(test_cases: &[TestCase]) {
let mut buffer = Vec::new();
for test in test_cases {
buffer.clear();
decode(test.encoded, &mut buffer, test.bits).unwrap();
assert_eq!(&buffer[..], test.unencoded);
}
}
#[test]
fn test_decode_standard() {
run_tests(STANDARD_TEST_DATA);
}
#[test]
fn test_decode_random() {
run_tests(RANDOM_TEST_DATA);
}
}