Skip to main content

heliosdb_proxy/distribcache/
mod.rs

1//! Helios-DistribCache - Intelligent Distributed Caching Layer
2//!
3//! A multi-tier distributed caching system with workload-aware strategies,
4//! intelligent prefetching, and AI/Agent optimizations.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ┌─────────────────────────────────────────────────────────────────┐
10//! │                     WORKLOAD CLASSIFIER                          │
11//! │  ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐   │
12//! │  │  OLTP   │ │  OLAP   │ │ Vector  │ │AIAgent  │ │   RAG   │   │
13//! │  └─────────┘ └─────────┘ └─────────┘ └─────────┘ └─────────┘   │
14//! └─────────────────────────────────────────────────────────────────┘
15//!                               │
16//! ┌─────────────────────────────────────────────────────────────────┐
17//! │                     MULTI-TIER CACHE                             │
18//! │  ┌───────────────────────────────────────────────────────────┐  │
19//! │  │ L1: Hot Cache (In-Memory, <100μs)                         │  │
20//! │  └───────────────────────────────────────────────────────────┘  │
21//! │                          │ miss                                  │
22//! │  ┌───────────────────────────────────────────────────────────┐  │
23//! │  │ L2: Warm Cache (Local SSD, <1ms)                          │  │
24//! │  └───────────────────────────────────────────────────────────┘  │
25//! │                          │ miss                                  │
26//! │  ┌───────────────────────────────────────────────────────────┐  │
27//! │  │ L3: Distributed Cache (Mesh, <10ms)                       │  │
28//! │  └───────────────────────────────────────────────────────────┘  │
29//! └─────────────────────────────────────────────────────────────────┘
30//! ```
31//!
32//! # Features
33//!
34//! - **Multi-Tier Caching**: L1 hot (memory), L2 warm (SSD), L3 distributed (mesh)
35//! - **Workload Classification**: OLTP, OLAP, Vector, AI Agent, RAG pipelines
36//! - **Intelligent Prefetching**: Pattern-based and temporal prediction
37//! - **WAL-Based Invalidation**: Real-time cache coherency via WAL streaming
38//! - **Heatmap Analytics**: Visual cache utilization and recommendations
39//! - **AI/Agent Caches**: Conversation context, RAG chunks, tool results, semantic queries
40
41pub mod ai;
42pub mod classifier;
43pub mod config;
44pub mod heatmap;
45pub mod invalidator;
46pub mod metrics;
47pub mod prefetcher;
48pub mod scheduler;
49pub mod tiers;
50
51pub use ai::{
52    cosine_similarity,
53    AIIntegrationConfig,
54    // Cross-feature AI integration
55    AIIntegrationCoordinator,
56    AIIntegrationStatsSnapshot,
57    AIWorkloadContext,
58    AIWorkloadDetection,
59    // Branch-aware and session-aware types (SessionId is defined locally as newtype)
60    BranchContext,
61    BranchId,
62    CachePriority,
63    CacheRecommendation,
64    Chunk,
65    ChunkId,
66    ConversationCacheStats,
67    ConversationContext,
68    ConversationContextCache,
69    Embedding,
70    RagCacheStatsSnapshot,
71    RagChunkCache,
72    RecommendedTier,
73    SemanticCacheStatsSnapshot,
74    SemanticEntry,
75    SemanticIndex,
76    SemanticIndexConfig,
77    SemanticQueryCache,
78    SessionTrackingInfo,
79    SimilarityResult,
80    ToolCacheStatsSnapshot,
81    ToolCallKey,
82    ToolResult,
83    ToolResultCache,
84    Turn,
85    VectorId,
86};
87pub use classifier::*;
88pub use config::*;
89pub use heatmap::*;
90pub use invalidator::*;
91pub use metrics::{DistribCacheMetrics, ErrorType, InvalidationSource};
92pub use prefetcher::*;
93pub use scheduler::*;
94pub use tiers::{
95    CacheEntry, CacheKey, CacheTier, CompressionType, DistributedCache, EvictionPolicy, HotCache,
96    TierStats, WarmCache,
97};
98
99use std::sync::atomic::{AtomicU64, Ordering};
100use std::sync::Arc;
101use std::time::{Duration, Instant};
102use thiserror::Error;
103
104/// Cache errors
105#[derive(Debug, Error)]
106pub enum CacheError {
107    #[error("Cache miss")]
108    Miss,
109
110    #[error("Entry expired")]
111    Expired,
112
113    #[error("Entry too large: {0} bytes (max: {1})")]
114    TooLarge(usize, usize),
115
116    #[error("Tier unavailable: {0}")]
117    TierUnavailable(String),
118
119    #[error("Peer not found: {0}")]
120    PeerNotFound(String),
121
122    #[error("Serialization error: {0}")]
123    Serialization(String),
124
125    #[error("Compression error: {0}")]
126    Compression(String),
127
128    #[error("Storage error: {0}")]
129    Storage(String),
130
131    #[error("Network error: {0}")]
132    Network(String),
133
134    #[error("Invalidation error: {0}")]
135    Invalidation(String),
136
137    #[error("Configuration error: {0}")]
138    Configuration(String),
139
140    #[error("Connection error: {0}")]
141    ConnectionError(String),
142}
143
144pub type CacheResult<T> = std::result::Result<T, CacheError>;
145
146/// Query fingerprint for cache key generation
147#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
148pub struct QueryFingerprint {
149    /// Normalized query template
150    pub template: String,
151    /// Table names referenced
152    pub tables: Vec<String>,
153    /// Parameter hash (for parameterized queries)
154    pub param_hash: Option<u64>,
155}
156
157impl QueryFingerprint {
158    /// Create a new fingerprint from a query
159    pub fn from_query(query: &str) -> Self {
160        let template = Self::normalize_query(query);
161        let tables = Self::extract_tables(&template);
162        let param_hash = None; // Set separately if parameterized
163
164        Self {
165            template,
166            tables,
167            param_hash,
168        }
169    }
170
171    /// Create fingerprint with parameter binding
172    pub fn with_params(mut self, params: &[&str]) -> Self {
173        use std::collections::hash_map::DefaultHasher;
174        use std::hash::{Hash, Hasher};
175
176        let mut hasher = DefaultHasher::new();
177        for param in params {
178            param.hash(&mut hasher);
179        }
180        self.param_hash = Some(hasher.finish());
181        self
182    }
183
184    /// Normalize query by removing literals and whitespace
185    fn normalize_query(query: &str) -> String {
186        let upper = query.to_uppercase();
187        // Simple normalization - replace string literals and numbers
188        let mut result = String::new();
189        let mut in_string = false;
190        let mut quote_char = ' ';
191
192        for ch in upper.chars() {
193            if in_string {
194                if ch == quote_char {
195                    in_string = false;
196                    result.push('?');
197                }
198            } else if ch == '\'' || ch == '"' {
199                in_string = true;
200                quote_char = ch;
201            } else if ch.is_numeric() {
202                if !result.ends_with('?') {
203                    result.push('?');
204                }
205            } else if ch.is_whitespace() {
206                if !result.ends_with(' ') {
207                    result.push(' ');
208                }
209            } else {
210                result.push(ch);
211            }
212        }
213
214        result.trim().to_string()
215    }
216
217    /// Extract table names from query
218    fn extract_tables(query: &str) -> Vec<String> {
219        let mut tables = Vec::new();
220        let words: Vec<&str> = query.split_whitespace().collect();
221
222        for (i, word) in words.iter().enumerate() {
223            if *word == "FROM" || *word == "JOIN" || *word == "INTO" || *word == "UPDATE" {
224                if let Some(table) = words.get(i + 1) {
225                    let table_name = table.trim_matches(|c| c == '(' || c == ')' || c == ',');
226                    if !table_name.is_empty() && !tables.contains(&table_name.to_string()) {
227                        tables.push(table_name.to_string());
228                    }
229                }
230            }
231        }
232
233        tables
234    }
235
236    /// Convert to bytes for storage
237    pub fn to_bytes(&self) -> Vec<u8> {
238        let mut bytes = Vec::new();
239        bytes.extend_from_slice(self.template.as_bytes());
240        if let Some(hash) = self.param_hash {
241            bytes.extend_from_slice(&hash.to_le_bytes());
242        }
243        bytes
244    }
245}
246
247/// Session identifier for session affinity
248#[derive(Debug, Clone, Hash, PartialEq, Eq)]
249pub struct SessionId(pub String);
250
251impl SessionId {
252    pub fn new(id: impl Into<String>) -> Self {
253        Self(id.into())
254    }
255}
256
257/// Query context for cache decisions
258#[derive(Debug, Clone)]
259pub struct QueryContext {
260    /// Session identifier
261    pub session_id: SessionId,
262    /// Workload hint (if explicitly specified)
263    pub workload_hint: Option<WorkloadType>,
264    /// Branch name (for branch-aware caching)
265    pub branch: Option<String>,
266    /// Time travel timestamp (for historical queries)
267    pub as_of: Option<u64>,
268    /// Whether this is a prepared statement
269    pub is_prepared: bool,
270    /// Request timestamp
271    pub timestamp: Instant,
272}
273
274impl QueryContext {
275    pub fn new(session_id: impl Into<String>) -> Self {
276        Self {
277            session_id: SessionId::new(session_id),
278            workload_hint: None,
279            branch: None,
280            as_of: None,
281            is_prepared: false,
282            timestamp: Instant::now(),
283        }
284    }
285
286    pub fn with_workload_hint(mut self, hint: WorkloadType) -> Self {
287        self.workload_hint = Some(hint);
288        self
289    }
290
291    pub fn with_branch(mut self, branch: impl Into<String>) -> Self {
292        self.branch = Some(branch.into());
293        self
294    }
295
296    pub fn with_as_of(mut self, timestamp: u64) -> Self {
297        self.as_of = Some(timestamp);
298        self
299    }
300}
301
302/// Helios-DistribCache - Main distributed cache instance
303pub struct HeliosDistribCache {
304    /// Workload classifier
305    classifier: WorkloadClassifier,
306
307    /// L1 hot cache (in-memory)
308    l1_hot: Arc<HotCache>,
309
310    /// L2 warm cache (SSD)
311    l2_warm: Arc<WarmCache>,
312
313    /// L3 distributed cache (mesh)
314    l3_distributed: Arc<DistributedCache>,
315
316    /// Predictive prefetcher
317    prefetcher: Arc<PredictivePrefetcher>,
318
319    /// WAL-based invalidator
320    invalidator: Arc<WalInvalidator>,
321
322    /// Cache heatmap analytics
323    heatmap: Arc<CacheHeatmap>,
324
325    /// Workload scheduler
326    scheduler: Arc<WorkloadScheduler>,
327
328    /// AI/Agent caches
329    conversation_cache: Arc<ConversationContextCache>,
330    rag_cache: Arc<RagChunkCache>,
331    tool_cache: Arc<ToolResultCache>,
332    semantic_cache: Arc<SemanticQueryCache>,
333
334    /// Metrics
335    #[allow(dead_code)]
336    metrics: Arc<DistribCacheMetrics>,
337
338    /// Configuration
339    config: DistribCacheConfig,
340
341    /// Statistics
342    stats: CacheStatistics,
343}
344
345/// Cache statistics
346#[derive(Debug, Default)]
347struct CacheStatistics {
348    total_lookups: AtomicU64,
349    l1_hits: AtomicU64,
350    l2_hits: AtomicU64,
351    l3_hits: AtomicU64,
352    total_misses: AtomicU64,
353    time_saved_us: AtomicU64,
354    queries_avoided: AtomicU64,
355}
356
357impl HeliosDistribCache {
358    /// Create a new distributed cache instance
359    pub fn new(config: DistribCacheConfig) -> Self {
360        let l1_hot = Arc::new(HotCache::new(
361            config.l1_size_mb * 1024 * 1024,
362            config.l1_max_entry_size,
363            config.l1_eviction_policy,
364        ));
365
366        let l2_warm = Arc::new(WarmCache::new(
367            config.l2_size_gb * 1024 * 1024 * 1024,
368            config.l2_path.clone(),
369            config.l2_compression,
370        ));
371
372        let l3_distributed = Arc::new(DistributedCache::new(
373            config.l3_replication_factor,
374            config.l3_peers.clone(),
375        ));
376
377        let classifier = WorkloadClassifier::new(config.clone());
378        let prefetcher = Arc::new(PredictivePrefetcher::new(config.clone()));
379        let invalidator = Arc::new(WalInvalidator::new(config.clone()));
380        let heatmap = Arc::new(CacheHeatmap::new());
381        let scheduler = Arc::new(WorkloadScheduler::new(config.clone()));
382        let metrics = Arc::new(DistribCacheMetrics::new());
383
384        // AI caches
385        let conversation_cache = Arc::new(ConversationContextCache::new(1000, 50));
386        let rag_cache = Arc::new(RagChunkCache::new(config.l1_size_mb / 4));
387        let tool_cache = Arc::new(ToolResultCache::new());
388        let semantic_cache = Arc::new(SemanticQueryCache::new(0.85));
389
390        Self {
391            classifier,
392            l1_hot,
393            l2_warm,
394            l3_distributed,
395            prefetcher,
396            invalidator,
397            heatmap,
398            scheduler,
399            conversation_cache,
400            rag_cache,
401            tool_cache,
402            semantic_cache,
403            metrics,
404            config,
405            stats: CacheStatistics::default(),
406        }
407    }
408
409    /// Get an entry from the cache (checking all tiers)
410    pub async fn get(
411        &self,
412        fingerprint: &QueryFingerprint,
413        context: &QueryContext,
414    ) -> CacheResult<CacheEntry> {
415        self.stats.total_lookups.fetch_add(1, Ordering::Relaxed);
416        let start = Instant::now();
417
418        // Classify workload for scheduling
419        let _workload = self
420            .classifier
421            .classify_query(&fingerprint.template, context);
422
423        // Check L1 hot cache first
424        if let Some(entry) = self.l1_hot.get(fingerprint, context.session_id.clone()) {
425            self.stats.l1_hits.fetch_add(1, Ordering::Relaxed);
426            self.record_hit(fingerprint, CacheTier::L1, start.elapsed());
427            return Ok(entry);
428        }
429
430        // Check L2 warm cache
431        if self.config.l2_enabled {
432            if let Some(entry) = self.l2_warm.get(fingerprint) {
433                self.stats.l2_hits.fetch_add(1, Ordering::Relaxed);
434                // Promote to L1
435                self.l1_hot.insert(
436                    fingerprint.clone(),
437                    entry.clone(),
438                    Some(context.session_id.clone()),
439                );
440                self.record_hit(fingerprint, CacheTier::L2, start.elapsed());
441                return Ok(entry);
442            }
443        }
444
445        // Check L3 distributed cache
446        if self.config.l3_enabled {
447            if let Some(entry) = self.l3_distributed.get(fingerprint).await {
448                self.stats.l3_hits.fetch_add(1, Ordering::Relaxed);
449                // Promote to L1 and L2
450                self.l1_hot.insert(
451                    fingerprint.clone(),
452                    entry.clone(),
453                    Some(context.session_id.clone()),
454                );
455                if self.config.l2_enabled {
456                    self.l2_warm.insert(fingerprint.clone(), entry.clone());
457                }
458                self.record_hit(fingerprint, CacheTier::L3, start.elapsed());
459                return Ok(entry);
460            }
461        }
462
463        // Cache miss
464        self.stats.total_misses.fetch_add(1, Ordering::Relaxed);
465        self.heatmap
466            .record_access(fingerprint, false, Duration::ZERO);
467
468        // Trigger prefetching for related queries
469        if self.config.prefetch_enabled {
470            self.prefetcher
471                .predict_and_prefetch(fingerprint, &context.session_id);
472        }
473
474        Err(CacheError::Miss)
475    }
476
477    /// Insert an entry into the cache
478    pub async fn insert(
479        &self,
480        fingerprint: QueryFingerprint,
481        entry: CacheEntry,
482        context: &QueryContext,
483    ) -> CacheResult<()> {
484        let workload = self
485            .classifier
486            .classify_query(&fingerprint.template, context);
487        let ttl = self.get_ttl_for_workload(workload);
488
489        let entry = entry.with_ttl(ttl);
490
491        // Insert into L1
492        self.l1_hot.insert(
493            fingerprint.clone(),
494            entry.clone(),
495            Some(context.session_id.clone()),
496        );
497
498        // Insert into L2 if entry is large enough and TTL warrants it
499        if self.config.l2_enabled && entry.size() > 1024 && ttl > Duration::from_secs(60) {
500            self.l2_warm.insert(fingerprint.clone(), entry.clone());
501        }
502
503        // Insert into L3 for shared caching
504        if self.config.l3_enabled && !matches!(workload, WorkloadType::OLTP) {
505            self.l3_distributed
506                .insert(fingerprint.clone(), entry.clone())
507                .await;
508        }
509
510        // Record for prefetcher learning
511        if self.config.prefetch_enabled {
512            self.prefetcher.record(&context.session_id, fingerprint);
513        }
514
515        Ok(())
516    }
517
518    /// Invalidate entries for a table
519    pub fn invalidate_table(&self, table: &str) {
520        self.l1_hot.invalidate_by_table(table);
521        if self.config.l2_enabled {
522            self.l2_warm.invalidate_by_table(table);
523        }
524        // L3 invalidation is handled by gossip protocol
525    }
526
527    /// Invalidate a specific entry
528    pub fn invalidate(&self, fingerprint: &QueryFingerprint) {
529        self.l1_hot.invalidate(fingerprint);
530        if self.config.l2_enabled {
531            self.l2_warm.invalidate(fingerprint);
532        }
533    }
534
535    /// Get TTL based on workload type
536    fn get_ttl_for_workload(&self, workload: WorkloadType) -> Duration {
537        match workload {
538            WorkloadType::OLTP => self.config.oltp_cache_ttl,
539            WorkloadType::OLAP => self.config.olap_cache_ttl,
540            WorkloadType::Vector => self.config.vector_cache_ttl,
541            WorkloadType::AIAgent => self.config.ai_agent_cache_ttl,
542            WorkloadType::RAG => self.config.rag_cache_ttl,
543            WorkloadType::Mixed => self.config.default_cache_ttl,
544        }
545    }
546
547    /// Record cache hit for metrics and heatmap
548    fn record_hit(&self, fingerprint: &QueryFingerprint, tier: CacheTier, _latency: Duration) {
549        let time_saved = match tier {
550            CacheTier::L1 => Duration::from_millis(10), // Assume 10ms DB query
551            CacheTier::L2 => Duration::from_millis(9),
552            CacheTier::L3 => Duration::from_millis(5),
553        };
554
555        self.stats
556            .time_saved_us
557            .fetch_add(time_saved.as_micros() as u64, Ordering::Relaxed);
558        self.stats.queries_avoided.fetch_add(1, Ordering::Relaxed);
559        self.heatmap.record_access(fingerprint, true, time_saved);
560    }
561
562    /// Get conversation context cache
563    pub fn conversation_cache(&self) -> &ConversationContextCache {
564        &self.conversation_cache
565    }
566
567    /// Get RAG chunk cache
568    pub fn rag_cache(&self) -> &RagChunkCache {
569        &self.rag_cache
570    }
571
572    /// Get tool result cache
573    pub fn tool_cache(&self) -> &ToolResultCache {
574        &self.tool_cache
575    }
576
577    /// Get semantic query cache
578    pub fn semantic_cache(&self) -> &SemanticQueryCache {
579        &self.semantic_cache
580    }
581
582    /// Get cache statistics
583    pub fn stats(&self) -> DistribCacheStats {
584        let total = self.stats.total_lookups.load(Ordering::Relaxed);
585        let l1_hits = self.stats.l1_hits.load(Ordering::Relaxed);
586        let l2_hits = self.stats.l2_hits.load(Ordering::Relaxed);
587        let l3_hits = self.stats.l3_hits.load(Ordering::Relaxed);
588        let _misses = self.stats.total_misses.load(Ordering::Relaxed);
589
590        DistribCacheStats {
591            l1: self.l1_hot.stats(),
592            l2: self.l2_warm.stats(),
593            l3: self.l3_distributed.stats(),
594            overall_hit_ratio: if total > 0 {
595                (l1_hits + l2_hits + l3_hits) as f64 / total as f64
596            } else {
597                0.0
598            },
599            time_saved_seconds: self.stats.time_saved_us.load(Ordering::Relaxed) as f64
600                / 1_000_000.0,
601            queries_avoided: self.stats.queries_avoided.load(Ordering::Relaxed),
602        }
603    }
604
605    /// Generate heatmap data
606    pub fn heatmap(&self) -> HeatmapData {
607        self.heatmap.generate_heatmap()
608    }
609
610    /// Get workload distribution
611    pub fn workload_distribution(&self) -> WorkloadDistribution {
612        self.scheduler.get_distribution()
613    }
614
615    /// Start background services (prefetcher, invalidator)
616    pub async fn start(&self) -> CacheResult<()> {
617        // Start WAL invalidator if configured
618        if let Some(wal_endpoint) = &self.config.wal_endpoint {
619            self.invalidator.start(wal_endpoint).await?;
620        }
621
622        // Start prefetcher background worker
623        if self.config.prefetch_enabled {
624            self.prefetcher.start().await;
625        }
626
627        Ok(())
628    }
629
630    /// Stop background services
631    pub async fn stop(&self) -> CacheResult<()> {
632        self.invalidator.stop().await;
633        self.prefetcher.stop().await;
634        Ok(())
635    }
636}
637
638/// Cache statistics snapshot
639#[derive(Debug, Clone)]
640pub struct DistribCacheStats {
641    /// L1 tier stats
642    pub l1: TierStats,
643    /// L2 tier stats
644    pub l2: TierStats,
645    /// L3 tier stats
646    pub l3: TierStats,
647    /// Overall hit ratio
648    pub overall_hit_ratio: f64,
649    /// Total time saved in seconds
650    pub time_saved_seconds: f64,
651    /// Total queries avoided
652    pub queries_avoided: u64,
653}
654
655#[cfg(test)]
656mod tests {
657    use super::*;
658
659    #[test]
660    fn test_query_fingerprint() {
661        let fp = QueryFingerprint::from_query("SELECT * FROM users WHERE id = 42");
662        assert!(fp.template.contains("SELECT"));
663        assert!(fp.template.contains("USERS"));
664        assert!(fp.tables.contains(&"USERS".to_string()));
665    }
666
667    #[test]
668    fn test_query_fingerprint_normalization() {
669        let fp1 = QueryFingerprint::from_query("SELECT * FROM users WHERE id = 1");
670        let fp2 = QueryFingerprint::from_query("SELECT * FROM users WHERE id = 2");
671        // Both should have same template after normalization
672        assert_eq!(fp1.template, fp2.template);
673    }
674
675    #[test]
676    fn test_session_id() {
677        let sid = SessionId::new("test-session");
678        assert_eq!(sid.0, "test-session");
679    }
680
681    #[test]
682    fn test_query_context() {
683        let ctx = QueryContext::new("session-1")
684            .with_workload_hint(WorkloadType::OLTP)
685            .with_branch("feature-x");
686
687        assert_eq!(ctx.session_id.0, "session-1");
688        assert_eq!(ctx.workload_hint, Some(WorkloadType::OLTP));
689        assert_eq!(ctx.branch, Some("feature-x".to_string()));
690    }
691}