Skip to main content

heliosdb_proxy/graphql/
dataloader.rs

1//! DataLoader
2//!
3//! Batching and caching for N+1 query prevention.
4
5use std::collections::HashMap;
6use std::hash::Hash;
7use std::time::{Duration, Instant};
8
9/// DataLoader configuration
10#[derive(Debug, Clone)]
11pub struct DataLoaderConfig {
12    /// Batch window duration
13    pub batch_window: Duration,
14    /// Maximum batch size
15    pub max_batch_size: usize,
16    /// Enable caching
17    pub cache_enabled: bool,
18    /// Cache TTL
19    pub cache_ttl: Duration,
20    /// Enable deduplication
21    pub dedupe: bool,
22}
23
24impl Default for DataLoaderConfig {
25    fn default() -> Self {
26        Self {
27            batch_window: Duration::from_millis(10),
28            max_batch_size: 100,
29            cache_enabled: true,
30            cache_ttl: Duration::from_secs(60),
31            dedupe: true,
32        }
33    }
34}
35
36impl DataLoaderConfig {
37    /// Create a new configuration
38    pub fn new() -> Self {
39        Self::default()
40    }
41
42    /// Set batch window
43    pub fn batch_window(mut self, duration: Duration) -> Self {
44        self.batch_window = duration;
45        self
46    }
47
48    /// Set max batch size
49    pub fn max_batch_size(mut self, size: usize) -> Self {
50        self.max_batch_size = size;
51        self
52    }
53
54    /// Enable/disable caching
55    pub fn cache(mut self, enabled: bool) -> Self {
56        self.cache_enabled = enabled;
57        self
58    }
59
60    /// Set cache TTL
61    pub fn cache_ttl(mut self, ttl: Duration) -> Self {
62        self.cache_ttl = ttl;
63        self
64    }
65}
66
67/// Batch result from loader function
68#[derive(Debug, Clone)]
69pub struct BatchResult<K, V> {
70    /// Results mapped by key
71    pub results: HashMap<K, V>,
72    /// Keys that returned no results
73    pub missing: Vec<K>,
74}
75
76impl<K: Eq + Hash, V> BatchResult<K, V> {
77    /// Create a new batch result
78    pub fn new(results: HashMap<K, V>) -> Self {
79        Self {
80            results,
81            missing: Vec::new(),
82        }
83    }
84
85    /// Create an empty result
86    pub fn empty() -> Self {
87        Self {
88            results: HashMap::new(),
89            missing: Vec::new(),
90        }
91    }
92
93    /// Add missing keys
94    pub fn with_missing(mut self, missing: Vec<K>) -> Self {
95        self.missing = missing;
96        self
97    }
98
99    /// Get a value by key
100    pub fn get(&self, key: &K) -> Option<&V> {
101        self.results.get(key)
102    }
103
104    /// Check if a key is missing
105    pub fn is_missing(&self, key: &K) -> bool
106    where
107        K: PartialEq,
108    {
109        self.missing.contains(key)
110    }
111}
112
113/// Cache entry with TTL
114#[derive(Debug, Clone)]
115struct CacheEntry<V> {
116    value: V,
117    expires_at: Instant,
118}
119
120impl<V> CacheEntry<V> {
121    fn new(value: V, ttl: Duration) -> Self {
122        Self {
123            value,
124            expires_at: Instant::now() + ttl,
125        }
126    }
127
128    fn is_expired(&self) -> bool {
129        Instant::now() >= self.expires_at
130    }
131}
132
133/// DataLoader for batching and caching
134///
135/// Prevents N+1 queries by batching multiple individual loads
136/// into a single batch load.
137#[derive(Debug)]
138pub struct DataLoader<K, V>
139where
140    K: Eq + Hash + Clone,
141    V: Clone,
142{
143    /// Configuration
144    config: DataLoaderConfig,
145    /// Cache
146    cache: std::sync::Mutex<HashMap<K, CacheEntry<V>>>,
147    /// Pending requests
148    pending: std::sync::Mutex<Vec<K>>,
149    /// Statistics
150    stats: std::sync::Mutex<DataLoaderStats>,
151}
152
153/// DataLoader statistics
154#[derive(Debug, Clone, Default)]
155pub struct DataLoaderStats {
156    /// Total loads requested
157    pub total_loads: u64,
158    /// Cache hits
159    pub cache_hits: u64,
160    /// Cache misses
161    pub cache_misses: u64,
162    /// Batch loads executed
163    pub batch_loads: u64,
164    /// Average batch size
165    pub avg_batch_size: f64,
166}
167
168impl DataLoaderStats {
169    /// Get cache hit rate
170    pub fn hit_rate(&self) -> f64 {
171        if self.total_loads == 0 {
172            0.0
173        } else {
174            self.cache_hits as f64 / self.total_loads as f64
175        }
176    }
177}
178
179impl<K, V> DataLoader<K, V>
180where
181    K: Eq + Hash + Clone + Send + Sync,
182    V: Clone + Send + Sync,
183{
184    /// Create a new DataLoader
185    pub fn new(config: DataLoaderConfig) -> Self {
186        Self {
187            config,
188            cache: std::sync::Mutex::new(HashMap::new()),
189            pending: std::sync::Mutex::new(Vec::new()),
190            stats: std::sync::Mutex::new(DataLoaderStats::default()),
191        }
192    }
193
194    /// Load a single value
195    pub fn load(&self, key: K) -> Option<V> {
196        self.update_stats(|s| s.total_loads += 1);
197
198        // Check cache first
199        if self.config.cache_enabled {
200            if let Some(value) = self.get_cached(&key) {
201                self.update_stats(|s| s.cache_hits += 1);
202                return Some(value);
203            }
204            self.update_stats(|s| s.cache_misses += 1);
205        }
206
207        // Add to pending
208        self.pending.lock().unwrap().push(key);
209
210        None
211    }
212
213    /// Load multiple values
214    pub fn load_many(&self, keys: Vec<K>) -> HashMap<K, Option<V>> {
215        let mut results = HashMap::new();
216
217        for key in keys {
218            results.insert(key.clone(), self.load(key));
219        }
220
221        results
222    }
223
224    /// Prime the cache with a value
225    pub fn prime(&self, key: K, value: V) {
226        if self.config.cache_enabled {
227            let entry = CacheEntry::new(value, self.config.cache_ttl);
228            self.cache.lock().unwrap().insert(key, entry);
229        }
230    }
231
232    /// Clear the cache
233    pub fn clear(&self) {
234        self.cache.lock().unwrap().clear();
235    }
236
237    /// Clear a single key from the cache
238    pub fn clear_key(&self, key: &K) {
239        self.cache.lock().unwrap().remove(key);
240    }
241
242    /// Execute pending batch
243    pub fn execute_batch<F>(&self, mut loader: F) -> BatchResult<K, V>
244    where
245        F: FnMut(Vec<K>) -> HashMap<K, V>,
246    {
247        // Take pending keys
248        let keys: Vec<K> = {
249            let mut pending = self.pending.lock().unwrap();
250            std::mem::take(&mut *pending)
251        };
252
253        if keys.is_empty() {
254            return BatchResult::empty();
255        }
256
257        // Deduplicate if enabled
258        let unique_keys: Vec<K> = if self.config.dedupe {
259            let mut seen = std::collections::HashSet::new();
260            keys.into_iter()
261                .filter(|k| seen.insert(k.clone()))
262                .collect()
263        } else {
264            keys
265        };
266
267        // Split into batches if needed
268        let _batch_count = unique_keys.len().div_ceil(self.config.max_batch_size);
269
270        let mut all_results = HashMap::new();
271
272        for batch in unique_keys.chunks(self.config.max_batch_size) {
273            let batch_keys: Vec<K> = batch.to_vec();
274            let batch_size = batch_keys.len();
275
276            // Execute loader
277            let results = loader(batch_keys);
278
279            self.update_stats(|s| {
280                s.batch_loads += 1;
281                let total_batches = s.batch_loads as f64;
282                s.avg_batch_size = ((s.avg_batch_size * (total_batches - 1.0)) + batch_size as f64)
283                    / total_batches;
284            });
285
286            // Cache results
287            if self.config.cache_enabled {
288                let mut cache = self.cache.lock().unwrap();
289                for (k, v) in &results {
290                    cache.insert(k.clone(), CacheEntry::new(v.clone(), self.config.cache_ttl));
291                }
292            }
293
294            all_results.extend(results);
295        }
296
297        BatchResult::new(all_results)
298    }
299
300    /// Get cached value if valid
301    fn get_cached(&self, key: &K) -> Option<V> {
302        let mut cache = self.cache.lock().unwrap();
303
304        if let Some(entry) = cache.get(key) {
305            if !entry.is_expired() {
306                return Some(entry.value.clone());
307            } else {
308                cache.remove(key);
309            }
310        }
311
312        None
313    }
314
315    /// Update statistics
316    fn update_stats<F>(&self, f: F)
317    where
318        F: FnOnce(&mut DataLoaderStats),
319    {
320        let mut stats = self.stats.lock().unwrap();
321        f(&mut stats);
322    }
323
324    /// Get statistics
325    pub fn stats(&self) -> DataLoaderStats {
326        self.stats.lock().unwrap().clone()
327    }
328
329    /// Get configuration
330    pub fn config(&self) -> &DataLoaderConfig {
331        &self.config
332    }
333
334    /// Clean expired cache entries
335    pub fn clean_expired(&self) {
336        let mut cache = self.cache.lock().unwrap();
337        cache.retain(|_, entry| !entry.is_expired());
338    }
339}
340
341impl<K, V> Clone for DataLoader<K, V>
342where
343    K: Eq + Hash + Clone,
344    V: Clone,
345{
346    fn clone(&self) -> Self {
347        Self {
348            config: self.config.clone(),
349            cache: std::sync::Mutex::new(self.cache.lock().unwrap().clone()),
350            pending: std::sync::Mutex::new(self.pending.lock().unwrap().clone()),
351            stats: std::sync::Mutex::new(self.stats.lock().unwrap().clone()),
352        }
353    }
354}
355
356/// DataLoader factory for creating typed loaders
357#[derive(Debug)]
358pub struct DataLoaderFactory {
359    /// Default configuration
360    default_config: DataLoaderConfig,
361}
362
363impl DataLoaderFactory {
364    /// Create a new factory
365    pub fn new(config: DataLoaderConfig) -> Self {
366        Self {
367            default_config: config,
368        }
369    }
370
371    /// Create a DataLoader with default config
372    pub fn create<K, V>(&self) -> DataLoader<K, V>
373    where
374        K: Eq + Hash + Clone + Send + Sync,
375        V: Clone + Send + Sync,
376    {
377        DataLoader::new(self.default_config.clone())
378    }
379
380    /// Create a DataLoader with custom config
381    pub fn create_with_config<K, V>(&self, config: DataLoaderConfig) -> DataLoader<K, V>
382    where
383        K: Eq + Hash + Clone + Send + Sync,
384        V: Clone + Send + Sync,
385    {
386        DataLoader::new(config)
387    }
388}
389
390impl Default for DataLoaderFactory {
391    fn default() -> Self {
392        Self::new(DataLoaderConfig::default())
393    }
394}
395
396/// Type alias for ID-based loaders
397pub type IdLoader<V> = DataLoader<String, V>;
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn test_dataloader_config() {
405        let config = DataLoaderConfig::new()
406            .batch_window(Duration::from_millis(20))
407            .max_batch_size(50)
408            .cache(true)
409            .cache_ttl(Duration::from_secs(120));
410
411        assert_eq!(config.batch_window, Duration::from_millis(20));
412        assert_eq!(config.max_batch_size, 50);
413        assert!(config.cache_enabled);
414        assert_eq!(config.cache_ttl, Duration::from_secs(120));
415    }
416
417    #[test]
418    fn test_dataloader_prime_and_load() {
419        let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
420
421        loader.prime("key1".to_string(), "value1".to_string());
422
423        let result = loader.load("key1".to_string());
424        assert_eq!(result, Some("value1".to_string()));
425
426        let stats = loader.stats();
427        assert_eq!(stats.cache_hits, 1);
428    }
429
430    #[test]
431    fn test_dataloader_batch_execution() {
432        let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
433
434        // Add pending keys
435        loader.load("key1".to_string());
436        loader.load("key2".to_string());
437        loader.load("key3".to_string());
438
439        // Execute batch
440        let result = loader.execute_batch(|keys| {
441            keys.into_iter()
442                .map(|k| (k.clone(), format!("value_{}", k)))
443                .collect()
444        });
445
446        assert_eq!(result.results.len(), 3);
447        assert_eq!(
448            result.get(&"key1".to_string()),
449            Some(&"value_key1".to_string())
450        );
451
452        let stats = loader.stats();
453        assert_eq!(stats.batch_loads, 1);
454    }
455
456    #[test]
457    fn test_dataloader_deduplication() {
458        let loader: DataLoader<String, i32> =
459            DataLoader::new(DataLoaderConfig::default().max_batch_size(100));
460
461        // Add duplicate keys
462        loader.load("key1".to_string());
463        loader.load("key1".to_string());
464        loader.load("key2".to_string());
465        loader.load("key1".to_string());
466
467        let mut batch_keys_count = 0;
468        let result = loader.execute_batch(|keys| {
469            batch_keys_count = keys.len();
470            keys.into_iter().map(|k| (k, 1)).collect()
471        });
472
473        // Should only have 2 unique keys
474        assert_eq!(batch_keys_count, 2);
475        assert_eq!(result.results.len(), 2);
476    }
477
478    #[test]
479    fn test_dataloader_batch_splitting() {
480        let loader: DataLoader<i32, i32> =
481            DataLoader::new(DataLoaderConfig::default().max_batch_size(2));
482
483        // Add 5 keys
484        for i in 0..5 {
485            loader.load(i);
486        }
487
488        let result = loader.execute_batch(|keys| keys.into_iter().map(|k| (k, k * 10)).collect());
489
490        assert_eq!(result.results.len(), 5);
491
492        let stats = loader.stats();
493        assert_eq!(stats.batch_loads, 3); // 5 keys / 2 per batch = 3 batches
494    }
495
496    #[test]
497    fn test_dataloader_clear() {
498        let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
499
500        loader.prime("key1".to_string(), "value1".to_string());
501        loader.prime("key2".to_string(), "value2".to_string());
502
503        assert!(loader.load("key1".to_string()).is_some());
504
505        loader.clear();
506
507        // After clear, should be cache miss
508        assert!(loader.load("key1".to_string()).is_none());
509    }
510
511    #[test]
512    fn test_dataloader_clear_key() {
513        let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
514
515        loader.prime("key1".to_string(), "value1".to_string());
516        loader.prime("key2".to_string(), "value2".to_string());
517
518        loader.clear_key(&"key1".to_string());
519
520        assert!(loader.load("key1".to_string()).is_none());
521        assert!(loader.load("key2".to_string()).is_some());
522    }
523
524    #[test]
525    fn test_dataloader_stats() {
526        let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
527
528        loader.prime("cached".to_string(), "value".to_string());
529
530        // Cache hit
531        loader.load("cached".to_string());
532        // Cache miss
533        loader.load("not_cached".to_string());
534
535        let stats = loader.stats();
536        assert_eq!(stats.total_loads, 2);
537        assert_eq!(stats.cache_hits, 1);
538        assert_eq!(stats.cache_misses, 1);
539        assert_eq!(stats.hit_rate(), 0.5);
540    }
541
542    #[test]
543    fn test_dataloader_cache_disabled() {
544        let loader: DataLoader<String, String> =
545            DataLoader::new(DataLoaderConfig::default().cache(false));
546
547        loader.prime("key1".to_string(), "value1".to_string());
548
549        // With cache disabled, prime doesn't work
550        let result = loader.load("key1".to_string());
551        assert!(result.is_none());
552    }
553
554    #[test]
555    fn test_batch_result() {
556        let mut results = HashMap::new();
557        results.insert("a".to_string(), 1);
558        results.insert("b".to_string(), 2);
559
560        let batch = BatchResult::new(results).with_missing(vec!["c".to_string()]);
561
562        assert_eq!(batch.get(&"a".to_string()), Some(&1));
563        assert_eq!(batch.get(&"c".to_string()), None);
564        assert!(batch.is_missing(&"c".to_string()));
565        assert!(!batch.is_missing(&"a".to_string()));
566    }
567
568    #[test]
569    fn test_dataloader_factory() {
570        let factory = DataLoaderFactory::new(DataLoaderConfig::default().max_batch_size(50));
571
572        let loader: DataLoader<String, i32> = factory.create();
573        assert_eq!(loader.config().max_batch_size, 50);
574
575        let custom_loader: DataLoader<String, i32> =
576            factory.create_with_config(DataLoaderConfig::default().max_batch_size(100));
577        assert_eq!(custom_loader.config().max_batch_size, 100);
578    }
579
580    #[test]
581    fn test_dataloader_load_many() {
582        let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
583
584        loader.prime("key1".to_string(), "value1".to_string());
585
586        let results = loader.load_many(vec!["key1".to_string(), "key2".to_string()]);
587
588        assert_eq!(
589            results.get(&"key1".to_string()),
590            Some(&Some("value1".to_string()))
591        );
592        assert_eq!(results.get(&"key2".to_string()), Some(&None));
593    }
594}