Skip to main content

cortexai_data/
pipeline.rs

1//! Data pipeline with LRU caching
2//!
3//! Async data processing with TTL-based cache eviction and parallel batch support.
4
5use crate::metrics::DataMatchingMetrics;
6use crate::types::{DataError, DataSource};
7use dashmap::DashMap;
8use futures::future::join_all;
9use lru::LruCache;
10use parking_lot::Mutex;
11use std::num::NonZeroUsize;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tracing::{debug, info};
15
16/// Cached data entry with TTL
17#[derive(Debug, Clone)]
18pub struct CachedData {
19    /// The cached data source
20    pub data: DataSource,
21    /// When this entry was cached
22    pub cached_at: Instant,
23    /// Time-to-live for this entry
24    pub ttl: Duration,
25}
26
27impl CachedData {
28    /// Create a new cached entry
29    pub fn new(data: DataSource, ttl: Duration) -> Self {
30        Self {
31            data,
32            cached_at: Instant::now(),
33            ttl,
34        }
35    }
36
37    /// Check if this entry has expired
38    pub fn is_expired(&self) -> bool {
39        self.cached_at.elapsed() > self.ttl
40    }
41
42    /// Get remaining TTL
43    pub fn remaining_ttl(&self) -> Duration {
44        self.ttl.saturating_sub(self.cached_at.elapsed())
45    }
46}
47
48/// LRU cache for data sources
49pub struct DataCache {
50    /// The LRU cache
51    cache: Mutex<LruCache<String, CachedData>>,
52    /// Default TTL for entries
53    default_ttl: Duration,
54    /// Cache statistics
55    stats: Arc<CacheStats>,
56}
57
58/// Cache statistics
59#[derive(Debug, Default)]
60pub struct CacheStats {
61    pub hits: std::sync::atomic::AtomicU64,
62    pub misses: std::sync::atomic::AtomicU64,
63    pub evictions: std::sync::atomic::AtomicU64,
64}
65
66impl CacheStats {
67    pub fn hit_rate(&self) -> f64 {
68        use std::sync::atomic::Ordering;
69        let hits = self.hits.load(Ordering::Relaxed);
70        let misses = self.misses.load(Ordering::Relaxed);
71        let total = hits + misses;
72        if total == 0 {
73            0.0
74        } else {
75            hits as f64 / total as f64
76        }
77    }
78}
79
80impl DataCache {
81    /// Create a new cache with given capacity
82    pub fn new(capacity: usize) -> Self {
83        Self {
84            cache: Mutex::new(LruCache::new(
85                NonZeroUsize::new(capacity).expect("capacity must be > 0"),
86            )),
87            default_ttl: Duration::from_secs(300), // 5 minutes default
88            stats: Arc::new(CacheStats::default()),
89        }
90    }
91
92    /// Set default TTL
93    pub fn with_ttl(mut self, ttl: Duration) -> Self {
94        self.default_ttl = ttl;
95        self
96    }
97
98    /// Get from cache
99    pub fn get(&self, key: &str) -> Option<DataSource> {
100        use std::sync::atomic::Ordering;
101
102        let mut cache = self.cache.lock();
103
104        if let Some(entry) = cache.get(key) {
105            if entry.is_expired() {
106                debug!(key = key, "Cache entry expired");
107                cache.pop(key);
108                self.stats.misses.fetch_add(1, Ordering::Relaxed);
109                return None;
110            }
111
112            debug!(key = key, remaining_ttl_ms = ?entry.remaining_ttl().as_millis(), "Cache hit");
113            self.stats.hits.fetch_add(1, Ordering::Relaxed);
114            return Some(entry.data.clone());
115        }
116
117        self.stats.misses.fetch_add(1, Ordering::Relaxed);
118        None
119    }
120
121    /// Insert into cache
122    pub fn insert(&self, key: String, data: DataSource) {
123        self.insert_with_ttl(key, data, self.default_ttl);
124    }
125
126    /// Insert with custom TTL
127    pub fn insert_with_ttl(&self, key: String, data: DataSource, ttl: Duration) {
128        use std::sync::atomic::Ordering;
129
130        let mut cache = self.cache.lock();
131
132        // Check if we're evicting an entry
133        if cache.len() >= cache.cap().get() {
134            self.stats.evictions.fetch_add(1, Ordering::Relaxed);
135        }
136
137        cache.put(key, CachedData::new(data, ttl));
138    }
139
140    /// Remove from cache
141    pub fn remove(&self, key: &str) -> Option<DataSource> {
142        self.cache.lock().pop(key).map(|e| e.data)
143    }
144
145    /// Clear all entries
146    pub fn clear(&self) {
147        self.cache.lock().clear();
148    }
149
150    /// Get cache size
151    pub fn len(&self) -> usize {
152        self.cache.lock().len()
153    }
154
155    /// Check if cache is empty
156    pub fn is_empty(&self) -> bool {
157        self.cache.lock().is_empty()
158    }
159
160    /// Get cache statistics
161    pub fn stats(&self) -> &CacheStats {
162        &self.stats
163    }
164
165    /// Check if key exists and is not expired
166    pub fn contains(&self, key: &str) -> bool {
167        let cache = self.cache.lock();
168        if let Some(entry) = cache.peek(key) {
169            !entry.is_expired()
170        } else {
171            false
172        }
173    }
174}
175
176/// Data pipeline for async processing
177pub struct DataPipeline {
178    /// Cache for processed data
179    cache: DataCache,
180    /// Data loader function (customizable)
181    #[allow(clippy::type_complexity)]
182    loader: Option<Arc<dyn Fn(&str) -> DataSource + Send + Sync>>,
183}
184
185impl DataPipeline {
186    /// Create a new pipeline
187    pub fn new() -> Self {
188        Self {
189            cache: DataCache::new(100),
190            loader: None,
191        }
192    }
193
194    /// Set cache capacity
195    pub fn with_capacity(mut self, capacity: usize) -> Self {
196        self.cache = DataCache::new(capacity);
197        self
198    }
199
200    /// Set cache TTL
201    pub fn with_ttl(mut self, ttl: Duration) -> Self {
202        self.cache = self.cache.with_ttl(ttl);
203        self
204    }
205
206    /// Set custom data loader
207    pub fn with_loader<F>(mut self, loader: F) -> Self
208    where
209        F: Fn(&str) -> DataSource + Send + Sync + 'static,
210    {
211        self.loader = Some(Arc::new(loader));
212        self
213    }
214
215    /// Process a data source (with caching)
216    pub async fn process(&self, source_id: &str) -> Result<DataSource, DataError> {
217        // Check cache first
218        if let Some(cached) = self.cache.get(source_id) {
219            info!(source_id = source_id, "Returning cached data");
220            return Ok(cached);
221        }
222
223        // Load data
224        let data = self.load_source(source_id).await?;
225
226        // Cache it
227        self.cache.insert(source_id.to_string(), data.clone());
228
229        info!(source_id = source_id, "Loaded and cached data");
230        Ok(data)
231    }
232
233    /// Process multiple sources
234    pub async fn process_batch(&self, source_ids: &[String]) -> Vec<Result<DataSource, DataError>> {
235        let mut results = Vec::with_capacity(source_ids.len());
236
237        for source_id in source_ids {
238            results.push(self.process(source_id).await);
239        }
240
241        results
242    }
243
244    /// Load a data source
245    async fn load_source(&self, source_id: &str) -> Result<DataSource, DataError> {
246        // Simulate async load
247        tokio::time::sleep(Duration::from_millis(10)).await;
248
249        if let Some(ref loader) = self.loader {
250            Ok(loader(source_id))
251        } else {
252            // Default: return empty source
253            Ok(DataSource::new(source_id, format!("Source {}", source_id)))
254        }
255    }
256
257    /// Get cache statistics
258    pub fn cache_stats(&self) -> &CacheStats {
259        self.cache.stats()
260    }
261
262    /// Clear cache
263    pub fn clear_cache(&self) {
264        self.cache.clear();
265    }
266
267    /// Invalidate specific cache entry
268    pub fn invalidate(&self, source_id: &str) {
269        self.cache.remove(source_id);
270    }
271}
272
273impl Default for DataPipeline {
274    fn default() -> Self {
275        Self::new()
276    }
277}
278
279/// Negative cache entry - remembers that something was not found
280#[derive(Debug, Clone)]
281pub struct NegativeCacheEntry {
282    /// When this entry was cached
283    pub cached_at: Instant,
284    /// TTL for negative cache (usually shorter than positive)
285    pub ttl: Duration,
286}
287
288impl NegativeCacheEntry {
289    pub fn new(ttl: Duration) -> Self {
290        Self {
291            cached_at: Instant::now(),
292            ttl,
293        }
294    }
295
296    pub fn is_expired(&self) -> bool {
297        self.cached_at.elapsed() > self.ttl
298    }
299}
300
301/// Concurrent cache using DashMap for high-throughput scenarios
302pub struct ConcurrentCache {
303    /// Main data cache
304    cache: DashMap<String, CachedData>,
305    /// Negative cache (remembers "not found" results)
306    negative_cache: DashMap<String, NegativeCacheEntry>,
307    /// Maximum capacity
308    capacity: usize,
309    /// Default TTL for entries
310    default_ttl: Duration,
311    /// TTL for negative cache entries (shorter)
312    negative_ttl: Duration,
313    /// Metrics
314    metrics: Arc<DataMatchingMetrics>,
315}
316
317impl ConcurrentCache {
318    /// Create a new concurrent cache
319    pub fn new(capacity: usize, metrics: Arc<DataMatchingMetrics>) -> Self {
320        Self {
321            cache: DashMap::with_capacity(capacity),
322            negative_cache: DashMap::with_capacity(capacity / 4),
323            capacity,
324            default_ttl: Duration::from_secs(300),
325            negative_ttl: Duration::from_secs(60),
326            metrics,
327        }
328    }
329
330    /// Set default TTL
331    pub fn with_ttl(mut self, ttl: Duration) -> Self {
332        self.default_ttl = ttl;
333        self
334    }
335
336    /// Set negative cache TTL
337    pub fn with_negative_ttl(mut self, ttl: Duration) -> Self {
338        self.negative_ttl = ttl;
339        self
340    }
341
342    /// Get from cache (checks negative cache too)
343    pub fn get(&self, key: &str) -> CacheResult {
344        // Check negative cache first
345        if let Some(entry) = self.negative_cache.get(key) {
346            if !entry.is_expired() {
347                self.metrics.record_cache(true, true);
348                return CacheResult::NegativeHit;
349            } else {
350                drop(entry);
351                self.negative_cache.remove(key);
352            }
353        }
354
355        // Check main cache
356        if let Some(entry) = self.cache.get(key) {
357            if !entry.is_expired() {
358                self.metrics.record_cache(true, false);
359                return CacheResult::Hit(entry.data.clone());
360            } else {
361                drop(entry);
362                self.cache.remove(key);
363            }
364        }
365
366        self.metrics.record_cache(false, false);
367        CacheResult::Miss
368    }
369
370    /// Insert into cache
371    pub fn insert(&self, key: String, data: DataSource) {
372        // Evict if at capacity
373        if self.cache.len() >= self.capacity {
374            if let Some(entry) = self.cache.iter().next() {
375                let key_to_remove = entry.key().clone();
376                drop(entry);
377                self.cache.remove(&key_to_remove);
378                self.metrics.record_eviction();
379            }
380        }
381
382        // Remove from negative cache if present
383        self.negative_cache.remove(&key);
384
385        self.cache
386            .insert(key, CachedData::new(data, self.default_ttl));
387    }
388
389    /// Insert a negative cache entry
390    pub fn insert_negative(&self, key: String) {
391        if self.negative_cache.len() >= self.capacity / 4 {
392            if let Some(entry) = self.negative_cache.iter().next() {
393                let key_to_remove = entry.key().clone();
394                drop(entry);
395                self.negative_cache.remove(&key_to_remove);
396            }
397        }
398
399        self.negative_cache
400            .insert(key, NegativeCacheEntry::new(self.negative_ttl));
401    }
402
403    /// Remove from both caches
404    pub fn remove(&self, key: &str) {
405        self.cache.remove(key);
406        self.negative_cache.remove(key);
407    }
408
409    /// Clear all caches
410    pub fn clear(&self) {
411        self.cache.clear();
412        self.negative_cache.clear();
413    }
414
415    /// Get cache size
416    pub fn len(&self) -> usize {
417        self.cache.len()
418    }
419
420    /// Check if cache is empty
421    pub fn is_empty(&self) -> bool {
422        self.cache.is_empty()
423    }
424
425    /// Get negative cache size
426    pub fn negative_len(&self) -> usize {
427        self.negative_cache.len()
428    }
429}
430
431/// Result of a cache lookup
432#[derive(Debug, Clone)]
433pub enum CacheResult {
434    /// Data found in cache
435    Hit(DataSource),
436    /// Key was previously looked up and not found
437    NegativeHit,
438    /// Key not in cache
439    Miss,
440}
441
442/// High-performance parallel data pipeline
443pub struct ParallelPipeline {
444    /// Concurrent cache
445    cache: Arc<ConcurrentCache>,
446    /// Async data loader
447    loader: Option<Arc<dyn Fn(String) -> DataSource + Send + Sync>>,
448    /// Metrics
449    metrics: Arc<DataMatchingMetrics>,
450    /// Maximum concurrent tasks
451    max_concurrency: usize,
452}
453
454impl ParallelPipeline {
455    /// Create a new parallel pipeline
456    pub fn new(capacity: usize) -> Self {
457        let metrics = Arc::new(DataMatchingMetrics::new());
458        Self {
459            cache: Arc::new(ConcurrentCache::new(capacity, metrics.clone())),
460            loader: None,
461            metrics,
462            max_concurrency: 10,
463        }
464    }
465
466    /// Set maximum concurrency
467    pub fn with_max_concurrency(mut self, max: usize) -> Self {
468        self.max_concurrency = max;
469        self
470    }
471
472    /// Set cache TTL
473    pub fn with_ttl(mut self, ttl: Duration) -> Self {
474        self.cache =
475            Arc::new(ConcurrentCache::new(self.cache.capacity, self.metrics.clone()).with_ttl(ttl));
476        self
477    }
478
479    /// Set custom data loader
480    pub fn with_loader<F>(mut self, loader: F) -> Self
481    where
482        F: Fn(String) -> DataSource + Send + Sync + 'static,
483    {
484        self.loader = Some(Arc::new(loader));
485        self
486    }
487
488    /// Process a single source
489    pub async fn process(&self, source_id: &str) -> Result<DataSource, DataError> {
490        let start = Instant::now();
491
492        match self.cache.get(source_id) {
493            CacheResult::Hit(data) => {
494                self.metrics.record_query(true, start.elapsed());
495                return Ok(data);
496            }
497            CacheResult::NegativeHit => {
498                self.metrics.record_query(false, start.elapsed());
499                return Err(DataError::SourceNotFound(source_id.to_string()));
500            }
501            CacheResult::Miss => {}
502        }
503
504        let result = self.load_source(source_id).await;
505
506        match result {
507            Ok(data) => {
508                self.cache.insert(source_id.to_string(), data.clone());
509                self.metrics.record_query(true, start.elapsed());
510                Ok(data)
511            }
512            Err(e) => {
513                self.cache.insert_negative(source_id.to_string());
514                self.metrics.record_query(false, start.elapsed());
515                Err(e)
516            }
517        }
518    }
519
520    /// Process multiple sources in parallel
521    pub async fn process_parallel(
522        &self,
523        source_ids: Vec<String>,
524    ) -> Vec<Result<DataSource, DataError>> {
525        let chunks: Vec<_> = source_ids
526            .chunks(self.max_concurrency)
527            .map(|c| c.to_vec())
528            .collect();
529
530        let mut all_results = Vec::with_capacity(source_ids.len());
531
532        for chunk in chunks {
533            let tasks: Vec<_> = chunk
534                .into_iter()
535                .map(|id| {
536                    let cache = self.cache.clone();
537                    let loader = self.loader.clone();
538                    let metrics = self.metrics.clone();
539                    async move {
540                        let start = Instant::now();
541
542                        match cache.get(&id) {
543                            CacheResult::Hit(data) => {
544                                metrics.record_query(true, start.elapsed());
545                                return Ok(data);
546                            }
547                            CacheResult::NegativeHit => {
548                                metrics.record_query(false, start.elapsed());
549                                return Err(DataError::SourceNotFound(id));
550                            }
551                            CacheResult::Miss => {}
552                        }
553
554                        tokio::time::sleep(Duration::from_millis(10)).await;
555
556                        if let Some(ref loader) = loader {
557                            let data = loader(id.clone());
558                            cache.insert(id, data.clone());
559                            metrics.record_query(true, start.elapsed());
560                            Ok(data)
561                        } else {
562                            let data = DataSource::new(&id, format!("Source {}", id));
563                            cache.insert(id, data.clone());
564                            metrics.record_query(true, start.elapsed());
565                            Ok(data)
566                        }
567                    }
568                })
569                .collect();
570
571            let chunk_results = join_all(tasks).await;
572            all_results.extend(chunk_results);
573        }
574
575        all_results
576    }
577
578    /// Load a data source
579    async fn load_source(&self, source_id: &str) -> Result<DataSource, DataError> {
580        tokio::time::sleep(Duration::from_millis(10)).await;
581
582        if let Some(ref loader) = self.loader {
583            Ok(loader(source_id.to_string()))
584        } else {
585            Ok(DataSource::new(source_id, format!("Source {}", source_id)))
586        }
587    }
588
589    /// Get metrics
590    pub fn metrics(&self) -> &DataMatchingMetrics {
591        &self.metrics
592    }
593
594    /// Clear cache
595    pub fn clear_cache(&self) {
596        self.cache.clear();
597    }
598}
599
600#[cfg(test)]
601mod tests {
602    use super::*;
603
604    #[test]
605    fn test_cache_basic() {
606        let cache = DataCache::new(2);
607
608        let source = DataSource::new("test", "Test Source");
609        cache.insert("a".to_string(), source.clone());
610
611        assert!(cache.contains("a"));
612        assert!(!cache.contains("b"));
613
614        let retrieved = cache.get("a");
615        assert!(retrieved.is_some());
616        assert_eq!(retrieved.unwrap().id, "test");
617    }
618
619    #[test]
620    fn test_cache_lru_eviction() {
621        let cache = DataCache::new(2);
622
623        cache.insert("a".to_string(), DataSource::new("a", "A"));
624        cache.insert("b".to_string(), DataSource::new("b", "B"));
625
626        // Access 'a' to make it recently used
627        let _ = cache.get("a");
628
629        // Insert 'c' - should evict 'b'
630        cache.insert("c".to_string(), DataSource::new("c", "C"));
631
632        assert!(cache.contains("a"), "a should still exist (recently used)");
633        assert!(cache.contains("c"), "c should exist (just inserted)");
634        assert!(!cache.contains("b"), "b should be evicted (LRU)");
635    }
636
637    #[test]
638    fn test_cache_ttl_expiration() {
639        let cache = DataCache::new(10).with_ttl(Duration::from_millis(50));
640
641        cache.insert("short".to_string(), DataSource::new("short", "Short TTL"));
642
643        // Should exist immediately
644        assert!(cache.get("short").is_some());
645
646        // Wait for expiration
647        std::thread::sleep(Duration::from_millis(60));
648
649        // Should be expired
650        assert!(cache.get("short").is_none());
651    }
652
653    #[test]
654    fn test_cache_stats() {
655        let cache = DataCache::new(10);
656
657        cache.insert("a".to_string(), DataSource::new("a", "A"));
658
659        let _ = cache.get("a"); // Hit
660        let _ = cache.get("a"); // Hit
661        let _ = cache.get("b"); // Miss
662
663        use std::sync::atomic::Ordering;
664        assert_eq!(cache.stats().hits.load(Ordering::Relaxed), 2);
665        assert_eq!(cache.stats().misses.load(Ordering::Relaxed), 1);
666        assert!(cache.stats().hit_rate() > 0.6);
667    }
668
669    #[tokio::test]
670    async fn test_pipeline_caching() {
671        let pipeline = DataPipeline::new()
672            .with_ttl(Duration::from_secs(60))
673            .with_loader(|id| DataSource::new(id, format!("Loaded {}", id)));
674
675        // First call - cache miss
676        let start = Instant::now();
677        let _ = pipeline.process("test").await.unwrap();
678        let first_duration = start.elapsed();
679
680        // Second call - cache hit
681        let start2 = Instant::now();
682        let _ = pipeline.process("test").await.unwrap();
683        let second_duration = start2.elapsed();
684
685        // Cache hit should be faster (no simulated load delay)
686        assert!(
687            second_duration < first_duration,
688            "Cache hit should be faster: {:?} vs {:?}",
689            second_duration,
690            first_duration
691        );
692    }
693
694    #[tokio::test]
695    async fn test_pipeline_batch() {
696        let pipeline =
697            DataPipeline::new().with_loader(|id| DataSource::new(id, format!("Source {}", id)));
698
699        let ids = vec!["a".to_string(), "b".to_string(), "c".to_string()];
700        let results = pipeline.process_batch(&ids).await;
701
702        assert_eq!(results.len(), 3);
703        assert!(results.iter().all(|r| r.is_ok()));
704    }
705
706    #[tokio::test]
707    async fn test_pipeline_invalidation() {
708        let pipeline =
709            DataPipeline::new().with_loader(|id| DataSource::new(id, format!("Source {}", id)));
710
711        // Load and cache
712        let _ = pipeline.process("test").await.unwrap();
713        assert!(pipeline.cache.contains("test"));
714
715        // Invalidate
716        pipeline.invalidate("test");
717        assert!(!pipeline.cache.contains("test"));
718    }
719
720    #[test]
721    fn test_concurrent_cache_basic() {
722        let metrics = Arc::new(DataMatchingMetrics::new());
723        let cache = ConcurrentCache::new(10, metrics);
724
725        let source = DataSource::new("test", "Test Source");
726        cache.insert("a".to_string(), source);
727
728        match cache.get("a") {
729            CacheResult::Hit(data) => assert_eq!(data.id, "test"),
730            _ => panic!("Expected cache hit"),
731        }
732
733        match cache.get("nonexistent") {
734            CacheResult::Miss => {}
735            _ => panic!("Expected cache miss"),
736        }
737    }
738
739    #[test]
740    fn test_concurrent_cache_negative() {
741        let metrics = Arc::new(DataMatchingMetrics::new());
742        let cache = ConcurrentCache::new(10, metrics);
743
744        // Insert negative entry
745        cache.insert_negative("missing".to_string());
746
747        // Should get negative hit
748        match cache.get("missing") {
749            CacheResult::NegativeHit => {}
750            _ => panic!("Expected negative cache hit"),
751        }
752
753        // Insert actual data - should remove from negative cache
754        cache.insert("missing".to_string(), DataSource::new("missing", "Found"));
755
756        match cache.get("missing") {
757            CacheResult::Hit(data) => assert_eq!(data.id, "missing"),
758            _ => panic!("Expected cache hit after insert"),
759        }
760    }
761
762    #[tokio::test]
763    async fn test_parallel_pipeline() {
764        let pipeline = ParallelPipeline::new(100)
765            .with_max_concurrency(5)
766            .with_loader(|id| DataSource::new(&id, format!("Source {}", id)));
767
768        let ids: Vec<String> = (0..20).map(|i| format!("source_{}", i)).collect();
769        let results = pipeline.process_parallel(ids).await;
770
771        assert_eq!(results.len(), 20);
772        assert!(results.iter().all(|r| r.is_ok()));
773    }
774
775    #[tokio::test]
776    async fn test_parallel_pipeline_caching() {
777        let pipeline = ParallelPipeline::new(100)
778            .with_loader(|id| DataSource::new(&id, format!("Source {}", id)));
779
780        // First call
781        let _ = pipeline.process("test").await.unwrap();
782
783        // Second call should be cached
784        let start = Instant::now();
785        let _ = pipeline.process("test").await.unwrap();
786        let cached_duration = start.elapsed();
787
788        // Should be very fast (cached)
789        assert!(cached_duration < Duration::from_millis(5));
790    }
791}