1use dashmap::DashMap;
20use std::collections::BinaryHeap;
21use std::sync::atomic::{AtomicU64, Ordering};
22use std::time::{Duration, Instant};
23
24pub type VectorId = u64;
26
27pub type BranchId = String;
29
30pub type SessionId = String;
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
35pub struct BranchContext {
36 pub branch: BranchId,
38 pub snapshot_at: Option<u64>,
40}
41
42impl BranchContext {
43 pub fn new(branch: impl Into<String>) -> Self {
45 Self {
46 branch: branch.into(),
47 snapshot_at: None,
48 }
49 }
50
51 pub fn with_snapshot(branch: impl Into<String>, snapshot: u64) -> Self {
53 Self {
54 branch: branch.into(),
55 snapshot_at: Some(snapshot),
56 }
57 }
58
59 pub fn main() -> Self {
61 Self::new("main")
62 }
63
64 pub fn is_compatible(&self, other: &BranchContext) -> bool {
66 if self.branch != other.branch {
67 return false;
68 }
69 match (self.snapshot_at, other.snapshot_at) {
71 (None, None) => true,
72 (Some(entry_snap), Some(query_snap)) => entry_snap <= query_snap,
73 (None, Some(_)) => true, (Some(_), None) => false, }
76 }
77}
78
79impl Default for BranchContext {
80 fn default() -> Self {
81 Self::main()
82 }
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
87pub enum AIWorkloadContext {
88 RAGRetrieval,
90 RAGGeneration,
92 AgentConversation,
94 ToolResult,
96 #[default]
98 General,
99}
100
101pub type Embedding = Vec<f32>;
103
104#[derive(Debug, Clone)]
106pub struct SemanticEntry {
107 pub id: VectorId,
109 pub query: String,
111 pub embedding: Embedding,
113 pub result: serde_json::Value,
115 pub created_at: Instant,
117 pub ttl: Duration,
119 pub access_count: u64,
121 pub branch_context: Option<BranchContext>,
123 pub session_id: Option<SessionId>,
125 pub workload: AIWorkloadContext,
127 pub tables: Vec<String>,
129}
130
131impl SemanticEntry {
132 pub fn new(
134 id: VectorId,
135 query: impl Into<String>,
136 embedding: Embedding,
137 result: serde_json::Value,
138 ) -> Self {
139 Self {
140 id,
141 query: query.into(),
142 embedding,
143 result,
144 created_at: Instant::now(),
145 ttl: Duration::from_secs(3600), access_count: 0,
147 branch_context: None,
148 session_id: None,
149 workload: AIWorkloadContext::default(),
150 tables: Vec::new(),
151 }
152 }
153
154 pub fn with_ttl(mut self, ttl: Duration) -> Self {
156 self.ttl = ttl;
157 self
158 }
159
160 pub fn with_branch(mut self, branch: BranchContext) -> Self {
162 self.branch_context = Some(branch);
163 self
164 }
165
166 pub fn with_session(mut self, session: impl Into<String>) -> Self {
168 self.session_id = Some(session.into());
169 self
170 }
171
172 pub fn with_workload(mut self, workload: AIWorkloadContext) -> Self {
174 self.workload = workload;
175 self
176 }
177
178 pub fn with_tables(mut self, tables: Vec<String>) -> Self {
180 self.tables = tables;
181 self
182 }
183
184 pub fn workload_ttl(&self) -> Duration {
186 match self.workload {
187 AIWorkloadContext::RAGRetrieval => Duration::from_secs(300), AIWorkloadContext::RAGGeneration => Duration::from_secs(1800), AIWorkloadContext::AgentConversation => Duration::from_secs(3600), AIWorkloadContext::ToolResult => Duration::from_secs(86400), AIWorkloadContext::General => self.ttl,
192 }
193 }
194
195 pub fn is_expired(&self) -> bool {
197 self.created_at.elapsed() > self.workload_ttl()
198 }
199
200 pub fn matches_branch(&self, query_branch: &BranchContext) -> bool {
202 match &self.branch_context {
203 None => true, Some(entry_branch) => entry_branch.is_compatible(query_branch),
205 }
206 }
207
208 pub fn matches_session(&self, session: &SessionId) -> bool {
210 match &self.session_id {
211 None => true,
212 Some(entry_session) => entry_session == session,
213 }
214 }
215
216 pub fn size(&self) -> usize {
218 self.query.len()
219 + self.embedding.len() * 4
220 + self.result.to_string().len()
221 + self.tables.iter().map(|t| t.len()).sum::<usize>()
222 + self.session_id.as_ref().map(|s| s.len()).unwrap_or(0)
223 + self
224 .branch_context
225 .as_ref()
226 .map(|b| b.branch.len() + 8)
227 .unwrap_or(0)
228 + 96
229 }
230}
231
232#[derive(Debug, Clone)]
234pub struct SimilarityResult {
235 pub id: VectorId,
237 pub similarity: f32,
239 pub entry: SemanticEntry,
241}
242
243impl PartialEq for SimilarityResult {
244 fn eq(&self, other: &Self) -> bool {
245 self.similarity == other.similarity
246 }
247}
248
249impl Eq for SimilarityResult {}
250
251impl PartialOrd for SimilarityResult {
252 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
253 Some(self.cmp(other))
254 }
255}
256
257impl Ord for SimilarityResult {
258 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
259 other
261 .similarity
262 .partial_cmp(&self.similarity)
263 .unwrap_or(std::cmp::Ordering::Equal)
264 }
265}
266
267pub struct SemanticIndex {
270 vectors: DashMap<VectorId, Embedding>,
272
273 #[allow(dead_code)]
275 config: SemanticIndexConfig,
276
277 next_id: AtomicU64,
279}
280
281#[derive(Debug, Clone)]
283pub struct SemanticIndexConfig {
284 pub max_connections: usize,
286 pub ef_search: usize,
288 pub dimension: usize,
290}
291
292impl Default for SemanticIndexConfig {
293 fn default() -> Self {
294 Self {
295 max_connections: 16,
296 ef_search: 100,
297 dimension: 384, }
299 }
300}
301
302impl SemanticIndex {
303 pub fn new(config: SemanticIndexConfig) -> Self {
305 Self {
306 vectors: DashMap::new(),
307 config,
308 next_id: AtomicU64::new(1),
309 }
310 }
311
312 pub fn insert(&self, embedding: Embedding) -> VectorId {
314 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
315 self.vectors.insert(id, embedding);
316 id
317 }
318
319 pub fn remove(&self, id: VectorId) {
321 self.vectors.remove(&id);
322 }
323
324 pub fn search(&self, query: &[f32], k: usize) -> Vec<(VectorId, f32)> {
326 let mut heap: BinaryHeap<(std::cmp::Reverse<i64>, VectorId)> = BinaryHeap::new();
328
329 for entry in self.vectors.iter() {
330 let similarity = cosine_similarity(query, entry.value());
331 let sim_int = (similarity * 1_000_000.0) as i64;
333 heap.push((std::cmp::Reverse(sim_int), *entry.key()));
334
335 if heap.len() > k {
336 heap.pop();
337 }
338 }
339
340 let mut results: Vec<_> = heap
342 .into_iter()
343 .map(|(std::cmp::Reverse(sim), id)| (id, sim as f32 / 1_000_000.0))
344 .collect();
345
346 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
347 results
348 }
349
350 pub fn len(&self) -> usize {
352 self.vectors.len()
353 }
354
355 pub fn is_empty(&self) -> bool {
357 self.vectors.is_empty()
358 }
359
360 pub fn clear(&self) {
362 self.vectors.clear();
363 }
364}
365
366pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
368 if a.len() != b.len() || a.is_empty() {
369 return 0.0;
370 }
371
372 let mut dot = 0.0;
373 let mut norm_a = 0.0;
374 let mut norm_b = 0.0;
375
376 for i in 0..a.len() {
377 dot += a[i] * b[i];
378 norm_a += a[i] * a[i];
379 norm_b += b[i] * b[i];
380 }
381
382 let denominator = (norm_a * norm_b).sqrt();
383 if denominator == 0.0 {
384 0.0
385 } else {
386 dot / denominator
387 }
388}
389
390pub struct SemanticQueryCache {
392 index: SemanticIndex,
394
395 entries: DashMap<VectorId, SemanticEntry>,
397
398 threshold: f32,
400
401 max_entries: usize,
403
404 stats: SemanticCacheStats,
406}
407
408#[derive(Debug, Default)]
410struct SemanticCacheStats {
411 hits: AtomicU64,
412 misses: AtomicU64,
413 semantic_hits: AtomicU64,
414 exact_hits: AtomicU64,
415 insertions: AtomicU64,
416 evictions: AtomicU64,
417}
418
419impl SemanticQueryCache {
420 pub fn new(threshold: f32) -> Self {
422 Self::with_capacity(threshold, 10000)
423 }
424
425 pub fn with_capacity(threshold: f32, max_entries: usize) -> Self {
427 Self {
428 index: SemanticIndex::new(SemanticIndexConfig::default()),
429 entries: DashMap::new(),
430 threshold,
431 max_entries,
432 stats: SemanticCacheStats::default(),
433 }
434 }
435
436 pub fn with_config(
438 threshold: f32,
439 max_entries: usize,
440 index_config: SemanticIndexConfig,
441 ) -> Self {
442 Self {
443 index: SemanticIndex::new(index_config),
444 entries: DashMap::new(),
445 threshold,
446 max_entries,
447 stats: SemanticCacheStats::default(),
448 }
449 }
450
451 pub fn lookup(&self, embedding: &[f32]) -> Option<SimilarityResult> {
453 let results = self.index.search(embedding, 1);
455
456 if let Some((id, similarity)) = results.first() {
457 if *similarity >= self.threshold {
458 if let Some(entry) = self.entries.get(id) {
459 if !entry.is_expired() {
460 self.stats.hits.fetch_add(1, Ordering::Relaxed);
461
462 if *similarity > 0.999 {
463 self.stats.exact_hits.fetch_add(1, Ordering::Relaxed);
464 } else {
465 self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
466 }
467
468 return Some(SimilarityResult {
469 id: *id,
470 similarity: *similarity,
471 entry: entry.clone(),
472 });
473 } else {
474 drop(entry);
476 self.remove(*id);
477 }
478 }
479 }
480 }
481
482 self.stats.misses.fetch_add(1, Ordering::Relaxed);
483 None
484 }
485
486 pub fn lookup_with_threshold(
488 &self,
489 embedding: &[f32],
490 threshold: f32,
491 ) -> Option<SimilarityResult> {
492 let results = self.index.search(embedding, 1);
493
494 if let Some((id, similarity)) = results.first() {
495 if *similarity >= threshold {
496 if let Some(entry) = self.entries.get(id) {
497 if !entry.is_expired() {
498 return Some(SimilarityResult {
499 id: *id,
500 similarity: *similarity,
501 entry: entry.clone(),
502 });
503 }
504 }
505 }
506 }
507
508 None
509 }
510
511 pub fn find_similar(&self, embedding: &[f32], k: usize) -> Vec<SimilarityResult> {
513 let results = self.index.search(embedding, k);
514
515 results
516 .into_iter()
517 .filter_map(|(id, similarity)| {
518 self.entries.get(&id).and_then(|entry| {
519 if !entry.is_expired() {
520 Some(SimilarityResult {
521 id,
522 similarity,
523 entry: entry.clone(),
524 })
525 } else {
526 None
527 }
528 })
529 })
530 .collect()
531 }
532
533 pub fn lookup_with_branch(
538 &self,
539 embedding: &[f32],
540 branch: &BranchContext,
541 ) -> Option<SimilarityResult> {
542 let results = self.index.search(embedding, 10);
544
545 for (id, similarity) in results {
546 if similarity < self.threshold {
547 break; }
549
550 if let Some(entry) = self.entries.get(&id) {
551 if !entry.is_expired() && entry.matches_branch(branch) {
552 self.stats.hits.fetch_add(1, Ordering::Relaxed);
553 if similarity > 0.999 {
554 self.stats.exact_hits.fetch_add(1, Ordering::Relaxed);
555 } else {
556 self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
557 }
558
559 return Some(SimilarityResult {
560 id,
561 similarity,
562 entry: entry.clone(),
563 });
564 }
565 }
566 }
567
568 self.stats.misses.fetch_add(1, Ordering::Relaxed);
569 None
570 }
571
572 pub fn lookup_with_session(
576 &self,
577 embedding: &[f32],
578 session: &SessionId,
579 ) -> Option<SimilarityResult> {
580 let results = self.index.search(embedding, 20);
581
582 for (id, similarity) in &results {
584 if *similarity < self.threshold {
585 break;
586 }
587
588 if let Some(entry) = self.entries.get(id) {
589 if !entry.is_expired() && entry.matches_session(session) {
590 self.stats.hits.fetch_add(1, Ordering::Relaxed);
591 self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
592
593 return Some(SimilarityResult {
594 id: *id,
595 similarity: *similarity,
596 entry: entry.clone(),
597 });
598 }
599 }
600 }
601
602 for (id, similarity) in &results {
604 if *similarity < self.threshold {
605 break;
606 }
607
608 if let Some(entry) = self.entries.get(id) {
609 if !entry.is_expired() {
610 self.stats.hits.fetch_add(1, Ordering::Relaxed);
611 self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
612
613 return Some(SimilarityResult {
614 id: *id,
615 similarity: *similarity,
616 entry: entry.clone(),
617 });
618 }
619 }
620 }
621
622 self.stats.misses.fetch_add(1, Ordering::Relaxed);
623 None
624 }
625
626 pub fn lookup_with_context(
633 &self,
634 embedding: &[f32],
635 branch: Option<&BranchContext>,
636 session: Option<&SessionId>,
637 workload: AIWorkloadContext,
638 ) -> Option<SimilarityResult> {
639 let results = self.index.search(embedding, 20);
640
641 for (id, similarity) in &results {
643 if *similarity < self.threshold {
644 break;
645 }
646
647 if let Some(entry) = self.entries.get(id) {
648 let branch_match = branch.map(|b| entry.matches_branch(b)).unwrap_or(true);
649 let session_match = session.map(|s| entry.matches_session(s)).unwrap_or(false);
650 let workload_match = entry.workload == workload;
651
652 if !entry.is_expired() && branch_match && session_match && workload_match {
653 self.stats.hits.fetch_add(1, Ordering::Relaxed);
654 self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
655 return Some(SimilarityResult {
656 id: *id,
657 similarity: *similarity,
658 entry: entry.clone(),
659 });
660 }
661 }
662 }
663
664 for (id, similarity) in &results {
666 if *similarity < self.threshold {
667 break;
668 }
669
670 if let Some(entry) = self.entries.get(id) {
671 let branch_match = branch.map(|b| entry.matches_branch(b)).unwrap_or(true);
672 let workload_match = entry.workload == workload;
673
674 if !entry.is_expired() && branch_match && workload_match {
675 self.stats.hits.fetch_add(1, Ordering::Relaxed);
676 self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
677 return Some(SimilarityResult {
678 id: *id,
679 similarity: *similarity,
680 entry: entry.clone(),
681 });
682 }
683 }
684 }
685
686 for (id, similarity) in &results {
688 if *similarity < self.threshold {
689 break;
690 }
691
692 if let Some(entry) = self.entries.get(id) {
693 let branch_match = branch.map(|b| entry.matches_branch(b)).unwrap_or(true);
694
695 if !entry.is_expired() && branch_match {
696 self.stats.hits.fetch_add(1, Ordering::Relaxed);
697 self.stats.semantic_hits.fetch_add(1, Ordering::Relaxed);
698 return Some(SimilarityResult {
699 id: *id,
700 similarity: *similarity,
701 entry: entry.clone(),
702 });
703 }
704 }
705 }
706
707 self.stats.misses.fetch_add(1, Ordering::Relaxed);
708 None
709 }
710
711 pub fn find_similar_in_branch(
713 &self,
714 embedding: &[f32],
715 branch: &BranchContext,
716 k: usize,
717 ) -> Vec<SimilarityResult> {
718 let results = self.index.search(embedding, k * 3);
720
721 results
722 .into_iter()
723 .filter_map(|(id, similarity)| {
724 self.entries.get(&id).and_then(|entry| {
725 if !entry.is_expired() && entry.matches_branch(branch) {
726 Some(SimilarityResult {
727 id,
728 similarity,
729 entry: entry.clone(),
730 })
731 } else {
732 None
733 }
734 })
735 })
736 .take(k)
737 .collect()
738 }
739
740 pub fn invalidate_by_table(&self, table: &str) -> usize {
744 let to_remove: Vec<_> = self
745 .entries
746 .iter()
747 .filter(|e| e.tables.iter().any(|t| t == table))
748 .map(|e| *e.key())
749 .collect();
750
751 let count = to_remove.len();
752 for id in to_remove {
753 self.remove(id);
754 }
755 count
756 }
757
758 pub fn invalidate_branch(&self, branch: &BranchId) -> usize {
760 let to_remove: Vec<_> = self
761 .entries
762 .iter()
763 .filter(|e| {
764 e.branch_context
765 .as_ref()
766 .map(|b| &b.branch == branch)
767 .unwrap_or(false)
768 })
769 .map(|e| *e.key())
770 .collect();
771
772 let count = to_remove.len();
773 for id in to_remove {
774 self.remove(id);
775 }
776 count
777 }
778
779 pub fn insert(
781 &self,
782 query: impl Into<String>,
783 embedding: Embedding,
784 result: serde_json::Value,
785 ) -> VectorId {
786 while self.entries.len() >= self.max_entries {
788 self.evict_one();
789 }
790
791 let id = self.index.insert(embedding.clone());
793
794 let entry = SemanticEntry::new(id, query, embedding, result);
796 self.entries.insert(id, entry);
797
798 self.stats.insertions.fetch_add(1, Ordering::Relaxed);
799 id
800 }
801
802 pub fn insert_with_ttl(
804 &self,
805 query: impl Into<String>,
806 embedding: Embedding,
807 result: serde_json::Value,
808 ttl: Duration,
809 ) -> VectorId {
810 while self.entries.len() >= self.max_entries {
811 self.evict_one();
812 }
813
814 let id = self.index.insert(embedding.clone());
815 let entry = SemanticEntry::new(id, query, embedding, result).with_ttl(ttl);
816 self.entries.insert(id, entry);
817
818 self.stats.insertions.fetch_add(1, Ordering::Relaxed);
819 id
820 }
821
822 #[allow(clippy::too_many_arguments)]
828 pub fn insert_with_context(
829 &self,
830 query: impl Into<String>,
831 embedding: Embedding,
832 result: serde_json::Value,
833 branch: Option<BranchContext>,
834 session: Option<SessionId>,
835 workload: AIWorkloadContext,
836 tables: Vec<String>,
837 ) -> VectorId {
838 while self.entries.len() >= self.max_entries {
839 self.evict_one();
840 }
841
842 let id = self.index.insert(embedding.clone());
843 let mut entry = SemanticEntry::new(id, query, embedding, result)
844 .with_workload(workload)
845 .with_tables(tables);
846
847 if let Some(b) = branch {
848 entry = entry.with_branch(b);
849 }
850 if let Some(s) = session {
851 entry = entry.with_session(s);
852 }
853
854 self.entries.insert(id, entry);
855 self.stats.insertions.fetch_add(1, Ordering::Relaxed);
856 id
857 }
858
859 pub fn insert_rag_retrieval(
863 &self,
864 query: impl Into<String>,
865 embedding: Embedding,
866 result: serde_json::Value,
867 tables: Vec<String>,
868 ) -> VectorId {
869 self.insert_with_context(
870 query,
871 embedding,
872 result,
873 None,
874 None,
875 AIWorkloadContext::RAGRetrieval,
876 tables,
877 )
878 }
879
880 pub fn insert_agent_response(
884 &self,
885 query: impl Into<String>,
886 embedding: Embedding,
887 result: serde_json::Value,
888 session: SessionId,
889 branch: Option<BranchContext>,
890 ) -> VectorId {
891 self.insert_with_context(
892 query,
893 embedding,
894 result,
895 branch,
896 Some(session),
897 AIWorkloadContext::AgentConversation,
898 Vec::new(),
899 )
900 }
901
902 pub fn insert_tool_result(
906 &self,
907 query: impl Into<String>,
908 embedding: Embedding,
909 result: serde_json::Value,
910 ) -> VectorId {
911 self.insert_with_context(
912 query,
913 embedding,
914 result,
915 None,
916 None,
917 AIWorkloadContext::ToolResult,
918 Vec::new(),
919 )
920 }
921
922 pub fn remove(&self, id: VectorId) {
924 self.index.remove(id);
925 self.entries.remove(&id);
926 }
927
928 fn evict_one(&self) {
930 let mut oldest_id = None;
931 let mut oldest_time = Instant::now();
932
933 for entry in self.entries.iter() {
934 if entry.created_at < oldest_time {
935 oldest_time = entry.created_at;
936 oldest_id = Some(*entry.key());
937 }
938 }
939
940 if let Some(id) = oldest_id {
941 self.remove(id);
942 self.stats.evictions.fetch_add(1, Ordering::Relaxed);
943 }
944 }
945
946 pub fn cleanup_expired(&self) {
948 let expired: Vec<_> = self
949 .entries
950 .iter()
951 .filter(|e| e.is_expired())
952 .map(|e| *e.key())
953 .collect();
954
955 for id in expired {
956 self.remove(id);
957 }
958 }
959
960 pub fn clear(&self) {
962 self.index.clear();
963 self.entries.clear();
964 }
965
966 pub fn len(&self) -> usize {
968 self.entries.len()
969 }
970
971 pub fn is_empty(&self) -> bool {
973 self.entries.is_empty()
974 }
975
976 pub fn stats(&self) -> SemanticCacheStatsSnapshot {
978 let hits = self.stats.hits.load(Ordering::Relaxed);
979 let misses = self.stats.misses.load(Ordering::Relaxed);
980 let total = hits + misses;
981
982 SemanticCacheStatsSnapshot {
983 entries: self.entries.len(),
984 threshold: self.threshold,
985 hits,
986 misses,
987 hit_rate: if total > 0 {
988 hits as f64 / total as f64
989 } else {
990 0.0
991 },
992 semantic_hits: self.stats.semantic_hits.load(Ordering::Relaxed),
993 exact_hits: self.stats.exact_hits.load(Ordering::Relaxed),
994 insertions: self.stats.insertions.load(Ordering::Relaxed),
995 evictions: self.stats.evictions.load(Ordering::Relaxed),
996 }
997 }
998}
999
1000#[derive(Debug, Clone)]
1002pub struct SemanticCacheStatsSnapshot {
1003 pub entries: usize,
1004 pub threshold: f32,
1005 pub hits: u64,
1006 pub misses: u64,
1007 pub hit_rate: f64,
1008 pub semantic_hits: u64,
1009 pub exact_hits: u64,
1010 pub insertions: u64,
1011 pub evictions: u64,
1012}
1013
1014#[cfg(test)]
1015mod tests {
1016 use super::*;
1017 use serde_json::json;
1018
1019 #[test]
1020 fn test_cosine_similarity() {
1021 let a = vec![1.0, 0.0, 0.0];
1022 let b = vec![1.0, 0.0, 0.0];
1023 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
1024
1025 let c = vec![0.0, 1.0, 0.0];
1026 assert!(cosine_similarity(&a, &c).abs() < 0.001);
1027
1028 let d = vec![0.707, 0.707, 0.0];
1029 let sim = cosine_similarity(&a, &d);
1030 assert!((sim - 0.707).abs() < 0.01);
1031 }
1032
1033 #[test]
1034 fn test_semantic_index() {
1035 let index = SemanticIndex::new(SemanticIndexConfig::default());
1036
1037 let id1 = index.insert(vec![1.0, 0.0, 0.0]);
1038 let id2 = index.insert(vec![0.9, 0.1, 0.0]);
1039 let id3 = index.insert(vec![0.0, 1.0, 0.0]);
1040
1041 let results = index.search(&[1.0, 0.0, 0.0], 2);
1043
1044 assert_eq!(results.len(), 2);
1045 assert_eq!(results[0].0, id1); assert_eq!(results[1].0, id2); }
1048
1049 #[test]
1050 fn test_semantic_cache_insert_lookup() {
1051 let cache = SemanticQueryCache::with_capacity(0.9, 100);
1052
1053 let embedding = vec![1.0, 0.0, 0.0];
1054 let id = cache.insert(
1055 "SELECT * FROM users WHERE name = 'test'",
1056 embedding.clone(),
1057 json!({"count": 5}),
1058 );
1059
1060 let result = cache.lookup(&embedding);
1062 assert!(result.is_some());
1063 let res = result.unwrap();
1064 assert_eq!(res.id, id);
1065 assert!(res.similarity > 0.999);
1066 }
1067
1068 #[test]
1069 fn test_semantic_similarity_lookup() {
1070 let cache = SemanticQueryCache::with_capacity(0.9, 100);
1071
1072 cache.insert(
1074 "SELECT * FROM users WHERE id = 1",
1075 vec![1.0, 0.0, 0.0],
1076 json!({"user": "alice"}),
1077 );
1078
1079 let similar_embedding = vec![0.95, 0.05, 0.0];
1081 let result = cache.lookup(&similar_embedding);
1082
1083 assert!(result.is_some());
1084 let res = result.unwrap();
1085 assert!(res.similarity >= 0.9);
1086 }
1087
1088 #[test]
1089 fn test_threshold_rejection() {
1090 let cache = SemanticQueryCache::with_capacity(0.95, 100);
1091
1092 cache.insert(
1093 "SELECT * FROM orders",
1094 vec![1.0, 0.0, 0.0],
1095 json!({"total": 100}),
1096 );
1097
1098 let different_embedding = vec![0.7, 0.7, 0.0];
1100 let result = cache.lookup(&different_embedding);
1101 assert!(result.is_none());
1102 }
1103
1104 #[test]
1105 fn test_find_similar() {
1106 let cache = SemanticQueryCache::with_capacity(0.5, 100);
1107
1108 cache.insert("query1", vec![1.0, 0.0, 0.0], json!(1));
1109 cache.insert("query2", vec![0.9, 0.1, 0.0], json!(2));
1110 cache.insert("query3", vec![0.8, 0.2, 0.0], json!(3));
1111 cache.insert("query4", vec![0.0, 1.0, 0.0], json!(4));
1112
1113 let similar = cache.find_similar(&[1.0, 0.0, 0.0], 3);
1114
1115 assert_eq!(similar.len(), 3);
1116 assert!(similar[0].similarity > similar[1].similarity);
1118 assert!(similar[1].similarity > similar[2].similarity);
1119 }
1120
1121 #[test]
1122 fn test_expiration() {
1123 let cache = SemanticQueryCache::with_capacity(0.9, 100);
1124
1125 let embedding = vec![1.0, 0.0, 0.0];
1126 cache.insert_with_ttl(
1127 "expiring query",
1128 embedding.clone(),
1129 json!({"expires": true}),
1130 Duration::from_millis(1),
1131 );
1132
1133 std::thread::sleep(Duration::from_millis(10));
1135
1136 let result = cache.lookup(&embedding);
1137 assert!(result.is_none());
1138 }
1139
1140 #[test]
1141 fn test_eviction() {
1142 let cache = SemanticQueryCache::with_capacity(0.9, 3);
1143
1144 for i in 0..3 {
1146 cache.insert(format!("query{}", i), vec![i as f32, 0.0, 0.0], json!(i));
1147 }
1148
1149 assert_eq!(cache.len(), 3);
1150
1151 cache.insert("query3", vec![3.0, 0.0, 0.0], json!(3));
1153
1154 assert_eq!(cache.len(), 3);
1155 }
1156
1157 #[test]
1158 fn test_stats() {
1159 let cache = SemanticQueryCache::with_capacity(0.9, 100);
1160
1161 let embedding = vec![1.0, 0.0, 0.0];
1162 cache.insert("test query", embedding.clone(), json!(1));
1163
1164 cache.lookup(&embedding);
1166 cache.lookup(&embedding);
1167
1168 cache.lookup(&[0.0, 1.0, 0.0]);
1170
1171 let stats = cache.stats();
1172 assert_eq!(stats.hits, 2);
1173 assert_eq!(stats.misses, 1);
1174 assert_eq!(stats.exact_hits, 2);
1175 assert_eq!(stats.insertions, 1);
1176 }
1177
1178 #[test]
1179 fn test_branch_context_compatibility() {
1180 let main = BranchContext::main();
1181 let feature = BranchContext::new("feature-x");
1182 let snapshot = BranchContext::with_snapshot("main", 1000);
1183 let later_snapshot = BranchContext::with_snapshot("main", 2000);
1184
1185 assert!(main.is_compatible(&main));
1187 assert!(!main.is_compatible(&feature));
1188
1189 assert!(snapshot.is_compatible(&later_snapshot)); assert!(!later_snapshot.is_compatible(&snapshot)); }
1193
1194 #[test]
1195 fn test_lookup_with_branch() {
1196 let cache = SemanticQueryCache::with_capacity(0.9, 100);
1197
1198 let embedding = vec![1.0, 0.0, 0.0];
1200 cache.insert_with_context(
1201 "SELECT * FROM users",
1202 embedding.clone(),
1203 json!({"users": []}),
1204 Some(BranchContext::main()),
1205 None,
1206 AIWorkloadContext::General,
1207 vec!["users".to_string()],
1208 );
1209
1210 let embedding2 = vec![0.95, 0.05, 0.0];
1212 cache.insert_with_context(
1213 "SELECT * FROM users",
1214 embedding2.clone(),
1215 json!({"users": ["new_user"]}),
1216 Some(BranchContext::new("feature-x")),
1217 None,
1218 AIWorkloadContext::General,
1219 vec!["users".to_string()],
1220 );
1221
1222 let main_result = cache.lookup_with_branch(&embedding, &BranchContext::main());
1224 assert!(main_result.is_some());
1225 assert_eq!(
1226 main_result
1227 .unwrap()
1228 .entry
1229 .branch_context
1230 .as_ref()
1231 .unwrap()
1232 .branch,
1233 "main"
1234 );
1235
1236 let feature_result =
1238 cache.lookup_with_branch(&embedding2, &BranchContext::new("feature-x"));
1239 assert!(feature_result.is_some());
1240 assert_eq!(
1241 feature_result
1242 .unwrap()
1243 .entry
1244 .branch_context
1245 .as_ref()
1246 .unwrap()
1247 .branch,
1248 "feature-x"
1249 );
1250 }
1251
1252 #[test]
1253 fn test_lookup_with_session() {
1254 let cache = SemanticQueryCache::with_capacity(0.9, 100);
1255 let session1 = "session-001".to_string();
1256 let session2 = "session-002".to_string();
1257
1258 let embedding = vec![1.0, 0.0, 0.0];
1260 cache.insert_agent_response(
1261 "What is the weather?",
1262 embedding.clone(),
1263 json!({"weather": "sunny"}),
1264 session1.clone(),
1265 None,
1266 );
1267
1268 let embedding2 = vec![0.98, 0.02, 0.0];
1270 cache.insert_agent_response(
1271 "How's the weather?",
1272 embedding2,
1273 json!({"weather": "cloudy"}),
1274 session2.clone(),
1275 None,
1276 );
1277
1278 let result = cache.lookup_with_session(&embedding, &session1);
1280 assert!(result.is_some());
1281 assert_eq!(
1282 result.unwrap().entry.session_id.as_ref().unwrap(),
1283 &session1
1284 );
1285 }
1286
1287 #[test]
1288 fn test_lookup_with_context() {
1289 let cache = SemanticQueryCache::with_capacity(0.8, 100);
1290 let session = "agent-session".to_string();
1291 let branch = BranchContext::main();
1292
1293 let embedding = vec![1.0, 0.0, 0.0];
1295 cache.insert_with_context(
1296 "Find users with orders",
1297 embedding.clone(),
1298 json!({"users": 42}),
1299 Some(branch.clone()),
1300 Some(session.clone()),
1301 AIWorkloadContext::RAGRetrieval,
1302 vec!["users".to_string(), "orders".to_string()],
1303 );
1304
1305 let result = cache.lookup_with_context(
1307 &embedding,
1308 Some(&branch),
1309 Some(&session),
1310 AIWorkloadContext::RAGRetrieval,
1311 );
1312 assert!(result.is_some());
1313
1314 let result2 =
1316 cache.lookup_with_context(&embedding, Some(&branch), None, AIWorkloadContext::General);
1317 assert!(result2.is_some());
1318 }
1319
1320 #[test]
1321 fn test_invalidate_by_table() {
1322 let cache = SemanticQueryCache::with_capacity(0.9, 100);
1323
1324 cache.insert_with_context(
1326 "SELECT * FROM users",
1327 vec![1.0, 0.0, 0.0],
1328 json!(1),
1329 None,
1330 None,
1331 AIWorkloadContext::General,
1332 vec!["users".to_string()],
1333 );
1334 cache.insert_with_context(
1335 "SELECT * FROM orders",
1336 vec![0.0, 1.0, 0.0],
1337 json!(2),
1338 None,
1339 None,
1340 AIWorkloadContext::General,
1341 vec!["orders".to_string()],
1342 );
1343 cache.insert_with_context(
1344 "SELECT * FROM users JOIN orders",
1345 vec![0.5, 0.5, 0.0],
1346 json!(3),
1347 None,
1348 None,
1349 AIWorkloadContext::General,
1350 vec!["users".to_string(), "orders".to_string()],
1351 );
1352
1353 assert_eq!(cache.len(), 3);
1354
1355 let removed = cache.invalidate_by_table("users");
1357 assert_eq!(removed, 2); assert_eq!(cache.len(), 1);
1359 }
1360
1361 #[test]
1362 fn test_invalidate_branch() {
1363 let cache = SemanticQueryCache::with_capacity(0.9, 100);
1364
1365 cache.insert_with_context(
1367 "query1",
1368 vec![1.0, 0.0, 0.0],
1369 json!(1),
1370 Some(BranchContext::main()),
1371 None,
1372 AIWorkloadContext::General,
1373 Vec::new(),
1374 );
1375 cache.insert_with_context(
1376 "query2",
1377 vec![0.0, 1.0, 0.0],
1378 json!(2),
1379 Some(BranchContext::new("feature-x")),
1380 None,
1381 AIWorkloadContext::General,
1382 Vec::new(),
1383 );
1384 cache.insert_with_context(
1385 "query3",
1386 vec![0.0, 0.0, 1.0],
1387 json!(3),
1388 Some(BranchContext::new("feature-x")),
1389 None,
1390 AIWorkloadContext::General,
1391 Vec::new(),
1392 );
1393
1394 assert_eq!(cache.len(), 3);
1395
1396 let removed = cache.invalidate_branch(&"feature-x".to_string());
1398 assert_eq!(removed, 2);
1399 assert_eq!(cache.len(), 1);
1400 }
1401
1402 #[test]
1403 fn test_workload_ttl() {
1404 let rag_entry = SemanticEntry::new(1, "rag query", vec![], json!({}))
1406 .with_workload(AIWorkloadContext::RAGRetrieval);
1407 assert_eq!(rag_entry.workload_ttl(), Duration::from_secs(300));
1408
1409 let tool_entry = SemanticEntry::new(2, "tool query", vec![], json!({}))
1411 .with_workload(AIWorkloadContext::ToolResult);
1412 assert_eq!(tool_entry.workload_ttl(), Duration::from_secs(86400));
1413
1414 let agent_entry = SemanticEntry::new(3, "agent query", vec![], json!({}))
1416 .with_workload(AIWorkloadContext::AgentConversation);
1417 assert_eq!(agent_entry.workload_ttl(), Duration::from_secs(3600));
1418 }
1419
1420 #[test]
1421 fn test_find_similar_in_branch() {
1422 let cache = SemanticQueryCache::with_capacity(0.5, 100);
1423 let main = BranchContext::main();
1424 let feature = BranchContext::new("feature-x");
1425
1426 for i in 0..3 {
1428 cache.insert_with_context(
1429 format!("main query {}", i),
1430 vec![1.0 - (i as f32 * 0.1), i as f32 * 0.1, 0.0],
1431 json!(i),
1432 Some(main.clone()),
1433 None,
1434 AIWorkloadContext::General,
1435 Vec::new(),
1436 );
1437 }
1438
1439 for i in 0..2 {
1441 cache.insert_with_context(
1442 format!("feature query {}", i),
1443 vec![0.5, 0.5 + (i as f32 * 0.1), 0.0],
1444 json!(100 + i),
1445 Some(feature.clone()),
1446 None,
1447 AIWorkloadContext::General,
1448 Vec::new(),
1449 );
1450 }
1451
1452 let main_results = cache.find_similar_in_branch(&[1.0, 0.0, 0.0], &main, 5);
1454 assert_eq!(main_results.len(), 3);
1455 for r in &main_results {
1456 assert_eq!(r.entry.branch_context.as_ref().unwrap().branch, "main");
1457 }
1458
1459 let feature_results = cache.find_similar_in_branch(&[0.5, 0.5, 0.0], &feature, 5);
1461 assert_eq!(feature_results.len(), 2);
1462 for r in &feature_results {
1463 assert_eq!(r.entry.branch_context.as_ref().unwrap().branch, "feature-x");
1464 }
1465 }
1466}