1mod cache;
61
62#[cfg(not(feature = "boxed-trait"))]
63use std::future::Future;
64use std::{
65 any::{Any, TypeId},
66 borrow::Cow,
67 collections::{HashMap, HashSet},
68 hash::Hash,
69 sync::{
70 Arc,
71 atomic::{AtomicBool, Ordering},
72 },
73 time::Duration,
74};
75
76pub use cache::{CacheFactory, CacheStorage, HashMapCache, LruCache, NoCache};
77use futures_channel::oneshot;
78use futures_util::task::{Spawn, SpawnExt};
79use rustc_hash::FxBuildHasher;
80#[cfg(feature = "tracing")]
81use tracing::{Instrument, info_span, instrument};
82
83use crate::runtime::Timer;
84
85type FxHashMap<K, V> = scc::HashMap<K, V, FxBuildHasher>;
86
87#[allow(clippy::type_complexity)]
88struct ResSender<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
89 use_cache_values: HashMap<K, T::Value>,
90 tx: oneshot::Sender<Result<HashMap<K, T::Value>, T::Error>>,
91}
92
93struct Requests<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
94 keys: HashSet<K>,
95 pending: Vec<(HashSet<K>, ResSender<K, T>)>,
96 cache_storage: Box<dyn CacheStorage<Key = K, Value = T::Value>>,
97 disable_cache: bool,
98}
99
100type KeysAndSender<K, T> = (HashSet<K>, Vec<(HashSet<K>, ResSender<K, T>)>);
101
102impl<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> Requests<K, T> {
103 fn new<C: CacheFactory>(cache_factory: &C) -> Self {
104 Self {
105 keys: Default::default(),
106 pending: Vec::new(),
107 cache_storage: cache_factory.create::<K, T::Value>(),
108 disable_cache: false,
109 }
110 }
111
112 fn take(&mut self) -> KeysAndSender<K, T> {
113 (
114 std::mem::take(&mut self.keys),
115 std::mem::take(&mut self.pending),
116 )
117 }
118}
119
120#[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
122pub trait Loader<K: Send + Sync + Hash + Eq + Clone + 'static>: Send + Sync + 'static {
123 type Value: Send + Sync + Clone + 'static;
125
126 type Error: Send + Clone + 'static;
128
129 #[cfg(feature = "boxed-trait")]
131 async fn load(&self, keys: &[K]) -> Result<HashMap<K, Self::Value>, Self::Error>;
132
133 #[cfg(not(feature = "boxed-trait"))]
135 fn load(
136 &self,
137 keys: &[K],
138 ) -> impl Future<Output = Result<HashMap<K, Self::Value>, Self::Error>> + Send;
139}
140
141struct DataLoaderInner<T> {
142 requests: FxHashMap<TypeId, Box<dyn Any + Sync + Send>>,
143 loader: T,
144}
145
146impl<T> DataLoaderInner<T> {
147 #[cfg_attr(feature = "tracing", instrument(skip_all))]
148 async fn do_load<K>(&self, disable_cache: bool, (keys, senders): KeysAndSender<K, T>)
149 where
150 K: Send + Sync + Hash + Eq + Clone + 'static,
151 T: Loader<K>,
152 {
153 let tid = TypeId::of::<K>();
154 let keys = keys.into_iter().collect::<Vec<_>>();
155
156 match self.loader.load(&keys).await {
157 Ok(values) => {
158 let mut entry = self.requests.get_async(&tid).await.unwrap();
160
161 let typed_requests = entry.get_mut().downcast_mut::<Requests<K, T>>().unwrap();
162
163 let disable_cache = typed_requests.disable_cache || disable_cache;
164 if !disable_cache {
165 for (key, value) in &values {
166 typed_requests
167 .cache_storage
168 .insert(Cow::Borrowed(key), Cow::Borrowed(value));
169 }
170 }
171
172 for (keys, sender) in senders {
174 let mut res = HashMap::new();
175 res.extend(sender.use_cache_values);
176 for key in &keys {
177 res.extend(values.get(key).map(|value| (key.clone(), value.clone())));
178 }
179 sender.tx.send(Ok(res)).ok();
180 }
181 }
182 Err(err) => {
183 for (_, sender) in senders {
184 sender.tx.send(Err(err.clone())).ok();
185 }
186 }
187 }
188 }
189}
190
191pub struct DataLoader<T, C = NoCache> {
195 inner: Arc<DataLoaderInner<T>>,
196 cache_factory: C,
197 delay: Duration,
198 max_batch_size: usize,
199 disable_cache: AtomicBool,
200 spawner: Box<dyn Spawn + Send + Sync>,
201 timer: Arc<dyn Timer>,
202}
203
204impl<T> DataLoader<T, NoCache> {
205 pub fn new<S, TR>(loader: T, spawner: S, timer: TR) -> Self
207 where
208 S: Spawn + Send + Sync + 'static,
209 TR: Timer,
210 {
211 Self {
212 inner: Arc::new(DataLoaderInner {
213 requests: Default::default(),
214 loader,
215 }),
216 cache_factory: NoCache,
217 delay: Duration::from_millis(1),
218 max_batch_size: 1000,
219 disable_cache: false.into(),
220 spawner: Box::new(spawner),
221 timer: Arc::new(timer),
222 }
223 }
224}
225
226impl<T, C: CacheFactory> DataLoader<T, C> {
227 pub fn with_cache<S, TR>(loader: T, spawner: S, timer: TR, cache_factory: C) -> Self
229 where
230 S: Spawn + Send + Sync + 'static,
231 TR: Timer,
232 {
233 Self {
234 inner: Arc::new(DataLoaderInner {
235 requests: Default::default(),
236 loader,
237 }),
238 cache_factory,
239 delay: Duration::from_millis(1),
240 max_batch_size: 1000,
241 disable_cache: false.into(),
242 spawner: Box::new(spawner),
243 timer: Arc::new(timer),
244 }
245 }
246
247 #[must_use]
249 pub fn delay(self, delay: Duration) -> Self {
250 Self { delay, ..self }
251 }
252
253 #[must_use]
259 pub fn max_batch_size(self, max_batch_size: usize) -> Self {
260 Self {
261 max_batch_size,
262 ..self
263 }
264 }
265
266 #[inline]
268 pub fn loader(&self) -> &T {
269 &self.inner.loader
270 }
271
272 pub fn enable_all_cache(&self, enable: bool) {
274 self.disable_cache.store(!enable, Ordering::SeqCst);
275 }
276
277 pub async fn enable_cache<K>(&self, enable: bool)
279 where
280 K: Send + Sync + Hash + Eq + Clone + 'static,
281 T: Loader<K>,
282 {
283 let tid = TypeId::of::<K>();
284 let mut entry = self.inner.requests.get_async(&tid).await.unwrap();
285 let typed_requests = entry.get_mut().downcast_mut::<Requests<K, T>>().unwrap();
286 typed_requests.disable_cache = !enable;
287 }
288
289 #[cfg_attr(feature = "tracing", instrument(skip_all))]
291 pub async fn load_one<K>(&self, key: K) -> Result<Option<T::Value>, T::Error>
292 where
293 K: Send + Sync + Hash + Eq + Clone + 'static,
294 T: Loader<K>,
295 {
296 let mut values = self.load_many(std::iter::once(key.clone())).await?;
297 Ok(values.remove(&key))
298 }
299
300 #[cfg_attr(feature = "tracing", instrument(skip_all))]
302 pub async fn load_many<K, I>(&self, keys: I) -> Result<HashMap<K, T::Value>, T::Error>
303 where
304 K: Send + Sync + Hash + Eq + Clone + 'static,
305 I: IntoIterator<Item = K>,
306 T: Loader<K>,
307 {
308 enum Action<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
309 ImmediateLoad(KeysAndSender<K, T>),
310 StartFetch,
311 Delay,
312 }
313
314 let tid = TypeId::of::<K>();
315
316 let (action, rx) = {
317 let mut entry = self
318 .inner
319 .requests
320 .entry_async(tid)
321 .await
322 .or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)));
323
324 let typed_requests = entry.downcast_mut::<Requests<K, T>>().unwrap();
325
326 let prev_count = typed_requests.keys.len();
327 let mut keys_set = HashSet::new();
328 let mut use_cache_values = HashMap::new();
329
330 if typed_requests.disable_cache || self.disable_cache.load(Ordering::SeqCst) {
331 keys_set = keys.into_iter().collect();
332 } else {
333 for key in keys {
334 if let Some(value) = typed_requests.cache_storage.get(&key) {
335 use_cache_values.insert(key.clone(), value);
337 } else {
338 keys_set.insert(key);
339 }
340 }
341 }
342
343 if !use_cache_values.is_empty() && keys_set.is_empty() {
344 return Ok(use_cache_values);
345 } else if use_cache_values.is_empty() && keys_set.is_empty() {
346 return Ok(Default::default());
347 }
348
349 typed_requests.keys.extend(keys_set.clone());
350 let (tx, rx) = oneshot::channel();
351 typed_requests.pending.push((
352 keys_set,
353 ResSender {
354 use_cache_values,
355 tx,
356 },
357 ));
358
359 if typed_requests.keys.len() >= self.max_batch_size {
360 (Action::ImmediateLoad(typed_requests.take()), rx)
361 } else {
362 (
363 if !typed_requests.keys.is_empty() && prev_count == 0 {
364 Action::StartFetch
365 } else {
366 Action::Delay
367 },
368 rx,
369 )
370 }
371 };
372
373 match action {
374 Action::ImmediateLoad(keys) => {
375 let inner = self.inner.clone();
376 let disable_cache = self.disable_cache.load(Ordering::SeqCst);
377 let task = async move { inner.do_load(disable_cache, keys).await };
378 #[cfg(feature = "tracing")]
379 let task = task
380 .instrument(info_span!("immediate_load"))
381 .in_current_span();
382
383 let _ = self.spawner.spawn(task);
384 }
385 Action::StartFetch => {
386 let inner = self.inner.clone();
387 let disable_cache = self.disable_cache.load(Ordering::SeqCst);
388 let delay = self.delay;
389 let timer = self.timer.clone();
390
391 let task = async move {
392 timer.delay(delay).await;
393
394 let keys = {
395 let mut entry = inner.requests.get_async(&tid).await.unwrap();
396 let typed_requests = entry.downcast_mut::<Requests<K, T>>().unwrap();
397 typed_requests.take()
398 };
399
400 if !keys.0.is_empty() {
401 inner.do_load(disable_cache, keys).await
402 }
403 };
404 #[cfg(feature = "tracing")]
405 let task = task.instrument(info_span!("start_fetch")).in_current_span();
406 let _ = self.spawner.spawn(task);
407 }
408 Action::Delay => {}
409 }
410
411 rx.await.unwrap()
412 }
413
414 #[cfg_attr(feature = "tracing", instrument(skip_all))]
419 pub async fn feed_many<K, I>(&self, values: I)
420 where
421 K: Send + Sync + Hash + Eq + Clone + 'static,
422 I: IntoIterator<Item = (K, T::Value)>,
423 T: Loader<K>,
424 {
425 let tid = TypeId::of::<K>();
426 let mut entry = self
427 .inner
428 .requests
429 .entry_async(tid)
430 .await
431 .or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)));
432
433 let typed_requests = entry.downcast_mut::<Requests<K, T>>().unwrap();
434
435 for (key, value) in values {
436 typed_requests
437 .cache_storage
438 .insert(Cow::Owned(key), Cow::Owned(value));
439 }
440 }
441
442 #[cfg_attr(feature = "tracing", instrument(skip_all))]
447 pub async fn feed_one<K>(&self, key: K, value: T::Value)
448 where
449 K: Send + Sync + Hash + Eq + Clone + 'static,
450 T: Loader<K>,
451 {
452 self.feed_many(std::iter::once((key, value))).await;
453 }
454
455 #[cfg_attr(feature = "tracing", instrument(skip_all))]
460 pub fn clear<K>(&self)
461 where
462 K: Send + Sync + Hash + Eq + Clone + 'static,
463 T: Loader<K>,
464 {
465 let tid = TypeId::of::<K>();
466 let mut entry = self
467 .inner
468 .requests
469 .entry_sync(tid)
470 .or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)));
471
472 let typed_requests = entry.downcast_mut::<Requests<K, T>>().unwrap();
473 typed_requests.cache_storage.clear();
474 }
475
476 pub async fn get_cached_values<K>(&self) -> HashMap<K, T::Value>
478 where
479 K: Send + Sync + Hash + Eq + Clone + 'static,
480 T: Loader<K>,
481 {
482 let tid = TypeId::of::<K>();
483 match self.inner.requests.get_async(&tid).await {
484 None => HashMap::new(),
485 Some(requests) => {
486 let typed_requests = requests.get().downcast_ref::<Requests<K, T>>().unwrap();
487 typed_requests
488 .cache_storage
489 .iter()
490 .map(|(k, v)| (k.clone(), v.clone()))
491 .collect()
492 }
493 }
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use rustc_hash::FxBuildHasher;
500
501 use super::*;
502 use crate::runtime::{TokioSpawner, TokioTimer};
503
504 struct MyLoader;
505
506 #[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
507 impl Loader<i32> for MyLoader {
508 type Value = i32;
509 type Error = ();
510
511 async fn load(&self, keys: &[i32]) -> Result<HashMap<i32, Self::Value>, Self::Error> {
512 assert!(keys.len() <= 10);
513 Ok(keys.iter().copied().map(|k| (k, k)).collect())
514 }
515 }
516
517 #[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
518 impl Loader<i64> for MyLoader {
519 type Value = i64;
520 type Error = ();
521
522 async fn load(&self, keys: &[i64]) -> Result<HashMap<i64, Self::Value>, Self::Error> {
523 assert!(keys.len() <= 10);
524 Ok(keys.iter().copied().map(|k| (k, k)).collect())
525 }
526 }
527
528 #[tokio::test]
529 async fn test_dataloader() {
530 let loader = Arc::new(
531 DataLoader::new(MyLoader, TokioSpawner::current(), TokioTimer::default())
532 .max_batch_size(10),
533 );
534 assert_eq!(
535 futures_util::future::try_join_all((0..100i32).map({
536 let loader = loader.clone();
537 move |n| {
538 let loader = loader.clone();
539 async move { loader.load_one(n).await }
540 }
541 }))
542 .await
543 .unwrap(),
544 (0..100).map(Option::Some).collect::<Vec<_>>()
545 );
546
547 assert_eq!(
548 futures_util::future::try_join_all((0..100i64).map({
549 let loader = loader.clone();
550 move |n| {
551 let loader = loader.clone();
552 async move { loader.load_one(n).await }
553 }
554 }))
555 .await
556 .unwrap(),
557 (0..100).map(Option::Some).collect::<Vec<_>>()
558 );
559 }
560
561 #[tokio::test]
562 async fn test_duplicate_keys() {
563 let loader = Arc::new(
564 DataLoader::new(MyLoader, TokioSpawner::current(), TokioTimer::default())
565 .max_batch_size(10),
566 );
567 assert_eq!(
568 futures_util::future::try_join_all([1, 3, 5, 1, 7, 8, 3, 7].iter().copied().map({
569 let loader = loader.clone();
570 move |n| {
571 let loader = loader.clone();
572 async move { loader.load_one(n).await }
573 }
574 }))
575 .await
576 .unwrap(),
577 [1, 3, 5, 1, 7, 8, 3, 7]
578 .iter()
579 .copied()
580 .map(Option::Some)
581 .collect::<Vec<_>>()
582 );
583 }
584
585 #[tokio::test]
586 async fn test_dataloader_load_empty() {
587 let loader = DataLoader::new(MyLoader, TokioSpawner::current(), TokioTimer::default());
588 assert!(loader.load_many::<i32, _>(vec![]).await.unwrap().is_empty());
589 }
590
591 #[tokio::test]
592 async fn test_dataloader_with_cache() {
593 let loader = DataLoader::with_cache(
594 MyLoader,
595 TokioSpawner::current(),
596 TokioTimer::default(),
597 HashMapCache::default(),
598 );
599 loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
600
601 assert_eq!(
603 loader.load_many(vec![1, 2, 3]).await.unwrap(),
604 vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
605 );
606
607 assert_eq!(
609 loader.load_many(vec![1, 5, 6]).await.unwrap(),
610 vec![(1, 10), (5, 5), (6, 6)].into_iter().collect()
611 );
612
613 assert_eq!(
615 loader.load_many(vec![8, 9, 10]).await.unwrap(),
616 vec![(8, 8), (9, 9), (10, 10)].into_iter().collect()
617 );
618
619 loader.clear::<i32>();
621 assert_eq!(
622 loader.load_many(vec![1, 2, 3]).await.unwrap(),
623 vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
624 );
625 }
626
627 #[tokio::test]
628 async fn test_dataloader_with_cache_hashmap_fx() {
629 let loader = DataLoader::with_cache(
630 MyLoader,
631 TokioSpawner::current(),
632 TokioTimer::default(),
633 HashMapCache::<FxBuildHasher>::new(),
634 );
635 loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
636
637 assert_eq!(
639 loader.load_many(vec![1, 2, 3]).await.unwrap(),
640 vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
641 );
642
643 assert_eq!(
645 loader.load_many(vec![1, 5, 6]).await.unwrap(),
646 vec![(1, 10), (5, 5), (6, 6)].into_iter().collect()
647 );
648
649 assert_eq!(
651 loader.load_many(vec![8, 9, 10]).await.unwrap(),
652 vec![(8, 8), (9, 9), (10, 10)].into_iter().collect()
653 );
654
655 loader.clear::<i32>();
657 assert_eq!(
658 loader.load_many(vec![1, 2, 3]).await.unwrap(),
659 vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
660 );
661 }
662
663 #[tokio::test]
664 async fn test_dataloader_disable_all_cache() {
665 let loader = DataLoader::with_cache(
666 MyLoader,
667 TokioSpawner::current(),
668 TokioTimer::default(),
669 HashMapCache::default(),
670 );
671 loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
672
673 loader.enable_all_cache(false);
675 assert_eq!(
676 loader.load_many(vec![1, 2, 3]).await.unwrap(),
677 vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
678 );
679
680 loader.enable_all_cache(true);
682 assert_eq!(
683 loader.load_many(vec![1, 2, 3]).await.unwrap(),
684 vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
685 );
686 }
687
688 #[tokio::test]
689 async fn test_dataloader_disable_cache() {
690 let loader = DataLoader::with_cache(
691 MyLoader,
692 TokioSpawner::current(),
693 TokioTimer::default(),
694 HashMapCache::default(),
695 );
696 loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
697
698 loader.enable_cache::<i32>(false).await;
700 assert_eq!(
701 loader.load_many(vec![1, 2, 3]).await.unwrap(),
702 vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
703 );
704
705 loader.enable_cache::<i32>(true).await;
707 assert_eq!(
708 loader.load_many(vec![1, 2, 3]).await.unwrap(),
709 vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
710 );
711 }
712
713 #[tokio::test]
714 async fn test_dataloader_dead_lock() {
715 struct MyDelayLoader;
716
717 #[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
718 impl Loader<i32> for MyDelayLoader {
719 type Value = i32;
720 type Error = ();
721
722 async fn load(&self, keys: &[i32]) -> Result<HashMap<i32, Self::Value>, Self::Error> {
723 tokio::time::sleep(Duration::from_secs(1)).await;
724 Ok(keys.iter().copied().map(|k| (k, k)).collect())
725 }
726 }
727
728 let loader = Arc::new(
729 DataLoader::with_cache(
730 MyDelayLoader,
731 TokioSpawner::current(),
732 TokioTimer::default(),
733 NoCache,
734 )
735 .delay(Duration::from_secs(1)),
736 );
737 let handle = tokio::spawn({
738 let loader = loader.clone();
739 async move {
740 loader.load_many(vec![1, 2, 3]).await.unwrap();
741 }
742 });
743
744 tokio::time::sleep(Duration::from_millis(500)).await;
745 handle.abort();
746 loader.load_many(vec![4, 5, 6]).await.unwrap();
747 }
748}