1pub mod config;
42pub mod classifier;
43pub mod tiers;
44pub mod prefetcher;
45pub mod invalidator;
46pub mod heatmap;
47pub mod scheduler;
48pub mod ai;
49pub mod metrics;
50
51pub use config::*;
52pub use classifier::*;
53pub use tiers::{
54 HotCache, WarmCache, DistributedCache,
55 CacheEntry, CacheKey, CacheTier, TierStats,
56 EvictionPolicy, CompressionType,
57};
58pub use prefetcher::*;
59pub use invalidator::*;
60pub use heatmap::*;
61pub use scheduler::*;
62pub use ai::{
63 ConversationContextCache, ConversationContext, Turn, ConversationCacheStats,
64 RagChunkCache, Chunk, ChunkId, RagCacheStatsSnapshot,
65 ToolResultCache, ToolCallKey, ToolResult, ToolCacheStatsSnapshot,
66 SemanticQueryCache, SemanticEntry, SemanticCacheStatsSnapshot, cosine_similarity,
67 BranchContext, BranchId, AIWorkloadContext, VectorId, Embedding,
69 SemanticIndex, SemanticIndexConfig, SimilarityResult,
70 AIIntegrationCoordinator, AIIntegrationConfig, AIIntegrationStatsSnapshot,
72 AIWorkloadDetection, SessionTrackingInfo, CacheRecommendation,
73 CachePriority, RecommendedTier,
74};
75pub use metrics::{DistribCacheMetrics, InvalidationSource, ErrorType};
76
77use std::sync::atomic::{AtomicU64, Ordering};
78use std::sync::Arc;
79use std::time::{Duration, Instant};
80use thiserror::Error;
81
82#[derive(Debug, Error)]
84pub enum CacheError {
85 #[error("Cache miss")]
86 Miss,
87
88 #[error("Entry expired")]
89 Expired,
90
91 #[error("Entry too large: {0} bytes (max: {1})")]
92 TooLarge(usize, usize),
93
94 #[error("Tier unavailable: {0}")]
95 TierUnavailable(String),
96
97 #[error("Peer not found: {0}")]
98 PeerNotFound(String),
99
100 #[error("Serialization error: {0}")]
101 Serialization(String),
102
103 #[error("Compression error: {0}")]
104 Compression(String),
105
106 #[error("Storage error: {0}")]
107 Storage(String),
108
109 #[error("Network error: {0}")]
110 Network(String),
111
112 #[error("Invalidation error: {0}")]
113 Invalidation(String),
114
115 #[error("Configuration error: {0}")]
116 Configuration(String),
117
118 #[error("Connection error: {0}")]
119 ConnectionError(String),
120}
121
122pub type CacheResult<T> = std::result::Result<T, CacheError>;
123
124#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
126pub struct QueryFingerprint {
127 pub template: String,
129 pub tables: Vec<String>,
131 pub param_hash: Option<u64>,
133}
134
135impl QueryFingerprint {
136 pub fn from_query(query: &str) -> Self {
138 let template = Self::normalize_query(query);
139 let tables = Self::extract_tables(&template);
140 let param_hash = None; Self {
143 template,
144 tables,
145 param_hash,
146 }
147 }
148
149 pub fn with_params(mut self, params: &[&str]) -> Self {
151 use std::hash::{Hash, Hasher};
152 use std::collections::hash_map::DefaultHasher;
153
154 let mut hasher = DefaultHasher::new();
155 for param in params {
156 param.hash(&mut hasher);
157 }
158 self.param_hash = Some(hasher.finish());
159 self
160 }
161
162 fn normalize_query(query: &str) -> String {
164 let upper = query.to_uppercase();
165 let mut result = String::new();
167 let mut in_string = false;
168 let mut quote_char = ' ';
169
170 for ch in upper.chars() {
171 if in_string {
172 if ch == quote_char {
173 in_string = false;
174 result.push('?');
175 }
176 } else if ch == '\'' || ch == '"' {
177 in_string = true;
178 quote_char = ch;
179 } else if ch.is_numeric() {
180 if !result.ends_with('?') {
181 result.push('?');
182 }
183 } else if ch.is_whitespace() {
184 if !result.ends_with(' ') {
185 result.push(' ');
186 }
187 } else {
188 result.push(ch);
189 }
190 }
191
192 result.trim().to_string()
193 }
194
195 fn extract_tables(query: &str) -> Vec<String> {
197 let mut tables = Vec::new();
198 let words: Vec<&str> = query.split_whitespace().collect();
199
200 for (i, word) in words.iter().enumerate() {
201 if *word == "FROM" || *word == "JOIN" || *word == "INTO" || *word == "UPDATE" {
202 if let Some(table) = words.get(i + 1) {
203 let table_name = table.trim_matches(|c| c == '(' || c == ')' || c == ',');
204 if !table_name.is_empty() && !tables.contains(&table_name.to_string()) {
205 tables.push(table_name.to_string());
206 }
207 }
208 }
209 }
210
211 tables
212 }
213
214 pub fn to_bytes(&self) -> Vec<u8> {
216 let mut bytes = Vec::new();
217 bytes.extend_from_slice(self.template.as_bytes());
218 if let Some(hash) = self.param_hash {
219 bytes.extend_from_slice(&hash.to_le_bytes());
220 }
221 bytes
222 }
223}
224
225#[derive(Debug, Clone, Hash, PartialEq, Eq)]
227pub struct SessionId(pub String);
228
229impl SessionId {
230 pub fn new(id: impl Into<String>) -> Self {
231 Self(id.into())
232 }
233}
234
235#[derive(Debug, Clone)]
237pub struct QueryContext {
238 pub session_id: SessionId,
240 pub workload_hint: Option<WorkloadType>,
242 pub branch: Option<String>,
244 pub as_of: Option<u64>,
246 pub is_prepared: bool,
248 pub timestamp: Instant,
250}
251
252impl QueryContext {
253 pub fn new(session_id: impl Into<String>) -> Self {
254 Self {
255 session_id: SessionId::new(session_id),
256 workload_hint: None,
257 branch: None,
258 as_of: None,
259 is_prepared: false,
260 timestamp: Instant::now(),
261 }
262 }
263
264 pub fn with_workload_hint(mut self, hint: WorkloadType) -> Self {
265 self.workload_hint = Some(hint);
266 self
267 }
268
269 pub fn with_branch(mut self, branch: impl Into<String>) -> Self {
270 self.branch = Some(branch.into());
271 self
272 }
273
274 pub fn with_as_of(mut self, timestamp: u64) -> Self {
275 self.as_of = Some(timestamp);
276 self
277 }
278}
279
280pub struct HeliosDistribCache {
282 classifier: WorkloadClassifier,
284
285 l1_hot: Arc<HotCache>,
287
288 l2_warm: Arc<WarmCache>,
290
291 l3_distributed: Arc<DistributedCache>,
293
294 prefetcher: Arc<PredictivePrefetcher>,
296
297 invalidator: Arc<WalInvalidator>,
299
300 heatmap: Arc<CacheHeatmap>,
302
303 scheduler: Arc<WorkloadScheduler>,
305
306 conversation_cache: Arc<ConversationContextCache>,
308 rag_cache: Arc<RagChunkCache>,
309 tool_cache: Arc<ToolResultCache>,
310 semantic_cache: Arc<SemanticQueryCache>,
311
312 metrics: Arc<DistribCacheMetrics>,
314
315 config: DistribCacheConfig,
317
318 stats: CacheStatistics,
320}
321
322#[derive(Debug, Default)]
324struct CacheStatistics {
325 total_lookups: AtomicU64,
326 l1_hits: AtomicU64,
327 l2_hits: AtomicU64,
328 l3_hits: AtomicU64,
329 total_misses: AtomicU64,
330 time_saved_us: AtomicU64,
331 queries_avoided: AtomicU64,
332}
333
334impl HeliosDistribCache {
335 pub fn new(config: DistribCacheConfig) -> Self {
337 let l1_hot = Arc::new(HotCache::new(
338 config.l1_size_mb * 1024 * 1024,
339 config.l1_max_entry_size,
340 config.l1_eviction_policy,
341 ));
342
343 let l2_warm = Arc::new(WarmCache::new(
344 config.l2_size_gb * 1024 * 1024 * 1024,
345 config.l2_path.clone(),
346 config.l2_compression,
347 ));
348
349 let l3_distributed = Arc::new(DistributedCache::new(
350 config.l3_replication_factor,
351 config.l3_peers.clone(),
352 ));
353
354 let classifier = WorkloadClassifier::new(config.clone());
355 let prefetcher = Arc::new(PredictivePrefetcher::new(config.clone()));
356 let invalidator = Arc::new(WalInvalidator::new(config.clone()));
357 let heatmap = Arc::new(CacheHeatmap::new());
358 let scheduler = Arc::new(WorkloadScheduler::new(config.clone()));
359 let metrics = Arc::new(DistribCacheMetrics::new());
360
361 let conversation_cache = Arc::new(ConversationContextCache::new(1000, 50));
363 let rag_cache = Arc::new(RagChunkCache::new(config.l1_size_mb / 4));
364 let tool_cache = Arc::new(ToolResultCache::new());
365 let semantic_cache = Arc::new(SemanticQueryCache::new(0.85));
366
367 Self {
368 classifier,
369 l1_hot,
370 l2_warm,
371 l3_distributed,
372 prefetcher,
373 invalidator,
374 heatmap,
375 scheduler,
376 conversation_cache,
377 rag_cache,
378 tool_cache,
379 semantic_cache,
380 metrics,
381 config,
382 stats: CacheStatistics::default(),
383 }
384 }
385
386 pub async fn get(
388 &self,
389 fingerprint: &QueryFingerprint,
390 context: &QueryContext,
391 ) -> CacheResult<CacheEntry> {
392 self.stats.total_lookups.fetch_add(1, Ordering::Relaxed);
393 let start = Instant::now();
394
395 let _workload = self.classifier.classify_query(&fingerprint.template, context);
397
398 if let Some(entry) = self.l1_hot.get(fingerprint, context.session_id.clone()) {
400 self.stats.l1_hits.fetch_add(1, Ordering::Relaxed);
401 self.record_hit(fingerprint, CacheTier::L1, start.elapsed());
402 return Ok(entry);
403 }
404
405 if self.config.l2_enabled {
407 if let Some(entry) = self.l2_warm.get(fingerprint) {
408 self.stats.l2_hits.fetch_add(1, Ordering::Relaxed);
409 self.l1_hot.insert(
411 fingerprint.clone(),
412 entry.clone(),
413 Some(context.session_id.clone()),
414 );
415 self.record_hit(fingerprint, CacheTier::L2, start.elapsed());
416 return Ok(entry);
417 }
418 }
419
420 if self.config.l3_enabled {
422 if let Some(entry) = self.l3_distributed.get(fingerprint).await {
423 self.stats.l3_hits.fetch_add(1, Ordering::Relaxed);
424 self.l1_hot.insert(
426 fingerprint.clone(),
427 entry.clone(),
428 Some(context.session_id.clone()),
429 );
430 if self.config.l2_enabled {
431 self.l2_warm.insert(fingerprint.clone(), entry.clone());
432 }
433 self.record_hit(fingerprint, CacheTier::L3, start.elapsed());
434 return Ok(entry);
435 }
436 }
437
438 self.stats.total_misses.fetch_add(1, Ordering::Relaxed);
440 self.heatmap.record_access(fingerprint, false, Duration::ZERO);
441
442 if self.config.prefetch_enabled {
444 self.prefetcher.predict_and_prefetch(fingerprint, &context.session_id);
445 }
446
447 Err(CacheError::Miss)
448 }
449
450 pub async fn insert(
452 &self,
453 fingerprint: QueryFingerprint,
454 entry: CacheEntry,
455 context: &QueryContext,
456 ) -> CacheResult<()> {
457 let workload = self.classifier.classify_query(&fingerprint.template, context);
458 let ttl = self.get_ttl_for_workload(workload);
459
460 let entry = entry.with_ttl(ttl);
461
462 self.l1_hot.insert(
464 fingerprint.clone(),
465 entry.clone(),
466 Some(context.session_id.clone()),
467 );
468
469 if self.config.l2_enabled && entry.size() > 1024 && ttl > Duration::from_secs(60) {
471 self.l2_warm.insert(fingerprint.clone(), entry.clone());
472 }
473
474 if self.config.l3_enabled && !matches!(workload, WorkloadType::OLTP) {
476 self.l3_distributed.insert(fingerprint.clone(), entry.clone()).await;
477 }
478
479 if self.config.prefetch_enabled {
481 self.prefetcher.record(&context.session_id, fingerprint);
482 }
483
484 Ok(())
485 }
486
487 pub fn invalidate_table(&self, table: &str) {
489 self.l1_hot.invalidate_by_table(table);
490 if self.config.l2_enabled {
491 self.l2_warm.invalidate_by_table(table);
492 }
493 }
495
496 pub fn invalidate(&self, fingerprint: &QueryFingerprint) {
498 self.l1_hot.invalidate(fingerprint);
499 if self.config.l2_enabled {
500 self.l2_warm.invalidate(fingerprint);
501 }
502 }
503
504 fn get_ttl_for_workload(&self, workload: WorkloadType) -> Duration {
506 match workload {
507 WorkloadType::OLTP => self.config.oltp_cache_ttl,
508 WorkloadType::OLAP => self.config.olap_cache_ttl,
509 WorkloadType::Vector => self.config.vector_cache_ttl,
510 WorkloadType::AIAgent => self.config.ai_agent_cache_ttl,
511 WorkloadType::RAG => self.config.rag_cache_ttl,
512 WorkloadType::Mixed => self.config.default_cache_ttl,
513 }
514 }
515
516 fn record_hit(&self, fingerprint: &QueryFingerprint, tier: CacheTier, _latency: Duration) {
518 let time_saved = match tier {
519 CacheTier::L1 => Duration::from_millis(10), CacheTier::L2 => Duration::from_millis(9),
521 CacheTier::L3 => Duration::from_millis(5),
522 };
523
524 self.stats.time_saved_us.fetch_add(
525 time_saved.as_micros() as u64,
526 Ordering::Relaxed,
527 );
528 self.stats.queries_avoided.fetch_add(1, Ordering::Relaxed);
529 self.heatmap.record_access(fingerprint, true, time_saved);
530 }
531
532 pub fn conversation_cache(&self) -> &ConversationContextCache {
534 &self.conversation_cache
535 }
536
537 pub fn rag_cache(&self) -> &RagChunkCache {
539 &self.rag_cache
540 }
541
542 pub fn tool_cache(&self) -> &ToolResultCache {
544 &self.tool_cache
545 }
546
547 pub fn semantic_cache(&self) -> &SemanticQueryCache {
549 &self.semantic_cache
550 }
551
552 pub fn stats(&self) -> DistribCacheStats {
554 let total = self.stats.total_lookups.load(Ordering::Relaxed);
555 let l1_hits = self.stats.l1_hits.load(Ordering::Relaxed);
556 let l2_hits = self.stats.l2_hits.load(Ordering::Relaxed);
557 let l3_hits = self.stats.l3_hits.load(Ordering::Relaxed);
558 let _misses = self.stats.total_misses.load(Ordering::Relaxed);
559
560 DistribCacheStats {
561 l1: self.l1_hot.stats(),
562 l2: self.l2_warm.stats(),
563 l3: self.l3_distributed.stats(),
564 overall_hit_ratio: if total > 0 {
565 (l1_hits + l2_hits + l3_hits) as f64 / total as f64
566 } else {
567 0.0
568 },
569 time_saved_seconds: self.stats.time_saved_us.load(Ordering::Relaxed) as f64 / 1_000_000.0,
570 queries_avoided: self.stats.queries_avoided.load(Ordering::Relaxed),
571 }
572 }
573
574 pub fn heatmap(&self) -> HeatmapData {
576 self.heatmap.generate_heatmap()
577 }
578
579 pub fn workload_distribution(&self) -> WorkloadDistribution {
581 self.scheduler.get_distribution()
582 }
583
584 pub async fn start(&self) -> CacheResult<()> {
586 if let Some(wal_endpoint) = &self.config.wal_endpoint {
588 self.invalidator.start(wal_endpoint).await?;
589 }
590
591 if self.config.prefetch_enabled {
593 self.prefetcher.start().await;
594 }
595
596 Ok(())
597 }
598
599 pub async fn stop(&self) -> CacheResult<()> {
601 self.invalidator.stop().await;
602 self.prefetcher.stop().await;
603 Ok(())
604 }
605}
606
607#[derive(Debug, Clone)]
609pub struct DistribCacheStats {
610 pub l1: TierStats,
612 pub l2: TierStats,
614 pub l3: TierStats,
616 pub overall_hit_ratio: f64,
618 pub time_saved_seconds: f64,
620 pub queries_avoided: u64,
622}
623
624#[cfg(test)]
625mod tests {
626 use super::*;
627
628 #[test]
629 fn test_query_fingerprint() {
630 let fp = QueryFingerprint::from_query("SELECT * FROM users WHERE id = 42");
631 assert!(fp.template.contains("SELECT"));
632 assert!(fp.template.contains("USERS"));
633 assert!(fp.tables.contains(&"USERS".to_string()));
634 }
635
636 #[test]
637 fn test_query_fingerprint_normalization() {
638 let fp1 = QueryFingerprint::from_query("SELECT * FROM users WHERE id = 1");
639 let fp2 = QueryFingerprint::from_query("SELECT * FROM users WHERE id = 2");
640 assert_eq!(fp1.template, fp2.template);
642 }
643
644 #[test]
645 fn test_session_id() {
646 let sid = SessionId::new("test-session");
647 assert_eq!(sid.0, "test-session");
648 }
649
650 #[test]
651 fn test_query_context() {
652 let ctx = QueryContext::new("session-1")
653 .with_workload_hint(WorkloadType::OLTP)
654 .with_branch("feature-x");
655
656 assert_eq!(ctx.session_id.0, "session-1");
657 assert_eq!(ctx.workload_hint, Some(WorkloadType::OLTP));
658 assert_eq!(ctx.branch, Some("feature-x".to_string()));
659 }
660}