1mod cache;
60
61#[cfg(not(feature = "boxed-trait"))]
62use std::future::Future;
63use std::{
64 any::{Any, TypeId},
65 borrow::Cow,
66 collections::{HashMap, HashSet},
67 hash::Hash,
68 sync::{
69 Arc, Mutex,
70 atomic::{AtomicBool, Ordering},
71 },
72 time::Duration,
73};
74
75pub use cache::{CacheFactory, CacheStorage, HashMapCache, LruCache, NoCache};
76use fnv::FnvHashMap;
77use futures_channel::oneshot;
78use futures_util::future::BoxFuture;
79#[cfg(feature = "tracing")]
80use tracing::{Instrument, info_span, instrument};
81
82use crate::util::Delay;
83
84#[allow(clippy::type_complexity)]
85struct ResSender<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
86 use_cache_values: HashMap<K, T::Value>,
87 tx: oneshot::Sender<Result<HashMap<K, T::Value>, T::Error>>,
88}
89
90struct Requests<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
91 keys: HashSet<K>,
92 pending: Vec<(HashSet<K>, ResSender<K, T>)>,
93 cache_storage: Box<dyn CacheStorage<Key = K, Value = T::Value>>,
94 disable_cache: bool,
95}
96
97type KeysAndSender<K, T> = (HashSet<K>, Vec<(HashSet<K>, ResSender<K, T>)>);
98
99impl<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> Requests<K, T> {
100 fn new<C: CacheFactory>(cache_factory: &C) -> Self {
101 Self {
102 keys: Default::default(),
103 pending: Vec::new(),
104 cache_storage: cache_factory.create::<K, T::Value>(),
105 disable_cache: false,
106 }
107 }
108
109 fn take(&mut self) -> KeysAndSender<K, T> {
110 (
111 std::mem::take(&mut self.keys),
112 std::mem::take(&mut self.pending),
113 )
114 }
115}
116
117#[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
119pub trait Loader<K: Send + Sync + Hash + Eq + Clone + 'static>: Send + Sync + 'static {
120 type Value: Send + Sync + Clone + 'static;
122
123 type Error: Send + Clone + 'static;
125
126 #[cfg(feature = "boxed-trait")]
128 async fn load(&self, keys: &[K]) -> Result<HashMap<K, Self::Value>, Self::Error>;
129
130 #[cfg(not(feature = "boxed-trait"))]
132 fn load(
133 &self,
134 keys: &[K],
135 ) -> impl Future<Output = Result<HashMap<K, Self::Value>, Self::Error>> + Send;
136}
137
138struct DataLoaderInner<T> {
139 requests: Mutex<FnvHashMap<TypeId, Box<dyn Any + Sync + Send>>>,
140 loader: T,
141}
142
143impl<T> DataLoaderInner<T> {
144 #[cfg_attr(feature = "tracing", instrument(skip_all))]
145 async fn do_load<K>(&self, disable_cache: bool, (keys, senders): KeysAndSender<K, T>)
146 where
147 K: Send + Sync + Hash + Eq + Clone + 'static,
148 T: Loader<K>,
149 {
150 let tid = TypeId::of::<K>();
151 let keys = keys.into_iter().collect::<Vec<_>>();
152
153 match self.loader.load(&keys).await {
154 Ok(values) => {
155 let mut request = self.requests.lock().unwrap();
157 let typed_requests = request
158 .get_mut(&tid)
159 .unwrap()
160 .downcast_mut::<Requests<K, T>>()
161 .unwrap();
162 let disable_cache = typed_requests.disable_cache || disable_cache;
163 if !disable_cache {
164 for (key, value) in &values {
165 typed_requests
166 .cache_storage
167 .insert(Cow::Borrowed(key), Cow::Borrowed(value));
168 }
169 }
170
171 for (keys, sender) in senders {
173 let mut res = HashMap::new();
174 res.extend(sender.use_cache_values);
175 for key in &keys {
176 res.extend(values.get(key).map(|value| (key.clone(), value.clone())));
177 }
178 sender.tx.send(Ok(res)).ok();
179 }
180 }
181 Err(err) => {
182 for (_, sender) in senders {
183 sender.tx.send(Err(err.clone())).ok();
184 }
185 }
186 }
187 }
188}
189
190pub struct DataLoader<T, C = NoCache> {
194 inner: Arc<DataLoaderInner<T>>,
195 cache_factory: C,
196 delay: Duration,
197 max_batch_size: usize,
198 disable_cache: AtomicBool,
199 spawner: Box<dyn Fn(BoxFuture<'static, ()>) + Send + Sync>,
200}
201
202impl<T> DataLoader<T, NoCache> {
203 pub fn new<S, R>(loader: T, spawner: S) -> Self
205 where
206 S: Fn(BoxFuture<'static, ()>) -> R + Send + Sync + 'static,
207 {
208 Self {
209 inner: Arc::new(DataLoaderInner {
210 requests: Mutex::new(Default::default()),
211 loader,
212 }),
213 cache_factory: NoCache,
214 delay: Duration::from_millis(1),
215 max_batch_size: 1000,
216 disable_cache: false.into(),
217 spawner: Box::new(move |fut| {
218 spawner(fut);
219 }),
220 }
221 }
222}
223
224impl<T, C: CacheFactory> DataLoader<T, C> {
225 pub fn with_cache<S, R>(loader: T, spawner: S, cache_factory: C) -> Self
227 where
228 S: Fn(BoxFuture<'static, ()>) -> R + Send + Sync + 'static,
229 {
230 Self {
231 inner: Arc::new(DataLoaderInner {
232 requests: Mutex::new(Default::default()),
233 loader,
234 }),
235 cache_factory,
236 delay: Duration::from_millis(1),
237 max_batch_size: 1000,
238 disable_cache: false.into(),
239 spawner: Box::new(move |fut| {
240 spawner(fut);
241 }),
242 }
243 }
244
245 #[must_use]
247 pub fn delay(self, delay: Duration) -> Self {
248 Self { delay, ..self }
249 }
250
251 #[must_use]
257 pub fn max_batch_size(self, max_batch_size: usize) -> Self {
258 Self {
259 max_batch_size,
260 ..self
261 }
262 }
263
264 #[inline]
266 pub fn loader(&self) -> &T {
267 &self.inner.loader
268 }
269
270 pub fn enable_all_cache(&self, enable: bool) {
272 self.disable_cache.store(!enable, Ordering::SeqCst);
273 }
274
275 pub fn enable_cache<K>(&self, enable: bool)
277 where
278 K: Send + Sync + Hash + Eq + Clone + 'static,
279 T: Loader<K>,
280 {
281 let tid = TypeId::of::<K>();
282 let mut requests = self.inner.requests.lock().unwrap();
283 let typed_requests = requests
284 .get_mut(&tid)
285 .unwrap()
286 .downcast_mut::<Requests<K, T>>()
287 .unwrap();
288 typed_requests.disable_cache = !enable;
289 }
290
291 #[cfg_attr(feature = "tracing", instrument(skip_all))]
293 pub async fn load_one<K>(&self, key: K) -> Result<Option<T::Value>, T::Error>
294 where
295 K: Send + Sync + Hash + Eq + Clone + 'static,
296 T: Loader<K>,
297 {
298 let mut values = self.load_many(std::iter::once(key.clone())).await?;
299 Ok(values.remove(&key))
300 }
301
302 #[cfg_attr(feature = "tracing", instrument(skip_all))]
304 pub async fn load_many<K, I>(&self, keys: I) -> Result<HashMap<K, T::Value>, T::Error>
305 where
306 K: Send + Sync + Hash + Eq + Clone + 'static,
307 I: IntoIterator<Item = K>,
308 T: Loader<K>,
309 {
310 enum Action<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
311 ImmediateLoad(KeysAndSender<K, T>),
312 StartFetch,
313 Delay,
314 }
315
316 let tid = TypeId::of::<K>();
317
318 let (action, rx) = {
319 let mut requests = self.inner.requests.lock().unwrap();
320 let typed_requests = requests
321 .entry(tid)
322 .or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)))
323 .downcast_mut::<Requests<K, T>>()
324 .unwrap();
325 let prev_count = typed_requests.keys.len();
326 let mut keys_set = HashSet::new();
327 let mut use_cache_values = HashMap::new();
328
329 if typed_requests.disable_cache || self.disable_cache.load(Ordering::SeqCst) {
330 keys_set = keys.into_iter().collect();
331 } else {
332 for key in keys {
333 if let Some(value) = typed_requests.cache_storage.get(&key) {
334 use_cache_values.insert(key.clone(), value.clone());
336 } else {
337 keys_set.insert(key);
338 }
339 }
340 }
341
342 if !use_cache_values.is_empty() && keys_set.is_empty() {
343 return Ok(use_cache_values);
344 } else if use_cache_values.is_empty() && keys_set.is_empty() {
345 return Ok(Default::default());
346 }
347
348 typed_requests.keys.extend(keys_set.clone());
349 let (tx, rx) = oneshot::channel();
350 typed_requests.pending.push((
351 keys_set,
352 ResSender {
353 use_cache_values,
354 tx,
355 },
356 ));
357
358 if typed_requests.keys.len() >= self.max_batch_size {
359 (Action::ImmediateLoad(typed_requests.take()), rx)
360 } else {
361 (
362 if !typed_requests.keys.is_empty() && prev_count == 0 {
363 Action::StartFetch
364 } else {
365 Action::Delay
366 },
367 rx,
368 )
369 }
370 };
371
372 match action {
373 Action::ImmediateLoad(keys) => {
374 let inner = self.inner.clone();
375 let disable_cache = self.disable_cache.load(Ordering::SeqCst);
376 let task = async move { inner.do_load(disable_cache, keys).await };
377 #[cfg(feature = "tracing")]
378 let task = task
379 .instrument(info_span!("immediate_load"))
380 .in_current_span();
381
382 (self.spawner)(Box::pin(task));
383 }
384 Action::StartFetch => {
385 let inner = self.inner.clone();
386 let disable_cache = self.disable_cache.load(Ordering::SeqCst);
387 let delay = self.delay;
388
389 let task = async move {
390 Delay::new(delay).await;
391
392 let keys = {
393 let mut request = inner.requests.lock().unwrap();
394 let typed_requests = request
395 .get_mut(&tid)
396 .unwrap()
397 .downcast_mut::<Requests<K, T>>()
398 .unwrap();
399 typed_requests.take()
400 };
401
402 if !keys.0.is_empty() {
403 inner.do_load(disable_cache, keys).await
404 }
405 };
406 #[cfg(feature = "tracing")]
407 let task = task.instrument(info_span!("start_fetch")).in_current_span();
408 (self.spawner)(Box::pin(task))
409 }
410 Action::Delay => {}
411 }
412
413 rx.await.unwrap()
414 }
415
416 #[cfg_attr(feature = "tracing", instrument(skip_all))]
421 pub async fn feed_many<K, I>(&self, values: I)
422 where
423 K: Send + Sync + Hash + Eq + Clone + 'static,
424 I: IntoIterator<Item = (K, T::Value)>,
425 T: Loader<K>,
426 {
427 let tid = TypeId::of::<K>();
428 let mut requests = self.inner.requests.lock().unwrap();
429 let typed_requests = requests
430 .entry(tid)
431 .or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)))
432 .downcast_mut::<Requests<K, T>>()
433 .unwrap();
434 for (key, value) in values {
435 typed_requests
436 .cache_storage
437 .insert(Cow::Owned(key), Cow::Owned(value));
438 }
439 }
440
441 #[cfg_attr(feature = "tracing", instrument(skip_all))]
446 pub async fn feed_one<K>(&self, key: K, value: T::Value)
447 where
448 K: Send + Sync + Hash + Eq + Clone + 'static,
449 T: Loader<K>,
450 {
451 self.feed_many(std::iter::once((key, value))).await;
452 }
453
454 #[cfg_attr(feature = "tracing", instrument(skip_all))]
459 pub fn clear<K>(&self)
460 where
461 K: Send + Sync + Hash + Eq + Clone + 'static,
462 T: Loader<K>,
463 {
464 let tid = TypeId::of::<K>();
465 let mut requests = self.inner.requests.lock().unwrap();
466 let typed_requests = requests
467 .entry(tid)
468 .or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)))
469 .downcast_mut::<Requests<K, T>>()
470 .unwrap();
471 typed_requests.cache_storage.clear();
472 }
473
474 pub fn get_cached_values<K>(&self) -> HashMap<K, T::Value>
476 where
477 K: Send + Sync + Hash + Eq + Clone + 'static,
478 T: Loader<K>,
479 {
480 let tid = TypeId::of::<K>();
481 let requests = self.inner.requests.lock().unwrap();
482 match requests.get(&tid) {
483 None => HashMap::new(),
484 Some(requests) => {
485 let typed_requests = requests.downcast_ref::<Requests<K, T>>().unwrap();
486 typed_requests
487 .cache_storage
488 .iter()
489 .map(|(k, v)| (k.clone(), v.clone()))
490 .collect()
491 }
492 }
493 }
494}
495
496#[cfg(test)]
497mod tests {
498 use fnv::FnvBuildHasher;
499
500 use super::*;
501
502 struct MyLoader;
503
504 #[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
505 impl Loader<i32> for MyLoader {
506 type Value = i32;
507 type Error = ();
508
509 async fn load(&self, keys: &[i32]) -> Result<HashMap<i32, Self::Value>, Self::Error> {
510 assert!(keys.len() <= 10);
511 Ok(keys.iter().copied().map(|k| (k, k)).collect())
512 }
513 }
514
515 #[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
516 impl Loader<i64> for MyLoader {
517 type Value = i64;
518 type Error = ();
519
520 async fn load(&self, keys: &[i64]) -> Result<HashMap<i64, Self::Value>, Self::Error> {
521 assert!(keys.len() <= 10);
522 Ok(keys.iter().copied().map(|k| (k, k)).collect())
523 }
524 }
525
526 #[tokio::test]
527 async fn test_dataloader() {
528 let loader = Arc::new(DataLoader::new(MyLoader, tokio::spawn).max_batch_size(10));
529 assert_eq!(
530 futures_util::future::try_join_all((0..100i32).map({
531 let loader = loader.clone();
532 move |n| {
533 let loader = loader.clone();
534 async move { loader.load_one(n).await }
535 }
536 }))
537 .await
538 .unwrap(),
539 (0..100).map(Option::Some).collect::<Vec<_>>()
540 );
541
542 assert_eq!(
543 futures_util::future::try_join_all((0..100i64).map({
544 let loader = loader.clone();
545 move |n| {
546 let loader = loader.clone();
547 async move { loader.load_one(n).await }
548 }
549 }))
550 .await
551 .unwrap(),
552 (0..100).map(Option::Some).collect::<Vec<_>>()
553 );
554 }
555
556 #[tokio::test]
557 async fn test_duplicate_keys() {
558 let loader = Arc::new(DataLoader::new(MyLoader, tokio::spawn).max_batch_size(10));
559 assert_eq!(
560 futures_util::future::try_join_all([1, 3, 5, 1, 7, 8, 3, 7].iter().copied().map({
561 let loader = loader.clone();
562 move |n| {
563 let loader = loader.clone();
564 async move { loader.load_one(n).await }
565 }
566 }))
567 .await
568 .unwrap(),
569 [1, 3, 5, 1, 7, 8, 3, 7]
570 .iter()
571 .copied()
572 .map(Option::Some)
573 .collect::<Vec<_>>()
574 );
575 }
576
577 #[tokio::test]
578 async fn test_dataloader_load_empty() {
579 let loader = DataLoader::new(MyLoader, tokio::spawn);
580 assert!(loader.load_many::<i32, _>(vec![]).await.unwrap().is_empty());
581 }
582
583 #[tokio::test]
584 async fn test_dataloader_with_cache() {
585 let loader = DataLoader::with_cache(MyLoader, tokio::spawn, HashMapCache::default());
586 loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
587
588 assert_eq!(
590 loader.load_many(vec![1, 2, 3]).await.unwrap(),
591 vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
592 );
593
594 assert_eq!(
596 loader.load_many(vec![1, 5, 6]).await.unwrap(),
597 vec![(1, 10), (5, 5), (6, 6)].into_iter().collect()
598 );
599
600 assert_eq!(
602 loader.load_many(vec![8, 9, 10]).await.unwrap(),
603 vec![(8, 8), (9, 9), (10, 10)].into_iter().collect()
604 );
605
606 loader.clear::<i32>();
608 assert_eq!(
609 loader.load_many(vec![1, 2, 3]).await.unwrap(),
610 vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
611 );
612 }
613
614 #[tokio::test]
615 async fn test_dataloader_with_cache_hashmap_fnv() {
616 let loader = DataLoader::with_cache(
617 MyLoader,
618 tokio::spawn,
619 HashMapCache::<FnvBuildHasher>::new(),
620 );
621 loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
622
623 assert_eq!(
625 loader.load_many(vec![1, 2, 3]).await.unwrap(),
626 vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
627 );
628
629 assert_eq!(
631 loader.load_many(vec![1, 5, 6]).await.unwrap(),
632 vec![(1, 10), (5, 5), (6, 6)].into_iter().collect()
633 );
634
635 assert_eq!(
637 loader.load_many(vec![8, 9, 10]).await.unwrap(),
638 vec![(8, 8), (9, 9), (10, 10)].into_iter().collect()
639 );
640
641 loader.clear::<i32>();
643 assert_eq!(
644 loader.load_many(vec![1, 2, 3]).await.unwrap(),
645 vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
646 );
647 }
648
649 #[tokio::test]
650 async fn test_dataloader_disable_all_cache() {
651 let loader = DataLoader::with_cache(MyLoader, tokio::spawn, HashMapCache::default());
652 loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
653
654 loader.enable_all_cache(false);
656 assert_eq!(
657 loader.load_many(vec![1, 2, 3]).await.unwrap(),
658 vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
659 );
660
661 loader.enable_all_cache(true);
663 assert_eq!(
664 loader.load_many(vec![1, 2, 3]).await.unwrap(),
665 vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
666 );
667 }
668
669 #[tokio::test]
670 async fn test_dataloader_disable_cache() {
671 let loader = DataLoader::with_cache(MyLoader, tokio::spawn, HashMapCache::default());
672 loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
673
674 loader.enable_cache::<i32>(false);
676 assert_eq!(
677 loader.load_many(vec![1, 2, 3]).await.unwrap(),
678 vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
679 );
680
681 loader.enable_cache::<i32>(true);
683 assert_eq!(
684 loader.load_many(vec![1, 2, 3]).await.unwrap(),
685 vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
686 );
687 }
688
689 #[tokio::test]
690 async fn test_dataloader_dead_lock() {
691 struct MyDelayLoader;
692
693 #[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
694 impl Loader<i32> for MyDelayLoader {
695 type Value = i32;
696 type Error = ();
697
698 async fn load(&self, keys: &[i32]) -> Result<HashMap<i32, Self::Value>, Self::Error> {
699 tokio::time::sleep(Duration::from_secs(1)).await;
700 Ok(keys.iter().copied().map(|k| (k, k)).collect())
701 }
702 }
703
704 let loader = Arc::new(
705 DataLoader::with_cache(MyDelayLoader, tokio::spawn, NoCache)
706 .delay(Duration::from_secs(1)),
707 );
708 let handle = tokio::spawn({
709 let loader = loader.clone();
710 async move {
711 loader.load_many(vec![1, 2, 3]).await.unwrap();
712 }
713 });
714
715 tokio::time::sleep(Duration::from_millis(500)).await;
716 handle.abort();
717 loader.load_many(vec![4, 5, 6]).await.unwrap();
718 }
719}