1use std::collections::HashMap;
65use std::sync::atomic::{AtomicU64, Ordering};
66use std::sync::{Arc, RwLock};
67
68use rayon::prelude::*;
69use serde::{Deserialize, Serialize};
70
71use crate::retrieval::TernaryInvertedIndex;
72use crate::search::{two_stage_search, SearchConfig};
73use embeddenator_vsa::SparseVec;
74
75#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
77pub struct ShardId(pub u32);
78
79impl ShardId {
80 pub fn from_u32(id: u32) -> Self {
82 Self(id)
83 }
84}
85
86#[derive(Clone, Copy, Debug, PartialEq, Eq, Default, Serialize, Deserialize)]
88pub enum ShardStatus {
89 #[default]
91 Healthy,
92 Degraded,
94 Offline,
96 Rebuilding,
98}
99
100#[derive(Debug)]
102pub struct Shard {
103 pub id: ShardId,
105 pub status: ShardStatus,
107 index: TernaryInvertedIndex,
109 vectors: HashMap<usize, SparseVec>,
111 doc_count: usize,
113 query_count: AtomicU64,
115}
116
117impl Shard {
118 pub fn new(id: ShardId) -> Self {
120 Self {
121 id,
122 status: ShardStatus::Healthy,
123 index: TernaryInvertedIndex::new(),
124 vectors: HashMap::new(),
125 doc_count: 0,
126 query_count: AtomicU64::new(0),
127 }
128 }
129
130 pub fn add(&mut self, doc_id: usize, vec: SparseVec) {
132 self.index.add(doc_id, &vec);
133 self.vectors.insert(doc_id, vec);
134 self.doc_count += 1;
135 }
136
137 pub fn finalize(&mut self) {
139 self.index.finalize();
140 }
141
142 pub fn query(&self, query: &SparseVec, config: &SearchConfig, k: usize) -> Vec<ShardResult> {
144 self.query_count.fetch_add(1, Ordering::Relaxed);
145
146 let results = two_stage_search(query, &self.index, &self.vectors, config, k);
147
148 results
149 .into_iter()
150 .map(|r| ShardResult {
151 shard_id: self.id,
152 doc_id: r.id,
153 score: r.score,
154 approx_score: r.approx_score,
155 })
156 .collect()
157 }
158
159 pub fn doc_count(&self) -> usize {
161 self.doc_count
162 }
163
164 pub fn query_count(&self) -> u64 {
166 self.query_count.load(Ordering::Relaxed)
167 }
168
169 pub fn is_available(&self) -> bool {
171 matches!(self.status, ShardStatus::Healthy | ShardStatus::Degraded)
172 }
173
174 pub fn set_status(&mut self, status: ShardStatus) {
176 self.status = status;
177 }
178}
179
180#[derive(Clone, Debug, PartialEq)]
182pub struct ShardResult {
183 pub shard_id: ShardId,
185 pub doc_id: usize,
187 pub score: f64,
189 pub approx_score: i32,
191}
192
193#[derive(Clone, Debug, PartialEq)]
195pub struct DistributedResult {
196 pub doc_id: usize,
198 pub shard_id: ShardId,
200 pub score: f64,
202 pub rank: usize,
204}
205
206#[derive(Clone, Debug)]
208pub struct DistributedConfig {
209 pub search_config: SearchConfig,
211 pub shard_k_multiplier: f64,
213 pub shard_timeout_ms: u64,
215 pub min_shards: usize,
217 pub parallel_shards: bool,
219}
220
221impl Default for DistributedConfig {
222 fn default() -> Self {
223 Self {
224 search_config: SearchConfig::default(),
225 shard_k_multiplier: 2.0,
226 shard_timeout_ms: 5000,
227 min_shards: 1,
228 parallel_shards: true,
229 }
230 }
231}
232
233#[derive(Clone, Debug, Default)]
235pub struct QueryStats {
236 pub shards_queried: usize,
238 pub shards_responded: usize,
240 pub total_candidates: usize,
242 pub unique_results: usize,
244 pub query_time_ms: u64,
246}
247
248#[derive(Debug, Clone)]
250pub enum DistributedError {
251 InsufficientShards { available: usize, required: usize },
253 AllShardsFailed,
255 Timeout,
257 InvalidConfig(String),
259}
260
261impl std::fmt::Display for DistributedError {
262 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
263 match self {
264 DistributedError::InsufficientShards {
265 available,
266 required,
267 } => {
268 write!(
269 f,
270 "Insufficient shards: {} available, {} required",
271 available, required
272 )
273 }
274 DistributedError::AllShardsFailed => write!(f, "All shard queries failed"),
275 DistributedError::Timeout => write!(f, "Query timeout"),
276 DistributedError::InvalidConfig(msg) => write!(f, "Invalid config: {}", msg),
277 }
278 }
279}
280
281impl std::error::Error for DistributedError {}
282
283#[derive(Default)]
287pub struct DistributedSearch {
288 config: DistributedConfig,
290 shards: Vec<Arc<RwLock<Shard>>>,
292 total_queries: AtomicU64,
294}
295
296impl DistributedSearch {
297 pub fn new(config: DistributedConfig) -> Self {
299 Self {
300 config,
301 shards: Vec::new(),
302 total_queries: AtomicU64::new(0),
303 }
304 }
305
306 pub fn add_shard(&mut self, shard: Shard) {
314 self.shards.push(Arc::new(RwLock::new(shard)));
315 }
316
317 pub fn shard_count(&self) -> usize {
319 self.shards.len()
320 }
321
322 pub fn available_shard_count(&self) -> usize {
324 self.shards
325 .iter()
326 .filter(|s| s.read().map(|s| s.is_available()).unwrap_or(false))
327 .count()
328 }
329
330 pub fn query(
332 &self,
333 query: &SparseVec,
334 k: usize,
335 ) -> Result<(Vec<DistributedResult>, QueryStats), DistributedError> {
336 let start = std::time::Instant::now();
337 self.total_queries.fetch_add(1, Ordering::Relaxed);
338
339 if k == 0 {
341 return Ok((
342 Vec::new(),
343 QueryStats {
344 shards_queried: 0,
345 shards_responded: 0,
346 total_candidates: 0,
347 unique_results: 0,
348 query_time_ms: start.elapsed().as_millis() as u64,
349 },
350 ));
351 }
352
353 let available_shards: Vec<_> = self
355 .shards
356 .iter()
357 .filter(|s| s.read().map(|s| s.is_available()).unwrap_or(false))
358 .collect();
359
360 if available_shards.len() < self.config.min_shards {
361 return Err(DistributedError::InsufficientShards {
362 available: available_shards.len(),
363 required: self.config.min_shards,
364 });
365 }
366
367 let shard_k =
369 ((k as f64 * self.config.shard_k_multiplier).min(usize::MAX as f64) as usize).max(k);
370
371 let shard_results: Vec<Vec<ShardResult>> = if self.config.parallel_shards {
373 available_shards
374 .par_iter()
375 .filter_map(|shard| {
376 shard
377 .read()
378 .ok()
379 .map(|s| s.query(query, &self.config.search_config, shard_k))
380 })
381 .collect()
382 } else {
383 available_shards
384 .iter()
385 .filter_map(|shard| {
386 shard
387 .read()
388 .ok()
389 .map(|s| s.query(query, &self.config.search_config, shard_k))
390 })
391 .collect()
392 };
393
394 let shards_responded = shard_results.len();
396
397 if shard_results.is_empty() {
398 return Err(DistributedError::AllShardsFailed);
399 }
400
401 let total_candidates: usize = shard_results.iter().map(|r| r.len()).sum();
403 let mut all_results: Vec<ShardResult> = shard_results.into_iter().flatten().collect();
404
405 all_results.sort_by(|a, b| {
407 b.score
408 .partial_cmp(&a.score)
409 .unwrap_or(std::cmp::Ordering::Equal)
410 .then_with(|| a.doc_id.cmp(&b.doc_id))
411 });
412
413 let mut seen = std::collections::HashSet::new();
415 let unique_results: Vec<DistributedResult> = all_results
416 .into_iter()
417 .filter(|r| seen.insert(r.doc_id))
418 .take(k)
419 .enumerate()
420 .map(|(idx, r)| DistributedResult {
421 doc_id: r.doc_id,
422 shard_id: r.shard_id,
423 score: r.score,
424 rank: idx + 1,
425 })
426 .collect();
427
428 let stats = QueryStats {
429 shards_queried: available_shards.len(),
430 shards_responded,
431 total_candidates,
432 unique_results: unique_results.len(),
433 query_time_ms: start.elapsed().as_millis() as u64,
434 };
435
436 Ok((unique_results, stats))
437 }
438
439 pub fn total_queries(&self) -> u64 {
441 self.total_queries.load(Ordering::Relaxed)
442 }
443
444 pub fn config(&self) -> &DistributedConfig {
446 &self.config
447 }
448}
449
450#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
452pub enum ShardingStrategy {
453 #[default]
455 RoundRobin,
456 HashBased,
458 RangeBased,
460}
461
462pub struct ShardAssigner {
464 strategy: ShardingStrategy,
465 num_shards: u32,
466 counter: AtomicU64,
467}
468
469impl ShardAssigner {
470 pub fn new(strategy: ShardingStrategy, num_shards: u32) -> Self {
472 Self {
473 strategy,
474 num_shards,
475 counter: AtomicU64::new(0),
476 }
477 }
478
479 pub fn assign(&self, doc_id: usize) -> ShardId {
481 match self.strategy {
482 ShardingStrategy::RoundRobin => {
483 let idx = self.counter.fetch_add(1, Ordering::Relaxed);
484 ShardId((idx as u32) % self.num_shards)
485 }
486 ShardingStrategy::HashBased => {
487 let hash = doc_id.wrapping_mul(0x9e3779b9) >> 16;
489 ShardId((hash as u32) % self.num_shards)
490 }
491 ShardingStrategy::RangeBased => {
492 let range_size = usize::MAX / self.num_shards as usize;
494 ShardId((doc_id / range_size).min(self.num_shards as usize - 1) as u32)
495 }
496 }
497 }
498}
499
500pub struct DistributedSearchBuilder {
502 config: DistributedConfig,
503 num_shards: u32,
504 sharding_strategy: ShardingStrategy,
505 shards: Vec<Shard>,
506 assigner: ShardAssigner,
507}
508
509impl DistributedSearchBuilder {
510 pub fn new(num_shards: u32) -> Self {
516 assert!(num_shards > 0, "num_shards must be greater than 0");
517 let shards = (0..num_shards).map(|i| Shard::new(ShardId(i))).collect();
518 let assigner = ShardAssigner::new(ShardingStrategy::default(), num_shards);
519 Self {
520 config: DistributedConfig::default(),
521 num_shards,
522 sharding_strategy: ShardingStrategy::default(),
523 shards,
524 assigner,
525 }
526 }
527
528 pub fn with_config(mut self, config: DistributedConfig) -> Self {
530 self.config = config;
531 self
532 }
533
534 pub fn with_strategy(mut self, strategy: ShardingStrategy) -> Self {
536 self.sharding_strategy = strategy;
537 self.assigner = ShardAssigner::new(strategy, self.num_shards);
539 self
540 }
541
542 pub fn add_document(&mut self, doc_id: usize, vec: SparseVec) {
544 let shard_id = self.assigner.assign(doc_id);
546 if let Some(shard) = self.shards.get_mut(shard_id.0 as usize) {
547 shard.add(doc_id, vec);
548 }
549 }
550
551 pub fn build(mut self) -> DistributedSearch {
553 for shard in &mut self.shards {
555 shard.finalize();
556 }
557
558 let mut search = DistributedSearch::new(self.config);
559 for shard in self.shards {
560 search.add_shard(shard);
561 }
562 search
563 }
564}
565
566#[cfg(test)]
567mod tests {
568 use super::*;
569 use embeddenator_vsa::ReversibleVSAConfig;
570
571 fn create_test_vec(data: &[u8]) -> SparseVec {
572 let config = ReversibleVSAConfig::default();
573 SparseVec::encode_data(data, &config, None)
574 }
575
576 #[test]
577 fn test_shard_id() {
578 let id = ShardId(42);
579 assert_eq!(id.0, 42);
580 assert_eq!(ShardId::from_u32(42), id);
581 }
582
583 #[test]
584 fn test_shard_basic() {
585 let mut shard = Shard::new(ShardId(0));
586 assert_eq!(shard.doc_count(), 0);
587 assert!(shard.is_available());
588
589 shard.add(1, create_test_vec(b"document one"));
590 shard.add(2, create_test_vec(b"document two"));
591 shard.finalize();
592
593 assert_eq!(shard.doc_count(), 2);
594 }
595
596 #[test]
597 fn test_shard_query() {
598 let mut shard = Shard::new(ShardId(0));
599 shard.add(1, create_test_vec(b"hello world"));
600 shard.add(2, create_test_vec(b"goodbye world"));
601 shard.finalize();
602
603 let query = create_test_vec(b"hello");
604 let config = SearchConfig::default();
605 let results = shard.query(&query, &config, 2);
606
607 assert!(!results.is_empty());
608 assert_eq!(results[0].shard_id, ShardId(0));
609 assert_eq!(shard.query_count(), 1);
610 }
611
612 #[test]
613 fn test_shard_status() {
614 let mut shard = Shard::new(ShardId(0));
615 assert!(shard.is_available());
616
617 shard.status = ShardStatus::Degraded;
618 assert!(shard.is_available());
619
620 shard.status = ShardStatus::Offline;
621 assert!(!shard.is_available());
622 }
623
624 #[test]
625 fn test_distributed_search_basic() {
626 let mut shard0 = Shard::new(ShardId(0));
627 let mut shard1 = Shard::new(ShardId(1));
628
629 shard0.add(1, create_test_vec(b"document one"));
630 shard0.add(2, create_test_vec(b"document two"));
631 shard0.finalize();
632
633 shard1.add(3, create_test_vec(b"document three"));
634 shard1.add(4, create_test_vec(b"document four"));
635 shard1.finalize();
636
637 let mut search = DistributedSearch::new(DistributedConfig::default());
638 search.add_shard(shard0);
639 search.add_shard(shard1);
640
641 assert_eq!(search.shard_count(), 2);
642 assert_eq!(search.available_shard_count(), 2);
643
644 let query = create_test_vec(b"document");
645 let (results, stats) = search.query(&query, 5).unwrap();
646
647 assert!(!results.is_empty());
648 assert!(results.len() <= 5);
649 assert_eq!(stats.shards_queried, 2);
650 assert_eq!(results[0].rank, 1);
651 }
652
653 #[test]
654 fn test_distributed_search_deduplication() {
655 let mut shard0 = Shard::new(ShardId(0));
657 let mut shard1 = Shard::new(ShardId(1));
658
659 let vec = create_test_vec(b"shared document");
660 shard0.add(1, vec.clone());
661 shard0.finalize();
662
663 shard1.add(1, vec); shard1.finalize();
665
666 let mut search = DistributedSearch::new(DistributedConfig::default());
667 search.add_shard(shard0);
668 search.add_shard(shard1);
669
670 let query = create_test_vec(b"shared");
671 let (results, _) = search.query(&query, 10).unwrap();
672
673 let count_doc1 = results.iter().filter(|r| r.doc_id == 1).count();
675 assert_eq!(count_doc1, 1);
676 }
677
678 #[test]
679 fn test_distributed_search_insufficient_shards() {
680 let search = DistributedSearch::new(DistributedConfig {
681 min_shards: 3,
682 ..Default::default()
683 });
684
685 let query = create_test_vec(b"test");
686 let result = search.query(&query, 10);
687
688 assert!(matches!(
689 result,
690 Err(DistributedError::InsufficientShards { .. })
691 ));
692 }
693
694 #[test]
695 fn test_shard_assigner_round_robin() {
696 let assigner = ShardAssigner::new(ShardingStrategy::RoundRobin, 3);
697
698 assert_eq!(assigner.assign(0), ShardId(0));
699 assert_eq!(assigner.assign(1), ShardId(1));
700 assert_eq!(assigner.assign(2), ShardId(2));
701 assert_eq!(assigner.assign(3), ShardId(0)); }
703
704 #[test]
705 fn test_shard_assigner_hash_based() {
706 let assigner = ShardAssigner::new(ShardingStrategy::HashBased, 4);
707
708 let shard1 = assigner.assign(100);
710 let shard2 = assigner.assign(100);
711 assert_eq!(shard1, shard2);
712
713 let _shard_a = assigner.assign(1);
715 let _shard_b = assigner.assign(1000);
716 }
717
718 #[test]
719 fn test_distributed_builder() {
720 let mut builder = DistributedSearchBuilder::new(3)
721 .with_strategy(ShardingStrategy::RoundRobin)
722 .with_config(DistributedConfig::default());
723
724 builder.add_document(1, create_test_vec(b"doc1"));
725 builder.add_document(2, create_test_vec(b"doc2"));
726 builder.add_document(3, create_test_vec(b"doc3"));
727 builder.add_document(4, create_test_vec(b"doc4"));
728
729 let search = builder.build();
730 assert_eq!(search.shard_count(), 3);
731
732 let query = create_test_vec(b"doc");
733 let (results, _) = search.query(&query, 10).unwrap();
734 assert!(!results.is_empty());
735 }
736
737 #[test]
738 fn test_query_stats() {
739 let mut builder = DistributedSearchBuilder::new(2);
740
741 for i in 0..10 {
742 let data = format!("document {}", i);
743 builder.add_document(i, create_test_vec(data.as_bytes()));
744 }
745
746 let search = builder.build();
747 let query = create_test_vec(b"document");
748 let (_, stats) = search.query(&query, 5).unwrap();
749
750 assert_eq!(stats.shards_queried, 2);
751 assert_eq!(stats.shards_responded, 2);
752 assert!(stats.total_candidates > 0);
753 assert!(stats.unique_results <= 5);
754 }
755
756 #[test]
757 fn test_parallel_distributed_search() {
758 let config = DistributedConfig {
759 parallel_shards: true,
760 ..Default::default()
761 };
762
763 let mut builder = DistributedSearchBuilder::new(4).with_config(config);
764
765 for i in 0..100 {
766 let data = format!("document {} content for testing", i);
767 builder.add_document(i, create_test_vec(data.as_bytes()));
768 }
769
770 let search = builder.build();
771 let query = create_test_vec(b"document content");
772 let (results, stats) = search.query(&query, 20).unwrap();
773
774 assert!(!results.is_empty());
775 assert_eq!(stats.shards_queried, 4);
776 }
777
778 #[test]
779 fn test_all_shards_failed() {
780 let mut shard0 = Shard::new(ShardId(0));
782 shard0.add(1, create_test_vec(b"document one"));
783 shard0.finalize();
784 shard0.set_status(ShardStatus::Offline);
785
786 let mut shard1 = Shard::new(ShardId(1));
787 shard1.add(2, create_test_vec(b"document two"));
788 shard1.finalize();
789 shard1.set_status(ShardStatus::Offline);
790
791 let mut search = DistributedSearch::new(DistributedConfig {
792 min_shards: 1, ..Default::default()
794 });
795 search.add_shard(shard0);
796 search.add_shard(shard1);
797
798 let query = create_test_vec(b"document");
799 let result = search.query(&query, 10);
800
801 assert!(matches!(
803 result,
804 Err(DistributedError::InsufficientShards { available: 0, .. })
805 ));
806 }
807
808 #[test]
809 fn test_shard_assigner_range_based() {
810 let assigner = ShardAssigner::new(ShardingStrategy::RangeBased, 4);
811
812 let shard_low = assigner.assign(0);
814 let shard_mid = assigner.assign(usize::MAX / 2);
815 let shard_high = assigner.assign(usize::MAX - 1);
816
817 assert_eq!(shard_low, ShardId(0));
819 assert_eq!(shard_high, ShardId(3));
821 assert!(shard_mid.0 >= 1 && shard_mid.0 <= 2);
823 }
824
825 #[test]
826 fn test_round_robin_distribution() {
827 let mut builder =
829 DistributedSearchBuilder::new(3).with_strategy(ShardingStrategy::RoundRobin);
830
831 for i in 0..9 {
833 builder.add_document(i, create_test_vec(format!("doc{}", i).as_bytes()));
834 }
835
836 let shard0_count = builder.shards[0].doc_count();
838 let shard1_count = builder.shards[1].doc_count();
839 let shard2_count = builder.shards[2].doc_count();
840
841 assert_eq!(shard0_count, 3, "Shard 0 should have 3 documents");
843 assert_eq!(shard1_count, 3, "Shard 1 should have 3 documents");
844 assert_eq!(shard2_count, 3, "Shard 2 should have 3 documents");
845 }
846
847 #[test]
848 fn test_query_k_zero() {
849 let mut builder = DistributedSearchBuilder::new(2);
850 builder.add_document(1, create_test_vec(b"test document"));
851 let search = builder.build();
852
853 let query = create_test_vec(b"test");
854 let (results, stats) = search.query(&query, 0).unwrap();
855
856 assert!(results.is_empty());
858 assert_eq!(stats.shards_queried, 0);
859 }
860
861 #[test]
862 #[should_panic(expected = "num_shards must be greater than 0")]
863 fn test_builder_zero_shards_panics() {
864 let _ = DistributedSearchBuilder::new(0);
865 }
866
867 #[test]
868 fn test_shard_set_status() {
869 let mut shard = Shard::new(ShardId(0));
870 assert_eq!(shard.status, ShardStatus::Healthy);
871
872 shard.set_status(ShardStatus::Degraded);
873 assert_eq!(shard.status, ShardStatus::Degraded);
874 assert!(shard.is_available());
875
876 shard.set_status(ShardStatus::Rebuilding);
877 assert_eq!(shard.status, ShardStatus::Rebuilding);
878 assert!(!shard.is_available());
879 }
880}