Skip to main content

heliosdb_proxy/distribcache/ai/
semantic.rs

1//! Semantic query cache
2//!
3//! Caches query results based on semantic similarity of queries.
4//! Uses embedding-based lookup to find similar queries and return cached results.
5//!
6//! # Branch-Aware Caching
7//!
8//! The cache supports branch-aware lookups for time-travel queries:
9//! - Entries can be scoped to specific branches
10//! - Lookups can filter by branch context
11//! - Cross-branch semantic search is supported
12//!
13//! # AI/Agent Optimizations
14//!
15//! - Session affinity for agent conversations
16//! - Workload-aware TTL adjustments
17//! - RAG-specific caching strategies
18
19use dashmap::DashMap;
20use std::collections::BinaryHeap;
21use std::sync::atomic::{AtomicU64, Ordering};
22use std::time::{Duration, Instant};
23
24/// Vector ID for semantic indexing
25pub type VectorId = u64;
26
27/// Branch identifier for branch-aware caching
28pub type BranchId = String;
29
30/// Session identifier for agent sessions
31pub type SessionId = String;
32
33/// Branch context for cache entries
34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
35pub struct BranchContext {
36    /// Branch name (e.g., "main", "feature-x")
37    pub branch: BranchId,
38    /// Optional snapshot timestamp for time-travel queries
39    pub snapshot_at: Option<u64>,
40}
41
42impl BranchContext {
43    /// Create a new branch context
44    pub fn new(branch: impl Into<String>) -> Self {
45        Self {
46            branch: branch.into(),
47            snapshot_at: None,
48        }
49    }
50
51    /// Create branch context with snapshot time
52    pub fn with_snapshot(branch: impl Into<String>, snapshot: u64) -> Self {
53        Self {
54            branch: branch.into(),
55            snapshot_at: Some(snapshot),
56        }
57    }
58
59    /// Main branch context
60    pub fn main() -> Self {
61        Self::new("main")
62    }
63
64    /// Check if this context is compatible with another (for cache hits)
65    pub fn is_compatible(&self, other: &BranchContext) -> bool {
66        if self.branch != other.branch {
67            return false;
68        }
69        // Snapshot compatibility: either both None, or entry snapshot <= query snapshot
70        match (self.snapshot_at, other.snapshot_at) {
71            (None, None) => true,
72            (Some(entry_snap), Some(query_snap)) => entry_snap <= query_snap,
73            (None, Some(_)) => true, // Current data valid for any snapshot
74            (Some(_), None) => false, // Historical entry not valid for current query
75        }
76    }
77}
78
79impl Default for BranchContext {
80    fn default() -> Self {
81        Self::main()
82    }
83}
84
85/// AI workload context for cache optimization
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum AIWorkloadContext {
88    /// RAG retrieval phase - fast, high-throughput
89    RAGRetrieval,
90    /// RAG generation phase - slower, lower frequency
91    RAGGeneration,
92    /// Agent conversation - session-aware
93    AgentConversation,
94    /// Tool call caching - deterministic
95    ToolResult,
96    /// General semantic query
97    General,
98}
99
100impl Default for AIWorkloadContext {
101    fn default() -> Self {
102        Self::General
103    }
104}
105
106/// Query embedding vector
107pub type Embedding = Vec<f32>;
108
109/// Semantic query entry with branch and session awareness
110#[derive(Debug, Clone)]
111pub struct SemanticEntry {
112    /// Entry ID
113    pub id: VectorId,
114    /// Original query
115    pub query: String,
116    /// Query embedding
117    pub embedding: Embedding,
118    /// Cached result
119    pub result: serde_json::Value,
120    /// Creation time
121    pub created_at: Instant,
122    /// TTL
123    pub ttl: Duration,
124    /// Access count
125    pub access_count: u64,
126    /// Branch context (for branch-aware caching)
127    pub branch_context: Option<BranchContext>,
128    /// Session ID (for agent conversation affinity)
129    pub session_id: Option<SessionId>,
130    /// AI workload type
131    pub workload: AIWorkloadContext,
132    /// Tables referenced by this query (for invalidation)
133    pub tables: Vec<String>,
134}
135
136impl SemanticEntry {
137    /// Create a new semantic entry
138    pub fn new(
139        id: VectorId,
140        query: impl Into<String>,
141        embedding: Embedding,
142        result: serde_json::Value,
143    ) -> Self {
144        Self {
145            id,
146            query: query.into(),
147            embedding,
148            result,
149            created_at: Instant::now(),
150            ttl: Duration::from_secs(3600), // Default 1 hour
151            access_count: 0,
152            branch_context: None,
153            session_id: None,
154            workload: AIWorkloadContext::default(),
155            tables: Vec::new(),
156        }
157    }
158
159    /// Set TTL
160    pub fn with_ttl(mut self, ttl: Duration) -> Self {
161        self.ttl = ttl;
162        self
163    }
164
165    /// Set branch context
166    pub fn with_branch(mut self, branch: BranchContext) -> Self {
167        self.branch_context = Some(branch);
168        self
169    }
170
171    /// Set session ID for agent affinity
172    pub fn with_session(mut self, session: impl Into<String>) -> Self {
173        self.session_id = Some(session.into());
174        self
175    }
176
177    /// Set AI workload context
178    pub fn with_workload(mut self, workload: AIWorkloadContext) -> Self {
179        self.workload = workload;
180        self
181    }
182
183    /// Set referenced tables for invalidation tracking
184    pub fn with_tables(mut self, tables: Vec<String>) -> Self {
185        self.tables = tables;
186        self
187    }
188
189    /// Get workload-adjusted TTL
190    pub fn workload_ttl(&self) -> Duration {
191        match self.workload {
192            AIWorkloadContext::RAGRetrieval => Duration::from_secs(300), // 5 min - fast refresh
193            AIWorkloadContext::RAGGeneration => Duration::from_secs(1800), // 30 min - slower refresh
194            AIWorkloadContext::AgentConversation => Duration::from_secs(3600), // 1 hour - session lifetime
195            AIWorkloadContext::ToolResult => Duration::from_secs(86400), // 24 hours - deterministic
196            AIWorkloadContext::General => self.ttl,
197        }
198    }
199
200    /// Check if expired (considering workload-adjusted TTL)
201    pub fn is_expired(&self) -> bool {
202        self.created_at.elapsed() > self.workload_ttl()
203    }
204
205    /// Check if entry matches branch context
206    pub fn matches_branch(&self, query_branch: &BranchContext) -> bool {
207        match &self.branch_context {
208            None => true, // No branch restriction
209            Some(entry_branch) => entry_branch.is_compatible(query_branch),
210        }
211    }
212
213    /// Check if entry belongs to session
214    pub fn matches_session(&self, session: &SessionId) -> bool {
215        match &self.session_id {
216            None => true,
217            Some(entry_session) => entry_session == session,
218        }
219    }
220
221    /// Approximate size in bytes
222    pub fn size(&self) -> usize {
223        self.query.len() +
224        self.embedding.len() * 4 +
225        self.result.to_string().len() +
226        self.tables.iter().map(|t| t.len()).sum::<usize>() +
227        self.session_id.as_ref().map(|s| s.len()).unwrap_or(0) +
228        self.branch_context.as_ref().map(|b| b.branch.len() + 8).unwrap_or(0) +
229        96
230    }
231}
232
233/// Similarity search result
234#[derive(Debug, Clone)]
235pub struct SimilarityResult {
236    /// Entry ID
237    pub id: VectorId,
238    /// Similarity score (0.0 - 1.0)
239    pub similarity: f32,
240    /// The entry
241    pub entry: SemanticEntry,
242}
243
244impl PartialEq for SimilarityResult {
245    fn eq(&self, other: &Self) -> bool {
246        self.similarity == other.similarity
247    }
248}
249
250impl Eq for SimilarityResult {}
251
252impl PartialOrd for SimilarityResult {
253    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
254        Some(self.cmp(other))
255    }
256}
257
258impl Ord for SimilarityResult {
259    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
260        // Reverse order for min-heap (we want highest similarity)
261        other.similarity.partial_cmp(&self.similarity)
262            .unwrap_or(std::cmp::Ordering::Equal)
263    }
264}
265
266/// Simple HNSW-like index for semantic search
267/// (Simplified implementation - production would use a proper HNSW library)
268pub struct SemanticIndex {
269    /// All vectors with their IDs
270    vectors: DashMap<VectorId, Embedding>,
271
272    /// Index configuration
273    config: SemanticIndexConfig,
274
275    /// Next ID
276    next_id: AtomicU64,
277}
278
279/// Semantic index configuration
280#[derive(Debug, Clone)]
281pub struct SemanticIndexConfig {
282    /// Maximum connections per node (M parameter)
283    pub max_connections: usize,
284    /// Search expansion factor (ef parameter)
285    pub ef_search: usize,
286    /// Embedding dimension
287    pub dimension: usize,
288}
289
290impl Default for SemanticIndexConfig {
291    fn default() -> Self {
292        Self {
293            max_connections: 16,
294            ef_search: 100,
295            dimension: 384, // Common embedding size
296        }
297    }
298}
299
300impl SemanticIndex {
301    /// Create a new semantic index
302    pub fn new(config: SemanticIndexConfig) -> Self {
303        Self {
304            vectors: DashMap::new(),
305            config,
306            next_id: AtomicU64::new(1),
307        }
308    }
309
310    /// Insert a vector and return its ID
311    pub fn insert(&self, embedding: Embedding) -> VectorId {
312        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
313        self.vectors.insert(id, embedding);
314        id
315    }
316
317    /// Remove a vector
318    pub fn remove(&self, id: VectorId) {
319        self.vectors.remove(&id);
320    }
321
322    /// Search for k nearest neighbors
323    pub fn search(&self, query: &[f32], k: usize) -> Vec<(VectorId, f32)> {
324        // Brute force search (production would use HNSW)
325        let mut heap: BinaryHeap<(std::cmp::Reverse<i64>, VectorId)> = BinaryHeap::new();
326
327        for entry in self.vectors.iter() {
328            let similarity = cosine_similarity(query, entry.value());
329            // Convert to integer for ordering (multiply by 1M for precision)
330            let sim_int = (similarity * 1_000_000.0) as i64;
331            heap.push((std::cmp::Reverse(sim_int), *entry.key()));
332
333            if heap.len() > k {
334                heap.pop();
335            }
336        }
337
338        // Extract results in descending similarity order
339        let mut results: Vec<_> = heap.into_iter()
340            .map(|(std::cmp::Reverse(sim), id)| (id, sim as f32 / 1_000_000.0))
341            .collect();
342
343        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
344        results
345    }
346
347    /// Get vector count
348    pub fn len(&self) -> usize {
349        self.vectors.len()
350    }
351
352    /// Check if empty
353    pub fn is_empty(&self) -> bool {
354        self.vectors.is_empty()
355    }
356
357    /// Clear all vectors
358    pub fn clear(&self) {
359        self.vectors.clear();
360    }
361}
362
363/// Compute cosine similarity between two vectors
364pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
365    if a.len() != b.len() || a.is_empty() {
366        return 0.0;
367    }
368
369    let mut dot = 0.0;
370    let mut norm_a = 0.0;
371    let mut norm_b = 0.0;
372
373    for i in 0..a.len() {
374        dot += a[i] * b[i];
375        norm_a += a[i] * a[i];
376        norm_b += b[i] * b[i];
377    }
378
379    let denominator = (norm_a * norm_b).sqrt();
380    if denominator == 0.0 {
381        0.0
382    } else {
383        dot / denominator
384    }
385}
386
387/// Semantic query cache
388pub struct SemanticQueryCache {
389    /// Query index
390    index: SemanticIndex,
391
392    /// Cached entries
393    entries: DashMap<VectorId, SemanticEntry>,
394
395    /// Similarity threshold for cache hit
396    threshold: f32,
397
398    /// Maximum entries
399    max_entries: usize,
400
401    /// Statistics
402    stats: SemanticCacheStats,
403}
404
405/// Semantic cache statistics
406#[derive(Debug, Default)]
407struct SemanticCacheStats {
408    hits: AtomicU64,
409    misses: AtomicU64,
410    semantic_hits: AtomicU64,
411    exact_hits: AtomicU64,
412    insertions: AtomicU64,
413    evictions: AtomicU64,
414}
415
416impl SemanticQueryCache {
417    /// Create a new semantic query cache with default max entries
418    pub fn new(threshold: f32) -> Self {
419        Self::with_capacity(threshold, 10000)
420    }
421
422    /// Create a new semantic query cache with specified capacity
423    pub fn with_capacity(threshold: f32, max_entries: usize) -> Self {
424        Self {
425            index: SemanticIndex::new(SemanticIndexConfig::default()),
426            entries: DashMap::new(),
427            threshold,
428            max_entries,
429            stats: SemanticCacheStats::default(),
430        }
431    }
432
433    /// Create with custom index config
434    pub fn with_config(threshold: f32, max_entries: usize, index_config: SemanticIndexConfig) -> Self {
435        Self {
436            index: SemanticIndex::new(index_config),
437            entries: DashMap::new(),
438            threshold,
439            max_entries,
440            stats: SemanticCacheStats::default(),
441        }
442    }
443
444    /// Lookup by semantic similarity
445    pub fn lookup(&self, embedding: &[f32]) -> Option<SimilarityResult> {
446        // Search for nearest neighbor
447        let results = self.index.search(embedding, 1);
448
449        if let Some((id, similarity)) = results.first() {
450            if *similarity >= self.threshold {
451                if let Some(entry) = self.entries.get(id) {
452                    if !entry.is_expired() {
453                        self.stats.hits.fetch_add(1, Ordering::Relaxed);
454
455                        if *similarity > 0.999 {
456                            self.stats.exact_hits.fetch_add(1, Ordering::Relaxed);
457                        } else {
458                            self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
459                        }
460
461                        return Some(SimilarityResult {
462                            id: *id,
463                            similarity: *similarity,
464                            entry: entry.clone(),
465                        });
466                    } else {
467                        // Remove expired entry
468                        drop(entry);
469                        self.remove(*id);
470                    }
471                }
472            }
473        }
474
475        self.stats.misses.fetch_add(1, Ordering::Relaxed);
476        None
477    }
478
479    /// Lookup with custom threshold
480    pub fn lookup_with_threshold(&self, embedding: &[f32], threshold: f32) -> Option<SimilarityResult> {
481        let results = self.index.search(embedding, 1);
482
483        if let Some((id, similarity)) = results.first() {
484            if *similarity >= threshold {
485                if let Some(entry) = self.entries.get(id) {
486                    if !entry.is_expired() {
487                        return Some(SimilarityResult {
488                            id: *id,
489                            similarity: *similarity,
490                            entry: entry.clone(),
491                        });
492                    }
493                }
494            }
495        }
496
497        None
498    }
499
500    /// Find k most similar queries
501    pub fn find_similar(&self, embedding: &[f32], k: usize) -> Vec<SimilarityResult> {
502        let results = self.index.search(embedding, k);
503
504        results.into_iter()
505            .filter_map(|(id, similarity)| {
506                self.entries.get(&id).and_then(|entry| {
507                    if !entry.is_expired() {
508                        Some(SimilarityResult {
509                            id,
510                            similarity,
511                            entry: entry.clone(),
512                        })
513                    } else {
514                        None
515                    }
516                })
517            })
518            .collect()
519    }
520
521    /// Lookup with branch context filtering
522    ///
523    /// Returns cached entry only if it's compatible with the given branch context.
524    /// This enables branch-aware caching for time-travel queries.
525    pub fn lookup_with_branch(
526        &self,
527        embedding: &[f32],
528        branch: &BranchContext,
529    ) -> Option<SimilarityResult> {
530        // Search for multiple candidates to filter by branch
531        let results = self.index.search(embedding, 10);
532
533        for (id, similarity) in results {
534            if similarity < self.threshold {
535                break; // Results are sorted by similarity
536            }
537
538            if let Some(entry) = self.entries.get(&id) {
539                if !entry.is_expired() && entry.matches_branch(branch) {
540                    self.stats.hits.fetch_add(1, Ordering::Relaxed);
541                    if similarity > 0.999 {
542                        self.stats.exact_hits.fetch_add(1, Ordering::Relaxed);
543                    } else {
544                        self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
545                    }
546
547                    return Some(SimilarityResult {
548                        id,
549                        similarity,
550                        entry: entry.clone(),
551                    });
552                }
553            }
554        }
555
556        self.stats.misses.fetch_add(1, Ordering::Relaxed);
557        None
558    }
559
560    /// Lookup with session affinity for agent conversations
561    ///
562    /// Prioritizes entries from the same session for better conversation context.
563    pub fn lookup_with_session(
564        &self,
565        embedding: &[f32],
566        session: &SessionId,
567    ) -> Option<SimilarityResult> {
568        let results = self.index.search(embedding, 20);
569
570        // First pass: look for same-session matches
571        for (id, similarity) in &results {
572            if *similarity < self.threshold {
573                break;
574            }
575
576            if let Some(entry) = self.entries.get(id) {
577                if !entry.is_expired() && entry.matches_session(session) {
578                    self.stats.hits.fetch_add(1, Ordering::Relaxed);
579                    self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
580
581                    return Some(SimilarityResult {
582                        id: *id,
583                        similarity: *similarity,
584                        entry: entry.clone(),
585                    });
586                }
587            }
588        }
589
590        // Second pass: any matching entry (cross-session)
591        for (id, similarity) in &results {
592            if *similarity < self.threshold {
593                break;
594            }
595
596            if let Some(entry) = self.entries.get(id) {
597                if !entry.is_expired() {
598                    self.stats.hits.fetch_add(1, Ordering::Relaxed);
599                    self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
600
601                    return Some(SimilarityResult {
602                        id: *id,
603                        similarity: *similarity,
604                        entry: entry.clone(),
605                    });
606                }
607            }
608        }
609
610        self.stats.misses.fetch_add(1, Ordering::Relaxed);
611        None
612    }
613
614    /// Lookup with full AI context (branch + session + workload)
615    ///
616    /// Most comprehensive lookup that considers:
617    /// - Branch context for time-travel
618    /// - Session affinity for agent conversations
619    /// - Workload type for TTL and priority
620    pub fn lookup_with_context(
621        &self,
622        embedding: &[f32],
623        branch: Option<&BranchContext>,
624        session: Option<&SessionId>,
625        workload: AIWorkloadContext,
626    ) -> Option<SimilarityResult> {
627        let results = self.index.search(embedding, 20);
628
629        // Priority 1: Same session + same branch + same workload
630        for (id, similarity) in &results {
631            if *similarity < self.threshold {
632                break;
633            }
634
635            if let Some(entry) = self.entries.get(id) {
636                let branch_match = branch.map(|b| entry.matches_branch(b)).unwrap_or(true);
637                let session_match = session.map(|s| entry.matches_session(s)).unwrap_or(false);
638                let workload_match = entry.workload == workload;
639
640                if !entry.is_expired() && branch_match && session_match && workload_match {
641                    self.stats.hits.fetch_add(1, Ordering::Relaxed);
642                    self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
643                    return Some(SimilarityResult {
644                        id: *id,
645                        similarity: *similarity,
646                        entry: entry.clone(),
647                    });
648                }
649            }
650        }
651
652        // Priority 2: Same branch + same workload
653        for (id, similarity) in &results {
654            if *similarity < self.threshold {
655                break;
656            }
657
658            if let Some(entry) = self.entries.get(id) {
659                let branch_match = branch.map(|b| entry.matches_branch(b)).unwrap_or(true);
660                let workload_match = entry.workload == workload;
661
662                if !entry.is_expired() && branch_match && workload_match {
663                    self.stats.hits.fetch_add(1, Ordering::Relaxed);
664                    self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
665                    return Some(SimilarityResult {
666                        id: *id,
667                        similarity: *similarity,
668                        entry: entry.clone(),
669                    });
670                }
671            }
672        }
673
674        // Priority 3: Same branch only
675        for (id, similarity) in &results {
676            if *similarity < self.threshold {
677                break;
678            }
679
680            if let Some(entry) = self.entries.get(id) {
681                let branch_match = branch.map(|b| entry.matches_branch(b)).unwrap_or(true);
682
683                if !entry.is_expired() && branch_match {
684                    self.stats.hits.fetch_add(1, Ordering::Relaxed);
685                    self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
686                    return Some(SimilarityResult {
687                        id: *id,
688                        similarity: *similarity,
689                        entry: entry.clone(),
690                    });
691                }
692            }
693        }
694
695        self.stats.misses.fetch_add(1, Ordering::Relaxed);
696        None
697    }
698
699    /// Find similar entries within a branch
700    pub fn find_similar_in_branch(
701        &self,
702        embedding: &[f32],
703        branch: &BranchContext,
704        k: usize,
705    ) -> Vec<SimilarityResult> {
706        // Search for more candidates since we'll filter
707        let results = self.index.search(embedding, k * 3);
708
709        results.into_iter()
710            .filter_map(|(id, similarity)| {
711                self.entries.get(&id).and_then(|entry| {
712                    if !entry.is_expired() && entry.matches_branch(branch) {
713                        Some(SimilarityResult {
714                            id,
715                            similarity,
716                            entry: entry.clone(),
717                        })
718                    } else {
719                        None
720                    }
721                })
722            })
723            .take(k)
724            .collect()
725    }
726
727    /// Invalidate entries by table name
728    ///
729    /// Used when WAL invalidation detects changes to a table.
730    pub fn invalidate_by_table(&self, table: &str) -> usize {
731        let to_remove: Vec<_> = self.entries.iter()
732            .filter(|e| e.tables.iter().any(|t| t == table))
733            .map(|e| *e.key())
734            .collect();
735
736        let count = to_remove.len();
737        for id in to_remove {
738            self.remove(id);
739        }
740        count
741    }
742
743    /// Invalidate entries by branch
744    pub fn invalidate_branch(&self, branch: &BranchId) -> usize {
745        let to_remove: Vec<_> = self.entries.iter()
746            .filter(|e| {
747                e.branch_context.as_ref()
748                    .map(|b| &b.branch == branch)
749                    .unwrap_or(false)
750            })
751            .map(|e| *e.key())
752            .collect();
753
754        let count = to_remove.len();
755        for id in to_remove {
756            self.remove(id);
757        }
758        count
759    }
760
761    /// Insert a new entry
762    pub fn insert(&self, query: impl Into<String>, embedding: Embedding, result: serde_json::Value) -> VectorId {
763        // Evict if at capacity
764        while self.entries.len() >= self.max_entries {
765            self.evict_one();
766        }
767
768        // Insert into index
769        let id = self.index.insert(embedding.clone());
770
771        // Create and store entry
772        let entry = SemanticEntry::new(id, query, embedding, result);
773        self.entries.insert(id, entry);
774
775        self.stats.insertions.fetch_add(1, Ordering::Relaxed);
776        id
777    }
778
779    /// Insert with TTL
780    pub fn insert_with_ttl(
781        &self,
782        query: impl Into<String>,
783        embedding: Embedding,
784        result: serde_json::Value,
785        ttl: Duration,
786    ) -> VectorId {
787        while self.entries.len() >= self.max_entries {
788            self.evict_one();
789        }
790
791        let id = self.index.insert(embedding.clone());
792        let entry = SemanticEntry::new(id, query, embedding, result).with_ttl(ttl);
793        self.entries.insert(id, entry);
794
795        self.stats.insertions.fetch_add(1, Ordering::Relaxed);
796        id
797    }
798
799    /// Insert with full AI context (branch, session, workload, tables)
800    ///
801    /// This is the recommended insertion method for AI/Agent workloads
802    /// as it enables branch-aware caching, session affinity, and
803    /// workload-specific TTL management.
804    pub fn insert_with_context(
805        &self,
806        query: impl Into<String>,
807        embedding: Embedding,
808        result: serde_json::Value,
809        branch: Option<BranchContext>,
810        session: Option<SessionId>,
811        workload: AIWorkloadContext,
812        tables: Vec<String>,
813    ) -> VectorId {
814        while self.entries.len() >= self.max_entries {
815            self.evict_one();
816        }
817
818        let id = self.index.insert(embedding.clone());
819        let mut entry = SemanticEntry::new(id, query, embedding, result)
820            .with_workload(workload)
821            .with_tables(tables);
822
823        if let Some(b) = branch {
824            entry = entry.with_branch(b);
825        }
826        if let Some(s) = session {
827            entry = entry.with_session(s);
828        }
829
830        self.entries.insert(id, entry);
831        self.stats.insertions.fetch_add(1, Ordering::Relaxed);
832        id
833    }
834
835    /// Insert for RAG retrieval workload
836    ///
837    /// Optimized TTL for fast-refresh retrieval phase.
838    pub fn insert_rag_retrieval(
839        &self,
840        query: impl Into<String>,
841        embedding: Embedding,
842        result: serde_json::Value,
843        tables: Vec<String>,
844    ) -> VectorId {
845        self.insert_with_context(
846            query,
847            embedding,
848            result,
849            None,
850            None,
851            AIWorkloadContext::RAGRetrieval,
852            tables,
853        )
854    }
855
856    /// Insert for agent conversation
857    ///
858    /// Session-aware with longer TTL for conversation context.
859    pub fn insert_agent_response(
860        &self,
861        query: impl Into<String>,
862        embedding: Embedding,
863        result: serde_json::Value,
864        session: SessionId,
865        branch: Option<BranchContext>,
866    ) -> VectorId {
867        self.insert_with_context(
868            query,
869            embedding,
870            result,
871            branch,
872            Some(session),
873            AIWorkloadContext::AgentConversation,
874            Vec::new(),
875        )
876    }
877
878    /// Insert deterministic tool result
879    ///
880    /// Long TTL for deterministic tool calls (e.g., math, date formatting).
881    pub fn insert_tool_result(
882        &self,
883        query: impl Into<String>,
884        embedding: Embedding,
885        result: serde_json::Value,
886    ) -> VectorId {
887        self.insert_with_context(
888            query,
889            embedding,
890            result,
891            None,
892            None,
893            AIWorkloadContext::ToolResult,
894            Vec::new(),
895        )
896    }
897
898    /// Remove an entry
899    pub fn remove(&self, id: VectorId) {
900        self.index.remove(id);
901        self.entries.remove(&id);
902    }
903
904    /// Evict one entry (oldest by creation time)
905    fn evict_one(&self) {
906        let mut oldest_id = None;
907        let mut oldest_time = Instant::now();
908
909        for entry in self.entries.iter() {
910            if entry.created_at < oldest_time {
911                oldest_time = entry.created_at;
912                oldest_id = Some(*entry.key());
913            }
914        }
915
916        if let Some(id) = oldest_id {
917            self.remove(id);
918            self.stats.evictions.fetch_add(1, Ordering::Relaxed);
919        }
920    }
921
922    /// Remove expired entries
923    pub fn cleanup_expired(&self) {
924        let expired: Vec<_> = self.entries.iter()
925            .filter(|e| e.is_expired())
926            .map(|e| *e.key())
927            .collect();
928
929        for id in expired {
930            self.remove(id);
931        }
932    }
933
934    /// Clear all entries
935    pub fn clear(&self) {
936        self.index.clear();
937        self.entries.clear();
938    }
939
940    /// Get entry count
941    pub fn len(&self) -> usize {
942        self.entries.len()
943    }
944
945    /// Check if empty
946    pub fn is_empty(&self) -> bool {
947        self.entries.is_empty()
948    }
949
950    /// Get statistics
951    pub fn stats(&self) -> SemanticCacheStatsSnapshot {
952        let hits = self.stats.hits.load(Ordering::Relaxed);
953        let misses = self.stats.misses.load(Ordering::Relaxed);
954        let total = hits + misses;
955
956        SemanticCacheStatsSnapshot {
957            entries: self.entries.len(),
958            threshold: self.threshold,
959            hits,
960            misses,
961            hit_rate: if total > 0 { hits as f64 / total as f64 } else { 0.0 },
962            semantic_hits: self.stats.semantic_hits.load(Ordering::Relaxed),
963            exact_hits: self.stats.exact_hits.load(Ordering::Relaxed),
964            insertions: self.stats.insertions.load(Ordering::Relaxed),
965            evictions: self.stats.evictions.load(Ordering::Relaxed),
966        }
967    }
968}
969
970/// Semantic cache statistics snapshot
971#[derive(Debug, Clone)]
972pub struct SemanticCacheStatsSnapshot {
973    pub entries: usize,
974    pub threshold: f32,
975    pub hits: u64,
976    pub misses: u64,
977    pub hit_rate: f64,
978    pub semantic_hits: u64,
979    pub exact_hits: u64,
980    pub insertions: u64,
981    pub evictions: u64,
982}
983
984#[cfg(test)]
985mod tests {
986    use super::*;
987    use serde_json::json;
988
989    #[test]
990    fn test_cosine_similarity() {
991        let a = vec![1.0, 0.0, 0.0];
992        let b = vec![1.0, 0.0, 0.0];
993        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
994
995        let c = vec![0.0, 1.0, 0.0];
996        assert!(cosine_similarity(&a, &c).abs() < 0.001);
997
998        let d = vec![0.707, 0.707, 0.0];
999        let sim = cosine_similarity(&a, &d);
1000        assert!((sim - 0.707).abs() < 0.01);
1001    }
1002
1003    #[test]
1004    fn test_semantic_index() {
1005        let index = SemanticIndex::new(SemanticIndexConfig::default());
1006
1007        let id1 = index.insert(vec![1.0, 0.0, 0.0]);
1008        let id2 = index.insert(vec![0.9, 0.1, 0.0]);
1009        let id3 = index.insert(vec![0.0, 1.0, 0.0]);
1010
1011        // Search for vector similar to [1.0, 0.0, 0.0]
1012        let results = index.search(&[1.0, 0.0, 0.0], 2);
1013
1014        assert_eq!(results.len(), 2);
1015        assert_eq!(results[0].0, id1); // Most similar
1016        assert_eq!(results[1].0, id2); // Second most similar
1017    }
1018
1019    #[test]
1020    fn test_semantic_cache_insert_lookup() {
1021        let cache = SemanticQueryCache::with_capacity(0.9, 100);
1022
1023        let embedding = vec![1.0, 0.0, 0.0];
1024        let id = cache.insert(
1025            "SELECT * FROM users WHERE name = 'test'",
1026            embedding.clone(),
1027            json!({"count": 5}),
1028        );
1029
1030        // Exact lookup
1031        let result = cache.lookup(&embedding);
1032        assert!(result.is_some());
1033        let res = result.unwrap();
1034        assert_eq!(res.id, id);
1035        assert!(res.similarity > 0.999);
1036    }
1037
1038    #[test]
1039    fn test_semantic_similarity_lookup() {
1040        let cache = SemanticQueryCache::with_capacity(0.9, 100);
1041
1042        // Insert original query
1043        cache.insert(
1044            "SELECT * FROM users WHERE id = 1",
1045            vec![1.0, 0.0, 0.0],
1046            json!({"user": "alice"}),
1047        );
1048
1049        // Lookup with similar query embedding
1050        let similar_embedding = vec![0.95, 0.05, 0.0];
1051        let result = cache.lookup(&similar_embedding);
1052
1053        assert!(result.is_some());
1054        let res = result.unwrap();
1055        assert!(res.similarity >= 0.9);
1056    }
1057
1058    #[test]
1059    fn test_threshold_rejection() {
1060        let cache = SemanticQueryCache::with_capacity(0.95, 100);
1061
1062        cache.insert(
1063            "SELECT * FROM orders",
1064            vec![1.0, 0.0, 0.0],
1065            json!({"total": 100}),
1066        );
1067
1068        // Query too different (below threshold)
1069        let different_embedding = vec![0.7, 0.7, 0.0];
1070        let result = cache.lookup(&different_embedding);
1071        assert!(result.is_none());
1072    }
1073
1074    #[test]
1075    fn test_find_similar() {
1076        let cache = SemanticQueryCache::with_capacity(0.5, 100);
1077
1078        cache.insert("query1", vec![1.0, 0.0, 0.0], json!(1));
1079        cache.insert("query2", vec![0.9, 0.1, 0.0], json!(2));
1080        cache.insert("query3", vec![0.8, 0.2, 0.0], json!(3));
1081        cache.insert("query4", vec![0.0, 1.0, 0.0], json!(4));
1082
1083        let similar = cache.find_similar(&[1.0, 0.0, 0.0], 3);
1084
1085        assert_eq!(similar.len(), 3);
1086        // First should be most similar
1087        assert!(similar[0].similarity > similar[1].similarity);
1088        assert!(similar[1].similarity > similar[2].similarity);
1089    }
1090
1091    #[test]
1092    fn test_expiration() {
1093        let cache = SemanticQueryCache::with_capacity(0.9, 100);
1094
1095        let embedding = vec![1.0, 0.0, 0.0];
1096        cache.insert_with_ttl(
1097            "expiring query",
1098            embedding.clone(),
1099            json!({"expires": true}),
1100            Duration::from_millis(1),
1101        );
1102
1103        // Wait for expiration
1104        std::thread::sleep(Duration::from_millis(10));
1105
1106        let result = cache.lookup(&embedding);
1107        assert!(result.is_none());
1108    }
1109
1110    #[test]
1111    fn test_eviction() {
1112        let cache = SemanticQueryCache::with_capacity(0.9, 3);
1113
1114        // Fill cache
1115        for i in 0..3 {
1116            cache.insert(
1117                format!("query{}", i),
1118                vec![i as f32, 0.0, 0.0],
1119                json!(i),
1120            );
1121        }
1122
1123        assert_eq!(cache.len(), 3);
1124
1125        // Insert one more (should evict)
1126        cache.insert("query3", vec![3.0, 0.0, 0.0], json!(3));
1127
1128        assert_eq!(cache.len(), 3);
1129    }
1130
1131    #[test]
1132    fn test_stats() {
1133        let cache = SemanticQueryCache::with_capacity(0.9, 100);
1134
1135        let embedding = vec![1.0, 0.0, 0.0];
1136        cache.insert("test query", embedding.clone(), json!(1));
1137
1138        // Hits
1139        cache.lookup(&embedding);
1140        cache.lookup(&embedding);
1141
1142        // Miss
1143        cache.lookup(&[0.0, 1.0, 0.0]);
1144
1145        let stats = cache.stats();
1146        assert_eq!(stats.hits, 2);
1147        assert_eq!(stats.misses, 1);
1148        assert_eq!(stats.exact_hits, 2);
1149        assert_eq!(stats.insertions, 1);
1150    }
1151
1152    #[test]
1153    fn test_branch_context_compatibility() {
1154        let main = BranchContext::main();
1155        let feature = BranchContext::new("feature-x");
1156        let snapshot = BranchContext::with_snapshot("main", 1000);
1157        let later_snapshot = BranchContext::with_snapshot("main", 2000);
1158
1159        // Same branch
1160        assert!(main.is_compatible(&main));
1161        assert!(!main.is_compatible(&feature));
1162
1163        // Snapshot compatibility
1164        assert!(snapshot.is_compatible(&later_snapshot)); // Entry @ 1000 valid for query @ 2000
1165        assert!(!later_snapshot.is_compatible(&snapshot)); // Entry @ 2000 not valid for query @ 1000
1166    }
1167
1168    #[test]
1169    fn test_lookup_with_branch() {
1170        let cache = SemanticQueryCache::with_capacity(0.9, 100);
1171
1172        // Insert entry for main branch
1173        let embedding = vec![1.0, 0.0, 0.0];
1174        cache.insert_with_context(
1175            "SELECT * FROM users",
1176            embedding.clone(),
1177            json!({"users": []}),
1178            Some(BranchContext::main()),
1179            None,
1180            AIWorkloadContext::General,
1181            vec!["users".to_string()],
1182        );
1183
1184        // Insert entry for feature branch
1185        let embedding2 = vec![0.95, 0.05, 0.0];
1186        cache.insert_with_context(
1187            "SELECT * FROM users",
1188            embedding2.clone(),
1189            json!({"users": ["new_user"]}),
1190            Some(BranchContext::new("feature-x")),
1191            None,
1192            AIWorkloadContext::General,
1193            vec!["users".to_string()],
1194        );
1195
1196        // Lookup for main should find main entry
1197        let main_result = cache.lookup_with_branch(&embedding, &BranchContext::main());
1198        assert!(main_result.is_some());
1199        assert_eq!(main_result.unwrap().entry.branch_context.as_ref().unwrap().branch, "main");
1200
1201        // Lookup for feature should find feature entry
1202        let feature_result = cache.lookup_with_branch(&embedding2, &BranchContext::new("feature-x"));
1203        assert!(feature_result.is_some());
1204        assert_eq!(feature_result.unwrap().entry.branch_context.as_ref().unwrap().branch, "feature-x");
1205    }
1206
1207    #[test]
1208    fn test_lookup_with_session() {
1209        let cache = SemanticQueryCache::with_capacity(0.9, 100);
1210        let session1 = "session-001".to_string();
1211        let session2 = "session-002".to_string();
1212
1213        // Insert for session 1
1214        let embedding = vec![1.0, 0.0, 0.0];
1215        cache.insert_agent_response(
1216            "What is the weather?",
1217            embedding.clone(),
1218            json!({"weather": "sunny"}),
1219            session1.clone(),
1220            None,
1221        );
1222
1223        // Insert similar query for session 2
1224        let embedding2 = vec![0.98, 0.02, 0.0];
1225        cache.insert_agent_response(
1226            "How's the weather?",
1227            embedding2,
1228            json!({"weather": "cloudy"}),
1229            session2.clone(),
1230            None,
1231        );
1232
1233        // Lookup with session 1 should prefer session 1 entry
1234        let result = cache.lookup_with_session(&embedding, &session1);
1235        assert!(result.is_some());
1236        assert_eq!(result.unwrap().entry.session_id.as_ref().unwrap(), &session1);
1237    }
1238
1239    #[test]
1240    fn test_lookup_with_context() {
1241        let cache = SemanticQueryCache::with_capacity(0.8, 100);
1242        let session = "agent-session".to_string();
1243        let branch = BranchContext::main();
1244
1245        // Insert with full context
1246        let embedding = vec![1.0, 0.0, 0.0];
1247        cache.insert_with_context(
1248            "Find users with orders",
1249            embedding.clone(),
1250            json!({"users": 42}),
1251            Some(branch.clone()),
1252            Some(session.clone()),
1253            AIWorkloadContext::RAGRetrieval,
1254            vec!["users".to_string(), "orders".to_string()],
1255        );
1256
1257        // Full context match
1258        let result = cache.lookup_with_context(
1259            &embedding,
1260            Some(&branch),
1261            Some(&session),
1262            AIWorkloadContext::RAGRetrieval,
1263        );
1264        assert!(result.is_some());
1265
1266        // Different workload still matches (lower priority)
1267        let result2 = cache.lookup_with_context(
1268            &embedding,
1269            Some(&branch),
1270            None,
1271            AIWorkloadContext::General,
1272        );
1273        assert!(result2.is_some());
1274    }
1275
1276    #[test]
1277    fn test_invalidate_by_table() {
1278        let cache = SemanticQueryCache::with_capacity(0.9, 100);
1279
1280        // Insert entries referencing different tables
1281        cache.insert_with_context(
1282            "SELECT * FROM users",
1283            vec![1.0, 0.0, 0.0],
1284            json!(1),
1285            None,
1286            None,
1287            AIWorkloadContext::General,
1288            vec!["users".to_string()],
1289        );
1290        cache.insert_with_context(
1291            "SELECT * FROM orders",
1292            vec![0.0, 1.0, 0.0],
1293            json!(2),
1294            None,
1295            None,
1296            AIWorkloadContext::General,
1297            vec!["orders".to_string()],
1298        );
1299        cache.insert_with_context(
1300            "SELECT * FROM users JOIN orders",
1301            vec![0.5, 0.5, 0.0],
1302            json!(3),
1303            None,
1304            None,
1305            AIWorkloadContext::General,
1306            vec!["users".to_string(), "orders".to_string()],
1307        );
1308
1309        assert_eq!(cache.len(), 3);
1310
1311        // Invalidate users table
1312        let removed = cache.invalidate_by_table("users");
1313        assert_eq!(removed, 2); // users and users+orders entries
1314        assert_eq!(cache.len(), 1);
1315    }
1316
1317    #[test]
1318    fn test_invalidate_branch() {
1319        let cache = SemanticQueryCache::with_capacity(0.9, 100);
1320
1321        // Insert entries for different branches
1322        cache.insert_with_context(
1323            "query1",
1324            vec![1.0, 0.0, 0.0],
1325            json!(1),
1326            Some(BranchContext::main()),
1327            None,
1328            AIWorkloadContext::General,
1329            Vec::new(),
1330        );
1331        cache.insert_with_context(
1332            "query2",
1333            vec![0.0, 1.0, 0.0],
1334            json!(2),
1335            Some(BranchContext::new("feature-x")),
1336            None,
1337            AIWorkloadContext::General,
1338            Vec::new(),
1339        );
1340        cache.insert_with_context(
1341            "query3",
1342            vec![0.0, 0.0, 1.0],
1343            json!(3),
1344            Some(BranchContext::new("feature-x")),
1345            None,
1346            AIWorkloadContext::General,
1347            Vec::new(),
1348        );
1349
1350        assert_eq!(cache.len(), 3);
1351
1352        // Invalidate feature-x branch
1353        let removed = cache.invalidate_branch(&"feature-x".to_string());
1354        assert_eq!(removed, 2);
1355        assert_eq!(cache.len(), 1);
1356    }
1357
1358    #[test]
1359    fn test_workload_ttl() {
1360        // RAG retrieval - short TTL (5 min)
1361        let rag_entry = SemanticEntry::new(1, "rag query", vec![], json!({}))
1362            .with_workload(AIWorkloadContext::RAGRetrieval);
1363        assert_eq!(rag_entry.workload_ttl(), Duration::from_secs(300));
1364
1365        // Tool result - long TTL (24 hours)
1366        let tool_entry = SemanticEntry::new(2, "tool query", vec![], json!({}))
1367            .with_workload(AIWorkloadContext::ToolResult);
1368        assert_eq!(tool_entry.workload_ttl(), Duration::from_secs(86400));
1369
1370        // Agent conversation - medium TTL (1 hour)
1371        let agent_entry = SemanticEntry::new(3, "agent query", vec![], json!({}))
1372            .with_workload(AIWorkloadContext::AgentConversation);
1373        assert_eq!(agent_entry.workload_ttl(), Duration::from_secs(3600));
1374    }
1375
1376    #[test]
1377    fn test_find_similar_in_branch() {
1378        let cache = SemanticQueryCache::with_capacity(0.5, 100);
1379        let main = BranchContext::main();
1380        let feature = BranchContext::new("feature-x");
1381
1382        // Insert entries in main
1383        for i in 0..3 {
1384            cache.insert_with_context(
1385                format!("main query {}", i),
1386                vec![1.0 - (i as f32 * 0.1), i as f32 * 0.1, 0.0],
1387                json!(i),
1388                Some(main.clone()),
1389                None,
1390                AIWorkloadContext::General,
1391                Vec::new(),
1392            );
1393        }
1394
1395        // Insert entries in feature
1396        for i in 0..2 {
1397            cache.insert_with_context(
1398                format!("feature query {}", i),
1399                vec![0.5, 0.5 + (i as f32 * 0.1), 0.0],
1400                json!(100 + i),
1401                Some(feature.clone()),
1402                None,
1403                AIWorkloadContext::General,
1404                Vec::new(),
1405            );
1406        }
1407
1408        // Find similar in main branch only
1409        let main_results = cache.find_similar_in_branch(&[1.0, 0.0, 0.0], &main, 5);
1410        assert_eq!(main_results.len(), 3);
1411        for r in &main_results {
1412            assert_eq!(r.entry.branch_context.as_ref().unwrap().branch, "main");
1413        }
1414
1415        // Find similar in feature branch only
1416        let feature_results = cache.find_similar_in_branch(&[0.5, 0.5, 0.0], &feature, 5);
1417        assert_eq!(feature_results.len(), 2);
1418        for r in &feature_results {
1419            assert_eq!(r.entry.branch_context.as_ref().unwrap().branch, "feature-x");
1420        }
1421    }
1422}