async_dataloader/
lib.rs

1//! Powerful for avoiding N+1 queries with async/await, based on the DataLoader pattern.
2//!
3//! data_loader batches loads which occur during a single "poll", without requiring an artificial delay.
4//!
5//! Design inspired by https://github.com/exAspArk/batch-loader and https://github.com/graphql/dataloader
6//!
7//! # Usage
8//!
9//! ```
10//! use async_dataloader::{def_batch_loader, batched};
11//!
12//! def_batch_loader! {
13//!     pub async fn loader(inputs: u64) -> String {
14//!         inputs.map(|input| {
15//!              input.to_string()
16//!         })
17//!     }
18//! }
19//!
20//! # futures::executor::block_on(async {
21//! batched(async {
22//!     assert_eq!(*loader(1).await, "1".to_owned());
23//! }).await
24//! # })
25//! ```
26
27
28use std::{any::{Any, TypeId}, cell::{RefCell, RefMut}, collections::{HashMap}, future::Future, marker::PhantomData, pin::Pin, rc::Rc, task::{Poll}, unreachable};
29
30use futures::{FutureExt};
31use futures::channel::oneshot;
32use slab::Slab;
33
34
35/// Allows using batch loaders from within the passed future.
36pub fn batched<F: Future>(fut: F) -> Batched<F> {
37    Batched {
38        fut,
39        batch_futures: Slab::new(),
40        ctx: Rc::new(RefCell::new(BatchContext {
41            accumulating: HashMap::new(),
42            postpone_loading: 0,
43            user_ctx: HashMap::new()
44        }))
45    }
46}
47
48type ResultSender = futures::channel::oneshot::Sender<Box<dyn Any>>;
49
50#[doc(hidden)]
51pub mod __internal {
52    use std::{future::Future, pin::Pin, task::Poll};
53
54    use super::{ResultSender};
55
56    pub struct LoadBatch<Outputs: Iterator, F: Future<Output = Outputs>> {
57        pub fut: F,
58        pub result_senders: Vec<ResultSender>
59    }
60
61    impl<Outputs: Iterator, F: Future<Output = Outputs>> Future for LoadBatch<Outputs, F> where Outputs::Item: 'static {
62        type Output = ();
63
64        fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
65            let fut;
66            let senders;
67
68            unsafe {
69                let this = self.get_unchecked_mut();
70
71                senders = &mut this.result_senders;
72                fut = Pin::new_unchecked(&mut this.fut);
73            };
74
75            // check if no one cares for the results of this batch
76            if senders.iter().all(|res| res.is_canceled() ) { return Poll::Ready(()) }
77
78            match fut.poll(cx) {
79                Poll::Ready(outputs) => {
80                    for (output, sender) in outputs.zip(senders.drain(..)) {
81                        let _ = sender.send(Box::new(output));
82                    }
83                    Poll::Ready(())
84                }
85                Poll::Pending => Poll::Pending
86            }
87        }
88    }
89}
90
91/// Define a batch loader
92#[macro_export]
93macro_rules! def_batch_loader {
94    (
95        $(#[$attr:meta])*
96        $vis:vis async fn $name:ident($inputs:ident: $input_ty:ty) -> $output_ty:ty $block:block
97    ) => {
98        $(#[$attr])* $vis fn $name( input: $input_ty ) -> $crate::BatchLoad::<$input_ty, $output_ty> {
99            // A type-erased load function, which conforms to LoadFn
100            fn load_batch( batch: $crate::Batch ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()>>> {
101                // The user-provided batch loader
102                #[inline(always)]
103                async fn loader(
104                    $inputs: impl Iterator<Item = Box<$input_ty>>
105                ) -> impl Iterator<Item = $output_ty> $block
106
107                // Downcast inputs to the expected type.
108                let inputs = batch.inputs.into_iter().map(|input| {
109                    // It should be impossible to pass in an input of the wrong type through the public API
110                    input.downcast::<$input_ty>().unwrap()
111                });
112
113                let fut = $crate::__internal::LoadBatch {
114                    fut: loader(inputs),
115                    result_senders: batch.result_senders
116                };
117
118                // Call load_batch, then return the future as a Pin<Box<dyn Future<Output = ()>>>
119                Box::pin(fut)
120            }
121
122            $crate::BatchLoad::New {
123                load_fn: load_batch,
124                input: Box::new(input),
125                phantom: std::marker::PhantomData
126            }
127        }
128    };
129}
130
131type LoadFn = fn ( Batch ) -> Pin<Box<dyn Future<Output = ()>>>;
132
133/// Context provided when executing within a batched() future.
134pub struct BatchContext {
135    accumulating: HashMap<LoadFn, Batch>,
136
137    postpone_loading: usize,
138
139    user_ctx: HashMap<TypeId, Box<dyn Any>>
140}
141
142impl BatchContext {
143    /// Provide context of a given type. Exactly one value per type may be stored.
144    pub fn set_ctx(&mut self, val: Box<dyn Any>) -> Option<Box<dyn Any>> {
145        self.user_ctx.insert((*val).type_id(), val)
146    }
147    /// Get context of a given type. Exactly one value per type may be stored.
148    pub fn get_ctx<T: Any>(&self) -> Option<&T> {
149        self.user_ctx.get(&TypeId::of::<T>()).map(|a| a.downcast_ref().unwrap())
150    }
151    /// Get context of a given type. Exactly one value per type may be stored.
152    pub fn mut_ctx<'a, T: Any>(&'a mut self) -> Option<&'a mut T> {
153        self.user_ctx.get_mut(&TypeId::of::<T>()).map(|a| a.downcast_mut().unwrap())
154    }
155}
156
157thread_local! {
158    static BATCH_CONTEXT: RefCell<Option<Rc<RefCell<BatchContext>>>> = RefCell::new(None);
159}
160
161// Batched inputs and result senders
162#[doc(hidden)]
163pub struct Batch {
164    pub inputs: Vec<Box<dyn Any>>,
165    pub result_senders: Vec<ResultSender>
166}
167
168impl Batch {
169    fn empty() -> Self {
170        Batch { inputs: vec![], result_senders: vec![] }
171    }
172    fn push(&mut self, input: Box<dyn Any>, result: ResultSender) {
173        self.inputs.push(input);
174        self.result_senders.push(result);
175    }
176}
177
178/// Future returned by a batch loader
179pub enum BatchLoad<Input, Output: ?Sized> {
180    New {
181        load_fn: LoadFn,
182        input: Box<Input>,
183        phantom: PhantomData<Box<Output>>
184    },
185    Pending(oneshot::Receiver<Box<dyn Any>>)
186}
187
188impl<Input: 'static, Output: ?Sized> BatchLoad<Input, Output> {
189    /// Schedules this input to be loaded within the current batch context.
190    ///
191    /// Rust futures are lazy, meaning they have do nothing until polled.
192    /// Calling this method will cause the load to be added to the next batch,
193    /// even if it the future is not polled until later.
194    pub fn schedule(&mut self) {
195        if let Self::New {..} = self {
196            let (tx, rx) = futures::channel::oneshot::channel();
197
198            let (load_fn, input) = match std::mem::replace(self, BatchLoad::Pending(rx)) {
199                Self::New { load_fn, input, .. } => (load_fn, input),
200                _ => unreachable!()
201            };
202
203            with_batch_ctx(|ctx| {
204                let batch = ctx.accumulating.entry(load_fn).or_insert(Batch::empty());
205
206                batch.push(input, tx);
207            });
208        }
209    }
210}
211
212impl<Input: 'static, Output: 'static> Future for BatchLoad<Input, Output> {
213    type Output = Box<Output>;
214
215    #[track_caller]
216    #[inline]
217    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
218        let this = self.get_mut();
219
220        if let Self::New {..} = this {
221            this.schedule();
222        }
223
224        let rx = if let Self::Pending(rx) = this { rx } else { unreachable!() };
225
226        let poll = rx.poll_unpin(cx).map(|res| res.expect("Batch loading context was cancelled"));
227
228        poll.map(|val| {
229            val.downcast().unwrap()
230        })
231    }
232}
233
234/// A Future which provides a BatchContext to its child while executing
235pub struct Batched<F: Future> {
236    fut: F,
237
238    ctx: Rc<RefCell<BatchContext>>,
239
240    batch_futures: Slab<Pin<Box<dyn Future<Output = ()>>>>
241}
242
243impl<F: Future> Batched<F> {
244    /// Access the BatchContext from outside of async execution
245    pub fn ctx<'a>(&'a mut self) -> RefMut<'a, BatchContext> {
246        self.ctx.borrow_mut()
247    }
248}
249
250/// Provides the batch context through thread local storage
251#[inline]
252fn provide_batch_ctx<T>(ctx: Rc<RefCell<BatchContext>>, cb: impl FnOnce() -> T) -> T {
253    let existing_ctx = BATCH_CONTEXT.with(|cell| {
254        cell.replace(Some(ctx))
255    });
256
257    let val = (cb)();
258
259    BATCH_CONTEXT.with(|cell| {
260        cell.replace(existing_ctx)
261    });
262
263    val
264}
265
266/// Retrieves the batch context from thread local storage
267pub fn with_batch_ctx<T>(cb: impl FnOnce(&mut BatchContext) -> T) -> T {
268    BATCH_CONTEXT.with(|cell| {
269        let ctx = cell.borrow();
270        let ctx = ctx.as_ref().expect("Tried to call a batched loader outside of a batching context.");
271        let mut ctx = (&*ctx).borrow_mut();
272        cb(&mut ctx)
273    })
274}
275
276
277#[doc(hidden)]
278pub struct DelayGuard<'a>( PhantomData<Rc<RefCell<&'a ()>>> );
279
280impl<'a> Drop for DelayGuard<'a> {
281    fn drop(&mut self) {
282        with_batch_ctx(|ctx| {
283            ctx.postpone_loading -= 1;
284        });
285    }
286}
287
288/// Provides a guard which prevents loading new batches until dropped.
289///
290/// ```
291/// # use async_dataloader::{*};
292/// # use futures::FutureExt;
293/// #
294/// # futures::executor::block_on(async {
295/// #
296/// #
297/// # async fn yield_now() {
298/// #     struct YieldNow {
299/// #         yielded: bool,
300/// #     }
301/// #
302/// #     impl std::future::Future for YieldNow {
303/// #         type Output = ();
304/// #         fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
305/// #             if self.yielded {
306/// #                 std::task::Poll::Ready(())
307/// #             } else {
308/// #                 cx.waker().wake_by_ref();
309/// #                 self.yielded = true;
310/// #                 std::task::Poll::Pending
311/// #             }
312/// #         }
313/// #     }
314/// #
315/// #     YieldNow { yielded: false }.await;
316/// # }
317///
318/// def_batch_loader! {
319///     pub async fn loader(inputs: u64) -> (Vec<u64>, String) {
320///         let inputs: Vec<_> = inputs.map(|a| *a).collect();
321///         let inputs_copy = inputs.clone();
322///
323///         inputs.into_iter().map(move |input| {
324///             (inputs_copy.clone(), input.to_string())
325///         })
326///     }
327/// }
328///
329/// batched(async {
330///     let mut one = loader(1);
331///     let mut two = loader(2);
332///     let mut three = loader(3);
333///
334///     one.schedule();
335///
336///     // yielding without delay_loading_batches will cause the batch to load
337///     yield_now().await;
338///
339///     assert_eq!(one.await, Box::new((vec![1], "1".to_owned())));
340///
341///     // delay_loading_batches enables accumulating batches across yields
342///     let guard = delay_loading_batches();
343///     two.schedule();
344///     yield_now().await;
345///     drop(guard);
346/// 
347///     let three = three.await;
348/// 
349///     assert_eq!(three, Box::new((vec![2, 3], "3".to_owned())));
350/// }).await;
351/// # });
352/// ```
353/// ## Panics
354///
355/// Must be called from within a batched() context.
356pub fn delay_loading_batches<'a>() -> DelayGuard<'a> {
357    with_batch_ctx(|ctx| {
358        ctx.postpone_loading += 1;
359    });
360    DelayGuard(PhantomData)
361}
362
363impl<F: Future> Drop for Batched<F> {
364    fn drop(&mut self) {
365        provide_batch_ctx(self.ctx.clone(), move || {
366            let Self { .. } = self;
367        });
368    }
369}
370
371impl<F: Future> Future for Batched<F> {
372    type Output = F::Output;
373
374    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
375        let fut;
376        let batch_futures;
377        let ctx;
378
379        unsafe {
380            let this = self.get_unchecked_mut();
381
382            batch_futures = &mut this.batch_futures;
383            fut = Pin::new_unchecked(&mut this.fut);
384            ctx = &this.ctx;
385        };
386
387        let poll = provide_batch_ctx(ctx.clone(), || {
388            let poll = fut.poll(cx);
389
390            let mut ready_futures = vec![];
391            
392            for (idx, batch_fut) in batch_futures.iter_mut() {
393                match batch_fut.as_mut().poll(cx) {
394                    Poll::Ready(_) => ready_futures.push(idx),
395                    Poll::Pending => { }
396                }
397            }
398
399            for idx in ready_futures {
400                batch_futures.remove(idx);
401            }
402
403            poll
404        });
405
406        loop {
407            let batches = {
408                let mut ctx = (**ctx).borrow_mut();
409
410                if ctx.accumulating.is_empty() { break }
411
412                if ctx.postpone_loading > 0 { break }
413
414                std::mem::replace(&mut ctx.accumulating, HashMap::new())
415            };
416
417            provide_batch_ctx(ctx.clone(), || {
418                for (loader, batch) in batches.into_iter() {
419                    let mut fut = (loader)(batch);
420
421                    if let Poll::Pending = fut.as_mut().poll(cx) {
422                        batch_futures.insert(fut);
423                    }
424                }
425            })
426        }
427
428        match poll {
429            Poll::Ready(val) if batch_futures.is_empty() => {
430                Poll::Ready(val)
431            },
432            _ => Poll::Pending
433        }
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    async fn yield_now() {
440        struct YieldNow {
441            yielded: bool,
442        }
443
444        impl std::future::Future for YieldNow {
445            type Output = ();
446            fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
447                if self.yielded {
448                    std::task::Poll::Ready(())
449                } else {
450                    cx.waker().wake_by_ref();
451                    self.yielded = true;
452                    std::task::Poll::Pending
453                }
454            }
455        }
456
457        YieldNow { yielded: false }.await;
458    }
459
460    use super::{batched, def_batch_loader, delay_loading_batches};
461    use futures::{FutureExt};
462
463    def_batch_loader! {
464        /// Hello there!
465        pub async fn load_foobar_batched(inputs: u64) -> (Vec<u64>, String) {
466            let inputs: Vec<_> = inputs.map(|a| *a).collect();
467            let inputs_copy = inputs.clone();
468
469            yield_now().await;
470
471            inputs.into_iter().map(move |input| {
472                (inputs_copy.clone(), input.to_string())
473            })
474        }
475    }
476
477    #[test]
478    fn test() {
479        futures::executor::block_on(async {
480            batched(async {
481                let fifty_four = load_foobar_batched(54).fuse();
482                let thirty_two = load_foobar_batched(32).fuse();
483
484                futures::pin_mut!(fifty_four, thirty_two);
485
486                futures::select_biased! {
487                    tt = thirty_two => {
488                        assert_eq!(tt, Box::new((vec![32, 54], "32".to_owned())));
489                    },
490                    ff = fifty_four => {
491                        assert_eq!(ff, Box::new((vec![32, 54], "54".to_owned())));
492                    }
493                }
494            }).await;
495        });
496    }
497
498    #[test]
499    fn test_schedule() {
500        futures::executor::block_on(async {
501            batched(async {
502                assert_eq!(load_foobar_batched(12).await, Box::new((vec![12], "12".to_owned())));
503
504                let mut fifty_four = load_foobar_batched(54);
505                let thirty_two = load_foobar_batched(32);
506                
507                fifty_four.schedule();
508
509                assert_eq!(thirty_two.await, Box::new((vec![54, 32], "32".to_owned())));
510                assert_eq!(fifty_four.await, Box::new((vec![54, 32], "54".to_owned())));
511            }).await;
512        });
513    }
514
515
516    #[test]
517    fn test_ctx() {
518        futures::executor::block_on(async {
519            struct Count(usize);
520
521            def_batch_loader! {
522                pub async fn counter(inputs: &'static str) -> (&'static str, usize) {
523                    inputs.map(|input| {
524                        let count = super::with_batch_ctx(|ctx| {
525                            let count = ctx.mut_ctx::<Count>().unwrap();
526
527                            count.0 += 1;
528
529                            count.0
530                        });
531
532                        (*input, count)
533                    })
534                }
535            }
536
537            let mut scope = batched(async {
538                assert_eq!( counter("hello").await, Box::new(("hello", 1)) );
539                assert_eq!( counter("hello there").await, Box::new(("hello there", 2)) );
540            });
541            
542            scope.ctx().set_ctx(Box::new(Count(0)));
543
544            scope.await;
545        });
546    }
547
548    #[test]
549    fn test_drop_delay() {
550        futures::executor::block_on(async {
551            batched(async {
552                let one = load_foobar_batched(1).fuse();
553
554                futures::pin_mut!(one);
555
556                futures::select_biased! {
557                    one = one => {
558                        assert_eq!(one, Box::new((vec![1], "1".to_owned())));
559                    }
560                }
561
562                pub struct PendingOnce {
563                    is_ready: bool,
564                }
565
566                impl std::future::Future for PendingOnce {
567                    type Output = ();
568                    fn poll(mut self: std::pin::Pin<&mut Self>, _: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
569                        if self.is_ready {
570                            std::task::Poll::Ready(())
571                        } else {
572                            self.is_ready = true;
573                            std::task::Poll::Pending
574                        }
575                    }
576                }
577
578                let _ = delay_loading_batches();
579            }).await;
580        });
581    }
582}