1use std::collections::HashMap;
6use std::hash::Hash;
7use std::time::{Duration, Instant};
8
9#[derive(Debug, Clone)]
11pub struct DataLoaderConfig {
12 pub batch_window: Duration,
14 pub max_batch_size: usize,
16 pub cache_enabled: bool,
18 pub cache_ttl: Duration,
20 pub dedupe: bool,
22}
23
24impl Default for DataLoaderConfig {
25 fn default() -> Self {
26 Self {
27 batch_window: Duration::from_millis(10),
28 max_batch_size: 100,
29 cache_enabled: true,
30 cache_ttl: Duration::from_secs(60),
31 dedupe: true,
32 }
33 }
34}
35
36impl DataLoaderConfig {
37 pub fn new() -> Self {
39 Self::default()
40 }
41
42 pub fn batch_window(mut self, duration: Duration) -> Self {
44 self.batch_window = duration;
45 self
46 }
47
48 pub fn max_batch_size(mut self, size: usize) -> Self {
50 self.max_batch_size = size;
51 self
52 }
53
54 pub fn cache(mut self, enabled: bool) -> Self {
56 self.cache_enabled = enabled;
57 self
58 }
59
60 pub fn cache_ttl(mut self, ttl: Duration) -> Self {
62 self.cache_ttl = ttl;
63 self
64 }
65}
66
67#[derive(Debug, Clone)]
69pub struct BatchResult<K, V> {
70 pub results: HashMap<K, V>,
72 pub missing: Vec<K>,
74}
75
76impl<K: Eq + Hash, V> BatchResult<K, V> {
77 pub fn new(results: HashMap<K, V>) -> Self {
79 Self {
80 results,
81 missing: Vec::new(),
82 }
83 }
84
85 pub fn empty() -> Self {
87 Self {
88 results: HashMap::new(),
89 missing: Vec::new(),
90 }
91 }
92
93 pub fn with_missing(mut self, missing: Vec<K>) -> Self {
95 self.missing = missing;
96 self
97 }
98
99 pub fn get(&self, key: &K) -> Option<&V> {
101 self.results.get(key)
102 }
103
104 pub fn is_missing(&self, key: &K) -> bool
106 where
107 K: PartialEq,
108 {
109 self.missing.contains(key)
110 }
111}
112
113#[derive(Debug, Clone)]
115struct CacheEntry<V> {
116 value: V,
117 expires_at: Instant,
118}
119
120impl<V> CacheEntry<V> {
121 fn new(value: V, ttl: Duration) -> Self {
122 Self {
123 value,
124 expires_at: Instant::now() + ttl,
125 }
126 }
127
128 fn is_expired(&self) -> bool {
129 Instant::now() >= self.expires_at
130 }
131}
132
133#[derive(Debug)]
138pub struct DataLoader<K, V>
139where
140 K: Eq + Hash + Clone,
141 V: Clone,
142{
143 config: DataLoaderConfig,
145 cache: std::sync::Mutex<HashMap<K, CacheEntry<V>>>,
147 pending: std::sync::Mutex<Vec<K>>,
149 stats: std::sync::Mutex<DataLoaderStats>,
151}
152
153#[derive(Debug, Clone, Default)]
155pub struct DataLoaderStats {
156 pub total_loads: u64,
158 pub cache_hits: u64,
160 pub cache_misses: u64,
162 pub batch_loads: u64,
164 pub avg_batch_size: f64,
166}
167
168impl DataLoaderStats {
169 pub fn hit_rate(&self) -> f64 {
171 if self.total_loads == 0 {
172 0.0
173 } else {
174 self.cache_hits as f64 / self.total_loads as f64
175 }
176 }
177}
178
179impl<K, V> DataLoader<K, V>
180where
181 K: Eq + Hash + Clone + Send + Sync,
182 V: Clone + Send + Sync,
183{
184 pub fn new(config: DataLoaderConfig) -> Self {
186 Self {
187 config,
188 cache: std::sync::Mutex::new(HashMap::new()),
189 pending: std::sync::Mutex::new(Vec::new()),
190 stats: std::sync::Mutex::new(DataLoaderStats::default()),
191 }
192 }
193
194 pub fn load(&self, key: K) -> Option<V> {
196 self.update_stats(|s| s.total_loads += 1);
197
198 if self.config.cache_enabled {
200 if let Some(value) = self.get_cached(&key) {
201 self.update_stats(|s| s.cache_hits += 1);
202 return Some(value);
203 }
204 self.update_stats(|s| s.cache_misses += 1);
205 }
206
207 self.pending.lock().unwrap().push(key);
209
210 None
211 }
212
213 pub fn load_many(&self, keys: Vec<K>) -> HashMap<K, Option<V>> {
215 let mut results = HashMap::new();
216
217 for key in keys {
218 results.insert(key.clone(), self.load(key));
219 }
220
221 results
222 }
223
224 pub fn prime(&self, key: K, value: V) {
226 if self.config.cache_enabled {
227 let entry = CacheEntry::new(value, self.config.cache_ttl);
228 self.cache.lock().unwrap().insert(key, entry);
229 }
230 }
231
232 pub fn clear(&self) {
234 self.cache.lock().unwrap().clear();
235 }
236
237 pub fn clear_key(&self, key: &K) {
239 self.cache.lock().unwrap().remove(key);
240 }
241
242 pub fn execute_batch<F>(&self, mut loader: F) -> BatchResult<K, V>
244 where
245 F: FnMut(Vec<K>) -> HashMap<K, V>,
246 {
247 let keys: Vec<K> = {
249 let mut pending = self.pending.lock().unwrap();
250 std::mem::take(&mut *pending)
251 };
252
253 if keys.is_empty() {
254 return BatchResult::empty();
255 }
256
257 let unique_keys: Vec<K> = if self.config.dedupe {
259 let mut seen = std::collections::HashSet::new();
260 keys.into_iter()
261 .filter(|k| seen.insert(k.clone()))
262 .collect()
263 } else {
264 keys
265 };
266
267 let _batch_count = (unique_keys.len() + self.config.max_batch_size - 1)
269 / self.config.max_batch_size;
270
271 let mut all_results = HashMap::new();
272
273 for batch in unique_keys.chunks(self.config.max_batch_size) {
274 let batch_keys: Vec<K> = batch.to_vec();
275 let batch_size = batch_keys.len();
276
277 let results = loader(batch_keys);
279
280 self.update_stats(|s| {
281 s.batch_loads += 1;
282 let total_batches = s.batch_loads as f64;
283 s.avg_batch_size = ((s.avg_batch_size * (total_batches - 1.0)) + batch_size as f64)
284 / total_batches;
285 });
286
287 if self.config.cache_enabled {
289 let mut cache = self.cache.lock().unwrap();
290 for (k, v) in &results {
291 cache.insert(k.clone(), CacheEntry::new(v.clone(), self.config.cache_ttl));
292 }
293 }
294
295 all_results.extend(results);
296 }
297
298 BatchResult::new(all_results)
299 }
300
301 fn get_cached(&self, key: &K) -> Option<V> {
303 let mut cache = self.cache.lock().unwrap();
304
305 if let Some(entry) = cache.get(key) {
306 if !entry.is_expired() {
307 return Some(entry.value.clone());
308 } else {
309 cache.remove(key);
310 }
311 }
312
313 None
314 }
315
316 fn update_stats<F>(&self, f: F)
318 where
319 F: FnOnce(&mut DataLoaderStats),
320 {
321 let mut stats = self.stats.lock().unwrap();
322 f(&mut stats);
323 }
324
325 pub fn stats(&self) -> DataLoaderStats {
327 self.stats.lock().unwrap().clone()
328 }
329
330 pub fn config(&self) -> &DataLoaderConfig {
332 &self.config
333 }
334
335 pub fn clean_expired(&self) {
337 let mut cache = self.cache.lock().unwrap();
338 cache.retain(|_, entry| !entry.is_expired());
339 }
340}
341
342impl<K, V> Clone for DataLoader<K, V>
343where
344 K: Eq + Hash + Clone,
345 V: Clone,
346{
347 fn clone(&self) -> Self {
348 Self {
349 config: self.config.clone(),
350 cache: std::sync::Mutex::new(self.cache.lock().unwrap().clone()),
351 pending: std::sync::Mutex::new(self.pending.lock().unwrap().clone()),
352 stats: std::sync::Mutex::new(self.stats.lock().unwrap().clone()),
353 }
354 }
355}
356
357#[derive(Debug)]
359pub struct DataLoaderFactory {
360 default_config: DataLoaderConfig,
362}
363
364impl DataLoaderFactory {
365 pub fn new(config: DataLoaderConfig) -> Self {
367 Self {
368 default_config: config,
369 }
370 }
371
372 pub fn create<K, V>(&self) -> DataLoader<K, V>
374 where
375 K: Eq + Hash + Clone + Send + Sync,
376 V: Clone + Send + Sync,
377 {
378 DataLoader::new(self.default_config.clone())
379 }
380
381 pub fn create_with_config<K, V>(&self, config: DataLoaderConfig) -> DataLoader<K, V>
383 where
384 K: Eq + Hash + Clone + Send + Sync,
385 V: Clone + Send + Sync,
386 {
387 DataLoader::new(config)
388 }
389}
390
391impl Default for DataLoaderFactory {
392 fn default() -> Self {
393 Self::new(DataLoaderConfig::default())
394 }
395}
396
397pub type IdLoader<V> = DataLoader<String, V>;
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
405 fn test_dataloader_config() {
406 let config = DataLoaderConfig::new()
407 .batch_window(Duration::from_millis(20))
408 .max_batch_size(50)
409 .cache(true)
410 .cache_ttl(Duration::from_secs(120));
411
412 assert_eq!(config.batch_window, Duration::from_millis(20));
413 assert_eq!(config.max_batch_size, 50);
414 assert!(config.cache_enabled);
415 assert_eq!(config.cache_ttl, Duration::from_secs(120));
416 }
417
418 #[test]
419 fn test_dataloader_prime_and_load() {
420 let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
421
422 loader.prime("key1".to_string(), "value1".to_string());
423
424 let result = loader.load("key1".to_string());
425 assert_eq!(result, Some("value1".to_string()));
426
427 let stats = loader.stats();
428 assert_eq!(stats.cache_hits, 1);
429 }
430
431 #[test]
432 fn test_dataloader_batch_execution() {
433 let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
434
435 loader.load("key1".to_string());
437 loader.load("key2".to_string());
438 loader.load("key3".to_string());
439
440 let result = loader.execute_batch(|keys| {
442 keys.into_iter()
443 .map(|k| (k.clone(), format!("value_{}", k)))
444 .collect()
445 });
446
447 assert_eq!(result.results.len(), 3);
448 assert_eq!(result.get(&"key1".to_string()), Some(&"value_key1".to_string()));
449
450 let stats = loader.stats();
451 assert_eq!(stats.batch_loads, 1);
452 }
453
454 #[test]
455 fn test_dataloader_deduplication() {
456 let loader: DataLoader<String, i32> = DataLoader::new(
457 DataLoaderConfig::default().max_batch_size(100)
458 );
459
460 loader.load("key1".to_string());
462 loader.load("key1".to_string());
463 loader.load("key2".to_string());
464 loader.load("key1".to_string());
465
466 let mut batch_keys_count = 0;
467 let result = loader.execute_batch(|keys| {
468 batch_keys_count = keys.len();
469 keys.into_iter().map(|k| (k, 1)).collect()
470 });
471
472 assert_eq!(batch_keys_count, 2);
474 assert_eq!(result.results.len(), 2);
475 }
476
477 #[test]
478 fn test_dataloader_batch_splitting() {
479 let loader: DataLoader<i32, i32> = DataLoader::new(
480 DataLoaderConfig::default().max_batch_size(2)
481 );
482
483 for i in 0..5 {
485 loader.load(i);
486 }
487
488 let result = loader.execute_batch(|keys| {
489 keys.into_iter().map(|k| (k, k * 10)).collect()
490 });
491
492 assert_eq!(result.results.len(), 5);
493
494 let stats = loader.stats();
495 assert_eq!(stats.batch_loads, 3); }
497
498 #[test]
499 fn test_dataloader_clear() {
500 let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
501
502 loader.prime("key1".to_string(), "value1".to_string());
503 loader.prime("key2".to_string(), "value2".to_string());
504
505 assert!(loader.load("key1".to_string()).is_some());
506
507 loader.clear();
508
509 assert!(loader.load("key1".to_string()).is_none());
511 }
512
513 #[test]
514 fn test_dataloader_clear_key() {
515 let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
516
517 loader.prime("key1".to_string(), "value1".to_string());
518 loader.prime("key2".to_string(), "value2".to_string());
519
520 loader.clear_key(&"key1".to_string());
521
522 assert!(loader.load("key1".to_string()).is_none());
523 assert!(loader.load("key2".to_string()).is_some());
524 }
525
526 #[test]
527 fn test_dataloader_stats() {
528 let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
529
530 loader.prime("cached".to_string(), "value".to_string());
531
532 loader.load("cached".to_string());
534 loader.load("not_cached".to_string());
536
537 let stats = loader.stats();
538 assert_eq!(stats.total_loads, 2);
539 assert_eq!(stats.cache_hits, 1);
540 assert_eq!(stats.cache_misses, 1);
541 assert_eq!(stats.hit_rate(), 0.5);
542 }
543
544 #[test]
545 fn test_dataloader_cache_disabled() {
546 let loader: DataLoader<String, String> = DataLoader::new(
547 DataLoaderConfig::default().cache(false)
548 );
549
550 loader.prime("key1".to_string(), "value1".to_string());
551
552 let result = loader.load("key1".to_string());
554 assert!(result.is_none());
555 }
556
557 #[test]
558 fn test_batch_result() {
559 let mut results = HashMap::new();
560 results.insert("a".to_string(), 1);
561 results.insert("b".to_string(), 2);
562
563 let batch = BatchResult::new(results)
564 .with_missing(vec!["c".to_string()]);
565
566 assert_eq!(batch.get(&"a".to_string()), Some(&1));
567 assert_eq!(batch.get(&"c".to_string()), None);
568 assert!(batch.is_missing(&"c".to_string()));
569 assert!(!batch.is_missing(&"a".to_string()));
570 }
571
572 #[test]
573 fn test_dataloader_factory() {
574 let factory = DataLoaderFactory::new(
575 DataLoaderConfig::default().max_batch_size(50)
576 );
577
578 let loader: DataLoader<String, i32> = factory.create();
579 assert_eq!(loader.config().max_batch_size, 50);
580
581 let custom_loader: DataLoader<String, i32> = factory.create_with_config(
582 DataLoaderConfig::default().max_batch_size(100)
583 );
584 assert_eq!(custom_loader.config().max_batch_size, 100);
585 }
586
587 #[test]
588 fn test_dataloader_load_many() {
589 let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
590
591 loader.prime("key1".to_string(), "value1".to_string());
592
593 let results = loader.load_many(vec![
594 "key1".to_string(),
595 "key2".to_string(),
596 ]);
597
598 assert_eq!(results.get(&"key1".to_string()), Some(&Some("value1".to_string())));
599 assert_eq!(results.get(&"key2".to_string()), Some(&None));
600 }
601}