rong_stream 0.3.1

Stream module for RongJS
use bytes::Bytes;
use flate2::Compression;
use flate2::write::{
    DeflateDecoder, DeflateEncoder, GzDecoder, GzEncoder, ZlibDecoder, ZlibEncoder,
};
use rong::{HostError, JSContext, JSObject, JSResult, JSValue, js_class, js_export, js_method};
use std::io::{self, Write};
use tokio::sync::mpsc;

use crate::{JSReadableStream, JSWritableStream, writable_stream_to_sender};

type StreamChunk = Result<Bytes, String>;

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum CompressionFormat {
    Gzip,
    Deflate,
    DeflateRaw,
}

impl CompressionFormat {
    fn parse(raw: &str) -> JSResult<Self> {
        match raw {
            "gzip" => Ok(Self::Gzip),
            "deflate" => Ok(Self::Deflate),
            "deflate-raw" => Ok(Self::DeflateRaw),
            _ => Err(HostError::new(
                rong::error::E_TYPE,
                format!("Unsupported compression format: {raw}"),
            )
            .with_name("TypeError")
            .into()),
        }
    }
}

enum EncoderState {
    Gzip(GzEncoder<Vec<u8>>),
    Deflate(ZlibEncoder<Vec<u8>>),
    DeflateRaw(DeflateEncoder<Vec<u8>>),
}

impl EncoderState {
    fn new(format: CompressionFormat) -> Self {
        match format {
            CompressionFormat::Gzip => {
                Self::Gzip(GzEncoder::new(Vec::new(), Compression::default()))
            }
            CompressionFormat::Deflate => {
                Self::Deflate(ZlibEncoder::new(Vec::new(), Compression::default()))
            }
            CompressionFormat::DeflateRaw => {
                Self::DeflateRaw(DeflateEncoder::new(Vec::new(), Compression::default()))
            }
        }
    }

    fn write_all(&mut self, chunk: &[u8]) -> io::Result<()> {
        match self {
            Self::Gzip(inner) => inner.write_all(chunk),
            Self::Deflate(inner) => inner.write_all(chunk),
            Self::DeflateRaw(inner) => inner.write_all(chunk),
        }
    }

    fn try_finish(&mut self) -> io::Result<()> {
        match self {
            Self::Gzip(inner) => inner.try_finish(),
            Self::Deflate(inner) => inner.try_finish(),
            Self::DeflateRaw(inner) => inner.try_finish(),
        }
    }

    fn take_output(&mut self) -> Option<Bytes> {
        let buffer = match self {
            Self::Gzip(inner) => inner.get_mut(),
            Self::Deflate(inner) => inner.get_mut(),
            Self::DeflateRaw(inner) => inner.get_mut(),
        };
        if buffer.is_empty() {
            return None;
        }
        Some(Bytes::from(std::mem::take(buffer)))
    }
}

enum DecoderState {
    Gzip(GzDecoder<Vec<u8>>),
    Deflate(ZlibDecoder<Vec<u8>>),
    DeflateRaw(DeflateDecoder<Vec<u8>>),
}

impl DecoderState {
    fn new(format: CompressionFormat) -> Self {
        match format {
            CompressionFormat::Gzip => Self::Gzip(GzDecoder::new(Vec::new())),
            CompressionFormat::Deflate => Self::Deflate(ZlibDecoder::new(Vec::new())),
            CompressionFormat::DeflateRaw => Self::DeflateRaw(DeflateDecoder::new(Vec::new())),
        }
    }

    fn write_chunk(&mut self, chunk: &[u8]) -> io::Result<()> {
        let mut offset = 0;
        while offset < chunk.len() {
            let written = match self {
                Self::Gzip(inner) => inner.write(&chunk[offset..])?,
                Self::Deflate(inner) => inner.write(&chunk[offset..])?,
                Self::DeflateRaw(inner) => inner.write(&chunk[offset..])?,
            };
            if written == 0 {
                return Err(io::Error::new(
                    io::ErrorKind::InvalidData,
                    "Trailing data after end of compressed stream",
                ));
            }
            offset += written;
        }
        Ok(())
    }

    fn flush(&mut self) -> io::Result<()> {
        match self {
            Self::Gzip(inner) => inner.flush(),
            Self::Deflate(inner) => inner.flush(),
            Self::DeflateRaw(inner) => inner.flush(),
        }
    }

    fn try_finish(&mut self) -> io::Result<()> {
        match self {
            Self::Gzip(inner) => inner.try_finish(),
            Self::Deflate(inner) => inner.try_finish(),
            Self::DeflateRaw(inner) => inner.try_finish(),
        }
    }

    fn take_output(&mut self) -> Option<Bytes> {
        let buffer = match self {
            Self::Gzip(inner) => inner.get_mut(),
            Self::Deflate(inner) => inner.get_mut(),
            Self::DeflateRaw(inner) => inner.get_mut(),
        };
        if buffer.is_empty() {
            return None;
        }
        Some(Bytes::from(std::mem::take(buffer)))
    }
}

