use crate::{IntoStatic, StreamError, stream::ByteStream, xrpc::XrpcRequest};
use alloc::boxed::Box;
use bytes::Bytes;
use core::{marker::PhantomData, pin::Pin};
use http::StatusCode;
use n0_future::{StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
#[cfg(all(not(target_arch = "wasm32"), feature = "std"))]
use std::path::Path;
#[cfg(not(target_arch = "wasm32"))]
type Boxed<T> = Pin<Box<dyn n0_future::Stream<Item = T> + Send>>;
#[cfg(target_arch = "wasm32")]
type Boxed<T> = Pin<Box<dyn n0_future::Stream<Item = T>>>;
pub trait XrpcProcedureStream {
const NSID: &'static str;
const ENCODING: &'static str;
type Frame<'de>;
type Request: XrpcRequest;
type Response: XrpcStreamResp;
fn encode_frame<'de>(data: Self::Frame<'de>) -> Result<Bytes, StreamError>
where
Self::Frame<'de>: Serialize,
{
Ok(Bytes::from_owner(
serde_ipld_dagcbor::to_vec(&data).map_err(StreamError::encode)?,
))
}
fn decode_frame<'de>(frame: &'de [u8]) -> Result<Self::Frame<'de>, StreamError>
where
Self::Frame<'de>: Deserialize<'de>,
{
Ok(serde_ipld_dagcbor::from_slice(frame).map_err(StreamError::decode)?)
}
}
pub trait XrpcStreamResp {
const NSID: &'static str;
const ENCODING: &'static str;
type Frame<'de>: IntoStatic;
fn encode_frame<'de>(data: Self::Frame<'de>) -> Result<Bytes, StreamError>
where
Self::Frame<'de>: Serialize,
{
Ok(Bytes::from_owner(
serde_ipld_dagcbor::to_vec(&data).map_err(StreamError::encode)?,
))
}
fn decode_frame<'de>(frame: &'de [u8]) -> Result<Self::Frame<'de>, StreamError>
where
Self::Frame<'de>: Deserialize<'de>,
{
Ok(serde_ipld_dagcbor::from_slice(frame).map_err(StreamError::decode)?)
}
}
#[repr(transparent)]
pub struct XrpcStreamFrame<F = ()> {
pub buffer: Bytes,
_marker: PhantomData<F>,
}
impl XrpcStreamFrame {
pub fn new(buffer: Bytes) -> Self {
Self {
buffer,
_marker: PhantomData,
}
}
}
impl<F> XrpcStreamFrame<F> {
pub fn new_typed<G>(buffer: Bytes) -> Self {
Self {
buffer,
_marker: PhantomData,
}
}
}
#[cfg(all(not(target_arch = "wasm32"), feature = "std"))]
pub async fn upload_stream(file: impl AsRef<Path>) -> Result<XrpcProcedureSend, tokio::io::Error> {
use tokio_util::io::ReaderStream;
let file = tokio::fs::File::open(file).await?;
let reader = ReaderStream::new(file);
let stream = reader
.map(|b| match b {
Ok(bytes) => Ok(XrpcStreamFrame::new(bytes)),
Err(err) => Err(StreamError::transport(err)),
})
.boxed();
Ok(XrpcProcedureSend(stream))
}
pub fn encode_stream<P: XrpcProcedureStream + 'static>(
s: Boxed<P::Frame<'static>>,
) -> XrpcProcedureSend<P::Frame<'static>>
where
<P as XrpcProcedureStream>::Frame<'static>: Serialize,
{
let stream =
s.map(|f| P::encode_frame(f).map(|b| XrpcStreamFrame::new_typed::<P::Frame<'_>>(b)));
XrpcProcedureSend(Box::pin(stream))
}
pub struct XrpcProcedureSend<F = ()>(pub Boxed<Result<XrpcStreamFrame<F>, StreamError>>);
pub struct XrpcProcedureSink<F = ()>(
pub Pin<Box<dyn n0_future::Sink<XrpcStreamFrame<F>, Error = StreamError> + Send>>,
);
pub struct XrpcResponseStream<F = ()> {
parts: http::response::Parts,
body: Boxed<Result<XrpcStreamFrame<F>, StreamError>>,
}
impl XrpcResponseStream {
pub fn from_bytestream(StreamingResponse { parts, body }: StreamingResponse) -> Self {
Self {
parts,
body: Box::pin(body.into_inner().map_ok(|b| XrpcStreamFrame::new(b))),
}
}
pub fn from_parts(parts: http::response::Parts, body: ByteStream) -> Self {
Self {
parts,
body: Box::pin(body.into_inner().map_ok(|b| XrpcStreamFrame::new(b))),
}
}
pub fn into_parts(self) -> (http::response::Parts, ByteStream) {
(
self.parts,
ByteStream::new(Box::pin(self.body.map_ok(|f| f.buffer))),
)
}
pub fn into_bytestream(self) -> ByteStream {
ByteStream::new(Box::pin(self.body.map_ok(|f| f.buffer)))
}
}
impl<F: XrpcStreamResp> XrpcResponseStream<F> {
pub fn from_stream(StreamingResponse { parts, body }: StreamingResponse) -> Self {
Self {
parts,
body: Box::pin(
body.into_inner()
.map_ok(|b| XrpcStreamFrame::new_typed::<F::Frame<'_>>(b)),
),
}
}
pub fn from_typed_parts(parts: http::response::Parts, body: ByteStream) -> Self {
Self {
parts,
body: Box::pin(
body.into_inner()
.map_ok(|b| XrpcStreamFrame::new_typed::<F::Frame<'_>>(b)),
),
}
}
}
impl<F: XrpcStreamResp + 'static> XrpcResponseStream<F> {
pub fn into_bytestream(self) -> ByteStream {
ByteStream::new(Box::pin(self.body.map_ok(|f| f.buffer)))
}
}
pub struct StreamingResponse {
parts: http::response::Parts,
body: ByteStream,
}
impl StreamingResponse {
pub fn new(parts: http::response::Parts, body: ByteStream) -> Self {
Self { parts, body }
}
pub fn status(&self) -> StatusCode {
self.parts.status
}
pub fn headers(&self) -> &http::HeaderMap {
&self.parts.headers
}
pub fn version(&self) -> http::Version {
self.parts.version
}
pub fn into_parts(self) -> (http::response::Parts, ByteStream) {
(self.parts, self.body)
}
pub fn body_mut(&mut self) -> &mut ByteStream {
&mut self.body
}
pub fn body(&self) -> &ByteStream {
&self.body
}
}
impl core::fmt::Debug for StreamingResponse {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("StreamingResponse")
.field("status", &self.parts.status)
.field("version", &self.parts.version)
.field("headers", &self.parts.headers)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use futures::stream;
#[test]
fn streaming_response_holds_parts_and_body() {
let response = http::Response::builder()
.status(StatusCode::OK)
.body(())
.unwrap();
let (parts, _) = response.into_parts();
let stream = stream::iter(vec![Ok(Bytes::from("test"))]);
let body = ByteStream::new(stream);
let response = StreamingResponse::new(parts, body);
assert_eq!(response.status(), StatusCode::OK);
}
}