local_runtime/
concurrency.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    sync::{
5        atomic::{AtomicBool, Ordering},
6        Arc,
7    },
8    task::{Context, Poll, Wake, Waker},
9};
10
11use atomic_waker::AtomicWaker;
12use futures_core::Stream;
13
14struct FlagWaker {
15    waker: Arc<AtomicWaker>,
16    awoken: AtomicBool,
17}
18
19impl Wake for FlagWaker {
20    fn wake(self: Arc<Self>) {
21        self.set_awoken();
22        self.waker.wake();
23    }
24}
25
26impl FlagWaker {
27    pub(crate) fn new(waker: Arc<AtomicWaker>) -> Self {
28        Self {
29            waker,
30            // Initialize the flag to true so that the future gets polled the first time
31            awoken: AtomicBool::new(true),
32        }
33    }
34
35    pub(crate) fn waker_pair(waker: Arc<AtomicWaker>) -> (Arc<Self>, Waker) {
36        let this = Arc::new(Self::new(waker));
37        let waker = this.clone().into();
38        (this, waker)
39    }
40
41    pub(crate) fn check_awoken(&self) -> bool {
42        self.awoken.swap(false, Ordering::Relaxed)
43    }
44
45    pub(crate) fn set_awoken(&self) {
46        self.awoken.store(true, Ordering::Relaxed);
47    }
48}
49
50type PinFut<'a, T> = Pin<&'a mut dyn Future<Output = T>>;
51type PinStream<'a, T> = Pin<&'a mut dyn Stream<Item = T>>;
52
53enum Inflight<'a, T> {
54    Fut(PinFut<'a, T>),
55    Done(T),
56}
57
58impl<T> Inflight<'_, T> {
59    fn unwrap_done(self) -> T {
60        match self {
61            Inflight::Fut(_) => panic!("expected inflight future to be done"),
62            Inflight::Done(val) => val,
63        }
64    }
65}
66
67#[doc(hidden)]
68#[must_use = "Futures do nothing unless polled"]
69pub struct JoinFuture<'a, T, const N: usize> {
70    base_waker: Arc<AtomicWaker>,
71    inflight: Option<[Inflight<'a, T>; N]>,
72    wakers: [(Arc<FlagWaker>, Waker); N],
73}
74
75impl<'a, T, const N: usize> JoinFuture<'a, T, N> {
76    pub fn new(futures: [PinFut<'a, T>; N]) -> Self {
77        let base_waker = Arc::new(AtomicWaker::new());
78        Self {
79            inflight: Some(futures.map(Inflight::Fut)),
80            wakers: std::array::from_fn(|_| FlagWaker::waker_pair(base_waker.clone())),
81            base_waker,
82        }
83    }
84}
85
86impl<T: Unpin, const N: usize> Future for JoinFuture<'_, T, N> {
87    type Output = [T; N];
88
89    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
90        let this = self.get_mut();
91        this.base_waker.register(cx.waker());
92        poll_join(this.inflight.as_mut().unwrap(), &mut this.wakers)
93            .map(|_| this.inflight.take().unwrap().map(Inflight::unwrap_done))
94    }
95}
96
97fn poll_join<T>(inflights: &mut [Inflight<T>], wakers: &mut [(Arc<FlagWaker>, Waker)]) -> Poll<()> {
98    let mut out = Poll::Ready(());
99    for (inflight, (waker_data, waker)) in inflights.iter_mut().zip(wakers.iter_mut()) {
100        if let Inflight::Fut(fut) = inflight {
101            if waker_data.check_awoken() {
102                if let Poll::Ready(out) = fut.as_mut().poll(&mut Context::from_waker(waker)) {
103                    *inflight = Inflight::Done(out);
104                    continue;
105                }
106            }
107            out = Poll::Pending;
108        }
109    }
110    out
111}
112
113/// Poll multiple futures concurrently, returning a future that outputs an array of all results
114/// once all futures have completed.
115///
116/// # Minimal polling
117///
118/// This future will only poll each inner future when it is awoken, rather than polling all inner
119/// futures on each iteration.
120///
121/// # Caveat
122///
123/// The futures must all have the same output type, which must be `Unpin`.
124///
125/// # Examples
126///
127/// ```
128/// use local_runtime::join;
129///
130/// # local_runtime::block_on(async {
131/// let a = async { 1 };
132/// let b = async { 2 };
133/// let c = async { 3 };
134/// assert_eq!(join!(a, b, c).await, [1, 2, 3]);
135/// # })
136/// ```
137#[macro_export]
138macro_rules! join {
139    ($($fut:expr),+ $(,)?) => {
140        async { $crate::JoinFuture::new([$(std::pin::pin!($fut)),+]).await }
141    };
142}
143
144#[doc(hidden)]
145#[must_use = "Streams do nothing unless polled"]
146pub struct MergeFutureStream<'a, T, const N: usize> {
147    base_waker: Arc<AtomicWaker>,
148    futures: [Option<PinFut<'a, T>>; N],
149    wakers: [(Arc<FlagWaker>, Waker); N],
150    idx: usize,
151    none_count: usize,
152}
153
154impl<'a, T, const N: usize> MergeFutureStream<'a, T, N> {
155    pub fn new(futures: [PinFut<'a, T>; N]) -> Self {
156        let base_waker = Arc::new(AtomicWaker::new());
157        Self {
158            futures: futures.map(Some),
159            wakers: std::array::from_fn(|_| FlagWaker::waker_pair(base_waker.clone())),
160            idx: 0,
161            none_count: 0,
162            base_waker,
163        }
164    }
165}
166
167impl<T, const N: usize> Stream for MergeFutureStream<'_, T, N> {
168    type Item = T;
169
170    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
171        let this = self.get_mut();
172        this.base_waker.register(cx.waker());
173        poll_merged(
174            &mut this.futures,
175            &mut this.wakers,
176            &mut this.idx,
177            &mut this.none_count,
178            |fut, cx| fut.as_mut().poll(cx),
179            |x| Some(x),
180            |_| true,
181        )
182    }
183}
184
185#[allow(clippy::too_many_arguments)]
186fn poll_merged<P, O, T, PF, OF, NF>(
187    pollers: &mut [Option<P>],
188    wakers: &mut [(Arc<FlagWaker>, Waker)],
189    idx: &mut usize,
190    none_count: &mut usize,
191    mut poll_fn: PF,
192    mut opt_fn: OF,
193    mut none_fn: NF,
194) -> Poll<Option<T>>
195where
196    PF: FnMut(&mut P, &mut Context) -> Poll<O>,
197    OF: FnMut(O) -> Option<T>,
198    NF: FnMut(&O) -> bool,
199{
200    let len = pollers.len();
201
202    let (futs_past, futs_remain) = pollers.split_at_mut(*idx);
203    let (wakers_past, wakers_remain) = wakers.split_at_mut(*idx);
204    let iter_past = futs_past.iter_mut().zip(wakers_past.iter_mut());
205    let iter_remain = futs_remain.iter_mut().zip(wakers_remain.iter_mut());
206    // Prioritize the futures we haven't seen yet
207    let iter = iter_remain.chain(iter_past);
208
209    for (poller_opt, (waker_data, waker)) in iter {
210        if let Some(poller) = poller_opt {
211            if waker_data.check_awoken() {
212                if let Poll::Ready(out) = poll_fn(poller, &mut Context::from_waker(waker)) {
213                    if none_fn(&out) {
214                        *poller_opt = None;
215                        *none_count += 1;
216                    }
217                    if let Some(ret) = opt_fn(out) {
218                        // Set the awoken flag so that the next time we poll, we'll start by
219                        // polling the future/stream that just yielded a value
220                        waker_data.set_awoken();
221                        return Poll::Ready(Some(ret));
222                    }
223                }
224            }
225        }
226        // Update index
227        *idx = (*idx + 1) % len;
228        // If all the futures/streams have terminated, end the stream by returning none
229        if *none_count == len {
230            return Poll::Ready(None);
231        }
232    }
233    Poll::Pending
234}
235
236/// Poll the futures concurrently and return their outputs as a stream.
237///
238/// Produces a stream that yields `N` values, where `N` is the number of merged futures. The
239/// outputs will be returned in the order in which the futures completed.
240///
241/// # Minimal polling
242///
243/// This stream will only poll each inner future when it is awoken, rather than polling all
244/// inner futures on each iteration.
245///
246/// # Pinning
247///
248/// The input futures to this macro must be pinned to the local context via [`pin`](std::pin::pin).
249///
250/// # Examples
251///
252/// ```
253/// use std::time::Duration;
254/// use std::pin::pin;
255/// use futures_lite::StreamExt;
256/// use local_runtime::time::sleep;
257/// use local_runtime::merge_futures;
258///
259/// # local_runtime::block_on(async {
260/// let a = pin!(async { 1 });
261/// let b = pin!(async {
262///     sleep(Duration::from_millis(5)).await;
263///     2
264/// });
265/// let c = pin!(async {
266///     sleep(Duration::from_millis(3)).await;
267///     3
268/// });
269/// let mut stream = merge_futures!(a, b, c);
270/// while let Some(x) = stream.next().await {
271///     // Expect the values to be: 1, 3, 5
272///     println!("Future returned: {x}");
273/// }
274/// # })
275/// ```
276#[macro_export]
277macro_rules! merge_futures {
278    ($($fut:expr),+ $(,)?) => {
279        $crate::MergeFutureStream::new([$($fut),+])
280    };
281}
282
283#[doc(hidden)]
284#[must_use = "Streams do nothing unless polled"]
285pub struct MergeStream<'a, T, const N: usize> {
286    base_waker: Arc<AtomicWaker>,
287    streams: [Option<PinStream<'a, T>>; N],
288    wakers: [(Arc<FlagWaker>, Waker); N],
289    idx: usize,
290    none_count: usize,
291}
292
293impl<'a, T, const N: usize> MergeStream<'a, T, N> {
294    pub fn new(streams: [PinStream<'a, T>; N]) -> Self {
295        let base_waker = Arc::new(AtomicWaker::new());
296        Self {
297            streams: streams.map(Some),
298            wakers: std::array::from_fn(|_| FlagWaker::waker_pair(base_waker.clone())),
299            idx: 0,
300            none_count: 0,
301            base_waker,
302        }
303    }
304}
305
306impl<T, const N: usize> Stream for MergeStream<'_, T, N> {
307    type Item = T;
308
309    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
310        let this = self.get_mut();
311        this.base_waker.register(cx.waker());
312        poll_merged(
313            &mut this.streams,
314            &mut this.wakers,
315            &mut this.idx,
316            &mut this.none_count,
317            |fut, cx| fut.as_mut().poll_next(cx),
318            |o| o,
319            |o| o.is_none(),
320        )
321    }
322}
323
324/// Run the streams concurrently and return their outputs one at a time.
325///
326/// Produces a stream that yields the outputs of the inner streams as they become available,
327/// effectively interleaving the inner streams.
328///
329/// # Minimal polling
330///
331/// This stream will only poll each inner stream when it is awoken, rather than polling all inner
332/// streams on each iteration.
333///
334/// # Pinning
335///
336/// The input streams to this macro must be pinned to the local context via [`pin`](std::pin::pin).
337///
338/// # Examples
339///
340/// ```
341/// use std::time::Duration;
342/// use std::pin::pin;
343/// use futures_lite::{Stream, StreamExt};
344/// use local_runtime::time::Periodic;
345/// use local_runtime::merge_streams;
346///
347/// # local_runtime::block_on(async {
348/// let a = pin!(Periodic::periodic(Duration::from_millis(70)).map(|_| 1u8));
349/// let b = pin!(Periodic::periodic(Duration::from_millis(30)).map(|_| 2u8));
350/// let stream = merge_streams!(a, b);
351/// assert_eq!(stream.take(6).collect::<Vec<_>>().await, &[2, 2, 1, 2, 2, 1]);
352/// # })
353/// ```
354#[macro_export]
355macro_rules! merge_streams {
356    ($($fut:expr),+ $(,)?) => {
357        $crate::MergeStream::new([$($fut),+])
358    };
359}
360
361#[cfg(test)]
362mod tests {
363    use crate::test::MockWaker;
364
365    use super::*;
366
367    #[test]
368    fn flag_waker_multiple_wakers() {
369        // Test that flag waker works even when the inner waker is swapped
370        let wk1 = Arc::new(MockWaker::default());
371        let wk2 = Arc::new(MockWaker::default());
372        let atomic_waker = Arc::new(AtomicWaker::new());
373        let (flag_waker_data, flag_waker) = FlagWaker::waker_pair(atomic_waker.clone());
374
375        // The waker flag should be initialized as true
376        assert!(flag_waker_data.check_awoken());
377        assert!(!flag_waker_data.awoken.load(Ordering::Relaxed));
378        atomic_waker.register(&wk1.clone().into());
379        flag_waker.wake_by_ref();
380        assert!(wk1.get());
381
382        // After calling wake_by_ref(), the flag should be set to true
383        assert!(flag_waker_data.check_awoken());
384        assert!(!flag_waker_data.awoken.load(Ordering::Relaxed));
385        atomic_waker.register(&wk2.clone().into());
386        flag_waker.wake_by_ref();
387        assert!(wk2.get());
388    }
389}