1#![doc = include_str!("../README.md")]
2
3use std::{
4 pin::Pin,
5 sync::{Arc, Mutex},
6 task::{Poll, Waker},
7};
8
9use futures_util::{Future, FutureExt, Stream};
10use pin_project_lite::pin_project;
11
12pub struct StreamEmitter<T> {
14 inner: Arc<Mutex<Inner<T>>>,
15}
16
17pub struct TryStreamEmitter<T, E> {
19 inner: Arc<Mutex<Inner<Result<T, E>>>>,
20}
21
22struct Inner<T> {
23 value: Option<T>,
24 waker: Option<Waker>,
25}
26
27pin_project! {
28 pub struct FnStream<T, Fut: Future<Output = ()>> {
30 #[pin]
31 fut: Fut,
32 inner: Arc<Mutex<Inner<T>>>,
33 }
34}
35
36pub fn fn_stream<T, Fut: Future<Output = ()>>(
56 func: impl FnOnce(StreamEmitter<T>) -> Fut,
57) -> FnStream<T, Fut> {
58 FnStream::new(func)
59}
60
61impl<T, Fut: Future<Output = ()>> FnStream<T, Fut> {
62 fn new<F: FnOnce(StreamEmitter<T>) -> Fut>(func: F) -> Self {
63 let inner = Arc::new(Mutex::new(Inner {
64 value: None,
65 waker: None,
66 }));
67 let emitter = StreamEmitter {
68 inner: inner.clone(),
69 };
70 let fut = func(emitter);
71 Self { fut, inner }
72 }
73}
74
75impl<T, Fut: Future<Output = ()>> Stream for FnStream<T, Fut> {
76 type Item = T;
77
78 fn poll_next(
79 self: Pin<&mut Self>,
80 cx: &mut std::task::Context<'_>,
81 ) -> Poll<Option<Self::Item>> {
82 let mut this = self.project();
83
84 this.inner.lock().expect("Mutex was poisoned").waker = Some(cx.waker().clone());
85 let r = this.fut.poll_unpin(cx);
86 match r {
87 std::task::Poll::Ready(()) => Poll::Ready(None),
88 std::task::Poll::Pending => {
89 let value = this.inner.lock().expect("Mutex was poisoned").value.take();
90 match value {
91 None => Poll::Pending,
92 Some(value) => Poll::Ready(Some(value)),
93 }
94 }
95 }
96 }
97}
98
99pub fn try_fn_stream<T, E, Fut: Future<Output = Result<(), E>>>(
128 func: impl FnOnce(TryStreamEmitter<T, E>) -> Fut,
129) -> TryFnStream<T, E, Fut> {
130 TryFnStream::new(func)
131}
132
133pin_project! {
134 pub struct TryFnStream<T, E, Fut: Future<Output = Result<(), E>>> {
136 is_err: bool,
137 #[pin]
138 fut: Fut,
139 inner: Arc<Mutex<Inner<Result<T, E>>>>,
140 }
141}
142
143impl<T, E, Fut: Future<Output = Result<(), E>>> TryFnStream<T, E, Fut> {
144 fn new<F: FnOnce(TryStreamEmitter<T, E>) -> Fut>(func: F) -> Self {
145 let inner = Arc::new(Mutex::new(Inner {
146 value: None,
147 waker: None,
148 }));
149 let emitter = TryStreamEmitter {
150 inner: inner.clone(),
151 };
152 let fut = func(emitter);
153 Self {
154 is_err: false,
155 fut,
156 inner,
157 }
158 }
159}
160
161impl<T, E, Fut: Future<Output = Result<(), E>>> Stream for TryFnStream<T, E, Fut> {
162 type Item = Result<T, E>;
163
164 fn poll_next(
165 self: Pin<&mut Self>,
166 cx: &mut std::task::Context<'_>,
167 ) -> Poll<Option<Self::Item>> {
168 if self.is_err {
169 return Poll::Ready(None);
170 }
171 let mut this = self.project();
172 this.inner.lock().expect("Mutex was poisoned").waker = Some(cx.waker().clone());
173 let r = this.fut.poll_unpin(cx);
174 match r {
175 std::task::Poll::Ready(Ok(())) => Poll::Ready(None),
176 std::task::Poll::Ready(Err(e)) => {
177 *this.is_err = true;
178 Poll::Ready(Some(Err(e)))
179 }
180 std::task::Poll::Pending => {
181 let value = this.inner.lock().expect("Mutex was poisoned").value.take();
182 match value {
183 None => Poll::Pending,
184 Some(value) => Poll::Ready(Some(value)),
185 }
186 }
187 }
188 }
189}
190
191impl<T> StreamEmitter<T> {
192 #[must_use = "Ensure that emit() is awaited"]
199 pub fn emit(&self, value: T) -> CollectFuture {
200 let mut inner = self.inner.lock().expect("Mutex was poisoned");
201 let inner = &mut *inner;
202 if inner.value.is_some() {
203 panic!("StreamEmitter::emit() was called without `.await`'ing result of previous emit")
204 }
205 inner.value = Some(value);
206 inner
207 .waker
208 .take()
209 .expect("StreamEmitter::emit() should only be called in context of Future::poll()")
210 .wake();
211 CollectFuture { polled: false }
212 }
213}
214
215impl<T, E> TryStreamEmitter<T, E> {
216 fn internal_emit(&self, res: Result<T, E>) -> CollectFuture {
217 let mut inner = self.inner.lock().expect("Mutex was poisoned");
218 let inner = &mut *inner;
219 if inner.value.is_some() {
220 panic!(
221 "TreStreamEmitter::emit/emit_err() was called without `.await`'ing result of previous collect"
222 )
223 }
224 inner.value = Some(res);
225 inner
226 .waker
227 .take()
228 .expect("TreStreamEmitter::emit/emit_err() should only be called in context of Future::poll()")
229 .wake();
230 CollectFuture { polled: false }
231 }
232
233 #[must_use = "Ensure that emit() is awaited"]
240 pub fn emit(&self, value: T) -> CollectFuture {
241 self.internal_emit(Ok(value))
242 }
243
244 #[must_use = "Ensure that emit_err() is awaited"]
251 pub fn emit_err(&self, err: E) -> CollectFuture {
252 self.internal_emit(Err(err))
253 }
254}
255
256pub struct CollectFuture {
258 polled: bool,
259}
260
261impl Future for CollectFuture {
262 type Output = ();
263
264 fn poll(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
265 if self.polled {
266 Poll::Ready(())
267 } else {
268 self.get_mut().polled = true;
269 Poll::Pending
270 }
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use std::io::ErrorKind;
277
278 use futures_util::{pin_mut, StreamExt};
279
280 use super::*;
281
282 #[test]
283 fn infallible_works() {
284 futures_executor::block_on(async {
285 let stream = fn_stream(|collector| async move {
286 eprintln!("stream 1");
287 collector.emit(1).await;
288 eprintln!("stream 2");
289 collector.emit(2).await;
290 eprintln!("stream 3");
291 });
292 pin_mut!(stream);
293 assert_eq!(Some(1), stream.next().await);
294 assert_eq!(Some(2), stream.next().await);
295 assert_eq!(None, stream.next().await);
296 });
297 }
298
299 #[test]
300 fn infallible_lifetime() {
301 let a = 1;
302 futures_executor::block_on(async {
303 let b = 2;
304 let a = &a;
305 let b = &b;
306 let stream = fn_stream(|collector| async move {
307 eprintln!("stream 1");
308 collector.emit(a).await;
309 eprintln!("stream 2");
310 collector.emit(b).await;
311 eprintln!("stream 3");
312 });
313 pin_mut!(stream);
314 assert_eq!(Some(a), stream.next().await);
315 assert_eq!(Some(b), stream.next().await);
316 assert_eq!(None, stream.next().await);
317 });
318 }
319
320 #[test]
321 #[should_panic]
322 fn infallible_panics_on_multiple_collects() {
323 futures_executor::block_on(async {
324 #[allow(unused_must_use)]
325 let stream = fn_stream(|collector| async move {
326 eprintln!("stream 1");
327 collector.emit(1);
328 collector.emit(2);
329 eprintln!("stream 3");
330 });
331 pin_mut!(stream);
332 assert_eq!(Some(1), stream.next().await);
333 assert_eq!(Some(2), stream.next().await);
334 assert_eq!(None, stream.next().await);
335 });
336 }
337
338 #[test]
339 fn fallible_works() {
340 futures_executor::block_on(async {
341 let stream = try_fn_stream(|collector| async move {
342 eprintln!("try stream 1");
343 collector.emit(1).await;
344 eprintln!("try stream 2");
345 collector.emit(2).await;
346 eprintln!("try stream 3");
347 Err(std::io::Error::from(ErrorKind::Other))
348 });
349 pin_mut!(stream);
350 assert_eq!(1, stream.next().await.unwrap().unwrap());
351 assert_eq!(2, stream.next().await.unwrap().unwrap());
352 assert!(stream.next().await.unwrap().is_err());
353 assert!(stream.next().await.is_none());
354 });
355 }
356
357 #[test]
358 fn fallible_emit_err_works() {
359 futures_executor::block_on(async {
360 let stream = try_fn_stream(|collector| async move {
361 eprintln!("try stream 1");
362 collector.emit(1).await;
363 eprintln!("try stream 2");
364 collector.emit(2).await;
365 eprintln!("try stream 3");
366 collector
367 .emit_err(std::io::Error::from(ErrorKind::Other))
368 .await;
369 eprintln!("try stream 4");
370 Err(std::io::Error::from(ErrorKind::Other))
371 });
372 pin_mut!(stream);
373 assert_eq!(1, stream.next().await.unwrap().unwrap());
374 assert_eq!(2, stream.next().await.unwrap().unwrap());
375 assert!(stream.next().await.unwrap().is_err());
376 assert!(stream.next().await.unwrap().is_err());
377 assert!(stream.next().await.is_none());
378 });
379 }
380
381 #[test]
382 fn method_async() {
383 struct St {
384 a: String,
385 }
386
387 impl St {
388 async fn f1(&self) -> impl Stream<Item = &str> {
389 self.f2().await
390 }
391
392 async fn f2(&self) -> impl Stream<Item = &str> {
393 fn_stream(|collector| async move {
394 collector.emit(self.a.as_str()).await;
395 collector.emit(self.a.as_str()).await;
396 collector.emit(self.a.as_str()).await;
397 })
398 }
399 }
400
401 futures_executor::block_on(async {
402 let l = St {
403 a: "qwe".to_owned(),
404 };
405 let s = l.f1().await;
406 let z: Vec<&str> = s.collect().await;
407 assert_eq!(z, ["qwe", "qwe", "qwe"]);
408 })
409 }
410}