Skip to main content

async_graphql/dataloader/
mod.rs

1//! Batch loading support, used to solve N+1 problem.
2//!
3//! # Examples
4//!
5//! ```rust
6//! use async_graphql::*;
7//! use async_graphql::dataloader::*;
8//! use std::collections::{HashSet, HashMap};
9//! use std::convert::Infallible;
10//! use async_graphql::dataloader::Loader;
11//! use async_graphql::runtime::{TokioSpawner, TokioTimer};
12//!
13//! /// This loader simply converts the integer key into a string value.
14//! struct MyLoader;
15//!
16//! #[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
17//! impl Loader<i32> for MyLoader {
18//!     type Value = String;
19//!     type Error = Infallible;
20//!
21//!     async fn load(&self, keys: &[i32]) -> Result<HashMap<i32, Self::Value>, Self::Error> {
22//!         // Use `MyLoader` to load data.
23//!         Ok(keys.iter().copied().map(|n| (n, n.to_string())).collect())
24//!     }
25//! }
26//!
27//! struct Query;
28//!
29//! #[Object]
30//! impl Query {
31//!     async fn value(&self, ctx: &Context<'_>, n: i32) -> Option<String> {
32//!         ctx.data_unchecked::<DataLoader<MyLoader>>().load_one(n).await.unwrap()
33//!     }
34//! }
35//!
36//! # tokio::runtime::Runtime::new().unwrap().block_on(async move {
37//! let schema = Schema::new(Query, EmptyMutation, EmptySubscription);
38//! let query = r#"
39//!     {
40//!         v1: value(n: 1)
41//!         v2: value(n: 2)
42//!         v3: value(n: 3)
43//!         v4: value(n: 4)
44//!         v5: value(n: 5)
45//!     }
46//! "#;
47//! let request = Request::new(query).data(DataLoader::new(MyLoader, TokioSpawner::current(), TokioTimer::default()));
48//! let res = schema.execute(request).await.into_result().unwrap().data;
49//!
50//! assert_eq!(res, value!({
51//!     "v1": "1",
52//!     "v2": "2",
53//!     "v3": "3",
54//!     "v4": "4",
55//!     "v5": "5",
56//! }));
57//! # });
58//! ```
59
60mod cache;
61
62#[cfg(not(feature = "boxed-trait"))]
63use std::future::Future;
64use std::{
65    any::{Any, TypeId},
66    borrow::Cow,
67    collections::{HashMap, HashSet},
68    hash::Hash,
69    sync::{
70        Arc,
71        atomic::{AtomicBool, Ordering},
72    },
73    time::Duration,
74};
75
76pub use cache::{CacheFactory, CacheStorage, HashMapCache, LruCache, NoCache};
77use futures_channel::oneshot;
78use futures_util::task::{Spawn, SpawnExt};
79use rustc_hash::FxBuildHasher;
80#[cfg(feature = "tracing")]
81use tracing::{Instrument, info_span, instrument};
82
83use crate::runtime::Timer;
84
85type FxHashMap<K, V> = scc::HashMap<K, V, FxBuildHasher>;
86
87#[allow(clippy::type_complexity)]
88struct ResSender<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
89    use_cache_values: HashMap<K, T::Value>,
90    tx: oneshot::Sender<Result<HashMap<K, T::Value>, T::Error>>,
91}
92
93struct Requests<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
94    keys: HashSet<K>,
95    pending: Vec<(HashSet<K>, ResSender<K, T>)>,
96    cache_storage: Box<dyn CacheStorage<Key = K, Value = T::Value>>,
97    disable_cache: bool,
98}
99
100type KeysAndSender<K, T> = (HashSet<K>, Vec<(HashSet<K>, ResSender<K, T>)>);
101
102impl<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> Requests<K, T> {
103    fn new<C: CacheFactory>(cache_factory: &C) -> Self {
104        Self {
105            keys: Default::default(),
106            pending: Vec::new(),
107            cache_storage: cache_factory.create::<K, T::Value>(),
108            disable_cache: false,
109        }
110    }
111
112    fn take(&mut self) -> KeysAndSender<K, T> {
113        (
114            std::mem::take(&mut self.keys),
115            std::mem::take(&mut self.pending),
116        )
117    }
118}
119
120/// Trait for batch loading.
121#[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
122pub trait Loader<K: Send + Sync + Hash + Eq + Clone + 'static>: Send + Sync + 'static {
123    /// type of value.
124    type Value: Send + Sync + Clone + 'static;
125
126    /// Type of error.
127    type Error: Send + Clone + 'static;
128
129    /// Load the data set specified by the `keys`.
130    #[cfg(feature = "boxed-trait")]
131    async fn load(&self, keys: &[K]) -> Result<HashMap<K, Self::Value>, Self::Error>;
132
133    /// Load the data set specified by the `keys`.
134    #[cfg(not(feature = "boxed-trait"))]
135    fn load(
136        &self,
137        keys: &[K],
138    ) -> impl Future<Output = Result<HashMap<K, Self::Value>, Self::Error>> + Send;
139}
140
141struct DataLoaderInner<T> {
142    requests: FxHashMap<TypeId, Box<dyn Any + Sync + Send>>,
143    loader: T,
144}
145
146impl<T> DataLoaderInner<T> {
147    #[cfg_attr(feature = "tracing", instrument(skip_all))]
148    async fn do_load<K>(&self, disable_cache: bool, (keys, senders): KeysAndSender<K, T>)
149    where
150        K: Send + Sync + Hash + Eq + Clone + 'static,
151        T: Loader<K>,
152    {
153        let tid = TypeId::of::<K>();
154        let keys = keys.into_iter().collect::<Vec<_>>();
155
156        match self.loader.load(&keys).await {
157            Ok(values) => {
158                // update cache
159                let mut entry = self.requests.get_async(&tid).await.unwrap();
160
161                let typed_requests = entry.get_mut().downcast_mut::<Requests<K, T>>().unwrap();
162
163                let disable_cache = typed_requests.disable_cache || disable_cache;
164                if !disable_cache {
165                    for (key, value) in &values {
166                        typed_requests
167                            .cache_storage
168                            .insert(Cow::Borrowed(key), Cow::Borrowed(value));
169                    }
170                }
171
172                // send response
173                for (keys, sender) in senders {
174                    let mut res = HashMap::new();
175                    res.extend(sender.use_cache_values);
176                    for key in &keys {
177                        res.extend(values.get(key).map(|value| (key.clone(), value.clone())));
178                    }
179                    sender.tx.send(Ok(res)).ok();
180                }
181            }
182            Err(err) => {
183                for (_, sender) in senders {
184                    sender.tx.send(Err(err.clone())).ok();
185                }
186            }
187        }
188    }
189}
190
191/// Data loader.
192///
193/// Reference: <https://github.com/facebook/dataloader>
194pub struct DataLoader<T, C = NoCache> {
195    inner: Arc<DataLoaderInner<T>>,
196    cache_factory: C,
197    delay: Duration,
198    max_batch_size: usize,
199    disable_cache: AtomicBool,
200    spawner: Box<dyn Spawn + Send + Sync>,
201    timer: Arc<dyn Timer>,
202}
203
204impl<T> DataLoader<T, NoCache> {
205    /// Use `Loader` to create a [DataLoader] that does not cache records.
206    pub fn new<S, TR>(loader: T, spawner: S, timer: TR) -> Self
207    where
208        S: Spawn + Send + Sync + 'static,
209        TR: Timer,
210    {
211        Self {
212            inner: Arc::new(DataLoaderInner {
213                requests: Default::default(),
214                loader,
215            }),
216            cache_factory: NoCache,
217            delay: Duration::from_millis(1),
218            max_batch_size: 1000,
219            disable_cache: false.into(),
220            spawner: Box::new(spawner),
221            timer: Arc::new(timer),
222        }
223    }
224}
225
226impl<T, C: CacheFactory> DataLoader<T, C> {
227    /// Use `Loader` to create a [DataLoader] with a cache factory.
228    pub fn with_cache<S, TR>(loader: T, spawner: S, timer: TR, cache_factory: C) -> Self
229    where
230        S: Spawn + Send + Sync + 'static,
231        TR: Timer,
232    {
233        Self {
234            inner: Arc::new(DataLoaderInner {
235                requests: Default::default(),
236                loader,
237            }),
238            cache_factory,
239            delay: Duration::from_millis(1),
240            max_batch_size: 1000,
241            disable_cache: false.into(),
242            spawner: Box::new(spawner),
243            timer: Arc::new(timer),
244        }
245    }
246
247    /// Specify the delay time for loading data, the default is `1ms`.
248    #[must_use]
249    pub fn delay(self, delay: Duration) -> Self {
250        Self { delay, ..self }
251    }
252
253    /// pub fn Specify the max batch size for loading data, the default is
254    /// `1000`.
255    ///
256    /// If the keys waiting to be loaded reach the threshold, they are loaded
257    /// immediately.
258    #[must_use]
259    pub fn max_batch_size(self, max_batch_size: usize) -> Self {
260        Self {
261            max_batch_size,
262            ..self
263        }
264    }
265
266    /// Get the loader.
267    #[inline]
268    pub fn loader(&self) -> &T {
269        &self.inner.loader
270    }
271
272    /// Enable/Disable cache of all loaders.
273    pub fn enable_all_cache(&self, enable: bool) {
274        self.disable_cache.store(!enable, Ordering::SeqCst);
275    }
276
277    /// Enable/Disable cache of specified loader.
278    pub async fn enable_cache<K>(&self, enable: bool)
279    where
280        K: Send + Sync + Hash + Eq + Clone + 'static,
281        T: Loader<K>,
282    {
283        let tid = TypeId::of::<K>();
284        let mut entry = self.inner.requests.get_async(&tid).await.unwrap();
285        let typed_requests = entry.get_mut().downcast_mut::<Requests<K, T>>().unwrap();
286        typed_requests.disable_cache = !enable;
287    }
288
289    /// Use this `DataLoader` load a data.
290    #[cfg_attr(feature = "tracing", instrument(skip_all))]
291    pub async fn load_one<K>(&self, key: K) -> Result<Option<T::Value>, T::Error>
292    where
293        K: Send + Sync + Hash + Eq + Clone + 'static,
294        T: Loader<K>,
295    {
296        let mut values = self.load_many(std::iter::once(key.clone())).await?;
297        Ok(values.remove(&key))
298    }
299
300    /// Use this `DataLoader` to load some data.
301    #[cfg_attr(feature = "tracing", instrument(skip_all))]
302    pub async fn load_many<K, I>(&self, keys: I) -> Result<HashMap<K, T::Value>, T::Error>
303    where
304        K: Send + Sync + Hash + Eq + Clone + 'static,
305        I: IntoIterator<Item = K>,
306        T: Loader<K>,
307    {
308        enum Action<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
309            ImmediateLoad(KeysAndSender<K, T>),
310            StartFetch,
311            Delay,
312        }
313
314        let tid = TypeId::of::<K>();
315
316        let (action, rx) = {
317            let mut entry = self
318                .inner
319                .requests
320                .entry_async(tid)
321                .await
322                .or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)));
323
324            let typed_requests = entry.downcast_mut::<Requests<K, T>>().unwrap();
325
326            let prev_count = typed_requests.keys.len();
327            let mut keys_set = HashSet::new();
328            let mut use_cache_values = HashMap::new();
329
330            if typed_requests.disable_cache || self.disable_cache.load(Ordering::SeqCst) {
331                keys_set = keys.into_iter().collect();
332            } else {
333                for key in keys {
334                    if let Some(value) = typed_requests.cache_storage.get(&key) {
335                        // Already in cache
336                        use_cache_values.insert(key.clone(), value);
337                    } else {
338                        keys_set.insert(key);
339                    }
340                }
341            }
342
343            if !use_cache_values.is_empty() && keys_set.is_empty() {
344                return Ok(use_cache_values);
345            } else if use_cache_values.is_empty() && keys_set.is_empty() {
346                return Ok(Default::default());
347            }
348
349            typed_requests.keys.extend(keys_set.clone());
350            let (tx, rx) = oneshot::channel();
351            typed_requests.pending.push((
352                keys_set,
353                ResSender {
354                    use_cache_values,
355                    tx,
356                },
357            ));
358
359            if typed_requests.keys.len() >= self.max_batch_size {
360                (Action::ImmediateLoad(typed_requests.take()), rx)
361            } else {
362                (
363                    if !typed_requests.keys.is_empty() && prev_count == 0 {
364                        Action::StartFetch
365                    } else {
366                        Action::Delay
367                    },
368                    rx,
369                )
370            }
371        };
372
373        match action {
374            Action::ImmediateLoad(keys) => {
375                let inner = self.inner.clone();
376                let disable_cache = self.disable_cache.load(Ordering::SeqCst);
377                let task = async move { inner.do_load(disable_cache, keys).await };
378                #[cfg(feature = "tracing")]
379                let task = task
380                    .instrument(info_span!("immediate_load"))
381                    .in_current_span();
382
383                let _ = self.spawner.spawn(task);
384            }
385            Action::StartFetch => {
386                let inner = self.inner.clone();
387                let disable_cache = self.disable_cache.load(Ordering::SeqCst);
388                let delay = self.delay;
389                let timer = self.timer.clone();
390
391                let task = async move {
392                    timer.delay(delay).await;
393
394                    let keys = {
395                        let mut entry = inner.requests.get_async(&tid).await.unwrap();
396                        let typed_requests = entry.downcast_mut::<Requests<K, T>>().unwrap();
397                        typed_requests.take()
398                    };
399
400                    if !keys.0.is_empty() {
401                        inner.do_load(disable_cache, keys).await
402                    }
403                };
404                #[cfg(feature = "tracing")]
405                let task = task.instrument(info_span!("start_fetch")).in_current_span();
406                let _ = self.spawner.spawn(task);
407            }
408            Action::Delay => {}
409        }
410
411        rx.await.unwrap()
412    }
413
414    /// Feed some data into the cache.
415    ///
416    /// **NOTE: If the cache type is [NoCache], this function will not take
417    /// effect. **
418    #[cfg_attr(feature = "tracing", instrument(skip_all))]
419    pub async fn feed_many<K, I>(&self, values: I)
420    where
421        K: Send + Sync + Hash + Eq + Clone + 'static,
422        I: IntoIterator<Item = (K, T::Value)>,
423        T: Loader<K>,
424    {
425        let tid = TypeId::of::<K>();
426        let mut entry = self
427            .inner
428            .requests
429            .entry_async(tid)
430            .await
431            .or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)));
432
433        let typed_requests = entry.downcast_mut::<Requests<K, T>>().unwrap();
434
435        for (key, value) in values {
436            typed_requests
437                .cache_storage
438                .insert(Cow::Owned(key), Cow::Owned(value));
439        }
440    }
441
442    /// Feed some data into the cache.
443    ///
444    /// **NOTE: If the cache type is [NoCache], this function will not take
445    /// effect. **
446    #[cfg_attr(feature = "tracing", instrument(skip_all))]
447    pub async fn feed_one<K>(&self, key: K, value: T::Value)
448    where
449        K: Send + Sync + Hash + Eq + Clone + 'static,
450        T: Loader<K>,
451    {
452        self.feed_many(std::iter::once((key, value))).await;
453    }
454
455    /// Clear a specific entry from the cache.
456    ///
457    /// **NOTE: if the cache type is [NoCache], this function will not take
458    /// effect. **
459    #[cfg_attr(feature = "tracing", instrument(skip_all))]
460    pub fn clear_one<K>(&self, key: &K)
461    where
462        K: Send + Sync + Hash + Eq + Clone + 'static,
463        T: Loader<K>,
464    {
465        let tid = TypeId::of::<K>();
466        let mut entry = self
467            .inner
468            .requests
469            .entry_sync(tid)
470            .or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)));
471
472        let typed_requests = entry.downcast_mut::<Requests<K, T>>().unwrap();
473        typed_requests.cache_storage.remove(key);
474    }
475
476    /// Clears the cache.
477    ///
478    /// **NOTE: If the cache type is [NoCache], this function will not take
479    /// effect. **
480    #[cfg_attr(feature = "tracing", instrument(skip_all))]
481    pub fn clear<K>(&self)
482    where
483        K: Send + Sync + Hash + Eq + Clone + 'static,
484        T: Loader<K>,
485    {
486        let tid = TypeId::of::<K>();
487        let mut entry = self
488            .inner
489            .requests
490            .entry_sync(tid)
491            .or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)));
492
493        let typed_requests = entry.downcast_mut::<Requests<K, T>>().unwrap();
494        typed_requests.cache_storage.clear();
495    }
496
497    /// Gets all values in the cache.
498    pub async fn get_cached_values<K>(&self) -> HashMap<K, T::Value>
499    where
500        K: Send + Sync + Hash + Eq + Clone + 'static,
501        T: Loader<K>,
502    {
503        let tid = TypeId::of::<K>();
504        match self.inner.requests.get_async(&tid).await {
505            None => HashMap::new(),
506            Some(requests) => {
507                let typed_requests = requests.get().downcast_ref::<Requests<K, T>>().unwrap();
508                typed_requests
509                    .cache_storage
510                    .iter()
511                    .map(|(k, v)| (k.clone(), v.clone()))
512                    .collect()
513            }
514        }
515    }
516}
517
518#[cfg(test)]
519mod tests {
520    use rustc_hash::FxBuildHasher;
521
522    use super::*;
523    use crate::runtime::{TokioSpawner, TokioTimer};
524
525    struct MyLoader;
526
527    #[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
528    impl Loader<i32> for MyLoader {
529        type Value = i32;
530        type Error = ();
531
532        async fn load(&self, keys: &[i32]) -> Result<HashMap<i32, Self::Value>, Self::Error> {
533            assert!(keys.len() <= 10);
534            Ok(keys.iter().copied().map(|k| (k, k)).collect())
535        }
536    }
537
538    #[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
539    impl Loader<i64> for MyLoader {
540        type Value = i64;
541        type Error = ();
542
543        async fn load(&self, keys: &[i64]) -> Result<HashMap<i64, Self::Value>, Self::Error> {
544            assert!(keys.len() <= 10);
545            Ok(keys.iter().copied().map(|k| (k, k)).collect())
546        }
547    }
548
549    #[tokio::test]
550    async fn test_dataloader() {
551        let loader = Arc::new(
552            DataLoader::new(MyLoader, TokioSpawner::current(), TokioTimer::default())
553                .max_batch_size(10),
554        );
555        assert_eq!(
556            futures_util::future::try_join_all((0..100i32).map({
557                let loader = loader.clone();
558                move |n| {
559                    let loader = loader.clone();
560                    async move { loader.load_one(n).await }
561                }
562            }))
563            .await
564            .unwrap(),
565            (0..100).map(Option::Some).collect::<Vec<_>>()
566        );
567
568        assert_eq!(
569            futures_util::future::try_join_all((0..100i64).map({
570                let loader = loader.clone();
571                move |n| {
572                    let loader = loader.clone();
573                    async move { loader.load_one(n).await }
574                }
575            }))
576            .await
577            .unwrap(),
578            (0..100).map(Option::Some).collect::<Vec<_>>()
579        );
580    }
581
582    #[tokio::test]
583    async fn test_duplicate_keys() {
584        let loader = Arc::new(
585            DataLoader::new(MyLoader, TokioSpawner::current(), TokioTimer::default())
586                .max_batch_size(10),
587        );
588        assert_eq!(
589            futures_util::future::try_join_all([1, 3, 5, 1, 7, 8, 3, 7].iter().copied().map({
590                let loader = loader.clone();
591                move |n| {
592                    let loader = loader.clone();
593                    async move { loader.load_one(n).await }
594                }
595            }))
596            .await
597            .unwrap(),
598            [1, 3, 5, 1, 7, 8, 3, 7]
599                .iter()
600                .copied()
601                .map(Option::Some)
602                .collect::<Vec<_>>()
603        );
604    }
605
606    #[tokio::test]
607    async fn test_dataloader_load_empty() {
608        let loader = DataLoader::new(MyLoader, TokioSpawner::current(), TokioTimer::default());
609        assert!(loader.load_many::<i32, _>(vec![]).await.unwrap().is_empty());
610    }
611
612    #[tokio::test]
613    async fn test_dataloader_with_cache() {
614        let loader = DataLoader::with_cache(
615            MyLoader,
616            TokioSpawner::current(),
617            TokioTimer::default(),
618            HashMapCache::default(),
619        );
620        loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
621
622        // All from the cache
623        assert_eq!(
624            loader.load_many(vec![1, 2, 3]).await.unwrap(),
625            vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
626        );
627
628        // Part from the cache
629        assert_eq!(
630            loader.load_many(vec![1, 5, 6]).await.unwrap(),
631            vec![(1, 10), (5, 5), (6, 6)].into_iter().collect()
632        );
633
634        // All from the loader
635        assert_eq!(
636            loader.load_many(vec![8, 9, 10]).await.unwrap(),
637            vec![(8, 8), (9, 9), (10, 10)].into_iter().collect()
638        );
639
640        // Clear cache
641        loader.clear::<i32>();
642        assert_eq!(
643            loader.load_many(vec![1, 2, 3]).await.unwrap(),
644            vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
645        );
646    }
647
648    #[tokio::test]
649    async fn test_dataloader_with_cache_hashmap_fx() {
650        let loader = DataLoader::with_cache(
651            MyLoader,
652            TokioSpawner::current(),
653            TokioTimer::default(),
654            HashMapCache::<FxBuildHasher>::new(),
655        );
656        loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
657
658        // All from the cache
659        assert_eq!(
660            loader.load_many(vec![1, 2, 3]).await.unwrap(),
661            vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
662        );
663
664        // Part from the cache
665        assert_eq!(
666            loader.load_many(vec![1, 5, 6]).await.unwrap(),
667            vec![(1, 10), (5, 5), (6, 6)].into_iter().collect()
668        );
669
670        // All from the loader
671        assert_eq!(
672            loader.load_many(vec![8, 9, 10]).await.unwrap(),
673            vec![(8, 8), (9, 9), (10, 10)].into_iter().collect()
674        );
675
676        // Clear cache
677        loader.clear::<i32>();
678        assert_eq!(
679            loader.load_many(vec![1, 2, 3]).await.unwrap(),
680            vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
681        );
682    }
683
684    #[tokio::test]
685    async fn test_dataloader_disable_all_cache() {
686        let loader = DataLoader::with_cache(
687            MyLoader,
688            TokioSpawner::current(),
689            TokioTimer::default(),
690            HashMapCache::default(),
691        );
692        loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
693
694        // All from the loader
695        loader.enable_all_cache(false);
696        assert_eq!(
697            loader.load_many(vec![1, 2, 3]).await.unwrap(),
698            vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
699        );
700
701        // All from the cache
702        loader.enable_all_cache(true);
703        assert_eq!(
704            loader.load_many(vec![1, 2, 3]).await.unwrap(),
705            vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
706        );
707    }
708
709    #[tokio::test]
710    async fn test_dataloader_evict_one_from_cache() {
711        let loader = DataLoader::with_cache(
712            MyLoader,
713            TokioSpawner::current(),
714            TokioTimer::default(),
715            HashMapCache::default(),
716        );
717        loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
718
719        // All from the cache
720        loader.enable_all_cache(true);
721        assert_eq!(
722            loader.load_many(vec![1, 2, 3]).await.unwrap(),
723            vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
724        );
725
726        // Remove one from cache
727        loader.clear_one(&1);
728        assert_eq!(
729            loader.load_many(vec![1, 2, 3]).await.unwrap(),
730            vec![(1, 1), (2, 20), (3, 30)].into_iter().collect()
731        );
732    }
733
734    #[tokio::test]
735    async fn test_dataloader_disable_cache() {
736        let loader = DataLoader::with_cache(
737            MyLoader,
738            TokioSpawner::current(),
739            TokioTimer::default(),
740            HashMapCache::default(),
741        );
742        loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
743
744        // All from the loader
745        loader.enable_cache::<i32>(false).await;
746        assert_eq!(
747            loader.load_many(vec![1, 2, 3]).await.unwrap(),
748            vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
749        );
750
751        // All from the cache
752        loader.enable_cache::<i32>(true).await;
753        assert_eq!(
754            loader.load_many(vec![1, 2, 3]).await.unwrap(),
755            vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
756        );
757    }
758
759    #[tokio::test]
760    async fn test_dataloader_dead_lock() {
761        struct MyDelayLoader;
762
763        #[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
764        impl Loader<i32> for MyDelayLoader {
765            type Value = i32;
766            type Error = ();
767
768            async fn load(&self, keys: &[i32]) -> Result<HashMap<i32, Self::Value>, Self::Error> {
769                tokio::time::sleep(Duration::from_secs(1)).await;
770                Ok(keys.iter().copied().map(|k| (k, k)).collect())
771            }
772        }
773
774        let loader = Arc::new(
775            DataLoader::with_cache(
776                MyDelayLoader,
777                TokioSpawner::current(),
778                TokioTimer::default(),
779                NoCache,
780            )
781            .delay(Duration::from_secs(1)),
782        );
783        let handle = tokio::spawn({
784            let loader = loader.clone();
785            async move {
786                loader.load_many(vec![1, 2, 3]).await.unwrap();
787            }
788        });
789
790        tokio::time::sleep(Duration::from_millis(500)).await;
791        handle.abort();
792        loader.load_many(vec![4, 5, 6]).await.unwrap();
793    }
794}