1use std::{
2 future::Future,
3 pin::Pin,
4 sync::{
5 atomic::{AtomicBool, Ordering},
6 Arc,
7 },
8 task::{Context, Poll, Wake, Waker},
9};
10
11use atomic_waker::AtomicWaker;
12use futures_core::Stream;
13
14struct FlagWaker {
15 waker: Arc<AtomicWaker>,
16 awoken: AtomicBool,
17}
18
19impl Wake for FlagWaker {
20 fn wake(self: Arc<Self>) {
21 self.set_awoken();
22 self.waker.wake();
23 }
24}
25
26impl FlagWaker {
27 pub(crate) fn new(waker: Arc<AtomicWaker>) -> Self {
28 Self {
29 waker,
30 awoken: AtomicBool::new(true),
32 }
33 }
34
35 pub(crate) fn waker_pair(waker: Arc<AtomicWaker>) -> (Arc<Self>, Waker) {
36 let this = Arc::new(Self::new(waker));
37 let waker = this.clone().into();
38 (this, waker)
39 }
40
41 pub(crate) fn check_awoken(&self) -> bool {
42 self.awoken.swap(false, Ordering::Relaxed)
43 }
44
45 pub(crate) fn set_awoken(&self) {
46 self.awoken.store(true, Ordering::Relaxed);
47 }
48}
49
50type PinFut<'a, T> = Pin<&'a mut dyn Future<Output = T>>;
51type PinStream<'a, T> = Pin<&'a mut dyn Stream<Item = T>>;
52
53enum Inflight<'a, T> {
54 Fut(PinFut<'a, T>),
55 Done(T),
56}
57
58impl<T> Inflight<'_, T> {
59 fn unwrap_done(self) -> T {
60 match self {
61 Inflight::Fut(_) => panic!("expected inflight future to be done"),
62 Inflight::Done(val) => val,
63 }
64 }
65}
66
67#[doc(hidden)]
68#[must_use = "Futures do nothing unless polled"]
69pub struct JoinFuture<'a, T, const N: usize> {
70 base_waker: Arc<AtomicWaker>,
71 inflight: Option<[Inflight<'a, T>; N]>,
72 wakers: [(Arc<FlagWaker>, Waker); N],
73}
74
75impl<'a, T, const N: usize> JoinFuture<'a, T, N> {
76 pub fn new(futures: [PinFut<'a, T>; N]) -> Self {
77 let base_waker = Arc::new(AtomicWaker::new());
78 Self {
79 inflight: Some(futures.map(Inflight::Fut)),
80 wakers: std::array::from_fn(|_| FlagWaker::waker_pair(base_waker.clone())),
81 base_waker,
82 }
83 }
84}
85
86impl<T: Unpin, const N: usize> Future for JoinFuture<'_, T, N> {
87 type Output = [T; N];
88
89 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
90 let this = self.get_mut();
91 this.base_waker.register(cx.waker());
92 poll_join(this.inflight.as_mut().unwrap(), &mut this.wakers)
93 .map(|_| this.inflight.take().unwrap().map(Inflight::unwrap_done))
94 }
95}
96
97fn poll_join<T>(inflights: &mut [Inflight<T>], wakers: &mut [(Arc<FlagWaker>, Waker)]) -> Poll<()> {
98 let mut out = Poll::Ready(());
99 for (inflight, (waker_data, waker)) in inflights.iter_mut().zip(wakers.iter_mut()) {
100 if let Inflight::Fut(fut) = inflight {
101 if waker_data.check_awoken() {
102 if let Poll::Ready(out) = fut.as_mut().poll(&mut Context::from_waker(waker)) {
103 *inflight = Inflight::Done(out);
104 continue;
105 }
106 }
107 out = Poll::Pending;
108 }
109 }
110 out
111}
112
113#[macro_export]
138macro_rules! join {
139 ($($fut:expr),+ $(,)?) => {
140 async { $crate::JoinFuture::new([$(std::pin::pin!($fut)),+]).await }
141 };
142}
143
144#[doc(hidden)]
145#[must_use = "Streams do nothing unless polled"]
146pub struct MergeFutureStream<'a, T, const N: usize> {
147 base_waker: Arc<AtomicWaker>,
148 futures: [Option<PinFut<'a, T>>; N],
149 wakers: [(Arc<FlagWaker>, Waker); N],
150 idx: usize,
151 none_count: usize,
152}
153
154impl<'a, T, const N: usize> MergeFutureStream<'a, T, N> {
155 pub fn new(futures: [PinFut<'a, T>; N]) -> Self {
156 let base_waker = Arc::new(AtomicWaker::new());
157 Self {
158 futures: futures.map(Some),
159 wakers: std::array::from_fn(|_| FlagWaker::waker_pair(base_waker.clone())),
160 idx: 0,
161 none_count: 0,
162 base_waker,
163 }
164 }
165}
166
167impl<T, const N: usize> Stream for MergeFutureStream<'_, T, N> {
168 type Item = T;
169
170 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
171 let this = self.get_mut();
172 this.base_waker.register(cx.waker());
173 poll_merged(
174 &mut this.futures,
175 &mut this.wakers,
176 &mut this.idx,
177 &mut this.none_count,
178 |fut, cx| fut.as_mut().poll(cx),
179 |x| Some(x),
180 |_| true,
181 )
182 }
183}
184
185#[allow(clippy::too_many_arguments)]
186fn poll_merged<P, O, T, PF, OF, NF>(
187 pollers: &mut [Option<P>],
188 wakers: &mut [(Arc<FlagWaker>, Waker)],
189 idx: &mut usize,
190 none_count: &mut usize,
191 mut poll_fn: PF,
192 mut opt_fn: OF,
193 mut none_fn: NF,
194) -> Poll<Option<T>>
195where
196 PF: FnMut(&mut P, &mut Context) -> Poll<O>,
197 OF: FnMut(O) -> Option<T>,
198 NF: FnMut(&O) -> bool,
199{
200 let len = pollers.len();
201
202 let (futs_past, futs_remain) = pollers.split_at_mut(*idx);
203 let (wakers_past, wakers_remain) = wakers.split_at_mut(*idx);
204 let iter_past = futs_past.iter_mut().zip(wakers_past.iter_mut());
205 let iter_remain = futs_remain.iter_mut().zip(wakers_remain.iter_mut());
206 let iter = iter_remain.chain(iter_past);
208
209 for (poller_opt, (waker_data, waker)) in iter {
210 if let Some(poller) = poller_opt {
211 if waker_data.check_awoken() {
212 if let Poll::Ready(out) = poll_fn(poller, &mut Context::from_waker(waker)) {
213 if none_fn(&out) {
214 *poller_opt = None;
215 *none_count += 1;
216 }
217 if let Some(ret) = opt_fn(out) {
218 waker_data.set_awoken();
221 return Poll::Ready(Some(ret));
222 }
223 }
224 }
225 }
226 *idx = (*idx + 1) % len;
228 if *none_count == len {
230 return Poll::Ready(None);
231 }
232 }
233 Poll::Pending
234}
235
236#[macro_export]
277macro_rules! merge_futures {
278 ($($fut:expr),+ $(,)?) => {
279 $crate::MergeFutureStream::new([$($fut),+])
280 };
281}
282
283#[doc(hidden)]
284#[must_use = "Streams do nothing unless polled"]
285pub struct MergeStream<'a, T, const N: usize> {
286 base_waker: Arc<AtomicWaker>,
287 streams: [Option<PinStream<'a, T>>; N],
288 wakers: [(Arc<FlagWaker>, Waker); N],
289 idx: usize,
290 none_count: usize,
291}
292
293impl<'a, T, const N: usize> MergeStream<'a, T, N> {
294 pub fn new(streams: [PinStream<'a, T>; N]) -> Self {
295 let base_waker = Arc::new(AtomicWaker::new());
296 Self {
297 streams: streams.map(Some),
298 wakers: std::array::from_fn(|_| FlagWaker::waker_pair(base_waker.clone())),
299 idx: 0,
300 none_count: 0,
301 base_waker,
302 }
303 }
304}
305
306impl<T, const N: usize> Stream for MergeStream<'_, T, N> {
307 type Item = T;
308
309 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
310 let this = self.get_mut();
311 this.base_waker.register(cx.waker());
312 poll_merged(
313 &mut this.streams,
314 &mut this.wakers,
315 &mut this.idx,
316 &mut this.none_count,
317 |fut, cx| fut.as_mut().poll_next(cx),
318 |o| o,
319 |o| o.is_none(),
320 )
321 }
322}
323
324#[macro_export]
355macro_rules! merge_streams {
356 ($($fut:expr),+ $(,)?) => {
357 $crate::MergeStream::new([$($fut),+])
358 };
359}
360
361#[cfg(test)]
362mod tests {
363 use crate::test::MockWaker;
364
365 use super::*;
366
367 #[test]
368 fn flag_waker_multiple_wakers() {
369 let wk1 = Arc::new(MockWaker::default());
371 let wk2 = Arc::new(MockWaker::default());
372 let atomic_waker = Arc::new(AtomicWaker::new());
373 let (flag_waker_data, flag_waker) = FlagWaker::waker_pair(atomic_waker.clone());
374
375 assert!(flag_waker_data.check_awoken());
377 assert!(!flag_waker_data.awoken.load(Ordering::Relaxed));
378 atomic_waker.register(&wk1.clone().into());
379 flag_waker.wake_by_ref();
380 assert!(wk1.get());
381
382 assert!(flag_waker_data.check_awoken());
384 assert!(!flag_waker_data.awoken.load(Ordering::Relaxed));
385 atomic_waker.register(&wk2.clone().into());
386 flag_waker.wake_by_ref();
387 assert!(wk2.get());
388 }
389}