rspc_procedure/
stream.rs

1use core::fmt;
2use std::{
3    cell::RefCell,
4    future::{poll_fn, Future},
5    panic::{catch_unwind, AssertUnwindSafe},
6    pin::Pin,
7    sync::Arc,
8    task::{ready, Context, Poll, Waker},
9};
10
11use futures_core::Stream;
12use pin_project_lite::pin_project;
13use serde::Serialize;
14
15use crate::{DynOutput, ProcedureError};
16
17thread_local! {
18    static CAN_FLUSH: RefCell<bool> = RefCell::default();
19    static SHOULD_FLUSH: RefCell<Option<bool>> = RefCell::default();
20}
21
22/// TODO
23pub async fn flush() {
24    if CAN_FLUSH.with(|v| *v.borrow()) {
25        let mut pending = true;
26        poll_fn(|_| {
27            if pending {
28                pending = false;
29                SHOULD_FLUSH.replace(Some(true));
30                return Poll::Pending;
31            }
32
33            Poll::Ready(())
34        })
35        .await;
36    }
37}
38
39enum Inner {
40    Dyn(Pin<Box<dyn DynReturnValue>>),
41    Value(Option<ProcedureError>),
42}
43
44/// TODO
45#[must_use = "`ProcedureStream` does nothing unless polled"]
46pub struct ProcedureStream {
47    inner: Inner,
48    // If `None` flushing is allowed.
49    // This is the default but will also be set after `flush` is called.
50    //
51    // If `Some` then `flush` must be called before the next value is yielded.
52    // Will poll until the first value and then return `Poll::Pending` and record the waker.
53    // The stored value will be yielded immediately after `flush` is called.
54    flush: Option<Waker>,
55    // This is set `true` if `Poll::Ready` is called while `flush` is `Some`.
56    // This informs the stream to yield the value immediately when `flush` is `None` again.
57    pending_value: bool, // TODO: Could we just check for a value on `inner`? Less chance of panic in the case of a bug.
58}
59
60impl From<ProcedureError> for ProcedureStream {
61    fn from(err: ProcedureError) -> Self {
62        Self {
63            inner: Inner::Value(Some(err)),
64            flush: None,
65            pending_value: false,
66        }
67    }
68}
69
70impl ProcedureStream {
71    /// TODO
72    pub fn from_stream<T, S>(s: S) -> Self
73    where
74        S: Stream<Item = Result<T, ProcedureError>> + Send + 'static,
75        T: Serialize + Send + Sync + 'static,
76    {
77        Self {
78            inner: Inner::Dyn(Box::pin(GenericDynReturnValue {
79                inner: s,
80                poll: |s, cx| s.poll_next(cx),
81                size_hint: |s| s.size_hint(),
82                resolved: |_| true,
83                as_value: |v| {
84                    DynOutput::new_serialize(
85                        v.as_mut()
86                            // Error's are caught before `as_value` is called.
87                            .expect("unreachable")
88                            .as_mut()
89                            // Attempted to access value when `Poll::Ready(None)` was not returned.
90                            .expect("unreachable"),
91                    )
92                },
93                flushed: false,
94                unwound: false,
95                value: None,
96            })),
97            flush: None,
98            pending_value: false,
99        }
100    }
101
102    /// TODO
103    pub fn from_future<T, F>(f: F) -> Self
104    where
105        F: Future<Output = Result<T, ProcedureError>> + Send + 'static,
106        T: Serialize + Send + Sync + 'static,
107    {
108        pin_project! {
109            #[project = ReprProj]
110            struct Repr<F> {
111                #[pin]
112                inner: Option<F>,
113            }
114        }
115
116        Self {
117            inner: Inner::Dyn(Box::pin(GenericDynReturnValue {
118                inner: Repr { inner: Some(f) },
119                poll: |f, cx| {
120                    let mut this = f.project();
121                    let v = match this.inner.as_mut().as_pin_mut() {
122                        Some(fut) => ready!(fut.poll(cx)),
123                        None => return Poll::Ready(None),
124                    };
125
126                    this.inner.set(None);
127                    Poll::Ready(Some(v))
128                },
129                size_hint: |f| {
130                    if f.inner.is_some() {
131                        (1, Some(1))
132                    } else {
133                        (0, Some(0))
134                    }
135                },
136                as_value: |v| {
137                    DynOutput::new_serialize(
138                        v.as_mut()
139                            // Error's are caught before `as_value` is called.
140                            .expect("unreachable")
141                            .as_mut()
142                            // Attempted to access value when `Poll::Ready(None)` was not returned.
143                            .expect("unreachable"),
144                    )
145                },
146                resolved: |f| f.inner.is_none(),
147                flushed: false,
148                unwound: false,
149                value: None,
150            })),
151            flush: None,
152            pending_value: false,
153        }
154    }
155
156    /// TODO
157    pub fn from_future_stream<T, F, S>(f: F) -> Self
158    where
159        F: Future<Output = Result<S, ProcedureError>> + Send + 'static,
160        S: Stream<Item = Result<T, ProcedureError>> + Send + 'static,
161        T: Serialize + Send + Sync + 'static,
162    {
163        pin_project! {
164            #[project = ReprProj]
165            enum Repr<F, S> {
166                Future {
167                    #[pin]
168                    inner: F,
169                },
170                Stream {
171                    #[pin]
172                    inner: S,
173                },
174            }
175        }
176
177        Self {
178            inner: Inner::Dyn(Box::pin(GenericDynReturnValue {
179                inner: Repr::<F, S>::Future { inner: f },
180                poll: |mut f, cx| loop {
181                    let this = f.as_mut().project();
182                    match this {
183                        ReprProj::Future { inner } => {
184                            let Poll::Ready(Ok(stream)) = inner.poll(cx) else {
185                                return Poll::Pending;
186                            };
187
188                            f.set(Repr::Stream { inner: stream });
189                            continue;
190                        }
191                        ReprProj::Stream { inner } => return inner.poll_next(cx),
192                    }
193                },
194                size_hint: |_| (1, Some(1)),
195                resolved: |f| matches!(f, Repr::Stream { .. }),
196                as_value: |v| {
197                    DynOutput::new_serialize(
198                        v.as_mut()
199                            // Error's are caught before `as_value` is called.
200                            .expect("unreachable")
201                            .as_mut()
202                            // Attempted to access value when `Poll::Ready(None)` was not returned.
203                            .expect("unreachable"),
204                    )
205                },
206                flushed: false,
207                unwound: false,
208                value: None,
209            })),
210            flush: None,
211            pending_value: false,
212        }
213    }
214
215    /// TODO
216    pub fn from_stream_value<T, S>(s: S) -> Self
217    where
218        S: Stream<Item = Result<T, ProcedureError>> + Send + 'static,
219        T: Send + Sync + 'static,
220    {
221        Self {
222            inner: Inner::Dyn(Box::pin(GenericDynReturnValue {
223                inner: s,
224                poll: |s, cx| s.poll_next(cx),
225                size_hint: |s| s.size_hint(),
226                resolved: |_| true,
227                // We passthrough the whole `Option` intentionally.
228                as_value: |v| DynOutput::new_value(v),
229                flushed: false,
230                unwound: false,
231                value: None,
232            })),
233            flush: None,
234            pending_value: false,
235        }
236    }
237
238    /// TODO
239    pub fn from_future_value<T, F>(f: F) -> Self
240    where
241        F: Future<Output = Result<T, ProcedureError>> + Send + 'static,
242        T: Send + Sync + 'static,
243    {
244        pin_project! {
245            #[project = ReprProj]
246            struct Repr<F> {
247                #[pin]
248                inner: Option<F>,
249            }
250        }
251
252        Self {
253            inner: Inner::Dyn(Box::pin(GenericDynReturnValue {
254                inner: Repr { inner: Some(f) },
255                poll: |f, cx| {
256                    let mut this = f.project();
257                    let v = match this.inner.as_mut().as_pin_mut() {
258                        Some(fut) => ready!(fut.poll(cx)),
259                        None => return Poll::Ready(None),
260                    };
261
262                    this.inner.set(None);
263                    Poll::Ready(Some(v))
264                },
265                size_hint: |f| {
266                    if f.inner.is_some() {
267                        (1, Some(1))
268                    } else {
269                        (0, Some(0))
270                    }
271                },
272                as_value: |v| DynOutput::new_value(v),
273                resolved: |f| f.inner.is_none(),
274                flushed: false,
275                unwound: false,
276                value: None,
277            })),
278            flush: None,
279            pending_value: false,
280        }
281    }
282
283    /// TODO
284    pub fn from_future_stream_value<T, F, S>(f: F) -> Self
285    where
286        F: Future<Output = Result<S, ProcedureError>> + Send + 'static,
287        S: Stream<Item = Result<T, ProcedureError>> + Send + 'static,
288        T: Send + Sync + 'static,
289    {
290        pin_project! {
291            #[project = ReprProj]
292            enum Repr<F, S> {
293                Future {
294                    #[pin]
295                    inner: F,
296                },
297                Stream {
298                    #[pin]
299                    inner: S,
300                },
301            }
302        }
303
304        Self {
305            inner: Inner::Dyn(Box::pin(GenericDynReturnValue {
306                inner: Repr::<F, S>::Future { inner: f },
307                poll: |mut f, cx| loop {
308                    let this = f.as_mut().project();
309                    match this {
310                        ReprProj::Future { inner } => {
311                            let Poll::Ready(Ok(stream)) = inner.poll(cx) else {
312                                return Poll::Pending;
313                            };
314
315                            f.set(Repr::Stream { inner: stream });
316                            continue;
317                        }
318                        ReprProj::Stream { inner } => return inner.poll_next(cx),
319                    }
320                },
321                size_hint: |_| (1, Some(1)),
322                resolved: |f| matches!(f, Repr::Stream { .. }),
323                as_value: |v| DynOutput::new_value(v),
324                flushed: false,
325                unwound: false,
326                value: None,
327            })),
328            flush: None,
329            pending_value: false,
330        }
331    }
332
333    /// By setting this the stream will delay returning any data until instructed by the caller (via `Self::stream`).
334    ///
335    /// This allows you to progress an entire runtime of streams until all of them are in a state ready to start returning responses.
336    /// This mechanism allows anything that could need to modify the HTTP response headers to do so before the body starts being streamed.
337    ///
338    /// # Behaviour
339    ///
340    /// `ProcedureStream` will poll the underlying stream until the first value is ready.
341    /// It will then return `Poll::Pending` and go inactive until `Self::stream` is called.
342    /// When polled for the first time after `Self::stream` is called if a value was already ready it will be immediately returned.
343    /// It is *guaranteed* that the stream will never yield `Poll::Ready` until `flush` is called if this is set.
344    ///
345    /// # Usage
346    ///
347    /// It's generally expected you will continue to poll the runtime until some criteria based on `Self::resolved` & `Self::flushable` is met on all streams.
348    /// Once this is met you can call `Self::stream` on all of the streams at once to begin streaming data.
349    ///
350    pub fn require_manual_stream(mut self) -> Self {
351        // TODO: When stablised replace with - https://doc.rust-lang.org/stable/std/task/struct.Waker.html#method.noop
352        struct NoOpWaker;
353        impl std::task::Wake for NoOpWaker {
354            fn wake(self: std::sync::Arc<Self>) {}
355        }
356
357        // This `Arc` is inefficient but `Waker::noop` is coming soon which will solve it.
358        self.flush = Some(Arc::new(NoOpWaker).into());
359        self
360    }
361
362    /// Start streaming data.
363    /// Refer to `Self::require_manual_stream` for more information.
364    pub fn stream(&mut self) {
365        if let Some(waker) = self.flush.take() {
366            waker.wake();
367        }
368    }
369
370    /// Will return `true` if the future has resolved.
371    ///
372    /// For a stream created via `Self::from_future*` this will be `true` once the future has resolved and for all other streams this will always be `true`.
373    pub fn resolved(&self) -> bool {
374        match &self.inner {
375            Inner::Dyn(stream) => stream.resolved(),
376            Inner::Value(_) => true,
377        }
378    }
379
380    /// Will return `true` if the stream is ready to start streaming data.
381    ///
382    /// This is `false` until the `flush` function is called by the user.
383    pub fn flushable(&self) -> bool {
384        match &self.inner {
385            Inner::Dyn(stream) => stream.flushed(),
386            Inner::Value(_) => false,
387        }
388    }
389
390    /// TODO
391    pub fn size_hint(&self) -> (usize, Option<usize>) {
392        match &self.inner {
393            Inner::Dyn(stream) => stream.size_hint(),
394            Inner::Value(_) => (1, Some(1)),
395        }
396    }
397
398    fn poll_inner(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
399        // Ensure the waker is up to date.
400        if let Some(waker) = &mut self.flush {
401            if !waker.will_wake(cx.waker()) {
402                self.flush = Some(cx.waker().clone());
403            }
404        }
405
406        if self.pending_value {
407            return if self.flush.is_none() {
408                // We have a queued value ready to be flushed.
409                self.pending_value = false;
410                Poll::Ready(Some(()))
411            } else {
412                // The async runtime would have no reason to be polling right now but we protect against it anyway.
413                Poll::Pending
414            };
415        }
416
417        match &mut self.inner {
418            Inner::Dyn(v) => match v.as_mut().poll_next_value(cx) {
419                Poll::Ready(v) => {
420                    if self.flush.is_none() {
421                        Poll::Ready(v)
422                    } else {
423                        match v {
424                            Some(v) => {
425                                self.pending_value = true;
426                                Poll::Pending
427                            }
428                            None => Poll::Ready(None),
429                        }
430                    }
431                }
432                Poll::Pending => Poll::Pending,
433            },
434            Inner::Value(v) => {
435                if self.flush.is_none() {
436                    // Poll::Ready(v.take().map(Err))
437                    todo!();
438                } else {
439                    Poll::Pending
440                }
441            }
442        }
443    }
444
445    /// TODO
446    pub fn poll_next(
447        &mut self,
448        cx: &mut Context<'_>,
449    ) -> Poll<Option<Result<DynOutput<'_>, ProcedureError>>> {
450        self.poll_inner(cx).map(|v| {
451            v.map(|_: ()| {
452                let Inner::Dyn(s) = &mut self.inner else {
453                    unreachable!(); // TODO: Handle this?
454                };
455                s.as_mut().value()
456            })
457        })
458    }
459
460    /// TODO
461    pub async fn next(&mut self) -> Option<Result<DynOutput<'_>, ProcedureError>> {
462        poll_fn(|cx| self.poll_inner(cx)).await.map(|_: ()| {
463            let Inner::Dyn(s) = &mut self.inner else {
464                unreachable!(); // TODO: Handle this?
465            };
466            s.as_mut().value()
467        })
468    }
469
470    /// TODO
471    // TODO: Should error be `String` type?
472    pub fn map<F: FnMut(Result<DynOutput, ProcedureError>) -> Result<T, String>, T>(
473        self,
474        map: F,
475    ) -> ProcedureStreamMap<F, T> {
476        ProcedureStreamMap { stream: self, map }
477    }
478}
479
480pub struct ProcedureStreamMap<F: FnMut(Result<DynOutput, ProcedureError>) -> Result<T, String>, T> {
481    stream: ProcedureStream,
482    map: F,
483}
484
485impl<F: FnMut(Result<DynOutput, ProcedureError>) -> Result<T, String>, T> ProcedureStreamMap<F, T> {
486    /// Start streaming data.
487    /// Refer to `Self::require_manual_stream` for more information.
488    pub fn stream(&mut self) {
489        self.stream.stream();
490    }
491
492    /// Will return `true` if the future has resolved.
493    ///
494    /// For a stream created via `Self::from_future*` this will be `true` once the future has resolved and for all other streams this will always be `true`.
495    pub fn resolved(&self) -> bool {
496        self.stream.resolved()
497    }
498
499    /// Will return `true` if the stream is ready to start streaming data.
500    ///
501    /// This is `false` until the `flush` function is called by the user.
502    pub fn flushable(&self) -> bool {
503        self.stream.flushable()
504    }
505}
506
507// TODO: Drop `Unpin` requirement
508impl<F: FnMut(Result<DynOutput, ProcedureError>) -> Result<T, String> + Unpin, T> Stream
509    for ProcedureStreamMap<F, T>
510{
511    type Item = T;
512
513    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
514        let this = self.get_mut();
515
516        this.stream.poll_inner(cx).map(|v| {
517            v.map(|_: ()| {
518                let Inner::Dyn(s) = &mut this.stream.inner else {
519                    unreachable!();
520                };
521
522                match (this.map)(s.as_mut().value()) {
523                    Ok(v) => v,
524                    // TODO: Exposing this error to the client or not?
525                    // TODO: Error type???
526                    Err(err) => {
527                        println!("Error serialzing {err:?}");
528                        todo!();
529                    }
530                }
531            })
532        })
533    }
534
535    fn size_hint(&self) -> (usize, Option<usize>) {
536        self.stream.size_hint()
537    }
538}
539
540impl fmt::Debug for ProcedureStream {
541    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
542        todo!();
543    }
544}
545
546trait DynReturnValue: Send {
547    fn poll_next_value<'a>(self: Pin<&'a mut Self>, cx: &mut Context<'_>) -> Poll<Option<()>>;
548    fn value(self: Pin<&mut Self>) -> Result<DynOutput<'_>, ProcedureError>;
549    fn size_hint(&self) -> (usize, Option<usize>);
550    fn resolved(&self) -> bool;
551    fn flushed(&self) -> bool;
552}
553
554pin_project! {
555    struct GenericDynReturnValue<S, T> {
556        #[pin]
557        inner: S,
558        // `Stream::poll`
559        poll: fn(Pin<&mut S>, &mut Context) -> Poll<Option<Result<T, ProcedureError>>>,
560        // `Stream::size_hint`
561        size_hint: fn(&S) -> (usize, Option<usize>),
562        // convert the current value to a `DynOutput`
563        as_value: fn(&mut Option<Result<T, ProcedureError>>) -> DynOutput<'_>,
564        // detect when the stream has finished it's future if it has one.
565        resolved: fn(&S) -> bool,
566        // has the user called `flushed` within it?
567        flushed: bool,
568        // has the user panicked?
569        unwound: bool,
570        // the last yielded value. We place `T` here so we can type-erase it and avoiding boxing every value.
571        // we hold `Result<_, ProcedureError>` for `ProcedureStream::require_manual_stream` to bepossible.
572        // Be extemely careful changing this type as it's used in `DynOutput`'s downcasting!
573        value: Option<Result<T, ProcedureError>>,
574    }
575}
576
577impl<S: Send, T: Send> DynReturnValue for GenericDynReturnValue<S, T> {
578    fn poll_next_value<'a>(mut self: Pin<&'a mut Self>, cx: &mut Context<'_>) -> Poll<Option<()>> {
579        if self.unwound {
580            // The stream is now done.
581            return Poll::Ready(None);
582        }
583
584        let this = self.as_mut().project();
585        let r = catch_unwind(AssertUnwindSafe(|| {
586            let _ = this.value.take(); // Reset value to ensure `take` being misused causes it to panic.
587            (this.poll)(this.inner, cx).map(|v| {
588                v.map(|v| {
589                    *this.value = Some(v);
590                    ()
591                })
592            })
593        }));
594
595        match r {
596            Ok(v) => v,
597            Err(err) => {
598                *this.unwound = true;
599                *this.value = Some(Err(ProcedureError::Unwind(err)));
600                Poll::Ready(Some(()))
601            }
602        }
603    }
604
605    fn value(self: Pin<&mut Self>) -> Result<DynOutput<'_>, ProcedureError> {
606        let this = self.project();
607        match this.value {
608            Some(Err(_)) => {
609                let Some(Err(err)) = std::mem::replace(this.value, None) else {
610                    unreachable!(); // checked above
611                };
612                Err(err)
613            }
614            v => Ok((this.as_value)(v)),
615        }
616    }
617
618    fn size_hint(&self) -> (usize, Option<usize>) {
619        (self.size_hint)(&self.inner)
620    }
621
622    fn resolved(&self) -> bool {
623        (self.resolved)(&self.inner)
624    }
625    fn flushed(&self) -> bool {
626        self.flushed
627    }
628}