use crate::{binary::bits, Reveal};
use std::{
io::{self, BufReader, BufWriter, Read, Write},
ops::ControlFlow,
};
#[derive(Debug, PartialEq)]
enum PayloadLength {
Bound(u64),
Unbound,
Embedded,
}
#[derive(Debug)]
pub struct Package<P, R>
where
P: FnMut(usize) -> Option<u8>,
R: Read,
{
pattern: P,
reader: BufReader<R>,
len: PayloadLength,
}
impl<P, R> Package<P, R>
where
P: FnMut(usize) -> Option<u8>,
R: Read,
{
#[must_use]
pub fn with_len(len: usize, pattern: P, reader: R) -> Self {
Self {
pattern,
reader: BufReader::new(reader),
len: PayloadLength::Bound(len as u64),
}
}
#[must_use]
pub fn with_embedded_len(pattern: P, reader: R) -> Self {
Self {
pattern,
reader: BufReader::new(reader),
len: PayloadLength::Embedded,
}
}
#[must_use]
pub fn new(pattern: P, reader: R) -> Self {
Self {
pattern,
reader: BufReader::new(reader),
len: PayloadLength::Unbound,
}
}
}
impl<M, R> Reveal for &mut Package<M, R>
where
M: FnMut(usize) -> Option<u8>,
R: Read,
{
type Err = io::Error;
fn reveal<W: Write>(self, output: W) -> io::Result<usize> {
let mut output = BufWriter::new(output);
let mut len_bytes = (self.len == PayloadLength::Embedded).then(|| Vec::with_capacity(8));
let mut bytes_written = 0usize;
let mut write_byte = |byte| -> Result<ControlFlow<()>, io::Error> {
if let Some(bytes) = len_bytes.as_mut() {
bytes.push(byte);
if bytes.len() == 8 {
self.len = PayloadLength::Bound(u64::from_be_bytes(
*bytes.first_chunk::<8>().unwrap(),
));
if let PayloadLength::Bound(0) = self.len {
return Ok(ControlFlow::Break(()));
}
len_bytes = None;
}
return Ok(ControlFlow::Continue(()));
}
bytes_written += output.write(&[byte])?;
Ok(match self.len {
PayloadLength::Embedded => unreachable!("`PayloadLength::Embedded` is replaced with `PayloadLength::Known(n)` before reaching this"),
PayloadLength::Unbound => ControlFlow::Continue(()),
PayloadLength::Bound(len) => {
if (bytes_written as u64) < len {
ControlFlow::Continue(())
} else {
ControlFlow::Break(())
}
}
})
};
let mut payload_byte = 0;
let mut bit_count = 0usize;
for (index, package_byte) in self.reader.by_ref().bytes().enumerate() {
let Some(mask) = (self.pattern)(index) else {
break;
};
let package_byte = package_byte?;
for pow in bits::Ones::from(mask) {
payload_byte |= ((package_byte >> pow) & 1) << bit_count;
bit_count += 1;
if bit_count < 8 {
continue;
}
if write_byte(payload_byte)?.is_break() {
output.flush()?;
return Ok(bytes_written);
}
bit_count = 0;
payload_byte = 0;
}
}
if bit_count > 0 {
write_byte(payload_byte)?;
}
output.flush()?;
Ok(bytes_written)
}
}