1mod cache;
61
62#[cfg(not(feature = "boxed-trait"))]
63use std::future::Future;
64use std::{
65 any::{Any, TypeId},
66 borrow::Cow,
67 collections::{HashMap, HashSet},
68 hash::Hash,
69 sync::{
70 Arc,
71 atomic::{AtomicBool, Ordering},
72 },
73 time::Duration,
74};
75
76pub use cache::{CacheFactory, CacheStorage, HashMapCache, LruCache, NoCache};
77use futures_channel::oneshot;
78use futures_util::task::{Spawn, SpawnExt};
79use rustc_hash::FxBuildHasher;
80#[cfg(feature = "tracing")]
81use tracing::{Instrument, info_span, instrument};
82
83use crate::runtime::Timer;
84
85type FxHashMap<K, V> = scc::HashMap<K, V, FxBuildHasher>;
86
87#[allow(clippy::type_complexity)]
88struct ResSender<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
89 use_cache_values: HashMap<K, T::Value>,
90 tx: oneshot::Sender<Result<HashMap<K, T::Value>, T::Error>>,
91}
92
93struct Requests<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
94 keys: HashSet<K>,
95 pending: Vec<(HashSet<K>, ResSender<K, T>)>,
96 cache_storage: Box<dyn CacheStorage<Key = K, Value = T::Value>>,
97 disable_cache: bool,
98}
99
100type KeysAndSender<K, T> = (HashSet<K>, Vec<(HashSet<K>, ResSender<K, T>)>);
101
102impl<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> Requests<K, T> {
103 fn new<C: CacheFactory>(cache_factory: &C) -> Self {
104 Self {
105 keys: Default::default(),
106 pending: Vec::new(),
107 cache_storage: cache_factory.create::<K, T::Value>(),
108 disable_cache: false,
109 }
110 }
111
112 fn take(&mut self) -> KeysAndSender<K, T> {
113 (
114 std::mem::take(&mut self.keys),
115 std::mem::take(&mut self.pending),
116 )
117 }
118}
119
120#[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
122pub trait Loader<K: Send + Sync + Hash + Eq + Clone + 'static>: Send + Sync + 'static {
123 type Value: Send + Sync + Clone + 'static;
125
126 type Error: Send + Clone + 'static;
128
129 #[cfg(feature = "boxed-trait")]
131 async fn load(&self, keys: &[K]) -> Result<HashMap<K, Self::Value>, Self::Error>;
132
133 #[cfg(not(feature = "boxed-trait"))]
135 fn load(
136 &self,
137 keys: &[K],
138 ) -> impl Future<Output = Result<HashMap<K, Self::Value>, Self::Error>> + Send;
139}
140
141struct DataLoaderInner<T> {
142 requests: FxHashMap<TypeId, Box<dyn Any + Sync + Send>>,
143 loader: T,
144}
145
146impl<T> DataLoaderInner<T> {
147 #[cfg_attr(feature = "tracing", instrument(skip_all))]
148 async fn do_load<K>(&self, disable_cache: bool, (keys, senders): KeysAndSender<K, T>)
149 where
150 K: Send + Sync + Hash + Eq + Clone + 'static,
151 T: Loader<K>,
152 {
153 let tid = TypeId::of::<K>();
154 let keys = keys.into_iter().collect::<Vec<_>>();
155
156 match self.loader.load(&keys).await {
157 Ok(values) => {
158 let mut entry = self.requests.get_async(&tid).await.unwrap();
160
161 let typed_requests = entry.get_mut().downcast_mut::<Requests<K, T>>().unwrap();
162
163 let disable_cache = typed_requests.disable_cache || disable_cache;
164 if !disable_cache {
165 for (key, value) in &values {
166 typed_requests
167 .cache_storage
168 .insert(Cow::Borrowed(key), Cow::Borrowed(value));
169 }
170 }
171
172 for (keys, sender) in senders {
174 let mut res = HashMap::new();
175 res.extend(sender.use_cache_values);
176 for key in &keys {
177 res.extend(values.get(key).map(|value| (key.clone(), value.clone())));
178 }
179 sender.tx.send(Ok(res)).ok();
180 }
181 }
182 Err(err) => {
183 for (_, sender) in senders {
184 sender.tx.send(Err(err.clone())).ok();
185 }
186 }
187 }
188 }
189}
190
191pub struct DataLoader<T, C = NoCache> {
195 inner: Arc<DataLoaderInner<T>>,
196 cache_factory: C,
197 delay: Duration,
198 max_batch_size: usize,
199 disable_cache: AtomicBool,
200 spawner: Box<dyn Spawn + Send + Sync>,
201 timer: Arc<dyn Timer>,
202}
203
204impl<T> DataLoader<T, NoCache> {
205 pub fn new<S, TR>(loader: T, spawner: S, timer: TR) -> Self
207 where
208 S: Spawn + Send + Sync + 'static,
209 TR: Timer,
210 {
211 Self {
212 inner: Arc::new(DataLoaderInner {
213 requests: Default::default(),
214 loader,
215 }),
216 cache_factory: NoCache,
217 delay: Duration::from_millis(1),
218 max_batch_size: 1000,
219 disable_cache: false.into(),
220 spawner: Box::new(spawner),
221 timer: Arc::new(timer),
222 }
223 }
224}
225
226impl<T, C: CacheFactory> DataLoader<T, C> {
227 pub fn with_cache<S, TR>(loader: T, spawner: S, timer: TR, cache_factory: C) -> Self
229 where
230 S: Spawn + Send + Sync + 'static,
231 TR: Timer,
232 {
233 Self {
234 inner: Arc::new(DataLoaderInner {
235 requests: Default::default(),
236 loader,
237 }),
238 cache_factory,
239 delay: Duration::from_millis(1),
240 max_batch_size: 1000,
241 disable_cache: false.into(),
242 spawner: Box::new(spawner),
243 timer: Arc::new(timer),
244 }
245 }
246
247 #[must_use]
249 pub fn delay(self, delay: Duration) -> Self {
250 Self { delay, ..self }
251 }
252
253 #[must_use]
259 pub fn max_batch_size(self, max_batch_size: usize) -> Self {
260 Self {
261 max_batch_size,
262 ..self
263 }
264 }
265
266 #[inline]
268 pub fn loader(&self) -> &T {
269 &self.inner.loader
270 }
271
272 pub fn enable_all_cache(&self, enable: bool) {
274 self.disable_cache.store(!enable, Ordering::SeqCst);
275 }
276
277 pub async fn enable_cache<K>(&self, enable: bool)
279 where
280 K: Send + Sync + Hash + Eq + Clone + 'static,
281 T: Loader<K>,
282 {
283 let tid = TypeId::of::<K>();
284 let mut entry = self.inner.requests.get_async(&tid).await.unwrap();
285 let typed_requests = entry.get_mut().downcast_mut::<Requests<K, T>>().unwrap();
286 typed_requests.disable_cache = !enable;
287 }
288
289 #[cfg_attr(feature = "tracing", instrument(skip_all))]
291 pub async fn load_one<K>(&self, key: K) -> Result<Option<T::Value>, T::Error>
292 where
293 K: Send + Sync + Hash + Eq + Clone + 'static,
294 T: Loader<K>,
295 {
296 let mut values = self.load_many(std::iter::once(key.clone())).await?;
297 Ok(values.remove(&key))
298 }
299
300 #[cfg_attr(feature = "tracing", instrument(skip_all))]
302 pub async fn load_many<K, I>(&self, keys: I) -> Result<HashMap<K, T::Value>, T::Error>
303 where
304 K: Send + Sync + Hash + Eq + Clone + 'static,
305 I: IntoIterator<Item = K>,
306 T: Loader<K>,
307 {
308 enum Action<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
309 ImmediateLoad(KeysAndSender<K, T>),
310 StartFetch,
311 Delay,
312 }
313
314 let tid = TypeId::of::<K>();
315
316 let (action, rx) = {
317 let mut entry = self
318 .inner
319 .requests
320 .entry_async(tid)
321 .await
322 .or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)));
323
324 let typed_requests = entry.downcast_mut::<Requests<K, T>>().unwrap();
325
326 let prev_count = typed_requests.keys.len();
327 let mut keys_set = HashSet::new();
328 let mut use_cache_values = HashMap::new();
329
330 if typed_requests.disable_cache || self.disable_cache.load(Ordering::SeqCst) {
331 keys_set = keys.into_iter().collect();
332 } else {
333 for key in keys {
334 if let Some(value) = typed_requests.cache_storage.get(&key) {
335 use_cache_values.insert(key.clone(), value);
337 } else {
338 keys_set.insert(key);
339 }
340 }
341 }
342
343 if !use_cache_values.is_empty() && keys_set.is_empty() {
344 return Ok(use_cache_values);
345 } else if use_cache_values.is_empty() && keys_set.is_empty() {
346 return Ok(Default::default());
347 }
348
349 typed_requests.keys.extend(keys_set.clone());
350 let (tx, rx) = oneshot::channel();
351 typed_requests.pending.push((
352 keys_set,
353 ResSender {
354 use_cache_values,
355 tx,
356 },
357 ));
358
359 if typed_requests.keys.len() >= self.max_batch_size {
360 (Action::ImmediateLoad(typed_requests.take()), rx)
361 } else {
362 (
363 if !typed_requests.keys.is_empty() && prev_count == 0 {
364 Action::StartFetch
365 } else {
366 Action::Delay
367 },
368 rx,
369 )
370 }
371 };
372
373 match action {
374 Action::ImmediateLoad(keys) => {
375 let inner = self.inner.clone();
376 let disable_cache = self.disable_cache.load(Ordering::SeqCst);
377 let task = async move { inner.do_load(disable_cache, keys).await };
378 #[cfg(feature = "tracing")]
379 let task = task
380 .instrument(info_span!("immediate_load"))
381 .in_current_span();
382
383 let _ = self.spawner.spawn(task);
384 }
385 Action::StartFetch => {
386 let inner = self.inner.clone();
387 let disable_cache = self.disable_cache.load(Ordering::SeqCst);
388 let delay = self.delay;
389 let timer = self.timer.clone();
390
391 let task = async move {
392 timer.delay(delay).await;
393
394 let keys = {
395 let mut entry = inner.requests.get_async(&tid).await.unwrap();
396 let typed_requests = entry.downcast_mut::<Requests<K, T>>().unwrap();
397 typed_requests.take()
398 };
399
400 if !keys.0.is_empty() {
401 inner.do_load(disable_cache, keys).await
402 }
403 };
404 #[cfg(feature = "tracing")]
405 let task = task.instrument(info_span!("start_fetch")).in_current_span();
406 let _ = self.spawner.spawn(task);
407 }
408 Action::Delay => {}
409 }
410
411 rx.await.unwrap()
412 }
413
414 #[cfg_attr(feature = "tracing", instrument(skip_all))]
419 pub async fn feed_many<K, I>(&self, values: I)
420 where
421 K: Send + Sync + Hash + Eq + Clone + 'static,
422 I: IntoIterator<Item = (K, T::Value)>,
423 T: Loader<K>,
424 {
425 let tid = TypeId::of::<K>();
426 let mut entry = self
427 .inner
428 .requests
429 .entry_async(tid)
430 .await
431 .or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)));
432
433 let typed_requests = entry.downcast_mut::<Requests<K, T>>().unwrap();
434
435 for (key, value) in values {
436 typed_requests
437 .cache_storage
438 .insert(Cow::Owned(key), Cow::Owned(value));
439 }
440 }
441
442 #[cfg_attr(feature = "tracing", instrument(skip_all))]
447 pub async fn feed_one<K>(&self, key: K, value: T::Value)
448 where
449 K: Send + Sync + Hash + Eq + Clone + 'static,
450 T: Loader<K>,
451 {
452 self.feed_many(std::iter::once((key, value))).await;
453 }
454
455 #[cfg_attr(feature = "tracing", instrument(skip_all))]
460 pub fn clear_one<K>(&self, key: &K)
461 where
462 K: Send + Sync + Hash + Eq + Clone + 'static,
463 T: Loader<K>,
464 {
465 let tid = TypeId::of::<K>();
466 let mut entry = self
467 .inner
468 .requests
469 .entry_sync(tid)
470 .or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)));
471
472 let typed_requests = entry.downcast_mut::<Requests<K, T>>().unwrap();
473 typed_requests.cache_storage.remove(key);
474 }
475
476 #[cfg_attr(feature = "tracing", instrument(skip_all))]
481 pub fn clear<K>(&self)
482 where
483 K: Send + Sync + Hash + Eq + Clone + 'static,
484 T: Loader<K>,
485 {
486 let tid = TypeId::of::<K>();
487 let mut entry = self
488 .inner
489 .requests
490 .entry_sync(tid)
491 .or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)));
492
493 let typed_requests = entry.downcast_mut::<Requests<K, T>>().unwrap();
494 typed_requests.cache_storage.clear();
495 }
496
497 pub async fn get_cached_values<K>(&self) -> HashMap<K, T::Value>
499 where
500 K: Send + Sync + Hash + Eq + Clone + 'static,
501 T: Loader<K>,
502 {
503 let tid = TypeId::of::<K>();
504 match self.inner.requests.get_async(&tid).await {
505 None => HashMap::new(),
506 Some(requests) => {
507 let typed_requests = requests.get().downcast_ref::<Requests<K, T>>().unwrap();
508 typed_requests
509 .cache_storage
510 .iter()
511 .map(|(k, v)| (k.clone(), v.clone()))
512 .collect()
513 }
514 }
515 }
516}
517
518#[cfg(test)]
519mod tests {
520 use rustc_hash::FxBuildHasher;
521
522 use super::*;
523 use crate::runtime::{TokioSpawner, TokioTimer};
524
525 struct MyLoader;
526
527 #[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
528 impl Loader<i32> for MyLoader {
529 type Value = i32;
530 type Error = ();
531
532 async fn load(&self, keys: &[i32]) -> Result<HashMap<i32, Self::Value>, Self::Error> {
533 assert!(keys.len() <= 10);
534 Ok(keys.iter().copied().map(|k| (k, k)).collect())
535 }
536 }
537
538 #[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
539 impl Loader<i64> for MyLoader {
540 type Value = i64;
541 type Error = ();
542
543 async fn load(&self, keys: &[i64]) -> Result<HashMap<i64, Self::Value>, Self::Error> {
544 assert!(keys.len() <= 10);
545 Ok(keys.iter().copied().map(|k| (k, k)).collect())
546 }
547 }
548
549 #[tokio::test]
550 async fn test_dataloader() {
551 let loader = Arc::new(
552 DataLoader::new(MyLoader, TokioSpawner::current(), TokioTimer::default())
553 .max_batch_size(10),
554 );
555 assert_eq!(
556 futures_util::future::try_join_all((0..100i32).map({
557 let loader = loader.clone();
558 move |n| {
559 let loader = loader.clone();
560 async move { loader.load_one(n).await }
561 }
562 }))
563 .await
564 .unwrap(),
565 (0..100).map(Option::Some).collect::<Vec<_>>()
566 );
567
568 assert_eq!(
569 futures_util::future::try_join_all((0..100i64).map({
570 let loader = loader.clone();
571 move |n| {
572 let loader = loader.clone();
573 async move { loader.load_one(n).await }
574 }
575 }))
576 .await
577 .unwrap(),
578 (0..100).map(Option::Some).collect::<Vec<_>>()
579 );
580 }
581
582 #[tokio::test]
583 async fn test_duplicate_keys() {
584 let loader = Arc::new(
585 DataLoader::new(MyLoader, TokioSpawner::current(), TokioTimer::default())
586 .max_batch_size(10),
587 );
588 assert_eq!(
589 futures_util::future::try_join_all([1, 3, 5, 1, 7, 8, 3, 7].iter().copied().map({
590 let loader = loader.clone();
591 move |n| {
592 let loader = loader.clone();
593 async move { loader.load_one(n).await }
594 }
595 }))
596 .await
597 .unwrap(),
598 [1, 3, 5, 1, 7, 8, 3, 7]
599 .iter()
600 .copied()
601 .map(Option::Some)
602 .collect::<Vec<_>>()
603 );
604 }
605
606 #[tokio::test]
607 async fn test_dataloader_load_empty() {
608 let loader = DataLoader::new(MyLoader, TokioSpawner::current(), TokioTimer::default());
609 assert!(loader.load_many::<i32, _>(vec![]).await.unwrap().is_empty());
610 }
611
612 #[tokio::test]
613 async fn test_dataloader_with_cache() {
614 let loader = DataLoader::with_cache(
615 MyLoader,
616 TokioSpawner::current(),
617 TokioTimer::default(),
618 HashMapCache::default(),
619 );
620 loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
621
622 assert_eq!(
624 loader.load_many(vec![1, 2, 3]).await.unwrap(),
625 vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
626 );
627
628 assert_eq!(
630 loader.load_many(vec![1, 5, 6]).await.unwrap(),
631 vec![(1, 10), (5, 5), (6, 6)].into_iter().collect()
632 );
633
634 assert_eq!(
636 loader.load_many(vec![8, 9, 10]).await.unwrap(),
637 vec![(8, 8), (9, 9), (10, 10)].into_iter().collect()
638 );
639
640 loader.clear::<i32>();
642 assert_eq!(
643 loader.load_many(vec![1, 2, 3]).await.unwrap(),
644 vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
645 );
646 }
647
648 #[tokio::test]
649 async fn test_dataloader_with_cache_hashmap_fx() {
650 let loader = DataLoader::with_cache(
651 MyLoader,
652 TokioSpawner::current(),
653 TokioTimer::default(),
654 HashMapCache::<FxBuildHasher>::new(),
655 );
656 loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
657
658 assert_eq!(
660 loader.load_many(vec![1, 2, 3]).await.unwrap(),
661 vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
662 );
663
664 assert_eq!(
666 loader.load_many(vec![1, 5, 6]).await.unwrap(),
667 vec![(1, 10), (5, 5), (6, 6)].into_iter().collect()
668 );
669
670 assert_eq!(
672 loader.load_many(vec![8, 9, 10]).await.unwrap(),
673 vec![(8, 8), (9, 9), (10, 10)].into_iter().collect()
674 );
675
676 loader.clear::<i32>();
678 assert_eq!(
679 loader.load_many(vec![1, 2, 3]).await.unwrap(),
680 vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
681 );
682 }
683
684 #[tokio::test]
685 async fn test_dataloader_disable_all_cache() {
686 let loader = DataLoader::with_cache(
687 MyLoader,
688 TokioSpawner::current(),
689 TokioTimer::default(),
690 HashMapCache::default(),
691 );
692 loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
693
694 loader.enable_all_cache(false);
696 assert_eq!(
697 loader.load_many(vec![1, 2, 3]).await.unwrap(),
698 vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
699 );
700
701 loader.enable_all_cache(true);
703 assert_eq!(
704 loader.load_many(vec![1, 2, 3]).await.unwrap(),
705 vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
706 );
707 }
708
709 #[tokio::test]
710 async fn test_dataloader_evict_one_from_cache() {
711 let loader = DataLoader::with_cache(
712 MyLoader,
713 TokioSpawner::current(),
714 TokioTimer::default(),
715 HashMapCache::default(),
716 );
717 loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
718
719 loader.enable_all_cache(true);
721 assert_eq!(
722 loader.load_many(vec![1, 2, 3]).await.unwrap(),
723 vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
724 );
725
726 loader.clear_one(&1);
728 assert_eq!(
729 loader.load_many(vec![1, 2, 3]).await.unwrap(),
730 vec![(1, 1), (2, 20), (3, 30)].into_iter().collect()
731 );
732 }
733
734 #[tokio::test]
735 async fn test_dataloader_disable_cache() {
736 let loader = DataLoader::with_cache(
737 MyLoader,
738 TokioSpawner::current(),
739 TokioTimer::default(),
740 HashMapCache::default(),
741 );
742 loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
743
744 loader.enable_cache::<i32>(false).await;
746 assert_eq!(
747 loader.load_many(vec![1, 2, 3]).await.unwrap(),
748 vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
749 );
750
751 loader.enable_cache::<i32>(true).await;
753 assert_eq!(
754 loader.load_many(vec![1, 2, 3]).await.unwrap(),
755 vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
756 );
757 }
758
759 #[tokio::test]
760 async fn test_dataloader_dead_lock() {
761 struct MyDelayLoader;
762
763 #[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
764 impl Loader<i32> for MyDelayLoader {
765 type Value = i32;
766 type Error = ();
767
768 async fn load(&self, keys: &[i32]) -> Result<HashMap<i32, Self::Value>, Self::Error> {
769 tokio::time::sleep(Duration::from_secs(1)).await;
770 Ok(keys.iter().copied().map(|k| (k, k)).collect())
771 }
772 }
773
774 let loader = Arc::new(
775 DataLoader::with_cache(
776 MyDelayLoader,
777 TokioSpawner::current(),
778 TokioTimer::default(),
779 NoCache,
780 )
781 .delay(Duration::from_secs(1)),
782 );
783 let handle = tokio::spawn({
784 let loader = loader.clone();
785 async move {
786 loader.load_many(vec![1, 2, 3]).await.unwrap();
787 }
788 });
789
790 tokio::time::sleep(Duration::from_millis(500)).await;
791 handle.abort();
792 loader.load_many(vec![4, 5, 6]).await.unwrap();
793 }
794}