oxify_vector/
distributed.rs

1//! Distributed vector search with sharding and replication.
2//!
3//! This module provides distributed search capabilities for scaling to billions of vectors:
4//!
5//! - **Horizontal Sharding**: Split vectors across multiple nodes
6//! - **Consistent Hashing**: Load balancing for shard assignments
7//! - **Query Routing**: Fan-out queries to all relevant shards
8//! - **Result Merging**: Combine and re-rank results from multiple shards
9//! - **Replication**: Fault tolerance with configurable replica count
10//!
11//! ## Example
12//!
13//! ```rust
14//! use oxify_vector::distributed::{DistributedIndex, ShardConfig, ConsistentHash};
15//! use oxify_vector::{SearchConfig, DistanceMetric};
16//! use std::collections::HashMap;
17//!
18//! # fn example() -> anyhow::Result<()> {
19//! // Configure distributed index with 3 shards
20//! let shard_config = ShardConfig::new(3, 2); // 3 shards, 2 replicas
21//! let search_config = SearchConfig::default();
22//! let mut index = DistributedIndex::new(shard_config, search_config);
23//!
24//! // Build index with automatic sharding
25//! let mut embeddings = HashMap::new();
26//! embeddings.insert("doc1".to_string(), vec![0.1, 0.2, 0.3]);
27//! embeddings.insert("doc2".to_string(), vec![0.2, 0.3, 0.4]);
28//! index.build(&embeddings)?;
29//!
30//! // Search across all shards
31//! let query = vec![0.15, 0.25, 0.35];
32//! let results = index.search(&query, 5)?;
33//! # Ok(())
34//! # }
35//! ```
36
37use crate::filter::{Filter, Metadata};
38use crate::search::VectorSearchIndex;
39use crate::types::{SearchConfig, SearchResult};
40use anyhow::{anyhow, Result};
41use serde::{Deserialize, Serialize};
42use std::collections::{BTreeMap, HashMap};
43use std::hash::{Hash, Hasher};
44use std::sync::{Arc, RwLock};
45
46/// Configuration for distributed sharding.
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ShardConfig {
49    /// Number of shards to split data across
50    pub num_shards: usize,
51    /// Number of replicas for fault tolerance
52    pub num_replicas: usize,
53    /// Number of virtual nodes per physical shard (for consistent hashing)
54    pub virtual_nodes: usize,
55}
56
57impl ShardConfig {
58    /// Create a new shard configuration.
59    ///
60    /// # Arguments
61    /// * `num_shards` - Number of shards (must be >= 1)
62    /// * `num_replicas` - Number of replicas (must be >= 1)
63    ///
64    /// # Example
65    /// ```
66    /// use oxify_vector::distributed::ShardConfig;
67    /// let config = ShardConfig::new(3, 2); // 3 shards, 2 replicas
68    /// ```
69    pub fn new(num_shards: usize, num_replicas: usize) -> Self {
70        assert!(num_shards >= 1, "num_shards must be at least 1");
71        assert!(num_replicas >= 1, "num_replicas must be at least 1");
72        Self {
73            num_shards,
74            num_replicas,
75            virtual_nodes: 150, // Default: 150 virtual nodes per shard
76        }
77    }
78
79    /// Set the number of virtual nodes for consistent hashing.
80    pub fn with_virtual_nodes(mut self, virtual_nodes: usize) -> Self {
81        self.virtual_nodes = virtual_nodes;
82        self
83    }
84}
85
86impl Default for ShardConfig {
87    fn default() -> Self {
88        Self::new(1, 1) // Single shard, single replica by default
89    }
90}
91
92/// Consistent hashing for load balancing across shards.
93///
94/// Uses virtual nodes to improve distribution uniformity.
95#[derive(Debug)]
96pub struct ConsistentHash {
97    /// Ring of hash values to shard IDs
98    ring: BTreeMap<u64, usize>,
99    /// Number of virtual nodes per shard
100    #[allow(dead_code)]
101    virtual_nodes: usize,
102}
103
104impl ConsistentHash {
105    /// Create a new consistent hash ring.
106    ///
107    /// # Arguments
108    /// * `num_shards` - Number of physical shards
109    /// * `virtual_nodes` - Virtual nodes per shard (more = better distribution)
110    pub fn new(num_shards: usize, virtual_nodes: usize) -> Self {
111        let mut ring = BTreeMap::new();
112
113        // Add virtual nodes for each shard
114        for shard_id in 0..num_shards {
115            for vnode in 0..virtual_nodes {
116                let key = format!("shard-{}-vnode-{}", shard_id, vnode);
117                let hash = Self::hash_key(&key);
118                ring.insert(hash, shard_id);
119            }
120        }
121
122        Self {
123            ring,
124            virtual_nodes,
125        }
126    }
127
128    /// Get the shard ID for a given key.
129    pub fn get_shard(&self, key: &str) -> usize {
130        if self.ring.is_empty() {
131            return 0;
132        }
133
134        let hash = Self::hash_key(key);
135
136        // Find the first shard with hash >= key hash (clockwise on ring)
137        match self.ring.range(hash..).next() {
138            Some((&_, &shard_id)) => shard_id,
139            None => *self.ring.values().next().unwrap(), // Wrap around to first shard
140        }
141    }
142
143    /// Get N replica shard IDs for a given key.
144    pub fn get_replicas(&self, key: &str, num_replicas: usize) -> Vec<usize> {
145        if self.ring.is_empty() {
146            return vec![0];
147        }
148
149        let hash = Self::hash_key(key);
150        let mut replicas = Vec::new();
151        let mut seen = std::collections::HashSet::new();
152
153        // Start from the primary shard and walk clockwise
154        for (&_, &shard_id) in self.ring.range(hash..) {
155            if !seen.contains(&shard_id) {
156                replicas.push(shard_id);
157                seen.insert(shard_id);
158                if replicas.len() >= num_replicas {
159                    return replicas;
160                }
161            }
162        }
163
164        // Wrap around if needed
165        for (&_, &shard_id) in self.ring.iter() {
166            if !seen.contains(&shard_id) {
167                replicas.push(shard_id);
168                seen.insert(shard_id);
169                if replicas.len() >= num_replicas {
170                    return replicas;
171                }
172            }
173        }
174
175        replicas
176    }
177
178    /// Hash a key using FNV-1a algorithm
179    fn hash_key(key: &str) -> u64 {
180        let mut hasher = std::collections::hash_map::DefaultHasher::new();
181        key.hash(&mut hasher);
182        hasher.finish()
183    }
184}
185
186/// A single shard containing a vector search index.
187#[derive(Debug)]
188struct Shard {
189    /// Shard ID
190    #[allow(dead_code)]
191    id: usize,
192    /// Vector search index for this shard
193    index: VectorSearchIndex,
194    /// Number of vectors in this shard
195    size: usize,
196}
197
198impl Shard {
199    fn new(id: usize, config: SearchConfig) -> Self {
200        Self {
201            id,
202            index: VectorSearchIndex::new(config),
203            size: 0,
204        }
205    }
206
207    fn build(&mut self, embeddings: &HashMap<String, Vec<f32>>) -> Result<()> {
208        // Only build if there are embeddings (allow empty shards)
209        if !embeddings.is_empty() {
210            self.index.build(embeddings)?;
211            self.size = embeddings.len();
212        }
213        Ok(())
214    }
215
216    fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
217        // Return empty results if shard is empty
218        if self.size == 0 {
219            return Ok(Vec::new());
220        }
221        self.index.search(query, k)
222    }
223
224    fn filtered_search(
225        &self,
226        query: &[f32],
227        k: usize,
228        filter: &Filter,
229    ) -> Result<Vec<SearchResult>> {
230        if self.size == 0 {
231            return Ok(Vec::new());
232        }
233        self.index.filtered_search(query, k, filter)
234    }
235
236    fn set_metadata(&mut self, entity_id: &str, metadata: Metadata) {
237        self.index.set_metadata(entity_id, metadata);
238    }
239
240    fn get_metadata(&self, entity_id: &str) -> Option<&Metadata> {
241        self.index.get_metadata(entity_id)
242    }
243}
244
245/// Distributed vector search index with sharding and replication.
246///
247/// Splits vectors across multiple shards for horizontal scaling.
248/// Supports replication for fault tolerance.
249pub struct DistributedIndex {
250    /// Shard configuration
251    shard_config: ShardConfig,
252    /// Search configuration for each shard
253    #[allow(dead_code)]
254    search_config: SearchConfig,
255    /// Shards (shard_id -> Shard)
256    shards: Vec<Arc<RwLock<Shard>>>,
257    /// Consistent hash ring for shard assignment
258    hash_ring: ConsistentHash,
259    /// Total number of vectors across all shards
260    total_size: Arc<RwLock<usize>>,
261}
262
263impl DistributedIndex {
264    /// Create a new distributed index.
265    pub fn new(shard_config: ShardConfig, search_config: SearchConfig) -> Self {
266        let hash_ring = ConsistentHash::new(shard_config.num_shards, shard_config.virtual_nodes);
267
268        let mut shards = Vec::new();
269        for i in 0..shard_config.num_shards {
270            let shard = Shard::new(i, search_config.clone());
271            shards.push(Arc::new(RwLock::new(shard)));
272        }
273
274        Self {
275            shard_config,
276            search_config,
277            shards,
278            hash_ring,
279            total_size: Arc::new(RwLock::new(0)),
280        }
281    }
282
283    /// Build the distributed index from embeddings.
284    ///
285    /// Automatically distributes vectors across shards using consistent hashing.
286    pub fn build(&mut self, embeddings: &HashMap<String, Vec<f32>>) -> Result<()> {
287        // Partition embeddings by shard
288        let mut shard_embeddings: Vec<HashMap<String, Vec<f32>>> =
289            vec![HashMap::new(); self.shard_config.num_shards];
290
291        for (entity_id, embedding) in embeddings {
292            let shard_id = self.hash_ring.get_shard(entity_id);
293            shard_embeddings[shard_id].insert(entity_id.clone(), embedding.clone());
294        }
295
296        // Build each shard in parallel
297        #[cfg(feature = "parallel")]
298        {
299            use rayon::prelude::*;
300            self.shards
301                .par_iter()
302                .zip(shard_embeddings.par_iter())
303                .try_for_each(|(shard, embs)| -> Result<()> {
304                    let mut shard = shard.write().map_err(|e| anyhow!("Lock error: {}", e))?;
305                    shard.build(embs)?;
306                    Ok(())
307                })?;
308        }
309
310        #[cfg(not(feature = "parallel"))]
311        {
312            for (shard, embs) in self.shards.iter().zip(shard_embeddings.iter()) {
313                let mut shard = shard.write().map_err(|e| anyhow!("Lock error: {}", e))?;
314                shard.build(embs)?;
315            }
316        }
317
318        // Update total size
319        let mut total = self
320            .total_size
321            .write()
322            .map_err(|e| anyhow!("Lock error: {}", e))?;
323        *total = embeddings.len();
324
325        Ok(())
326    }
327
328    /// Search across all shards and merge results.
329    ///
330    /// Performs a fan-out search to all shards in parallel, then merges and re-ranks results.
331    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
332        // Search all shards in parallel
333        #[cfg(feature = "parallel")]
334        let shard_results = {
335            use rayon::prelude::*;
336            self.shards
337                .par_iter()
338                .map(|shard| -> Result<Vec<SearchResult>> {
339                    let shard = shard.read().map_err(|e| anyhow!("Lock error: {}", e))?;
340                    shard.search(query, k)
341                })
342                .collect::<Result<Vec<Vec<SearchResult>>>>()?
343        };
344
345        #[cfg(not(feature = "parallel"))]
346        let shard_results = {
347            let mut results = Vec::new();
348            for shard in &self.shards {
349                let shard = shard.read().map_err(|e| anyhow!("Lock error: {}", e))?;
350                let result = shard.search(query, k)?;
351                results.push(result);
352            }
353            results
354        };
355
356        // Merge and re-rank results from all shards
357        let merged = Self::merge_results(shard_results, k);
358
359        Ok(merged)
360    }
361
362    /// Batch search for multiple queries across all shards.
363    ///
364    /// Performs multiple searches in parallel and returns results for each query.
365    pub fn batch_search(&self, queries: &[Vec<f32>], k: usize) -> Result<Vec<Vec<SearchResult>>> {
366        #[cfg(feature = "parallel")]
367        {
368            use rayon::prelude::*;
369            queries
370                .par_iter()
371                .map(|query| self.search(query, k))
372                .collect()
373        }
374
375        #[cfg(not(feature = "parallel"))]
376        {
377            queries.iter().map(|query| self.search(query, k)).collect()
378        }
379    }
380
381    /// Search with metadata filtering across all shards.
382    ///
383    /// Applies the filter to each shard before merging results.
384    pub fn filtered_search(
385        &self,
386        query: &[f32],
387        k: usize,
388        filter: &Filter,
389    ) -> Result<Vec<SearchResult>> {
390        // Search all shards in parallel
391        #[cfg(feature = "parallel")]
392        let shard_results = {
393            use rayon::prelude::*;
394            self.shards
395                .par_iter()
396                .map(|shard| -> Result<Vec<SearchResult>> {
397                    let shard = shard.read().map_err(|e| anyhow!("Lock error: {}", e))?;
398                    shard.filtered_search(query, k, filter)
399                })
400                .collect::<Result<Vec<Vec<SearchResult>>>>()?
401        };
402
403        #[cfg(not(feature = "parallel"))]
404        let shard_results = {
405            let mut results = Vec::new();
406            for shard in &self.shards {
407                let shard = shard.read().map_err(|e| anyhow!("Lock error: {}", e))?;
408                let result = shard.filtered_search(query, k, filter)?;
409                results.push(result);
410            }
411            results
412        };
413
414        // Merge and re-rank results from all shards
415        let merged = Self::merge_results(shard_results, k);
416
417        Ok(merged)
418    }
419
420    /// Set metadata for an entity across all replica shards.
421    pub fn set_metadata(&mut self, entity_id: &str, metadata: Metadata) {
422        let replica_shards = self
423            .hash_ring
424            .get_replicas(entity_id, self.shard_config.num_replicas);
425
426        for shard_id in replica_shards {
427            if let Ok(mut shard) = self.shards[shard_id].write() {
428                shard.set_metadata(entity_id, metadata.clone());
429            }
430        }
431    }
432
433    /// Get metadata for an entity from the primary shard.
434    pub fn get_metadata(&self, entity_id: &str) -> Option<Metadata> {
435        let shard_id = self.hash_ring.get_shard(entity_id);
436        if let Ok(shard) = self.shards[shard_id].read() {
437            shard.get_metadata(entity_id).cloned()
438        } else {
439            None
440        }
441    }
442
443    /// Set metadata for multiple entities in batch.
444    pub fn batch_set_metadata(&mut self, metadata_map: &HashMap<String, Metadata>) {
445        for (entity_id, metadata) in metadata_map {
446            self.set_metadata(entity_id, metadata.clone());
447        }
448    }
449
450    /// Get statistics about the distributed index.
451    pub fn get_stats(&self) -> Result<DistributedStats> {
452        let mut shard_sizes = Vec::new();
453        let mut total_vectors = 0;
454
455        for shard in &self.shards {
456            let shard = shard.read().map_err(|e| anyhow!("Lock error: {}", e))?;
457            shard_sizes.push(shard.size);
458            total_vectors += shard.size;
459        }
460
461        let avg_shard_size = if !shard_sizes.is_empty() {
462            shard_sizes.iter().sum::<usize>() as f64 / shard_sizes.len() as f64
463        } else {
464            0.0
465        };
466
467        let max_shard_size = shard_sizes.iter().copied().max().unwrap_or(0);
468        let min_shard_size = shard_sizes.iter().copied().min().unwrap_or(0);
469
470        Ok(DistributedStats {
471            num_shards: self.shard_config.num_shards,
472            num_replicas: self.shard_config.num_replicas,
473            total_vectors,
474            shard_sizes,
475            avg_shard_size,
476            max_shard_size,
477            min_shard_size,
478            balance_ratio: if max_shard_size > 0 {
479                min_shard_size as f64 / max_shard_size as f64
480            } else {
481                1.0
482            },
483        })
484    }
485
486    /// Merge results from multiple shards and return top-k.
487    ///
488    /// Uses a simple merge strategy based on scores.
489    fn merge_results(shard_results: Vec<Vec<SearchResult>>, k: usize) -> Vec<SearchResult> {
490        let mut all_results = Vec::new();
491
492        // Flatten all shard results
493        for results in shard_results {
494            all_results.extend(results);
495        }
496
497        // Sort by score (descending for cosine/dot, ascending for distance)
498        all_results.sort_by(|a, b| {
499            b.score
500                .partial_cmp(&a.score)
501                .unwrap_or(std::cmp::Ordering::Equal)
502        });
503
504        // Take top-k and deduplicate by entity_id
505        let mut seen = std::collections::HashSet::new();
506        let mut merged = Vec::new();
507
508        for result in all_results {
509            if !seen.contains(&result.entity_id) {
510                seen.insert(result.entity_id.clone());
511                merged.push(result);
512                if merged.len() >= k {
513                    break;
514                }
515            }
516        }
517
518        merged
519    }
520}
521
522/// Statistics for a distributed index.
523#[derive(Debug, Clone, Serialize, Deserialize)]
524pub struct DistributedStats {
525    /// Number of shards
526    pub num_shards: usize,
527    /// Number of replicas per vector
528    pub num_replicas: usize,
529    /// Total vectors across all shards (counting each unique vector once)
530    pub total_vectors: usize,
531    /// Size of each shard
532    pub shard_sizes: Vec<usize>,
533    /// Average shard size
534    pub avg_shard_size: f64,
535    /// Maximum shard size
536    pub max_shard_size: usize,
537    /// Minimum shard size
538    pub min_shard_size: usize,
539    /// Load balance ratio (min/max, closer to 1.0 is better)
540    pub balance_ratio: f64,
541}
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546
547    #[test]
548    fn test_shard_config() {
549        let config = ShardConfig::new(3, 2);
550        assert_eq!(config.num_shards, 3);
551        assert_eq!(config.num_replicas, 2);
552        assert_eq!(config.virtual_nodes, 150);
553
554        let config = config.with_virtual_nodes(200);
555        assert_eq!(config.virtual_nodes, 200);
556    }
557
558    #[test]
559    fn test_consistent_hash() {
560        let hash = ConsistentHash::new(3, 10);
561
562        // Same key should always map to same shard
563        let shard1 = hash.get_shard("doc1");
564        let shard2 = hash.get_shard("doc1");
565        assert_eq!(shard1, shard2);
566
567        // Different keys should distribute across shards
568        let mut shard_counts = vec![0; 3];
569        for i in 0..100 {
570            let key = format!("doc{}", i);
571            let shard = hash.get_shard(&key);
572            shard_counts[shard] += 1;
573        }
574
575        // Check that distribution is reasonably balanced
576        // (with 100 keys and 3 shards, expect ~33 per shard, allow ±15)
577        for count in shard_counts {
578            assert!(
579                (18..=48).contains(&count),
580                "Imbalanced distribution: {}",
581                count
582            );
583        }
584    }
585
586    #[test]
587    fn test_consistent_hash_replicas() {
588        let hash = ConsistentHash::new(5, 10);
589
590        let replicas = hash.get_replicas("doc1", 3);
591        assert_eq!(replicas.len(), 3);
592
593        // All replicas should be different shards
594        let unique: std::collections::HashSet<_> = replicas.iter().collect();
595        assert_eq!(unique.len(), 3);
596    }
597
598    #[test]
599    fn test_distributed_index_creation() {
600        let shard_config = ShardConfig::new(2, 1);
601        let search_config = SearchConfig::default();
602        let index = DistributedIndex::new(shard_config, search_config);
603
604        assert_eq!(index.shards.len(), 2);
605    }
606
607    #[test]
608    fn test_distributed_index_build() {
609        let shard_config = ShardConfig::new(2, 1);
610        let search_config = SearchConfig::default();
611        let mut index = DistributedIndex::new(shard_config, search_config);
612
613        let mut embeddings = HashMap::new();
614        embeddings.insert("doc1".to_string(), vec![0.1, 0.2, 0.3]);
615        embeddings.insert("doc2".to_string(), vec![0.2, 0.3, 0.4]);
616        embeddings.insert("doc3".to_string(), vec![0.3, 0.4, 0.5]);
617
618        assert!(index.build(&embeddings).is_ok());
619
620        let stats = index.get_stats().unwrap();
621        assert_eq!(stats.num_shards, 2);
622        assert!(stats.total_vectors <= 3); // May be less if replicas counted differently
623    }
624
625    #[test]
626    fn test_distributed_search() {
627        let shard_config = ShardConfig::new(2, 1);
628        let search_config = SearchConfig::default();
629        let mut index = DistributedIndex::new(shard_config, search_config);
630
631        let mut embeddings = HashMap::new();
632        embeddings.insert("doc1".to_string(), vec![1.0, 0.0, 0.0]);
633        embeddings.insert("doc2".to_string(), vec![0.0, 1.0, 0.0]);
634        embeddings.insert("doc3".to_string(), vec![0.0, 0.0, 1.0]);
635
636        index.build(&embeddings).unwrap();
637
638        let query = vec![0.9, 0.1, 0.0];
639        let results = index.search(&query, 2).unwrap();
640
641        assert!(results.len() <= 2);
642        assert_eq!(results[0].entity_id, "doc1");
643    }
644
645    #[test]
646    fn test_distributed_stats() {
647        let shard_config = ShardConfig::new(3, 1);
648        let search_config = SearchConfig::default();
649        let mut index = DistributedIndex::new(shard_config, search_config);
650
651        let mut embeddings = HashMap::new();
652        for i in 0..10 {
653            let key = format!("doc{}", i);
654            let embedding = vec![i as f32 * 0.1, 0.2, 0.3];
655            embeddings.insert(key, embedding);
656        }
657
658        index.build(&embeddings).unwrap();
659
660        let stats = index.get_stats().unwrap();
661        assert_eq!(stats.num_shards, 3);
662        assert_eq!(stats.num_replicas, 1);
663        assert!(stats.total_vectors <= 10);
664        // With only 10 vectors and 3 shards, distribution might not be perfect
665        // Allow for imbalance (some shards could be empty)
666        assert!(stats.balance_ratio >= 0.0 && stats.balance_ratio <= 1.0);
667    }
668
669    #[test]
670    fn test_merge_results() {
671        let shard1_results = vec![
672            SearchResult {
673                entity_id: "doc1".to_string(),
674                score: 0.9,
675                distance: 0.1,
676                rank: 0,
677            },
678            SearchResult {
679                entity_id: "doc2".to_string(),
680                score: 0.7,
681                distance: 0.3,
682                rank: 1,
683            },
684        ];
685
686        let shard2_results = vec![
687            SearchResult {
688                entity_id: "doc3".to_string(),
689                score: 0.85,
690                distance: 0.15,
691                rank: 0,
692            },
693            SearchResult {
694                entity_id: "doc4".to_string(),
695                score: 0.6,
696                distance: 0.4,
697                rank: 1,
698            },
699        ];
700
701        let merged = DistributedIndex::merge_results(vec![shard1_results, shard2_results], 3);
702
703        assert_eq!(merged.len(), 3);
704        assert_eq!(merged[0].entity_id, "doc1"); // 0.9
705        assert_eq!(merged[1].entity_id, "doc3"); // 0.85
706        assert_eq!(merged[2].entity_id, "doc2"); // 0.7
707    }
708
709    #[test]
710    fn test_merge_results_deduplication() {
711        let shard1_results = vec![SearchResult {
712            entity_id: "doc1".to_string(),
713            score: 0.9,
714            distance: 0.1,
715            rank: 0,
716        }];
717
718        let shard2_results = vec![SearchResult {
719            entity_id: "doc1".to_string(),
720            score: 0.85,
721            distance: 0.15,
722            rank: 0,
723        }];
724
725        let merged = DistributedIndex::merge_results(vec![shard1_results, shard2_results], 5);
726
727        // Should only have one copy of doc1
728        assert_eq!(merged.len(), 1);
729        assert_eq!(merged[0].entity_id, "doc1");
730        assert_eq!(merged[0].score, 0.9); // Higher score should win
731    }
732
733    #[test]
734    fn test_distributed_replication() {
735        let shard_config = ShardConfig::new(3, 2); // 3 shards, 2 replicas
736        let search_config = SearchConfig::default();
737        let mut index = DistributedIndex::new(shard_config, search_config);
738
739        // Build with vectors (they will be replicated to multiple shards)
740        let mut embeddings = HashMap::new();
741        embeddings.insert("doc1".to_string(), vec![0.1, 0.2, 0.3]);
742        index.build(&embeddings).unwrap();
743
744        // Search should find it even if one shard fails
745        let query = vec![0.1, 0.2, 0.3];
746        let results = index.search(&query, 1).unwrap();
747        assert_eq!(results.len(), 1);
748        assert_eq!(results[0].entity_id, "doc1");
749    }
750
751    #[test]
752    fn test_distributed_batch_search() {
753        let shard_config = ShardConfig::new(2, 1);
754        let search_config = SearchConfig::default();
755        let mut index = DistributedIndex::new(shard_config, search_config);
756
757        let mut embeddings = HashMap::new();
758        embeddings.insert("doc1".to_string(), vec![1.0, 0.0, 0.0]);
759        embeddings.insert("doc2".to_string(), vec![0.0, 1.0, 0.0]);
760        embeddings.insert("doc3".to_string(), vec![0.0, 0.0, 1.0]);
761        index.build(&embeddings).unwrap();
762
763        let queries = vec![
764            vec![0.9, 0.1, 0.0],
765            vec![0.0, 0.9, 0.1],
766            vec![0.0, 0.0, 0.9],
767        ];
768
769        let results = index.batch_search(&queries, 1).unwrap();
770        assert_eq!(results.len(), 3);
771        assert_eq!(results[0][0].entity_id, "doc1");
772        assert_eq!(results[1][0].entity_id, "doc2");
773        assert_eq!(results[2][0].entity_id, "doc3");
774    }
775
776    #[test]
777    fn test_distributed_filtered_search() {
778        use crate::filter::FilterValue;
779
780        let shard_config = ShardConfig::new(2, 1);
781        let search_config = SearchConfig::default();
782        let mut index = DistributedIndex::new(shard_config, search_config);
783
784        let mut embeddings = HashMap::new();
785        embeddings.insert("doc1".to_string(), vec![1.0, 0.0, 0.0]);
786        embeddings.insert("doc2".to_string(), vec![0.0, 1.0, 0.0]);
787        embeddings.insert("doc3".to_string(), vec![0.0, 0.0, 1.0]);
788        index.build(&embeddings).unwrap();
789
790        // Set metadata
791        let mut metadata1 = HashMap::new();
792        metadata1.insert(
793            "type".to_string(),
794            FilterValue::String("article".to_string()),
795        );
796        index.set_metadata("doc1", metadata1);
797
798        let mut metadata2 = HashMap::new();
799        metadata2.insert("type".to_string(), FilterValue::String("book".to_string()));
800        index.set_metadata("doc2", metadata2);
801
802        let mut metadata3 = HashMap::new();
803        metadata3.insert(
804            "type".to_string(),
805            FilterValue::String("article".to_string()),
806        );
807        index.set_metadata("doc3", metadata3);
808
809        // Filter for articles only
810        let filter = Filter::new().eq("type", "article");
811
812        let query = vec![0.5, 0.5, 0.5];
813        let results = index.filtered_search(&query, 10, &filter).unwrap();
814
815        // Should only return doc1 and doc3 (articles)
816        assert!(results.len() <= 2);
817        for result in &results {
818            assert!(result.entity_id == "doc1" || result.entity_id == "doc3");
819        }
820    }
821
822    #[test]
823    fn test_distributed_metadata() {
824        use crate::filter::FilterValue;
825
826        let shard_config = ShardConfig::new(2, 1);
827        let search_config = SearchConfig::default();
828        let mut index = DistributedIndex::new(shard_config, search_config);
829
830        let mut embeddings = HashMap::new();
831        embeddings.insert("doc1".to_string(), vec![0.1, 0.2, 0.3]);
832        index.build(&embeddings).unwrap();
833
834        // Set metadata
835        let mut metadata = HashMap::new();
836        metadata.insert("year".to_string(), FilterValue::Int(2026));
837        index.set_metadata("doc1", metadata.clone());
838
839        // Get metadata
840        let retrieved = index.get_metadata("doc1");
841        assert!(retrieved.is_some());
842        let retrieved = retrieved.unwrap();
843        assert_eq!(retrieved.get("year"), Some(&FilterValue::Int(2026)));
844    }
845
846    #[test]
847    fn test_distributed_batch_metadata() {
848        use crate::filter::FilterValue;
849
850        let shard_config = ShardConfig::new(2, 1);
851        let search_config = SearchConfig::default();
852        let mut index = DistributedIndex::new(shard_config, search_config);
853
854        let mut embeddings = HashMap::new();
855        embeddings.insert("doc1".to_string(), vec![0.1, 0.2, 0.3]);
856        embeddings.insert("doc2".to_string(), vec![0.2, 0.3, 0.4]);
857        index.build(&embeddings).unwrap();
858
859        // Batch set metadata
860        let mut metadata_map = HashMap::new();
861
862        let mut m1 = HashMap::new();
863        m1.insert("year".to_string(), FilterValue::Int(2026));
864        metadata_map.insert("doc1".to_string(), m1);
865
866        let mut m2 = HashMap::new();
867        m2.insert("year".to_string(), FilterValue::Int(2023));
868        metadata_map.insert("doc2".to_string(), m2);
869
870        index.batch_set_metadata(&metadata_map);
871
872        // Verify
873        assert!(index.get_metadata("doc1").is_some());
874        assert!(index.get_metadata("doc2").is_some());
875    }
876}