1pub mod ai;
42pub mod classifier;
43pub mod config;
44pub mod heatmap;
45pub mod invalidator;
46pub mod metrics;
47pub mod prefetcher;
48pub mod scheduler;
49pub mod tiers;
50
51pub use ai::{
52 cosine_similarity,
53 AIIntegrationConfig,
54 AIIntegrationCoordinator,
56 AIIntegrationStatsSnapshot,
57 AIWorkloadContext,
58 AIWorkloadDetection,
59 BranchContext,
61 BranchId,
62 CachePriority,
63 CacheRecommendation,
64 Chunk,
65 ChunkId,
66 ConversationCacheStats,
67 ConversationContext,
68 ConversationContextCache,
69 Embedding,
70 RagCacheStatsSnapshot,
71 RagChunkCache,
72 RecommendedTier,
73 SemanticCacheStatsSnapshot,
74 SemanticEntry,
75 SemanticIndex,
76 SemanticIndexConfig,
77 SemanticQueryCache,
78 SessionTrackingInfo,
79 SimilarityResult,
80 ToolCacheStatsSnapshot,
81 ToolCallKey,
82 ToolResult,
83 ToolResultCache,
84 Turn,
85 VectorId,
86};
87pub use classifier::*;
88pub use config::*;
89pub use heatmap::*;
90pub use invalidator::*;
91pub use metrics::{DistribCacheMetrics, ErrorType, InvalidationSource};
92pub use prefetcher::*;
93pub use scheduler::*;
94pub use tiers::{
95 CacheEntry, CacheKey, CacheTier, CompressionType, DistributedCache, EvictionPolicy, HotCache,
96 TierStats, WarmCache,
97};
98
99use std::sync::atomic::{AtomicU64, Ordering};
100use std::sync::Arc;
101use std::time::{Duration, Instant};
102use thiserror::Error;
103
104#[derive(Debug, Error)]
106pub enum CacheError {
107 #[error("Cache miss")]
108 Miss,
109
110 #[error("Entry expired")]
111 Expired,
112
113 #[error("Entry too large: {0} bytes (max: {1})")]
114 TooLarge(usize, usize),
115
116 #[error("Tier unavailable: {0}")]
117 TierUnavailable(String),
118
119 #[error("Peer not found: {0}")]
120 PeerNotFound(String),
121
122 #[error("Serialization error: {0}")]
123 Serialization(String),
124
125 #[error("Compression error: {0}")]
126 Compression(String),
127
128 #[error("Storage error: {0}")]
129 Storage(String),
130
131 #[error("Network error: {0}")]
132 Network(String),
133
134 #[error("Invalidation error: {0}")]
135 Invalidation(String),
136
137 #[error("Configuration error: {0}")]
138 Configuration(String),
139
140 #[error("Connection error: {0}")]
141 ConnectionError(String),
142}
143
144pub type CacheResult<T> = std::result::Result<T, CacheError>;
145
146#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
148pub struct QueryFingerprint {
149 pub template: String,
151 pub tables: Vec<String>,
153 pub param_hash: Option<u64>,
155}
156
157impl QueryFingerprint {
158 pub fn from_query(query: &str) -> Self {
160 let template = Self::normalize_query(query);
161 let tables = Self::extract_tables(&template);
162 let param_hash = None; Self {
165 template,
166 tables,
167 param_hash,
168 }
169 }
170
171 pub fn with_params(mut self, params: &[&str]) -> Self {
173 use std::collections::hash_map::DefaultHasher;
174 use std::hash::{Hash, Hasher};
175
176 let mut hasher = DefaultHasher::new();
177 for param in params {
178 param.hash(&mut hasher);
179 }
180 self.param_hash = Some(hasher.finish());
181 self
182 }
183
184 fn normalize_query(query: &str) -> String {
186 let upper = query.to_uppercase();
187 let mut result = String::new();
189 let mut in_string = false;
190 let mut quote_char = ' ';
191
192 for ch in upper.chars() {
193 if in_string {
194 if ch == quote_char {
195 in_string = false;
196 result.push('?');
197 }
198 } else if ch == '\'' || ch == '"' {
199 in_string = true;
200 quote_char = ch;
201 } else if ch.is_numeric() {
202 if !result.ends_with('?') {
203 result.push('?');
204 }
205 } else if ch.is_whitespace() {
206 if !result.ends_with(' ') {
207 result.push(' ');
208 }
209 } else {
210 result.push(ch);
211 }
212 }
213
214 result.trim().to_string()
215 }
216
217 fn extract_tables(query: &str) -> Vec<String> {
219 let mut tables = Vec::new();
220 let words: Vec<&str> = query.split_whitespace().collect();
221
222 for (i, word) in words.iter().enumerate() {
223 if *word == "FROM" || *word == "JOIN" || *word == "INTO" || *word == "UPDATE" {
224 if let Some(table) = words.get(i + 1) {
225 let table_name = table.trim_matches(|c| c == '(' || c == ')' || c == ',');
226 if !table_name.is_empty() && !tables.contains(&table_name.to_string()) {
227 tables.push(table_name.to_string());
228 }
229 }
230 }
231 }
232
233 tables
234 }
235
236 pub fn to_bytes(&self) -> Vec<u8> {
238 let mut bytes = Vec::new();
239 bytes.extend_from_slice(self.template.as_bytes());
240 if let Some(hash) = self.param_hash {
241 bytes.extend_from_slice(&hash.to_le_bytes());
242 }
243 bytes
244 }
245}
246
247#[derive(Debug, Clone, Hash, PartialEq, Eq)]
249pub struct SessionId(pub String);
250
251impl SessionId {
252 pub fn new(id: impl Into<String>) -> Self {
253 Self(id.into())
254 }
255}
256
257#[derive(Debug, Clone)]
259pub struct QueryContext {
260 pub session_id: SessionId,
262 pub workload_hint: Option<WorkloadType>,
264 pub branch: Option<String>,
266 pub as_of: Option<u64>,
268 pub is_prepared: bool,
270 pub timestamp: Instant,
272}
273
274impl QueryContext {
275 pub fn new(session_id: impl Into<String>) -> Self {
276 Self {
277 session_id: SessionId::new(session_id),
278 workload_hint: None,
279 branch: None,
280 as_of: None,
281 is_prepared: false,
282 timestamp: Instant::now(),
283 }
284 }
285
286 pub fn with_workload_hint(mut self, hint: WorkloadType) -> Self {
287 self.workload_hint = Some(hint);
288 self
289 }
290
291 pub fn with_branch(mut self, branch: impl Into<String>) -> Self {
292 self.branch = Some(branch.into());
293 self
294 }
295
296 pub fn with_as_of(mut self, timestamp: u64) -> Self {
297 self.as_of = Some(timestamp);
298 self
299 }
300}
301
302pub struct HeliosDistribCache {
304 classifier: WorkloadClassifier,
306
307 l1_hot: Arc<HotCache>,
309
310 l2_warm: Arc<WarmCache>,
312
313 l3_distributed: Arc<DistributedCache>,
315
316 prefetcher: Arc<PredictivePrefetcher>,
318
319 invalidator: Arc<WalInvalidator>,
321
322 heatmap: Arc<CacheHeatmap>,
324
325 scheduler: Arc<WorkloadScheduler>,
327
328 conversation_cache: Arc<ConversationContextCache>,
330 rag_cache: Arc<RagChunkCache>,
331 tool_cache: Arc<ToolResultCache>,
332 semantic_cache: Arc<SemanticQueryCache>,
333
334 #[allow(dead_code)]
336 metrics: Arc<DistribCacheMetrics>,
337
338 config: DistribCacheConfig,
340
341 stats: CacheStatistics,
343}
344
345#[derive(Debug, Default)]
347struct CacheStatistics {
348 total_lookups: AtomicU64,
349 l1_hits: AtomicU64,
350 l2_hits: AtomicU64,
351 l3_hits: AtomicU64,
352 total_misses: AtomicU64,
353 time_saved_us: AtomicU64,
354 queries_avoided: AtomicU64,
355}
356
357impl HeliosDistribCache {
358 pub fn new(config: DistribCacheConfig) -> Self {
360 let l1_hot = Arc::new(HotCache::new(
361 config.l1_size_mb * 1024 * 1024,
362 config.l1_max_entry_size,
363 config.l1_eviction_policy,
364 ));
365
366 let l2_warm = Arc::new(WarmCache::new(
367 config.l2_size_gb * 1024 * 1024 * 1024,
368 config.l2_path.clone(),
369 config.l2_compression,
370 ));
371
372 let l3_distributed = Arc::new(DistributedCache::new(
373 config.l3_replication_factor,
374 config.l3_peers.clone(),
375 ));
376
377 let classifier = WorkloadClassifier::new(config.clone());
378 let prefetcher = Arc::new(PredictivePrefetcher::new(config.clone()));
379 let invalidator = Arc::new(WalInvalidator::new(config.clone()));
380 let heatmap = Arc::new(CacheHeatmap::new());
381 let scheduler = Arc::new(WorkloadScheduler::new(config.clone()));
382 let metrics = Arc::new(DistribCacheMetrics::new());
383
384 let conversation_cache = Arc::new(ConversationContextCache::new(1000, 50));
386 let rag_cache = Arc::new(RagChunkCache::new(config.l1_size_mb / 4));
387 let tool_cache = Arc::new(ToolResultCache::new());
388 let semantic_cache = Arc::new(SemanticQueryCache::new(0.85));
389
390 Self {
391 classifier,
392 l1_hot,
393 l2_warm,
394 l3_distributed,
395 prefetcher,
396 invalidator,
397 heatmap,
398 scheduler,
399 conversation_cache,
400 rag_cache,
401 tool_cache,
402 semantic_cache,
403 metrics,
404 config,
405 stats: CacheStatistics::default(),
406 }
407 }
408
409 pub async fn get(
411 &self,
412 fingerprint: &QueryFingerprint,
413 context: &QueryContext,
414 ) -> CacheResult<CacheEntry> {
415 self.stats.total_lookups.fetch_add(1, Ordering::Relaxed);
416 let start = Instant::now();
417
418 let _workload = self
420 .classifier
421 .classify_query(&fingerprint.template, context);
422
423 if let Some(entry) = self.l1_hot.get(fingerprint, context.session_id.clone()) {
425 self.stats.l1_hits.fetch_add(1, Ordering::Relaxed);
426 self.record_hit(fingerprint, CacheTier::L1, start.elapsed());
427 return Ok(entry);
428 }
429
430 if self.config.l2_enabled {
432 if let Some(entry) = self.l2_warm.get(fingerprint) {
433 self.stats.l2_hits.fetch_add(1, Ordering::Relaxed);
434 self.l1_hot.insert(
436 fingerprint.clone(),
437 entry.clone(),
438 Some(context.session_id.clone()),
439 );
440 self.record_hit(fingerprint, CacheTier::L2, start.elapsed());
441 return Ok(entry);
442 }
443 }
444
445 if self.config.l3_enabled {
447 if let Some(entry) = self.l3_distributed.get(fingerprint).await {
448 self.stats.l3_hits.fetch_add(1, Ordering::Relaxed);
449 self.l1_hot.insert(
451 fingerprint.clone(),
452 entry.clone(),
453 Some(context.session_id.clone()),
454 );
455 if self.config.l2_enabled {
456 self.l2_warm.insert(fingerprint.clone(), entry.clone());
457 }
458 self.record_hit(fingerprint, CacheTier::L3, start.elapsed());
459 return Ok(entry);
460 }
461 }
462
463 self.stats.total_misses.fetch_add(1, Ordering::Relaxed);
465 self.heatmap
466 .record_access(fingerprint, false, Duration::ZERO);
467
468 if self.config.prefetch_enabled {
470 self.prefetcher
471 .predict_and_prefetch(fingerprint, &context.session_id);
472 }
473
474 Err(CacheError::Miss)
475 }
476
477 pub async fn insert(
479 &self,
480 fingerprint: QueryFingerprint,
481 entry: CacheEntry,
482 context: &QueryContext,
483 ) -> CacheResult<()> {
484 let workload = self
485 .classifier
486 .classify_query(&fingerprint.template, context);
487 let ttl = self.get_ttl_for_workload(workload);
488
489 let entry = entry.with_ttl(ttl);
490
491 self.l1_hot.insert(
493 fingerprint.clone(),
494 entry.clone(),
495 Some(context.session_id.clone()),
496 );
497
498 if self.config.l2_enabled && entry.size() > 1024 && ttl > Duration::from_secs(60) {
500 self.l2_warm.insert(fingerprint.clone(), entry.clone());
501 }
502
503 if self.config.l3_enabled && !matches!(workload, WorkloadType::OLTP) {
505 self.l3_distributed
506 .insert(fingerprint.clone(), entry.clone())
507 .await;
508 }
509
510 if self.config.prefetch_enabled {
512 self.prefetcher.record(&context.session_id, fingerprint);
513 }
514
515 Ok(())
516 }
517
518 pub fn invalidate_table(&self, table: &str) {
520 self.l1_hot.invalidate_by_table(table);
521 if self.config.l2_enabled {
522 self.l2_warm.invalidate_by_table(table);
523 }
524 }
526
527 pub fn invalidate(&self, fingerprint: &QueryFingerprint) {
529 self.l1_hot.invalidate(fingerprint);
530 if self.config.l2_enabled {
531 self.l2_warm.invalidate(fingerprint);
532 }
533 }
534
535 fn get_ttl_for_workload(&self, workload: WorkloadType) -> Duration {
537 match workload {
538 WorkloadType::OLTP => self.config.oltp_cache_ttl,
539 WorkloadType::OLAP => self.config.olap_cache_ttl,
540 WorkloadType::Vector => self.config.vector_cache_ttl,
541 WorkloadType::AIAgent => self.config.ai_agent_cache_ttl,
542 WorkloadType::RAG => self.config.rag_cache_ttl,
543 WorkloadType::Mixed => self.config.default_cache_ttl,
544 }
545 }
546
547 fn record_hit(&self, fingerprint: &QueryFingerprint, tier: CacheTier, _latency: Duration) {
549 let time_saved = match tier {
550 CacheTier::L1 => Duration::from_millis(10), CacheTier::L2 => Duration::from_millis(9),
552 CacheTier::L3 => Duration::from_millis(5),
553 };
554
555 self.stats
556 .time_saved_us
557 .fetch_add(time_saved.as_micros() as u64, Ordering::Relaxed);
558 self.stats.queries_avoided.fetch_add(1, Ordering::Relaxed);
559 self.heatmap.record_access(fingerprint, true, time_saved);
560 }
561
562 pub fn conversation_cache(&self) -> &ConversationContextCache {
564 &self.conversation_cache
565 }
566
567 pub fn rag_cache(&self) -> &RagChunkCache {
569 &self.rag_cache
570 }
571
572 pub fn tool_cache(&self) -> &ToolResultCache {
574 &self.tool_cache
575 }
576
577 pub fn semantic_cache(&self) -> &SemanticQueryCache {
579 &self.semantic_cache
580 }
581
582 pub fn stats(&self) -> DistribCacheStats {
584 let total = self.stats.total_lookups.load(Ordering::Relaxed);
585 let l1_hits = self.stats.l1_hits.load(Ordering::Relaxed);
586 let l2_hits = self.stats.l2_hits.load(Ordering::Relaxed);
587 let l3_hits = self.stats.l3_hits.load(Ordering::Relaxed);
588 let _misses = self.stats.total_misses.load(Ordering::Relaxed);
589
590 DistribCacheStats {
591 l1: self.l1_hot.stats(),
592 l2: self.l2_warm.stats(),
593 l3: self.l3_distributed.stats(),
594 overall_hit_ratio: if total > 0 {
595 (l1_hits + l2_hits + l3_hits) as f64 / total as f64
596 } else {
597 0.0
598 },
599 time_saved_seconds: self.stats.time_saved_us.load(Ordering::Relaxed) as f64
600 / 1_000_000.0,
601 queries_avoided: self.stats.queries_avoided.load(Ordering::Relaxed),
602 }
603 }
604
605 pub fn heatmap(&self) -> HeatmapData {
607 self.heatmap.generate_heatmap()
608 }
609
610 pub fn workload_distribution(&self) -> WorkloadDistribution {
612 self.scheduler.get_distribution()
613 }
614
615 pub async fn start(&self) -> CacheResult<()> {
617 if let Some(wal_endpoint) = &self.config.wal_endpoint {
619 self.invalidator.start(wal_endpoint).await?;
620 }
621
622 if self.config.prefetch_enabled {
624 self.prefetcher.start().await;
625 }
626
627 Ok(())
628 }
629
630 pub async fn stop(&self) -> CacheResult<()> {
632 self.invalidator.stop().await;
633 self.prefetcher.stop().await;
634 Ok(())
635 }
636}
637
638#[derive(Debug, Clone)]
640pub struct DistribCacheStats {
641 pub l1: TierStats,
643 pub l2: TierStats,
645 pub l3: TierStats,
647 pub overall_hit_ratio: f64,
649 pub time_saved_seconds: f64,
651 pub queries_avoided: u64,
653}
654
655#[cfg(test)]
656mod tests {
657 use super::*;
658
659 #[test]
660 fn test_query_fingerprint() {
661 let fp = QueryFingerprint::from_query("SELECT * FROM users WHERE id = 42");
662 assert!(fp.template.contains("SELECT"));
663 assert!(fp.template.contains("USERS"));
664 assert!(fp.tables.contains(&"USERS".to_string()));
665 }
666
667 #[test]
668 fn test_query_fingerprint_normalization() {
669 let fp1 = QueryFingerprint::from_query("SELECT * FROM users WHERE id = 1");
670 let fp2 = QueryFingerprint::from_query("SELECT * FROM users WHERE id = 2");
671 assert_eq!(fp1.template, fp2.template);
673 }
674
675 #[test]
676 fn test_session_id() {
677 let sid = SessionId::new("test-session");
678 assert_eq!(sid.0, "test-session");
679 }
680
681 #[test]
682 fn test_query_context() {
683 let ctx = QueryContext::new("session-1")
684 .with_workload_hint(WorkloadType::OLTP)
685 .with_branch("feature-x");
686
687 assert_eq!(ctx.session_id.0, "session-1");
688 assert_eq!(ctx.workload_hint, Some(WorkloadType::OLTP));
689 assert_eq!(ctx.branch, Some("feature-x".to_string()));
690 }
691}