Skip to main content

heliosdb_proxy/distribcache/ai/
semantic.rs

1//! Semantic query cache
2//!
3//! Caches query results based on semantic similarity of queries.
4//! Uses embedding-based lookup to find similar queries and return cached results.
5//!
6//! # Branch-Aware Caching
7//!
8//! The cache supports branch-aware lookups for time-travel queries:
9//! - Entries can be scoped to specific branches
10//! - Lookups can filter by branch context
11//! - Cross-branch semantic search is supported
12//!
13//! # AI/Agent Optimizations
14//!
15//! - Session affinity for agent conversations
16//! - Workload-aware TTL adjustments
17//! - RAG-specific caching strategies
18
19use dashmap::DashMap;
20use std::collections::BinaryHeap;
21use std::sync::atomic::{AtomicU64, Ordering};
22use std::time::{Duration, Instant};
23
24/// Vector ID for semantic indexing
25pub type VectorId = u64;
26
27/// Branch identifier for branch-aware caching
28pub type BranchId = String;
29
30/// Session identifier for agent sessions
31pub type SessionId = String;
32
33/// Branch context for cache entries
34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
35pub struct BranchContext {
36    /// Branch name (e.g., "main", "feature-x")
37    pub branch: BranchId,
38    /// Optional snapshot timestamp for time-travel queries
39    pub snapshot_at: Option<u64>,
40}
41
42impl BranchContext {
43    /// Create a new branch context
44    pub fn new(branch: impl Into<String>) -> Self {
45        Self {
46            branch: branch.into(),
47            snapshot_at: None,
48        }
49    }
50
51    /// Create branch context with snapshot time
52    pub fn with_snapshot(branch: impl Into<String>, snapshot: u64) -> Self {
53        Self {
54            branch: branch.into(),
55            snapshot_at: Some(snapshot),
56        }
57    }
58
59    /// Main branch context
60    pub fn main() -> Self {
61        Self::new("main")
62    }
63
64    /// Check if this context is compatible with another (for cache hits)
65    pub fn is_compatible(&self, other: &BranchContext) -> bool {
66        if self.branch != other.branch {
67            return false;
68        }
69        // Snapshot compatibility: either both None, or entry snapshot <= query snapshot
70        match (self.snapshot_at, other.snapshot_at) {
71            (None, None) => true,
72            (Some(entry_snap), Some(query_snap)) => entry_snap <= query_snap,
73            (None, Some(_)) => true,  // Current data valid for any snapshot
74            (Some(_), None) => false, // Historical entry not valid for current query
75        }
76    }
77}
78
79impl Default for BranchContext {
80    fn default() -> Self {
81        Self::main()
82    }
83}
84
85/// AI workload context for cache optimization
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
87pub enum AIWorkloadContext {
88    /// RAG retrieval phase - fast, high-throughput
89    RAGRetrieval,
90    /// RAG generation phase - slower, lower frequency
91    RAGGeneration,
92    /// Agent conversation - session-aware
93    AgentConversation,
94    /// Tool call caching - deterministic
95    ToolResult,
96    /// General semantic query
97    #[default]
98    General,
99}
100
101/// Query embedding vector
102pub type Embedding = Vec<f32>;
103
104/// Semantic query entry with branch and session awareness
105#[derive(Debug, Clone)]
106pub struct SemanticEntry {
107    /// Entry ID
108    pub id: VectorId,
109    /// Original query
110    pub query: String,
111    /// Query embedding
112    pub embedding: Embedding,
113    /// Cached result
114    pub result: serde_json::Value,
115    /// Creation time
116    pub created_at: Instant,
117    /// TTL
118    pub ttl: Duration,
119    /// Access count
120    pub access_count: u64,
121    /// Branch context (for branch-aware caching)
122    pub branch_context: Option<BranchContext>,
123    /// Session ID (for agent conversation affinity)
124    pub session_id: Option<SessionId>,
125    /// AI workload type
126    pub workload: AIWorkloadContext,
127    /// Tables referenced by this query (for invalidation)
128    pub tables: Vec<String>,
129}
130
131impl SemanticEntry {
132    /// Create a new semantic entry
133    pub fn new(
134        id: VectorId,
135        query: impl Into<String>,
136        embedding: Embedding,
137        result: serde_json::Value,
138    ) -> Self {
139        Self {
140            id,
141            query: query.into(),
142            embedding,
143            result,
144            created_at: Instant::now(),
145            ttl: Duration::from_secs(3600), // Default 1 hour
146            access_count: 0,
147            branch_context: None,
148            session_id: None,
149            workload: AIWorkloadContext::default(),
150            tables: Vec::new(),
151        }
152    }
153
154    /// Set TTL
155    pub fn with_ttl(mut self, ttl: Duration) -> Self {
156        self.ttl = ttl;
157        self
158    }
159
160    /// Set branch context
161    pub fn with_branch(mut self, branch: BranchContext) -> Self {
162        self.branch_context = Some(branch);
163        self
164    }
165
166    /// Set session ID for agent affinity
167    pub fn with_session(mut self, session: impl Into<String>) -> Self {
168        self.session_id = Some(session.into());
169        self
170    }
171
172    /// Set AI workload context
173    pub fn with_workload(mut self, workload: AIWorkloadContext) -> Self {
174        self.workload = workload;
175        self
176    }
177
178    /// Set referenced tables for invalidation tracking
179    pub fn with_tables(mut self, tables: Vec<String>) -> Self {
180        self.tables = tables;
181        self
182    }
183
184    /// Get workload-adjusted TTL
185    pub fn workload_ttl(&self) -> Duration {
186        match self.workload {
187            AIWorkloadContext::RAGRetrieval => Duration::from_secs(300), // 5 min - fast refresh
188            AIWorkloadContext::RAGGeneration => Duration::from_secs(1800), // 30 min - slower refresh
189            AIWorkloadContext::AgentConversation => Duration::from_secs(3600), // 1 hour - session lifetime
190            AIWorkloadContext::ToolResult => Duration::from_secs(86400), // 24 hours - deterministic
191            AIWorkloadContext::General => self.ttl,
192        }
193    }
194
195    /// Check if expired (considering workload-adjusted TTL)
196    pub fn is_expired(&self) -> bool {
197        self.created_at.elapsed() > self.workload_ttl()
198    }
199
200    /// Check if entry matches branch context
201    pub fn matches_branch(&self, query_branch: &BranchContext) -> bool {
202        match &self.branch_context {
203            None => true, // No branch restriction
204            Some(entry_branch) => entry_branch.is_compatible(query_branch),
205        }
206    }
207
208    /// Check if entry belongs to session
209    pub fn matches_session(&self, session: &SessionId) -> bool {
210        match &self.session_id {
211            None => true,
212            Some(entry_session) => entry_session == session,
213        }
214    }
215
216    /// Approximate size in bytes
217    pub fn size(&self) -> usize {
218        self.query.len()
219            + self.embedding.len() * 4
220            + self.result.to_string().len()
221            + self.tables.iter().map(|t| t.len()).sum::<usize>()
222            + self.session_id.as_ref().map(|s| s.len()).unwrap_or(0)
223            + self
224                .branch_context
225                .as_ref()
226                .map(|b| b.branch.len() + 8)
227                .unwrap_or(0)
228            + 96
229    }
230}
231
232/// Similarity search result
233#[derive(Debug, Clone)]
234pub struct SimilarityResult {
235    /// Entry ID
236    pub id: VectorId,
237    /// Similarity score (0.0 - 1.0)
238    pub similarity: f32,
239    /// The entry
240    pub entry: SemanticEntry,
241}
242
243impl PartialEq for SimilarityResult {
244    fn eq(&self, other: &Self) -> bool {
245        self.similarity == other.similarity
246    }
247}
248
249impl Eq for SimilarityResult {}
250
251impl PartialOrd for SimilarityResult {
252    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
253        Some(self.cmp(other))
254    }
255}
256
257impl Ord for SimilarityResult {
258    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
259        // Reverse order for min-heap (we want highest similarity)
260        other
261            .similarity
262            .partial_cmp(&self.similarity)
263            .unwrap_or(std::cmp::Ordering::Equal)
264    }
265}
266
267/// Simple HNSW-like index for semantic search
268/// (Simplified implementation - production would use a proper HNSW library)
269pub struct SemanticIndex {
270    /// All vectors with their IDs
271    vectors: DashMap<VectorId, Embedding>,
272
273    /// Index configuration
274    #[allow(dead_code)]
275    config: SemanticIndexConfig,
276
277    /// Next ID
278    next_id: AtomicU64,
279}
280
281/// Semantic index configuration
282#[derive(Debug, Clone)]
283pub struct SemanticIndexConfig {
284    /// Maximum connections per node (M parameter)
285    pub max_connections: usize,
286    /// Search expansion factor (ef parameter)
287    pub ef_search: usize,
288    /// Embedding dimension
289    pub dimension: usize,
290}
291
292impl Default for SemanticIndexConfig {
293    fn default() -> Self {
294        Self {
295            max_connections: 16,
296            ef_search: 100,
297            dimension: 384, // Common embedding size
298        }
299    }
300}
301
302impl SemanticIndex {
303    /// Create a new semantic index
304    pub fn new(config: SemanticIndexConfig) -> Self {
305        Self {
306            vectors: DashMap::new(),
307            config,
308            next_id: AtomicU64::new(1),
309        }
310    }
311
312    /// Insert a vector and return its ID
313    pub fn insert(&self, embedding: Embedding) -> VectorId {
314        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
315        self.vectors.insert(id, embedding);
316        id
317    }
318
319    /// Remove a vector
320    pub fn remove(&self, id: VectorId) {
321        self.vectors.remove(&id);
322    }
323
324    /// Search for k nearest neighbors
325    pub fn search(&self, query: &[f32], k: usize) -> Vec<(VectorId, f32)> {
326        // Brute force search (production would use HNSW)
327        let mut heap: BinaryHeap<(std::cmp::Reverse<i64>, VectorId)> = BinaryHeap::new();
328
329        for entry in self.vectors.iter() {
330            let similarity = cosine_similarity(query, entry.value());
331            // Convert to integer for ordering (multiply by 1M for precision)
332            let sim_int = (similarity * 1_000_000.0) as i64;
333            heap.push((std::cmp::Reverse(sim_int), *entry.key()));
334
335            if heap.len() > k {
336                heap.pop();
337            }
338        }
339
340        // Extract results in descending similarity order
341        let mut results: Vec<_> = heap
342            .into_iter()
343            .map(|(std::cmp::Reverse(sim), id)| (id, sim as f32 / 1_000_000.0))
344            .collect();
345
346        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
347        results
348    }
349
350    /// Get vector count
351    pub fn len(&self) -> usize {
352        self.vectors.len()
353    }
354
355    /// Check if empty
356    pub fn is_empty(&self) -> bool {
357        self.vectors.is_empty()
358    }
359
360    /// Clear all vectors
361    pub fn clear(&self) {
362        self.vectors.clear();
363    }
364}
365
366/// Compute cosine similarity between two vectors
367pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
368    if a.len() != b.len() || a.is_empty() {
369        return 0.0;
370    }
371
372    let mut dot = 0.0;
373    let mut norm_a = 0.0;
374    let mut norm_b = 0.0;
375
376    for i in 0..a.len() {
377        dot += a[i] * b[i];
378        norm_a += a[i] * a[i];
379        norm_b += b[i] * b[i];
380    }
381
382    let denominator = (norm_a * norm_b).sqrt();
383    if denominator == 0.0 {
384        0.0
385    } else {
386        dot / denominator
387    }
388}
389
390/// Semantic query cache
391pub struct SemanticQueryCache {
392    /// Query index
393    index: SemanticIndex,
394
395    /// Cached entries
396    entries: DashMap<VectorId, SemanticEntry>,
397
398    /// Similarity threshold for cache hit
399    threshold: f32,
400
401    /// Maximum entries
402    max_entries: usize,
403
404    /// Statistics
405    stats: SemanticCacheStats,
406}
407
408/// Semantic cache statistics
409#[derive(Debug, Default)]
410struct SemanticCacheStats {
411    hits: AtomicU64,
412    misses: AtomicU64,
413    semantic_hits: AtomicU64,
414    exact_hits: AtomicU64,
415    insertions: AtomicU64,
416    evictions: AtomicU64,
417}
418
419impl SemanticQueryCache {
420    /// Create a new semantic query cache with default max entries
421    pub fn new(threshold: f32) -> Self {
422        Self::with_capacity(threshold, 10000)
423    }
424
425    /// Create a new semantic query cache with specified capacity
426    pub fn with_capacity(threshold: f32, max_entries: usize) -> Self {
427        Self {
428            index: SemanticIndex::new(SemanticIndexConfig::default()),
429            entries: DashMap::new(),
430            threshold,
431            max_entries,
432            stats: SemanticCacheStats::default(),
433        }
434    }
435
436    /// Create with custom index config
437    pub fn with_config(
438        threshold: f32,
439        max_entries: usize,
440        index_config: SemanticIndexConfig,
441    ) -> Self {
442        Self {
443            index: SemanticIndex::new(index_config),
444            entries: DashMap::new(),
445            threshold,
446            max_entries,
447            stats: SemanticCacheStats::default(),
448        }
449    }
450
451    /// Lookup by semantic similarity
452    pub fn lookup(&self, embedding: &[f32]) -> Option<SimilarityResult> {
453        // Search for nearest neighbor
454        let results = self.index.search(embedding, 1);
455
456        if let Some((id, similarity)) = results.first() {
457            if *similarity >= self.threshold {
458                if let Some(entry) = self.entries.get(id) {
459                    if !entry.is_expired() {
460                        self.stats.hits.fetch_add(1, Ordering::Relaxed);
461
462                        if *similarity > 0.999 {
463                            self.stats.exact_hits.fetch_add(1, Ordering::Relaxed);
464                        } else {
465                            self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
466                        }
467
468                        return Some(SimilarityResult {
469                            id: *id,
470                            similarity: *similarity,
471                            entry: entry.clone(),
472                        });
473                    } else {
474                        // Remove expired entry
475                        drop(entry);
476                        self.remove(*id);
477                    }
478                }
479            }
480        }
481
482        self.stats.misses.fetch_add(1, Ordering::Relaxed);
483        None
484    }
485
486    /// Lookup with custom threshold
487    pub fn lookup_with_threshold(
488        &self,
489        embedding: &[f32],
490        threshold: f32,
491    ) -> Option<SimilarityResult> {
492        let results = self.index.search(embedding, 1);
493
494        if let Some((id, similarity)) = results.first() {
495            if *similarity >= threshold {
496                if let Some(entry) = self.entries.get(id) {
497                    if !entry.is_expired() {
498                        return Some(SimilarityResult {
499                            id: *id,
500                            similarity: *similarity,
501                            entry: entry.clone(),
502                        });
503                    }
504                }
505            }
506        }
507
508        None
509    }
510
511    /// Find k most similar queries
512    pub fn find_similar(&self, embedding: &[f32], k: usize) -> Vec<SimilarityResult> {
513        let results = self.index.search(embedding, k);
514
515        results
516            .into_iter()
517            .filter_map(|(id, similarity)| {
518                self.entries.get(&id).and_then(|entry| {
519                    if !entry.is_expired() {
520                        Some(SimilarityResult {
521                            id,
522                            similarity,
523                            entry: entry.clone(),
524                        })
525                    } else {
526                        None
527                    }
528                })
529            })
530            .collect()
531    }
532
533    /// Lookup with branch context filtering
534    ///
535    /// Returns cached entry only if it's compatible with the given branch context.
536    /// This enables branch-aware caching for time-travel queries.
537    pub fn lookup_with_branch(
538        &self,
539        embedding: &[f32],
540        branch: &BranchContext,
541    ) -> Option<SimilarityResult> {
542        // Search for multiple candidates to filter by branch
543        let results = self.index.search(embedding, 10);
544
545        for (id, similarity) in results {
546            if similarity < self.threshold {
547                break; // Results are sorted by similarity
548            }
549
550            if let Some(entry) = self.entries.get(&id) {
551                if !entry.is_expired() && entry.matches_branch(branch) {
552                    self.stats.hits.fetch_add(1, Ordering::Relaxed);
553                    if similarity > 0.999 {
554                        self.stats.exact_hits.fetch_add(1, Ordering::Relaxed);
555                    } else {
556                        self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
557                    }
558
559                    return Some(SimilarityResult {
560                        id,
561                        similarity,
562                        entry: entry.clone(),
563                    });
564                }
565            }
566        }
567
568        self.stats.misses.fetch_add(1, Ordering::Relaxed);
569        None
570    }
571
572    /// Lookup with session affinity for agent conversations
573    ///
574    /// Prioritizes entries from the same session for better conversation context.
575    pub fn lookup_with_session(
576        &self,
577        embedding: &[f32],
578        session: &SessionId,
579    ) -> Option<SimilarityResult> {
580        let results = self.index.search(embedding, 20);
581
582        // First pass: look for same-session matches
583        for (id, similarity) in &results {
584            if *similarity < self.threshold {
585                break;
586            }
587
588            if let Some(entry) = self.entries.get(id) {
589                if !entry.is_expired() && entry.matches_session(session) {
590                    self.stats.hits.fetch_add(1, Ordering::Relaxed);
591                    self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
592
593                    return Some(SimilarityResult {
594                        id: *id,
595                        similarity: *similarity,
596                        entry: entry.clone(),
597                    });
598                }
599            }
600        }
601
602        // Second pass: any matching entry (cross-session)
603        for (id, similarity) in &results {
604            if *similarity < self.threshold {
605                break;
606            }
607
608            if let Some(entry) = self.entries.get(id) {
609                if !entry.is_expired() {
610                    self.stats.hits.fetch_add(1, Ordering::Relaxed);
611                    self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
612
613                    return Some(SimilarityResult {
614                        id: *id,
615                        similarity: *similarity,
616                        entry: entry.clone(),
617                    });
618                }
619            }
620        }
621
622        self.stats.misses.fetch_add(1, Ordering::Relaxed);
623        None
624    }
625
626    /// Lookup with full AI context (branch + session + workload)
627    ///
628    /// Most comprehensive lookup that considers:
629    /// - Branch context for time-travel
630    /// - Session affinity for agent conversations
631    /// - Workload type for TTL and priority
632    pub fn lookup_with_context(
633        &self,
634        embedding: &[f32],
635        branch: Option<&BranchContext>,
636        session: Option<&SessionId>,
637        workload: AIWorkloadContext,
638    ) -> Option<SimilarityResult> {
639        let results = self.index.search(embedding, 20);
640
641        // Priority 1: Same session + same branch + same workload
642        for (id, similarity) in &results {
643            if *similarity < self.threshold {
644                break;
645            }
646
647            if let Some(entry) = self.entries.get(id) {
648                let branch_match = branch.map(|b| entry.matches_branch(b)).unwrap_or(true);
649                let session_match = session.map(|s| entry.matches_session(s)).unwrap_or(false);
650                let workload_match = entry.workload == workload;
651
652                if !entry.is_expired() && branch_match && session_match && workload_match {
653                    self.stats.hits.fetch_add(1, Ordering::Relaxed);
654                    self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
655                    return Some(SimilarityResult {
656                        id: *id,
657                        similarity: *similarity,
658                        entry: entry.clone(),
659                    });
660                }
661            }
662        }
663
664        // Priority 2: Same branch + same workload
665        for (id, similarity) in &results {
666            if *similarity < self.threshold {
667                break;
668            }
669
670            if let Some(entry) = self.entries.get(id) {
671                let branch_match = branch.map(|b| entry.matches_branch(b)).unwrap_or(true);
672                let workload_match = entry.workload == workload;
673
674                if !entry.is_expired() && branch_match && workload_match {
675                    self.stats.hits.fetch_add(1, Ordering::Relaxed);
676                    self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
677                    return Some(SimilarityResult {
678                        id: *id,
679                        similarity: *similarity,
680                        entry: entry.clone(),
681                    });
682                }
683            }
684        }
685
686        // Priority 3: Same branch only
687        for (id, similarity) in &results {
688            if *similarity < self.threshold {
689                break;
690            }
691
692            if let Some(entry) = self.entries.get(id) {
693                let branch_match = branch.map(|b| entry.matches_branch(b)).unwrap_or(true);
694
695                if !entry.is_expired() && branch_match {
696                    self.stats.hits.fetch_add(1, Ordering::Relaxed);
697                    self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
698                    return Some(SimilarityResult {
699                        id: *id,
700                        similarity: *similarity,
701                        entry: entry.clone(),
702                    });
703                }
704            }
705        }
706
707        self.stats.misses.fetch_add(1, Ordering::Relaxed);
708        None
709    }
710
711    /// Find similar entries within a branch
712    pub fn find_similar_in_branch(
713        &self,
714        embedding: &[f32],
715        branch: &BranchContext,
716        k: usize,
717    ) -> Vec<SimilarityResult> {
718        // Search for more candidates since we'll filter
719        let results = self.index.search(embedding, k * 3);
720
721        results
722            .into_iter()
723            .filter_map(|(id, similarity)| {
724                self.entries.get(&id).and_then(|entry| {
725                    if !entry.is_expired() && entry.matches_branch(branch) {
726                        Some(SimilarityResult {
727                            id,
728                            similarity,
729                            entry: entry.clone(),
730                        })
731                    } else {
732                        None
733                    }
734                })
735            })
736            .take(k)
737            .collect()
738    }
739
740    /// Invalidate entries by table name
741    ///
742    /// Used when WAL invalidation detects changes to a table.
743    pub fn invalidate_by_table(&self, table: &str) -> usize {
744        let to_remove: Vec<_> = self
745            .entries
746            .iter()
747            .filter(|e| e.tables.iter().any(|t| t == table))
748            .map(|e| *e.key())
749            .collect();
750
751        let count = to_remove.len();
752        for id in to_remove {
753            self.remove(id);
754        }
755        count
756    }
757
758    /// Invalidate entries by branch
759    pub fn invalidate_branch(&self, branch: &BranchId) -> usize {
760        let to_remove: Vec<_> = self
761            .entries
762            .iter()
763            .filter(|e| {
764                e.branch_context
765                    .as_ref()
766                    .map(|b| &b.branch == branch)
767                    .unwrap_or(false)
768            })
769            .map(|e| *e.key())
770            .collect();
771
772        let count = to_remove.len();
773        for id in to_remove {
774            self.remove(id);
775        }
776        count
777    }
778
779    /// Insert a new entry
780    pub fn insert(
781        &self,
782        query: impl Into<String>,
783        embedding: Embedding,
784        result: serde_json::Value,
785    ) -> VectorId {
786        // Evict if at capacity
787        while self.entries.len() >= self.max_entries {
788            self.evict_one();
789        }
790
791        // Insert into index
792        let id = self.index.insert(embedding.clone());
793
794        // Create and store entry
795        let entry = SemanticEntry::new(id, query, embedding, result);
796        self.entries.insert(id, entry);
797
798        self.stats.insertions.fetch_add(1, Ordering::Relaxed);
799        id
800    }
801
802    /// Insert with TTL
803    pub fn insert_with_ttl(
804        &self,
805        query: impl Into<String>,
806        embedding: Embedding,
807        result: serde_json::Value,
808        ttl: Duration,
809    ) -> VectorId {
810        while self.entries.len() >= self.max_entries {
811            self.evict_one();
812        }
813
814        let id = self.index.insert(embedding.clone());
815        let entry = SemanticEntry::new(id, query, embedding, result).with_ttl(ttl);
816        self.entries.insert(id, entry);
817
818        self.stats.insertions.fetch_add(1, Ordering::Relaxed);
819        id
820    }
821
822    /// Insert with full AI context (branch, session, workload, tables)
823    ///
824    /// This is the recommended insertion method for AI/Agent workloads
825    /// as it enables branch-aware caching, session affinity, and
826    /// workload-specific TTL management.
827    #[allow(clippy::too_many_arguments)]
828    pub fn insert_with_context(
829        &self,
830        query: impl Into<String>,
831        embedding: Embedding,
832        result: serde_json::Value,
833        branch: Option<BranchContext>,
834        session: Option<SessionId>,
835        workload: AIWorkloadContext,
836        tables: Vec<String>,
837    ) -> VectorId {
838        while self.entries.len() >= self.max_entries {
839            self.evict_one();
840        }
841
842        let id = self.index.insert(embedding.clone());
843        let mut entry = SemanticEntry::new(id, query, embedding, result)
844            .with_workload(workload)
845            .with_tables(tables);
846
847        if let Some(b) = branch {
848            entry = entry.with_branch(b);
849        }
850        if let Some(s) = session {
851            entry = entry.with_session(s);
852        }
853
854        self.entries.insert(id, entry);
855        self.stats.insertions.fetch_add(1, Ordering::Relaxed);
856        id
857    }
858
859    /// Insert for RAG retrieval workload
860    ///
861    /// Optimized TTL for fast-refresh retrieval phase.
862    pub fn insert_rag_retrieval(
863        &self,
864        query: impl Into<String>,
865        embedding: Embedding,
866        result: serde_json::Value,
867        tables: Vec<String>,
868    ) -> VectorId {
869        self.insert_with_context(
870            query,
871            embedding,
872            result,
873            None,
874            None,
875            AIWorkloadContext::RAGRetrieval,
876            tables,
877        )
878    }
879
880    /// Insert for agent conversation
881    ///
882    /// Session-aware with longer TTL for conversation context.
883    pub fn insert_agent_response(
884        &self,
885        query: impl Into<String>,
886        embedding: Embedding,
887        result: serde_json::Value,
888        session: SessionId,
889        branch: Option<BranchContext>,
890    ) -> VectorId {
891        self.insert_with_context(
892            query,
893            embedding,
894            result,
895            branch,
896            Some(session),
897            AIWorkloadContext::AgentConversation,
898            Vec::new(),
899        )
900    }
901
902    /// Insert deterministic tool result
903    ///
904    /// Long TTL for deterministic tool calls (e.g., math, date formatting).
905    pub fn insert_tool_result(
906        &self,
907        query: impl Into<String>,
908        embedding: Embedding,
909        result: serde_json::Value,
910    ) -> VectorId {
911        self.insert_with_context(
912            query,
913            embedding,
914            result,
915            None,
916            None,
917            AIWorkloadContext::ToolResult,
918            Vec::new(),
919        )
920    }
921
922    /// Remove an entry
923    pub fn remove(&self, id: VectorId) {
924        self.index.remove(id);
925        self.entries.remove(&id);
926    }
927
928    /// Evict one entry (oldest by creation time)
929    fn evict_one(&self) {
930        let mut oldest_id = None;
931        let mut oldest_time = Instant::now();
932
933        for entry in self.entries.iter() {
934            if entry.created_at < oldest_time {
935                oldest_time = entry.created_at;
936                oldest_id = Some(*entry.key());
937            }
938        }
939
940        if let Some(id) = oldest_id {
941            self.remove(id);
942            self.stats.evictions.fetch_add(1, Ordering::Relaxed);
943        }
944    }
945
946    /// Remove expired entries
947    pub fn cleanup_expired(&self) {
948        let expired: Vec<_> = self
949            .entries
950            .iter()
951            .filter(|e| e.is_expired())
952            .map(|e| *e.key())
953            .collect();
954
955        for id in expired {
956            self.remove(id);
957        }
958    }
959
960    /// Clear all entries
961    pub fn clear(&self) {
962        self.index.clear();
963        self.entries.clear();
964    }
965
966    /// Get entry count
967    pub fn len(&self) -> usize {
968        self.entries.len()
969    }
970
971    /// Check if empty
972    pub fn is_empty(&self) -> bool {
973        self.entries.is_empty()
974    }
975
976    /// Get statistics
977    pub fn stats(&self) -> SemanticCacheStatsSnapshot {
978        let hits = self.stats.hits.load(Ordering::Relaxed);
979        let misses = self.stats.misses.load(Ordering::Relaxed);
980        let total = hits + misses;
981
982        SemanticCacheStatsSnapshot {
983            entries: self.entries.len(),
984            threshold: self.threshold,
985            hits,
986            misses,
987            hit_rate: if total > 0 {
988                hits as f64 / total as f64
989            } else {
990                0.0
991            },
992            semantic_hits: self.stats.semantic_hits.load(Ordering::Relaxed),
993            exact_hits: self.stats.exact_hits.load(Ordering::Relaxed),
994            insertions: self.stats.insertions.load(Ordering::Relaxed),
995            evictions: self.stats.evictions.load(Ordering::Relaxed),
996        }
997    }
998}
999
1000/// Semantic cache statistics snapshot
1001#[derive(Debug, Clone)]
1002pub struct SemanticCacheStatsSnapshot {
1003    pub entries: usize,
1004    pub threshold: f32,
1005    pub hits: u64,
1006    pub misses: u64,
1007    pub hit_rate: f64,
1008    pub semantic_hits: u64,
1009    pub exact_hits: u64,
1010    pub insertions: u64,
1011    pub evictions: u64,
1012}
1013
1014#[cfg(test)]
1015mod tests {
1016    use super::*;
1017    use serde_json::json;
1018
1019    #[test]
1020    fn test_cosine_similarity() {
1021        let a = vec![1.0, 0.0, 0.0];
1022        let b = vec![1.0, 0.0, 0.0];
1023        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
1024
1025        let c = vec![0.0, 1.0, 0.0];
1026        assert!(cosine_similarity(&a, &c).abs() < 0.001);
1027
1028        let d = vec![0.707, 0.707, 0.0];
1029        let sim = cosine_similarity(&a, &d);
1030        assert!((sim - 0.707).abs() < 0.01);
1031    }
1032
1033    #[test]
1034    fn test_semantic_index() {
1035        let index = SemanticIndex::new(SemanticIndexConfig::default());
1036
1037        let id1 = index.insert(vec![1.0, 0.0, 0.0]);
1038        let id2 = index.insert(vec![0.9, 0.1, 0.0]);
1039        let id3 = index.insert(vec![0.0, 1.0, 0.0]);
1040
1041        // Search for vector similar to [1.0, 0.0, 0.0]
1042        let results = index.search(&[1.0, 0.0, 0.0], 2);
1043
1044        assert_eq!(results.len(), 2);
1045        assert_eq!(results[0].0, id1); // Most similar
1046        assert_eq!(results[1].0, id2); // Second most similar
1047    }
1048
1049    #[test]
1050    fn test_semantic_cache_insert_lookup() {
1051        let cache = SemanticQueryCache::with_capacity(0.9, 100);
1052
1053        let embedding = vec![1.0, 0.0, 0.0];
1054        let id = cache.insert(
1055            "SELECT * FROM users WHERE name = 'test'",
1056            embedding.clone(),
1057            json!({"count": 5}),
1058        );
1059
1060        // Exact lookup
1061        let result = cache.lookup(&embedding);
1062        assert!(result.is_some());
1063        let res = result.unwrap();
1064        assert_eq!(res.id, id);
1065        assert!(res.similarity > 0.999);
1066    }
1067
1068    #[test]
1069    fn test_semantic_similarity_lookup() {
1070        let cache = SemanticQueryCache::with_capacity(0.9, 100);
1071
1072        // Insert original query
1073        cache.insert(
1074            "SELECT * FROM users WHERE id = 1",
1075            vec![1.0, 0.0, 0.0],
1076            json!({"user": "alice"}),
1077        );
1078
1079        // Lookup with similar query embedding
1080        let similar_embedding = vec![0.95, 0.05, 0.0];
1081        let result = cache.lookup(&similar_embedding);
1082
1083        assert!(result.is_some());
1084        let res = result.unwrap();
1085        assert!(res.similarity >= 0.9);
1086    }
1087
1088    #[test]
1089    fn test_threshold_rejection() {
1090        let cache = SemanticQueryCache::with_capacity(0.95, 100);
1091
1092        cache.insert(
1093            "SELECT * FROM orders",
1094            vec![1.0, 0.0, 0.0],
1095            json!({"total": 100}),
1096        );
1097
1098        // Query too different (below threshold)
1099        let different_embedding = vec![0.7, 0.7, 0.0];
1100        let result = cache.lookup(&different_embedding);
1101        assert!(result.is_none());
1102    }
1103
1104    #[test]
1105    fn test_find_similar() {
1106        let cache = SemanticQueryCache::with_capacity(0.5, 100);
1107
1108        cache.insert("query1", vec![1.0, 0.0, 0.0], json!(1));
1109        cache.insert("query2", vec![0.9, 0.1, 0.0], json!(2));
1110        cache.insert("query3", vec![0.8, 0.2, 0.0], json!(3));
1111        cache.insert("query4", vec![0.0, 1.0, 0.0], json!(4));
1112
1113        let similar = cache.find_similar(&[1.0, 0.0, 0.0], 3);
1114
1115        assert_eq!(similar.len(), 3);
1116        // First should be most similar
1117        assert!(similar[0].similarity > similar[1].similarity);
1118        assert!(similar[1].similarity > similar[2].similarity);
1119    }
1120
1121    #[test]
1122    fn test_expiration() {
1123        let cache = SemanticQueryCache::with_capacity(0.9, 100);
1124
1125        let embedding = vec![1.0, 0.0, 0.0];
1126        cache.insert_with_ttl(
1127            "expiring query",
1128            embedding.clone(),
1129            json!({"expires": true}),
1130            Duration::from_millis(1),
1131        );
1132
1133        // Wait for expiration
1134        std::thread::sleep(Duration::from_millis(10));
1135
1136        let result = cache.lookup(&embedding);
1137        assert!(result.is_none());
1138    }
1139
1140    #[test]
1141    fn test_eviction() {
1142        let cache = SemanticQueryCache::with_capacity(0.9, 3);
1143
1144        // Fill cache
1145        for i in 0..3 {
1146            cache.insert(format!("query{}", i), vec![i as f32, 0.0, 0.0], json!(i));
1147        }
1148
1149        assert_eq!(cache.len(), 3);
1150
1151        // Insert one more (should evict)
1152        cache.insert("query3", vec![3.0, 0.0, 0.0], json!(3));
1153
1154        assert_eq!(cache.len(), 3);
1155    }
1156
1157    #[test]
1158    fn test_stats() {
1159        let cache = SemanticQueryCache::with_capacity(0.9, 100);
1160
1161        let embedding = vec![1.0, 0.0, 0.0];
1162        cache.insert("test query", embedding.clone(), json!(1));
1163
1164        // Hits
1165        cache.lookup(&embedding);
1166        cache.lookup(&embedding);
1167
1168        // Miss
1169        cache.lookup(&[0.0, 1.0, 0.0]);
1170
1171        let stats = cache.stats();
1172        assert_eq!(stats.hits, 2);
1173        assert_eq!(stats.misses, 1);
1174        assert_eq!(stats.exact_hits, 2);
1175        assert_eq!(stats.insertions, 1);
1176    }
1177
1178    #[test]
1179    fn test_branch_context_compatibility() {
1180        let main = BranchContext::main();
1181        let feature = BranchContext::new("feature-x");
1182        let snapshot = BranchContext::with_snapshot("main", 1000);
1183        let later_snapshot = BranchContext::with_snapshot("main", 2000);
1184
1185        // Same branch
1186        assert!(main.is_compatible(&main));
1187        assert!(!main.is_compatible(&feature));
1188
1189        // Snapshot compatibility
1190        assert!(snapshot.is_compatible(&later_snapshot)); // Entry @ 1000 valid for query @ 2000
1191        assert!(!later_snapshot.is_compatible(&snapshot)); // Entry @ 2000 not valid for query @ 1000
1192    }
1193
1194    #[test]
1195    fn test_lookup_with_branch() {
1196        let cache = SemanticQueryCache::with_capacity(0.9, 100);
1197
1198        // Insert entry for main branch
1199        let embedding = vec![1.0, 0.0, 0.0];
1200        cache.insert_with_context(
1201            "SELECT * FROM users",
1202            embedding.clone(),
1203            json!({"users": []}),
1204            Some(BranchContext::main()),
1205            None,
1206            AIWorkloadContext::General,
1207            vec!["users".to_string()],
1208        );
1209
1210        // Insert entry for feature branch
1211        let embedding2 = vec![0.95, 0.05, 0.0];
1212        cache.insert_with_context(
1213            "SELECT * FROM users",
1214            embedding2.clone(),
1215            json!({"users": ["new_user"]}),
1216            Some(BranchContext::new("feature-x")),
1217            None,
1218            AIWorkloadContext::General,
1219            vec!["users".to_string()],
1220        );
1221
1222        // Lookup for main should find main entry
1223        let main_result = cache.lookup_with_branch(&embedding, &BranchContext::main());
1224        assert!(main_result.is_some());
1225        assert_eq!(
1226            main_result
1227                .unwrap()
1228                .entry
1229                .branch_context
1230                .as_ref()
1231                .unwrap()
1232                .branch,
1233            "main"
1234        );
1235
1236        // Lookup for feature should find feature entry
1237        let feature_result =
1238            cache.lookup_with_branch(&embedding2, &BranchContext::new("feature-x"));
1239        assert!(feature_result.is_some());
1240        assert_eq!(
1241            feature_result
1242                .unwrap()
1243                .entry
1244                .branch_context
1245                .as_ref()
1246                .unwrap()
1247                .branch,
1248            "feature-x"
1249        );
1250    }
1251
1252    #[test]
1253    fn test_lookup_with_session() {
1254        let cache = SemanticQueryCache::with_capacity(0.9, 100);
1255        let session1 = "session-001".to_string();
1256        let session2 = "session-002".to_string();
1257
1258        // Insert for session 1
1259        let embedding = vec![1.0, 0.0, 0.0];
1260        cache.insert_agent_response(
1261            "What is the weather?",
1262            embedding.clone(),
1263            json!({"weather": "sunny"}),
1264            session1.clone(),
1265            None,
1266        );
1267
1268        // Insert similar query for session 2
1269        let embedding2 = vec![0.98, 0.02, 0.0];
1270        cache.insert_agent_response(
1271            "How's the weather?",
1272            embedding2,
1273            json!({"weather": "cloudy"}),
1274            session2.clone(),
1275            None,
1276        );
1277
1278        // Lookup with session 1 should prefer session 1 entry
1279        let result = cache.lookup_with_session(&embedding, &session1);
1280        assert!(result.is_some());
1281        assert_eq!(
1282            result.unwrap().entry.session_id.as_ref().unwrap(),
1283            &session1
1284        );
1285    }
1286
1287    #[test]
1288    fn test_lookup_with_context() {
1289        let cache = SemanticQueryCache::with_capacity(0.8, 100);
1290        let session = "agent-session".to_string();
1291        let branch = BranchContext::main();
1292
1293        // Insert with full context
1294        let embedding = vec![1.0, 0.0, 0.0];
1295        cache.insert_with_context(
1296            "Find users with orders",
1297            embedding.clone(),
1298            json!({"users": 42}),
1299            Some(branch.clone()),
1300            Some(session.clone()),
1301            AIWorkloadContext::RAGRetrieval,
1302            vec!["users".to_string(), "orders".to_string()],
1303        );
1304
1305        // Full context match
1306        let result = cache.lookup_with_context(
1307            &embedding,
1308            Some(&branch),
1309            Some(&session),
1310            AIWorkloadContext::RAGRetrieval,
1311        );
1312        assert!(result.is_some());
1313
1314        // Different workload still matches (lower priority)
1315        let result2 =
1316            cache.lookup_with_context(&embedding, Some(&branch), None, AIWorkloadContext::General);
1317        assert!(result2.is_some());
1318    }
1319
1320    #[test]
1321    fn test_invalidate_by_table() {
1322        let cache = SemanticQueryCache::with_capacity(0.9, 100);
1323
1324        // Insert entries referencing different tables
1325        cache.insert_with_context(
1326            "SELECT * FROM users",
1327            vec![1.0, 0.0, 0.0],
1328            json!(1),
1329            None,
1330            None,
1331            AIWorkloadContext::General,
1332            vec!["users".to_string()],
1333        );
1334        cache.insert_with_context(
1335            "SELECT * FROM orders",
1336            vec![0.0, 1.0, 0.0],
1337            json!(2),
1338            None,
1339            None,
1340            AIWorkloadContext::General,
1341            vec!["orders".to_string()],
1342        );
1343        cache.insert_with_context(
1344            "SELECT * FROM users JOIN orders",
1345            vec![0.5, 0.5, 0.0],
1346            json!(3),
1347            None,
1348            None,
1349            AIWorkloadContext::General,
1350            vec!["users".to_string(), "orders".to_string()],
1351        );
1352
1353        assert_eq!(cache.len(), 3);
1354
1355        // Invalidate users table
1356        let removed = cache.invalidate_by_table("users");
1357        assert_eq!(removed, 2); // users and users+orders entries
1358        assert_eq!(cache.len(), 1);
1359    }
1360
1361    #[test]
1362    fn test_invalidate_branch() {
1363        let cache = SemanticQueryCache::with_capacity(0.9, 100);
1364
1365        // Insert entries for different branches
1366        cache.insert_with_context(
1367            "query1",
1368            vec![1.0, 0.0, 0.0],
1369            json!(1),
1370            Some(BranchContext::main()),
1371            None,
1372            AIWorkloadContext::General,
1373            Vec::new(),
1374        );
1375        cache.insert_with_context(
1376            "query2",
1377            vec![0.0, 1.0, 0.0],
1378            json!(2),
1379            Some(BranchContext::new("feature-x")),
1380            None,
1381            AIWorkloadContext::General,
1382            Vec::new(),
1383        );
1384        cache.insert_with_context(
1385            "query3",
1386            vec![0.0, 0.0, 1.0],
1387            json!(3),
1388            Some(BranchContext::new("feature-x")),
1389            None,
1390            AIWorkloadContext::General,
1391            Vec::new(),
1392        );
1393
1394        assert_eq!(cache.len(), 3);
1395
1396        // Invalidate feature-x branch
1397        let removed = cache.invalidate_branch(&"feature-x".to_string());
1398        assert_eq!(removed, 2);
1399        assert_eq!(cache.len(), 1);
1400    }
1401
1402    #[test]
1403    fn test_workload_ttl() {
1404        // RAG retrieval - short TTL (5 min)
1405        let rag_entry = SemanticEntry::new(1, "rag query", vec![], json!({}))
1406            .with_workload(AIWorkloadContext::RAGRetrieval);
1407        assert_eq!(rag_entry.workload_ttl(), Duration::from_secs(300));
1408
1409        // Tool result - long TTL (24 hours)
1410        let tool_entry = SemanticEntry::new(2, "tool query", vec![], json!({}))
1411            .with_workload(AIWorkloadContext::ToolResult);
1412        assert_eq!(tool_entry.workload_ttl(), Duration::from_secs(86400));
1413
1414        // Agent conversation - medium TTL (1 hour)
1415        let agent_entry = SemanticEntry::new(3, "agent query", vec![], json!({}))
1416            .with_workload(AIWorkloadContext::AgentConversation);
1417        assert_eq!(agent_entry.workload_ttl(), Duration::from_secs(3600));
1418    }
1419
1420    #[test]
1421    fn test_find_similar_in_branch() {
1422        let cache = SemanticQueryCache::with_capacity(0.5, 100);
1423        let main = BranchContext::main();
1424        let feature = BranchContext::new("feature-x");
1425
1426        // Insert entries in main
1427        for i in 0..3 {
1428            cache.insert_with_context(
1429                format!("main query {}", i),
1430                vec![1.0 - (i as f32 * 0.1), i as f32 * 0.1, 0.0],
1431                json!(i),
1432                Some(main.clone()),
1433                None,
1434                AIWorkloadContext::General,
1435                Vec::new(),
1436            );
1437        }
1438
1439        // Insert entries in feature
1440        for i in 0..2 {
1441            cache.insert_with_context(
1442                format!("feature query {}", i),
1443                vec![0.5, 0.5 + (i as f32 * 0.1), 0.0],
1444                json!(100 + i),
1445                Some(feature.clone()),
1446                None,
1447                AIWorkloadContext::General,
1448                Vec::new(),
1449            );
1450        }
1451
1452        // Find similar in main branch only
1453        let main_results = cache.find_similar_in_branch(&[1.0, 0.0, 0.0], &main, 5);
1454        assert_eq!(main_results.len(), 3);
1455        for r in &main_results {
1456            assert_eq!(r.entry.branch_context.as_ref().unwrap().branch, "main");
1457        }
1458
1459        // Find similar in feature branch only
1460        let feature_results = cache.find_similar_in_branch(&[0.5, 0.5, 0.0], &feature, 5);
1461        assert_eq!(feature_results.len(), 2);
1462        for r in &feature_results {
1463            assert_eq!(r.entry.branch_context.as_ref().unwrap().branch, "feature-x");
1464        }
1465    }
1466}