1mod cache;
60
61#[cfg(not(feature = "boxed-trait"))]
62use std::future::Future;
63use std::{
64 any::{Any, TypeId},
65 borrow::Cow,
66 collections::{HashMap, HashSet},
67 hash::Hash,
68 sync::{
69 Arc, Mutex,
70 atomic::{AtomicBool, Ordering},
71 },
72 time::Duration,
73};
74
75pub use cache::{CacheFactory, CacheStorage, HashMapCache, LruCache, NoCache};
76use fnv::FnvHashMap;
77use futures_channel::oneshot;
78use futures_timer::Delay;
79use futures_util::future::BoxFuture;
80#[cfg(feature = "tracing")]
81use tracing::{Instrument, info_span, instrument};
82
83#[allow(clippy::type_complexity)]
84struct ResSender<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
85 use_cache_values: HashMap<K, T::Value>,
86 tx: oneshot::Sender<Result<HashMap<K, T::Value>, T::Error>>,
87}
88
89struct Requests<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
90 keys: HashSet<K>,
91 pending: Vec<(HashSet<K>, ResSender<K, T>)>,
92 cache_storage: Box<dyn CacheStorage<Key = K, Value = T::Value>>,
93 disable_cache: bool,
94}
95
96type KeysAndSender<K, T> = (HashSet<K>, Vec<(HashSet<K>, ResSender<K, T>)>);
97
98impl<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> Requests<K, T> {
99 fn new<C: CacheFactory>(cache_factory: &C) -> Self {
100 Self {
101 keys: Default::default(),
102 pending: Vec::new(),
103 cache_storage: cache_factory.create::<K, T::Value>(),
104 disable_cache: false,
105 }
106 }
107
108 fn take(&mut self) -> KeysAndSender<K, T> {
109 (
110 std::mem::take(&mut self.keys),
111 std::mem::take(&mut self.pending),
112 )
113 }
114}
115
116#[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
118pub trait Loader<K: Send + Sync + Hash + Eq + Clone + 'static>: Send + Sync + 'static {
119 type Value: Send + Sync + Clone + 'static;
121
122 type Error: Send + Clone + 'static;
124
125 #[cfg(feature = "boxed-trait")]
127 async fn load(&self, keys: &[K]) -> Result<HashMap<K, Self::Value>, Self::Error>;
128
129 #[cfg(not(feature = "boxed-trait"))]
131 fn load(
132 &self,
133 keys: &[K],
134 ) -> impl Future<Output = Result<HashMap<K, Self::Value>, Self::Error>> + Send;
135}
136
137struct DataLoaderInner<T> {
138 requests: Mutex<FnvHashMap<TypeId, Box<dyn Any + Sync + Send>>>,
139 loader: T,
140}
141
142impl<T> DataLoaderInner<T> {
143 #[cfg_attr(feature = "tracing", instrument(skip_all))]
144 async fn do_load<K>(&self, disable_cache: bool, (keys, senders): KeysAndSender<K, T>)
145 where
146 K: Send + Sync + Hash + Eq + Clone + 'static,
147 T: Loader<K>,
148 {
149 let tid = TypeId::of::<K>();
150 let keys = keys.into_iter().collect::<Vec<_>>();
151
152 match self.loader.load(&keys).await {
153 Ok(values) => {
154 let mut request = self.requests.lock().unwrap();
156 let typed_requests = request
157 .get_mut(&tid)
158 .unwrap()
159 .downcast_mut::<Requests<K, T>>()
160 .unwrap();
161 let disable_cache = typed_requests.disable_cache || disable_cache;
162 if !disable_cache {
163 for (key, value) in &values {
164 typed_requests
165 .cache_storage
166 .insert(Cow::Borrowed(key), Cow::Borrowed(value));
167 }
168 }
169
170 for (keys, sender) in senders {
172 let mut res = HashMap::new();
173 res.extend(sender.use_cache_values);
174 for key in &keys {
175 res.extend(values.get(key).map(|value| (key.clone(), value.clone())));
176 }
177 sender.tx.send(Ok(res)).ok();
178 }
179 }
180 Err(err) => {
181 for (_, sender) in senders {
182 sender.tx.send(Err(err.clone())).ok();
183 }
184 }
185 }
186 }
187}
188
189pub struct DataLoader<T, C = NoCache> {
193 inner: Arc<DataLoaderInner<T>>,
194 cache_factory: C,
195 delay: Duration,
196 max_batch_size: usize,
197 disable_cache: AtomicBool,
198 spawner: Box<dyn Fn(BoxFuture<'static, ()>) + Send + Sync>,
199}
200
201impl<T> DataLoader<T, NoCache> {
202 pub fn new<S, R>(loader: T, spawner: S) -> Self
204 where
205 S: Fn(BoxFuture<'static, ()>) -> R + Send + Sync + 'static,
206 {
207 Self {
208 inner: Arc::new(DataLoaderInner {
209 requests: Mutex::new(Default::default()),
210 loader,
211 }),
212 cache_factory: NoCache,
213 delay: Duration::from_millis(1),
214 max_batch_size: 1000,
215 disable_cache: false.into(),
216 spawner: Box::new(move |fut| {
217 spawner(fut);
218 }),
219 }
220 }
221}
222
223impl<T, C: CacheFactory> DataLoader<T, C> {
224 pub fn with_cache<S, R>(loader: T, spawner: S, cache_factory: C) -> Self
226 where
227 S: Fn(BoxFuture<'static, ()>) -> R + Send + Sync + 'static,
228 {
229 Self {
230 inner: Arc::new(DataLoaderInner {
231 requests: Mutex::new(Default::default()),
232 loader,
233 }),
234 cache_factory,
235 delay: Duration::from_millis(1),
236 max_batch_size: 1000,
237 disable_cache: false.into(),
238 spawner: Box::new(move |fut| {
239 spawner(fut);
240 }),
241 }
242 }
243
244 #[must_use]
246 pub fn delay(self, delay: Duration) -> Self {
247 Self { delay, ..self }
248 }
249
250 #[must_use]
256 pub fn max_batch_size(self, max_batch_size: usize) -> Self {
257 Self {
258 max_batch_size,
259 ..self
260 }
261 }
262
263 #[inline]
265 pub fn loader(&self) -> &T {
266 &self.inner.loader
267 }
268
269 pub fn enable_all_cache(&self, enable: bool) {
271 self.disable_cache.store(!enable, Ordering::SeqCst);
272 }
273
274 pub fn enable_cache<K>(&self, enable: bool)
276 where
277 K: Send + Sync + Hash + Eq + Clone + 'static,
278 T: Loader<K>,
279 {
280 let tid = TypeId::of::<K>();
281 let mut requests = self.inner.requests.lock().unwrap();
282 let typed_requests = requests
283 .get_mut(&tid)
284 .unwrap()
285 .downcast_mut::<Requests<K, T>>()
286 .unwrap();
287 typed_requests.disable_cache = !enable;
288 }
289
290 #[cfg_attr(feature = "tracing", instrument(skip_all))]
292 pub async fn load_one<K>(&self, key: K) -> Result<Option<T::Value>, T::Error>
293 where
294 K: Send + Sync + Hash + Eq + Clone + 'static,
295 T: Loader<K>,
296 {
297 let mut values = self.load_many(std::iter::once(key.clone())).await?;
298 Ok(values.remove(&key))
299 }
300
301 #[cfg_attr(feature = "tracing", instrument(skip_all))]
303 pub async fn load_many<K, I>(&self, keys: I) -> Result<HashMap<K, T::Value>, T::Error>
304 where
305 K: Send + Sync + Hash + Eq + Clone + 'static,
306 I: IntoIterator<Item = K>,
307 T: Loader<K>,
308 {
309 enum Action<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
310 ImmediateLoad(KeysAndSender<K, T>),
311 StartFetch,
312 Delay,
313 }
314
315 let tid = TypeId::of::<K>();
316
317 let (action, rx) = {
318 let mut requests = self.inner.requests.lock().unwrap();
319 let typed_requests = requests
320 .entry(tid)
321 .or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)))
322 .downcast_mut::<Requests<K, T>>()
323 .unwrap();
324 let prev_count = typed_requests.keys.len();
325 let mut keys_set = HashSet::new();
326 let mut use_cache_values = HashMap::new();
327
328 if typed_requests.disable_cache || self.disable_cache.load(Ordering::SeqCst) {
329 keys_set = keys.into_iter().collect();
330 } else {
331 for key in keys {
332 if let Some(value) = typed_requests.cache_storage.get(&key) {
333 use_cache_values.insert(key.clone(), value.clone());
335 } else {
336 keys_set.insert(key);
337 }
338 }
339 }
340
341 if !use_cache_values.is_empty() && keys_set.is_empty() {
342 return Ok(use_cache_values);
343 } else if use_cache_values.is_empty() && keys_set.is_empty() {
344 return Ok(Default::default());
345 }
346
347 typed_requests.keys.extend(keys_set.clone());
348 let (tx, rx) = oneshot::channel();
349 typed_requests.pending.push((
350 keys_set,
351 ResSender {
352 use_cache_values,
353 tx,
354 },
355 ));
356
357 if typed_requests.keys.len() >= self.max_batch_size {
358 (Action::ImmediateLoad(typed_requests.take()), rx)
359 } else {
360 (
361 if !typed_requests.keys.is_empty() && prev_count == 0 {
362 Action::StartFetch
363 } else {
364 Action::Delay
365 },
366 rx,
367 )
368 }
369 };
370
371 match action {
372 Action::ImmediateLoad(keys) => {
373 let inner = self.inner.clone();
374 let disable_cache = self.disable_cache.load(Ordering::SeqCst);
375 let task = async move { inner.do_load(disable_cache, keys).await };
376 #[cfg(feature = "tracing")]
377 let task = task
378 .instrument(info_span!("immediate_load"))
379 .in_current_span();
380
381 (self.spawner)(Box::pin(task));
382 }
383 Action::StartFetch => {
384 let inner = self.inner.clone();
385 let disable_cache = self.disable_cache.load(Ordering::SeqCst);
386 let delay = self.delay;
387
388 let task = async move {
389 Delay::new(delay).await;
390
391 let keys = {
392 let mut request = inner.requests.lock().unwrap();
393 let typed_requests = request
394 .get_mut(&tid)
395 .unwrap()
396 .downcast_mut::<Requests<K, T>>()
397 .unwrap();
398 typed_requests.take()
399 };
400
401 if !keys.0.is_empty() {
402 inner.do_load(disable_cache, keys).await
403 }
404 };
405 #[cfg(feature = "tracing")]
406 let task = task.instrument(info_span!("start_fetch")).in_current_span();
407 (self.spawner)(Box::pin(task))
408 }
409 Action::Delay => {}
410 }
411
412 rx.await.unwrap()
413 }
414
415 #[cfg_attr(feature = "tracing", instrument(skip_all))]
420 pub async fn feed_many<K, I>(&self, values: I)
421 where
422 K: Send + Sync + Hash + Eq + Clone + 'static,
423 I: IntoIterator<Item = (K, T::Value)>,
424 T: Loader<K>,
425 {
426 let tid = TypeId::of::<K>();
427 let mut requests = self.inner.requests.lock().unwrap();
428 let typed_requests = requests
429 .entry(tid)
430 .or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)))
431 .downcast_mut::<Requests<K, T>>()
432 .unwrap();
433 for (key, value) in values {
434 typed_requests
435 .cache_storage
436 .insert(Cow::Owned(key), Cow::Owned(value));
437 }
438 }
439
440 #[cfg_attr(feature = "tracing", instrument(skip_all))]
445 pub async fn feed_one<K>(&self, key: K, value: T::Value)
446 where
447 K: Send + Sync + Hash + Eq + Clone + 'static,
448 T: Loader<K>,
449 {
450 self.feed_many(std::iter::once((key, value))).await;
451 }
452
453 #[cfg_attr(feature = "tracing", instrument(skip_all))]
458 pub fn clear<K>(&self)
459 where
460 K: Send + Sync + Hash + Eq + Clone + 'static,
461 T: Loader<K>,
462 {
463 let tid = TypeId::of::<K>();
464 let mut requests = self.inner.requests.lock().unwrap();
465 let typed_requests = requests
466 .entry(tid)
467 .or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)))
468 .downcast_mut::<Requests<K, T>>()
469 .unwrap();
470 typed_requests.cache_storage.clear();
471 }
472
473 pub fn get_cached_values<K>(&self) -> HashMap<K, T::Value>
475 where
476 K: Send + Sync + Hash + Eq + Clone + 'static,
477 T: Loader<K>,
478 {
479 let tid = TypeId::of::<K>();
480 let requests = self.inner.requests.lock().unwrap();
481 match requests.get(&tid) {
482 None => HashMap::new(),
483 Some(requests) => {
484 let typed_requests = requests.downcast_ref::<Requests<K, T>>().unwrap();
485 typed_requests
486 .cache_storage
487 .iter()
488 .map(|(k, v)| (k.clone(), v.clone()))
489 .collect()
490 }
491 }
492 }
493}
494
495#[cfg(test)]
496mod tests {
497 use fnv::FnvBuildHasher;
498
499 use super::*;
500
501 struct MyLoader;
502
503 #[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
504 impl Loader<i32> for MyLoader {
505 type Value = i32;
506 type Error = ();
507
508 async fn load(&self, keys: &[i32]) -> Result<HashMap<i32, Self::Value>, Self::Error> {
509 assert!(keys.len() <= 10);
510 Ok(keys.iter().copied().map(|k| (k, k)).collect())
511 }
512 }
513
514 #[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
515 impl Loader<i64> for MyLoader {
516 type Value = i64;
517 type Error = ();
518
519 async fn load(&self, keys: &[i64]) -> Result<HashMap<i64, Self::Value>, Self::Error> {
520 assert!(keys.len() <= 10);
521 Ok(keys.iter().copied().map(|k| (k, k)).collect())
522 }
523 }
524
525 #[tokio::test]
526 async fn test_dataloader() {
527 let loader = Arc::new(DataLoader::new(MyLoader, tokio::spawn).max_batch_size(10));
528 assert_eq!(
529 futures_util::future::try_join_all((0..100i32).map({
530 let loader = loader.clone();
531 move |n| {
532 let loader = loader.clone();
533 async move { loader.load_one(n).await }
534 }
535 }))
536 .await
537 .unwrap(),
538 (0..100).map(Option::Some).collect::<Vec<_>>()
539 );
540
541 assert_eq!(
542 futures_util::future::try_join_all((0..100i64).map({
543 let loader = loader.clone();
544 move |n| {
545 let loader = loader.clone();
546 async move { loader.load_one(n).await }
547 }
548 }))
549 .await
550 .unwrap(),
551 (0..100).map(Option::Some).collect::<Vec<_>>()
552 );
553 }
554
555 #[tokio::test]
556 async fn test_duplicate_keys() {
557 let loader = Arc::new(DataLoader::new(MyLoader, tokio::spawn).max_batch_size(10));
558 assert_eq!(
559 futures_util::future::try_join_all([1, 3, 5, 1, 7, 8, 3, 7].iter().copied().map({
560 let loader = loader.clone();
561 move |n| {
562 let loader = loader.clone();
563 async move { loader.load_one(n).await }
564 }
565 }))
566 .await
567 .unwrap(),
568 [1, 3, 5, 1, 7, 8, 3, 7]
569 .iter()
570 .copied()
571 .map(Option::Some)
572 .collect::<Vec<_>>()
573 );
574 }
575
576 #[tokio::test]
577 async fn test_dataloader_load_empty() {
578 let loader = DataLoader::new(MyLoader, tokio::spawn);
579 assert!(loader.load_many::<i32, _>(vec![]).await.unwrap().is_empty());
580 }
581
582 #[tokio::test]
583 async fn test_dataloader_with_cache() {
584 let loader = DataLoader::with_cache(MyLoader, tokio::spawn, HashMapCache::default());
585 loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
586
587 assert_eq!(
589 loader.load_many(vec![1, 2, 3]).await.unwrap(),
590 vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
591 );
592
593 assert_eq!(
595 loader.load_many(vec![1, 5, 6]).await.unwrap(),
596 vec![(1, 10), (5, 5), (6, 6)].into_iter().collect()
597 );
598
599 assert_eq!(
601 loader.load_many(vec![8, 9, 10]).await.unwrap(),
602 vec![(8, 8), (9, 9), (10, 10)].into_iter().collect()
603 );
604
605 loader.clear::<i32>();
607 assert_eq!(
608 loader.load_many(vec![1, 2, 3]).await.unwrap(),
609 vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
610 );
611 }
612
613 #[tokio::test]
614 async fn test_dataloader_with_cache_hashmap_fnv() {
615 let loader = DataLoader::with_cache(
616 MyLoader,
617 tokio::spawn,
618 HashMapCache::<FnvBuildHasher>::new(),
619 );
620 loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
621
622 assert_eq!(
624 loader.load_many(vec![1, 2, 3]).await.unwrap(),
625 vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
626 );
627
628 assert_eq!(
630 loader.load_many(vec![1, 5, 6]).await.unwrap(),
631 vec![(1, 10), (5, 5), (6, 6)].into_iter().collect()
632 );
633
634 assert_eq!(
636 loader.load_many(vec![8, 9, 10]).await.unwrap(),
637 vec![(8, 8), (9, 9), (10, 10)].into_iter().collect()
638 );
639
640 loader.clear::<i32>();
642 assert_eq!(
643 loader.load_many(vec![1, 2, 3]).await.unwrap(),
644 vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
645 );
646 }
647
648 #[tokio::test]
649 async fn test_dataloader_disable_all_cache() {
650 let loader = DataLoader::with_cache(MyLoader, tokio::spawn, HashMapCache::default());
651 loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
652
653 loader.enable_all_cache(false);
655 assert_eq!(
656 loader.load_many(vec![1, 2, 3]).await.unwrap(),
657 vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
658 );
659
660 loader.enable_all_cache(true);
662 assert_eq!(
663 loader.load_many(vec![1, 2, 3]).await.unwrap(),
664 vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
665 );
666 }
667
668 #[tokio::test]
669 async fn test_dataloader_disable_cache() {
670 let loader = DataLoader::with_cache(MyLoader, tokio::spawn, HashMapCache::default());
671 loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
672
673 loader.enable_cache::<i32>(false);
675 assert_eq!(
676 loader.load_many(vec![1, 2, 3]).await.unwrap(),
677 vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
678 );
679
680 loader.enable_cache::<i32>(true);
682 assert_eq!(
683 loader.load_many(vec![1, 2, 3]).await.unwrap(),
684 vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
685 );
686 }
687
688 #[tokio::test]
689 async fn test_dataloader_dead_lock() {
690 struct MyDelayLoader;
691
692 #[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
693 impl Loader<i32> for MyDelayLoader {
694 type Value = i32;
695 type Error = ();
696
697 async fn load(&self, keys: &[i32]) -> Result<HashMap<i32, Self::Value>, Self::Error> {
698 tokio::time::sleep(Duration::from_secs(1)).await;
699 Ok(keys.iter().copied().map(|k| (k, k)).collect())
700 }
701 }
702
703 let loader = Arc::new(
704 DataLoader::with_cache(MyDelayLoader, tokio::spawn, NoCache)
705 .delay(Duration::from_secs(1)),
706 );
707 let handle = tokio::spawn({
708 let loader = loader.clone();
709 async move {
710 loader.load_many(vec![1, 2, 3]).await.unwrap();
711 }
712 });
713
714 tokio::time::sleep(Duration::from_millis(500)).await;
715 handle.abort();
716 loader.load_many(vec![4, 5, 6]).await.unwrap();
717 }
718}