1use std::{any::{Any, TypeId}, cell::{RefCell, RefMut}, collections::{HashMap}, future::Future, marker::PhantomData, pin::Pin, rc::Rc, task::{Poll}, unreachable};
29
30use futures::{FutureExt};
31use futures::channel::oneshot;
32use slab::Slab;
33
34
35pub fn batched<F: Future>(fut: F) -> Batched<F> {
37 Batched {
38 fut,
39 batch_futures: Slab::new(),
40 ctx: Rc::new(RefCell::new(BatchContext {
41 accumulating: HashMap::new(),
42 postpone_loading: 0,
43 user_ctx: HashMap::new()
44 }))
45 }
46}
47
48type ResultSender = futures::channel::oneshot::Sender<Box<dyn Any>>;
49
50#[doc(hidden)]
51pub mod __internal {
52 use std::{future::Future, pin::Pin, task::Poll};
53
54 use super::{ResultSender};
55
56 pub struct LoadBatch<Outputs: Iterator, F: Future<Output = Outputs>> {
57 pub fut: F,
58 pub result_senders: Vec<ResultSender>
59 }
60
61 impl<Outputs: Iterator, F: Future<Output = Outputs>> Future for LoadBatch<Outputs, F> where Outputs::Item: 'static {
62 type Output = ();
63
64 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
65 let fut;
66 let senders;
67
68 unsafe {
69 let this = self.get_unchecked_mut();
70
71 senders = &mut this.result_senders;
72 fut = Pin::new_unchecked(&mut this.fut);
73 };
74
75 if senders.iter().all(|res| res.is_canceled() ) { return Poll::Ready(()) }
77
78 match fut.poll(cx) {
79 Poll::Ready(outputs) => {
80 for (output, sender) in outputs.zip(senders.drain(..)) {
81 let _ = sender.send(Box::new(output));
82 }
83 Poll::Ready(())
84 }
85 Poll::Pending => Poll::Pending
86 }
87 }
88 }
89}
90
91#[macro_export]
93macro_rules! def_batch_loader {
94 (
95 $(#[$attr:meta])*
96 $vis:vis async fn $name:ident($inputs:ident: $input_ty:ty) -> $output_ty:ty $block:block
97 ) => {
98 $(#[$attr])* $vis fn $name( input: $input_ty ) -> $crate::BatchLoad::<$input_ty, $output_ty> {
99 fn load_batch( batch: $crate::Batch ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()>>> {
101 #[inline(always)]
103 async fn loader(
104 $inputs: impl Iterator<Item = Box<$input_ty>>
105 ) -> impl Iterator<Item = $output_ty> $block
106
107 let inputs = batch.inputs.into_iter().map(|input| {
109 input.downcast::<$input_ty>().unwrap()
111 });
112
113 let fut = $crate::__internal::LoadBatch {
114 fut: loader(inputs),
115 result_senders: batch.result_senders
116 };
117
118 Box::pin(fut)
120 }
121
122 $crate::BatchLoad::New {
123 load_fn: load_batch,
124 input: Box::new(input),
125 phantom: std::marker::PhantomData
126 }
127 }
128 };
129}
130
131type LoadFn = fn ( Batch ) -> Pin<Box<dyn Future<Output = ()>>>;
132
133pub struct BatchContext {
135 accumulating: HashMap<LoadFn, Batch>,
136
137 postpone_loading: usize,
138
139 user_ctx: HashMap<TypeId, Box<dyn Any>>
140}
141
142impl BatchContext {
143 pub fn set_ctx(&mut self, val: Box<dyn Any>) -> Option<Box<dyn Any>> {
145 self.user_ctx.insert((*val).type_id(), val)
146 }
147 pub fn get_ctx<T: Any>(&self) -> Option<&T> {
149 self.user_ctx.get(&TypeId::of::<T>()).map(|a| a.downcast_ref().unwrap())
150 }
151 pub fn mut_ctx<'a, T: Any>(&'a mut self) -> Option<&'a mut T> {
153 self.user_ctx.get_mut(&TypeId::of::<T>()).map(|a| a.downcast_mut().unwrap())
154 }
155}
156
157thread_local! {
158 static BATCH_CONTEXT: RefCell<Option<Rc<RefCell<BatchContext>>>> = RefCell::new(None);
159}
160
161#[doc(hidden)]
163pub struct Batch {
164 pub inputs: Vec<Box<dyn Any>>,
165 pub result_senders: Vec<ResultSender>
166}
167
168impl Batch {
169 fn empty() -> Self {
170 Batch { inputs: vec![], result_senders: vec![] }
171 }
172 fn push(&mut self, input: Box<dyn Any>, result: ResultSender) {
173 self.inputs.push(input);
174 self.result_senders.push(result);
175 }
176}
177
178pub enum BatchLoad<Input, Output: ?Sized> {
180 New {
181 load_fn: LoadFn,
182 input: Box<Input>,
183 phantom: PhantomData<Box<Output>>
184 },
185 Pending(oneshot::Receiver<Box<dyn Any>>)
186}
187
188impl<Input: 'static, Output: ?Sized> BatchLoad<Input, Output> {
189 pub fn schedule(&mut self) {
195 if let Self::New {..} = self {
196 let (tx, rx) = futures::channel::oneshot::channel();
197
198 let (load_fn, input) = match std::mem::replace(self, BatchLoad::Pending(rx)) {
199 Self::New { load_fn, input, .. } => (load_fn, input),
200 _ => unreachable!()
201 };
202
203 with_batch_ctx(|ctx| {
204 let batch = ctx.accumulating.entry(load_fn).or_insert(Batch::empty());
205
206 batch.push(input, tx);
207 });
208 }
209 }
210}
211
212impl<Input: 'static, Output: 'static> Future for BatchLoad<Input, Output> {
213 type Output = Box<Output>;
214
215 #[track_caller]
216 #[inline]
217 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
218 let this = self.get_mut();
219
220 if let Self::New {..} = this {
221 this.schedule();
222 }
223
224 let rx = if let Self::Pending(rx) = this { rx } else { unreachable!() };
225
226 let poll = rx.poll_unpin(cx).map(|res| res.expect("Batch loading context was cancelled"));
227
228 poll.map(|val| {
229 val.downcast().unwrap()
230 })
231 }
232}
233
234pub struct Batched<F: Future> {
236 fut: F,
237
238 ctx: Rc<RefCell<BatchContext>>,
239
240 batch_futures: Slab<Pin<Box<dyn Future<Output = ()>>>>
241}
242
243impl<F: Future> Batched<F> {
244 pub fn ctx<'a>(&'a mut self) -> RefMut<'a, BatchContext> {
246 self.ctx.borrow_mut()
247 }
248}
249
250#[inline]
252fn provide_batch_ctx<T>(ctx: Rc<RefCell<BatchContext>>, cb: impl FnOnce() -> T) -> T {
253 let existing_ctx = BATCH_CONTEXT.with(|cell| {
254 cell.replace(Some(ctx))
255 });
256
257 let val = (cb)();
258
259 BATCH_CONTEXT.with(|cell| {
260 cell.replace(existing_ctx)
261 });
262
263 val
264}
265
266pub fn with_batch_ctx<T>(cb: impl FnOnce(&mut BatchContext) -> T) -> T {
268 BATCH_CONTEXT.with(|cell| {
269 let ctx = cell.borrow();
270 let ctx = ctx.as_ref().expect("Tried to call a batched loader outside of a batching context.");
271 let mut ctx = (&*ctx).borrow_mut();
272 cb(&mut ctx)
273 })
274}
275
276
277#[doc(hidden)]
278pub struct DelayGuard<'a>( PhantomData<Rc<RefCell<&'a ()>>> );
279
280impl<'a> Drop for DelayGuard<'a> {
281 fn drop(&mut self) {
282 with_batch_ctx(|ctx| {
283 ctx.postpone_loading -= 1;
284 });
285 }
286}
287
288pub fn delay_loading_batches<'a>() -> DelayGuard<'a> {
357 with_batch_ctx(|ctx| {
358 ctx.postpone_loading += 1;
359 });
360 DelayGuard(PhantomData)
361}
362
363impl<F: Future> Drop for Batched<F> {
364 fn drop(&mut self) {
365 provide_batch_ctx(self.ctx.clone(), move || {
366 let Self { .. } = self;
367 });
368 }
369}
370
371impl<F: Future> Future for Batched<F> {
372 type Output = F::Output;
373
374 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
375 let fut;
376 let batch_futures;
377 let ctx;
378
379 unsafe {
380 let this = self.get_unchecked_mut();
381
382 batch_futures = &mut this.batch_futures;
383 fut = Pin::new_unchecked(&mut this.fut);
384 ctx = &this.ctx;
385 };
386
387 let poll = provide_batch_ctx(ctx.clone(), || {
388 let poll = fut.poll(cx);
389
390 let mut ready_futures = vec![];
391
392 for (idx, batch_fut) in batch_futures.iter_mut() {
393 match batch_fut.as_mut().poll(cx) {
394 Poll::Ready(_) => ready_futures.push(idx),
395 Poll::Pending => { }
396 }
397 }
398
399 for idx in ready_futures {
400 batch_futures.remove(idx);
401 }
402
403 poll
404 });
405
406 loop {
407 let batches = {
408 let mut ctx = (**ctx).borrow_mut();
409
410 if ctx.accumulating.is_empty() { break }
411
412 if ctx.postpone_loading > 0 { break }
413
414 std::mem::replace(&mut ctx.accumulating, HashMap::new())
415 };
416
417 provide_batch_ctx(ctx.clone(), || {
418 for (loader, batch) in batches.into_iter() {
419 let mut fut = (loader)(batch);
420
421 if let Poll::Pending = fut.as_mut().poll(cx) {
422 batch_futures.insert(fut);
423 }
424 }
425 })
426 }
427
428 match poll {
429 Poll::Ready(val) if batch_futures.is_empty() => {
430 Poll::Ready(val)
431 },
432 _ => Poll::Pending
433 }
434 }
435}
436
437#[cfg(test)]
438mod tests {
439 async fn yield_now() {
440 struct YieldNow {
441 yielded: bool,
442 }
443
444 impl std::future::Future for YieldNow {
445 type Output = ();
446 fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
447 if self.yielded {
448 std::task::Poll::Ready(())
449 } else {
450 cx.waker().wake_by_ref();
451 self.yielded = true;
452 std::task::Poll::Pending
453 }
454 }
455 }
456
457 YieldNow { yielded: false }.await;
458 }
459
460 use super::{batched, def_batch_loader, delay_loading_batches};
461 use futures::{FutureExt};
462
463 def_batch_loader! {
464 pub async fn load_foobar_batched(inputs: u64) -> (Vec<u64>, String) {
466 let inputs: Vec<_> = inputs.map(|a| *a).collect();
467 let inputs_copy = inputs.clone();
468
469 yield_now().await;
470
471 inputs.into_iter().map(move |input| {
472 (inputs_copy.clone(), input.to_string())
473 })
474 }
475 }
476
477 #[test]
478 fn test() {
479 futures::executor::block_on(async {
480 batched(async {
481 let fifty_four = load_foobar_batched(54).fuse();
482 let thirty_two = load_foobar_batched(32).fuse();
483
484 futures::pin_mut!(fifty_four, thirty_two);
485
486 futures::select_biased! {
487 tt = thirty_two => {
488 assert_eq!(tt, Box::new((vec![32, 54], "32".to_owned())));
489 },
490 ff = fifty_four => {
491 assert_eq!(ff, Box::new((vec![32, 54], "54".to_owned())));
492 }
493 }
494 }).await;
495 });
496 }
497
498 #[test]
499 fn test_schedule() {
500 futures::executor::block_on(async {
501 batched(async {
502 assert_eq!(load_foobar_batched(12).await, Box::new((vec![12], "12".to_owned())));
503
504 let mut fifty_four = load_foobar_batched(54);
505 let thirty_two = load_foobar_batched(32);
506
507 fifty_four.schedule();
508
509 assert_eq!(thirty_two.await, Box::new((vec![54, 32], "32".to_owned())));
510 assert_eq!(fifty_four.await, Box::new((vec![54, 32], "54".to_owned())));
511 }).await;
512 });
513 }
514
515
516 #[test]
517 fn test_ctx() {
518 futures::executor::block_on(async {
519 struct Count(usize);
520
521 def_batch_loader! {
522 pub async fn counter(inputs: &'static str) -> (&'static str, usize) {
523 inputs.map(|input| {
524 let count = super::with_batch_ctx(|ctx| {
525 let count = ctx.mut_ctx::<Count>().unwrap();
526
527 count.0 += 1;
528
529 count.0
530 });
531
532 (*input, count)
533 })
534 }
535 }
536
537 let mut scope = batched(async {
538 assert_eq!( counter("hello").await, Box::new(("hello", 1)) );
539 assert_eq!( counter("hello there").await, Box::new(("hello there", 2)) );
540 });
541
542 scope.ctx().set_ctx(Box::new(Count(0)));
543
544 scope.await;
545 });
546 }
547
548 #[test]
549 fn test_drop_delay() {
550 futures::executor::block_on(async {
551 batched(async {
552 let one = load_foobar_batched(1).fuse();
553
554 futures::pin_mut!(one);
555
556 futures::select_biased! {
557 one = one => {
558 assert_eq!(one, Box::new((vec![1], "1".to_owned())));
559 }
560 }
561
562 pub struct PendingOnce {
563 is_ready: bool,
564 }
565
566 impl std::future::Future for PendingOnce {
567 type Output = ();
568 fn poll(mut self: std::pin::Pin<&mut Self>, _: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
569 if self.is_ready {
570 std::task::Poll::Ready(())
571 } else {
572 self.is_ready = true;
573 std::task::Poll::Pending
574 }
575 }
576 }
577
578 let _ = delay_loading_batches();
579 }).await;
580 });
581 }
582}