1use crate::hnsw::{DistanceMetric, SearchResult, VectorIndex};
8use ipfrs_core::{Cid, Result};
9use lru::LruCache;
10use std::collections::hash_map::DefaultHasher;
11use std::hash::{Hash, Hasher};
12use std::num::NonZeroUsize;
13use std::sync::{Arc, RwLock};
14
15#[derive(Debug, Clone)]
17pub struct RouterConfig {
18 pub dimension: usize,
20 pub metric: DistanceMetric,
22 pub max_connections: usize,
24 pub ef_construction: usize,
26 pub ef_search: usize,
28 pub cache_size: usize,
30}
31
32impl Default for RouterConfig {
33 fn default() -> Self {
34 Self {
35 dimension: 768, metric: DistanceMetric::Cosine,
37 max_connections: 16,
38 ef_construction: 200,
39 ef_search: 50,
40 cache_size: 1000, }
42 }
43}
44
45impl RouterConfig {
46 pub fn low_latency(dimension: usize) -> Self {
60 Self {
61 dimension,
62 metric: DistanceMetric::Cosine,
63 max_connections: 12,
64 ef_construction: 150,
65 ef_search: 32,
66 cache_size: 2000, }
68 }
69
70 pub fn high_recall(dimension: usize) -> Self {
84 Self {
85 dimension,
86 metric: DistanceMetric::Cosine,
87 max_connections: 32,
88 ef_construction: 400,
89 ef_search: 200,
90 cache_size: 1000,
91 }
92 }
93
94 pub fn memory_efficient(dimension: usize) -> Self {
108 Self {
109 dimension,
110 metric: DistanceMetric::Cosine,
111 max_connections: 8,
112 ef_construction: 100,
113 ef_search: 50,
114 cache_size: 500, }
116 }
117
118 pub fn large_scale(dimension: usize) -> Self {
132 Self {
133 dimension,
134 metric: DistanceMetric::Cosine,
135 max_connections: 24,
136 ef_construction: 300,
137 ef_search: 100,
138 cache_size: 5000, }
140 }
141
142 pub fn balanced(dimension: usize) -> Self {
156 Self {
157 dimension,
158 metric: DistanceMetric::Cosine,
159 max_connections: 16,
160 ef_construction: 200,
161 ef_search: 50,
162 cache_size: 1000,
163 }
164 }
165
166 pub fn with_metric(dimension: usize, metric: DistanceMetric) -> Self {
176 Self {
177 dimension,
178 metric,
179 ..Self::balanced(dimension)
180 }
181 }
182
183 pub fn with_cache_size(mut self, size: usize) -> Self {
193 self.cache_size = size;
194 self
195 }
196
197 pub fn with_ef_search(mut self, ef_search: usize) -> Self {
209 self.ef_search = ef_search;
210 self
211 }
212}
213
214#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
216pub struct QueryFilter {
217 pub min_score: Option<f32>,
219 pub max_score: Option<f32>,
221 pub max_results: Option<usize>,
223 pub cid_prefix: Option<String>,
225}
226
227impl Default for QueryFilter {
228 fn default() -> Self {
229 Self {
230 min_score: None,
231 max_score: None,
232 max_results: Some(10),
233 cid_prefix: None,
234 }
235 }
236}
237
238impl QueryFilter {
239 pub fn range(min: f32, max: f32) -> Self {
241 Self {
242 min_score: Some(min),
243 max_score: Some(max),
244 max_results: None,
245 cid_prefix: None,
246 }
247 }
248
249 pub fn threshold(min: f32) -> Self {
251 Self {
252 min_score: Some(min),
253 max_score: None,
254 max_results: None,
255 cid_prefix: None,
256 }
257 }
258
259 pub fn prefix(prefix: String) -> Self {
261 Self {
262 min_score: None,
263 max_score: None,
264 max_results: None,
265 cid_prefix: Some(prefix),
266 }
267 }
268
269 pub fn and(mut self, other: QueryFilter) -> Self {
271 if let Some(min) = other.min_score {
272 self.min_score = Some(self.min_score.unwrap_or(f32::MIN).max(min));
273 }
274 if let Some(max) = other.max_score {
275 self.max_score = Some(self.max_score.unwrap_or(f32::MAX).min(max));
276 }
277 if let Some(max_results) = other.max_results {
278 self.max_results = Some(self.max_results.unwrap_or(usize::MAX).min(max_results));
279 }
280 if other.cid_prefix.is_some() {
281 self.cid_prefix = other.cid_prefix;
282 }
283 self
284 }
285
286 pub fn limit(mut self, max: usize) -> Self {
288 self.max_results = Some(max);
289 self
290 }
291}
292
293type QueryCacheKey = u64;
295
296pub struct SemanticRouter {
301 index: Arc<RwLock<VectorIndex>>,
303 config: RouterConfig,
305 query_cache: Arc<RwLock<LruCache<QueryCacheKey, Vec<SearchResult>>>>,
307}
308
309impl SemanticRouter {
310 pub fn new(config: RouterConfig) -> Result<Self> {
312 let index = VectorIndex::new(
313 config.dimension,
314 config.metric,
315 config.max_connections,
316 config.ef_construction,
317 )?;
318
319 let cache_size =
320 NonZeroUsize::new(config.cache_size).unwrap_or(NonZeroUsize::new(1000).unwrap());
321 let query_cache = LruCache::new(cache_size);
322
323 Ok(Self {
324 index: Arc::new(RwLock::new(index)),
325 config,
326 query_cache: Arc::new(RwLock::new(query_cache)),
327 })
328 }
329
330 pub fn with_defaults() -> Result<Self> {
332 Self::new(RouterConfig::default())
333 }
334
335 pub fn add(&self, cid: &Cid, embedding: &[f32]) -> Result<()> {
341 self.index.write().unwrap().insert(cid, embedding)
342 }
343
344 pub fn add_batch(&self, items: &[(Cid, Vec<f32>)]) -> Result<()> {
351 self.index.write().unwrap().insert_batch(items)
352 }
353
354 pub fn remove(&self, cid: &Cid) -> Result<()> {
356 self.index.write().unwrap().delete(cid)
357 }
358
359 pub fn contains(&self, cid: &Cid) -> bool {
361 self.index.read().unwrap().contains(cid)
362 }
363
364 pub async fn query(&self, query_embedding: &[f32], k: usize) -> Result<Vec<SearchResult>> {
370 self.query_with_filter(query_embedding, k, QueryFilter::default())
371 .await
372 }
373
374 pub async fn query_auto(&self, query_embedding: &[f32], k: usize) -> Result<Vec<SearchResult>> {
382 let optimal_ef_search = self.index.read().unwrap().compute_optimal_ef_search(k);
383 self.query_with_ef(query_embedding, k, optimal_ef_search)
384 .await
385 }
386
387 pub async fn query_with_ef(
394 &self,
395 query_embedding: &[f32],
396 k: usize,
397 ef_search: usize,
398 ) -> Result<Vec<SearchResult>> {
399 let cache_key = Self::compute_cache_key(query_embedding, k, &QueryFilter::default());
401
402 if let Some(cached) = self.query_cache.write().unwrap().get(&cache_key) {
404 return Ok(cached.clone());
405 }
406
407 let results = self
409 .index
410 .read()
411 .unwrap()
412 .search(query_embedding, k, ef_search)?;
413
414 self.query_cache
416 .write()
417 .unwrap()
418 .put(cache_key, results.clone());
419
420 Ok(results)
421 }
422
423 pub async fn query_with_filter(
430 &self,
431 query_embedding: &[f32],
432 k: usize,
433 filter: QueryFilter,
434 ) -> Result<Vec<SearchResult>> {
435 let cache_key = Self::compute_cache_key(query_embedding, k, &filter);
437
438 if filter.min_score.is_none() && filter.cid_prefix.is_none() {
440 if let Some(cached) = self.query_cache.write().unwrap().get(&cache_key) {
441 return Ok(cached.clone());
442 }
443 }
444
445 let fetch_k = if filter.min_score.is_some() || filter.cid_prefix.is_some() {
447 k * 2 } else {
449 k
450 };
451
452 let mut results =
454 self.index
455 .read()
456 .unwrap()
457 .search(query_embedding, fetch_k, self.config.ef_search)?;
458
459 if let Some(min_score) = filter.min_score {
461 results.retain(|r| r.score >= min_score);
462 }
463
464 if let Some(max_score) = filter.max_score {
465 results.retain(|r| r.score <= max_score);
466 }
467
468 if let Some(ref prefix) = filter.cid_prefix {
469 results.retain(|r| r.cid.to_string().starts_with(prefix));
470 }
471
472 if let Some(max_results) = filter.max_results {
474 results.truncate(max_results);
475 }
476
477 if filter.min_score.is_none() && filter.cid_prefix.is_none() {
479 self.query_cache
480 .write()
481 .unwrap()
482 .put(cache_key, results.clone());
483 }
484
485 Ok(results)
486 }
487
488 fn compute_cache_key(embedding: &[f32], k: usize, filter: &QueryFilter) -> QueryCacheKey {
490 let mut hasher = DefaultHasher::new();
491
492 for (i, &val) in embedding.iter().enumerate().step_by(8) {
494 (i, (val * 1000.0) as i32).hash(&mut hasher);
495 }
496
497 k.hash(&mut hasher);
498 filter.max_results.hash(&mut hasher);
499
500 hasher.finish()
501 }
502
503 pub fn clear_cache(&self) {
505 self.query_cache.write().unwrap().clear();
506 }
507
508 pub fn cache_stats(&self) -> CacheStats {
510 let cache = self.query_cache.read().unwrap();
511 CacheStats {
512 size: cache.len(),
513 capacity: cache.cap().get(),
514 }
515 }
516
517 pub fn stats(&self) -> RouterStats {
519 let index = self.index.read().unwrap();
520 RouterStats {
521 num_vectors: index.len(),
522 dimension: index.dimension(),
523 metric: index.metric(),
524 }
525 }
526
527 pub fn optimization_recommendations(&self) -> OptimizationRecommendations {
531 let index = self.index.read().unwrap();
532 let (m, ef_construction) = index.compute_optimal_parameters();
533
534 OptimizationRecommendations {
535 recommended_m: m,
536 recommended_ef_construction: ef_construction,
537 current_size: index.len(),
538 }
539 }
540
541 pub async fn save_index<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
549 self.index.read().unwrap().save(path.as_ref())
550 }
551
552 pub async fn load_index<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
559 let loaded_index = VectorIndex::load(path.as_ref())?;
560 *self.index.write().unwrap() = loaded_index;
561 self.clear_cache();
563 Ok(())
564 }
565
566 pub fn clear(&self) -> Result<()> {
568 let new_index = VectorIndex::new(
570 self.config.dimension,
571 self.config.metric,
572 self.config.max_connections,
573 self.config.ef_construction,
574 )?;
575
576 *self.index.write().unwrap() = new_index;
577
578 self.query_cache.write().unwrap().clear();
580
581 Ok(())
582 }
583
584 pub async fn query_with_aggregations(
593 &self,
594 query_embedding: &[f32],
595 k: usize,
596 filter: QueryFilter,
597 ) -> Result<(Vec<SearchResult>, SearchAggregations)> {
598 let results = self.query_with_filter(query_embedding, k, filter).await?;
599 let aggregations = SearchAggregations::from_results(&results);
600 Ok((results, aggregations))
601 }
602
603 pub async fn query_batch(
615 &self,
616 query_embeddings: &[Vec<f32>],
617 k: usize,
618 ) -> Result<Vec<Vec<SearchResult>>> {
619 self.query_batch_with_filter(query_embeddings, k, QueryFilter::default())
620 .await
621 }
622
623 pub async fn query_batch_with_filter(
635 &self,
636 query_embeddings: &[Vec<f32>],
637 k: usize,
638 filter: QueryFilter,
639 ) -> Result<Vec<Vec<SearchResult>>> {
640 use rayon::prelude::*;
641
642 let results: Result<Vec<Vec<SearchResult>>> = query_embeddings
644 .par_iter()
645 .map(|embedding| {
646 let cache_key = Self::compute_cache_key(embedding, k, &filter);
648
649 if filter.min_score.is_none() && filter.cid_prefix.is_none() {
651 if let Some(cached) = self.query_cache.write().unwrap().get(&cache_key) {
652 return Ok(cached.clone());
653 }
654 }
655
656 let fetch_k = if filter.min_score.is_some() || filter.cid_prefix.is_some() {
658 k * 2 } else {
660 k
661 };
662
663 let mut results =
665 self.index
666 .read()
667 .unwrap()
668 .search(embedding, fetch_k, self.config.ef_search)?;
669
670 if let Some(min_score) = filter.min_score {
672 results.retain(|r| r.score >= min_score);
673 }
674
675 if let Some(max_score) = filter.max_score {
676 results.retain(|r| r.score <= max_score);
677 }
678
679 if let Some(ref prefix) = filter.cid_prefix {
680 results.retain(|r| r.cid.to_string().starts_with(prefix));
681 }
682
683 if let Some(max_results) = filter.max_results {
685 results.truncate(max_results);
686 }
687
688 if filter.min_score.is_none() && filter.cid_prefix.is_none() {
690 self.query_cache
691 .write()
692 .unwrap()
693 .put(cache_key, results.clone());
694 }
695
696 Ok(results)
697 })
698 .collect();
699
700 results
701 }
702
703 pub async fn query_batch_with_ef(
715 &self,
716 query_embeddings: &[Vec<f32>],
717 k: usize,
718 ef_search: usize,
719 ) -> Result<Vec<Vec<SearchResult>>> {
720 use rayon::prelude::*;
721
722 let results: Result<Vec<Vec<SearchResult>>> = query_embeddings
724 .par_iter()
725 .map(|embedding| {
726 let cache_key = Self::compute_cache_key(embedding, k, &QueryFilter::default());
728
729 if let Some(cached) = self.query_cache.write().unwrap().get(&cache_key) {
731 return Ok(cached.clone());
732 }
733
734 let results = self.index.read().unwrap().search(embedding, k, ef_search)?;
736
737 self.query_cache
739 .write()
740 .unwrap()
741 .put(cache_key, results.clone());
742
743 Ok(results)
744 })
745 .collect();
746
747 results
748 }
749
750 pub fn batch_stats(&self, batch_results: &[Vec<SearchResult>]) -> BatchStats {
754 let total_queries = batch_results.len();
755 let total_results: usize = batch_results.iter().map(|r| r.len()).sum();
756 let avg_results_per_query = if total_queries > 0 {
757 total_results as f32 / total_queries as f32
758 } else {
759 0.0
760 };
761
762 let all_scores: Vec<f32> = batch_results
763 .iter()
764 .flat_map(|results| results.iter().map(|r| r.score))
765 .collect();
766
767 let avg_score = if !all_scores.is_empty() {
768 all_scores.iter().sum::<f32>() / all_scores.len() as f32
769 } else {
770 0.0
771 };
772
773 BatchStats {
774 total_queries,
775 total_results,
776 avg_results_per_query,
777 avg_score,
778 }
779 }
780}
781
782#[derive(Debug, Clone)]
784pub struct RouterStats {
785 pub num_vectors: usize,
787 pub dimension: usize,
789 pub metric: DistanceMetric,
791}
792
793#[derive(Debug, Clone)]
795pub struct CacheStats {
796 pub size: usize,
798 pub capacity: usize,
800}
801
802#[derive(Debug, Clone)]
804pub struct BatchStats {
805 pub total_queries: usize,
807 pub total_results: usize,
809 pub avg_results_per_query: f32,
811 pub avg_score: f32,
813}
814
815#[derive(Debug, Clone)]
817pub struct OptimizationRecommendations {
818 pub recommended_m: usize,
820 pub recommended_ef_construction: usize,
822 pub current_size: usize,
824}
825
826#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
828pub struct SearchAggregations {
829 pub total_count: usize,
831 pub avg_score: f32,
833 pub min_score: f32,
835 pub max_score: f32,
837 pub score_buckets: Vec<ScoreBucket>,
839}
840
841#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
843pub struct ScoreBucket {
844 pub range: (f32, f32),
846 pub count: usize,
848}
849
850impl SearchAggregations {
851 pub fn from_results(results: &[SearchResult]) -> Self {
853 if results.is_empty() {
854 return Self {
855 total_count: 0,
856 avg_score: 0.0,
857 min_score: 0.0,
858 max_score: 0.0,
859 score_buckets: Vec::new(),
860 };
861 }
862
863 let total_count = results.len();
864 let sum: f32 = results.iter().map(|r| r.score).sum();
865 let avg_score = sum / total_count as f32;
866 let min_score = results
867 .iter()
868 .map(|r| r.score)
869 .min_by(|a, b| a.partial_cmp(b).unwrap())
870 .unwrap();
871 let max_score = results
872 .iter()
873 .map(|r| r.score)
874 .max_by(|a, b| a.partial_cmp(b).unwrap())
875 .unwrap();
876
877 let bucket_count = 10;
879 let range = max_score - min_score;
880 let bucket_size = if range > 0.0 {
881 range / bucket_count as f32
882 } else {
883 1.0
884 };
885
886 let mut buckets = vec![0; bucket_count];
887 for result in results {
888 let bucket_idx = if range > 0.0 {
889 ((result.score - min_score) / bucket_size).floor() as usize
890 } else {
891 0
892 };
893 let bucket_idx = bucket_idx.min(bucket_count - 1);
894 buckets[bucket_idx] += 1;
895 }
896
897 let score_buckets = buckets
898 .into_iter()
899 .enumerate()
900 .map(|(i, count)| ScoreBucket {
901 range: (
902 min_score + i as f32 * bucket_size,
903 min_score + (i + 1) as f32 * bucket_size,
904 ),
905 count,
906 })
907 .collect();
908
909 Self {
910 total_count,
911 avg_score,
912 min_score,
913 max_score,
914 score_buckets,
915 }
916 }
917}
918
919impl Default for SemanticRouter {
920 fn default() -> Self {
921 Self::with_defaults().expect("Failed to create default SemanticRouter")
922 }
923}
924
925#[cfg(test)]
926mod tests {
927 use super::*;
928
929 #[tokio::test]
930 async fn test_router_creation() {
931 let router = SemanticRouter::with_defaults();
932 assert!(router.is_ok());
933 }
934
935 #[tokio::test]
936 async fn test_add_and_query() {
937 let router = SemanticRouter::with_defaults().unwrap();
938
939 let cid1 = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
940 .parse::<Cid>()
941 .unwrap();
942 let embedding1 = vec![0.5; 768];
943
944 router.add(&cid1, &embedding1).unwrap();
945
946 let results = router.query(&embedding1, 1).await.unwrap();
947 assert_eq!(results.len(), 1);
948 assert_eq!(results[0].cid, cid1);
949 }
950
951 #[tokio::test]
952 async fn test_filtering() {
953 let router = SemanticRouter::with_defaults().unwrap();
954
955 let cid1 = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
956 .parse::<Cid>()
957 .unwrap();
958 let embedding1 = vec![0.5; 768];
959
960 router.add(&cid1, &embedding1).unwrap();
961
962 let filter = QueryFilter {
964 min_score: Some(0.9),
965 max_score: None,
966 max_results: Some(10),
967 cid_prefix: None,
968 };
969
970 let results = router
971 .query_with_filter(&embedding1, 10, filter)
972 .await
973 .unwrap();
974
975 assert!(!results.is_empty());
977 }
978
979 #[tokio::test]
980 async fn test_integration_with_blocks() {
981 use bytes::Bytes;
982 use ipfrs_core::Block;
983
984 let router = SemanticRouter::new(RouterConfig {
986 dimension: 3,
987 ..Default::default()
988 })
989 .unwrap();
990
991 let data1 = Bytes::from_static(b"Hello, semantic search!");
993 let data2 = Bytes::from_static(b"Goodbye, semantic search!");
994 let data3 = Bytes::from_static(b"Hello, world!");
995
996 let block1 = Block::new(data1).unwrap();
997 let block2 = Block::new(data2).unwrap();
998 let block3 = Block::new(data3).unwrap();
999
1000 let embedding1 = vec![1.0, 0.0, 0.0]; let embedding2 = vec![0.0, 1.0, 0.0]; let embedding3 = vec![0.9, 0.1, 0.0]; router.add(block1.cid(), &embedding1).unwrap();
1008 router.add(block2.cid(), &embedding2).unwrap();
1009 router.add(block3.cid(), &embedding3).unwrap();
1010
1011 let query_embedding = vec![1.0, 0.0, 0.0];
1013 let results = router.query(&query_embedding, 2).await.unwrap();
1014
1015 assert_eq!(results.len(), 2);
1017 assert_eq!(results[0].cid, *block1.cid());
1018 }
1019
1020 #[tokio::test]
1021 async fn test_integration_with_tensor_metadata() {
1022 use ipfrs_core::{TensorDtype, TensorMetadata, TensorShape};
1023
1024 let router = SemanticRouter::new(RouterConfig {
1025 dimension: 2,
1026 ..Default::default()
1027 })
1028 .unwrap();
1029
1030 let shape1 = TensorShape::new(vec![1, 768]);
1032 let mut metadata1 = TensorMetadata::new(shape1, TensorDtype::F32);
1033 metadata1.name = Some("vision_embedding".to_string());
1034 metadata1
1035 .metadata
1036 .insert("semantic_tag".to_string(), "vision".to_string());
1037
1038 let shape2 = TensorShape::new(vec![1, 768]);
1039 let mut metadata2 = TensorMetadata::new(shape2, TensorDtype::F32);
1040 metadata2.name = Some("text_embedding".to_string());
1041 metadata2
1042 .metadata
1043 .insert("semantic_tag".to_string(), "text".to_string());
1044
1045 let vision_embedding = vec![1.0, 0.0];
1047 let text_embedding = vec![0.0, 1.0];
1048
1049 let cid1 = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
1051 .parse::<Cid>()
1052 .unwrap();
1053 let cid2 = "bafybeibazl2z6vqxqqzmhmvx2hfpxqtwggqgbbyy3sxkq4vzq6cqsvwbjy"
1054 .parse::<Cid>()
1055 .unwrap();
1056
1057 router.add(&cid1, &vision_embedding).unwrap();
1059 router.add(&cid2, &text_embedding).unwrap();
1060
1061 let results = router.query(&vision_embedding, 1).await.unwrap();
1063 assert_eq!(results.len(), 1);
1064 assert_eq!(results[0].cid, cid1);
1065 }
1066
1067 #[tokio::test]
1068 async fn test_large_scale_indexing() {
1069 use rand::Rng;
1070
1071 let dimension = 128;
1072
1073 let router = SemanticRouter::new(RouterConfig {
1075 dimension,
1076 ..Default::default()
1077 })
1078 .unwrap();
1079
1080 let mut rng = rand::rng();
1082 let num_items = 1000;
1083
1084 let mut indexed_cids = Vec::new();
1085
1086 for i in 0..num_items {
1087 use multihash_codetable::{Code, MultihashDigest};
1089 let data = format!("large_scale_test_{}", i);
1090 let hash = Code::Sha2_256.digest(data.as_bytes());
1091 let cid = Cid::new_v1(0x55, hash);
1092
1093 let embedding: Vec<f32> = (0..dimension)
1095 .map(|_| rng.random_range(-1.0..1.0))
1096 .collect();
1097
1098 router.add(&cid, &embedding).unwrap();
1099 indexed_cids.push((cid, embedding));
1100 }
1101
1102 let stats = router.stats();
1104 assert_eq!(stats.num_vectors, num_items);
1105
1106 let (test_cid, test_embedding) = &indexed_cids[42];
1108 let results = router.query(test_embedding, 1).await.unwrap();
1109
1110 assert_eq!(results.len(), 1);
1112 assert_eq!(results[0].cid, *test_cid);
1113 }
1114
1115 #[tokio::test]
1116 async fn test_cache_effectiveness() {
1117 let router = SemanticRouter::with_defaults().unwrap();
1118
1119 let cid1 = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
1120 .parse::<Cid>()
1121 .unwrap();
1122 let embedding1 = vec![0.5; 768];
1123
1124 router.add(&cid1, &embedding1).unwrap();
1125
1126 for _ in 0..10 {
1128 let _ = router.query(&embedding1, 1).await.unwrap();
1129 }
1130
1131 let cache_stats = router.cache_stats();
1133 assert_eq!(cache_stats.size, 1, "Cache should have 1 unique query");
1134 assert!(cache_stats.capacity > 0, "Cache should have capacity");
1135 }
1136
1137 #[tokio::test]
1138 async fn test_batch_query() {
1139 use rand::Rng;
1140
1141 let dimension = 128;
1142
1143 let router = SemanticRouter::new(RouterConfig {
1145 dimension,
1146 ..Default::default()
1147 })
1148 .unwrap();
1149
1150 let mut rng = rand::rng();
1152 let num_items = 100;
1153
1154 for i in 0..num_items {
1155 use multihash_codetable::{Code, MultihashDigest};
1157 let data = format!("batch_test_{}", i);
1158 let hash = Code::Sha2_256.digest(data.as_bytes());
1159 let cid = Cid::new_v1(0x55, hash);
1160
1161 let embedding: Vec<f32> = (0..dimension)
1163 .map(|_| rng.random_range(-1.0..1.0))
1164 .collect();
1165
1166 router.add(&cid, &embedding).unwrap();
1167 }
1168
1169 let batch_size = 10;
1171 let query_batch: Vec<Vec<f32>> = (0..batch_size)
1172 .map(|_| {
1173 (0..dimension)
1174 .map(|_| rng.random_range(-1.0..1.0))
1175 .collect()
1176 })
1177 .collect();
1178
1179 let results = router.query_batch(&query_batch, 5).await.unwrap();
1181
1182 assert_eq!(results.len(), batch_size);
1184 for result in &results {
1185 assert!(!result.is_empty());
1186 assert!(result.len() <= 5);
1187 }
1188
1189 let stats = router.batch_stats(&results);
1191 assert_eq!(stats.total_queries, batch_size);
1192 assert!(stats.total_results > 0);
1193 assert!(stats.avg_results_per_query > 0.0);
1194 }
1195
1196 #[tokio::test]
1197 async fn test_batch_query_with_filter() {
1198 use rand::Rng;
1199
1200 let dimension = 64;
1201
1202 let router = SemanticRouter::new(RouterConfig {
1203 dimension,
1204 ..Default::default()
1205 })
1206 .unwrap();
1207
1208 let mut rng = rand::rng();
1210 let num_items = 50;
1211
1212 for i in 0..num_items {
1213 use multihash_codetable::{Code, MultihashDigest};
1214 let data = format!("filter_batch_test_{}", i);
1215 let hash = Code::Sha2_256.digest(data.as_bytes());
1216 let cid = Cid::new_v1(0x55, hash);
1217
1218 let embedding: Vec<f32> = (0..dimension)
1219 .map(|_| rng.random_range(-1.0..1.0))
1220 .collect();
1221
1222 router.add(&cid, &embedding).unwrap();
1223 }
1224
1225 let batch_size = 5;
1227 let query_batch: Vec<Vec<f32>> = (0..batch_size)
1228 .map(|_| {
1229 (0..dimension)
1230 .map(|_| rng.random_range(-1.0..1.0))
1231 .collect()
1232 })
1233 .collect();
1234
1235 let filter = QueryFilter {
1237 min_score: Some(0.0),
1238 max_results: Some(3),
1239 ..Default::default()
1240 };
1241
1242 let results = router
1243 .query_batch_with_filter(&query_batch, 5, filter)
1244 .await
1245 .unwrap();
1246
1247 assert_eq!(results.len(), batch_size);
1249 for result in &results {
1250 assert!(result.len() <= 3); }
1252 }
1253
1254 #[tokio::test]
1255 async fn test_batch_query_with_ef() {
1256 use rand::Rng;
1257
1258 let dimension = 64;
1259
1260 let router = SemanticRouter::new(RouterConfig {
1261 dimension,
1262 ..Default::default()
1263 })
1264 .unwrap();
1265
1266 let mut rng = rand::rng();
1268 let num_items = 50;
1269
1270 for i in 0..num_items {
1271 use multihash_codetable::{Code, MultihashDigest};
1272 let data = format!("ef_batch_test_{}", i);
1273 let hash = Code::Sha2_256.digest(data.as_bytes());
1274 let cid = Cid::new_v1(0x55, hash);
1275
1276 let embedding: Vec<f32> = (0..dimension)
1277 .map(|_| rng.random_range(-1.0..1.0))
1278 .collect();
1279
1280 router.add(&cid, &embedding).unwrap();
1281 }
1282
1283 let batch_size = 5;
1285 let query_batch: Vec<Vec<f32>> = (0..batch_size)
1286 .map(|_| {
1287 (0..dimension)
1288 .map(|_| rng.random_range(-1.0..1.0))
1289 .collect()
1290 })
1291 .collect();
1292
1293 let results = router
1295 .query_batch_with_ef(&query_batch, 3, 100)
1296 .await
1297 .unwrap();
1298
1299 assert_eq!(results.len(), batch_size);
1301 for result in &results {
1302 assert!(!result.is_empty());
1303 assert!(result.len() <= 3);
1304 }
1305 }
1306}