1use crate::metrics::DataMatchingMetrics;
6use crate::types::{DataError, DataSource};
7use dashmap::DashMap;
8use futures::future::join_all;
9use lru::LruCache;
10use parking_lot::Mutex;
11use std::num::NonZeroUsize;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tracing::{debug, info};
15
16#[derive(Debug, Clone)]
18pub struct CachedData {
19 pub data: DataSource,
21 pub cached_at: Instant,
23 pub ttl: Duration,
25}
26
27impl CachedData {
28 pub fn new(data: DataSource, ttl: Duration) -> Self {
30 Self {
31 data,
32 cached_at: Instant::now(),
33 ttl,
34 }
35 }
36
37 pub fn is_expired(&self) -> bool {
39 self.cached_at.elapsed() > self.ttl
40 }
41
42 pub fn remaining_ttl(&self) -> Duration {
44 self.ttl.saturating_sub(self.cached_at.elapsed())
45 }
46}
47
48pub struct DataCache {
50 cache: Mutex<LruCache<String, CachedData>>,
52 default_ttl: Duration,
54 stats: Arc<CacheStats>,
56}
57
58#[derive(Debug, Default)]
60pub struct CacheStats {
61 pub hits: std::sync::atomic::AtomicU64,
62 pub misses: std::sync::atomic::AtomicU64,
63 pub evictions: std::sync::atomic::AtomicU64,
64}
65
66impl CacheStats {
67 pub fn hit_rate(&self) -> f64 {
68 use std::sync::atomic::Ordering;
69 let hits = self.hits.load(Ordering::Relaxed);
70 let misses = self.misses.load(Ordering::Relaxed);
71 let total = hits + misses;
72 if total == 0 {
73 0.0
74 } else {
75 hits as f64 / total as f64
76 }
77 }
78}
79
80impl DataCache {
81 pub fn new(capacity: usize) -> Self {
83 Self {
84 cache: Mutex::new(LruCache::new(
85 NonZeroUsize::new(capacity).expect("capacity must be > 0"),
86 )),
87 default_ttl: Duration::from_secs(300), stats: Arc::new(CacheStats::default()),
89 }
90 }
91
92 pub fn with_ttl(mut self, ttl: Duration) -> Self {
94 self.default_ttl = ttl;
95 self
96 }
97
98 pub fn get(&self, key: &str) -> Option<DataSource> {
100 use std::sync::atomic::Ordering;
101
102 let mut cache = self.cache.lock();
103
104 if let Some(entry) = cache.get(key) {
105 if entry.is_expired() {
106 debug!(key = key, "Cache entry expired");
107 cache.pop(key);
108 self.stats.misses.fetch_add(1, Ordering::Relaxed);
109 return None;
110 }
111
112 debug!(key = key, remaining_ttl_ms = ?entry.remaining_ttl().as_millis(), "Cache hit");
113 self.stats.hits.fetch_add(1, Ordering::Relaxed);
114 return Some(entry.data.clone());
115 }
116
117 self.stats.misses.fetch_add(1, Ordering::Relaxed);
118 None
119 }
120
121 pub fn insert(&self, key: String, data: DataSource) {
123 self.insert_with_ttl(key, data, self.default_ttl);
124 }
125
126 pub fn insert_with_ttl(&self, key: String, data: DataSource, ttl: Duration) {
128 use std::sync::atomic::Ordering;
129
130 let mut cache = self.cache.lock();
131
132 if cache.len() >= cache.cap().get() {
134 self.stats.evictions.fetch_add(1, Ordering::Relaxed);
135 }
136
137 cache.put(key, CachedData::new(data, ttl));
138 }
139
140 pub fn remove(&self, key: &str) -> Option<DataSource> {
142 self.cache.lock().pop(key).map(|e| e.data)
143 }
144
145 pub fn clear(&self) {
147 self.cache.lock().clear();
148 }
149
150 pub fn len(&self) -> usize {
152 self.cache.lock().len()
153 }
154
155 pub fn is_empty(&self) -> bool {
157 self.cache.lock().is_empty()
158 }
159
160 pub fn stats(&self) -> &CacheStats {
162 &self.stats
163 }
164
165 pub fn contains(&self, key: &str) -> bool {
167 let cache = self.cache.lock();
168 if let Some(entry) = cache.peek(key) {
169 !entry.is_expired()
170 } else {
171 false
172 }
173 }
174}
175
176pub struct DataPipeline {
178 cache: DataCache,
180 #[allow(clippy::type_complexity)]
182 loader: Option<Arc<dyn Fn(&str) -> DataSource + Send + Sync>>,
183}
184
185impl DataPipeline {
186 pub fn new() -> Self {
188 Self {
189 cache: DataCache::new(100),
190 loader: None,
191 }
192 }
193
194 pub fn with_capacity(mut self, capacity: usize) -> Self {
196 self.cache = DataCache::new(capacity);
197 self
198 }
199
200 pub fn with_ttl(mut self, ttl: Duration) -> Self {
202 self.cache = self.cache.with_ttl(ttl);
203 self
204 }
205
206 pub fn with_loader<F>(mut self, loader: F) -> Self
208 where
209 F: Fn(&str) -> DataSource + Send + Sync + 'static,
210 {
211 self.loader = Some(Arc::new(loader));
212 self
213 }
214
215 pub async fn process(&self, source_id: &str) -> Result<DataSource, DataError> {
217 if let Some(cached) = self.cache.get(source_id) {
219 info!(source_id = source_id, "Returning cached data");
220 return Ok(cached);
221 }
222
223 let data = self.load_source(source_id).await?;
225
226 self.cache.insert(source_id.to_string(), data.clone());
228
229 info!(source_id = source_id, "Loaded and cached data");
230 Ok(data)
231 }
232
233 pub async fn process_batch(&self, source_ids: &[String]) -> Vec<Result<DataSource, DataError>> {
235 let mut results = Vec::with_capacity(source_ids.len());
236
237 for source_id in source_ids {
238 results.push(self.process(source_id).await);
239 }
240
241 results
242 }
243
244 async fn load_source(&self, source_id: &str) -> Result<DataSource, DataError> {
246 tokio::time::sleep(Duration::from_millis(10)).await;
248
249 if let Some(ref loader) = self.loader {
250 Ok(loader(source_id))
251 } else {
252 Ok(DataSource::new(source_id, format!("Source {}", source_id)))
254 }
255 }
256
257 pub fn cache_stats(&self) -> &CacheStats {
259 self.cache.stats()
260 }
261
262 pub fn clear_cache(&self) {
264 self.cache.clear();
265 }
266
267 pub fn invalidate(&self, source_id: &str) {
269 self.cache.remove(source_id);
270 }
271}
272
273impl Default for DataPipeline {
274 fn default() -> Self {
275 Self::new()
276 }
277}
278
279#[derive(Debug, Clone)]
281pub struct NegativeCacheEntry {
282 pub cached_at: Instant,
284 pub ttl: Duration,
286}
287
288impl NegativeCacheEntry {
289 pub fn new(ttl: Duration) -> Self {
290 Self {
291 cached_at: Instant::now(),
292 ttl,
293 }
294 }
295
296 pub fn is_expired(&self) -> bool {
297 self.cached_at.elapsed() > self.ttl
298 }
299}
300
301pub struct ConcurrentCache {
303 cache: DashMap<String, CachedData>,
305 negative_cache: DashMap<String, NegativeCacheEntry>,
307 capacity: usize,
309 default_ttl: Duration,
311 negative_ttl: Duration,
313 metrics: Arc<DataMatchingMetrics>,
315}
316
317impl ConcurrentCache {
318 pub fn new(capacity: usize, metrics: Arc<DataMatchingMetrics>) -> Self {
320 Self {
321 cache: DashMap::with_capacity(capacity),
322 negative_cache: DashMap::with_capacity(capacity / 4),
323 capacity,
324 default_ttl: Duration::from_secs(300),
325 negative_ttl: Duration::from_secs(60),
326 metrics,
327 }
328 }
329
330 pub fn with_ttl(mut self, ttl: Duration) -> Self {
332 self.default_ttl = ttl;
333 self
334 }
335
336 pub fn with_negative_ttl(mut self, ttl: Duration) -> Self {
338 self.negative_ttl = ttl;
339 self
340 }
341
342 pub fn get(&self, key: &str) -> CacheResult {
344 if let Some(entry) = self.negative_cache.get(key) {
346 if !entry.is_expired() {
347 self.metrics.record_cache(true, true);
348 return CacheResult::NegativeHit;
349 } else {
350 drop(entry);
351 self.negative_cache.remove(key);
352 }
353 }
354
355 if let Some(entry) = self.cache.get(key) {
357 if !entry.is_expired() {
358 self.metrics.record_cache(true, false);
359 return CacheResult::Hit(entry.data.clone());
360 } else {
361 drop(entry);
362 self.cache.remove(key);
363 }
364 }
365
366 self.metrics.record_cache(false, false);
367 CacheResult::Miss
368 }
369
370 pub fn insert(&self, key: String, data: DataSource) {
372 if self.cache.len() >= self.capacity {
374 if let Some(entry) = self.cache.iter().next() {
375 let key_to_remove = entry.key().clone();
376 drop(entry);
377 self.cache.remove(&key_to_remove);
378 self.metrics.record_eviction();
379 }
380 }
381
382 self.negative_cache.remove(&key);
384
385 self.cache
386 .insert(key, CachedData::new(data, self.default_ttl));
387 }
388
389 pub fn insert_negative(&self, key: String) {
391 if self.negative_cache.len() >= self.capacity / 4 {
392 if let Some(entry) = self.negative_cache.iter().next() {
393 let key_to_remove = entry.key().clone();
394 drop(entry);
395 self.negative_cache.remove(&key_to_remove);
396 }
397 }
398
399 self.negative_cache
400 .insert(key, NegativeCacheEntry::new(self.negative_ttl));
401 }
402
403 pub fn remove(&self, key: &str) {
405 self.cache.remove(key);
406 self.negative_cache.remove(key);
407 }
408
409 pub fn clear(&self) {
411 self.cache.clear();
412 self.negative_cache.clear();
413 }
414
415 pub fn len(&self) -> usize {
417 self.cache.len()
418 }
419
420 pub fn is_empty(&self) -> bool {
422 self.cache.is_empty()
423 }
424
425 pub fn negative_len(&self) -> usize {
427 self.negative_cache.len()
428 }
429}
430
431#[derive(Debug, Clone)]
433pub enum CacheResult {
434 Hit(DataSource),
436 NegativeHit,
438 Miss,
440}
441
442pub struct ParallelPipeline {
444 cache: Arc<ConcurrentCache>,
446 loader: Option<Arc<dyn Fn(String) -> DataSource + Send + Sync>>,
448 metrics: Arc<DataMatchingMetrics>,
450 max_concurrency: usize,
452}
453
454impl ParallelPipeline {
455 pub fn new(capacity: usize) -> Self {
457 let metrics = Arc::new(DataMatchingMetrics::new());
458 Self {
459 cache: Arc::new(ConcurrentCache::new(capacity, metrics.clone())),
460 loader: None,
461 metrics,
462 max_concurrency: 10,
463 }
464 }
465
466 pub fn with_max_concurrency(mut self, max: usize) -> Self {
468 self.max_concurrency = max;
469 self
470 }
471
472 pub fn with_ttl(mut self, ttl: Duration) -> Self {
474 self.cache =
475 Arc::new(ConcurrentCache::new(self.cache.capacity, self.metrics.clone()).with_ttl(ttl));
476 self
477 }
478
479 pub fn with_loader<F>(mut self, loader: F) -> Self
481 where
482 F: Fn(String) -> DataSource + Send + Sync + 'static,
483 {
484 self.loader = Some(Arc::new(loader));
485 self
486 }
487
488 pub async fn process(&self, source_id: &str) -> Result<DataSource, DataError> {
490 let start = Instant::now();
491
492 match self.cache.get(source_id) {
493 CacheResult::Hit(data) => {
494 self.metrics.record_query(true, start.elapsed());
495 return Ok(data);
496 }
497 CacheResult::NegativeHit => {
498 self.metrics.record_query(false, start.elapsed());
499 return Err(DataError::SourceNotFound(source_id.to_string()));
500 }
501 CacheResult::Miss => {}
502 }
503
504 let result = self.load_source(source_id).await;
505
506 match result {
507 Ok(data) => {
508 self.cache.insert(source_id.to_string(), data.clone());
509 self.metrics.record_query(true, start.elapsed());
510 Ok(data)
511 }
512 Err(e) => {
513 self.cache.insert_negative(source_id.to_string());
514 self.metrics.record_query(false, start.elapsed());
515 Err(e)
516 }
517 }
518 }
519
520 pub async fn process_parallel(
522 &self,
523 source_ids: Vec<String>,
524 ) -> Vec<Result<DataSource, DataError>> {
525 let chunks: Vec<_> = source_ids
526 .chunks(self.max_concurrency)
527 .map(|c| c.to_vec())
528 .collect();
529
530 let mut all_results = Vec::with_capacity(source_ids.len());
531
532 for chunk in chunks {
533 let tasks: Vec<_> = chunk
534 .into_iter()
535 .map(|id| {
536 let cache = self.cache.clone();
537 let loader = self.loader.clone();
538 let metrics = self.metrics.clone();
539 async move {
540 let start = Instant::now();
541
542 match cache.get(&id) {
543 CacheResult::Hit(data) => {
544 metrics.record_query(true, start.elapsed());
545 return Ok(data);
546 }
547 CacheResult::NegativeHit => {
548 metrics.record_query(false, start.elapsed());
549 return Err(DataError::SourceNotFound(id));
550 }
551 CacheResult::Miss => {}
552 }
553
554 tokio::time::sleep(Duration::from_millis(10)).await;
555
556 if let Some(ref loader) = loader {
557 let data = loader(id.clone());
558 cache.insert(id, data.clone());
559 metrics.record_query(true, start.elapsed());
560 Ok(data)
561 } else {
562 let data = DataSource::new(&id, format!("Source {}", id));
563 cache.insert(id, data.clone());
564 metrics.record_query(true, start.elapsed());
565 Ok(data)
566 }
567 }
568 })
569 .collect();
570
571 let chunk_results = join_all(tasks).await;
572 all_results.extend(chunk_results);
573 }
574
575 all_results
576 }
577
578 async fn load_source(&self, source_id: &str) -> Result<DataSource, DataError> {
580 tokio::time::sleep(Duration::from_millis(10)).await;
581
582 if let Some(ref loader) = self.loader {
583 Ok(loader(source_id.to_string()))
584 } else {
585 Ok(DataSource::new(source_id, format!("Source {}", source_id)))
586 }
587 }
588
589 pub fn metrics(&self) -> &DataMatchingMetrics {
591 &self.metrics
592 }
593
594 pub fn clear_cache(&self) {
596 self.cache.clear();
597 }
598}
599
600#[cfg(test)]
601mod tests {
602 use super::*;
603
604 #[test]
605 fn test_cache_basic() {
606 let cache = DataCache::new(2);
607
608 let source = DataSource::new("test", "Test Source");
609 cache.insert("a".to_string(), source.clone());
610
611 assert!(cache.contains("a"));
612 assert!(!cache.contains("b"));
613
614 let retrieved = cache.get("a");
615 assert!(retrieved.is_some());
616 assert_eq!(retrieved.unwrap().id, "test");
617 }
618
619 #[test]
620 fn test_cache_lru_eviction() {
621 let cache = DataCache::new(2);
622
623 cache.insert("a".to_string(), DataSource::new("a", "A"));
624 cache.insert("b".to_string(), DataSource::new("b", "B"));
625
626 let _ = cache.get("a");
628
629 cache.insert("c".to_string(), DataSource::new("c", "C"));
631
632 assert!(cache.contains("a"), "a should still exist (recently used)");
633 assert!(cache.contains("c"), "c should exist (just inserted)");
634 assert!(!cache.contains("b"), "b should be evicted (LRU)");
635 }
636
637 #[test]
638 fn test_cache_ttl_expiration() {
639 let cache = DataCache::new(10).with_ttl(Duration::from_millis(50));
640
641 cache.insert("short".to_string(), DataSource::new("short", "Short TTL"));
642
643 assert!(cache.get("short").is_some());
645
646 std::thread::sleep(Duration::from_millis(60));
648
649 assert!(cache.get("short").is_none());
651 }
652
653 #[test]
654 fn test_cache_stats() {
655 let cache = DataCache::new(10);
656
657 cache.insert("a".to_string(), DataSource::new("a", "A"));
658
659 let _ = cache.get("a"); let _ = cache.get("a"); let _ = cache.get("b"); use std::sync::atomic::Ordering;
664 assert_eq!(cache.stats().hits.load(Ordering::Relaxed), 2);
665 assert_eq!(cache.stats().misses.load(Ordering::Relaxed), 1);
666 assert!(cache.stats().hit_rate() > 0.6);
667 }
668
669 #[tokio::test]
670 async fn test_pipeline_caching() {
671 let pipeline = DataPipeline::new()
672 .with_ttl(Duration::from_secs(60))
673 .with_loader(|id| DataSource::new(id, format!("Loaded {}", id)));
674
675 let start = Instant::now();
677 let _ = pipeline.process("test").await.unwrap();
678 let first_duration = start.elapsed();
679
680 let start2 = Instant::now();
682 let _ = pipeline.process("test").await.unwrap();
683 let second_duration = start2.elapsed();
684
685 assert!(
687 second_duration < first_duration,
688 "Cache hit should be faster: {:?} vs {:?}",
689 second_duration,
690 first_duration
691 );
692 }
693
694 #[tokio::test]
695 async fn test_pipeline_batch() {
696 let pipeline =
697 DataPipeline::new().with_loader(|id| DataSource::new(id, format!("Source {}", id)));
698
699 let ids = vec!["a".to_string(), "b".to_string(), "c".to_string()];
700 let results = pipeline.process_batch(&ids).await;
701
702 assert_eq!(results.len(), 3);
703 assert!(results.iter().all(|r| r.is_ok()));
704 }
705
706 #[tokio::test]
707 async fn test_pipeline_invalidation() {
708 let pipeline =
709 DataPipeline::new().with_loader(|id| DataSource::new(id, format!("Source {}", id)));
710
711 let _ = pipeline.process("test").await.unwrap();
713 assert!(pipeline.cache.contains("test"));
714
715 pipeline.invalidate("test");
717 assert!(!pipeline.cache.contains("test"));
718 }
719
720 #[test]
721 fn test_concurrent_cache_basic() {
722 let metrics = Arc::new(DataMatchingMetrics::new());
723 let cache = ConcurrentCache::new(10, metrics);
724
725 let source = DataSource::new("test", "Test Source");
726 cache.insert("a".to_string(), source);
727
728 match cache.get("a") {
729 CacheResult::Hit(data) => assert_eq!(data.id, "test"),
730 _ => panic!("Expected cache hit"),
731 }
732
733 match cache.get("nonexistent") {
734 CacheResult::Miss => {}
735 _ => panic!("Expected cache miss"),
736 }
737 }
738
739 #[test]
740 fn test_concurrent_cache_negative() {
741 let metrics = Arc::new(DataMatchingMetrics::new());
742 let cache = ConcurrentCache::new(10, metrics);
743
744 cache.insert_negative("missing".to_string());
746
747 match cache.get("missing") {
749 CacheResult::NegativeHit => {}
750 _ => panic!("Expected negative cache hit"),
751 }
752
753 cache.insert("missing".to_string(), DataSource::new("missing", "Found"));
755
756 match cache.get("missing") {
757 CacheResult::Hit(data) => assert_eq!(data.id, "missing"),
758 _ => panic!("Expected cache hit after insert"),
759 }
760 }
761
762 #[tokio::test]
763 async fn test_parallel_pipeline() {
764 let pipeline = ParallelPipeline::new(100)
765 .with_max_concurrency(5)
766 .with_loader(|id| DataSource::new(&id, format!("Source {}", id)));
767
768 let ids: Vec<String> = (0..20).map(|i| format!("source_{}", i)).collect();
769 let results = pipeline.process_parallel(ids).await;
770
771 assert_eq!(results.len(), 20);
772 assert!(results.iter().all(|r| r.is_ok()));
773 }
774
775 #[tokio::test]
776 async fn test_parallel_pipeline_caching() {
777 let pipeline = ParallelPipeline::new(100)
778 .with_loader(|id| DataSource::new(&id, format!("Source {}", id)));
779
780 let _ = pipeline.process("test").await.unwrap();
782
783 let start = Instant::now();
785 let _ = pipeline.process("test").await.unwrap();
786 let cached_duration = start.elapsed();
787
788 assert!(cached_duration < Duration::from_millis(5));
790 }
791}