hala_future/
batching.rs

1use std::{
2    future::Future,
3    ops,
4    pin::Pin,
5    ptr::null_mut,
6    sync::{
7        atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering},
8        Arc,
9    },
10    task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
11};
12
13use dashmap::DashMap;
14use futures::future::{poll_fn, BoxFuture};
15
16use hala_lockfree::queue::Queue;
17
18/// A set to handle the reigstration/deregistration of pending future.
19struct PendingFutures<R> {
20    futures: DashMap<usize, BoxFuture<'static, R>>,
21}
22
23impl<R> PendingFutures<R> {
24    fn insert(&self, id: usize, fut: BoxFuture<'static, R>) {
25        self.futures.insert(id, fut);
26    }
27
28    fn remove(&self, id: usize) -> Option<BoxFuture<'static, R>> {
29        self.futures.remove(&id).map(|(_, fut)| fut)
30    }
31}
32
33impl<R> Default for PendingFutures<R> {
34    fn default() -> Self {
35        Self {
36            futures: DashMap::new(),
37        }
38    }
39}
40
41#[derive(Default)]
42struct WakerHost {
43    waker: AtomicPtr<Waker>,
44}
45
46impl WakerHost {
47    fn wake(&self) {
48        if let Some(waker) = self.remove_waker() {
49            waker.wake();
50        }
51    }
52
53    fn remove_waker(&self) -> Option<Box<Waker>> {
54        loop {
55            let waker_ptr = self.waker.load(Ordering::Acquire);
56
57            if waker_ptr == null_mut() {
58                return None;
59            }
60
61            if self
62                .waker
63                .compare_exchange_weak(waker_ptr, null_mut(), Ordering::AcqRel, Ordering::Relaxed)
64                .is_err()
65            {
66                continue;
67            }
68
69            return Some(unsafe { Box::from_raw(waker_ptr) });
70        }
71    }
72
73    fn add_waker(&self, waker: Waker) {
74        let waker_ptr = Box::into_raw(Box::new(waker));
75
76        let old = self.waker.swap(waker_ptr, Ordering::AcqRel);
77
78        if old != null_mut() {
79            let waker = unsafe { Box::from_raw(old) };
80
81            drop(waker);
82
83            // TODO: check the data race!!!
84            log::trace!("Batching is awakened unintentionally !!!.");
85        }
86    }
87}
88
89#[derive(Clone)]
90struct BatcherWaker {
91    future_id: usize,
92    /// Current set of ready futures
93    ready_futures: Arc<Queue<usize>>,
94    /// Raw batch future waker
95    raw_waker: Arc<WakerHost>,
96}
97
98#[inline(always)]
99unsafe fn batch_future_waker_clone(data: *const ()) -> RawWaker {
100    let waker = Box::from_raw(data as *mut BatcherWaker);
101
102    let waker_cloned = waker.clone();
103
104    _ = Box::into_raw(waker);
105
106    RawWaker::new(Box::into_raw(waker_cloned) as *const (), &WAKER_VTABLE)
107}
108
109#[inline(always)]
110unsafe fn batch_future_waker_wake(data: *const ()) {
111    let waker = Box::from_raw(data as *mut BatcherWaker);
112
113    waker.ready_futures.push(waker.future_id);
114
115    waker.raw_waker.wake();
116}
117
118#[inline(always)]
119unsafe fn batch_future_waker_wake_by_ref(data: *const ()) {
120    let waker = Box::from_raw(data as *mut BatcherWaker);
121
122    waker.ready_futures.push(waker.future_id);
123
124    waker.raw_waker.wake();
125
126    _ = Box::into_raw(waker);
127}
128
129#[inline(always)]
130unsafe fn batch_future_waker_drop(data: *const ()) {
131    _ = Box::from_raw(data as *mut BatcherWaker);
132}
133
134const WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
135    batch_future_waker_clone,
136    batch_future_waker_wake,
137    batch_future_waker_wake_by_ref,
138    batch_future_waker_drop,
139);
140
141fn new_batcher_waker<Fut>(future_id: usize, batch_future: FutureBatcher<Fut>) -> Waker {
142    let boxed = Box::new(BatcherWaker {
143        future_id,
144        ready_futures: batch_future.wakeup_futures,
145        raw_waker: batch_future.raw_waker,
146    });
147
148    unsafe {
149        Waker::from_raw(RawWaker::new(
150            Box::into_raw(boxed) as *const (),
151            &WAKER_VTABLE,
152        ))
153    }
154}
155
156/// A lockfree processor to batch poll the same type of futures.
157pub struct FutureBatcher<R> {
158    /// The generator for the wrapped future id.
159    idgen: Arc<AtomicUsize>,
160    /// Current set of pending futures
161    pending_futures: Arc<PendingFutures<R>>,
162    /// Current set of ready futures
163    wakeup_futures: Arc<Queue<usize>>,
164    /// Raw batch future waker
165    raw_waker: Arc<WakerHost>,
166    /// poll thread counter.
167    await_counter: Arc<AtomicUsize>,
168    /// closed flag.
169    closed: Arc<AtomicBool>,
170}
171
172unsafe impl<R> Send for FutureBatcher<R> {}
173unsafe impl<R> Sync for FutureBatcher<R> {}
174
175impl<R> Clone for FutureBatcher<R> {
176    fn clone(&self) -> Self {
177        Self {
178            idgen: self.idgen.clone(),
179            pending_futures: self.pending_futures.clone(),
180            wakeup_futures: self.wakeup_futures.clone(),
181            raw_waker: self.raw_waker.clone(),
182            await_counter: self.await_counter.clone(),
183            closed: self.closed.clone(),
184        }
185    }
186}
187
188impl<R> Default for FutureBatcher<R> {
189    fn default() -> Self {
190        Self::new()
191    }
192}
193
194impl<R> FutureBatcher<R> {
195    pub fn new() -> Self {
196        Self {
197            idgen: Default::default(),
198            pending_futures: Default::default(),
199            wakeup_futures: Default::default(),
200            raw_waker: Default::default(),
201            await_counter: Default::default(),
202            closed: Default::default(),
203        }
204    }
205
206    /// Push a new task future.
207    ///
208    pub fn push<Fut>(&self, fut: Fut) -> usize
209    where
210        Fut: Future<Output = R> + Send + 'static,
211    {
212        let id = self.idgen.fetch_add(1, Ordering::AcqRel);
213
214        self.pending_futures.insert(id, Box::pin(fut));
215        self.wakeup_futures.push(id);
216
217        self.raw_waker.wake();
218
219        id
220    }
221
222    /// Use a fn_poll instead of a [`Future`] object
223    pub fn push_fn<F>(&self, f: F) -> usize
224    where
225        F: FnMut(&mut Context<'_>) -> std::task::Poll<R> + Send + 'static,
226    {
227        self.push(poll_fn(f))
228    }
229
230    /// Create a future task to batch poll
231    pub fn wait(&self) -> Wait<R> {
232        Wait {
233            batch: self.clone(),
234        }
235    }
236
237    pub fn close(&self) {
238        if self
239            .closed
240            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
241            .is_ok()
242        {
243            self.raw_waker.wake();
244        }
245    }
246}
247
248pub struct Wait<R> {
249    batch: FutureBatcher<R>,
250}
251
252impl<R> ops::Deref for Wait<R> {
253    type Target = FutureBatcher<R>;
254    fn deref(&self) -> &Self::Target {
255        &self.batch
256    }
257}
258
259impl<R> Future for Wait<R> {
260    type Output = Option<R>;
261
262    fn poll(
263        self: Pin<&mut Self>,
264        cx: &mut std::task::Context<'_>,
265    ) -> std::task::Poll<Self::Output> {
266        assert_eq!(
267            self.await_counter.fetch_add(1, Ordering::SeqCst),
268            0,
269            "Only one thread can call this batch poll"
270        );
271
272        if self.closed.load(Ordering::Acquire) {
273            return Poll::Ready(None);
274        }
275
276        // save system waker to prepare wakeup self again.
277        self.raw_waker.add_waker(cx.waker().clone());
278
279        while let Some(future_id) = self.wakeup_futures.pop() {
280            if self.closed.load(Ordering::Acquire) {
281                return Poll::Ready(None);
282            }
283
284            // remove future from pending mapping.
285            let future = self.pending_futures.remove(future_id);
286
287            // The batcher waker may be register more than once.
288            // thus it is possible that a future has been awakened more than once,
289            // and that previous awakening processing has deleted that future.
290            if future.is_none() {
291                continue;
292            }
293
294            let mut future = future.unwrap();
295
296            // Create a new wrapped waker.
297            let waker = new_batcher_waker(future_id, self.clone());
298
299            // poll if
300            match future.as_mut().poll(&mut Context::from_waker(&waker)) {
301                std::task::Poll::Pending => {
302                    self.pending_futures.insert(future_id, future);
303
304                    continue;
305                }
306                std::task::Poll::Ready(r) => {
307                    self.raw_waker.remove_waker();
308
309                    assert_eq!(
310                        self.await_counter.fetch_sub(1, Ordering::SeqCst),
311                        1,
312                        "Only one thread can call this batch poll"
313                    );
314                    return std::task::Poll::Ready(Some(r));
315                }
316            }
317        }
318
319        assert_eq!(
320            self.await_counter.fetch_sub(1, Ordering::SeqCst),
321            1,
322            "Only one thread can call this batch poll"
323        );
324
325        if self.closed.load(Ordering::Acquire) {
326            return Poll::Ready(None);
327        }
328
329        return std::task::Poll::Pending;
330    }
331}
332
333#[cfg(test)]
334mod tests {
335
336    use std::{io, sync::mpsc};
337
338    use futures::{executor::ThreadPool, future::poll_fn, task::SpawnExt};
339
340    use super::*;
341
342    #[futures_test::test]
343    async fn test_basic_case() {
344        let batch_future = FutureBatcher::<io::Result<()>>::new();
345
346        let loops = 100000;
347
348        for _ in 0..loops {
349            batch_future.push(async { Ok(()) });
350            batch_future.push(async move { Ok(()) });
351
352            batch_future.wait().await.unwrap().unwrap();
353
354            batch_future.wait().await.unwrap().unwrap();
355        }
356    }
357
358    #[futures_test::test]
359    async fn test_push_wakeup() {
360        let pool = ThreadPool::builder().pool_size(10).create().unwrap();
361
362        let batch_future = FutureBatcher::<io::Result<()>>::new();
363
364        let loops = 100000;
365
366        for _ in 0..loops {
367            let batch_future_cloned = batch_future.clone();
368
369            let handle = pool
370                .spawn_with_handle(async move {
371                    batch_future_cloned.wait().await.unwrap().unwrap();
372                })
373                .unwrap();
374
375            batch_future.push(async move { Ok(()) });
376
377            handle.await;
378        }
379    }
380
381    #[futures_test::test]
382    async fn test_future_wakeup() {
383        let pool = ThreadPool::builder().pool_size(10).create().unwrap();
384
385        let batch_future = FutureBatcher::<io::Result<()>>::new();
386
387        for _ in 0..10000 {
388            let (sender, receiver) = mpsc::channel();
389
390            let mut sent = false;
391
392            batch_future.push(poll_fn(move |cx| {
393                if sent {
394                    return std::task::Poll::Ready(Ok(()));
395                }
396
397                sender.send(cx.waker().clone()).unwrap();
398
399                sent = true;
400
401                std::task::Poll::Pending
402            }));
403
404            let batch_futre_cloned = batch_future.clone();
405
406            let handle = pool
407                .spawn_with_handle(async move {
408                    batch_futre_cloned.wait().await.unwrap().unwrap();
409                })
410                .unwrap();
411
412            let waker = receiver.recv().unwrap();
413
414            waker.wake();
415
416            handle.await;
417        }
418    }
419}