Skip to main content

oxirs_graphrag/federation/
distributed.rs

1//! Distributed federation: route GraphRAG queries to multiple remote nodes,
2//! merge results, and manage health state.
3//!
4//! # Architecture
5//!
6//! ```text
7//! FederatedGraphRag
8//!     │
9//!     ├── FederationRouter  (strategy + health-check policy)
10//!     ├── LocalRagEngine    (fallback / self-hosted stub)
11//!     └── Vec<FederationNode>  (remote peers)
12//!
13//! query(FederatedQuery)
14//!     → select nodes via FederationStrategy
15//!     → dispatch (simulated sync) to each node
16//!     → merge Vec<RagResult> into FederatedResult
17//! ```
18
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use std::time::Instant;
22
23// ─────────────────────────────────────────────────────────────────────────────
24// Core types
25// ─────────────────────────────────────────────────────────────────────────────
26
27/// A single result item returned by a RAG node.
28#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
29pub struct RagResult {
30    /// Retrieved text passage or triple serialization
31    pub text: String,
32    /// Relevance score in [0, 1]
33    pub score: f64,
34    /// Identifier of the node that produced this result
35    pub source: String,
36}
37
38/// A remote (or local) RAG node in the federation.
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct FederationNode {
41    /// Unique node identifier
42    pub id: String,
43    /// Network endpoint (URL or address)
44    pub endpoint: String,
45    /// Capabilities advertised by this node (e.g. "temporal", "vector")
46    pub capabilities: Vec<String>,
47    /// Observed round-trip latency in milliseconds
48    pub latency_ms: u64,
49    /// Whether the node is currently reachable
50    pub is_healthy: bool,
51}
52
53/// Strategy for selecting which nodes to query.
54#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
55pub enum FederationStrategy {
56    /// Send the query to every healthy node and merge results
57    BroadcastAll,
58    /// Route only to nodes whose `capabilities` overlap with query needs
59    RouteByCoverage,
60    /// Distribute queries evenly across healthy nodes
61    LoadBalance,
62    /// Try nodes in latency order; stop at the first successful response
63    FailoverChain,
64}
65
66/// Policy and health-check configuration for the router.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct FederationRouter {
69    pub strategy: FederationStrategy,
70    pub health_check_interval_ms: u64,
71}
72
73impl FederationRouter {
74    /// Create a router with the given strategy.
75    pub fn new(strategy: FederationStrategy) -> Self {
76        Self {
77            strategy,
78            health_check_interval_ms: 30_000,
79        }
80    }
81
82    /// Select which nodes to query given the full node list and query context.
83    pub fn select_nodes<'a>(
84        &self,
85        nodes: &'a [FederationNode],
86        query: &FederatedQuery,
87        counter: &mut u64,
88    ) -> Vec<&'a FederationNode> {
89        let healthy: Vec<&FederationNode> = nodes.iter().filter(|n| n.is_healthy).collect();
90
91        match &self.strategy {
92            FederationStrategy::BroadcastAll => healthy,
93
94            FederationStrategy::RouteByCoverage => {
95                // Use nodes that advertise "temporal" capability when query has a timestamp
96                if query.timestamp.is_some() {
97                    let temporal: Vec<_> = healthy
98                        .iter()
99                        .copied()
100                        .filter(|n| n.capabilities.iter().any(|c| c == "temporal"))
101                        .collect();
102                    if !temporal.is_empty() {
103                        return temporal;
104                    }
105                }
106                healthy
107            }
108
109            FederationStrategy::LoadBalance => {
110                if healthy.is_empty() {
111                    return vec![];
112                }
113                // Round-robin: pick node at index (counter % healthy.len())
114                let idx = (*counter as usize) % healthy.len();
115                *counter = counter.wrapping_add(1);
116                vec![healthy[idx]]
117            }
118
119            FederationStrategy::FailoverChain => {
120                // Sort by latency, pick the fastest healthy node
121                let mut sorted = healthy.clone();
122                sorted.sort_by_key(|n| n.latency_ms);
123                sorted.into_iter().take(1).collect()
124            }
125        }
126    }
127}
128
129// ─────────────────────────────────────────────────────────────────────────────
130// Stub local RAG engine
131// ─────────────────────────────────────────────────────────────────────────────
132
133/// Simple in-memory RAG engine used as a local fallback.
134///
135/// Stores a small text corpus and returns entries whose text contains any
136/// query keyword (scored by match ratio).
137#[derive(Debug, Default)]
138pub struct LocalRagEngine {
139    corpus: Vec<(String, f64)>, // (text, base_score)
140}
141
142impl LocalRagEngine {
143    pub fn new() -> Self {
144        Self::default()
145    }
146
147    /// Add a passage to the local corpus with a pre-assigned relevance base score.
148    pub fn add_passage(&mut self, text: impl Into<String>, base_score: f64) {
149        self.corpus.push((text.into(), base_score.clamp(0.0, 1.0)));
150    }
151
152    /// Query the corpus and return up to `top_k` results.
153    pub fn query(&self, q: &str, top_k: usize, source: &str) -> Vec<RagResult> {
154        let keywords: Vec<&str> = q.split_whitespace().collect();
155        let mut scored: Vec<RagResult> = self
156            .corpus
157            .iter()
158            .filter_map(|(text, base)| {
159                let matched = keywords
160                    .iter()
161                    .filter(|kw| text.to_lowercase().contains(&kw.to_lowercase()))
162                    .count();
163                if matched == 0 {
164                    return None;
165                }
166                let kw_score = matched as f64 / keywords.len().max(1) as f64;
167                Some(RagResult {
168                    text: text.clone(),
169                    score: (base + kw_score) / 2.0,
170                    source: source.to_string(),
171                })
172            })
173            .collect();
174
175        scored.sort_by(|a, b| {
176            b.score
177                .partial_cmp(&a.score)
178                .unwrap_or(std::cmp::Ordering::Equal)
179        });
180        scored.truncate(top_k);
181        scored
182    }
183}
184
185// ─────────────────────────────────────────────────────────────────────────────
186// Query and result types
187// ─────────────────────────────────────────────────────────────────────────────
188
189/// A query to submit to the federation.
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct FederatedQuery {
192    /// Natural-language question or keyword query
193    pub query: String,
194    /// Optional point-in-time constraint (Unix-ms)
195    pub timestamp: Option<i64>,
196    /// Maximum number of results to return
197    pub top_k: usize,
198    /// Abort if no response within this many milliseconds (advisory)
199    pub timeout_ms: u64,
200}
201
202/// Aggregated results from the federation.
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct FederatedResult {
205    /// Merged and de-duplicated result list (sorted by score descending)
206    pub results: Vec<RagResult>,
207    /// IDs of the nodes that contributed results
208    pub sources: Vec<String>,
209    /// Observed total latency (wall-clock ms)
210    pub total_latency_ms: u64,
211    /// Number of nodes consulted
212    pub node_count: usize,
213}
214
215// ─────────────────────────────────────────────────────────────────────────────
216// FederatedGraphRag
217// ─────────────────────────────────────────────────────────────────────────────
218
219/// Multi-node GraphRAG federation manager.
220///
221/// In production this would issue async HTTP requests to remote endpoints;
222/// here remote nodes are simulated by the local engine (each node shares the
223/// same local corpus but is given a distinct source label).
224pub struct FederatedGraphRag {
225    nodes: Vec<FederationNode>,
226    local_rag: LocalRagEngine,
227    router: FederationRouter,
228    /// Round-robin counter used by LoadBalance strategy
229    lb_counter: u64,
230}
231
232impl FederatedGraphRag {
233    /// Create a new federation with the given routing strategy.
234    pub fn new(strategy: FederationStrategy) -> Self {
235        Self {
236            nodes: Vec::new(),
237            local_rag: LocalRagEngine::new(),
238            router: FederationRouter::new(strategy),
239            lb_counter: 0,
240        }
241    }
242
243    /// Add a remote node to the federation.
244    pub fn add_node(&mut self, node: FederationNode) {
245        self.nodes.push(node);
246    }
247
248    /// Remove a node by its ID.  Returns `true` if the node existed.
249    pub fn remove_node(&mut self, node_id: &str) -> bool {
250        let before = self.nodes.len();
251        self.nodes.retain(|n| n.id != node_id);
252        self.nodes.len() < before
253    }
254
255    /// Execute a federated query and return the merged result.
256    pub fn query(&mut self, q: &FederatedQuery) -> FederatedResult {
257        let start = Instant::now();
258
259        let selected: Vec<String> = self
260            .router
261            .select_nodes(&self.nodes, q, &mut self.lb_counter)
262            .iter()
263            .map(|n| n.id.clone())
264            .collect();
265
266        let mut all_results: Vec<RagResult> = Vec::new();
267        let mut sources: Vec<String> = Vec::new();
268
269        // Simulate per-node queries via the local engine
270        for node_id in &selected {
271            let node_results = self.local_rag.query(&q.query, q.top_k, node_id);
272            if !node_results.is_empty() {
273                sources.push(node_id.clone());
274                all_results.extend(node_results);
275            }
276        }
277
278        // Merge: de-duplicate by text, keep highest score
279        let mut seen: HashMap<String, usize> = HashMap::new();
280        let mut merged: Vec<RagResult> = Vec::new();
281        for r in all_results {
282            match seen.get(&r.text) {
283                Some(&idx) if merged[idx].score >= r.score => {}
284                _ => {
285                    let idx = merged.len();
286                    seen.insert(r.text.clone(), idx);
287                    merged.push(r);
288                }
289            }
290        }
291
292        merged.sort_by(|a, b| {
293            b.score
294                .partial_cmp(&a.score)
295                .unwrap_or(std::cmp::Ordering::Equal)
296        });
297        merged.truncate(q.top_k);
298
299        FederatedResult {
300            results: merged,
301            sources,
302            total_latency_ms: start.elapsed().as_millis() as u64,
303            node_count: selected.len(),
304        }
305    }
306
307    /// Return references to all currently healthy nodes.
308    pub fn healthy_nodes(&self) -> Vec<&FederationNode> {
309        self.nodes.iter().filter(|n| n.is_healthy).collect()
310    }
311
312    /// Mark a node as unhealthy (e.g. after a failed health check).
313    pub fn mark_unhealthy(&mut self, node_id: &str) {
314        if let Some(node) = self.nodes.iter_mut().find(|n| n.id == node_id) {
315            node.is_healthy = false;
316        }
317    }
318
319    /// Rebalance: restore all nodes to healthy (simulates health-check recovery).
320    pub fn rebalance(&mut self) {
321        for node in &mut self.nodes {
322            node.is_healthy = true;
323        }
324    }
325
326    /// Add a passage to the local corpus (used as backing store for all nodes
327    /// in the simulation).
328    pub fn add_corpus_passage(&mut self, text: impl Into<String>, base_score: f64) {
329        self.local_rag.add_passage(text, base_score);
330    }
331
332    /// Number of nodes in the federation (healthy + unhealthy).
333    pub fn node_count(&self) -> usize {
334        self.nodes.len()
335    }
336}
337
338// ─────────────────────────────────────────────────────────────────────────────
339// Index types and builder
340// ─────────────────────────────────────────────────────────────────────────────
341
342/// A per-node index fragment.
343#[derive(Debug, Clone, Serialize, Deserialize)]
344pub struct LocalIndex {
345    pub node_id: String,
346    /// (key, score) pairs
347    pub entries: Vec<(String, f64)>,
348}
349
350/// A merged index covering all nodes.
351#[derive(Debug, Clone, Serialize, Deserialize)]
352pub struct MergedIndex {
353    /// (key, score, originating_node_id)
354    pub entries: Vec<(String, f64, String)>,
355}
356
357/// A shard of a merged index for distribution.
358#[derive(Debug, Clone, Serialize, Deserialize)]
359pub struct IndexShard {
360    pub shard_id: usize,
361    /// (key, score, originating_node_id)
362    pub entries: Vec<(String, f64, String)>,
363}
364
365/// Utility for building and sharding federation indices.
366pub struct FederatedIndexBuilder;
367
368impl FederatedIndexBuilder {
369    /// Merge multiple per-node indices into a single sorted index.
370    ///
371    /// Duplicate keys are resolved by keeping the highest score across all nodes.
372    pub fn merge_indices(indices: Vec<LocalIndex>) -> MergedIndex {
373        let mut best: HashMap<String, (f64, String)> = HashMap::new();
374
375        for local in indices {
376            for (key, score) in local.entries {
377                let entry = best
378                    .entry(key.clone())
379                    .or_insert((f64::NEG_INFINITY, local.node_id.clone()));
380                if score > entry.0 {
381                    *entry = (score, local.node_id.clone());
382                }
383            }
384        }
385
386        let mut entries: Vec<(String, f64, String)> =
387            best.into_iter().map(|(k, (s, n))| (k, s, n)).collect();
388
389        // Sort by score descending, then key ascending for determinism
390        entries.sort_by(|(ka, sa, _), (kb, sb, _)| {
391            sb.partial_cmp(sa)
392                .unwrap_or(std::cmp::Ordering::Equal)
393                .then_with(|| ka.cmp(kb))
394        });
395
396        MergedIndex { entries }
397    }
398
399    /// Partition a merged index into `shard_count` roughly equal shards.
400    pub fn shard_index(index: &MergedIndex, shard_count: usize) -> Vec<IndexShard> {
401        if shard_count == 0 {
402            return vec![];
403        }
404
405        let mut shards: Vec<IndexShard> = (0..shard_count)
406            .map(|id| IndexShard {
407                shard_id: id,
408                entries: Vec::new(),
409            })
410            .collect();
411
412        for (i, entry) in index.entries.iter().enumerate() {
413            shards[i % shard_count].entries.push(entry.clone());
414        }
415
416        shards
417    }
418}
419
420// ─────────────────────────────────────────────────────────────────────────────
421// Tests
422// ─────────────────────────────────────────────────────────────────────────────
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    fn healthy_node(id: &str, latency: u64) -> FederationNode {
429        FederationNode {
430            id: id.to_string(),
431            endpoint: format!("http://{id}.example.com"),
432            capabilities: vec!["vector".to_string()],
433            latency_ms: latency,
434            is_healthy: true,
435        }
436    }
437
438    fn temporal_node(id: &str) -> FederationNode {
439        FederationNode {
440            id: id.to_string(),
441            endpoint: format!("http://{id}.example.com"),
442            capabilities: vec!["temporal".to_string(), "vector".to_string()],
443            latency_ms: 10,
444            is_healthy: true,
445        }
446    }
447
448    fn make_query(q: &str) -> FederatedQuery {
449        FederatedQuery {
450            query: q.to_string(),
451            timestamp: None,
452            top_k: 5,
453            timeout_ms: 1000,
454        }
455    }
456
457    // ── FederationNode ────────────────────────────────────────────────────
458
459    #[test]
460    fn test_federation_node_fields() {
461        let node = healthy_node("node1", 50);
462        assert_eq!(node.id, "node1");
463        assert!(node.is_healthy);
464        assert_eq!(node.latency_ms, 50);
465    }
466
467    // ── FederatedGraphRag::add_node / remove_node ─────────────────────────
468
469    #[test]
470    fn test_add_and_remove_node() {
471        let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
472        fed.add_node(healthy_node("A", 10));
473        fed.add_node(healthy_node("B", 20));
474        assert_eq!(fed.node_count(), 2);
475
476        let removed = fed.remove_node("A");
477        assert!(removed);
478        assert_eq!(fed.node_count(), 1);
479    }
480
481    #[test]
482    fn test_remove_nonexistent_node_returns_false() {
483        let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
484        assert!(!fed.remove_node("ghost"));
485    }
486
487    // ── healthy_nodes ─────────────────────────────────────────────────────
488
489    #[test]
490    fn test_healthy_nodes_filters_unhealthy() {
491        let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
492        fed.add_node(healthy_node("A", 10));
493        fed.add_node(healthy_node("B", 10));
494        fed.mark_unhealthy("A");
495        assert_eq!(fed.healthy_nodes().len(), 1);
496        assert_eq!(fed.healthy_nodes()[0].id, "B");
497    }
498
499    #[test]
500    fn test_healthy_nodes_empty_federation() {
501        let fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
502        assert!(fed.healthy_nodes().is_empty());
503    }
504
505    // ── mark_unhealthy / rebalance ────────────────────────────────────────
506
507    #[test]
508    fn test_mark_unhealthy_sets_flag() {
509        let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
510        fed.add_node(healthy_node("A", 10));
511        fed.mark_unhealthy("A");
512        assert!(!fed.nodes[0].is_healthy);
513    }
514
515    #[test]
516    fn test_rebalance_restores_all_nodes() {
517        let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
518        fed.add_node(healthy_node("A", 10));
519        fed.add_node(healthy_node("B", 10));
520        fed.mark_unhealthy("A");
521        fed.mark_unhealthy("B");
522        assert_eq!(fed.healthy_nodes().len(), 0);
523        fed.rebalance();
524        assert_eq!(fed.healthy_nodes().len(), 2);
525    }
526
527    // ── query: BroadcastAll ───────────────────────────────────────────────
528
529    #[test]
530    fn test_query_broadcast_all_returns_merged_results() {
531        let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
532        fed.add_node(healthy_node("A", 10));
533        fed.add_node(healthy_node("B", 20));
534        fed.add_corpus_passage("Rust is a systems language", 0.9);
535
536        let result = fed.query(&make_query("Rust language"));
537        // Should have collected from both nodes
538        assert_eq!(result.node_count, 2);
539        assert!(!result.results.is_empty());
540    }
541
542    #[test]
543    fn test_query_with_no_healthy_nodes_returns_empty() {
544        let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
545        fed.add_node(healthy_node("A", 10));
546        fed.mark_unhealthy("A");
547        let result = fed.query(&make_query("anything"));
548        assert!(result.results.is_empty());
549        assert_eq!(result.node_count, 0);
550    }
551
552    // ── query: FailoverChain ──────────────────────────────────────────────
553
554    #[test]
555    fn test_failover_chain_picks_fastest_node() {
556        let mut fed = FederatedGraphRag::new(FederationStrategy::FailoverChain);
557        fed.add_node(healthy_node("slow", 200));
558        fed.add_node(healthy_node("fast", 10));
559        fed.add_corpus_passage("Semantic Web SPARQL", 0.8);
560
561        let result = fed.query(&make_query("Semantic Web"));
562        // Only one node consulted (fastest)
563        assert_eq!(result.node_count, 1);
564        assert_eq!(result.sources[0], "fast");
565    }
566
567    // ── query: RouteByCoverage with timestamp ─────────────────────────────
568
569    #[test]
570    fn test_route_by_coverage_uses_temporal_node() {
571        let mut fed = FederatedGraphRag::new(FederationStrategy::RouteByCoverage);
572        fed.add_node(healthy_node("generic", 10));
573        fed.add_node(temporal_node("temporal_node"));
574        fed.add_corpus_passage("historical data", 0.85);
575
576        let mut q = make_query("historical data");
577        q.timestamp = Some(1_700_000_000_000); // some timestamp
578
579        let result = fed.query(&q);
580        assert!(result.node_count > 0);
581        // Should prefer temporal_node
582        assert!(result.sources.contains(&"temporal_node".to_string()));
583    }
584
585    // ── query: LoadBalance ────────────────────────────────────────────────
586
587    #[test]
588    fn test_load_balance_rotates_nodes() {
589        let mut fed = FederatedGraphRag::new(FederationStrategy::LoadBalance);
590        fed.add_node(healthy_node("N1", 10));
591        fed.add_node(healthy_node("N2", 10));
592        fed.add_corpus_passage("GraphRAG federation", 0.9);
593
594        let q = make_query("GraphRAG");
595        let r1 = fed.query(&q);
596        let r2 = fed.query(&q);
597
598        // Should have queried different nodes
599        assert_eq!(r1.node_count, 1);
600        assert_eq!(r2.node_count, 1);
601        // sources may differ (round-robin)
602        let _ = r1.sources;
603        let _ = r2.sources;
604    }
605
606    // ── FederatedResult fields ────────────────────────────────────────────
607
608    #[test]
609    fn test_federated_result_latency_non_negative() {
610        let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
611        fed.add_node(healthy_node("A", 10));
612        let result = fed.query(&make_query("test"));
613        // latency is wall-clock ms, should be very small in tests but ≥ 0
614        // (just verify it doesn't panic / overflow)
615        let _ = result.total_latency_ms;
616    }
617
618    // ── LocalRagEngine ────────────────────────────────────────────────────
619
620    #[test]
621    fn test_local_rag_returns_matching_passage() {
622        let mut eng = LocalRagEngine::new();
623        eng.add_passage("GraphRAG combines graph and retrieval", 0.8);
624        eng.add_passage("Unrelated content here", 0.5);
625
626        let results = eng.query("GraphRAG retrieval", 5, "local");
627        assert!(!results.is_empty());
628        assert!(results[0].text.contains("GraphRAG"));
629    }
630
631    #[test]
632    fn test_local_rag_top_k_limit() {
633        let mut eng = LocalRagEngine::new();
634        for i in 0..10 {
635            eng.add_passage(format!("passage {i} keyword"), 0.5);
636        }
637        let results = eng.query("keyword", 3, "local");
638        assert!(results.len() <= 3);
639    }
640
641    #[test]
642    fn test_local_rag_no_match_returns_empty() {
643        let mut eng = LocalRagEngine::new();
644        eng.add_passage("Completely unrelated text", 0.5);
645        let results = eng.query("xyzzy", 5, "local");
646        assert!(results.is_empty());
647    }
648
649    // ── FederatedIndexBuilder::merge_indices ──────────────────────────────
650
651    #[test]
652    fn test_merge_indices_picks_best_score() {
653        let i1 = LocalIndex {
654            node_id: "A".to_string(),
655            entries: vec![("key1".to_string(), 0.5), ("key2".to_string(), 0.9)],
656        };
657        let i2 = LocalIndex {
658            node_id: "B".to_string(),
659            entries: vec![("key1".to_string(), 0.8), ("key3".to_string(), 0.7)],
660        };
661
662        let merged = FederatedIndexBuilder::merge_indices(vec![i1, i2]);
663        // key1: B wins (0.8 > 0.5)
664        let key1 = merged
665            .entries
666            .iter()
667            .find(|(k, _, _)| k == "key1")
668            .expect("should succeed");
669        assert!((key1.1 - 0.8).abs() < 1e-9);
670        assert_eq!(key1.2, "B");
671        // key2 from A, key3 from B
672        assert_eq!(merged.entries.len(), 3);
673    }
674
675    #[test]
676    fn test_merge_indices_sorted_descending() {
677        let i1 = LocalIndex {
678            node_id: "A".to_string(),
679            entries: vec![
680                ("low".to_string(), 0.1),
681                ("high".to_string(), 0.9),
682                ("mid".to_string(), 0.5),
683            ],
684        };
685        let merged = FederatedIndexBuilder::merge_indices(vec![i1]);
686        for i in 1..merged.entries.len() {
687            assert!(merged.entries[i - 1].1 >= merged.entries[i].1);
688        }
689    }
690
691    #[test]
692    fn test_merge_indices_empty_returns_empty() {
693        let merged = FederatedIndexBuilder::merge_indices(vec![]);
694        assert!(merged.entries.is_empty());
695    }
696
697    // ── FederatedIndexBuilder::shard_index ───────────────────────────────
698
699    #[test]
700    fn test_shard_index_creates_correct_shard_count() {
701        let merged = MergedIndex {
702            entries: (0..10)
703                .map(|i| (format!("key{i}"), i as f64 * 0.1, "A".to_string()))
704                .collect(),
705        };
706        let shards = FederatedIndexBuilder::shard_index(&merged, 3);
707        assert_eq!(shards.len(), 3);
708    }
709
710    #[test]
711    fn test_shard_index_all_entries_distributed() {
712        let merged = MergedIndex {
713            entries: (0..9)
714                .map(|i| (format!("key{i}"), 0.5, "A".to_string()))
715                .collect(),
716        };
717        let shards = FederatedIndexBuilder::shard_index(&merged, 3);
718        let total: usize = shards.iter().map(|s| s.entries.len()).sum();
719        assert_eq!(total, 9);
720    }
721
722    #[test]
723    fn test_shard_index_zero_shards_returns_empty() {
724        let merged = MergedIndex {
725            entries: vec![("k".to_string(), 0.5, "A".to_string())],
726        };
727        let shards = FederatedIndexBuilder::shard_index(&merged, 0);
728        assert!(shards.is_empty());
729    }
730
731    #[test]
732    fn test_shard_index_ids_are_sequential() {
733        let merged = MergedIndex {
734            entries: (0..6)
735                .map(|i| (format!("k{i}"), 0.5, "A".to_string()))
736                .collect(),
737        };
738        let shards = FederatedIndexBuilder::shard_index(&merged, 3);
739        for (expected, shard) in shards.iter().enumerate() {
740            assert_eq!(shard.shard_id, expected);
741        }
742    }
743}