use crate::{Finish, Poll, SinkError};
use embedded_io::{Error, ErrorKind, ErrorType, Read, Write};
pub struct EncoderWriter<W, ENC> {
inner: W,
enc: ENC,
poll_buf: [u8; 256],
}
impl<W, ENC> EncoderWriter<W, ENC>
where
ENC: Default,
{
pub fn new(inner: W) -> Self {
Self {
inner,
enc: ENC::default(),
poll_buf: [0u8; 256],
}
}
}
impl<W, ENC> ErrorType for EncoderWriter<W, ENC>
where
W: ErrorType,
{
type Error = W::Error;
}
impl<W, ENC> EncoderWriter<W, ENC>
where
W: Write,
ENC: EncoderOps,
{
fn drain_poll(&mut self) -> Result<(), W::Error> {
loop {
match self.enc.poll(&mut self.poll_buf) {
Ok(Poll::More(n)) => self.inner.write_all(&self.poll_buf[..n])?,
Ok(Poll::Empty(n)) => {
if n > 0 {
self.inner.write_all(&self.poll_buf[..n])?;
}
return Ok(());
}
Err(_) => return Ok(()), }
}
}
pub fn finish(&mut self) -> Result<(), W::Error> {
loop {
match self.enc.finish() {
Finish::Done => return self.inner.flush(),
Finish::More => self.drain_poll()?,
}
}
}
}
impl<W, ENC> Write for EncoderWriter<W, ENC>
where
W: Write,
ENC: EncoderOps,
{
fn write(&mut self, buf: &[u8]) -> Result<usize, W::Error> {
let mut consumed = 0;
while consumed < buf.len() {
match self.enc.sink(&buf[consumed..]) {
Ok(n) => consumed += n,
Err(SinkError::Full) => self.drain_poll()?,
Err(SinkError::Misuse) => {
break;
}
}
}
self.drain_poll()?;
Ok(consumed)
}
fn flush(&mut self) -> Result<(), W::Error> {
self.inner.flush()
}
}
pub struct DecoderWriter<W, DEC> {
inner: W,
dec: DEC,
poll_buf: [u8; 256],
}
impl<W, DEC> DecoderWriter<W, DEC>
where
DEC: Default,
{
pub fn new(inner: W) -> Self {
Self {
inner,
dec: DEC::default(),
poll_buf: [0u8; 256],
}
}
}
impl<W, DEC> ErrorType for DecoderWriter<W, DEC>
where
W: ErrorType,
{
type Error = W::Error;
}
impl<W, DEC> DecoderWriter<W, DEC>
where
W: Write,
DEC: DecoderOps,
{
fn drain_poll(&mut self) -> Result<(), W::Error> {
loop {
match self.dec.poll(&mut self.poll_buf) {
Ok(Poll::More(n)) => self.inner.write_all(&self.poll_buf[..n])?,
Ok(Poll::Empty(n)) => {
if n > 0 {
self.inner.write_all(&self.poll_buf[..n])?;
}
return Ok(());
}
Err(_) => return Ok(()),
}
}
}
pub fn finish(&mut self) -> Result<(), W::Error> {
self.drain_poll()?;
self.inner.flush()
}
}
impl<W, DEC> Write for DecoderWriter<W, DEC>
where
W: Write,
DEC: DecoderOps,
{
fn write(&mut self, buf: &[u8]) -> Result<usize, W::Error> {
let mut consumed = 0;
while consumed < buf.len() {
match self.dec.sink(&buf[consumed..]) {
Ok(n) => consumed += n,
Err(SinkError::Full) => self.drain_poll()?,
Err(SinkError::Misuse) => break,
}
}
self.drain_poll()?;
Ok(consumed)
}
fn flush(&mut self) -> Result<(), W::Error> {
self.inner.flush()
}
}
pub struct DecoderReader<R, DEC, const OUT: usize = 256> {
inner: R,
dec: DEC,
out_buf: [u8; OUT],
out_start: usize,
out_end: usize,
in_buf: [u8; 64],
in_start: usize,
in_end: usize,
eof: bool,
}
impl<R, DEC, const OUT: usize> DecoderReader<R, DEC, OUT>
where
DEC: Default,
{
pub fn new(inner: R) -> Self {
Self {
inner,
dec: DEC::default(),
out_buf: [0u8; OUT],
out_start: 0,
out_end: 0,
in_buf: [0u8; 64],
in_start: 0,
in_end: 0,
eof: false,
}
}
}
impl<R, DEC, const OUT: usize> ErrorType for DecoderReader<R, DEC, OUT>
where
R: ErrorType,
{
type Error = R::Error;
}
impl<R, DEC, const OUT: usize> Read for DecoderReader<R, DEC, OUT>
where
R: Read,
DEC: DecoderOps,
{
fn read(&mut self, buf: &mut [u8]) -> Result<usize, R::Error> {
if self.out_start < self.out_end {
let n = (self.out_end - self.out_start).min(buf.len());
buf[..n].copy_from_slice(&self.out_buf[self.out_start..self.out_start + n]);
self.out_start += n;
if self.out_start == self.out_end {
self.out_start = 0;
self.out_end = 0;
}
return Ok(n);
}
loop {
while self.in_start < self.in_end {
match self.dec.sink(&self.in_buf[self.in_start..self.in_end]) {
Ok(n) => self.in_start += n,
Err(SinkError::Full) => break, Err(SinkError::Misuse) => return Ok(0),
}
}
if self.in_start == self.in_end {
self.in_start = 0;
self.in_end = 0;
}
match self.dec.poll(&mut self.out_buf) {
Ok(Poll::More(n)) | Ok(Poll::Empty(n)) if n > 0 => {
let give = n.min(buf.len());
buf[..give].copy_from_slice(&self.out_buf[..give]);
if n > give {
self.out_start = give;
self.out_end = n;
}
return Ok(give);
}
Ok(Poll::More(_)) => {
}
Ok(Poll::Empty(_)) => {
if self.eof && self.in_start == self.in_end {
return Ok(0); }
}
Err(_) => return Ok(0), }
if self.in_start < self.in_end {
continue;
}
if self.eof {
continue; }
let n_read = self.inner.read(&mut self.in_buf)?;
if n_read == 0 {
self.eof = true;
continue;
}
self.in_start = 0;
self.in_end = n_read;
}
}
}
pub struct EncoderReader<R, ENC, const OUT: usize = 256> {
inner: R,
enc: ENC,
out_buf: [u8; OUT],
out_start: usize,
out_end: usize,
in_buf: [u8; 64],
in_start: usize,
in_end: usize,
eof: bool,
finishing: bool,
}
impl<R, ENC, const OUT: usize> EncoderReader<R, ENC, OUT>
where
ENC: Default,
{
pub fn new(inner: R) -> Self {
Self {
inner,
enc: ENC::default(),
out_buf: [0u8; OUT],
out_start: 0,
out_end: 0,
in_buf: [0u8; 64],
in_start: 0,
in_end: 0,
eof: false,
finishing: false,
}
}
}
impl<R, ENC, const OUT: usize> ErrorType for EncoderReader<R, ENC, OUT>
where
R: ErrorType,
{
type Error = R::Error;
}
impl<R, ENC, const OUT: usize> Read for EncoderReader<R, ENC, OUT>
where
R: Read,
ENC: EncoderOps,
{
fn read(&mut self, buf: &mut [u8]) -> Result<usize, R::Error> {
if self.out_start < self.out_end {
let n = (self.out_end - self.out_start).min(buf.len());
buf[..n].copy_from_slice(&self.out_buf[self.out_start..self.out_start + n]);
self.out_start += n;
if self.out_start == self.out_end {
self.out_start = 0;
self.out_end = 0;
}
return Ok(n);
}
loop {
while self.in_start < self.in_end {
match self.enc.sink(&self.in_buf[self.in_start..self.in_end]) {
Ok(n) => self.in_start += n,
Err(SinkError::Full) => break,
Err(SinkError::Misuse) => return Ok(0),
}
}
if self.in_start == self.in_end {
self.in_start = 0;
self.in_end = 0;
}
match self.enc.poll(&mut self.out_buf) {
Ok(Poll::More(n)) | Ok(Poll::Empty(n)) if n > 0 => {
let give = n.min(buf.len());
buf[..give].copy_from_slice(&self.out_buf[..give]);
if n > give {
self.out_start = give;
self.out_end = n;
}
return Ok(give);
}
Ok(Poll::More(_)) => {
}
Ok(Poll::Empty(_)) => {
if self.finishing {
match self.enc.finish() {
Finish::Done => return Ok(0),
Finish::More => continue,
}
}
if self.eof && self.in_start == self.in_end {
self.finishing = true;
continue;
}
}
Err(_) => return Ok(0),
}
if self.in_start < self.in_end {
continue;
}
if self.eof || self.finishing {
continue;
}
let n_read = self.inner.read(&mut self.in_buf)?;
if n_read == 0 {
self.eof = true;
continue;
}
self.in_start = 0;
self.in_end = n_read;
}
}
}
mod private {
pub trait Sealed {}
}
pub trait EncoderOps: private::Sealed {
#[doc(hidden)]
fn sink(&mut self, input: &[u8]) -> Result<usize, SinkError>;
#[doc(hidden)]
fn poll(&mut self, output: &mut [u8]) -> Result<Poll, crate::PollError>;
#[doc(hidden)]
fn finish(&mut self) -> Finish;
}
pub trait DecoderOps: private::Sealed {
#[doc(hidden)]
fn sink(&mut self, input: &[u8]) -> Result<usize, SinkError>;
#[doc(hidden)]
fn poll(&mut self, output: &mut [u8]) -> Result<Poll, crate::PollError>;
#[doc(hidden)]
#[allow(dead_code)]
fn finish(&self) -> Finish;
}
impl<const W: usize, const L: usize, const BUF: usize> private::Sealed
for crate::encoder::HeatshrinkEncoder<W, L, BUF>
{
}
impl<const W: usize, const L: usize, const BUF: usize> EncoderOps
for crate::encoder::HeatshrinkEncoder<W, L, BUF>
{
#[inline]
fn sink(&mut self, input: &[u8]) -> Result<usize, SinkError> {
crate::encoder::HeatshrinkEncoder::sink(self, input)
}
#[inline]
fn poll(&mut self, output: &mut [u8]) -> Result<Poll, crate::PollError> {
crate::encoder::HeatshrinkEncoder::poll(self, output)
}
#[inline]
fn finish(&mut self) -> Finish {
crate::encoder::HeatshrinkEncoder::finish(self)
}
}
impl<const W: usize, const L: usize, const I: usize, const WIN: usize> private::Sealed
for crate::decoder::HeatshrinkDecoder<W, L, I, WIN>
{
}
impl<const W: usize, const L: usize, const I: usize, const WIN: usize> DecoderOps
for crate::decoder::HeatshrinkDecoder<W, L, I, WIN>
{
#[inline]
fn sink(&mut self, input: &[u8]) -> Result<usize, SinkError> {
crate::decoder::HeatshrinkDecoder::sink(self, input)
}
#[inline]
fn poll(&mut self, output: &mut [u8]) -> Result<Poll, crate::PollError> {
crate::decoder::HeatshrinkDecoder::poll(self, output)
}
#[inline]
fn finish(&self) -> Finish {
crate::decoder::HeatshrinkDecoder::finish(self)
}
}
pub struct SliceSink<'a> {
buf: &'a mut [u8],
pos: usize,
}
impl<'a> SliceSink<'a> {
pub fn new(buf: &'a mut [u8]) -> Self {
Self { buf, pos: 0 }
}
pub fn len(&self) -> usize {
self.pos
}
pub fn is_empty(&self) -> bool {
self.pos == 0
}
pub fn written(&self) -> &[u8] {
&self.buf[..self.pos]
}
}
#[derive(Debug)]
pub struct SliceSinkError(ErrorKind);
impl Error for SliceSinkError {
fn kind(&self) -> ErrorKind {
self.0
}
}
impl ErrorType for SliceSink<'_> {
type Error = SliceSinkError;
}
impl Write for SliceSink<'_> {
fn write(&mut self, buf: &[u8]) -> Result<usize, SliceSinkError> {
let space = self.buf.len() - self.pos;
if space == 0 {
return Err(SliceSinkError(ErrorKind::OutOfMemory));
}
let n = buf.len().min(space);
self.buf[self.pos..self.pos + n].copy_from_slice(&buf[..n]);
self.pos += n;
Ok(n)
}
fn flush(&mut self) -> Result<(), SliceSinkError> {
Ok(())
}
}
pub struct SliceSource<'a> {
buf: &'a [u8],
pos: usize,
}
impl<'a> SliceSource<'a> {
pub fn new(buf: &'a [u8]) -> Self {
Self { buf, pos: 0 }
}
}
#[derive(Debug)]
pub struct SliceSourceError(ErrorKind);
impl Error for SliceSourceError {
fn kind(&self) -> ErrorKind {
self.0
}
}
impl ErrorType for SliceSource<'_> {
type Error = SliceSourceError;
}
impl Read for SliceSource<'_> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, SliceSourceError> {
let remaining = self.buf.len() - self.pos;
let n = remaining.min(buf.len());
buf[..n].copy_from_slice(&self.buf[self.pos..self.pos + n]);
self.pos += n;
Ok(n)
}
}
#[cfg(test)]
#[cfg(feature = "embedded-io")]
mod test_io {
use super::{
DecoderReader, DecoderWriter, EncoderReader, EncoderWriter, SliceSink, SliceSource,
};
use crate::{DefaultDecoder, DefaultEncoder};
use embedded_io::{Read as _, Write as _};
fn ew_compress(src: &[u8], dst: &mut [u8]) -> usize {
let mut sink = SliceSink::new(dst);
let mut w: EncoderWriter<_, DefaultEncoder> = EncoderWriter::new(&mut sink);
w.write_all(src).unwrap();
w.finish().unwrap();
sink.len()
}
fn dw_decompress(src: &[u8], dst: &mut [u8]) -> usize {
let mut sink = SliceSink::new(dst);
let mut w: DecoderWriter<_, DefaultDecoder> = DecoderWriter::new(&mut sink);
w.write_all(src).unwrap();
w.finish().unwrap();
sink.len()
}
fn dr_decompress(src: &[u8], dst: &mut [u8]) -> usize {
let source = SliceSource::new(src);
let mut r: DecoderReader<_, DefaultDecoder> = DecoderReader::new(source);
let mut total = 0;
loop {
let n = r.read(&mut dst[total..]).unwrap();
if n == 0 {
break;
}
total += n;
}
total
}
fn er_compress(src: &[u8], dst: &mut [u8]) -> usize {
let source = SliceSource::new(src);
let mut r: EncoderReader<_, DefaultEncoder> = EncoderReader::new(source);
let mut total = 0;
loop {
let n = r.read(&mut dst[total..]).unwrap();
if n == 0 {
break;
}
total += n;
}
total
}
#[test]
fn writer_roundtrip_short() {
let src = b"hello heatshrink embedded-io";
let mut compressed = [0u8; 256];
let mut decompressed = [0u8; 256];
let n_enc = ew_compress(src, &mut compressed);
let n_dec = dw_decompress(&compressed[..n_enc], &mut decompressed);
assert_eq!(src, &decompressed[..n_dec]);
}
#[test]
fn writer_roundtrip_repetitive() {
let src: [u8; 4096] = core::array::from_fn(|i| (i % 251) as u8);
let mut compressed = [0u8; 4096];
let mut decompressed = [0u8; 4096];
let n_enc = ew_compress(&src, &mut compressed);
let n_dec = dw_decompress(&compressed[..n_enc], &mut decompressed);
assert_eq!(&src[..], &decompressed[..n_dec]);
}
#[test]
fn writer_roundtrip_random() {
let src: [u8; 4096] = core::array::from_fn(|i| {
(i.wrapping_mul(6364136223846793005usize)
.wrapping_add(1442695040888963407)
>> 56) as u8
});
let mut compressed = [0u8; 8192];
let mut decompressed = [0u8; 8192];
let n_enc = ew_compress(&src, &mut compressed);
let n_dec = dw_decompress(&compressed[..n_enc], &mut decompressed);
assert_eq!(&src[..], &decompressed[..n_dec]);
}
#[test]
fn writer_reader_roundtrip() {
let src = b"the quick brown fox jumps over the lazy dog";
let mut compressed = [0u8; 256];
let mut decompressed = [0u8; 256];
let n_enc = ew_compress(src, &mut compressed);
let n_dec = dr_decompress(&compressed[..n_enc], &mut decompressed);
assert_eq!(src, &decompressed[..n_dec]);
}
#[test]
fn reader_writer_roundtrip() {
let src = b"the quick brown fox jumps over the lazy dog";
let mut compressed = [0u8; 256];
let mut decompressed = [0u8; 256];
let n_enc = er_compress(src, &mut compressed);
let n_dec = dw_decompress(&compressed[..n_enc], &mut decompressed);
assert_eq!(src, &decompressed[..n_dec]);
}
#[test]
fn reader_reader_roundtrip() {
let src: [u8; 4096] = core::array::from_fn(|i| (i % 251) as u8);
let mut compressed = [0u8; 4096];
let mut decompressed = [0u8; 4096];
let n_enc = er_compress(&src, &mut compressed);
let n_dec = dr_decompress(&compressed[..n_enc], &mut decompressed);
assert_eq!(&src[..], &decompressed[..n_dec]);
}
#[test]
fn reader_small_output_buf() {
let src: [u8; 512] = core::array::from_fn(|i| (i % 137) as u8);
let mut compressed = [0u8; 1024];
let mut decompressed = [0u8; 1024];
let source = SliceSource::new(&src);
let mut r: EncoderReader<_, DefaultEncoder> = EncoderReader::new(source);
let mut n_enc = 0;
loop {
let n = r.read(&mut compressed[n_enc..n_enc + 1]).unwrap();
if n == 0 {
break;
}
n_enc += n;
}
let source2 = SliceSource::new(&compressed[..n_enc]);
let mut r2: DecoderReader<_, DefaultDecoder> = DecoderReader::new(source2);
let mut n_dec = 0;
let mut tmp = [0u8; 3];
loop {
let n = r2.read(&mut tmp).unwrap();
if n == 0 {
break;
}
decompressed[n_dec..n_dec + n].copy_from_slice(&tmp[..n]);
n_dec += n;
}
assert_eq!(&src[..], &decompressed[..n_dec]);
}
#[test]
fn writer_roundtrip_empty() {
let src: &[u8] = &[];
let mut compressed = [0u8; 64];
let mut decompressed = [0u8; 64];
let n_enc = ew_compress(src, &mut compressed);
let n_dec = dw_decompress(&compressed[..n_enc], &mut decompressed);
assert_eq!(src, &decompressed[..n_dec]);
}
}