Skip to main content

orpc_procedure/
stream.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures_core::Stream;
6use pin_project_lite::pin_project;
7
8use crate::error::ProcedureError;
9use crate::output::DynOutput;
10
11/// Type-erased async stream of procedure results.
12///
13/// Unifies single-value responses (queries, mutations) and streaming responses
14/// (subscriptions) behind a common `Stream` interface.
15pub struct ProcedureStream {
16    inner: Pin<Box<dyn Stream<Item = Result<DynOutput, ProcedureError>> + Send>>,
17}
18
19impl ProcedureStream {
20    /// Create from an existing stream.
21    pub fn from_stream<S>(stream: S) -> Self
22    where
23        S: Stream<Item = Result<DynOutput, ProcedureError>> + Send + 'static,
24    {
25        ProcedureStream {
26            inner: Box::pin(stream),
27        }
28    }
29
30    /// Create from a single-value future (for queries and mutations).
31    pub fn from_future<F>(future: F) -> Self
32    where
33        F: Future<Output = Result<DynOutput, ProcedureError>> + Send + 'static,
34    {
35        ProcedureStream {
36            inner: Box::pin(FutureStream::Pending { future }),
37        }
38    }
39
40    /// Create an error stream that yields a single error.
41    pub fn error(err: ProcedureError) -> Self {
42        ProcedureStream::from_future(async move { Err(err) })
43    }
44}
45
46impl Stream for ProcedureStream {
47    type Item = Result<DynOutput, ProcedureError>;
48
49    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
50        self.inner.as_mut().poll_next(cx)
51    }
52
53    fn size_hint(&self) -> (usize, Option<usize>) {
54        self.inner.size_hint()
55    }
56}
57
58impl std::fmt::Debug for ProcedureStream {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        f.debug_struct("ProcedureStream").finish_non_exhaustive()
61    }
62}
63
64// Internal helper: wraps a Future as a single-item Stream.
65// Uses an enum to free the future's memory after completion.
66pin_project! {
67    #[project = FutureStreamProj]
68    enum FutureStream<F> {
69        Pending { #[pin] future: F },
70        Done,
71    }
72}
73
74impl<F> Stream for FutureStream<F>
75where
76    F: Future<Output = Result<DynOutput, ProcedureError>>,
77{
78    type Item = Result<DynOutput, ProcedureError>;
79
80    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
81        match self.as_mut().project() {
82            FutureStreamProj::Pending { future } => {
83                let result = std::task::ready!(future.poll(cx));
84                self.set(FutureStream::Done);
85                Poll::Ready(Some(result))
86            }
87            FutureStreamProj::Done => Poll::Ready(None),
88        }
89    }
90
91    fn size_hint(&self) -> (usize, Option<usize>) {
92        match self {
93            FutureStream::Pending { .. } => (1, Some(1)),
94            FutureStream::Done => (0, Some(0)),
95        }
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102    use futures_util::StreamExt;
103
104    #[tokio::test]
105    async fn from_future_ok() {
106        let stream = ProcedureStream::from_future(async { Ok(DynOutput::new(42u32)) });
107        let results: Vec<_> = stream.collect().await;
108        assert_eq!(results.len(), 1);
109        assert!(results[0].is_ok());
110        let value = results[0].as_ref().unwrap().to_value().unwrap();
111        assert_eq!(value, serde_json::json!(42));
112    }
113
114    #[tokio::test]
115    async fn from_future_err() {
116        let stream = ProcedureStream::from_future(async {
117            Err(ProcedureError::Resolver(Box::new(std::io::Error::new(
118                std::io::ErrorKind::NotFound,
119                "not found",
120            ))))
121        });
122        let results: Vec<_> = stream.collect().await;
123        assert_eq!(results.len(), 1);
124        assert!(results[0].is_err());
125    }
126
127    #[tokio::test]
128    async fn from_future_yields_none_after_first() {
129        let mut stream = ProcedureStream::from_future(async { Ok(DynOutput::new("hello")) });
130        assert!(stream.next().await.is_some());
131        assert!(stream.next().await.is_none());
132    }
133
134    #[tokio::test]
135    async fn from_stream_multi_item() {
136        let items = vec![
137            Ok(DynOutput::new(1u32)),
138            Ok(DynOutput::new(2u32)),
139            Ok(DynOutput::new(3u32)),
140        ];
141        let stream = ProcedureStream::from_stream(futures_util::stream::iter(items));
142        let results: Vec<_> = stream.collect().await;
143        assert_eq!(results.len(), 3);
144        for (i, result) in results.iter().enumerate() {
145            let value = result.as_ref().unwrap().to_value().unwrap();
146            assert_eq!(value, serde_json::json!(i as u32 + 1));
147        }
148    }
149
150    #[tokio::test]
151    async fn from_stream_empty() {
152        let stream = ProcedureStream::from_stream(futures_util::stream::empty());
153        let results: Vec<Result<DynOutput, ProcedureError>> = stream.collect().await;
154        assert!(results.is_empty());
155    }
156
157    #[tokio::test]
158    async fn error_stream() {
159        let stream = ProcedureStream::error(ProcedureError::Unwind(Box::new("panic!")));
160        let results: Vec<_> = stream.collect().await;
161        assert_eq!(results.len(), 1);
162        assert!(matches!(&results[0], Err(ProcedureError::Unwind(_))));
163    }
164
165    #[test]
166    fn procedure_stream_is_send() {
167        fn assert_send<T: Send>() {}
168        assert_send::<ProcedureStream>();
169    }
170
171    #[test]
172    fn size_hint_from_future() {
173        let stream = ProcedureStream::from_future(async { Ok(DynOutput::new(1u32)) });
174        let (lower, upper) = stream.size_hint();
175        assert_eq!(lower, 1);
176        assert_eq!(upper, Some(1));
177    }
178}