1use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use std::time::Instant;
22
23#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
29pub struct RagResult {
30 pub text: String,
32 pub score: f64,
34 pub source: String,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct FederationNode {
41 pub id: String,
43 pub endpoint: String,
45 pub capabilities: Vec<String>,
47 pub latency_ms: u64,
49 pub is_healthy: bool,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
55pub enum FederationStrategy {
56 BroadcastAll,
58 RouteByCoverage,
60 LoadBalance,
62 FailoverChain,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct FederationRouter {
69 pub strategy: FederationStrategy,
70 pub health_check_interval_ms: u64,
71}
72
73impl FederationRouter {
74 pub fn new(strategy: FederationStrategy) -> Self {
76 Self {
77 strategy,
78 health_check_interval_ms: 30_000,
79 }
80 }
81
82 pub fn select_nodes<'a>(
84 &self,
85 nodes: &'a [FederationNode],
86 query: &FederatedQuery,
87 counter: &mut u64,
88 ) -> Vec<&'a FederationNode> {
89 let healthy: Vec<&FederationNode> = nodes.iter().filter(|n| n.is_healthy).collect();
90
91 match &self.strategy {
92 FederationStrategy::BroadcastAll => healthy,
93
94 FederationStrategy::RouteByCoverage => {
95 if query.timestamp.is_some() {
97 let temporal: Vec<_> = healthy
98 .iter()
99 .copied()
100 .filter(|n| n.capabilities.iter().any(|c| c == "temporal"))
101 .collect();
102 if !temporal.is_empty() {
103 return temporal;
104 }
105 }
106 healthy
107 }
108
109 FederationStrategy::LoadBalance => {
110 if healthy.is_empty() {
111 return vec![];
112 }
113 let idx = (*counter as usize) % healthy.len();
115 *counter = counter.wrapping_add(1);
116 vec![healthy[idx]]
117 }
118
119 FederationStrategy::FailoverChain => {
120 let mut sorted = healthy.clone();
122 sorted.sort_by_key(|n| n.latency_ms);
123 sorted.into_iter().take(1).collect()
124 }
125 }
126 }
127}
128
129#[derive(Debug, Default)]
138pub struct LocalRagEngine {
139 corpus: Vec<(String, f64)>, }
141
142impl LocalRagEngine {
143 pub fn new() -> Self {
144 Self::default()
145 }
146
147 pub fn add_passage(&mut self, text: impl Into<String>, base_score: f64) {
149 self.corpus.push((text.into(), base_score.clamp(0.0, 1.0)));
150 }
151
152 pub fn query(&self, q: &str, top_k: usize, source: &str) -> Vec<RagResult> {
154 let keywords: Vec<&str> = q.split_whitespace().collect();
155 let mut scored: Vec<RagResult> = self
156 .corpus
157 .iter()
158 .filter_map(|(text, base)| {
159 let matched = keywords
160 .iter()
161 .filter(|kw| text.to_lowercase().contains(&kw.to_lowercase()))
162 .count();
163 if matched == 0 {
164 return None;
165 }
166 let kw_score = matched as f64 / keywords.len().max(1) as f64;
167 Some(RagResult {
168 text: text.clone(),
169 score: (base + kw_score) / 2.0,
170 source: source.to_string(),
171 })
172 })
173 .collect();
174
175 scored.sort_by(|a, b| {
176 b.score
177 .partial_cmp(&a.score)
178 .unwrap_or(std::cmp::Ordering::Equal)
179 });
180 scored.truncate(top_k);
181 scored
182 }
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct FederatedQuery {
192 pub query: String,
194 pub timestamp: Option<i64>,
196 pub top_k: usize,
198 pub timeout_ms: u64,
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct FederatedResult {
205 pub results: Vec<RagResult>,
207 pub sources: Vec<String>,
209 pub total_latency_ms: u64,
211 pub node_count: usize,
213}
214
215pub struct FederatedGraphRag {
225 nodes: Vec<FederationNode>,
226 local_rag: LocalRagEngine,
227 router: FederationRouter,
228 lb_counter: u64,
230}
231
232impl FederatedGraphRag {
233 pub fn new(strategy: FederationStrategy) -> Self {
235 Self {
236 nodes: Vec::new(),
237 local_rag: LocalRagEngine::new(),
238 router: FederationRouter::new(strategy),
239 lb_counter: 0,
240 }
241 }
242
243 pub fn add_node(&mut self, node: FederationNode) {
245 self.nodes.push(node);
246 }
247
248 pub fn remove_node(&mut self, node_id: &str) -> bool {
250 let before = self.nodes.len();
251 self.nodes.retain(|n| n.id != node_id);
252 self.nodes.len() < before
253 }
254
255 pub fn query(&mut self, q: &FederatedQuery) -> FederatedResult {
257 let start = Instant::now();
258
259 let selected: Vec<String> = self
260 .router
261 .select_nodes(&self.nodes, q, &mut self.lb_counter)
262 .iter()
263 .map(|n| n.id.clone())
264 .collect();
265
266 let mut all_results: Vec<RagResult> = Vec::new();
267 let mut sources: Vec<String> = Vec::new();
268
269 for node_id in &selected {
271 let node_results = self.local_rag.query(&q.query, q.top_k, node_id);
272 if !node_results.is_empty() {
273 sources.push(node_id.clone());
274 all_results.extend(node_results);
275 }
276 }
277
278 let mut seen: HashMap<String, usize> = HashMap::new();
280 let mut merged: Vec<RagResult> = Vec::new();
281 for r in all_results {
282 match seen.get(&r.text) {
283 Some(&idx) if merged[idx].score >= r.score => {}
284 _ => {
285 let idx = merged.len();
286 seen.insert(r.text.clone(), idx);
287 merged.push(r);
288 }
289 }
290 }
291
292 merged.sort_by(|a, b| {
293 b.score
294 .partial_cmp(&a.score)
295 .unwrap_or(std::cmp::Ordering::Equal)
296 });
297 merged.truncate(q.top_k);
298
299 FederatedResult {
300 results: merged,
301 sources,
302 total_latency_ms: start.elapsed().as_millis() as u64,
303 node_count: selected.len(),
304 }
305 }
306
307 pub fn healthy_nodes(&self) -> Vec<&FederationNode> {
309 self.nodes.iter().filter(|n| n.is_healthy).collect()
310 }
311
312 pub fn mark_unhealthy(&mut self, node_id: &str) {
314 if let Some(node) = self.nodes.iter_mut().find(|n| n.id == node_id) {
315 node.is_healthy = false;
316 }
317 }
318
319 pub fn rebalance(&mut self) {
321 for node in &mut self.nodes {
322 node.is_healthy = true;
323 }
324 }
325
326 pub fn add_corpus_passage(&mut self, text: impl Into<String>, base_score: f64) {
329 self.local_rag.add_passage(text, base_score);
330 }
331
332 pub fn node_count(&self) -> usize {
334 self.nodes.len()
335 }
336}
337
338#[derive(Debug, Clone, Serialize, Deserialize)]
344pub struct LocalIndex {
345 pub node_id: String,
346 pub entries: Vec<(String, f64)>,
348}
349
350#[derive(Debug, Clone, Serialize, Deserialize)]
352pub struct MergedIndex {
353 pub entries: Vec<(String, f64, String)>,
355}
356
357#[derive(Debug, Clone, Serialize, Deserialize)]
359pub struct IndexShard {
360 pub shard_id: usize,
361 pub entries: Vec<(String, f64, String)>,
363}
364
365pub struct FederatedIndexBuilder;
367
368impl FederatedIndexBuilder {
369 pub fn merge_indices(indices: Vec<LocalIndex>) -> MergedIndex {
373 let mut best: HashMap<String, (f64, String)> = HashMap::new();
374
375 for local in indices {
376 for (key, score) in local.entries {
377 let entry = best
378 .entry(key.clone())
379 .or_insert((f64::NEG_INFINITY, local.node_id.clone()));
380 if score > entry.0 {
381 *entry = (score, local.node_id.clone());
382 }
383 }
384 }
385
386 let mut entries: Vec<(String, f64, String)> =
387 best.into_iter().map(|(k, (s, n))| (k, s, n)).collect();
388
389 entries.sort_by(|(ka, sa, _), (kb, sb, _)| {
391 sb.partial_cmp(sa)
392 .unwrap_or(std::cmp::Ordering::Equal)
393 .then_with(|| ka.cmp(kb))
394 });
395
396 MergedIndex { entries }
397 }
398
399 pub fn shard_index(index: &MergedIndex, shard_count: usize) -> Vec<IndexShard> {
401 if shard_count == 0 {
402 return vec![];
403 }
404
405 let mut shards: Vec<IndexShard> = (0..shard_count)
406 .map(|id| IndexShard {
407 shard_id: id,
408 entries: Vec::new(),
409 })
410 .collect();
411
412 for (i, entry) in index.entries.iter().enumerate() {
413 shards[i % shard_count].entries.push(entry.clone());
414 }
415
416 shards
417 }
418}
419
420#[cfg(test)]
425mod tests {
426 use super::*;
427
428 fn healthy_node(id: &str, latency: u64) -> FederationNode {
429 FederationNode {
430 id: id.to_string(),
431 endpoint: format!("http://{id}.example.com"),
432 capabilities: vec!["vector".to_string()],
433 latency_ms: latency,
434 is_healthy: true,
435 }
436 }
437
438 fn temporal_node(id: &str) -> FederationNode {
439 FederationNode {
440 id: id.to_string(),
441 endpoint: format!("http://{id}.example.com"),
442 capabilities: vec!["temporal".to_string(), "vector".to_string()],
443 latency_ms: 10,
444 is_healthy: true,
445 }
446 }
447
448 fn make_query(q: &str) -> FederatedQuery {
449 FederatedQuery {
450 query: q.to_string(),
451 timestamp: None,
452 top_k: 5,
453 timeout_ms: 1000,
454 }
455 }
456
457 #[test]
460 fn test_federation_node_fields() {
461 let node = healthy_node("node1", 50);
462 assert_eq!(node.id, "node1");
463 assert!(node.is_healthy);
464 assert_eq!(node.latency_ms, 50);
465 }
466
467 #[test]
470 fn test_add_and_remove_node() {
471 let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
472 fed.add_node(healthy_node("A", 10));
473 fed.add_node(healthy_node("B", 20));
474 assert_eq!(fed.node_count(), 2);
475
476 let removed = fed.remove_node("A");
477 assert!(removed);
478 assert_eq!(fed.node_count(), 1);
479 }
480
481 #[test]
482 fn test_remove_nonexistent_node_returns_false() {
483 let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
484 assert!(!fed.remove_node("ghost"));
485 }
486
487 #[test]
490 fn test_healthy_nodes_filters_unhealthy() {
491 let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
492 fed.add_node(healthy_node("A", 10));
493 fed.add_node(healthy_node("B", 10));
494 fed.mark_unhealthy("A");
495 assert_eq!(fed.healthy_nodes().len(), 1);
496 assert_eq!(fed.healthy_nodes()[0].id, "B");
497 }
498
499 #[test]
500 fn test_healthy_nodes_empty_federation() {
501 let fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
502 assert!(fed.healthy_nodes().is_empty());
503 }
504
505 #[test]
508 fn test_mark_unhealthy_sets_flag() {
509 let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
510 fed.add_node(healthy_node("A", 10));
511 fed.mark_unhealthy("A");
512 assert!(!fed.nodes[0].is_healthy);
513 }
514
515 #[test]
516 fn test_rebalance_restores_all_nodes() {
517 let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
518 fed.add_node(healthy_node("A", 10));
519 fed.add_node(healthy_node("B", 10));
520 fed.mark_unhealthy("A");
521 fed.mark_unhealthy("B");
522 assert_eq!(fed.healthy_nodes().len(), 0);
523 fed.rebalance();
524 assert_eq!(fed.healthy_nodes().len(), 2);
525 }
526
527 #[test]
530 fn test_query_broadcast_all_returns_merged_results() {
531 let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
532 fed.add_node(healthy_node("A", 10));
533 fed.add_node(healthy_node("B", 20));
534 fed.add_corpus_passage("Rust is a systems language", 0.9);
535
536 let result = fed.query(&make_query("Rust language"));
537 assert_eq!(result.node_count, 2);
539 assert!(!result.results.is_empty());
540 }
541
542 #[test]
543 fn test_query_with_no_healthy_nodes_returns_empty() {
544 let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
545 fed.add_node(healthy_node("A", 10));
546 fed.mark_unhealthy("A");
547 let result = fed.query(&make_query("anything"));
548 assert!(result.results.is_empty());
549 assert_eq!(result.node_count, 0);
550 }
551
552 #[test]
555 fn test_failover_chain_picks_fastest_node() {
556 let mut fed = FederatedGraphRag::new(FederationStrategy::FailoverChain);
557 fed.add_node(healthy_node("slow", 200));
558 fed.add_node(healthy_node("fast", 10));
559 fed.add_corpus_passage("Semantic Web SPARQL", 0.8);
560
561 let result = fed.query(&make_query("Semantic Web"));
562 assert_eq!(result.node_count, 1);
564 assert_eq!(result.sources[0], "fast");
565 }
566
567 #[test]
570 fn test_route_by_coverage_uses_temporal_node() {
571 let mut fed = FederatedGraphRag::new(FederationStrategy::RouteByCoverage);
572 fed.add_node(healthy_node("generic", 10));
573 fed.add_node(temporal_node("temporal_node"));
574 fed.add_corpus_passage("historical data", 0.85);
575
576 let mut q = make_query("historical data");
577 q.timestamp = Some(1_700_000_000_000); let result = fed.query(&q);
580 assert!(result.node_count > 0);
581 assert!(result.sources.contains(&"temporal_node".to_string()));
583 }
584
585 #[test]
588 fn test_load_balance_rotates_nodes() {
589 let mut fed = FederatedGraphRag::new(FederationStrategy::LoadBalance);
590 fed.add_node(healthy_node("N1", 10));
591 fed.add_node(healthy_node("N2", 10));
592 fed.add_corpus_passage("GraphRAG federation", 0.9);
593
594 let q = make_query("GraphRAG");
595 let r1 = fed.query(&q);
596 let r2 = fed.query(&q);
597
598 assert_eq!(r1.node_count, 1);
600 assert_eq!(r2.node_count, 1);
601 let _ = r1.sources;
603 let _ = r2.sources;
604 }
605
606 #[test]
609 fn test_federated_result_latency_non_negative() {
610 let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
611 fed.add_node(healthy_node("A", 10));
612 let result = fed.query(&make_query("test"));
613 let _ = result.total_latency_ms;
616 }
617
618 #[test]
621 fn test_local_rag_returns_matching_passage() {
622 let mut eng = LocalRagEngine::new();
623 eng.add_passage("GraphRAG combines graph and retrieval", 0.8);
624 eng.add_passage("Unrelated content here", 0.5);
625
626 let results = eng.query("GraphRAG retrieval", 5, "local");
627 assert!(!results.is_empty());
628 assert!(results[0].text.contains("GraphRAG"));
629 }
630
631 #[test]
632 fn test_local_rag_top_k_limit() {
633 let mut eng = LocalRagEngine::new();
634 for i in 0..10 {
635 eng.add_passage(format!("passage {i} keyword"), 0.5);
636 }
637 let results = eng.query("keyword", 3, "local");
638 assert!(results.len() <= 3);
639 }
640
641 #[test]
642 fn test_local_rag_no_match_returns_empty() {
643 let mut eng = LocalRagEngine::new();
644 eng.add_passage("Completely unrelated text", 0.5);
645 let results = eng.query("xyzzy", 5, "local");
646 assert!(results.is_empty());
647 }
648
649 #[test]
652 fn test_merge_indices_picks_best_score() {
653 let i1 = LocalIndex {
654 node_id: "A".to_string(),
655 entries: vec![("key1".to_string(), 0.5), ("key2".to_string(), 0.9)],
656 };
657 let i2 = LocalIndex {
658 node_id: "B".to_string(),
659 entries: vec![("key1".to_string(), 0.8), ("key3".to_string(), 0.7)],
660 };
661
662 let merged = FederatedIndexBuilder::merge_indices(vec![i1, i2]);
663 let key1 = merged
665 .entries
666 .iter()
667 .find(|(k, _, _)| k == "key1")
668 .expect("should succeed");
669 assert!((key1.1 - 0.8).abs() < 1e-9);
670 assert_eq!(key1.2, "B");
671 assert_eq!(merged.entries.len(), 3);
673 }
674
675 #[test]
676 fn test_merge_indices_sorted_descending() {
677 let i1 = LocalIndex {
678 node_id: "A".to_string(),
679 entries: vec![
680 ("low".to_string(), 0.1),
681 ("high".to_string(), 0.9),
682 ("mid".to_string(), 0.5),
683 ],
684 };
685 let merged = FederatedIndexBuilder::merge_indices(vec![i1]);
686 for i in 1..merged.entries.len() {
687 assert!(merged.entries[i - 1].1 >= merged.entries[i].1);
688 }
689 }
690
691 #[test]
692 fn test_merge_indices_empty_returns_empty() {
693 let merged = FederatedIndexBuilder::merge_indices(vec![]);
694 assert!(merged.entries.is_empty());
695 }
696
697 #[test]
700 fn test_shard_index_creates_correct_shard_count() {
701 let merged = MergedIndex {
702 entries: (0..10)
703 .map(|i| (format!("key{i}"), i as f64 * 0.1, "A".to_string()))
704 .collect(),
705 };
706 let shards = FederatedIndexBuilder::shard_index(&merged, 3);
707 assert_eq!(shards.len(), 3);
708 }
709
710 #[test]
711 fn test_shard_index_all_entries_distributed() {
712 let merged = MergedIndex {
713 entries: (0..9)
714 .map(|i| (format!("key{i}"), 0.5, "A".to_string()))
715 .collect(),
716 };
717 let shards = FederatedIndexBuilder::shard_index(&merged, 3);
718 let total: usize = shards.iter().map(|s| s.entries.len()).sum();
719 assert_eq!(total, 9);
720 }
721
722 #[test]
723 fn test_shard_index_zero_shards_returns_empty() {
724 let merged = MergedIndex {
725 entries: vec![("k".to_string(), 0.5, "A".to_string())],
726 };
727 let shards = FederatedIndexBuilder::shard_index(&merged, 0);
728 assert!(shards.is_empty());
729 }
730
731 #[test]
732 fn test_shard_index_ids_are_sequential() {
733 let merged = MergedIndex {
734 entries: (0..6)
735 .map(|i| (format!("k{i}"), 0.5, "A".to_string()))
736 .collect(),
737 };
738 let shards = FederatedIndexBuilder::shard_index(&merged, 3);
739 for (expected, shard) in shards.iter().enumerate() {
740 assert_eq!(shard.shard_id, expected);
741 }
742 }
743}