Skip to main content

heliosdb_proxy/graphql/
dataloader.rs

1//! DataLoader
2//!
3//! Batching and caching for N+1 query prevention.
4
5use std::collections::HashMap;
6use std::hash::Hash;
7use std::time::{Duration, Instant};
8
9/// DataLoader configuration
10#[derive(Debug, Clone)]
11pub struct DataLoaderConfig {
12    /// Batch window duration
13    pub batch_window: Duration,
14    /// Maximum batch size
15    pub max_batch_size: usize,
16    /// Enable caching
17    pub cache_enabled: bool,
18    /// Cache TTL
19    pub cache_ttl: Duration,
20    /// Enable deduplication
21    pub dedupe: bool,
22}
23
24impl Default for DataLoaderConfig {
25    fn default() -> Self {
26        Self {
27            batch_window: Duration::from_millis(10),
28            max_batch_size: 100,
29            cache_enabled: true,
30            cache_ttl: Duration::from_secs(60),
31            dedupe: true,
32        }
33    }
34}
35
36impl DataLoaderConfig {
37    /// Create a new configuration
38    pub fn new() -> Self {
39        Self::default()
40    }
41
42    /// Set batch window
43    pub fn batch_window(mut self, duration: Duration) -> Self {
44        self.batch_window = duration;
45        self
46    }
47
48    /// Set max batch size
49    pub fn max_batch_size(mut self, size: usize) -> Self {
50        self.max_batch_size = size;
51        self
52    }
53
54    /// Enable/disable caching
55    pub fn cache(mut self, enabled: bool) -> Self {
56        self.cache_enabled = enabled;
57        self
58    }
59
60    /// Set cache TTL
61    pub fn cache_ttl(mut self, ttl: Duration) -> Self {
62        self.cache_ttl = ttl;
63        self
64    }
65}
66
67/// Batch result from loader function
68#[derive(Debug, Clone)]
69pub struct BatchResult<K, V> {
70    /// Results mapped by key
71    pub results: HashMap<K, V>,
72    /// Keys that returned no results
73    pub missing: Vec<K>,
74}
75
76impl<K: Eq + Hash, V> BatchResult<K, V> {
77    /// Create a new batch result
78    pub fn new(results: HashMap<K, V>) -> Self {
79        Self {
80            results,
81            missing: Vec::new(),
82        }
83    }
84
85    /// Create an empty result
86    pub fn empty() -> Self {
87        Self {
88            results: HashMap::new(),
89            missing: Vec::new(),
90        }
91    }
92
93    /// Add missing keys
94    pub fn with_missing(mut self, missing: Vec<K>) -> Self {
95        self.missing = missing;
96        self
97    }
98
99    /// Get a value by key
100    pub fn get(&self, key: &K) -> Option<&V> {
101        self.results.get(key)
102    }
103
104    /// Check if a key is missing
105    pub fn is_missing(&self, key: &K) -> bool
106    where
107        K: PartialEq,
108    {
109        self.missing.contains(key)
110    }
111}
112
113/// Cache entry with TTL
114#[derive(Debug, Clone)]
115struct CacheEntry<V> {
116    value: V,
117    expires_at: Instant,
118}
119
120impl<V> CacheEntry<V> {
121    fn new(value: V, ttl: Duration) -> Self {
122        Self {
123            value,
124            expires_at: Instant::now() + ttl,
125        }
126    }
127
128    fn is_expired(&self) -> bool {
129        Instant::now() >= self.expires_at
130    }
131}
132
133/// DataLoader for batching and caching
134///
135/// Prevents N+1 queries by batching multiple individual loads
136/// into a single batch load.
137#[derive(Debug)]
138pub struct DataLoader<K, V>
139where
140    K: Eq + Hash + Clone,
141    V: Clone,
142{
143    /// Configuration
144    config: DataLoaderConfig,
145    /// Cache
146    cache: std::sync::Mutex<HashMap<K, CacheEntry<V>>>,
147    /// Pending requests
148    pending: std::sync::Mutex<Vec<K>>,
149    /// Statistics
150    stats: std::sync::Mutex<DataLoaderStats>,
151}
152
153/// DataLoader statistics
154#[derive(Debug, Clone, Default)]
155pub struct DataLoaderStats {
156    /// Total loads requested
157    pub total_loads: u64,
158    /// Cache hits
159    pub cache_hits: u64,
160    /// Cache misses
161    pub cache_misses: u64,
162    /// Batch loads executed
163    pub batch_loads: u64,
164    /// Average batch size
165    pub avg_batch_size: f64,
166}
167
168impl DataLoaderStats {
169    /// Get cache hit rate
170    pub fn hit_rate(&self) -> f64 {
171        if self.total_loads == 0 {
172            0.0
173        } else {
174            self.cache_hits as f64 / self.total_loads as f64
175        }
176    }
177}
178
179impl<K, V> DataLoader<K, V>
180where
181    K: Eq + Hash + Clone + Send + Sync,
182    V: Clone + Send + Sync,
183{
184    /// Create a new DataLoader
185    pub fn new(config: DataLoaderConfig) -> Self {
186        Self {
187            config,
188            cache: std::sync::Mutex::new(HashMap::new()),
189            pending: std::sync::Mutex::new(Vec::new()),
190            stats: std::sync::Mutex::new(DataLoaderStats::default()),
191        }
192    }
193
194    /// Load a single value
195    pub fn load(&self, key: K) -> Option<V> {
196        self.update_stats(|s| s.total_loads += 1);
197
198        // Check cache first
199        if self.config.cache_enabled {
200            if let Some(value) = self.get_cached(&key) {
201                self.update_stats(|s| s.cache_hits += 1);
202                return Some(value);
203            }
204            self.update_stats(|s| s.cache_misses += 1);
205        }
206
207        // Add to pending
208        self.pending.lock().unwrap().push(key);
209
210        None
211    }
212
213    /// Load multiple values
214    pub fn load_many(&self, keys: Vec<K>) -> HashMap<K, Option<V>> {
215        let mut results = HashMap::new();
216
217        for key in keys {
218            results.insert(key.clone(), self.load(key));
219        }
220
221        results
222    }
223
224    /// Prime the cache with a value
225    pub fn prime(&self, key: K, value: V) {
226        if self.config.cache_enabled {
227            let entry = CacheEntry::new(value, self.config.cache_ttl);
228            self.cache.lock().unwrap().insert(key, entry);
229        }
230    }
231
232    /// Clear the cache
233    pub fn clear(&self) {
234        self.cache.lock().unwrap().clear();
235    }
236
237    /// Clear a single key from the cache
238    pub fn clear_key(&self, key: &K) {
239        self.cache.lock().unwrap().remove(key);
240    }
241
242    /// Execute pending batch
243    pub fn execute_batch<F>(&self, mut loader: F) -> BatchResult<K, V>
244    where
245        F: FnMut(Vec<K>) -> HashMap<K, V>,
246    {
247        // Take pending keys
248        let keys: Vec<K> = {
249            let mut pending = self.pending.lock().unwrap();
250            std::mem::take(&mut *pending)
251        };
252
253        if keys.is_empty() {
254            return BatchResult::empty();
255        }
256
257        // Deduplicate if enabled
258        let unique_keys: Vec<K> = if self.config.dedupe {
259            let mut seen = std::collections::HashSet::new();
260            keys.into_iter()
261                .filter(|k| seen.insert(k.clone()))
262                .collect()
263        } else {
264            keys
265        };
266
267        // Split into batches if needed
268        let _batch_count = (unique_keys.len() + self.config.max_batch_size - 1)
269            / self.config.max_batch_size;
270
271        let mut all_results = HashMap::new();
272
273        for batch in unique_keys.chunks(self.config.max_batch_size) {
274            let batch_keys: Vec<K> = batch.to_vec();
275            let batch_size = batch_keys.len();
276
277            // Execute loader
278            let results = loader(batch_keys);
279
280            self.update_stats(|s| {
281                s.batch_loads += 1;
282                let total_batches = s.batch_loads as f64;
283                s.avg_batch_size = ((s.avg_batch_size * (total_batches - 1.0)) + batch_size as f64)
284                    / total_batches;
285            });
286
287            // Cache results
288            if self.config.cache_enabled {
289                let mut cache = self.cache.lock().unwrap();
290                for (k, v) in &results {
291                    cache.insert(k.clone(), CacheEntry::new(v.clone(), self.config.cache_ttl));
292                }
293            }
294
295            all_results.extend(results);
296        }
297
298        BatchResult::new(all_results)
299    }
300
301    /// Get cached value if valid
302    fn get_cached(&self, key: &K) -> Option<V> {
303        let mut cache = self.cache.lock().unwrap();
304
305        if let Some(entry) = cache.get(key) {
306            if !entry.is_expired() {
307                return Some(entry.value.clone());
308            } else {
309                cache.remove(key);
310            }
311        }
312
313        None
314    }
315
316    /// Update statistics
317    fn update_stats<F>(&self, f: F)
318    where
319        F: FnOnce(&mut DataLoaderStats),
320    {
321        let mut stats = self.stats.lock().unwrap();
322        f(&mut stats);
323    }
324
325    /// Get statistics
326    pub fn stats(&self) -> DataLoaderStats {
327        self.stats.lock().unwrap().clone()
328    }
329
330    /// Get configuration
331    pub fn config(&self) -> &DataLoaderConfig {
332        &self.config
333    }
334
335    /// Clean expired cache entries
336    pub fn clean_expired(&self) {
337        let mut cache = self.cache.lock().unwrap();
338        cache.retain(|_, entry| !entry.is_expired());
339    }
340}
341
342impl<K, V> Clone for DataLoader<K, V>
343where
344    K: Eq + Hash + Clone,
345    V: Clone,
346{
347    fn clone(&self) -> Self {
348        Self {
349            config: self.config.clone(),
350            cache: std::sync::Mutex::new(self.cache.lock().unwrap().clone()),
351            pending: std::sync::Mutex::new(self.pending.lock().unwrap().clone()),
352            stats: std::sync::Mutex::new(self.stats.lock().unwrap().clone()),
353        }
354    }
355}
356
357/// DataLoader factory for creating typed loaders
358#[derive(Debug)]
359pub struct DataLoaderFactory {
360    /// Default configuration
361    default_config: DataLoaderConfig,
362}
363
364impl DataLoaderFactory {
365    /// Create a new factory
366    pub fn new(config: DataLoaderConfig) -> Self {
367        Self {
368            default_config: config,
369        }
370    }
371
372    /// Create a DataLoader with default config
373    pub fn create<K, V>(&self) -> DataLoader<K, V>
374    where
375        K: Eq + Hash + Clone + Send + Sync,
376        V: Clone + Send + Sync,
377    {
378        DataLoader::new(self.default_config.clone())
379    }
380
381    /// Create a DataLoader with custom config
382    pub fn create_with_config<K, V>(&self, config: DataLoaderConfig) -> DataLoader<K, V>
383    where
384        K: Eq + Hash + Clone + Send + Sync,
385        V: Clone + Send + Sync,
386    {
387        DataLoader::new(config)
388    }
389}
390
391impl Default for DataLoaderFactory {
392    fn default() -> Self {
393        Self::new(DataLoaderConfig::default())
394    }
395}
396
397/// Type alias for ID-based loaders
398pub type IdLoader<V> = DataLoader<String, V>;
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[test]
405    fn test_dataloader_config() {
406        let config = DataLoaderConfig::new()
407            .batch_window(Duration::from_millis(20))
408            .max_batch_size(50)
409            .cache(true)
410            .cache_ttl(Duration::from_secs(120));
411
412        assert_eq!(config.batch_window, Duration::from_millis(20));
413        assert_eq!(config.max_batch_size, 50);
414        assert!(config.cache_enabled);
415        assert_eq!(config.cache_ttl, Duration::from_secs(120));
416    }
417
418    #[test]
419    fn test_dataloader_prime_and_load() {
420        let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
421
422        loader.prime("key1".to_string(), "value1".to_string());
423
424        let result = loader.load("key1".to_string());
425        assert_eq!(result, Some("value1".to_string()));
426
427        let stats = loader.stats();
428        assert_eq!(stats.cache_hits, 1);
429    }
430
431    #[test]
432    fn test_dataloader_batch_execution() {
433        let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
434
435        // Add pending keys
436        loader.load("key1".to_string());
437        loader.load("key2".to_string());
438        loader.load("key3".to_string());
439
440        // Execute batch
441        let result = loader.execute_batch(|keys| {
442            keys.into_iter()
443                .map(|k| (k.clone(), format!("value_{}", k)))
444                .collect()
445        });
446
447        assert_eq!(result.results.len(), 3);
448        assert_eq!(result.get(&"key1".to_string()), Some(&"value_key1".to_string()));
449
450        let stats = loader.stats();
451        assert_eq!(stats.batch_loads, 1);
452    }
453
454    #[test]
455    fn test_dataloader_deduplication() {
456        let loader: DataLoader<String, i32> = DataLoader::new(
457            DataLoaderConfig::default().max_batch_size(100)
458        );
459
460        // Add duplicate keys
461        loader.load("key1".to_string());
462        loader.load("key1".to_string());
463        loader.load("key2".to_string());
464        loader.load("key1".to_string());
465
466        let mut batch_keys_count = 0;
467        let result = loader.execute_batch(|keys| {
468            batch_keys_count = keys.len();
469            keys.into_iter().map(|k| (k, 1)).collect()
470        });
471
472        // Should only have 2 unique keys
473        assert_eq!(batch_keys_count, 2);
474        assert_eq!(result.results.len(), 2);
475    }
476
477    #[test]
478    fn test_dataloader_batch_splitting() {
479        let loader: DataLoader<i32, i32> = DataLoader::new(
480            DataLoaderConfig::default().max_batch_size(2)
481        );
482
483        // Add 5 keys
484        for i in 0..5 {
485            loader.load(i);
486        }
487
488        let result = loader.execute_batch(|keys| {
489            keys.into_iter().map(|k| (k, k * 10)).collect()
490        });
491
492        assert_eq!(result.results.len(), 5);
493
494        let stats = loader.stats();
495        assert_eq!(stats.batch_loads, 3); // 5 keys / 2 per batch = 3 batches
496    }
497
498    #[test]
499    fn test_dataloader_clear() {
500        let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
501
502        loader.prime("key1".to_string(), "value1".to_string());
503        loader.prime("key2".to_string(), "value2".to_string());
504
505        assert!(loader.load("key1".to_string()).is_some());
506
507        loader.clear();
508
509        // After clear, should be cache miss
510        assert!(loader.load("key1".to_string()).is_none());
511    }
512
513    #[test]
514    fn test_dataloader_clear_key() {
515        let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
516
517        loader.prime("key1".to_string(), "value1".to_string());
518        loader.prime("key2".to_string(), "value2".to_string());
519
520        loader.clear_key(&"key1".to_string());
521
522        assert!(loader.load("key1".to_string()).is_none());
523        assert!(loader.load("key2".to_string()).is_some());
524    }
525
526    #[test]
527    fn test_dataloader_stats() {
528        let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
529
530        loader.prime("cached".to_string(), "value".to_string());
531
532        // Cache hit
533        loader.load("cached".to_string());
534        // Cache miss
535        loader.load("not_cached".to_string());
536
537        let stats = loader.stats();
538        assert_eq!(stats.total_loads, 2);
539        assert_eq!(stats.cache_hits, 1);
540        assert_eq!(stats.cache_misses, 1);
541        assert_eq!(stats.hit_rate(), 0.5);
542    }
543
544    #[test]
545    fn test_dataloader_cache_disabled() {
546        let loader: DataLoader<String, String> = DataLoader::new(
547            DataLoaderConfig::default().cache(false)
548        );
549
550        loader.prime("key1".to_string(), "value1".to_string());
551
552        // With cache disabled, prime doesn't work
553        let result = loader.load("key1".to_string());
554        assert!(result.is_none());
555    }
556
557    #[test]
558    fn test_batch_result() {
559        let mut results = HashMap::new();
560        results.insert("a".to_string(), 1);
561        results.insert("b".to_string(), 2);
562
563        let batch = BatchResult::new(results)
564            .with_missing(vec!["c".to_string()]);
565
566        assert_eq!(batch.get(&"a".to_string()), Some(&1));
567        assert_eq!(batch.get(&"c".to_string()), None);
568        assert!(batch.is_missing(&"c".to_string()));
569        assert!(!batch.is_missing(&"a".to_string()));
570    }
571
572    #[test]
573    fn test_dataloader_factory() {
574        let factory = DataLoaderFactory::new(
575            DataLoaderConfig::default().max_batch_size(50)
576        );
577
578        let loader: DataLoader<String, i32> = factory.create();
579        assert_eq!(loader.config().max_batch_size, 50);
580
581        let custom_loader: DataLoader<String, i32> = factory.create_with_config(
582            DataLoaderConfig::default().max_batch_size(100)
583        );
584        assert_eq!(custom_loader.config().max_batch_size, 100);
585    }
586
587    #[test]
588    fn test_dataloader_load_many() {
589        let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
590
591        loader.prime("key1".to_string(), "value1".to_string());
592
593        let results = loader.load_many(vec![
594            "key1".to_string(),
595            "key2".to_string(),
596        ]);
597
598        assert_eq!(results.get(&"key1".to_string()), Some(&Some("value1".to_string())));
599        assert_eq!(results.get(&"key2".to_string()), Some(&None));
600    }
601}