oozextract 0.5.4

Open source Kraken / Mermaid / Selkie / Leviathan / LZNA / Bitknit decompressor
use crate::ooz::error::{Res, ResultBuilder};
#[cfg(feature = "async")]
use futures::{Stream, StreamExt};
use std::io::Read;
#[cfg(feature = "tokio")]
use tokio::io::{AsyncRead, AsyncReadExt};

pub trait Input<S: AsRef<[u8]>> {
    async fn read_to(&mut self, buf: &mut [u8]) -> Res<()>;
    async fn read_slice(&mut self, buf: &mut bytes::BytesMut, len: usize) -> Res<S>;
    async fn read_array<const N: usize>(&mut self, to_read: usize) -> Res<[u8; N]> {
        let mut buf = [0; N];
        self.read_to(&mut buf[N - to_read..]).await?;
        Ok(buf)
    }
}

impl<R: Read> Input<bytes::BytesMut> for R {
    async fn read_to(&mut self, buf: &mut [u8]) -> Res<()> {
        Ok(self.read_exact(buf)?)
    }

    async fn read_slice(&mut self, buf: &mut bytes::BytesMut, len: usize) -> Res<bytes::BytesMut> {
        if buf.len() < len {
            buf.resize(len, 0);
        }
        let mut slice = buf.split_to(len);
        self.read_exact(slice.as_mut())?;
        Ok(slice)
    }
}

pub(crate) struct Slice<'a> {
    pub(crate) buf: &'a [u8],
    pub(crate) pos: usize,
}

impl<'a> Input<&'a [u8]> for Slice<'a> {
    async fn read_to(&mut self, buf: &mut [u8]) -> Res<()> {
        buf.copy_from_slice(self.buf.get(self.pos..self.pos + buf.len()).msg_of(&(
            self.pos,
            self.buf,
            buf.len(),
        ))?);
        self.pos += buf.len();
        Ok(())
    }

    async fn read_slice(&mut self, _: &mut bytes::BytesMut, len: usize) -> Res<&'a [u8]> {
        let slice = self
            .buf
            .get(self.pos..self.pos + len)
            .msg_of(&(self.pos, self.buf, len))?;
        self.pos += len;
        Ok(slice)
    }
}

#[cfg(feature = "tokio")]
pub(crate) struct Async<R: AsyncRead + Unpin>(pub(crate) R);

#[cfg(feature = "tokio")]
impl<R: AsyncRead + Unpin> Input<bytes::BytesMut> for Async<R> {
    async fn read_to(&mut self, buf: &mut [u8]) -> Res<()> {
        self.0.read_exact(buf).await?;
        Ok(())
    }

    async fn read_slice(&mut self, buf: &mut bytes::BytesMut, len: usize) -> Res<bytes::BytesMut> {
        if buf.len() < len {
            buf.resize(len, 0);
        }
        let mut slice = buf.split_to(len);
        self.0.read_exact(slice.as_mut()).await?;
        Ok(slice)
    }
}

#[cfg(feature = "async")]
pub(crate) struct ByteStream<
    E: 'static + std::error::Error + Send + Sync,
    S: Stream<Item = Result<bytes::Bytes, E>>,
> {
    pub(crate) current: Option<bytes::Bytes>,
    pub(crate) stream: S,
}

#[cfg(feature = "async")]
impl<
        E: 'static + std::error::Error + Send + Sync,
        S: Stream<Item = Result<bytes::Bytes, E>> + Unpin,
    > Input<bytes::Bytes> for ByteStream<E, S>
{
    async fn read_to(&mut self, mut out: &mut [u8]) -> Res<()> {
        use crate::ooz::error::ErrorBuilder;
        if self.current.is_none() {
            self.current = ErrorBuilder::invert(self.stream.next().await)?
        }
        while let Some(ref mut bytes) = self.current {
            if out.is_empty() {
                return Ok(());
            } else if bytes.len() >= out.len() {
                out.copy_from_slice(bytes.split_to(out.len()).as_ref());
                if bytes.is_empty() {
                    self.current = None;
                }
                return Ok(());
            } else {
                out[..bytes.len()].copy_from_slice(bytes);
                out = &mut out[bytes.len()..];
                self.current = ErrorBuilder::invert(self.stream.next().await)?;
            }
        }
        Err(ErrorBuilder {
            #[cfg(feature = "verbose_errors")]
            message: Some("Unexpected end of stream".into()),
            ..Default::default()
        })?
    }

    async fn read_slice(&mut self, buf: &mut bytes::BytesMut, len: usize) -> Res<bytes::Bytes> {
        use crate::ooz::error::ErrorBuilder;
        if self.current.is_none() {
            self.current = ErrorBuilder::invert(self.stream.next().await)?;
        }
        if let Some(ref mut bytes) = self.current {
            if bytes.len() >= len {
                return Ok(bytes.split_to(len));
            }
        }

        if buf.len() < len {
            buf.resize(len, 0);
        }
        let mut slice = buf.split_to(len);
        self.read_to(slice.as_mut()).await?;
        Ok(slice.freeze())
    }
}