use std::io;
use super::{zstd_encoder, zstd_encoder_with_clevel};
use crate::{Compression, Result};
pub struct DynWriter<'a, W>(DynWriterImpl<'a, W>)
where
W: io::Write;
enum DynWriterImpl<'a, W>
where
W: io::Write,
{
Uncompressed(W),
Zstd(zstd::stream::AutoFinishEncoder<'a, W>),
}
impl<W> DynWriter<'_, W>
where
W: io::Write,
{
pub fn new(writer: W, compression: Compression) -> Result<Self> {
match compression {
Compression::None => Ok(Self(DynWriterImpl::Uncompressed(writer))),
Compression::Zstd => zstd_encoder(writer).map(|enc| Self(DynWriterImpl::Zstd(enc))),
}
}
pub fn with_compression_level(writer: W, level: i32) -> Result<Self> {
Ok(Self(DynWriterImpl::Zstd(zstd_encoder_with_clevel(
writer, level,
)?)))
}
pub fn get_mut(&mut self) -> &mut W {
match &mut self.0 {
DynWriterImpl::Uncompressed(w) => w,
DynWriterImpl::Zstd(enc) => enc.get_mut(),
}
}
}
impl<W> io::Write for DynWriter<'_, W>
where
W: io::Write,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match &mut self.0 {
DynWriterImpl::Uncompressed(writer) => writer.write(buf),
DynWriterImpl::Zstd(writer) => writer.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match &mut self.0 {
DynWriterImpl::Uncompressed(writer) => writer.flush(),
DynWriterImpl::Zstd(writer) => writer.flush(),
}
}
fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
match &mut self.0 {
DynWriterImpl::Uncompressed(writer) => writer.write_vectored(bufs),
DynWriterImpl::Zstd(writer) => writer.write_vectored(bufs),
}
}
fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
match &mut self.0 {
DynWriterImpl::Uncompressed(writer) => writer.write_all(buf),
DynWriterImpl::Zstd(writer) => writer.write_all(buf),
}
}
fn write_fmt(&mut self, fmt: std::fmt::Arguments<'_>) -> io::Result<()> {
match &mut self.0 {
DynWriterImpl::Uncompressed(writer) => writer.write_fmt(fmt),
DynWriterImpl::Zstd(writer) => writer.write_fmt(fmt),
}
}
}
#[cfg(feature = "async")]
pub use r#async::DynBufWriter as DynAsyncBufWriter;
#[cfg(feature = "async")]
pub use r#async::DynWriter as DynAsyncWriter;
#[cfg(feature = "async")]
mod r#async {
use std::{
pin::Pin,
task::{Context, Poll},
};
use async_compression::tokio::write::ZstdEncoder;
use tokio::io::{self, BufWriter};
use crate::{
encode::{async_zstd_encoder, async_zstd_encoder_with_clevel},
enums::Compression,
};
pub struct DynBufWriter<W, B = W>(DynBufWriterImpl<W, B>);
enum DynBufWriterImpl<W, B> {
Uncompressed(B),
Zstd(ZstdEncoder<W>),
}
impl<W> DynBufWriter<W, BufWriter<W>>
where
W: io::AsyncWriteExt + Unpin,
{
pub fn new(writer: W, compression: Compression) -> Self {
Self(match compression {
Compression::None => DynBufWriterImpl::Uncompressed(BufWriter::new(writer)),
Compression::Zstd => DynBufWriterImpl::Zstd(async_zstd_encoder(writer)),
})
}
pub fn with_compression_level(writer: W, level: i32) -> Self {
Self(DynBufWriterImpl::Zstd(async_zstd_encoder_with_clevel(
writer, level,
)))
}
}
impl<W> io::AsyncWrite for DynBufWriter<W>
where
W: io::AsyncWrite + io::AsyncWriteExt + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match &mut self.0 {
DynBufWriterImpl::Uncompressed(w) => {
io::AsyncWrite::poll_write(Pin::new(w), cx, buf)
}
DynBufWriterImpl::Zstd(enc) => io::AsyncWrite::poll_write(Pin::new(enc), cx, buf),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut self.0 {
DynBufWriterImpl::Uncompressed(w) => io::AsyncWrite::poll_flush(Pin::new(w), cx),
DynBufWriterImpl::Zstd(enc) => io::AsyncWrite::poll_flush(Pin::new(enc), cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut self.0 {
DynBufWriterImpl::Uncompressed(w) => io::AsyncWrite::poll_shutdown(Pin::new(w), cx),
DynBufWriterImpl::Zstd(enc) => io::AsyncWrite::poll_shutdown(Pin::new(enc), cx),
}
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
match &mut self.0 {
DynBufWriterImpl::Uncompressed(w) => {
io::AsyncWrite::poll_write_vectored(Pin::new(w), cx, bufs)
}
DynBufWriterImpl::Zstd(enc) => {
io::AsyncWrite::poll_write_vectored(Pin::new(enc), cx, bufs)
}
}
}
fn is_write_vectored(&self) -> bool {
match &self.0 {
DynBufWriterImpl::Uncompressed(w) => w.is_write_vectored(),
DynBufWriterImpl::Zstd(enc) => enc.is_write_vectored(),
}
}
}
pub struct DynWriter<W>(DynWriterImpl<W>)
where
W: io::AsyncWriteExt + Unpin;
enum DynWriterImpl<W>
where
W: io::AsyncWriteExt + Unpin,
{
Uncompressed(W),
Zstd(ZstdEncoder<W>),
}
impl<W> DynWriter<W>
where
W: io::AsyncWriteExt + Unpin,
{
pub fn new(writer: W, compression: Compression) -> Self {
Self(match compression {
Compression::None => DynWriterImpl::Uncompressed(writer),
Compression::Zstd => DynWriterImpl::Zstd(async_zstd_encoder(writer)),
})
}
pub fn with_compression_level(writer: W, level: i32) -> Self {
Self(DynWriterImpl::Zstd(async_zstd_encoder_with_clevel(
writer, level,
)))
}
pub fn get_mut(&mut self) -> &mut W {
match &mut self.0 {
DynWriterImpl::Uncompressed(w) => w,
DynWriterImpl::Zstd(enc) => enc.get_mut(),
}
}
}
impl<W> io::AsyncWrite for DynWriter<W>
where
W: io::AsyncWrite + io::AsyncWriteExt + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match &mut self.0 {
DynWriterImpl::Uncompressed(w) => io::AsyncWrite::poll_write(Pin::new(w), cx, buf),
DynWriterImpl::Zstd(enc) => io::AsyncWrite::poll_write(Pin::new(enc), cx, buf),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut self.0 {
DynWriterImpl::Uncompressed(w) => io::AsyncWrite::poll_flush(Pin::new(w), cx),
DynWriterImpl::Zstd(enc) => io::AsyncWrite::poll_flush(Pin::new(enc), cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut self.0 {
DynWriterImpl::Uncompressed(w) => io::AsyncWrite::poll_shutdown(Pin::new(w), cx),
DynWriterImpl::Zstd(enc) => io::AsyncWrite::poll_shutdown(Pin::new(enc), cx),
}
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
match &mut self.0 {
DynWriterImpl::Uncompressed(w) => {
io::AsyncWrite::poll_write_vectored(Pin::new(w), cx, bufs)
}
DynWriterImpl::Zstd(enc) => {
io::AsyncWrite::poll_write_vectored(Pin::new(enc), cx, bufs)
}
}
}
fn is_write_vectored(&self) -> bool {
match &self.0 {
DynWriterImpl::Uncompressed(w) => w.is_write_vectored(),
DynWriterImpl::Zstd(enc) => enc.is_write_vectored(),
}
}
}
}