use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Instant;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RagResult {
pub text: String,
pub score: f64,
pub source: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FederationNode {
pub id: String,
pub endpoint: String,
pub capabilities: Vec<String>,
pub latency_ms: u64,
pub is_healthy: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum FederationStrategy {
BroadcastAll,
RouteByCoverage,
LoadBalance,
FailoverChain,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FederationRouter {
pub strategy: FederationStrategy,
pub health_check_interval_ms: u64,
}
impl FederationRouter {
pub fn new(strategy: FederationStrategy) -> Self {
Self {
strategy,
health_check_interval_ms: 30_000,
}
}
pub fn select_nodes<'a>(
&self,
nodes: &'a [FederationNode],
query: &FederatedQuery,
counter: &mut u64,
) -> Vec<&'a FederationNode> {
let healthy: Vec<&FederationNode> = nodes.iter().filter(|n| n.is_healthy).collect();
match &self.strategy {
FederationStrategy::BroadcastAll => healthy,
FederationStrategy::RouteByCoverage => {
if query.timestamp.is_some() {
let temporal: Vec<_> = healthy
.iter()
.copied()
.filter(|n| n.capabilities.iter().any(|c| c == "temporal"))
.collect();
if !temporal.is_empty() {
return temporal;
}
}
healthy
}
FederationStrategy::LoadBalance => {
if healthy.is_empty() {
return vec![];
}
let idx = (*counter as usize) % healthy.len();
*counter = counter.wrapping_add(1);
vec![healthy[idx]]
}
FederationStrategy::FailoverChain => {
let mut sorted = healthy.clone();
sorted.sort_by_key(|n| n.latency_ms);
sorted.into_iter().take(1).collect()
}
}
}
}
#[derive(Debug, Default)]
pub struct LocalRagEngine {
corpus: Vec<(String, f64)>, }
impl LocalRagEngine {
pub fn new() -> Self {
Self::default()
}
pub fn add_passage(&mut self, text: impl Into<String>, base_score: f64) {
self.corpus.push((text.into(), base_score.clamp(0.0, 1.0)));
}
pub fn query(&self, q: &str, top_k: usize, source: &str) -> Vec<RagResult> {
let keywords: Vec<&str> = q.split_whitespace().collect();
let mut scored: Vec<RagResult> = self
.corpus
.iter()
.filter_map(|(text, base)| {
let matched = keywords
.iter()
.filter(|kw| text.to_lowercase().contains(&kw.to_lowercase()))
.count();
if matched == 0 {
return None;
}
let kw_score = matched as f64 / keywords.len().max(1) as f64;
Some(RagResult {
text: text.clone(),
score: (base + kw_score) / 2.0,
source: source.to_string(),
})
})
.collect();
scored.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
scored.truncate(top_k);
scored
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FederatedQuery {
pub query: String,
pub timestamp: Option<i64>,
pub top_k: usize,
pub timeout_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FederatedResult {
pub results: Vec<RagResult>,
pub sources: Vec<String>,
pub total_latency_ms: u64,
pub node_count: usize,
}
pub struct FederatedGraphRag {
nodes: Vec<FederationNode>,
local_rag: LocalRagEngine,
router: FederationRouter,
lb_counter: u64,
}
impl FederatedGraphRag {
pub fn new(strategy: FederationStrategy) -> Self {
Self {
nodes: Vec::new(),
local_rag: LocalRagEngine::new(),
router: FederationRouter::new(strategy),
lb_counter: 0,
}
}
pub fn add_node(&mut self, node: FederationNode) {
self.nodes.push(node);
}
pub fn remove_node(&mut self, node_id: &str) -> bool {
let before = self.nodes.len();
self.nodes.retain(|n| n.id != node_id);
self.nodes.len() < before
}
pub fn query(&mut self, q: &FederatedQuery) -> FederatedResult {
let start = Instant::now();
let selected: Vec<String> = self
.router
.select_nodes(&self.nodes, q, &mut self.lb_counter)
.iter()
.map(|n| n.id.clone())
.collect();
let mut all_results: Vec<RagResult> = Vec::new();
let mut sources: Vec<String> = Vec::new();
for node_id in &selected {
let node_results = self.local_rag.query(&q.query, q.top_k, node_id);
if !node_results.is_empty() {
sources.push(node_id.clone());
all_results.extend(node_results);
}
}
let mut seen: HashMap<String, usize> = HashMap::new();
let mut merged: Vec<RagResult> = Vec::new();
for r in all_results {
match seen.get(&r.text) {
Some(&idx) if merged[idx].score >= r.score => {}
_ => {
let idx = merged.len();
seen.insert(r.text.clone(), idx);
merged.push(r);
}
}
}
merged.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
merged.truncate(q.top_k);
FederatedResult {
results: merged,
sources,
total_latency_ms: start.elapsed().as_millis() as u64,
node_count: selected.len(),
}
}
pub fn healthy_nodes(&self) -> Vec<&FederationNode> {
self.nodes.iter().filter(|n| n.is_healthy).collect()
}
pub fn mark_unhealthy(&mut self, node_id: &str) {
if let Some(node) = self.nodes.iter_mut().find(|n| n.id == node_id) {
node.is_healthy = false;
}
}
pub fn rebalance(&mut self) {
for node in &mut self.nodes {
node.is_healthy = true;
}
}
pub fn add_corpus_passage(&mut self, text: impl Into<String>, base_score: f64) {
self.local_rag.add_passage(text, base_score);
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LocalIndex {
pub node_id: String,
pub entries: Vec<(String, f64)>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MergedIndex {
pub entries: Vec<(String, f64, String)>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexShard {
pub shard_id: usize,
pub entries: Vec<(String, f64, String)>,
}
pub struct FederatedIndexBuilder;
impl FederatedIndexBuilder {
pub fn merge_indices(indices: Vec<LocalIndex>) -> MergedIndex {
let mut best: HashMap<String, (f64, String)> = HashMap::new();
for local in indices {
for (key, score) in local.entries {
let entry = best
.entry(key.clone())
.or_insert((f64::NEG_INFINITY, local.node_id.clone()));
if score > entry.0 {
*entry = (score, local.node_id.clone());
}
}
}
let mut entries: Vec<(String, f64, String)> =
best.into_iter().map(|(k, (s, n))| (k, s, n)).collect();
entries.sort_by(|(ka, sa, _), (kb, sb, _)| {
sb.partial_cmp(sa)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| ka.cmp(kb))
});
MergedIndex { entries }
}
pub fn shard_index(index: &MergedIndex, shard_count: usize) -> Vec<IndexShard> {
if shard_count == 0 {
return vec![];
}
let mut shards: Vec<IndexShard> = (0..shard_count)
.map(|id| IndexShard {
shard_id: id,
entries: Vec::new(),
})
.collect();
for (i, entry) in index.entries.iter().enumerate() {
shards[i % shard_count].entries.push(entry.clone());
}
shards
}
}
#[cfg(test)]
mod tests {
use super::*;
fn healthy_node(id: &str, latency: u64) -> FederationNode {
FederationNode {
id: id.to_string(),
endpoint: format!("http://{id}.example.com"),
capabilities: vec!["vector".to_string()],
latency_ms: latency,
is_healthy: true,
}
}
fn temporal_node(id: &str) -> FederationNode {
FederationNode {
id: id.to_string(),
endpoint: format!("http://{id}.example.com"),
capabilities: vec!["temporal".to_string(), "vector".to_string()],
latency_ms: 10,
is_healthy: true,
}
}
fn make_query(q: &str) -> FederatedQuery {
FederatedQuery {
query: q.to_string(),
timestamp: None,
top_k: 5,
timeout_ms: 1000,
}
}
#[test]
fn test_federation_node_fields() {
let node = healthy_node("node1", 50);
assert_eq!(node.id, "node1");
assert!(node.is_healthy);
assert_eq!(node.latency_ms, 50);
}
#[test]
fn test_add_and_remove_node() {
let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
fed.add_node(healthy_node("A", 10));
fed.add_node(healthy_node("B", 20));
assert_eq!(fed.node_count(), 2);
let removed = fed.remove_node("A");
assert!(removed);
assert_eq!(fed.node_count(), 1);
}
#[test]
fn test_remove_nonexistent_node_returns_false() {
let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
assert!(!fed.remove_node("ghost"));
}
#[test]
fn test_healthy_nodes_filters_unhealthy() {
let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
fed.add_node(healthy_node("A", 10));
fed.add_node(healthy_node("B", 10));
fed.mark_unhealthy("A");
assert_eq!(fed.healthy_nodes().len(), 1);
assert_eq!(fed.healthy_nodes()[0].id, "B");
}
#[test]
fn test_healthy_nodes_empty_federation() {
let fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
assert!(fed.healthy_nodes().is_empty());
}
#[test]
fn test_mark_unhealthy_sets_flag() {
let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
fed.add_node(healthy_node("A", 10));
fed.mark_unhealthy("A");
assert!(!fed.nodes[0].is_healthy);
}
#[test]
fn test_rebalance_restores_all_nodes() {
let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
fed.add_node(healthy_node("A", 10));
fed.add_node(healthy_node("B", 10));
fed.mark_unhealthy("A");
fed.mark_unhealthy("B");
assert_eq!(fed.healthy_nodes().len(), 0);
fed.rebalance();
assert_eq!(fed.healthy_nodes().len(), 2);
}
#[test]
fn test_query_broadcast_all_returns_merged_results() {
let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
fed.add_node(healthy_node("A", 10));
fed.add_node(healthy_node("B", 20));
fed.add_corpus_passage("Rust is a systems language", 0.9);
let result = fed.query(&make_query("Rust language"));
assert_eq!(result.node_count, 2);
assert!(!result.results.is_empty());
}
#[test]
fn test_query_with_no_healthy_nodes_returns_empty() {
let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
fed.add_node(healthy_node("A", 10));
fed.mark_unhealthy("A");
let result = fed.query(&make_query("anything"));
assert!(result.results.is_empty());
assert_eq!(result.node_count, 0);
}
#[test]
fn test_failover_chain_picks_fastest_node() {
let mut fed = FederatedGraphRag::new(FederationStrategy::FailoverChain);
fed.add_node(healthy_node("slow", 200));
fed.add_node(healthy_node("fast", 10));
fed.add_corpus_passage("Semantic Web SPARQL", 0.8);
let result = fed.query(&make_query("Semantic Web"));
assert_eq!(result.node_count, 1);
assert_eq!(result.sources[0], "fast");
}
#[test]
fn test_route_by_coverage_uses_temporal_node() {
let mut fed = FederatedGraphRag::new(FederationStrategy::RouteByCoverage);
fed.add_node(healthy_node("generic", 10));
fed.add_node(temporal_node("temporal_node"));
fed.add_corpus_passage("historical data", 0.85);
let mut q = make_query("historical data");
q.timestamp = Some(1_700_000_000_000);
let result = fed.query(&q);
assert!(result.node_count > 0);
assert!(result.sources.contains(&"temporal_node".to_string()));
}
#[test]
fn test_load_balance_rotates_nodes() {
let mut fed = FederatedGraphRag::new(FederationStrategy::LoadBalance);
fed.add_node(healthy_node("N1", 10));
fed.add_node(healthy_node("N2", 10));
fed.add_corpus_passage("GraphRAG federation", 0.9);
let q = make_query("GraphRAG");
let r1 = fed.query(&q);
let r2 = fed.query(&q);
assert_eq!(r1.node_count, 1);
assert_eq!(r2.node_count, 1);
let _ = r1.sources;
let _ = r2.sources;
}
#[test]
fn test_federated_result_latency_non_negative() {
let mut fed = FederatedGraphRag::new(FederationStrategy::BroadcastAll);
fed.add_node(healthy_node("A", 10));
let result = fed.query(&make_query("test"));
let _ = result.total_latency_ms;
}
#[test]
fn test_local_rag_returns_matching_passage() {
let mut eng = LocalRagEngine::new();
eng.add_passage("GraphRAG combines graph and retrieval", 0.8);
eng.add_passage("Unrelated content here", 0.5);
let results = eng.query("GraphRAG retrieval", 5, "local");
assert!(!results.is_empty());
assert!(results[0].text.contains("GraphRAG"));
}
#[test]
fn test_local_rag_top_k_limit() {
let mut eng = LocalRagEngine::new();
for i in 0..10 {
eng.add_passage(format!("passage {i} keyword"), 0.5);
}
let results = eng.query("keyword", 3, "local");
assert!(results.len() <= 3);
}
#[test]
fn test_local_rag_no_match_returns_empty() {
let mut eng = LocalRagEngine::new();
eng.add_passage("Completely unrelated text", 0.5);
let results = eng.query("xyzzy", 5, "local");
assert!(results.is_empty());
}
#[test]
fn test_merge_indices_picks_best_score() {
let i1 = LocalIndex {
node_id: "A".to_string(),
entries: vec![("key1".to_string(), 0.5), ("key2".to_string(), 0.9)],
};
let i2 = LocalIndex {
node_id: "B".to_string(),
entries: vec![("key1".to_string(), 0.8), ("key3".to_string(), 0.7)],
};
let merged = FederatedIndexBuilder::merge_indices(vec![i1, i2]);
let key1 = merged
.entries
.iter()
.find(|(k, _, _)| k == "key1")
.expect("should succeed");
assert!((key1.1 - 0.8).abs() < 1e-9);
assert_eq!(key1.2, "B");
assert_eq!(merged.entries.len(), 3);
}
#[test]
fn test_merge_indices_sorted_descending() {
let i1 = LocalIndex {
node_id: "A".to_string(),
entries: vec![
("low".to_string(), 0.1),
("high".to_string(), 0.9),
("mid".to_string(), 0.5),
],
};
let merged = FederatedIndexBuilder::merge_indices(vec![i1]);
for i in 1..merged.entries.len() {
assert!(merged.entries[i - 1].1 >= merged.entries[i].1);
}
}
#[test]
fn test_merge_indices_empty_returns_empty() {
let merged = FederatedIndexBuilder::merge_indices(vec![]);
assert!(merged.entries.is_empty());
}
#[test]
fn test_shard_index_creates_correct_shard_count() {
let merged = MergedIndex {
entries: (0..10)
.map(|i| (format!("key{i}"), i as f64 * 0.1, "A".to_string()))
.collect(),
};
let shards = FederatedIndexBuilder::shard_index(&merged, 3);
assert_eq!(shards.len(), 3);
}
#[test]
fn test_shard_index_all_entries_distributed() {
let merged = MergedIndex {
entries: (0..9)
.map(|i| (format!("key{i}"), 0.5, "A".to_string()))
.collect(),
};
let shards = FederatedIndexBuilder::shard_index(&merged, 3);
let total: usize = shards.iter().map(|s| s.entries.len()).sum();
assert_eq!(total, 9);
}
#[test]
fn test_shard_index_zero_shards_returns_empty() {
let merged = MergedIndex {
entries: vec![("k".to_string(), 0.5, "A".to_string())],
};
let shards = FederatedIndexBuilder::shard_index(&merged, 0);
assert!(shards.is_empty());
}
#[test]
fn test_shard_index_ids_are_sequential() {
let merged = MergedIndex {
entries: (0..6)
.map(|i| (format!("k{i}"), 0.5, "A".to_string()))
.collect(),
};
let shards = FederatedIndexBuilder::shard_index(&merged, 3);
for (expected, shard) in shards.iter().enumerate() {
assert_eq!(shard.shard_id, expected);
}
}
}