Skip to main content

embeddenator_retrieval/
distributed.rs

1//! Distributed Search Infrastructure
2//!
3//! This module provides primitives for distributed semantic search across
4//! multiple nodes/shards. It handles:
5//!
6//! - **Sharding**: Partitioning data across nodes
7//! - **Query routing**: Fan-out queries to relevant shards
8//! - **Result aggregation**: Merge results from multiple shards
9//! - **Topology management**: Track available nodes
10//!
11//! # Architecture
12//!
13//! ```text
14//! ┌─────────────────────────────────────────────────────────────┐
15//! │                    Distributed Search                        │
16//! │  ┌─────────────┐                                            │
17//! │  │   Query     │                                            │
18//! │  └──────┬──────┘                                            │
19//! │         │                                                    │
20//! │         ▼                                                    │
21//! │  ┌─────────────┐    ┌──────────────────────────────────┐   │
22//! │  │   Router    │───▶│         Shard Cluster            │   │
23//! │  └──────┬──────┘    │  ┌─────┐ ┌─────┐ ┌─────┐        │   │
24//! │         │           │  │ S0  │ │ S1  │ │ S2  │ ...    │   │
25//! │         ▼           │  └──┬──┘ └──┬──┘ └──┬──┘        │   │
26//! │  ┌─────────────┐    └────│───────│───────│───────────┘   │
27//! │  │ Aggregator  │◀────────┴───────┴───────┘                │
28//! │  └──────┬──────┘                                            │
29//! │         │                                                    │
30//! │         ▼                                                    │
31//! │  ┌─────────────┐                                            │
32//! │  │   Results   │                                            │
33//! │  └─────────────┘                                            │
34//! └─────────────────────────────────────────────────────────────┘
35//! ```
36//!
37//! # Example
38//!
39//! ```rust,ignore
40//! use embeddenator_retrieval::distributed::{
41//!     DistributedSearch, Shard, ShardId, DistributedConfig,
42//! };
43//! use embeddenator_vsa::SparseVec;
44//!
45//! // Create shards (each could be on a different node)
46//! let mut shard0 = Shard::new(ShardId(0));
47//! shard0.add(1, SparseVec::from_data(b"document one"));
48//! shard0.finalize();
49//!
50//! let mut shard1 = Shard::new(ShardId(1));
51//! shard1.add(2, SparseVec::from_data(b"document two"));
52//! shard1.finalize();
53//!
54//! // Create distributed search coordinator
55//! let mut search = DistributedSearch::new(DistributedConfig::default());
56//! search.add_shard(shard0);
57//! search.add_shard(shard1);
58//!
59//! // Execute distributed query
60//! let query = SparseVec::from_data(b"document");
61//! let (results, stats) = search.query(&query, 10)?;
62//! ```
63
64use std::collections::HashMap;
65use std::sync::atomic::{AtomicU64, Ordering};
66use std::sync::{Arc, RwLock};
67
68use rayon::prelude::*;
69use serde::{Deserialize, Serialize};
70
71use crate::retrieval::TernaryInvertedIndex;
72use crate::search::{two_stage_search, SearchConfig};
73use embeddenator_vsa::SparseVec;
74
75/// Unique identifier for a shard
76#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
77pub struct ShardId(pub u32);
78
79impl ShardId {
80    /// Create from integer
81    pub fn from_u32(id: u32) -> Self {
82        Self(id)
83    }
84}
85
86/// Shard status
87#[derive(Clone, Copy, Debug, PartialEq, Eq, Default, Serialize, Deserialize)]
88pub enum ShardStatus {
89    /// Shard is healthy and accepting queries
90    #[default]
91    Healthy,
92    /// Shard is degraded (slow but functional)
93    Degraded,
94    /// Shard is offline
95    Offline,
96    /// Shard is rebuilding index
97    Rebuilding,
98}
99
100/// A single shard containing a partition of the search corpus
101#[derive(Debug)]
102pub struct Shard {
103    /// Unique shard identifier
104    pub id: ShardId,
105    /// Status of this shard
106    pub status: ShardStatus,
107    /// Inverted index for this shard's data
108    index: TernaryInvertedIndex,
109    /// Vectors in this shard
110    vectors: HashMap<usize, SparseVec>,
111    /// Document count
112    doc_count: usize,
113    /// Query counter
114    query_count: AtomicU64,
115}
116
117impl Shard {
118    /// Create a new empty shard
119    pub fn new(id: ShardId) -> Self {
120        Self {
121            id,
122            status: ShardStatus::Healthy,
123            index: TernaryInvertedIndex::new(),
124            vectors: HashMap::new(),
125            doc_count: 0,
126            query_count: AtomicU64::new(0),
127        }
128    }
129
130    /// Add a document to this shard
131    pub fn add(&mut self, doc_id: usize, vec: SparseVec) {
132        self.index.add(doc_id, &vec);
133        self.vectors.insert(doc_id, vec);
134        self.doc_count += 1;
135    }
136
137    /// Finalize the shard's index for querying
138    pub fn finalize(&mut self) {
139        self.index.finalize();
140    }
141
142    /// Query this shard locally
143    pub fn query(&self, query: &SparseVec, config: &SearchConfig, k: usize) -> Vec<ShardResult> {
144        self.query_count.fetch_add(1, Ordering::Relaxed);
145
146        let results = two_stage_search(query, &self.index, &self.vectors, config, k);
147
148        results
149            .into_iter()
150            .map(|r| ShardResult {
151                shard_id: self.id,
152                doc_id: r.id,
153                score: r.score,
154                approx_score: r.approx_score,
155            })
156            .collect()
157    }
158
159    /// Get document count
160    pub fn doc_count(&self) -> usize {
161        self.doc_count
162    }
163
164    /// Get query count
165    pub fn query_count(&self) -> u64 {
166        self.query_count.load(Ordering::Relaxed)
167    }
168
169    /// Check if shard is available for queries
170    pub fn is_available(&self) -> bool {
171        matches!(self.status, ShardStatus::Healthy | ShardStatus::Degraded)
172    }
173
174    /// Update shard status
175    pub fn set_status(&mut self, status: ShardStatus) {
176        self.status = status;
177    }
178}
179
180/// Result from a single shard query
181#[derive(Clone, Debug, PartialEq)]
182pub struct ShardResult {
183    /// Which shard this result came from
184    pub shard_id: ShardId,
185    /// Document ID within the corpus
186    pub doc_id: usize,
187    /// Similarity score
188    pub score: f64,
189    /// Approximate score from index
190    pub approx_score: i32,
191}
192
193/// Aggregated result from distributed query
194#[derive(Clone, Debug, PartialEq)]
195pub struct DistributedResult {
196    /// Document ID
197    pub doc_id: usize,
198    /// Source shard
199    pub shard_id: ShardId,
200    /// Final score
201    pub score: f64,
202    /// Global rank (1-indexed)
203    pub rank: usize,
204}
205
206/// Configuration for distributed search
207#[derive(Clone, Debug)]
208pub struct DistributedConfig {
209    /// Search configuration for each shard
210    pub search_config: SearchConfig,
211    /// Multiplier for k when querying shards (to ensure enough candidates)
212    pub shard_k_multiplier: f64,
213    /// Timeout per shard query (milliseconds)
214    pub shard_timeout_ms: u64,
215    /// Minimum shards required for valid result
216    pub min_shards: usize,
217    /// Enable parallel shard queries
218    pub parallel_shards: bool,
219}
220
221impl Default for DistributedConfig {
222    fn default() -> Self {
223        Self {
224            search_config: SearchConfig::default(),
225            shard_k_multiplier: 2.0,
226            shard_timeout_ms: 5000,
227            min_shards: 1,
228            parallel_shards: true,
229        }
230    }
231}
232
233/// Statistics from a distributed query
234#[derive(Clone, Debug, Default)]
235pub struct QueryStats {
236    /// Total shards queried
237    pub shards_queried: usize,
238    /// Shards that responded successfully
239    pub shards_responded: usize,
240    /// Total results before aggregation
241    pub total_candidates: usize,
242    /// Results after deduplication
243    pub unique_results: usize,
244    /// Query time in milliseconds
245    pub query_time_ms: u64,
246}
247
248/// Error type for distributed operations
249#[derive(Debug, Clone)]
250pub enum DistributedError {
251    /// Not enough shards available
252    InsufficientShards { available: usize, required: usize },
253    /// All shard queries failed
254    AllShardsFailed,
255    /// Query timeout (reserved for future use; timeout handling is not yet implemented)
256    Timeout,
257    /// Invalid configuration
258    InvalidConfig(String),
259}
260
261impl std::fmt::Display for DistributedError {
262    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
263        match self {
264            DistributedError::InsufficientShards {
265                available,
266                required,
267            } => {
268                write!(
269                    f,
270                    "Insufficient shards: {} available, {} required",
271                    available, required
272                )
273            }
274            DistributedError::AllShardsFailed => write!(f, "All shard queries failed"),
275            DistributedError::Timeout => write!(f, "Query timeout"),
276            DistributedError::InvalidConfig(msg) => write!(f, "Invalid config: {}", msg),
277        }
278    }
279}
280
281impl std::error::Error for DistributedError {}
282
283/// Distributed search coordinator
284///
285/// Manages multiple shards and coordinates distributed queries.
286#[derive(Default)]
287pub struct DistributedSearch {
288    /// Configuration
289    config: DistributedConfig,
290    /// Registered shards
291    shards: Vec<Arc<RwLock<Shard>>>,
292    /// Total queries executed
293    total_queries: AtomicU64,
294}
295
296impl DistributedSearch {
297    /// Create a new distributed search coordinator
298    pub fn new(config: DistributedConfig) -> Self {
299        Self {
300            config,
301            shards: Vec::new(),
302            total_queries: AtomicU64::new(0),
303        }
304    }
305
306    /// Add a shard to the cluster.
307    ///
308    /// Callers must ensure that each shard registered with this coordinator
309    /// has a unique [`ShardId`] and that the same shard is not added more
310    /// than once. Adding multiple shards with the same `ShardId`, or
311    /// registering the same shard repeatedly, may lead to incorrect query
312    /// results and statistics.
313    pub fn add_shard(&mut self, shard: Shard) {
314        self.shards.push(Arc::new(RwLock::new(shard)));
315    }
316
317    /// Get number of registered shards
318    pub fn shard_count(&self) -> usize {
319        self.shards.len()
320    }
321
322    /// Get number of available shards
323    pub fn available_shard_count(&self) -> usize {
324        self.shards
325            .iter()
326            .filter(|s| s.read().map(|s| s.is_available()).unwrap_or(false))
327            .count()
328    }
329
330    /// Execute a distributed query
331    pub fn query(
332        &self,
333        query: &SparseVec,
334        k: usize,
335    ) -> Result<(Vec<DistributedResult>, QueryStats), DistributedError> {
336        let start = std::time::Instant::now();
337        self.total_queries.fetch_add(1, Ordering::Relaxed);
338
339        // Short-circuit when k=0 to avoid unnecessary work
340        if k == 0 {
341            return Ok((
342                Vec::new(),
343                QueryStats {
344                    shards_queried: 0,
345                    shards_responded: 0,
346                    total_candidates: 0,
347                    unique_results: 0,
348                    query_time_ms: start.elapsed().as_millis() as u64,
349                },
350            ));
351        }
352
353        // Check shard availability
354        let available_shards: Vec<_> = self
355            .shards
356            .iter()
357            .filter(|s| s.read().map(|s| s.is_available()).unwrap_or(false))
358            .collect();
359
360        if available_shards.len() < self.config.min_shards {
361            return Err(DistributedError::InsufficientShards {
362                available: available_shards.len(),
363                required: self.config.min_shards,
364            });
365        }
366
367        // Calculate k for each shard with overflow protection
368        let shard_k =
369            ((k as f64 * self.config.shard_k_multiplier).min(usize::MAX as f64) as usize).max(k);
370
371        // Query shards (parallel or sequential), tracking actual responses
372        let shard_results: Vec<Vec<ShardResult>> = if self.config.parallel_shards {
373            available_shards
374                .par_iter()
375                .filter_map(|shard| {
376                    shard
377                        .read()
378                        .ok()
379                        .map(|s| s.query(query, &self.config.search_config, shard_k))
380                })
381                .collect()
382        } else {
383            available_shards
384                .iter()
385                .filter_map(|shard| {
386                    shard
387                        .read()
388                        .ok()
389                        .map(|s| s.query(query, &self.config.search_config, shard_k))
390                })
391                .collect()
392        };
393
394        // Track actual responses vs queried
395        let shards_responded = shard_results.len();
396
397        if shard_results.is_empty() {
398            return Err(DistributedError::AllShardsFailed);
399        }
400
401        // Aggregate results
402        let total_candidates: usize = shard_results.iter().map(|r| r.len()).sum();
403        let mut all_results: Vec<ShardResult> = shard_results.into_iter().flatten().collect();
404
405        // Sort by score descending, then by doc_id for deterministic ordering
406        all_results.sort_by(|a, b| {
407            b.score
408                .partial_cmp(&a.score)
409                .unwrap_or(std::cmp::Ordering::Equal)
410                .then_with(|| a.doc_id.cmp(&b.doc_id))
411        });
412
413        // Deduplicate by doc_id (keep highest score)
414        let mut seen = std::collections::HashSet::new();
415        let unique_results: Vec<DistributedResult> = all_results
416            .into_iter()
417            .filter(|r| seen.insert(r.doc_id))
418            .take(k)
419            .enumerate()
420            .map(|(idx, r)| DistributedResult {
421                doc_id: r.doc_id,
422                shard_id: r.shard_id,
423                score: r.score,
424                rank: idx + 1,
425            })
426            .collect();
427
428        let stats = QueryStats {
429            shards_queried: available_shards.len(),
430            shards_responded,
431            total_candidates,
432            unique_results: unique_results.len(),
433            query_time_ms: start.elapsed().as_millis() as u64,
434        };
435
436        Ok((unique_results, stats))
437    }
438
439    /// Get total queries executed
440    pub fn total_queries(&self) -> u64 {
441        self.total_queries.load(Ordering::Relaxed)
442    }
443
444    /// Get configuration
445    pub fn config(&self) -> &DistributedConfig {
446        &self.config
447    }
448}
449
450/// Sharding strategy for partitioning data
451#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
452pub enum ShardingStrategy {
453    /// Round-robin assignment
454    #[default]
455    RoundRobin,
456    /// Hash-based assignment (consistent hashing)
457    HashBased,
458    /// Range-based assignment (by document ID)
459    RangeBased,
460}
461
462/// Shard assignment helper
463pub struct ShardAssigner {
464    strategy: ShardingStrategy,
465    num_shards: u32,
466    counter: AtomicU64,
467}
468
469impl ShardAssigner {
470    /// Create a new shard assigner
471    pub fn new(strategy: ShardingStrategy, num_shards: u32) -> Self {
472        Self {
473            strategy,
474            num_shards,
475            counter: AtomicU64::new(0),
476        }
477    }
478
479    /// Assign a document to a shard
480    pub fn assign(&self, doc_id: usize) -> ShardId {
481        match self.strategy {
482            ShardingStrategy::RoundRobin => {
483                let idx = self.counter.fetch_add(1, Ordering::Relaxed);
484                ShardId((idx as u32) % self.num_shards)
485            }
486            ShardingStrategy::HashBased => {
487                // Simple hash function
488                let hash = doc_id.wrapping_mul(0x9e3779b9) >> 16;
489                ShardId((hash as u32) % self.num_shards)
490            }
491            ShardingStrategy::RangeBased => {
492                // Assume doc_ids are sequential
493                let range_size = usize::MAX / self.num_shards as usize;
494                ShardId((doc_id / range_size).min(self.num_shards as usize - 1) as u32)
495            }
496        }
497    }
498}
499
500/// Builder for creating a distributed search cluster
501pub struct DistributedSearchBuilder {
502    config: DistributedConfig,
503    num_shards: u32,
504    sharding_strategy: ShardingStrategy,
505    shards: Vec<Shard>,
506    assigner: ShardAssigner,
507}
508
509impl DistributedSearchBuilder {
510    /// Create a new builder
511    ///
512    /// # Panics
513    ///
514    /// Panics if `num_shards` is 0.
515    pub fn new(num_shards: u32) -> Self {
516        assert!(num_shards > 0, "num_shards must be greater than 0");
517        let shards = (0..num_shards).map(|i| Shard::new(ShardId(i))).collect();
518        let assigner = ShardAssigner::new(ShardingStrategy::default(), num_shards);
519        Self {
520            config: DistributedConfig::default(),
521            num_shards,
522            sharding_strategy: ShardingStrategy::default(),
523            shards,
524            assigner,
525        }
526    }
527
528    /// Set the search configuration
529    pub fn with_config(mut self, config: DistributedConfig) -> Self {
530        self.config = config;
531        self
532    }
533
534    /// Set the sharding strategy
535    pub fn with_strategy(mut self, strategy: ShardingStrategy) -> Self {
536        self.sharding_strategy = strategy;
537        // Recreate assigner with new strategy, preserving counter state
538        self.assigner = ShardAssigner::new(strategy, self.num_shards);
539        self
540    }
541
542    /// Add a document to the cluster (assigns to appropriate shard)
543    pub fn add_document(&mut self, doc_id: usize, vec: SparseVec) {
544        // Use the stored assigner to maintain RoundRobin counter state
545        let shard_id = self.assigner.assign(doc_id);
546        if let Some(shard) = self.shards.get_mut(shard_id.0 as usize) {
547            shard.add(doc_id, vec);
548        }
549    }
550
551    /// Build the distributed search cluster
552    pub fn build(mut self) -> DistributedSearch {
553        // Finalize all shards
554        for shard in &mut self.shards {
555            shard.finalize();
556        }
557
558        let mut search = DistributedSearch::new(self.config);
559        for shard in self.shards {
560            search.add_shard(shard);
561        }
562        search
563    }
564}
565
566#[cfg(test)]
567mod tests {
568    use super::*;
569    use embeddenator_vsa::ReversibleVSAConfig;
570
571    fn create_test_vec(data: &[u8]) -> SparseVec {
572        let config = ReversibleVSAConfig::default();
573        SparseVec::encode_data(data, &config, None)
574    }
575
576    #[test]
577    fn test_shard_id() {
578        let id = ShardId(42);
579        assert_eq!(id.0, 42);
580        assert_eq!(ShardId::from_u32(42), id);
581    }
582
583    #[test]
584    fn test_shard_basic() {
585        let mut shard = Shard::new(ShardId(0));
586        assert_eq!(shard.doc_count(), 0);
587        assert!(shard.is_available());
588
589        shard.add(1, create_test_vec(b"document one"));
590        shard.add(2, create_test_vec(b"document two"));
591        shard.finalize();
592
593        assert_eq!(shard.doc_count(), 2);
594    }
595
596    #[test]
597    fn test_shard_query() {
598        let mut shard = Shard::new(ShardId(0));
599        shard.add(1, create_test_vec(b"hello world"));
600        shard.add(2, create_test_vec(b"goodbye world"));
601        shard.finalize();
602
603        let query = create_test_vec(b"hello");
604        let config = SearchConfig::default();
605        let results = shard.query(&query, &config, 2);
606
607        assert!(!results.is_empty());
608        assert_eq!(results[0].shard_id, ShardId(0));
609        assert_eq!(shard.query_count(), 1);
610    }
611
612    #[test]
613    fn test_shard_status() {
614        let mut shard = Shard::new(ShardId(0));
615        assert!(shard.is_available());
616
617        shard.status = ShardStatus::Degraded;
618        assert!(shard.is_available());
619
620        shard.status = ShardStatus::Offline;
621        assert!(!shard.is_available());
622    }
623
624    #[test]
625    fn test_distributed_search_basic() {
626        let mut shard0 = Shard::new(ShardId(0));
627        let mut shard1 = Shard::new(ShardId(1));
628
629        shard0.add(1, create_test_vec(b"document one"));
630        shard0.add(2, create_test_vec(b"document two"));
631        shard0.finalize();
632
633        shard1.add(3, create_test_vec(b"document three"));
634        shard1.add(4, create_test_vec(b"document four"));
635        shard1.finalize();
636
637        let mut search = DistributedSearch::new(DistributedConfig::default());
638        search.add_shard(shard0);
639        search.add_shard(shard1);
640
641        assert_eq!(search.shard_count(), 2);
642        assert_eq!(search.available_shard_count(), 2);
643
644        let query = create_test_vec(b"document");
645        let (results, stats) = search.query(&query, 5).unwrap();
646
647        assert!(!results.is_empty());
648        assert!(results.len() <= 5);
649        assert_eq!(stats.shards_queried, 2);
650        assert_eq!(results[0].rank, 1);
651    }
652
653    #[test]
654    fn test_distributed_search_deduplication() {
655        // Create two shards with overlapping content
656        let mut shard0 = Shard::new(ShardId(0));
657        let mut shard1 = Shard::new(ShardId(1));
658
659        let vec = create_test_vec(b"shared document");
660        shard0.add(1, vec.clone());
661        shard0.finalize();
662
663        shard1.add(1, vec); // Same doc_id
664        shard1.finalize();
665
666        let mut search = DistributedSearch::new(DistributedConfig::default());
667        search.add_shard(shard0);
668        search.add_shard(shard1);
669
670        let query = create_test_vec(b"shared");
671        let (results, _) = search.query(&query, 10).unwrap();
672
673        // Should have only one result (deduplicated)
674        let count_doc1 = results.iter().filter(|r| r.doc_id == 1).count();
675        assert_eq!(count_doc1, 1);
676    }
677
678    #[test]
679    fn test_distributed_search_insufficient_shards() {
680        let search = DistributedSearch::new(DistributedConfig {
681            min_shards: 3,
682            ..Default::default()
683        });
684
685        let query = create_test_vec(b"test");
686        let result = search.query(&query, 10);
687
688        assert!(matches!(
689            result,
690            Err(DistributedError::InsufficientShards { .. })
691        ));
692    }
693
694    #[test]
695    fn test_shard_assigner_round_robin() {
696        let assigner = ShardAssigner::new(ShardingStrategy::RoundRobin, 3);
697
698        assert_eq!(assigner.assign(0), ShardId(0));
699        assert_eq!(assigner.assign(1), ShardId(1));
700        assert_eq!(assigner.assign(2), ShardId(2));
701        assert_eq!(assigner.assign(3), ShardId(0)); // Wraps around
702    }
703
704    #[test]
705    fn test_shard_assigner_hash_based() {
706        let assigner = ShardAssigner::new(ShardingStrategy::HashBased, 4);
707
708        // Same doc_id should always get same shard
709        let shard1 = assigner.assign(100);
710        let shard2 = assigner.assign(100);
711        assert_eq!(shard1, shard2);
712
713        // Different doc_ids likely get different shards (but not guaranteed)
714        let _shard_a = assigner.assign(1);
715        let _shard_b = assigner.assign(1000);
716    }
717
718    #[test]
719    fn test_distributed_builder() {
720        let mut builder = DistributedSearchBuilder::new(3)
721            .with_strategy(ShardingStrategy::RoundRobin)
722            .with_config(DistributedConfig::default());
723
724        builder.add_document(1, create_test_vec(b"doc1"));
725        builder.add_document(2, create_test_vec(b"doc2"));
726        builder.add_document(3, create_test_vec(b"doc3"));
727        builder.add_document(4, create_test_vec(b"doc4"));
728
729        let search = builder.build();
730        assert_eq!(search.shard_count(), 3);
731
732        let query = create_test_vec(b"doc");
733        let (results, _) = search.query(&query, 10).unwrap();
734        assert!(!results.is_empty());
735    }
736
737    #[test]
738    fn test_query_stats() {
739        let mut builder = DistributedSearchBuilder::new(2);
740
741        for i in 0..10 {
742            let data = format!("document {}", i);
743            builder.add_document(i, create_test_vec(data.as_bytes()));
744        }
745
746        let search = builder.build();
747        let query = create_test_vec(b"document");
748        let (_, stats) = search.query(&query, 5).unwrap();
749
750        assert_eq!(stats.shards_queried, 2);
751        assert_eq!(stats.shards_responded, 2);
752        assert!(stats.total_candidates > 0);
753        assert!(stats.unique_results <= 5);
754    }
755
756    #[test]
757    fn test_parallel_distributed_search() {
758        let config = DistributedConfig {
759            parallel_shards: true,
760            ..Default::default()
761        };
762
763        let mut builder = DistributedSearchBuilder::new(4).with_config(config);
764
765        for i in 0..100 {
766            let data = format!("document {} content for testing", i);
767            builder.add_document(i, create_test_vec(data.as_bytes()));
768        }
769
770        let search = builder.build();
771        let query = create_test_vec(b"document content");
772        let (results, stats) = search.query(&query, 20).unwrap();
773
774        assert!(!results.is_empty());
775        assert_eq!(stats.shards_queried, 4);
776    }
777
778    #[test]
779    fn test_all_shards_failed() {
780        // Create search with shards that all become unavailable
781        let mut shard0 = Shard::new(ShardId(0));
782        shard0.add(1, create_test_vec(b"document one"));
783        shard0.finalize();
784        shard0.set_status(ShardStatus::Offline);
785
786        let mut shard1 = Shard::new(ShardId(1));
787        shard1.add(2, create_test_vec(b"document two"));
788        shard1.finalize();
789        shard1.set_status(ShardStatus::Offline);
790
791        let mut search = DistributedSearch::new(DistributedConfig {
792            min_shards: 1, // Require at least 1 shard
793            ..Default::default()
794        });
795        search.add_shard(shard0);
796        search.add_shard(shard1);
797
798        let query = create_test_vec(b"document");
799        let result = search.query(&query, 10);
800
801        // Should fail because no shards are available (all offline)
802        assert!(matches!(
803            result,
804            Err(DistributedError::InsufficientShards { available: 0, .. })
805        ));
806    }
807
808    #[test]
809    fn test_shard_assigner_range_based() {
810        let assigner = ShardAssigner::new(ShardingStrategy::RangeBased, 4);
811
812        // Documents with low IDs should go to early shards
813        let shard_low = assigner.assign(0);
814        let shard_mid = assigner.assign(usize::MAX / 2);
815        let shard_high = assigner.assign(usize::MAX - 1);
816
817        // Low IDs should go to shard 0
818        assert_eq!(shard_low, ShardId(0));
819        // Very high IDs should go to the last shard
820        assert_eq!(shard_high, ShardId(3));
821        // Middle IDs should go to middle shards
822        assert!(shard_mid.0 >= 1 && shard_mid.0 <= 2);
823    }
824
825    #[test]
826    fn test_round_robin_distribution() {
827        // Verify that RoundRobin actually distributes across shards
828        let mut builder =
829            DistributedSearchBuilder::new(3).with_strategy(ShardingStrategy::RoundRobin);
830
831        // Add 9 documents (should be 3 per shard with RoundRobin)
832        for i in 0..9 {
833            builder.add_document(i, create_test_vec(format!("doc{}", i).as_bytes()));
834        }
835
836        // Check that documents are distributed across shards
837        let shard0_count = builder.shards[0].doc_count();
838        let shard1_count = builder.shards[1].doc_count();
839        let shard2_count = builder.shards[2].doc_count();
840
841        // Each shard should have exactly 3 documents with perfect round-robin
842        assert_eq!(shard0_count, 3, "Shard 0 should have 3 documents");
843        assert_eq!(shard1_count, 3, "Shard 1 should have 3 documents");
844        assert_eq!(shard2_count, 3, "Shard 2 should have 3 documents");
845    }
846
847    #[test]
848    fn test_query_k_zero() {
849        let mut builder = DistributedSearchBuilder::new(2);
850        builder.add_document(1, create_test_vec(b"test document"));
851        let search = builder.build();
852
853        let query = create_test_vec(b"test");
854        let (results, stats) = search.query(&query, 0).unwrap();
855
856        // Should return empty results without querying any shards
857        assert!(results.is_empty());
858        assert_eq!(stats.shards_queried, 0);
859    }
860
861    #[test]
862    #[should_panic(expected = "num_shards must be greater than 0")]
863    fn test_builder_zero_shards_panics() {
864        let _ = DistributedSearchBuilder::new(0);
865    }
866
867    #[test]
868    fn test_shard_set_status() {
869        let mut shard = Shard::new(ShardId(0));
870        assert_eq!(shard.status, ShardStatus::Healthy);
871
872        shard.set_status(ShardStatus::Degraded);
873        assert_eq!(shard.status, ShardStatus::Degraded);
874        assert!(shard.is_available());
875
876        shard.set_status(ShardStatus::Rebuilding);
877        assert_eq!(shard.status, ShardStatus::Rebuilding);
878        assert!(!shard.is_available());
879    }
880}