suspend_channel/
async_stream.rs

1// This currently does not fly in miri for the same reason generators have issues:
2// https://github.com/rust-lang/unsafe-code-guidelines/issues/148
3// Each poll of the AsyncStream future creates a Pin and invalidates the
4// channel pointer contained in the associated AsyncStreamScope.
5
6use core::{
7    future::Future,
8    marker::PhantomData,
9    pin::Pin,
10    ptr::NonNull,
11    task::{Context, Poll},
12};
13
14use futures_core::{FusedStream, Stream};
15
16use crate::util::Maybe;
17
18#[inline]
19unsafe fn channel_send<T>(channel: NonNull<Maybe<Option<T>>>, value: T) {
20    if channel.as_ref().replace(Some(value)).is_some() {
21        panic!("Invalid use of stream sender");
22    }
23}
24
25#[inline]
26unsafe fn channel_recv<T>(channel: &Maybe<Option<T>>) -> Option<T> {
27    channel.replace(None)
28}
29
30/// A utility class for providing values to the stream.
31#[derive(Debug)]
32pub struct AsyncStreamScope<'a, T> {
33    channel: NonNull<Maybe<Option<T>>>,
34    _marker: PhantomData<&'a mut std::cell::Cell<T>>,
35}
36
37impl<T> AsyncStreamScope<'_, T> {
38    pub(crate) unsafe fn new(channel: &mut Maybe<Option<T>>) -> Self {
39        Self {
40            channel: NonNull::new_unchecked(channel),
41            _marker: PhantomData,
42        }
43    }
44
45    /// Dispatch a value to the stream, returning a [`Future`] which will resolve
46    /// when the value has been received.
47    pub fn send<'a, 'b>(&'b mut self, value: T) -> AsyncStreamSend<'a, T>
48    where
49        'b: 'a,
50    {
51        unsafe {
52            channel_send(self.channel, value);
53            AsyncStreamSend {
54                channel: self.channel.as_ref(),
55                first: true,
56                _marker: PhantomData,
57            }
58        }
59    }
60}
61
62impl<T> Clone for AsyncStreamScope<'_, T> {
63    fn clone(&self) -> Self {
64        Self {
65            channel: self.channel,
66            _marker: PhantomData,
67        }
68    }
69}
70
71/// A [`Future`] which resolves when the dispatched value has been received.
72#[derive(Debug)]
73pub struct AsyncStreamSend<'a, T> {
74    channel: &'a Maybe<Option<T>>,
75    first: bool,
76    _marker: PhantomData<&'a mut std::cell::Cell<T>>,
77}
78
79impl<T> Future for AsyncStreamSend<'_, T> {
80    type Output = ();
81
82    fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
83        if self.first || unsafe { self.channel.as_ref().as_ref().is_some() } {
84            // always wait on first poll - the sender has just been filled.
85            // remain pending while the lock is occupied
86            self.first = false;
87            Poll::Pending
88        } else {
89            Poll::Ready(())
90        }
91    }
92}
93
94unsafe impl<T> Send for AsyncStreamSend<'_, T> {}
95
96/// A [`Stream`] implementation wrapping a generator `Future`.
97#[derive(Debug)]
98pub struct AsyncStream<'a, T, I, F> {
99    state: Maybe<AsyncStreamState<I, F>>,
100    channel: Maybe<Option<T>>,
101    _marker: PhantomData<&'a mut std::cell::Cell<T>>,
102}
103
104#[derive(Debug)]
105enum AsyncStreamState<I, F> {
106    Init(I),
107    Poll(F),
108    Complete,
109}
110
111/// Construct a new [`AsyncStream`] from a generator function.
112pub fn make_stream<'a, T, I, F>(init: I) -> AsyncStream<'a, T, I, F>
113where
114    I: FnOnce(AsyncStreamScope<'a, T>) -> F + 'a,
115    F: Future<Output = ()> + 'a,
116{
117    AsyncStream {
118        state: AsyncStreamState::Init(init).into(),
119        channel: None.into(),
120        _marker: PhantomData,
121    }
122}
123
124impl<'a, T, I, F> Stream for AsyncStream<'a, T, I, F>
125where
126    I: FnOnce(AsyncStreamScope<'a, T>) -> F,
127    F: Future<Output = ()>,
128{
129    type Item = T;
130
131    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
132        unsafe {
133            let slf = Pin::get_unchecked_mut(self);
134            loop {
135                match slf.state.as_ref() {
136                    AsyncStreamState::Init(_) => {
137                        let init = match slf.state.load() {
138                            AsyncStreamState::Init(init) => init,
139                            _ => unreachable!(),
140                        };
141                        let fut = init(AsyncStreamScope::new(&mut slf.channel));
142                        slf.state.store(AsyncStreamState::Poll(fut));
143                    }
144                    AsyncStreamState::Poll(_) => {
145                        let poll = match slf.state.as_mut() {
146                            AsyncStreamState::Poll(poll) => Pin::new_unchecked(poll),
147                            _ => unreachable!(),
148                        };
149                        if let Poll::Ready(_) = poll.poll(cx) {
150                            slf.state.replace(AsyncStreamState::Complete);
151                        } else {
152                            break if let Some(val) = channel_recv(&slf.channel) {
153                                Poll::Ready(Some(val))
154                            } else {
155                                Poll::Pending
156                            };
157                        }
158                    }
159                    AsyncStreamState::Complete => {
160                        break Poll::Ready(channel_recv(&slf.channel));
161                    }
162                }
163            }
164        }
165    }
166}
167
168impl<'a, T, I, F> Drop for AsyncStream<'a, T, I, F> {
169    fn drop(&mut self) {
170        unsafe {
171            self.channel.clear();
172            self.state.clear()
173        };
174    }
175}
176
177impl<'a, T, I, F> FusedStream for AsyncStream<'a, T, I, F>
178where
179    I: FnOnce(AsyncStreamScope<'a, T>) -> F,
180    F: Future<Output = ()>,
181{
182    fn is_terminated(&self) -> bool {
183        matches!(unsafe { self.state.as_ref() }, AsyncStreamState::Complete)
184    }
185}
186
187/// A [`Future`] which resolves when the dispatched value has been received.
188#[derive(Debug)]
189pub struct TryAsyncStreamSend<'a, T, E, F> {
190    channel: NonNull<Maybe<Option<Result<T, E>>>>,
191    fut: F,
192    _marker: PhantomData<&'a mut std::cell::Cell<T>>,
193}
194
195unsafe impl<T, E, F> Send for TryAsyncStreamSend<'_, T, E, F> {}
196
197impl<'a, T, E, F> TryAsyncStreamSend<'a, T, E, F> {
198    /// Construct a new `TryAsyncStreamSend` from an `AsyncStreamScope`.
199    pub fn new(sender: AsyncStreamScope<'a, Result<T, E>>, fut: F) -> Self {
200        Self {
201            channel: sender.channel,
202            fut,
203            _marker: PhantomData,
204        }
205    }
206}
207
208impl<'a, T, E, F> Future for TryAsyncStreamSend<'a, T, E, F>
209where
210    F: Future<Output = Result<(), E>>,
211{
212    type Output = ();
213    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
214        unsafe {
215            let channel = self.channel;
216            let fut = self.map_unchecked_mut(|s| &mut s.fut);
217            fut.poll(cx).map(|result| {
218                if let Err(err) = result {
219                    channel_send(channel, Err(err));
220                }
221            })
222        }
223    }
224}
225
226/// A macro for constructing an async stream from a generator function.
227#[macro_export]
228macro_rules! stream {
229    {$($block:tt)*} => {
230        $crate::make_stream(move |mut __sender| async move {
231            #[allow(unused)]
232            macro_rules! send {
233                ($v:expr) => {
234                    __sender.send($v).await;
235                }
236            }
237            $($block)*
238        })
239    }
240}
241
242/// A macro for constructing an async stream of `Result<T, E>` from a generator
243/// function.
244#[macro_export]
245macro_rules! try_stream {
246    {$($block:tt)*} => {
247        $crate::make_stream(move |mut __sender| {
248            $crate::TryAsyncStreamSend::new(__sender.clone(), async move {
249                    macro_rules! send {
250                        ($v:expr) => {
251                            __sender.send(Ok($v)).await;
252                        }
253                    }
254                    $($block)*
255                }
256            )
257        })
258    }
259}