worker-plus 0.0.15

A Rust SDK for writing Cloudflare Workers.
Documentation
use std::{
    pin::Pin,
    task::{Context, Poll},
};

use futures_util::{Stream, TryStreamExt};
use js_sys::{BigInt, Uint8Array};
use pin_project::pin_project;
use wasm_bindgen::{JsCast, JsValue};
use wasm_streams::readable::IntoStream;
use web_sys::ReadableStream;
use worker_sys::FixedLengthStream as FixedLengthStreamSys;

use crate::{Error, Result};

#[pin_project]
#[derive(Debug)]
pub struct ByteStream {
    #[pin]
    pub(crate) inner: IntoStream<'static>,
}

impl Stream for ByteStream {
    type Item = Result<Vec<u8>>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let this = self.project();
        let item = match futures_util::ready!(this.inner.poll_next(cx)) {
            Some(res) => res.map(Uint8Array::from).map_err(Error::from),
            None => return Poll::Ready(None),
        };

        Poll::Ready(match item {
            Ok(value) => Some(Ok(value.to_vec())),
            Err(e) if e.to_string() == "Error: aborted" => None,
            Err(e) => Some(Err(e)),
        })
    }
}

#[pin_project]
pub struct FixedLengthStream {
    length: u64,
    #[pin]
    bytes_read: u64,
    #[pin]
    inner: Pin<Box<dyn Stream<Item = Result<Vec<u8>>> + 'static>>,
}

impl FixedLengthStream {
    pub fn wrap(stream: impl Stream<Item = Result<Vec<u8>>> + 'static, length: u64) -> Self {
        Self {
            length,
            bytes_read: 0,
            inner: Box::pin(stream),
        }
    }
}

impl Stream for FixedLengthStream {
    type Item = Result<Vec<u8>>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let mut this = self.project();
        let item = if let Some(res) = futures_util::ready!(this.inner.poll_next(cx)) {
            let chunk = match res {
                Ok(chunk) => chunk,
                Err(err) => return Poll::Ready(Some(Err(err))),
            };

            *this.bytes_read += chunk.len() as u64;

            if *this.bytes_read > *this.length {
                let err = Error::from(format!(
                    "fixed length stream had different length than expected (expected {}, got {})",
                    *this.length, *this.bytes_read,
                ));
                Some(Err(err))
            } else {
                Some(Ok(chunk))
            }
        } else if *this.bytes_read != *this.length {
            let err = Error::from(format!(
                "fixed length stream had different length than expected (expected {}, got {})",
                *this.length, *this.bytes_read,
            ));
            Some(Err(err))
        } else {
            None
        };

        Poll::Ready(item)
    }
}

impl From<FixedLengthStream> for FixedLengthStreamSys {
    fn from(stream: FixedLengthStream) -> Self {
        let raw = if stream.length < u32::MAX as u64 {
            FixedLengthStreamSys::new(stream.length as u32)
        } else {
            FixedLengthStreamSys::new_big_int(BigInt::from(stream.length))
        };

        let js_stream = stream
            .map_ok(|item| -> Vec<u8> { item })
            .map_ok(|chunk| {
                let array = Uint8Array::new_with_length(chunk.len() as _);
                array.copy_from(&chunk);

                array.into()
            })
            .map_err(JsValue::from);

        let stream: ReadableStream = wasm_streams::ReadableStream::from_stream(js_stream)
            .as_raw()
            .clone()
            .unchecked_into();
        let _ = stream.pipe_to(&raw.writable());

        raw
    }
}