1use std::{
2 future::Future,
3 ops,
4 pin::Pin,
5 ptr::null_mut,
6 sync::{
7 atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering},
8 Arc,
9 },
10 task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
11};
12
13use dashmap::DashMap;
14use futures::future::{poll_fn, BoxFuture};
15
16use hala_lockfree::queue::Queue;
17
18struct PendingFutures<R> {
20 futures: DashMap<usize, BoxFuture<'static, R>>,
21}
22
23impl<R> PendingFutures<R> {
24 fn insert(&self, id: usize, fut: BoxFuture<'static, R>) {
25 self.futures.insert(id, fut);
26 }
27
28 fn remove(&self, id: usize) -> Option<BoxFuture<'static, R>> {
29 self.futures.remove(&id).map(|(_, fut)| fut)
30 }
31}
32
33impl<R> Default for PendingFutures<R> {
34 fn default() -> Self {
35 Self {
36 futures: DashMap::new(),
37 }
38 }
39}
40
41#[derive(Default)]
42struct WakerHost {
43 waker: AtomicPtr<Waker>,
44}
45
46impl WakerHost {
47 fn wake(&self) {
48 if let Some(waker) = self.remove_waker() {
49 waker.wake();
50 }
51 }
52
53 fn remove_waker(&self) -> Option<Box<Waker>> {
54 loop {
55 let waker_ptr = self.waker.load(Ordering::Acquire);
56
57 if waker_ptr == null_mut() {
58 return None;
59 }
60
61 if self
62 .waker
63 .compare_exchange_weak(waker_ptr, null_mut(), Ordering::AcqRel, Ordering::Relaxed)
64 .is_err()
65 {
66 continue;
67 }
68
69 return Some(unsafe { Box::from_raw(waker_ptr) });
70 }
71 }
72
73 fn add_waker(&self, waker: Waker) {
74 let waker_ptr = Box::into_raw(Box::new(waker));
75
76 let old = self.waker.swap(waker_ptr, Ordering::AcqRel);
77
78 if old != null_mut() {
79 let waker = unsafe { Box::from_raw(old) };
80
81 drop(waker);
82
83 log::trace!("Batching is awakened unintentionally !!!.");
85 }
86 }
87}
88
89#[derive(Clone)]
90struct BatcherWaker {
91 future_id: usize,
92 ready_futures: Arc<Queue<usize>>,
94 raw_waker: Arc<WakerHost>,
96}
97
98#[inline(always)]
99unsafe fn batch_future_waker_clone(data: *const ()) -> RawWaker {
100 let waker = Box::from_raw(data as *mut BatcherWaker);
101
102 let waker_cloned = waker.clone();
103
104 _ = Box::into_raw(waker);
105
106 RawWaker::new(Box::into_raw(waker_cloned) as *const (), &WAKER_VTABLE)
107}
108
109#[inline(always)]
110unsafe fn batch_future_waker_wake(data: *const ()) {
111 let waker = Box::from_raw(data as *mut BatcherWaker);
112
113 waker.ready_futures.push(waker.future_id);
114
115 waker.raw_waker.wake();
116}
117
118#[inline(always)]
119unsafe fn batch_future_waker_wake_by_ref(data: *const ()) {
120 let waker = Box::from_raw(data as *mut BatcherWaker);
121
122 waker.ready_futures.push(waker.future_id);
123
124 waker.raw_waker.wake();
125
126 _ = Box::into_raw(waker);
127}
128
129#[inline(always)]
130unsafe fn batch_future_waker_drop(data: *const ()) {
131 _ = Box::from_raw(data as *mut BatcherWaker);
132}
133
134const WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
135 batch_future_waker_clone,
136 batch_future_waker_wake,
137 batch_future_waker_wake_by_ref,
138 batch_future_waker_drop,
139);
140
141fn new_batcher_waker<Fut>(future_id: usize, batch_future: FutureBatcher<Fut>) -> Waker {
142 let boxed = Box::new(BatcherWaker {
143 future_id,
144 ready_futures: batch_future.wakeup_futures,
145 raw_waker: batch_future.raw_waker,
146 });
147
148 unsafe {
149 Waker::from_raw(RawWaker::new(
150 Box::into_raw(boxed) as *const (),
151 &WAKER_VTABLE,
152 ))
153 }
154}
155
156pub struct FutureBatcher<R> {
158 idgen: Arc<AtomicUsize>,
160 pending_futures: Arc<PendingFutures<R>>,
162 wakeup_futures: Arc<Queue<usize>>,
164 raw_waker: Arc<WakerHost>,
166 await_counter: Arc<AtomicUsize>,
168 closed: Arc<AtomicBool>,
170}
171
172unsafe impl<R> Send for FutureBatcher<R> {}
173unsafe impl<R> Sync for FutureBatcher<R> {}
174
175impl<R> Clone for FutureBatcher<R> {
176 fn clone(&self) -> Self {
177 Self {
178 idgen: self.idgen.clone(),
179 pending_futures: self.pending_futures.clone(),
180 wakeup_futures: self.wakeup_futures.clone(),
181 raw_waker: self.raw_waker.clone(),
182 await_counter: self.await_counter.clone(),
183 closed: self.closed.clone(),
184 }
185 }
186}
187
188impl<R> Default for FutureBatcher<R> {
189 fn default() -> Self {
190 Self::new()
191 }
192}
193
194impl<R> FutureBatcher<R> {
195 pub fn new() -> Self {
196 Self {
197 idgen: Default::default(),
198 pending_futures: Default::default(),
199 wakeup_futures: Default::default(),
200 raw_waker: Default::default(),
201 await_counter: Default::default(),
202 closed: Default::default(),
203 }
204 }
205
206 pub fn push<Fut>(&self, fut: Fut) -> usize
209 where
210 Fut: Future<Output = R> + Send + 'static,
211 {
212 let id = self.idgen.fetch_add(1, Ordering::AcqRel);
213
214 self.pending_futures.insert(id, Box::pin(fut));
215 self.wakeup_futures.push(id);
216
217 self.raw_waker.wake();
218
219 id
220 }
221
222 pub fn push_fn<F>(&self, f: F) -> usize
224 where
225 F: FnMut(&mut Context<'_>) -> std::task::Poll<R> + Send + 'static,
226 {
227 self.push(poll_fn(f))
228 }
229
230 pub fn wait(&self) -> Wait<R> {
232 Wait {
233 batch: self.clone(),
234 }
235 }
236
237 pub fn close(&self) {
238 if self
239 .closed
240 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
241 .is_ok()
242 {
243 self.raw_waker.wake();
244 }
245 }
246}
247
248pub struct Wait<R> {
249 batch: FutureBatcher<R>,
250}
251
252impl<R> ops::Deref for Wait<R> {
253 type Target = FutureBatcher<R>;
254 fn deref(&self) -> &Self::Target {
255 &self.batch
256 }
257}
258
259impl<R> Future for Wait<R> {
260 type Output = Option<R>;
261
262 fn poll(
263 self: Pin<&mut Self>,
264 cx: &mut std::task::Context<'_>,
265 ) -> std::task::Poll<Self::Output> {
266 assert_eq!(
267 self.await_counter.fetch_add(1, Ordering::SeqCst),
268 0,
269 "Only one thread can call this batch poll"
270 );
271
272 if self.closed.load(Ordering::Acquire) {
273 return Poll::Ready(None);
274 }
275
276 self.raw_waker.add_waker(cx.waker().clone());
278
279 while let Some(future_id) = self.wakeup_futures.pop() {
280 if self.closed.load(Ordering::Acquire) {
281 return Poll::Ready(None);
282 }
283
284 let future = self.pending_futures.remove(future_id);
286
287 if future.is_none() {
291 continue;
292 }
293
294 let mut future = future.unwrap();
295
296 let waker = new_batcher_waker(future_id, self.clone());
298
299 match future.as_mut().poll(&mut Context::from_waker(&waker)) {
301 std::task::Poll::Pending => {
302 self.pending_futures.insert(future_id, future);
303
304 continue;
305 }
306 std::task::Poll::Ready(r) => {
307 self.raw_waker.remove_waker();
308
309 assert_eq!(
310 self.await_counter.fetch_sub(1, Ordering::SeqCst),
311 1,
312 "Only one thread can call this batch poll"
313 );
314 return std::task::Poll::Ready(Some(r));
315 }
316 }
317 }
318
319 assert_eq!(
320 self.await_counter.fetch_sub(1, Ordering::SeqCst),
321 1,
322 "Only one thread can call this batch poll"
323 );
324
325 if self.closed.load(Ordering::Acquire) {
326 return Poll::Ready(None);
327 }
328
329 return std::task::Poll::Pending;
330 }
331}
332
333#[cfg(test)]
334mod tests {
335
336 use std::{io, sync::mpsc};
337
338 use futures::{executor::ThreadPool, future::poll_fn, task::SpawnExt};
339
340 use super::*;
341
342 #[futures_test::test]
343 async fn test_basic_case() {
344 let batch_future = FutureBatcher::<io::Result<()>>::new();
345
346 let loops = 100000;
347
348 for _ in 0..loops {
349 batch_future.push(async { Ok(()) });
350 batch_future.push(async move { Ok(()) });
351
352 batch_future.wait().await.unwrap().unwrap();
353
354 batch_future.wait().await.unwrap().unwrap();
355 }
356 }
357
358 #[futures_test::test]
359 async fn test_push_wakeup() {
360 let pool = ThreadPool::builder().pool_size(10).create().unwrap();
361
362 let batch_future = FutureBatcher::<io::Result<()>>::new();
363
364 let loops = 100000;
365
366 for _ in 0..loops {
367 let batch_future_cloned = batch_future.clone();
368
369 let handle = pool
370 .spawn_with_handle(async move {
371 batch_future_cloned.wait().await.unwrap().unwrap();
372 })
373 .unwrap();
374
375 batch_future.push(async move { Ok(()) });
376
377 handle.await;
378 }
379 }
380
381 #[futures_test::test]
382 async fn test_future_wakeup() {
383 let pool = ThreadPool::builder().pool_size(10).create().unwrap();
384
385 let batch_future = FutureBatcher::<io::Result<()>>::new();
386
387 for _ in 0..10000 {
388 let (sender, receiver) = mpsc::channel();
389
390 let mut sent = false;
391
392 batch_future.push(poll_fn(move |cx| {
393 if sent {
394 return std::task::Poll::Ready(Ok(()));
395 }
396
397 sender.send(cx.waker().clone()).unwrap();
398
399 sent = true;
400
401 std::task::Poll::Pending
402 }));
403
404 let batch_futre_cloned = batch_future.clone();
405
406 let handle = pool
407 .spawn_with_handle(async move {
408 batch_futre_cloned.wait().await.unwrap().unwrap();
409 })
410 .unwrap();
411
412 let waker = receiver.recv().unwrap();
413
414 waker.wake();
415
416 handle.await;
417 }
418 }
419}