fn spawn_compression_task(
    mut input_rx: mpsc::Receiver<Bytes>,
    output_tx: mpsc::Sender<StreamChunk>,
    format: CompressionFormat,
) {
    rong::spawn_local(async move {
        let mut encoder = EncoderState::new(format);

        while let Some(chunk) = input_rx.recv().await {
            let result = encoder.write_all(&chunk).map_err(|error| error.to_string());

            match result {
                Ok(()) => {
                    if let Some(bytes) = encoder.take_output()
                        && output_tx.send(Ok(bytes)).await.is_err()
                    {
                        return;
                    }
                }
                Err(error) => {
                    let _ = output_tx.send(Err(error)).await;
                    return;
                }
            }
        }

        match encoder.try_finish() {
            Ok(()) => {
                if let Some(bytes) = encoder.take_output() {
                    let _ = output_tx.send(Ok(bytes)).await;
                }
            }
            Err(error) => {
                let _ = output_tx.send(Err(error.to_string())).await;
            }
        }
    });
}

fn spawn_decompression_task(
    mut input_rx: mpsc::Receiver<Bytes>,
    output_tx: mpsc::Sender<StreamChunk>,
    format: CompressionFormat,
) {
    rong::spawn_local(async move {
        let mut decoder = DecoderState::new(format);

        while let Some(chunk) = input_rx.recv().await {
            let result = decoder
                .write_chunk(&chunk)
                .and_then(|_| decoder.flush())
                .map_err(|error| error.to_string());

            match result {
                Ok(()) => {
                    if let Some(bytes) = decoder.take_output()
                        && output_tx.send(Ok(bytes)).await.is_err()
                    {
                        return;
                    }
                }
                Err(error) => {
                    let _ = output_tx.send(Err(error)).await;
                    return;
                }
            }
        }

        match decoder.try_finish() {
            Ok(()) => {
                if let Some(bytes) = decoder.take_output() {
                    let _ = output_tx.send(Ok(bytes)).await;
                }
            }
            Err(error) => {
                let _ = output_tx.send(Err(error.to_string())).await;
            }
        }
    });
}

fn build_transform_pair(
    ctx: &JSContext,
    format: CompressionFormat,
    decode: bool,
) -> JSResult<(JSObject, JSObject)> {
    let (input_tx, input_rx) = mpsc::channel::<Bytes>(16);
    let (output_tx, output_rx) = mpsc::channel::<StreamChunk>(16);

    if decode {
        spawn_decompression_task(input_rx, output_tx, format);
    } else {
        spawn_compression_task(input_rx, output_tx, format);
    }

    let writable = JSWritableStream::new(ctx, writable_stream_to_sender(input_tx))?.into_object();
    let readable = JSReadableStream::from_receiver(ctx, output_rx)?.into_object();

    Ok((readable, writable))
}

#[js_export]
pub struct CompressionStream {
    readable: JSObject,
    writable: JSObject,
}

#[js_class]
impl CompressionStream {
    #[js_method(constructor)]
    fn constructor(ctx: JSContext, format: String) -> JSResult<Self> {
        let format = CompressionFormat::parse(&format)?;
        let (readable, writable) = build_transform_pair(&ctx, format, false)?;
        Ok(Self { readable, writable })
    }

    #[js_method(getter)]
    fn readable(&self) -> JSObject {
        self.readable.clone()
    }

    #[js_method(getter)]
    fn writable(&self) -> JSObject {
        self.writable.clone()
    }

    #[js_method(gc_mark)]
    fn gc_mark_with<F>(&self, mut mark_fn: F)
    where
        F: FnMut(&JSValue),
    {
        mark_fn(self.readable.as_js_value());
        mark_fn(self.writable.as_js_value());
    }
}

#[js_export]
pub struct DecompressionStream {
    readable: JSObject,
    writable: JSObject,
}

#[js_class]
impl DecompressionStream {
    #[js_method(constructor)]
    fn constructor(ctx: JSContext, format: String) -> JSResult<Self> {
        let format = CompressionFormat::parse(&format)?;
        let (readable, writable) = build_transform_pair(&ctx, format, true)?;
        Ok(Self { readable, writable })
    }

    #[js_method(getter)]
    fn readable(&self) -> JSObject {
        self.readable.clone()
    }

    #[js_method(getter)]
    fn writable(&self) -> JSObject {
        self.writable.clone()
    }

    #[js_method(gc_mark)]
    fn gc_mark_with<F>(&self, mut mark_fn: F)
    where
        F: FnMut(&JSValue),
    {
        mark_fn(self.readable.as_js_value());
        mark_fn(self.writable.as_js_value());
    }
}

pub fn init(ctx: &JSContext) -> JSResult<()> {
    ctx.register_class::<CompressionStream>()?;
    ctx.register_class::<DecompressionStream>()?;
    Ok(())
}