Skip to main content

heliosdb_proxy/distribcache/
mod.rs

1//! Helios-DistribCache - Intelligent Distributed Caching Layer
2//!
3//! A multi-tier distributed caching system with workload-aware strategies,
4//! intelligent prefetching, and AI/Agent optimizations.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ┌─────────────────────────────────────────────────────────────────┐
10//! │                     WORKLOAD CLASSIFIER                          │
11//! │  ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐   │
12//! │  │  OLTP   │ │  OLAP   │ │ Vector  │ │AIAgent  │ │   RAG   │   │
13//! │  └─────────┘ └─────────┘ └─────────┘ └─────────┘ └─────────┘   │
14//! └─────────────────────────────────────────────────────────────────┘
15//!                               │
16//! ┌─────────────────────────────────────────────────────────────────┐
17//! │                     MULTI-TIER CACHE                             │
18//! │  ┌───────────────────────────────────────────────────────────┐  │
19//! │  │ L1: Hot Cache (In-Memory, <100μs)                         │  │
20//! │  └───────────────────────────────────────────────────────────┘  │
21//! │                          │ miss                                  │
22//! │  ┌───────────────────────────────────────────────────────────┐  │
23//! │  │ L2: Warm Cache (Local SSD, <1ms)                          │  │
24//! │  └───────────────────────────────────────────────────────────┘  │
25//! │                          │ miss                                  │
26//! │  ┌───────────────────────────────────────────────────────────┐  │
27//! │  │ L3: Distributed Cache (Mesh, <10ms)                       │  │
28//! │  └───────────────────────────────────────────────────────────┘  │
29//! └─────────────────────────────────────────────────────────────────┘
30//! ```
31//!
32//! # Features
33//!
34//! - **Multi-Tier Caching**: L1 hot (memory), L2 warm (SSD), L3 distributed (mesh)
35//! - **Workload Classification**: OLTP, OLAP, Vector, AI Agent, RAG pipelines
36//! - **Intelligent Prefetching**: Pattern-based and temporal prediction
37//! - **WAL-Based Invalidation**: Real-time cache coherency via WAL streaming
38//! - **Heatmap Analytics**: Visual cache utilization and recommendations
39//! - **AI/Agent Caches**: Conversation context, RAG chunks, tool results, semantic queries
40
41pub mod config;
42pub mod classifier;
43pub mod tiers;
44pub mod prefetcher;
45pub mod invalidator;
46pub mod heatmap;
47pub mod scheduler;
48pub mod ai;
49pub mod metrics;
50
51pub use config::*;
52pub use classifier::*;
53pub use tiers::{
54    HotCache, WarmCache, DistributedCache,
55    CacheEntry, CacheKey, CacheTier, TierStats,
56    EvictionPolicy, CompressionType,
57};
58pub use prefetcher::*;
59pub use invalidator::*;
60pub use heatmap::*;
61pub use scheduler::*;
62pub use ai::{
63    ConversationContextCache, ConversationContext, Turn, ConversationCacheStats,
64    RagChunkCache, Chunk, ChunkId, RagCacheStatsSnapshot,
65    ToolResultCache, ToolCallKey, ToolResult, ToolCacheStatsSnapshot,
66    SemanticQueryCache, SemanticEntry, SemanticCacheStatsSnapshot, cosine_similarity,
67    // Branch-aware and session-aware types (SessionId is defined locally as newtype)
68    BranchContext, BranchId, AIWorkloadContext, VectorId, Embedding,
69    SemanticIndex, SemanticIndexConfig, SimilarityResult,
70    // Cross-feature AI integration
71    AIIntegrationCoordinator, AIIntegrationConfig, AIIntegrationStatsSnapshot,
72    AIWorkloadDetection, SessionTrackingInfo, CacheRecommendation,
73    CachePriority, RecommendedTier,
74};
75pub use metrics::{DistribCacheMetrics, InvalidationSource, ErrorType};
76
77use std::sync::atomic::{AtomicU64, Ordering};
78use std::sync::Arc;
79use std::time::{Duration, Instant};
80use thiserror::Error;
81
82/// Cache errors
83#[derive(Debug, Error)]
84pub enum CacheError {
85    #[error("Cache miss")]
86    Miss,
87
88    #[error("Entry expired")]
89    Expired,
90
91    #[error("Entry too large: {0} bytes (max: {1})")]
92    TooLarge(usize, usize),
93
94    #[error("Tier unavailable: {0}")]
95    TierUnavailable(String),
96
97    #[error("Peer not found: {0}")]
98    PeerNotFound(String),
99
100    #[error("Serialization error: {0}")]
101    Serialization(String),
102
103    #[error("Compression error: {0}")]
104    Compression(String),
105
106    #[error("Storage error: {0}")]
107    Storage(String),
108
109    #[error("Network error: {0}")]
110    Network(String),
111
112    #[error("Invalidation error: {0}")]
113    Invalidation(String),
114
115    #[error("Configuration error: {0}")]
116    Configuration(String),
117
118    #[error("Connection error: {0}")]
119    ConnectionError(String),
120}
121
122pub type CacheResult<T> = std::result::Result<T, CacheError>;
123
124/// Query fingerprint for cache key generation
125#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
126pub struct QueryFingerprint {
127    /// Normalized query template
128    pub template: String,
129    /// Table names referenced
130    pub tables: Vec<String>,
131    /// Parameter hash (for parameterized queries)
132    pub param_hash: Option<u64>,
133}
134
135impl QueryFingerprint {
136    /// Create a new fingerprint from a query
137    pub fn from_query(query: &str) -> Self {
138        let template = Self::normalize_query(query);
139        let tables = Self::extract_tables(&template);
140        let param_hash = None; // Set separately if parameterized
141
142        Self {
143            template,
144            tables,
145            param_hash,
146        }
147    }
148
149    /// Create fingerprint with parameter binding
150    pub fn with_params(mut self, params: &[&str]) -> Self {
151        use std::hash::{Hash, Hasher};
152        use std::collections::hash_map::DefaultHasher;
153
154        let mut hasher = DefaultHasher::new();
155        for param in params {
156            param.hash(&mut hasher);
157        }
158        self.param_hash = Some(hasher.finish());
159        self
160    }
161
162    /// Normalize query by removing literals and whitespace
163    fn normalize_query(query: &str) -> String {
164        let upper = query.to_uppercase();
165        // Simple normalization - replace string literals and numbers
166        let mut result = String::new();
167        let mut in_string = false;
168        let mut quote_char = ' ';
169
170        for ch in upper.chars() {
171            if in_string {
172                if ch == quote_char {
173                    in_string = false;
174                    result.push('?');
175                }
176            } else if ch == '\'' || ch == '"' {
177                in_string = true;
178                quote_char = ch;
179            } else if ch.is_numeric() {
180                if !result.ends_with('?') {
181                    result.push('?');
182                }
183            } else if ch.is_whitespace() {
184                if !result.ends_with(' ') {
185                    result.push(' ');
186                }
187            } else {
188                result.push(ch);
189            }
190        }
191
192        result.trim().to_string()
193    }
194
195    /// Extract table names from query
196    fn extract_tables(query: &str) -> Vec<String> {
197        let mut tables = Vec::new();
198        let words: Vec<&str> = query.split_whitespace().collect();
199
200        for (i, word) in words.iter().enumerate() {
201            if *word == "FROM" || *word == "JOIN" || *word == "INTO" || *word == "UPDATE" {
202                if let Some(table) = words.get(i + 1) {
203                    let table_name = table.trim_matches(|c| c == '(' || c == ')' || c == ',');
204                    if !table_name.is_empty() && !tables.contains(&table_name.to_string()) {
205                        tables.push(table_name.to_string());
206                    }
207                }
208            }
209        }
210
211        tables
212    }
213
214    /// Convert to bytes for storage
215    pub fn to_bytes(&self) -> Vec<u8> {
216        let mut bytes = Vec::new();
217        bytes.extend_from_slice(self.template.as_bytes());
218        if let Some(hash) = self.param_hash {
219            bytes.extend_from_slice(&hash.to_le_bytes());
220        }
221        bytes
222    }
223}
224
225/// Session identifier for session affinity
226#[derive(Debug, Clone, Hash, PartialEq, Eq)]
227pub struct SessionId(pub String);
228
229impl SessionId {
230    pub fn new(id: impl Into<String>) -> Self {
231        Self(id.into())
232    }
233}
234
235/// Query context for cache decisions
236#[derive(Debug, Clone)]
237pub struct QueryContext {
238    /// Session identifier
239    pub session_id: SessionId,
240    /// Workload hint (if explicitly specified)
241    pub workload_hint: Option<WorkloadType>,
242    /// Branch name (for branch-aware caching)
243    pub branch: Option<String>,
244    /// Time travel timestamp (for historical queries)
245    pub as_of: Option<u64>,
246    /// Whether this is a prepared statement
247    pub is_prepared: bool,
248    /// Request timestamp
249    pub timestamp: Instant,
250}
251
252impl QueryContext {
253    pub fn new(session_id: impl Into<String>) -> Self {
254        Self {
255            session_id: SessionId::new(session_id),
256            workload_hint: None,
257            branch: None,
258            as_of: None,
259            is_prepared: false,
260            timestamp: Instant::now(),
261        }
262    }
263
264    pub fn with_workload_hint(mut self, hint: WorkloadType) -> Self {
265        self.workload_hint = Some(hint);
266        self
267    }
268
269    pub fn with_branch(mut self, branch: impl Into<String>) -> Self {
270        self.branch = Some(branch.into());
271        self
272    }
273
274    pub fn with_as_of(mut self, timestamp: u64) -> Self {
275        self.as_of = Some(timestamp);
276        self
277    }
278}
279
280/// Helios-DistribCache - Main distributed cache instance
281pub struct HeliosDistribCache {
282    /// Workload classifier
283    classifier: WorkloadClassifier,
284
285    /// L1 hot cache (in-memory)
286    l1_hot: Arc<HotCache>,
287
288    /// L2 warm cache (SSD)
289    l2_warm: Arc<WarmCache>,
290
291    /// L3 distributed cache (mesh)
292    l3_distributed: Arc<DistributedCache>,
293
294    /// Predictive prefetcher
295    prefetcher: Arc<PredictivePrefetcher>,
296
297    /// WAL-based invalidator
298    invalidator: Arc<WalInvalidator>,
299
300    /// Cache heatmap analytics
301    heatmap: Arc<CacheHeatmap>,
302
303    /// Workload scheduler
304    scheduler: Arc<WorkloadScheduler>,
305
306    /// AI/Agent caches
307    conversation_cache: Arc<ConversationContextCache>,
308    rag_cache: Arc<RagChunkCache>,
309    tool_cache: Arc<ToolResultCache>,
310    semantic_cache: Arc<SemanticQueryCache>,
311
312    /// Metrics
313    metrics: Arc<DistribCacheMetrics>,
314
315    /// Configuration
316    config: DistribCacheConfig,
317
318    /// Statistics
319    stats: CacheStatistics,
320}
321
322/// Cache statistics
323#[derive(Debug, Default)]
324struct CacheStatistics {
325    total_lookups: AtomicU64,
326    l1_hits: AtomicU64,
327    l2_hits: AtomicU64,
328    l3_hits: AtomicU64,
329    total_misses: AtomicU64,
330    time_saved_us: AtomicU64,
331    queries_avoided: AtomicU64,
332}
333
334impl HeliosDistribCache {
335    /// Create a new distributed cache instance
336    pub fn new(config: DistribCacheConfig) -> Self {
337        let l1_hot = Arc::new(HotCache::new(
338            config.l1_size_mb * 1024 * 1024,
339            config.l1_max_entry_size,
340            config.l1_eviction_policy,
341        ));
342
343        let l2_warm = Arc::new(WarmCache::new(
344            config.l2_size_gb * 1024 * 1024 * 1024,
345            config.l2_path.clone(),
346            config.l2_compression,
347        ));
348
349        let l3_distributed = Arc::new(DistributedCache::new(
350            config.l3_replication_factor,
351            config.l3_peers.clone(),
352        ));
353
354        let classifier = WorkloadClassifier::new(config.clone());
355        let prefetcher = Arc::new(PredictivePrefetcher::new(config.clone()));
356        let invalidator = Arc::new(WalInvalidator::new(config.clone()));
357        let heatmap = Arc::new(CacheHeatmap::new());
358        let scheduler = Arc::new(WorkloadScheduler::new(config.clone()));
359        let metrics = Arc::new(DistribCacheMetrics::new());
360
361        // AI caches
362        let conversation_cache = Arc::new(ConversationContextCache::new(1000, 50));
363        let rag_cache = Arc::new(RagChunkCache::new(config.l1_size_mb / 4));
364        let tool_cache = Arc::new(ToolResultCache::new());
365        let semantic_cache = Arc::new(SemanticQueryCache::new(0.85));
366
367        Self {
368            classifier,
369            l1_hot,
370            l2_warm,
371            l3_distributed,
372            prefetcher,
373            invalidator,
374            heatmap,
375            scheduler,
376            conversation_cache,
377            rag_cache,
378            tool_cache,
379            semantic_cache,
380            metrics,
381            config,
382            stats: CacheStatistics::default(),
383        }
384    }
385
386    /// Get an entry from the cache (checking all tiers)
387    pub async fn get(
388        &self,
389        fingerprint: &QueryFingerprint,
390        context: &QueryContext,
391    ) -> CacheResult<CacheEntry> {
392        self.stats.total_lookups.fetch_add(1, Ordering::Relaxed);
393        let start = Instant::now();
394
395        // Classify workload for scheduling
396        let _workload = self.classifier.classify_query(&fingerprint.template, context);
397
398        // Check L1 hot cache first
399        if let Some(entry) = self.l1_hot.get(fingerprint, context.session_id.clone()) {
400            self.stats.l1_hits.fetch_add(1, Ordering::Relaxed);
401            self.record_hit(fingerprint, CacheTier::L1, start.elapsed());
402            return Ok(entry);
403        }
404
405        // Check L2 warm cache
406        if self.config.l2_enabled {
407            if let Some(entry) = self.l2_warm.get(fingerprint) {
408                self.stats.l2_hits.fetch_add(1, Ordering::Relaxed);
409                // Promote to L1
410                self.l1_hot.insert(
411                    fingerprint.clone(),
412                    entry.clone(),
413                    Some(context.session_id.clone()),
414                );
415                self.record_hit(fingerprint, CacheTier::L2, start.elapsed());
416                return Ok(entry);
417            }
418        }
419
420        // Check L3 distributed cache
421        if self.config.l3_enabled {
422            if let Some(entry) = self.l3_distributed.get(fingerprint).await {
423                self.stats.l3_hits.fetch_add(1, Ordering::Relaxed);
424                // Promote to L1 and L2
425                self.l1_hot.insert(
426                    fingerprint.clone(),
427                    entry.clone(),
428                    Some(context.session_id.clone()),
429                );
430                if self.config.l2_enabled {
431                    self.l2_warm.insert(fingerprint.clone(), entry.clone());
432                }
433                self.record_hit(fingerprint, CacheTier::L3, start.elapsed());
434                return Ok(entry);
435            }
436        }
437
438        // Cache miss
439        self.stats.total_misses.fetch_add(1, Ordering::Relaxed);
440        self.heatmap.record_access(fingerprint, false, Duration::ZERO);
441
442        // Trigger prefetching for related queries
443        if self.config.prefetch_enabled {
444            self.prefetcher.predict_and_prefetch(fingerprint, &context.session_id);
445        }
446
447        Err(CacheError::Miss)
448    }
449
450    /// Insert an entry into the cache
451    pub async fn insert(
452        &self,
453        fingerprint: QueryFingerprint,
454        entry: CacheEntry,
455        context: &QueryContext,
456    ) -> CacheResult<()> {
457        let workload = self.classifier.classify_query(&fingerprint.template, context);
458        let ttl = self.get_ttl_for_workload(workload);
459
460        let entry = entry.with_ttl(ttl);
461
462        // Insert into L1
463        self.l1_hot.insert(
464            fingerprint.clone(),
465            entry.clone(),
466            Some(context.session_id.clone()),
467        );
468
469        // Insert into L2 if entry is large enough and TTL warrants it
470        if self.config.l2_enabled && entry.size() > 1024 && ttl > Duration::from_secs(60) {
471            self.l2_warm.insert(fingerprint.clone(), entry.clone());
472        }
473
474        // Insert into L3 for shared caching
475        if self.config.l3_enabled && !matches!(workload, WorkloadType::OLTP) {
476            self.l3_distributed.insert(fingerprint.clone(), entry.clone()).await;
477        }
478
479        // Record for prefetcher learning
480        if self.config.prefetch_enabled {
481            self.prefetcher.record(&context.session_id, fingerprint);
482        }
483
484        Ok(())
485    }
486
487    /// Invalidate entries for a table
488    pub fn invalidate_table(&self, table: &str) {
489        self.l1_hot.invalidate_by_table(table);
490        if self.config.l2_enabled {
491            self.l2_warm.invalidate_by_table(table);
492        }
493        // L3 invalidation is handled by gossip protocol
494    }
495
496    /// Invalidate a specific entry
497    pub fn invalidate(&self, fingerprint: &QueryFingerprint) {
498        self.l1_hot.invalidate(fingerprint);
499        if self.config.l2_enabled {
500            self.l2_warm.invalidate(fingerprint);
501        }
502    }
503
504    /// Get TTL based on workload type
505    fn get_ttl_for_workload(&self, workload: WorkloadType) -> Duration {
506        match workload {
507            WorkloadType::OLTP => self.config.oltp_cache_ttl,
508            WorkloadType::OLAP => self.config.olap_cache_ttl,
509            WorkloadType::Vector => self.config.vector_cache_ttl,
510            WorkloadType::AIAgent => self.config.ai_agent_cache_ttl,
511            WorkloadType::RAG => self.config.rag_cache_ttl,
512            WorkloadType::Mixed => self.config.default_cache_ttl,
513        }
514    }
515
516    /// Record cache hit for metrics and heatmap
517    fn record_hit(&self, fingerprint: &QueryFingerprint, tier: CacheTier, _latency: Duration) {
518        let time_saved = match tier {
519            CacheTier::L1 => Duration::from_millis(10), // Assume 10ms DB query
520            CacheTier::L2 => Duration::from_millis(9),
521            CacheTier::L3 => Duration::from_millis(5),
522        };
523
524        self.stats.time_saved_us.fetch_add(
525            time_saved.as_micros() as u64,
526            Ordering::Relaxed,
527        );
528        self.stats.queries_avoided.fetch_add(1, Ordering::Relaxed);
529        self.heatmap.record_access(fingerprint, true, time_saved);
530    }
531
532    /// Get conversation context cache
533    pub fn conversation_cache(&self) -> &ConversationContextCache {
534        &self.conversation_cache
535    }
536
537    /// Get RAG chunk cache
538    pub fn rag_cache(&self) -> &RagChunkCache {
539        &self.rag_cache
540    }
541
542    /// Get tool result cache
543    pub fn tool_cache(&self) -> &ToolResultCache {
544        &self.tool_cache
545    }
546
547    /// Get semantic query cache
548    pub fn semantic_cache(&self) -> &SemanticQueryCache {
549        &self.semantic_cache
550    }
551
552    /// Get cache statistics
553    pub fn stats(&self) -> DistribCacheStats {
554        let total = self.stats.total_lookups.load(Ordering::Relaxed);
555        let l1_hits = self.stats.l1_hits.load(Ordering::Relaxed);
556        let l2_hits = self.stats.l2_hits.load(Ordering::Relaxed);
557        let l3_hits = self.stats.l3_hits.load(Ordering::Relaxed);
558        let _misses = self.stats.total_misses.load(Ordering::Relaxed);
559
560        DistribCacheStats {
561            l1: self.l1_hot.stats(),
562            l2: self.l2_warm.stats(),
563            l3: self.l3_distributed.stats(),
564            overall_hit_ratio: if total > 0 {
565                (l1_hits + l2_hits + l3_hits) as f64 / total as f64
566            } else {
567                0.0
568            },
569            time_saved_seconds: self.stats.time_saved_us.load(Ordering::Relaxed) as f64 / 1_000_000.0,
570            queries_avoided: self.stats.queries_avoided.load(Ordering::Relaxed),
571        }
572    }
573
574    /// Generate heatmap data
575    pub fn heatmap(&self) -> HeatmapData {
576        self.heatmap.generate_heatmap()
577    }
578
579    /// Get workload distribution
580    pub fn workload_distribution(&self) -> WorkloadDistribution {
581        self.scheduler.get_distribution()
582    }
583
584    /// Start background services (prefetcher, invalidator)
585    pub async fn start(&self) -> CacheResult<()> {
586        // Start WAL invalidator if configured
587        if let Some(wal_endpoint) = &self.config.wal_endpoint {
588            self.invalidator.start(wal_endpoint).await?;
589        }
590
591        // Start prefetcher background worker
592        if self.config.prefetch_enabled {
593            self.prefetcher.start().await;
594        }
595
596        Ok(())
597    }
598
599    /// Stop background services
600    pub async fn stop(&self) -> CacheResult<()> {
601        self.invalidator.stop().await;
602        self.prefetcher.stop().await;
603        Ok(())
604    }
605}
606
607/// Cache statistics snapshot
608#[derive(Debug, Clone)]
609pub struct DistribCacheStats {
610    /// L1 tier stats
611    pub l1: TierStats,
612    /// L2 tier stats
613    pub l2: TierStats,
614    /// L3 tier stats
615    pub l3: TierStats,
616    /// Overall hit ratio
617    pub overall_hit_ratio: f64,
618    /// Total time saved in seconds
619    pub time_saved_seconds: f64,
620    /// Total queries avoided
621    pub queries_avoided: u64,
622}
623
624#[cfg(test)]
625mod tests {
626    use super::*;
627
628    #[test]
629    fn test_query_fingerprint() {
630        let fp = QueryFingerprint::from_query("SELECT * FROM users WHERE id = 42");
631        assert!(fp.template.contains("SELECT"));
632        assert!(fp.template.contains("USERS"));
633        assert!(fp.tables.contains(&"USERS".to_string()));
634    }
635
636    #[test]
637    fn test_query_fingerprint_normalization() {
638        let fp1 = QueryFingerprint::from_query("SELECT * FROM users WHERE id = 1");
639        let fp2 = QueryFingerprint::from_query("SELECT * FROM users WHERE id = 2");
640        // Both should have same template after normalization
641        assert_eq!(fp1.template, fp2.template);
642    }
643
644    #[test]
645    fn test_session_id() {
646        let sid = SessionId::new("test-session");
647        assert_eq!(sid.0, "test-session");
648    }
649
650    #[test]
651    fn test_query_context() {
652        let ctx = QueryContext::new("session-1")
653            .with_workload_hint(WorkloadType::OLTP)
654            .with_branch("feature-x");
655
656        assert_eq!(ctx.session_id.0, "session-1");
657        assert_eq!(ctx.workload_hint, Some(WorkloadType::OLTP));
658        assert_eq!(ctx.branch, Some("feature-x".to_string()));
659    }
660}