1use dashmap::DashMap;
20use std::collections::BinaryHeap;
21use std::sync::atomic::{AtomicU64, Ordering};
22use std::time::{Duration, Instant};
23
24pub type VectorId = u64;
26
27pub type BranchId = String;
29
30pub type SessionId = String;
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
35pub struct BranchContext {
36 pub branch: BranchId,
38 pub snapshot_at: Option<u64>,
40}
41
42impl BranchContext {
43 pub fn new(branch: impl Into<String>) -> Self {
45 Self {
46 branch: branch.into(),
47 snapshot_at: None,
48 }
49 }
50
51 pub fn with_snapshot(branch: impl Into<String>, snapshot: u64) -> Self {
53 Self {
54 branch: branch.into(),
55 snapshot_at: Some(snapshot),
56 }
57 }
58
59 pub fn main() -> Self {
61 Self::new("main")
62 }
63
64 pub fn is_compatible(&self, other: &BranchContext) -> bool {
66 if self.branch != other.branch {
67 return false;
68 }
69 match (self.snapshot_at, other.snapshot_at) {
71 (None, None) => true,
72 (Some(entry_snap), Some(query_snap)) => entry_snap <= query_snap,
73 (None, Some(_)) => true, (Some(_), None) => false, }
76 }
77}
78
79impl Default for BranchContext {
80 fn default() -> Self {
81 Self::main()
82 }
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum AIWorkloadContext {
88 RAGRetrieval,
90 RAGGeneration,
92 AgentConversation,
94 ToolResult,
96 General,
98}
99
100impl Default for AIWorkloadContext {
101 fn default() -> Self {
102 Self::General
103 }
104}
105
106pub type Embedding = Vec<f32>;
108
109#[derive(Debug, Clone)]
111pub struct SemanticEntry {
112 pub id: VectorId,
114 pub query: String,
116 pub embedding: Embedding,
118 pub result: serde_json::Value,
120 pub created_at: Instant,
122 pub ttl: Duration,
124 pub access_count: u64,
126 pub branch_context: Option<BranchContext>,
128 pub session_id: Option<SessionId>,
130 pub workload: AIWorkloadContext,
132 pub tables: Vec<String>,
134}
135
136impl SemanticEntry {
137 pub fn new(
139 id: VectorId,
140 query: impl Into<String>,
141 embedding: Embedding,
142 result: serde_json::Value,
143 ) -> Self {
144 Self {
145 id,
146 query: query.into(),
147 embedding,
148 result,
149 created_at: Instant::now(),
150 ttl: Duration::from_secs(3600), access_count: 0,
152 branch_context: None,
153 session_id: None,
154 workload: AIWorkloadContext::default(),
155 tables: Vec::new(),
156 }
157 }
158
159 pub fn with_ttl(mut self, ttl: Duration) -> Self {
161 self.ttl = ttl;
162 self
163 }
164
165 pub fn with_branch(mut self, branch: BranchContext) -> Self {
167 self.branch_context = Some(branch);
168 self
169 }
170
171 pub fn with_session(mut self, session: impl Into<String>) -> Self {
173 self.session_id = Some(session.into());
174 self
175 }
176
177 pub fn with_workload(mut self, workload: AIWorkloadContext) -> Self {
179 self.workload = workload;
180 self
181 }
182
183 pub fn with_tables(mut self, tables: Vec<String>) -> Self {
185 self.tables = tables;
186 self
187 }
188
189 pub fn workload_ttl(&self) -> Duration {
191 match self.workload {
192 AIWorkloadContext::RAGRetrieval => Duration::from_secs(300), AIWorkloadContext::RAGGeneration => Duration::from_secs(1800), AIWorkloadContext::AgentConversation => Duration::from_secs(3600), AIWorkloadContext::ToolResult => Duration::from_secs(86400), AIWorkloadContext::General => self.ttl,
197 }
198 }
199
200 pub fn is_expired(&self) -> bool {
202 self.created_at.elapsed() > self.workload_ttl()
203 }
204
205 pub fn matches_branch(&self, query_branch: &BranchContext) -> bool {
207 match &self.branch_context {
208 None => true, Some(entry_branch) => entry_branch.is_compatible(query_branch),
210 }
211 }
212
213 pub fn matches_session(&self, session: &SessionId) -> bool {
215 match &self.session_id {
216 None => true,
217 Some(entry_session) => entry_session == session,
218 }
219 }
220
221 pub fn size(&self) -> usize {
223 self.query.len() +
224 self.embedding.len() * 4 +
225 self.result.to_string().len() +
226 self.tables.iter().map(|t| t.len()).sum::<usize>() +
227 self.session_id.as_ref().map(|s| s.len()).unwrap_or(0) +
228 self.branch_context.as_ref().map(|b| b.branch.len() + 8).unwrap_or(0) +
229 96
230 }
231}
232
233#[derive(Debug, Clone)]
235pub struct SimilarityResult {
236 pub id: VectorId,
238 pub similarity: f32,
240 pub entry: SemanticEntry,
242}
243
244impl PartialEq for SimilarityResult {
245 fn eq(&self, other: &Self) -> bool {
246 self.similarity == other.similarity
247 }
248}
249
250impl Eq for SimilarityResult {}
251
252impl PartialOrd for SimilarityResult {
253 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
254 Some(self.cmp(other))
255 }
256}
257
258impl Ord for SimilarityResult {
259 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
260 other.similarity.partial_cmp(&self.similarity)
262 .unwrap_or(std::cmp::Ordering::Equal)
263 }
264}
265
266pub struct SemanticIndex {
269 vectors: DashMap<VectorId, Embedding>,
271
272 config: SemanticIndexConfig,
274
275 next_id: AtomicU64,
277}
278
279#[derive(Debug, Clone)]
281pub struct SemanticIndexConfig {
282 pub max_connections: usize,
284 pub ef_search: usize,
286 pub dimension: usize,
288}
289
290impl Default for SemanticIndexConfig {
291 fn default() -> Self {
292 Self {
293 max_connections: 16,
294 ef_search: 100,
295 dimension: 384, }
297 }
298}
299
300impl SemanticIndex {
301 pub fn new(config: SemanticIndexConfig) -> Self {
303 Self {
304 vectors: DashMap::new(),
305 config,
306 next_id: AtomicU64::new(1),
307 }
308 }
309
310 pub fn insert(&self, embedding: Embedding) -> VectorId {
312 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
313 self.vectors.insert(id, embedding);
314 id
315 }
316
317 pub fn remove(&self, id: VectorId) {
319 self.vectors.remove(&id);
320 }
321
322 pub fn search(&self, query: &[f32], k: usize) -> Vec<(VectorId, f32)> {
324 let mut heap: BinaryHeap<(std::cmp::Reverse<i64>, VectorId)> = BinaryHeap::new();
326
327 for entry in self.vectors.iter() {
328 let similarity = cosine_similarity(query, entry.value());
329 let sim_int = (similarity * 1_000_000.0) as i64;
331 heap.push((std::cmp::Reverse(sim_int), *entry.key()));
332
333 if heap.len() > k {
334 heap.pop();
335 }
336 }
337
338 let mut results: Vec<_> = heap.into_iter()
340 .map(|(std::cmp::Reverse(sim), id)| (id, sim as f32 / 1_000_000.0))
341 .collect();
342
343 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
344 results
345 }
346
347 pub fn len(&self) -> usize {
349 self.vectors.len()
350 }
351
352 pub fn is_empty(&self) -> bool {
354 self.vectors.is_empty()
355 }
356
357 pub fn clear(&self) {
359 self.vectors.clear();
360 }
361}
362
363pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
365 if a.len() != b.len() || a.is_empty() {
366 return 0.0;
367 }
368
369 let mut dot = 0.0;
370 let mut norm_a = 0.0;
371 let mut norm_b = 0.0;
372
373 for i in 0..a.len() {
374 dot += a[i] * b[i];
375 norm_a += a[i] * a[i];
376 norm_b += b[i] * b[i];
377 }
378
379 let denominator = (norm_a * norm_b).sqrt();
380 if denominator == 0.0 {
381 0.0
382 } else {
383 dot / denominator
384 }
385}
386
387pub struct SemanticQueryCache {
389 index: SemanticIndex,
391
392 entries: DashMap<VectorId, SemanticEntry>,
394
395 threshold: f32,
397
398 max_entries: usize,
400
401 stats: SemanticCacheStats,
403}
404
405#[derive(Debug, Default)]
407struct SemanticCacheStats {
408 hits: AtomicU64,
409 misses: AtomicU64,
410 semantic_hits: AtomicU64,
411 exact_hits: AtomicU64,
412 insertions: AtomicU64,
413 evictions: AtomicU64,
414}
415
416impl SemanticQueryCache {
417 pub fn new(threshold: f32) -> Self {
419 Self::with_capacity(threshold, 10000)
420 }
421
422 pub fn with_capacity(threshold: f32, max_entries: usize) -> Self {
424 Self {
425 index: SemanticIndex::new(SemanticIndexConfig::default()),
426 entries: DashMap::new(),
427 threshold,
428 max_entries,
429 stats: SemanticCacheStats::default(),
430 }
431 }
432
433 pub fn with_config(threshold: f32, max_entries: usize, index_config: SemanticIndexConfig) -> Self {
435 Self {
436 index: SemanticIndex::new(index_config),
437 entries: DashMap::new(),
438 threshold,
439 max_entries,
440 stats: SemanticCacheStats::default(),
441 }
442 }
443
444 pub fn lookup(&self, embedding: &[f32]) -> Option<SimilarityResult> {
446 let results = self.index.search(embedding, 1);
448
449 if let Some((id, similarity)) = results.first() {
450 if *similarity >= self.threshold {
451 if let Some(entry) = self.entries.get(id) {
452 if !entry.is_expired() {
453 self.stats.hits.fetch_add(1, Ordering::Relaxed);
454
455 if *similarity > 0.999 {
456 self.stats.exact_hits.fetch_add(1, Ordering::Relaxed);
457 } else {
458 self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
459 }
460
461 return Some(SimilarityResult {
462 id: *id,
463 similarity: *similarity,
464 entry: entry.clone(),
465 });
466 } else {
467 drop(entry);
469 self.remove(*id);
470 }
471 }
472 }
473 }
474
475 self.stats.misses.fetch_add(1, Ordering::Relaxed);
476 None
477 }
478
479 pub fn lookup_with_threshold(&self, embedding: &[f32], threshold: f32) -> Option<SimilarityResult> {
481 let results = self.index.search(embedding, 1);
482
483 if let Some((id, similarity)) = results.first() {
484 if *similarity >= threshold {
485 if let Some(entry) = self.entries.get(id) {
486 if !entry.is_expired() {
487 return Some(SimilarityResult {
488 id: *id,
489 similarity: *similarity,
490 entry: entry.clone(),
491 });
492 }
493 }
494 }
495 }
496
497 None
498 }
499
500 pub fn find_similar(&self, embedding: &[f32], k: usize) -> Vec<SimilarityResult> {
502 let results = self.index.search(embedding, k);
503
504 results.into_iter()
505 .filter_map(|(id, similarity)| {
506 self.entries.get(&id).and_then(|entry| {
507 if !entry.is_expired() {
508 Some(SimilarityResult {
509 id,
510 similarity,
511 entry: entry.clone(),
512 })
513 } else {
514 None
515 }
516 })
517 })
518 .collect()
519 }
520
521 pub fn lookup_with_branch(
526 &self,
527 embedding: &[f32],
528 branch: &BranchContext,
529 ) -> Option<SimilarityResult> {
530 let results = self.index.search(embedding, 10);
532
533 for (id, similarity) in results {
534 if similarity < self.threshold {
535 break; }
537
538 if let Some(entry) = self.entries.get(&id) {
539 if !entry.is_expired() && entry.matches_branch(branch) {
540 self.stats.hits.fetch_add(1, Ordering::Relaxed);
541 if similarity > 0.999 {
542 self.stats.exact_hits.fetch_add(1, Ordering::Relaxed);
543 } else {
544 self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
545 }
546
547 return Some(SimilarityResult {
548 id,
549 similarity,
550 entry: entry.clone(),
551 });
552 }
553 }
554 }
555
556 self.stats.misses.fetch_add(1, Ordering::Relaxed);
557 None
558 }
559
560 pub fn lookup_with_session(
564 &self,
565 embedding: &[f32],
566 session: &SessionId,
567 ) -> Option<SimilarityResult> {
568 let results = self.index.search(embedding, 20);
569
570 for (id, similarity) in &results {
572 if *similarity < self.threshold {
573 break;
574 }
575
576 if let Some(entry) = self.entries.get(id) {
577 if !entry.is_expired() && entry.matches_session(session) {
578 self.stats.hits.fetch_add(1, Ordering::Relaxed);
579 self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
580
581 return Some(SimilarityResult {
582 id: *id,
583 similarity: *similarity,
584 entry: entry.clone(),
585 });
586 }
587 }
588 }
589
590 for (id, similarity) in &results {
592 if *similarity < self.threshold {
593 break;
594 }
595
596 if let Some(entry) = self.entries.get(id) {
597 if !entry.is_expired() {
598 self.stats.hits.fetch_add(1, Ordering::Relaxed);
599 self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
600
601 return Some(SimilarityResult {
602 id: *id,
603 similarity: *similarity,
604 entry: entry.clone(),
605 });
606 }
607 }
608 }
609
610 self.stats.misses.fetch_add(1, Ordering::Relaxed);
611 None
612 }
613
614 pub fn lookup_with_context(
621 &self,
622 embedding: &[f32],
623 branch: Option<&BranchContext>,
624 session: Option<&SessionId>,
625 workload: AIWorkloadContext,
626 ) -> Option<SimilarityResult> {
627 let results = self.index.search(embedding, 20);
628
629 for (id, similarity) in &results {
631 if *similarity < self.threshold {
632 break;
633 }
634
635 if let Some(entry) = self.entries.get(id) {
636 let branch_match = branch.map(|b| entry.matches_branch(b)).unwrap_or(true);
637 let session_match = session.map(|s| entry.matches_session(s)).unwrap_or(false);
638 let workload_match = entry.workload == workload;
639
640 if !entry.is_expired() && branch_match && session_match && workload_match {
641 self.stats.hits.fetch_add(1, Ordering::Relaxed);
642 self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
643 return Some(SimilarityResult {
644 id: *id,
645 similarity: *similarity,
646 entry: entry.clone(),
647 });
648 }
649 }
650 }
651
652 for (id, similarity) in &results {
654 if *similarity < self.threshold {
655 break;
656 }
657
658 if let Some(entry) = self.entries.get(id) {
659 let branch_match = branch.map(|b| entry.matches_branch(b)).unwrap_or(true);
660 let workload_match = entry.workload == workload;
661
662 if !entry.is_expired() && branch_match && workload_match {
663 self.stats.hits.fetch_add(1, Ordering::Relaxed);
664 self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
665 return Some(SimilarityResult {
666 id: *id,
667 similarity: *similarity,
668 entry: entry.clone(),
669 });
670 }
671 }
672 }
673
674 for (id, similarity) in &results {
676 if *similarity < self.threshold {
677 break;
678 }
679
680 if let Some(entry) = self.entries.get(id) {
681 let branch_match = branch.map(|b| entry.matches_branch(b)).unwrap_or(true);
682
683 if !entry.is_expired() && branch_match {
684 self.stats.hits.fetch_add(1, Ordering::Relaxed);
685 self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
686 return Some(SimilarityResult {
687 id: *id,
688 similarity: *similarity,
689 entry: entry.clone(),
690 });
691 }
692 }
693 }
694
695 self.stats.misses.fetch_add(1, Ordering::Relaxed);
696 None
697 }
698
699 pub fn find_similar_in_branch(
701 &self,
702 embedding: &[f32],
703 branch: &BranchContext,
704 k: usize,
705 ) -> Vec<SimilarityResult> {
706 let results = self.index.search(embedding, k * 3);
708
709 results.into_iter()
710 .filter_map(|(id, similarity)| {
711 self.entries.get(&id).and_then(|entry| {
712 if !entry.is_expired() && entry.matches_branch(branch) {
713 Some(SimilarityResult {
714 id,
715 similarity,
716 entry: entry.clone(),
717 })
718 } else {
719 None
720 }
721 })
722 })
723 .take(k)
724 .collect()
725 }
726
727 pub fn invalidate_by_table(&self, table: &str) -> usize {
731 let to_remove: Vec<_> = self.entries.iter()
732 .filter(|e| e.tables.iter().any(|t| t == table))
733 .map(|e| *e.key())
734 .collect();
735
736 let count = to_remove.len();
737 for id in to_remove {
738 self.remove(id);
739 }
740 count
741 }
742
743 pub fn invalidate_branch(&self, branch: &BranchId) -> usize {
745 let to_remove: Vec<_> = self.entries.iter()
746 .filter(|e| {
747 e.branch_context.as_ref()
748 .map(|b| &b.branch == branch)
749 .unwrap_or(false)
750 })
751 .map(|e| *e.key())
752 .collect();
753
754 let count = to_remove.len();
755 for id in to_remove {
756 self.remove(id);
757 }
758 count
759 }
760
761 pub fn insert(&self, query: impl Into<String>, embedding: Embedding, result: serde_json::Value) -> VectorId {
763 while self.entries.len() >= self.max_entries {
765 self.evict_one();
766 }
767
768 let id = self.index.insert(embedding.clone());
770
771 let entry = SemanticEntry::new(id, query, embedding, result);
773 self.entries.insert(id, entry);
774
775 self.stats.insertions.fetch_add(1, Ordering::Relaxed);
776 id
777 }
778
779 pub fn insert_with_ttl(
781 &self,
782 query: impl Into<String>,
783 embedding: Embedding,
784 result: serde_json::Value,
785 ttl: Duration,
786 ) -> VectorId {
787 while self.entries.len() >= self.max_entries {
788 self.evict_one();
789 }
790
791 let id = self.index.insert(embedding.clone());
792 let entry = SemanticEntry::new(id, query, embedding, result).with_ttl(ttl);
793 self.entries.insert(id, entry);
794
795 self.stats.insertions.fetch_add(1, Ordering::Relaxed);
796 id
797 }
798
799 pub fn insert_with_context(
805 &self,
806 query: impl Into<String>,
807 embedding: Embedding,
808 result: serde_json::Value,
809 branch: Option<BranchContext>,
810 session: Option<SessionId>,
811 workload: AIWorkloadContext,
812 tables: Vec<String>,
813 ) -> VectorId {
814 while self.entries.len() >= self.max_entries {
815 self.evict_one();
816 }
817
818 let id = self.index.insert(embedding.clone());
819 let mut entry = SemanticEntry::new(id, query, embedding, result)
820 .with_workload(workload)
821 .with_tables(tables);
822
823 if let Some(b) = branch {
824 entry = entry.with_branch(b);
825 }
826 if let Some(s) = session {
827 entry = entry.with_session(s);
828 }
829
830 self.entries.insert(id, entry);
831 self.stats.insertions.fetch_add(1, Ordering::Relaxed);
832 id
833 }
834
835 pub fn insert_rag_retrieval(
839 &self,
840 query: impl Into<String>,
841 embedding: Embedding,
842 result: serde_json::Value,
843 tables: Vec<String>,
844 ) -> VectorId {
845 self.insert_with_context(
846 query,
847 embedding,
848 result,
849 None,
850 None,
851 AIWorkloadContext::RAGRetrieval,
852 tables,
853 )
854 }
855
856 pub fn insert_agent_response(
860 &self,
861 query: impl Into<String>,
862 embedding: Embedding,
863 result: serde_json::Value,
864 session: SessionId,
865 branch: Option<BranchContext>,
866 ) -> VectorId {
867 self.insert_with_context(
868 query,
869 embedding,
870 result,
871 branch,
872 Some(session),
873 AIWorkloadContext::AgentConversation,
874 Vec::new(),
875 )
876 }
877
878 pub fn insert_tool_result(
882 &self,
883 query: impl Into<String>,
884 embedding: Embedding,
885 result: serde_json::Value,
886 ) -> VectorId {
887 self.insert_with_context(
888 query,
889 embedding,
890 result,
891 None,
892 None,
893 AIWorkloadContext::ToolResult,
894 Vec::new(),
895 )
896 }
897
898 pub fn remove(&self, id: VectorId) {
900 self.index.remove(id);
901 self.entries.remove(&id);
902 }
903
904 fn evict_one(&self) {
906 let mut oldest_id = None;
907 let mut oldest_time = Instant::now();
908
909 for entry in self.entries.iter() {
910 if entry.created_at < oldest_time {
911 oldest_time = entry.created_at;
912 oldest_id = Some(*entry.key());
913 }
914 }
915
916 if let Some(id) = oldest_id {
917 self.remove(id);
918 self.stats.evictions.fetch_add(1, Ordering::Relaxed);
919 }
920 }
921
922 pub fn cleanup_expired(&self) {
924 let expired: Vec<_> = self.entries.iter()
925 .filter(|e| e.is_expired())
926 .map(|e| *e.key())
927 .collect();
928
929 for id in expired {
930 self.remove(id);
931 }
932 }
933
934 pub fn clear(&self) {
936 self.index.clear();
937 self.entries.clear();
938 }
939
940 pub fn len(&self) -> usize {
942 self.entries.len()
943 }
944
945 pub fn is_empty(&self) -> bool {
947 self.entries.is_empty()
948 }
949
950 pub fn stats(&self) -> SemanticCacheStatsSnapshot {
952 let hits = self.stats.hits.load(Ordering::Relaxed);
953 let misses = self.stats.misses.load(Ordering::Relaxed);
954 let total = hits + misses;
955
956 SemanticCacheStatsSnapshot {
957 entries: self.entries.len(),
958 threshold: self.threshold,
959 hits,
960 misses,
961 hit_rate: if total > 0 { hits as f64 / total as f64 } else { 0.0 },
962 semantic_hits: self.stats.semantic_hits.load(Ordering::Relaxed),
963 exact_hits: self.stats.exact_hits.load(Ordering::Relaxed),
964 insertions: self.stats.insertions.load(Ordering::Relaxed),
965 evictions: self.stats.evictions.load(Ordering::Relaxed),
966 }
967 }
968}
969
970#[derive(Debug, Clone)]
972pub struct SemanticCacheStatsSnapshot {
973 pub entries: usize,
974 pub threshold: f32,
975 pub hits: u64,
976 pub misses: u64,
977 pub hit_rate: f64,
978 pub semantic_hits: u64,
979 pub exact_hits: u64,
980 pub insertions: u64,
981 pub evictions: u64,
982}
983
984#[cfg(test)]
985mod tests {
986 use super::*;
987 use serde_json::json;
988
989 #[test]
990 fn test_cosine_similarity() {
991 let a = vec![1.0, 0.0, 0.0];
992 let b = vec![1.0, 0.0, 0.0];
993 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
994
995 let c = vec![0.0, 1.0, 0.0];
996 assert!(cosine_similarity(&a, &c).abs() < 0.001);
997
998 let d = vec![0.707, 0.707, 0.0];
999 let sim = cosine_similarity(&a, &d);
1000 assert!((sim - 0.707).abs() < 0.01);
1001 }
1002
1003 #[test]
1004 fn test_semantic_index() {
1005 let index = SemanticIndex::new(SemanticIndexConfig::default());
1006
1007 let id1 = index.insert(vec![1.0, 0.0, 0.0]);
1008 let id2 = index.insert(vec![0.9, 0.1, 0.0]);
1009 let id3 = index.insert(vec![0.0, 1.0, 0.0]);
1010
1011 let results = index.search(&[1.0, 0.0, 0.0], 2);
1013
1014 assert_eq!(results.len(), 2);
1015 assert_eq!(results[0].0, id1); assert_eq!(results[1].0, id2); }
1018
1019 #[test]
1020 fn test_semantic_cache_insert_lookup() {
1021 let cache = SemanticQueryCache::with_capacity(0.9, 100);
1022
1023 let embedding = vec![1.0, 0.0, 0.0];
1024 let id = cache.insert(
1025 "SELECT * FROM users WHERE name = 'test'",
1026 embedding.clone(),
1027 json!({"count": 5}),
1028 );
1029
1030 let result = cache.lookup(&embedding);
1032 assert!(result.is_some());
1033 let res = result.unwrap();
1034 assert_eq!(res.id, id);
1035 assert!(res.similarity > 0.999);
1036 }
1037
1038 #[test]
1039 fn test_semantic_similarity_lookup() {
1040 let cache = SemanticQueryCache::with_capacity(0.9, 100);
1041
1042 cache.insert(
1044 "SELECT * FROM users WHERE id = 1",
1045 vec![1.0, 0.0, 0.0],
1046 json!({"user": "alice"}),
1047 );
1048
1049 let similar_embedding = vec![0.95, 0.05, 0.0];
1051 let result = cache.lookup(&similar_embedding);
1052
1053 assert!(result.is_some());
1054 let res = result.unwrap();
1055 assert!(res.similarity >= 0.9);
1056 }
1057
1058 #[test]
1059 fn test_threshold_rejection() {
1060 let cache = SemanticQueryCache::with_capacity(0.95, 100);
1061
1062 cache.insert(
1063 "SELECT * FROM orders",
1064 vec![1.0, 0.0, 0.0],
1065 json!({"total": 100}),
1066 );
1067
1068 let different_embedding = vec![0.7, 0.7, 0.0];
1070 let result = cache.lookup(&different_embedding);
1071 assert!(result.is_none());
1072 }
1073
1074 #[test]
1075 fn test_find_similar() {
1076 let cache = SemanticQueryCache::with_capacity(0.5, 100);
1077
1078 cache.insert("query1", vec![1.0, 0.0, 0.0], json!(1));
1079 cache.insert("query2", vec![0.9, 0.1, 0.0], json!(2));
1080 cache.insert("query3", vec![0.8, 0.2, 0.0], json!(3));
1081 cache.insert("query4", vec![0.0, 1.0, 0.0], json!(4));
1082
1083 let similar = cache.find_similar(&[1.0, 0.0, 0.0], 3);
1084
1085 assert_eq!(similar.len(), 3);
1086 assert!(similar[0].similarity > similar[1].similarity);
1088 assert!(similar[1].similarity > similar[2].similarity);
1089 }
1090
1091 #[test]
1092 fn test_expiration() {
1093 let cache = SemanticQueryCache::with_capacity(0.9, 100);
1094
1095 let embedding = vec![1.0, 0.0, 0.0];
1096 cache.insert_with_ttl(
1097 "expiring query",
1098 embedding.clone(),
1099 json!({"expires": true}),
1100 Duration::from_millis(1),
1101 );
1102
1103 std::thread::sleep(Duration::from_millis(10));
1105
1106 let result = cache.lookup(&embedding);
1107 assert!(result.is_none());
1108 }
1109
1110 #[test]
1111 fn test_eviction() {
1112 let cache = SemanticQueryCache::with_capacity(0.9, 3);
1113
1114 for i in 0..3 {
1116 cache.insert(
1117 format!("query{}", i),
1118 vec![i as f32, 0.0, 0.0],
1119 json!(i),
1120 );
1121 }
1122
1123 assert_eq!(cache.len(), 3);
1124
1125 cache.insert("query3", vec![3.0, 0.0, 0.0], json!(3));
1127
1128 assert_eq!(cache.len(), 3);
1129 }
1130
1131 #[test]
1132 fn test_stats() {
1133 let cache = SemanticQueryCache::with_capacity(0.9, 100);
1134
1135 let embedding = vec![1.0, 0.0, 0.0];
1136 cache.insert("test query", embedding.clone(), json!(1));
1137
1138 cache.lookup(&embedding);
1140 cache.lookup(&embedding);
1141
1142 cache.lookup(&[0.0, 1.0, 0.0]);
1144
1145 let stats = cache.stats();
1146 assert_eq!(stats.hits, 2);
1147 assert_eq!(stats.misses, 1);
1148 assert_eq!(stats.exact_hits, 2);
1149 assert_eq!(stats.insertions, 1);
1150 }
1151
1152 #[test]
1153 fn test_branch_context_compatibility() {
1154 let main = BranchContext::main();
1155 let feature = BranchContext::new("feature-x");
1156 let snapshot = BranchContext::with_snapshot("main", 1000);
1157 let later_snapshot = BranchContext::with_snapshot("main", 2000);
1158
1159 assert!(main.is_compatible(&main));
1161 assert!(!main.is_compatible(&feature));
1162
1163 assert!(snapshot.is_compatible(&later_snapshot)); assert!(!later_snapshot.is_compatible(&snapshot)); }
1167
1168 #[test]
1169 fn test_lookup_with_branch() {
1170 let cache = SemanticQueryCache::with_capacity(0.9, 100);
1171
1172 let embedding = vec![1.0, 0.0, 0.0];
1174 cache.insert_with_context(
1175 "SELECT * FROM users",
1176 embedding.clone(),
1177 json!({"users": []}),
1178 Some(BranchContext::main()),
1179 None,
1180 AIWorkloadContext::General,
1181 vec!["users".to_string()],
1182 );
1183
1184 let embedding2 = vec![0.95, 0.05, 0.0];
1186 cache.insert_with_context(
1187 "SELECT * FROM users",
1188 embedding2.clone(),
1189 json!({"users": ["new_user"]}),
1190 Some(BranchContext::new("feature-x")),
1191 None,
1192 AIWorkloadContext::General,
1193 vec!["users".to_string()],
1194 );
1195
1196 let main_result = cache.lookup_with_branch(&embedding, &BranchContext::main());
1198 assert!(main_result.is_some());
1199 assert_eq!(main_result.unwrap().entry.branch_context.as_ref().unwrap().branch, "main");
1200
1201 let feature_result = cache.lookup_with_branch(&embedding2, &BranchContext::new("feature-x"));
1203 assert!(feature_result.is_some());
1204 assert_eq!(feature_result.unwrap().entry.branch_context.as_ref().unwrap().branch, "feature-x");
1205 }
1206
1207 #[test]
1208 fn test_lookup_with_session() {
1209 let cache = SemanticQueryCache::with_capacity(0.9, 100);
1210 let session1 = "session-001".to_string();
1211 let session2 = "session-002".to_string();
1212
1213 let embedding = vec![1.0, 0.0, 0.0];
1215 cache.insert_agent_response(
1216 "What is the weather?",
1217 embedding.clone(),
1218 json!({"weather": "sunny"}),
1219 session1.clone(),
1220 None,
1221 );
1222
1223 let embedding2 = vec![0.98, 0.02, 0.0];
1225 cache.insert_agent_response(
1226 "How's the weather?",
1227 embedding2,
1228 json!({"weather": "cloudy"}),
1229 session2.clone(),
1230 None,
1231 );
1232
1233 let result = cache.lookup_with_session(&embedding, &session1);
1235 assert!(result.is_some());
1236 assert_eq!(result.unwrap().entry.session_id.as_ref().unwrap(), &session1);
1237 }
1238
1239 #[test]
1240 fn test_lookup_with_context() {
1241 let cache = SemanticQueryCache::with_capacity(0.8, 100);
1242 let session = "agent-session".to_string();
1243 let branch = BranchContext::main();
1244
1245 let embedding = vec![1.0, 0.0, 0.0];
1247 cache.insert_with_context(
1248 "Find users with orders",
1249 embedding.clone(),
1250 json!({"users": 42}),
1251 Some(branch.clone()),
1252 Some(session.clone()),
1253 AIWorkloadContext::RAGRetrieval,
1254 vec!["users".to_string(), "orders".to_string()],
1255 );
1256
1257 let result = cache.lookup_with_context(
1259 &embedding,
1260 Some(&branch),
1261 Some(&session),
1262 AIWorkloadContext::RAGRetrieval,
1263 );
1264 assert!(result.is_some());
1265
1266 let result2 = cache.lookup_with_context(
1268 &embedding,
1269 Some(&branch),
1270 None,
1271 AIWorkloadContext::General,
1272 );
1273 assert!(result2.is_some());
1274 }
1275
1276 #[test]
1277 fn test_invalidate_by_table() {
1278 let cache = SemanticQueryCache::with_capacity(0.9, 100);
1279
1280 cache.insert_with_context(
1282 "SELECT * FROM users",
1283 vec![1.0, 0.0, 0.0],
1284 json!(1),
1285 None,
1286 None,
1287 AIWorkloadContext::General,
1288 vec!["users".to_string()],
1289 );
1290 cache.insert_with_context(
1291 "SELECT * FROM orders",
1292 vec![0.0, 1.0, 0.0],
1293 json!(2),
1294 None,
1295 None,
1296 AIWorkloadContext::General,
1297 vec!["orders".to_string()],
1298 );
1299 cache.insert_with_context(
1300 "SELECT * FROM users JOIN orders",
1301 vec![0.5, 0.5, 0.0],
1302 json!(3),
1303 None,
1304 None,
1305 AIWorkloadContext::General,
1306 vec!["users".to_string(), "orders".to_string()],
1307 );
1308
1309 assert_eq!(cache.len(), 3);
1310
1311 let removed = cache.invalidate_by_table("users");
1313 assert_eq!(removed, 2); assert_eq!(cache.len(), 1);
1315 }
1316
1317 #[test]
1318 fn test_invalidate_branch() {
1319 let cache = SemanticQueryCache::with_capacity(0.9, 100);
1320
1321 cache.insert_with_context(
1323 "query1",
1324 vec![1.0, 0.0, 0.0],
1325 json!(1),
1326 Some(BranchContext::main()),
1327 None,
1328 AIWorkloadContext::General,
1329 Vec::new(),
1330 );
1331 cache.insert_with_context(
1332 "query2",
1333 vec![0.0, 1.0, 0.0],
1334 json!(2),
1335 Some(BranchContext::new("feature-x")),
1336 None,
1337 AIWorkloadContext::General,
1338 Vec::new(),
1339 );
1340 cache.insert_with_context(
1341 "query3",
1342 vec![0.0, 0.0, 1.0],
1343 json!(3),
1344 Some(BranchContext::new("feature-x")),
1345 None,
1346 AIWorkloadContext::General,
1347 Vec::new(),
1348 );
1349
1350 assert_eq!(cache.len(), 3);
1351
1352 let removed = cache.invalidate_branch(&"feature-x".to_string());
1354 assert_eq!(removed, 2);
1355 assert_eq!(cache.len(), 1);
1356 }
1357
1358 #[test]
1359 fn test_workload_ttl() {
1360 let rag_entry = SemanticEntry::new(1, "rag query", vec![], json!({}))
1362 .with_workload(AIWorkloadContext::RAGRetrieval);
1363 assert_eq!(rag_entry.workload_ttl(), Duration::from_secs(300));
1364
1365 let tool_entry = SemanticEntry::new(2, "tool query", vec![], json!({}))
1367 .with_workload(AIWorkloadContext::ToolResult);
1368 assert_eq!(tool_entry.workload_ttl(), Duration::from_secs(86400));
1369
1370 let agent_entry = SemanticEntry::new(3, "agent query", vec![], json!({}))
1372 .with_workload(AIWorkloadContext::AgentConversation);
1373 assert_eq!(agent_entry.workload_ttl(), Duration::from_secs(3600));
1374 }
1375
1376 #[test]
1377 fn test_find_similar_in_branch() {
1378 let cache = SemanticQueryCache::with_capacity(0.5, 100);
1379 let main = BranchContext::main();
1380 let feature = BranchContext::new("feature-x");
1381
1382 for i in 0..3 {
1384 cache.insert_with_context(
1385 format!("main query {}", i),
1386 vec![1.0 - (i as f32 * 0.1), i as f32 * 0.1, 0.0],
1387 json!(i),
1388 Some(main.clone()),
1389 None,
1390 AIWorkloadContext::General,
1391 Vec::new(),
1392 );
1393 }
1394
1395 for i in 0..2 {
1397 cache.insert_with_context(
1398 format!("feature query {}", i),
1399 vec![0.5, 0.5 + (i as f32 * 0.1), 0.0],
1400 json!(100 + i),
1401 Some(feature.clone()),
1402 None,
1403 AIWorkloadContext::General,
1404 Vec::new(),
1405 );
1406 }
1407
1408 let main_results = cache.find_similar_in_branch(&[1.0, 0.0, 0.0], &main, 5);
1410 assert_eq!(main_results.len(), 3);
1411 for r in &main_results {
1412 assert_eq!(r.entry.branch_context.as_ref().unwrap().branch, "main");
1413 }
1414
1415 let feature_results = cache.find_similar_in_branch(&[0.5, 0.5, 0.0], &feature, 5);
1417 assert_eq!(feature_results.len(), 2);
1418 for r in &feature_results {
1419 assert_eq!(r.entry.branch_context.as_ref().unwrap().branch, "feature-x");
1420 }
1421 }
1422}