async_graphql_dataloader/
batcher.rs1use std::{collections::HashMap, sync::Arc, time::Duration};
3use tokio::sync::{Mutex, oneshot, RwLock};
4use crate::loader::BatchLoad;
5use crate::error::DataLoaderError;
6
7type BatchResult<V> = oneshot::Sender<Result<V, DataLoaderError>>;
8type PendingBatch<K, V> = HashMap<K, Vec<BatchResult<V>>>;
9
10pub struct Batcher<L: BatchLoad> {
11 loader: Arc<L>,
12 pending: Mutex<PendingBatch<L::Key, L::Value>>,
13 metrics: Arc<Metrics>,
14}
15
16#[derive(Clone, Debug)]
17pub struct Metrics {
18 pub batches_dispatched: Arc<RwLock<u64>>,
19 pub keys_processed: Arc<RwLock<u64>>,
20}
21
22impl Metrics {
23 pub fn new() -> Self {
24 Self {
25 batches_dispatched: Arc::new(RwLock::new(0)),
26 keys_processed: Arc::new(RwLock::new(0)),
27 }
28 }
29
30 pub async fn get_stats(&self) -> BatchStats {
31 BatchStats {
32 batches_dispatched: *self.batches_dispatched.read().await,
33 keys_processed: *self.keys_processed.read().await,
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
39pub struct BatchStats {
40 pub batches_dispatched: u64,
41 pub keys_processed: u64,
42}
43
44impl<L> Batcher<L>
45where
46 L: BatchLoad + 'static,
47 L::Key: Clone + Eq + std::hash::Hash + std::fmt::Debug,
48 L::Value: Clone,
49 L::Error: From<String> + std::fmt::Display,
50{
51 pub fn new(loader: Arc<L>) -> Self {
52 Self {
53 loader,
54 pending: Mutex::new(HashMap::new()),
55 metrics: Arc::new(Metrics::new()),
56 }
57 }
58
59 pub fn metrics(&self) -> Arc<Metrics> {
60 Arc::clone(&self.metrics)
61 }
62
63 pub async fn schedule(&self, key: L::Key) -> Result<L::Value, DataLoaderError> {
64 let (tx, rx) = oneshot::channel();
65
66 let should_dispatch = {
67 let mut pending = self.pending.lock().await;
68 let entry = pending.entry(key.clone()).or_insert_with(Vec::new);
69 entry.push(tx);
70
71 true };
74
75 if should_dispatch {
76 self.dispatch_batch_for_key(key.clone()).await;
77 }
78
79 match rx.await {
80 Ok(result) => result,
81 Err(_) => Err(DataLoaderError::ChannelClosed),
82 }
83 }
84
85 async fn dispatch_batch_for_key(&self, key: L::Key) {
86 let batch = {
87 let mut pending = self.pending.lock().await;
88 pending.remove(&key).map(|senders| {
89 vec![(key, senders)]
90 }).unwrap_or_else(Vec::new)
91 };
92
93 if !batch.is_empty() {
94 self.process_batch(batch).await;
95 }
96 }
97
98 async fn process_batch(&self, batch: Vec<(L::Key, Vec<BatchResult<L::Value>>)>) {
99 let keys: Vec<L::Key> = batch.iter().map(|(key, _)| key.clone()).collect();
100
101 if keys.is_empty() {
102 return;
103 }
104
105 *self.metrics.keys_processed.write().await += keys.len() as u64;
106 *self.metrics.batches_dispatched.write().await += 1;
107
108 let results = self.loader.load(&keys).await;
109
110 for (key, senders) in batch {
111 let result = match results.get(&key) {
112 Some(Ok(value)) => Ok(value.clone()),
113 Some(Err(err)) => Err(DataLoaderError::BatchError(format!("{}", err))),
114 None => Err(DataLoaderError::KeyNotFound),
115 };
116
117 for sender in senders {
118 let _ = sender.send(result.clone());
119 }
120 }
121 }
122}
123
124impl<L> Clone for Batcher<L>
125where
126 L: BatchLoad,
127{
128 fn clone(&self) -> Self {
129 Self {
130 loader: Arc::clone(&self.loader),
131 pending: Mutex::new(HashMap::new()),
132 metrics: Arc::clone(&self.metrics),
133 }
134 }
135}