1use core::{
7 future::Future,
8 marker::PhantomData,
9 pin::Pin,
10 ptr::NonNull,
11 task::{Context, Poll},
12};
13
14use futures_core::{FusedStream, Stream};
15
16use crate::util::Maybe;
17
18#[inline]
19unsafe fn channel_send<T>(channel: NonNull<Maybe<Option<T>>>, value: T) {
20 if channel.as_ref().replace(Some(value)).is_some() {
21 panic!("Invalid use of stream sender");
22 }
23}
24
25#[inline]
26unsafe fn channel_recv<T>(channel: &Maybe<Option<T>>) -> Option<T> {
27 channel.replace(None)
28}
29
30#[derive(Debug)]
32pub struct AsyncStreamScope<'a, T> {
33 channel: NonNull<Maybe<Option<T>>>,
34 _marker: PhantomData<&'a mut std::cell::Cell<T>>,
35}
36
37impl<T> AsyncStreamScope<'_, T> {
38 pub(crate) unsafe fn new(channel: &mut Maybe<Option<T>>) -> Self {
39 Self {
40 channel: NonNull::new_unchecked(channel),
41 _marker: PhantomData,
42 }
43 }
44
45 pub fn send<'a, 'b>(&'b mut self, value: T) -> AsyncStreamSend<'a, T>
48 where
49 'b: 'a,
50 {
51 unsafe {
52 channel_send(self.channel, value);
53 AsyncStreamSend {
54 channel: self.channel.as_ref(),
55 first: true,
56 _marker: PhantomData,
57 }
58 }
59 }
60}
61
62impl<T> Clone for AsyncStreamScope<'_, T> {
63 fn clone(&self) -> Self {
64 Self {
65 channel: self.channel,
66 _marker: PhantomData,
67 }
68 }
69}
70
71#[derive(Debug)]
73pub struct AsyncStreamSend<'a, T> {
74 channel: &'a Maybe<Option<T>>,
75 first: bool,
76 _marker: PhantomData<&'a mut std::cell::Cell<T>>,
77}
78
79impl<T> Future for AsyncStreamSend<'_, T> {
80 type Output = ();
81
82 fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
83 if self.first || unsafe { self.channel.as_ref().as_ref().is_some() } {
84 self.first = false;
87 Poll::Pending
88 } else {
89 Poll::Ready(())
90 }
91 }
92}
93
94unsafe impl<T> Send for AsyncStreamSend<'_, T> {}
95
96#[derive(Debug)]
98pub struct AsyncStream<'a, T, I, F> {
99 state: Maybe<AsyncStreamState<I, F>>,
100 channel: Maybe<Option<T>>,
101 _marker: PhantomData<&'a mut std::cell::Cell<T>>,
102}
103
104#[derive(Debug)]
105enum AsyncStreamState<I, F> {
106 Init(I),
107 Poll(F),
108 Complete,
109}
110
111pub fn make_stream<'a, T, I, F>(init: I) -> AsyncStream<'a, T, I, F>
113where
114 I: FnOnce(AsyncStreamScope<'a, T>) -> F + 'a,
115 F: Future<Output = ()> + 'a,
116{
117 AsyncStream {
118 state: AsyncStreamState::Init(init).into(),
119 channel: None.into(),
120 _marker: PhantomData,
121 }
122}
123
124impl<'a, T, I, F> Stream for AsyncStream<'a, T, I, F>
125where
126 I: FnOnce(AsyncStreamScope<'a, T>) -> F,
127 F: Future<Output = ()>,
128{
129 type Item = T;
130
131 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
132 unsafe {
133 let slf = Pin::get_unchecked_mut(self);
134 loop {
135 match slf.state.as_ref() {
136 AsyncStreamState::Init(_) => {
137 let init = match slf.state.load() {
138 AsyncStreamState::Init(init) => init,
139 _ => unreachable!(),
140 };
141 let fut = init(AsyncStreamScope::new(&mut slf.channel));
142 slf.state.store(AsyncStreamState::Poll(fut));
143 }
144 AsyncStreamState::Poll(_) => {
145 let poll = match slf.state.as_mut() {
146 AsyncStreamState::Poll(poll) => Pin::new_unchecked(poll),
147 _ => unreachable!(),
148 };
149 if let Poll::Ready(_) = poll.poll(cx) {
150 slf.state.replace(AsyncStreamState::Complete);
151 } else {
152 break if let Some(val) = channel_recv(&slf.channel) {
153 Poll::Ready(Some(val))
154 } else {
155 Poll::Pending
156 };
157 }
158 }
159 AsyncStreamState::Complete => {
160 break Poll::Ready(channel_recv(&slf.channel));
161 }
162 }
163 }
164 }
165 }
166}
167
168impl<'a, T, I, F> Drop for AsyncStream<'a, T, I, F> {
169 fn drop(&mut self) {
170 unsafe {
171 self.channel.clear();
172 self.state.clear()
173 };
174 }
175}
176
177impl<'a, T, I, F> FusedStream for AsyncStream<'a, T, I, F>
178where
179 I: FnOnce(AsyncStreamScope<'a, T>) -> F,
180 F: Future<Output = ()>,
181{
182 fn is_terminated(&self) -> bool {
183 matches!(unsafe { self.state.as_ref() }, AsyncStreamState::Complete)
184 }
185}
186
187#[derive(Debug)]
189pub struct TryAsyncStreamSend<'a, T, E, F> {
190 channel: NonNull<Maybe<Option<Result<T, E>>>>,
191 fut: F,
192 _marker: PhantomData<&'a mut std::cell::Cell<T>>,
193}
194
195unsafe impl<T, E, F> Send for TryAsyncStreamSend<'_, T, E, F> {}
196
197impl<'a, T, E, F> TryAsyncStreamSend<'a, T, E, F> {
198 pub fn new(sender: AsyncStreamScope<'a, Result<T, E>>, fut: F) -> Self {
200 Self {
201 channel: sender.channel,
202 fut,
203 _marker: PhantomData,
204 }
205 }
206}
207
208impl<'a, T, E, F> Future for TryAsyncStreamSend<'a, T, E, F>
209where
210 F: Future<Output = Result<(), E>>,
211{
212 type Output = ();
213 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
214 unsafe {
215 let channel = self.channel;
216 let fut = self.map_unchecked_mut(|s| &mut s.fut);
217 fut.poll(cx).map(|result| {
218 if let Err(err) = result {
219 channel_send(channel, Err(err));
220 }
221 })
222 }
223 }
224}
225
226#[macro_export]
228macro_rules! stream {
229 {$($block:tt)*} => {
230 $crate::make_stream(move |mut __sender| async move {
231 #[allow(unused)]
232 macro_rules! send {
233 ($v:expr) => {
234 __sender.send($v).await;
235 }
236 }
237 $($block)*
238 })
239 }
240}
241
242#[macro_export]
245macro_rules! try_stream {
246 {$($block:tt)*} => {
247 $crate::make_stream(move |mut __sender| {
248 $crate::TryAsyncStreamSend::new(__sender.clone(), async move {
249 macro_rules! send {
250 ($v:expr) => {
251 __sender.send(Ok($v)).await;
252 }
253 }
254 $($block)*
255 }
256 )
257 })
258 }
259}