1use crate::filter::{Filter, Metadata};
38use crate::search::VectorSearchIndex;
39use crate::types::{SearchConfig, SearchResult};
40use anyhow::{anyhow, Result};
41use serde::{Deserialize, Serialize};
42use std::collections::{BTreeMap, HashMap};
43use std::hash::{Hash, Hasher};
44use std::sync::{Arc, RwLock};
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ShardConfig {
49 pub num_shards: usize,
51 pub num_replicas: usize,
53 pub virtual_nodes: usize,
55}
56
57impl ShardConfig {
58 pub fn new(num_shards: usize, num_replicas: usize) -> Self {
70 assert!(num_shards >= 1, "num_shards must be at least 1");
71 assert!(num_replicas >= 1, "num_replicas must be at least 1");
72 Self {
73 num_shards,
74 num_replicas,
75 virtual_nodes: 150, }
77 }
78
79 pub fn with_virtual_nodes(mut self, virtual_nodes: usize) -> Self {
81 self.virtual_nodes = virtual_nodes;
82 self
83 }
84}
85
86impl Default for ShardConfig {
87 fn default() -> Self {
88 Self::new(1, 1) }
90}
91
92#[derive(Debug)]
96pub struct ConsistentHash {
97 ring: BTreeMap<u64, usize>,
99 #[allow(dead_code)]
101 virtual_nodes: usize,
102}
103
104impl ConsistentHash {
105 pub fn new(num_shards: usize, virtual_nodes: usize) -> Self {
111 let mut ring = BTreeMap::new();
112
113 for shard_id in 0..num_shards {
115 for vnode in 0..virtual_nodes {
116 let key = format!("shard-{}-vnode-{}", shard_id, vnode);
117 let hash = Self::hash_key(&key);
118 ring.insert(hash, shard_id);
119 }
120 }
121
122 Self {
123 ring,
124 virtual_nodes,
125 }
126 }
127
128 pub fn get_shard(&self, key: &str) -> usize {
130 if self.ring.is_empty() {
131 return 0;
132 }
133
134 let hash = Self::hash_key(key);
135
136 match self.ring.range(hash..).next() {
138 Some((&_, &shard_id)) => shard_id,
139 None => *self.ring.values().next().unwrap(), }
141 }
142
143 pub fn get_replicas(&self, key: &str, num_replicas: usize) -> Vec<usize> {
145 if self.ring.is_empty() {
146 return vec![0];
147 }
148
149 let hash = Self::hash_key(key);
150 let mut replicas = Vec::new();
151 let mut seen = std::collections::HashSet::new();
152
153 for (&_, &shard_id) in self.ring.range(hash..) {
155 if !seen.contains(&shard_id) {
156 replicas.push(shard_id);
157 seen.insert(shard_id);
158 if replicas.len() >= num_replicas {
159 return replicas;
160 }
161 }
162 }
163
164 for (&_, &shard_id) in self.ring.iter() {
166 if !seen.contains(&shard_id) {
167 replicas.push(shard_id);
168 seen.insert(shard_id);
169 if replicas.len() >= num_replicas {
170 return replicas;
171 }
172 }
173 }
174
175 replicas
176 }
177
178 fn hash_key(key: &str) -> u64 {
180 let mut hasher = std::collections::hash_map::DefaultHasher::new();
181 key.hash(&mut hasher);
182 hasher.finish()
183 }
184}
185
186#[derive(Debug)]
188struct Shard {
189 #[allow(dead_code)]
191 id: usize,
192 index: VectorSearchIndex,
194 size: usize,
196}
197
198impl Shard {
199 fn new(id: usize, config: SearchConfig) -> Self {
200 Self {
201 id,
202 index: VectorSearchIndex::new(config),
203 size: 0,
204 }
205 }
206
207 fn build(&mut self, embeddings: &HashMap<String, Vec<f32>>) -> Result<()> {
208 if !embeddings.is_empty() {
210 self.index.build(embeddings)?;
211 self.size = embeddings.len();
212 }
213 Ok(())
214 }
215
216 fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
217 if self.size == 0 {
219 return Ok(Vec::new());
220 }
221 self.index.search(query, k)
222 }
223
224 fn filtered_search(
225 &self,
226 query: &[f32],
227 k: usize,
228 filter: &Filter,
229 ) -> Result<Vec<SearchResult>> {
230 if self.size == 0 {
231 return Ok(Vec::new());
232 }
233 self.index.filtered_search(query, k, filter)
234 }
235
236 fn set_metadata(&mut self, entity_id: &str, metadata: Metadata) {
237 self.index.set_metadata(entity_id, metadata);
238 }
239
240 fn get_metadata(&self, entity_id: &str) -> Option<&Metadata> {
241 self.index.get_metadata(entity_id)
242 }
243}
244
245pub struct DistributedIndex {
250 shard_config: ShardConfig,
252 #[allow(dead_code)]
254 search_config: SearchConfig,
255 shards: Vec<Arc<RwLock<Shard>>>,
257 hash_ring: ConsistentHash,
259 total_size: Arc<RwLock<usize>>,
261}
262
263impl DistributedIndex {
264 pub fn new(shard_config: ShardConfig, search_config: SearchConfig) -> Self {
266 let hash_ring = ConsistentHash::new(shard_config.num_shards, shard_config.virtual_nodes);
267
268 let mut shards = Vec::new();
269 for i in 0..shard_config.num_shards {
270 let shard = Shard::new(i, search_config.clone());
271 shards.push(Arc::new(RwLock::new(shard)));
272 }
273
274 Self {
275 shard_config,
276 search_config,
277 shards,
278 hash_ring,
279 total_size: Arc::new(RwLock::new(0)),
280 }
281 }
282
283 pub fn build(&mut self, embeddings: &HashMap<String, Vec<f32>>) -> Result<()> {
287 let mut shard_embeddings: Vec<HashMap<String, Vec<f32>>> =
289 vec![HashMap::new(); self.shard_config.num_shards];
290
291 for (entity_id, embedding) in embeddings {
292 let shard_id = self.hash_ring.get_shard(entity_id);
293 shard_embeddings[shard_id].insert(entity_id.clone(), embedding.clone());
294 }
295
296 #[cfg(feature = "parallel")]
298 {
299 use rayon::prelude::*;
300 self.shards
301 .par_iter()
302 .zip(shard_embeddings.par_iter())
303 .try_for_each(|(shard, embs)| -> Result<()> {
304 let mut shard = shard.write().map_err(|e| anyhow!("Lock error: {}", e))?;
305 shard.build(embs)?;
306 Ok(())
307 })?;
308 }
309
310 #[cfg(not(feature = "parallel"))]
311 {
312 for (shard, embs) in self.shards.iter().zip(shard_embeddings.iter()) {
313 let mut shard = shard.write().map_err(|e| anyhow!("Lock error: {}", e))?;
314 shard.build(embs)?;
315 }
316 }
317
318 let mut total = self
320 .total_size
321 .write()
322 .map_err(|e| anyhow!("Lock error: {}", e))?;
323 *total = embeddings.len();
324
325 Ok(())
326 }
327
328 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
332 #[cfg(feature = "parallel")]
334 let shard_results = {
335 use rayon::prelude::*;
336 self.shards
337 .par_iter()
338 .map(|shard| -> Result<Vec<SearchResult>> {
339 let shard = shard.read().map_err(|e| anyhow!("Lock error: {}", e))?;
340 shard.search(query, k)
341 })
342 .collect::<Result<Vec<Vec<SearchResult>>>>()?
343 };
344
345 #[cfg(not(feature = "parallel"))]
346 let shard_results = {
347 let mut results = Vec::new();
348 for shard in &self.shards {
349 let shard = shard.read().map_err(|e| anyhow!("Lock error: {}", e))?;
350 let result = shard.search(query, k)?;
351 results.push(result);
352 }
353 results
354 };
355
356 let merged = Self::merge_results(shard_results, k);
358
359 Ok(merged)
360 }
361
362 pub fn batch_search(&self, queries: &[Vec<f32>], k: usize) -> Result<Vec<Vec<SearchResult>>> {
366 #[cfg(feature = "parallel")]
367 {
368 use rayon::prelude::*;
369 queries
370 .par_iter()
371 .map(|query| self.search(query, k))
372 .collect()
373 }
374
375 #[cfg(not(feature = "parallel"))]
376 {
377 queries.iter().map(|query| self.search(query, k)).collect()
378 }
379 }
380
381 pub fn filtered_search(
385 &self,
386 query: &[f32],
387 k: usize,
388 filter: &Filter,
389 ) -> Result<Vec<SearchResult>> {
390 #[cfg(feature = "parallel")]
392 let shard_results = {
393 use rayon::prelude::*;
394 self.shards
395 .par_iter()
396 .map(|shard| -> Result<Vec<SearchResult>> {
397 let shard = shard.read().map_err(|e| anyhow!("Lock error: {}", e))?;
398 shard.filtered_search(query, k, filter)
399 })
400 .collect::<Result<Vec<Vec<SearchResult>>>>()?
401 };
402
403 #[cfg(not(feature = "parallel"))]
404 let shard_results = {
405 let mut results = Vec::new();
406 for shard in &self.shards {
407 let shard = shard.read().map_err(|e| anyhow!("Lock error: {}", e))?;
408 let result = shard.filtered_search(query, k, filter)?;
409 results.push(result);
410 }
411 results
412 };
413
414 let merged = Self::merge_results(shard_results, k);
416
417 Ok(merged)
418 }
419
420 pub fn set_metadata(&mut self, entity_id: &str, metadata: Metadata) {
422 let replica_shards = self
423 .hash_ring
424 .get_replicas(entity_id, self.shard_config.num_replicas);
425
426 for shard_id in replica_shards {
427 if let Ok(mut shard) = self.shards[shard_id].write() {
428 shard.set_metadata(entity_id, metadata.clone());
429 }
430 }
431 }
432
433 pub fn get_metadata(&self, entity_id: &str) -> Option<Metadata> {
435 let shard_id = self.hash_ring.get_shard(entity_id);
436 if let Ok(shard) = self.shards[shard_id].read() {
437 shard.get_metadata(entity_id).cloned()
438 } else {
439 None
440 }
441 }
442
443 pub fn batch_set_metadata(&mut self, metadata_map: &HashMap<String, Metadata>) {
445 for (entity_id, metadata) in metadata_map {
446 self.set_metadata(entity_id, metadata.clone());
447 }
448 }
449
450 pub fn get_stats(&self) -> Result<DistributedStats> {
452 let mut shard_sizes = Vec::new();
453 let mut total_vectors = 0;
454
455 for shard in &self.shards {
456 let shard = shard.read().map_err(|e| anyhow!("Lock error: {}", e))?;
457 shard_sizes.push(shard.size);
458 total_vectors += shard.size;
459 }
460
461 let avg_shard_size = if !shard_sizes.is_empty() {
462 shard_sizes.iter().sum::<usize>() as f64 / shard_sizes.len() as f64
463 } else {
464 0.0
465 };
466
467 let max_shard_size = shard_sizes.iter().copied().max().unwrap_or(0);
468 let min_shard_size = shard_sizes.iter().copied().min().unwrap_or(0);
469
470 Ok(DistributedStats {
471 num_shards: self.shard_config.num_shards,
472 num_replicas: self.shard_config.num_replicas,
473 total_vectors,
474 shard_sizes,
475 avg_shard_size,
476 max_shard_size,
477 min_shard_size,
478 balance_ratio: if max_shard_size > 0 {
479 min_shard_size as f64 / max_shard_size as f64
480 } else {
481 1.0
482 },
483 })
484 }
485
486 fn merge_results(shard_results: Vec<Vec<SearchResult>>, k: usize) -> Vec<SearchResult> {
490 let mut all_results = Vec::new();
491
492 for results in shard_results {
494 all_results.extend(results);
495 }
496
497 all_results.sort_by(|a, b| {
499 b.score
500 .partial_cmp(&a.score)
501 .unwrap_or(std::cmp::Ordering::Equal)
502 });
503
504 let mut seen = std::collections::HashSet::new();
506 let mut merged = Vec::new();
507
508 for result in all_results {
509 if !seen.contains(&result.entity_id) {
510 seen.insert(result.entity_id.clone());
511 merged.push(result);
512 if merged.len() >= k {
513 break;
514 }
515 }
516 }
517
518 merged
519 }
520}
521
522#[derive(Debug, Clone, Serialize, Deserialize)]
524pub struct DistributedStats {
525 pub num_shards: usize,
527 pub num_replicas: usize,
529 pub total_vectors: usize,
531 pub shard_sizes: Vec<usize>,
533 pub avg_shard_size: f64,
535 pub max_shard_size: usize,
537 pub min_shard_size: usize,
539 pub balance_ratio: f64,
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546
547 #[test]
548 fn test_shard_config() {
549 let config = ShardConfig::new(3, 2);
550 assert_eq!(config.num_shards, 3);
551 assert_eq!(config.num_replicas, 2);
552 assert_eq!(config.virtual_nodes, 150);
553
554 let config = config.with_virtual_nodes(200);
555 assert_eq!(config.virtual_nodes, 200);
556 }
557
558 #[test]
559 fn test_consistent_hash() {
560 let hash = ConsistentHash::new(3, 10);
561
562 let shard1 = hash.get_shard("doc1");
564 let shard2 = hash.get_shard("doc1");
565 assert_eq!(shard1, shard2);
566
567 let mut shard_counts = vec![0; 3];
569 for i in 0..100 {
570 let key = format!("doc{}", i);
571 let shard = hash.get_shard(&key);
572 shard_counts[shard] += 1;
573 }
574
575 for count in shard_counts {
578 assert!(
579 (18..=48).contains(&count),
580 "Imbalanced distribution: {}",
581 count
582 );
583 }
584 }
585
586 #[test]
587 fn test_consistent_hash_replicas() {
588 let hash = ConsistentHash::new(5, 10);
589
590 let replicas = hash.get_replicas("doc1", 3);
591 assert_eq!(replicas.len(), 3);
592
593 let unique: std::collections::HashSet<_> = replicas.iter().collect();
595 assert_eq!(unique.len(), 3);
596 }
597
598 #[test]
599 fn test_distributed_index_creation() {
600 let shard_config = ShardConfig::new(2, 1);
601 let search_config = SearchConfig::default();
602 let index = DistributedIndex::new(shard_config, search_config);
603
604 assert_eq!(index.shards.len(), 2);
605 }
606
607 #[test]
608 fn test_distributed_index_build() {
609 let shard_config = ShardConfig::new(2, 1);
610 let search_config = SearchConfig::default();
611 let mut index = DistributedIndex::new(shard_config, search_config);
612
613 let mut embeddings = HashMap::new();
614 embeddings.insert("doc1".to_string(), vec![0.1, 0.2, 0.3]);
615 embeddings.insert("doc2".to_string(), vec![0.2, 0.3, 0.4]);
616 embeddings.insert("doc3".to_string(), vec![0.3, 0.4, 0.5]);
617
618 assert!(index.build(&embeddings).is_ok());
619
620 let stats = index.get_stats().unwrap();
621 assert_eq!(stats.num_shards, 2);
622 assert!(stats.total_vectors <= 3); }
624
625 #[test]
626 fn test_distributed_search() {
627 let shard_config = ShardConfig::new(2, 1);
628 let search_config = SearchConfig::default();
629 let mut index = DistributedIndex::new(shard_config, search_config);
630
631 let mut embeddings = HashMap::new();
632 embeddings.insert("doc1".to_string(), vec![1.0, 0.0, 0.0]);
633 embeddings.insert("doc2".to_string(), vec![0.0, 1.0, 0.0]);
634 embeddings.insert("doc3".to_string(), vec![0.0, 0.0, 1.0]);
635
636 index.build(&embeddings).unwrap();
637
638 let query = vec![0.9, 0.1, 0.0];
639 let results = index.search(&query, 2).unwrap();
640
641 assert!(results.len() <= 2);
642 assert_eq!(results[0].entity_id, "doc1");
643 }
644
645 #[test]
646 fn test_distributed_stats() {
647 let shard_config = ShardConfig::new(3, 1);
648 let search_config = SearchConfig::default();
649 let mut index = DistributedIndex::new(shard_config, search_config);
650
651 let mut embeddings = HashMap::new();
652 for i in 0..10 {
653 let key = format!("doc{}", i);
654 let embedding = vec![i as f32 * 0.1, 0.2, 0.3];
655 embeddings.insert(key, embedding);
656 }
657
658 index.build(&embeddings).unwrap();
659
660 let stats = index.get_stats().unwrap();
661 assert_eq!(stats.num_shards, 3);
662 assert_eq!(stats.num_replicas, 1);
663 assert!(stats.total_vectors <= 10);
664 assert!(stats.balance_ratio >= 0.0 && stats.balance_ratio <= 1.0);
667 }
668
669 #[test]
670 fn test_merge_results() {
671 let shard1_results = vec![
672 SearchResult {
673 entity_id: "doc1".to_string(),
674 score: 0.9,
675 distance: 0.1,
676 rank: 0,
677 },
678 SearchResult {
679 entity_id: "doc2".to_string(),
680 score: 0.7,
681 distance: 0.3,
682 rank: 1,
683 },
684 ];
685
686 let shard2_results = vec![
687 SearchResult {
688 entity_id: "doc3".to_string(),
689 score: 0.85,
690 distance: 0.15,
691 rank: 0,
692 },
693 SearchResult {
694 entity_id: "doc4".to_string(),
695 score: 0.6,
696 distance: 0.4,
697 rank: 1,
698 },
699 ];
700
701 let merged = DistributedIndex::merge_results(vec![shard1_results, shard2_results], 3);
702
703 assert_eq!(merged.len(), 3);
704 assert_eq!(merged[0].entity_id, "doc1"); assert_eq!(merged[1].entity_id, "doc3"); assert_eq!(merged[2].entity_id, "doc2"); }
708
709 #[test]
710 fn test_merge_results_deduplication() {
711 let shard1_results = vec![SearchResult {
712 entity_id: "doc1".to_string(),
713 score: 0.9,
714 distance: 0.1,
715 rank: 0,
716 }];
717
718 let shard2_results = vec![SearchResult {
719 entity_id: "doc1".to_string(),
720 score: 0.85,
721 distance: 0.15,
722 rank: 0,
723 }];
724
725 let merged = DistributedIndex::merge_results(vec![shard1_results, shard2_results], 5);
726
727 assert_eq!(merged.len(), 1);
729 assert_eq!(merged[0].entity_id, "doc1");
730 assert_eq!(merged[0].score, 0.9); }
732
733 #[test]
734 fn test_distributed_replication() {
735 let shard_config = ShardConfig::new(3, 2); let search_config = SearchConfig::default();
737 let mut index = DistributedIndex::new(shard_config, search_config);
738
739 let mut embeddings = HashMap::new();
741 embeddings.insert("doc1".to_string(), vec![0.1, 0.2, 0.3]);
742 index.build(&embeddings).unwrap();
743
744 let query = vec![0.1, 0.2, 0.3];
746 let results = index.search(&query, 1).unwrap();
747 assert_eq!(results.len(), 1);
748 assert_eq!(results[0].entity_id, "doc1");
749 }
750
751 #[test]
752 fn test_distributed_batch_search() {
753 let shard_config = ShardConfig::new(2, 1);
754 let search_config = SearchConfig::default();
755 let mut index = DistributedIndex::new(shard_config, search_config);
756
757 let mut embeddings = HashMap::new();
758 embeddings.insert("doc1".to_string(), vec![1.0, 0.0, 0.0]);
759 embeddings.insert("doc2".to_string(), vec![0.0, 1.0, 0.0]);
760 embeddings.insert("doc3".to_string(), vec![0.0, 0.0, 1.0]);
761 index.build(&embeddings).unwrap();
762
763 let queries = vec![
764 vec![0.9, 0.1, 0.0],
765 vec![0.0, 0.9, 0.1],
766 vec![0.0, 0.0, 0.9],
767 ];
768
769 let results = index.batch_search(&queries, 1).unwrap();
770 assert_eq!(results.len(), 3);
771 assert_eq!(results[0][0].entity_id, "doc1");
772 assert_eq!(results[1][0].entity_id, "doc2");
773 assert_eq!(results[2][0].entity_id, "doc3");
774 }
775
776 #[test]
777 fn test_distributed_filtered_search() {
778 use crate::filter::FilterValue;
779
780 let shard_config = ShardConfig::new(2, 1);
781 let search_config = SearchConfig::default();
782 let mut index = DistributedIndex::new(shard_config, search_config);
783
784 let mut embeddings = HashMap::new();
785 embeddings.insert("doc1".to_string(), vec![1.0, 0.0, 0.0]);
786 embeddings.insert("doc2".to_string(), vec![0.0, 1.0, 0.0]);
787 embeddings.insert("doc3".to_string(), vec![0.0, 0.0, 1.0]);
788 index.build(&embeddings).unwrap();
789
790 let mut metadata1 = HashMap::new();
792 metadata1.insert(
793 "type".to_string(),
794 FilterValue::String("article".to_string()),
795 );
796 index.set_metadata("doc1", metadata1);
797
798 let mut metadata2 = HashMap::new();
799 metadata2.insert("type".to_string(), FilterValue::String("book".to_string()));
800 index.set_metadata("doc2", metadata2);
801
802 let mut metadata3 = HashMap::new();
803 metadata3.insert(
804 "type".to_string(),
805 FilterValue::String("article".to_string()),
806 );
807 index.set_metadata("doc3", metadata3);
808
809 let filter = Filter::new().eq("type", "article");
811
812 let query = vec![0.5, 0.5, 0.5];
813 let results = index.filtered_search(&query, 10, &filter).unwrap();
814
815 assert!(results.len() <= 2);
817 for result in &results {
818 assert!(result.entity_id == "doc1" || result.entity_id == "doc3");
819 }
820 }
821
822 #[test]
823 fn test_distributed_metadata() {
824 use crate::filter::FilterValue;
825
826 let shard_config = ShardConfig::new(2, 1);
827 let search_config = SearchConfig::default();
828 let mut index = DistributedIndex::new(shard_config, search_config);
829
830 let mut embeddings = HashMap::new();
831 embeddings.insert("doc1".to_string(), vec![0.1, 0.2, 0.3]);
832 index.build(&embeddings).unwrap();
833
834 let mut metadata = HashMap::new();
836 metadata.insert("year".to_string(), FilterValue::Int(2026));
837 index.set_metadata("doc1", metadata.clone());
838
839 let retrieved = index.get_metadata("doc1");
841 assert!(retrieved.is_some());
842 let retrieved = retrieved.unwrap();
843 assert_eq!(retrieved.get("year"), Some(&FilterValue::Int(2026)));
844 }
845
846 #[test]
847 fn test_distributed_batch_metadata() {
848 use crate::filter::FilterValue;
849
850 let shard_config = ShardConfig::new(2, 1);
851 let search_config = SearchConfig::default();
852 let mut index = DistributedIndex::new(shard_config, search_config);
853
854 let mut embeddings = HashMap::new();
855 embeddings.insert("doc1".to_string(), vec![0.1, 0.2, 0.3]);
856 embeddings.insert("doc2".to_string(), vec![0.2, 0.3, 0.4]);
857 index.build(&embeddings).unwrap();
858
859 let mut metadata_map = HashMap::new();
861
862 let mut m1 = HashMap::new();
863 m1.insert("year".to_string(), FilterValue::Int(2026));
864 metadata_map.insert("doc1".to_string(), m1);
865
866 let mut m2 = HashMap::new();
867 m2.insert("year".to_string(), FilterValue::Int(2023));
868 metadata_map.insert("doc2".to_string(), m2);
869
870 index.batch_set_metadata(&metadata_map);
871
872 assert!(index.get_metadata("doc1").is_some());
874 assert!(index.get_metadata("doc2").is_some());
875 }
876}