use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::Arc;
use std::time::{Duration, Instant};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::sync::{Mutex, RwLock, Semaphore};
use tracing::{debug, info, warn};
use crate::{GraphRAGError, GraphRAGResult, ScoredEntity, Triple};
#[derive(Error, Debug)]
pub enum DistributedError {
#[error("Endpoint {endpoint} is unreachable: {reason}")]
EndpointUnreachable { endpoint: String, reason: String },
#[error("Authentication failed for endpoint {endpoint}")]
AuthFailed { endpoint: String },
#[error("SPARQL query timeout after {timeout_ms}ms on endpoint {endpoint}")]
QueryTimeout { endpoint: String, timeout_ms: u64 },
#[error("Entity resolution cycle detected for URI {uri}")]
SameAsCycle { uri: String },
#[error("No healthy endpoints available for query")]
NoHealthyEndpoints,
#[error("Merge conflict: cannot reconcile {uri} across endpoints")]
MergeConflict { uri: String },
#[error("Configuration invalid: {0}")]
InvalidConfig(String),
}
impl From<DistributedError> for GraphRAGError {
fn from(e: DistributedError) -> Self {
GraphRAGError::InternalError(e.to_string())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum EndpointAuth {
None,
Bearer { token: String },
Basic { username: String, password: String },
ApiKey { header: String, key: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EndpointConfig {
pub name: String,
pub url: String,
pub auth: EndpointAuth,
pub timeout_ms: Option<u64>,
pub priority: f64,
pub enabled: bool,
pub graph_uri: Option<String>,
pub max_triples: usize,
}
impl Default for EndpointConfig {
fn default() -> Self {
Self {
name: String::new(),
url: String::new(),
auth: EndpointAuth::None,
timeout_ms: None,
priority: 1.0,
enabled: true,
graph_uri: None,
max_triples: 10_000,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FederatedGraphRAGConfig {
pub endpoints: Vec<EndpointConfig>,
pub global_timeout_ms: u64,
pub max_concurrency: usize,
pub same_as_max_depth: usize,
pub min_endpoint_priority: f64,
pub partial_results_ok: bool,
pub retry_count: usize,
pub retry_delay_ms: u64,
}
impl Default for FederatedGraphRAGConfig {
fn default() -> Self {
Self {
endpoints: vec![],
global_timeout_ms: 30_000,
max_concurrency: 8,
same_as_max_depth: 5,
min_endpoint_priority: 0.0,
partial_results_ok: true,
retry_count: 2,
retry_delay_ms: 500,
}
}
}
impl FederatedGraphRAGConfig {
pub fn validate(&self) -> Result<(), DistributedError> {
if self.global_timeout_ms == 0 {
return Err(DistributedError::InvalidConfig(
"global_timeout_ms must be > 0".into(),
));
}
if self.max_concurrency == 0 {
return Err(DistributedError::InvalidConfig(
"max_concurrency must be > 0".into(),
));
}
if self.same_as_max_depth == 0 {
return Err(DistributedError::InvalidConfig(
"same_as_max_depth must be > 0".into(),
));
}
for ep in &self.endpoints {
if ep.url.is_empty() {
return Err(DistributedError::InvalidConfig(format!(
"Endpoint '{}' has an empty URL",
ep.name
)));
}
if ep.max_triples == 0 {
return Err(DistributedError::InvalidConfig(format!(
"Endpoint '{}' max_triples must be > 0",
ep.name
)));
}
}
Ok(())
}
pub fn active_endpoints(&self) -> Vec<&EndpointConfig> {
self.endpoints
.iter()
.filter(|ep| ep.enabled && ep.priority >= self.min_endpoint_priority)
.collect()
}
}
#[derive(Debug, Clone, Default)]
pub struct KnowledgeGraph {
pub triples: Vec<Triple>,
pub provenance: Vec<String>,
pub equivalence_classes: Vec<HashSet<String>>,
pub canonical_uris: HashMap<String, String>,
}
impl KnowledgeGraph {
pub fn new() -> Self {
Self::default()
}
pub fn triple_count(&self) -> usize {
self.triples.len()
}
pub fn is_empty(&self) -> bool {
self.triples.is_empty()
}
pub fn canonical<'a>(&'a self, uri: &'a str) -> &'a str {
self.canonical_uris
.get(uri)
.map(|s| s.as_str())
.unwrap_or(uri)
}
}
#[derive(Debug)]
struct EndpointResult {
endpoint_name: String,
triples: Vec<Triple>,
latency_ms: u64,
}
fn build_seed_expansion_sparql(seeds: &[&str], graph_uri: Option<&str>, limit: usize) -> String {
let values: Vec<String> = seeds.iter().map(|s| format!("<{}>", s)).collect();
let values_block = values.join(" ");
let from_clause = match graph_uri {
Some(g) => format!("FROM <{}>", g),
None => String::new(),
};
format!(
r#"CONSTRUCT {{
?s ?p ?o .
}}
{from}
WHERE {{
VALUES ?seed {{ {seeds} }}
{{
BIND(?seed AS ?s)
?s ?p ?o .
}} UNION {{
?s ?p ?seed .
BIND(?seed AS ?o)
}}
}}
LIMIT {limit}
"#,
from = from_clause,
seeds = values_block,
limit = limit,
)
}
fn build_same_as_sparql(uris: &[&str], graph_uri: Option<&str>) -> String {
let values: Vec<String> = uris.iter().map(|s| format!("<{}>", s)).collect();
let values_block = values.join(" ");
let from_clause = match graph_uri {
Some(g) => format!("FROM <{}>", g),
None => String::new(),
};
format!(
r#"SELECT DISTINCT ?a ?b
{from}
WHERE {{
VALUES ?a {{ {uris} }}
{{
?a <http://www.w3.org/2002/07/owl#sameAs> ?b .
}} UNION {{
?b <http://www.w3.org/2002/07/owl#sameAs> ?a .
}}
}}
"#,
from = from_clause,
uris = values_block,
)
}
#[async_trait::async_trait]
pub trait EndpointExecutor: Send + Sync {
async fn construct(
&self,
endpoint: &EndpointConfig,
sparql: &str,
timeout: Duration,
) -> GraphRAGResult<Vec<Triple>>;
async fn select(
&self,
endpoint: &EndpointConfig,
sparql: &str,
timeout: Duration,
) -> GraphRAGResult<Vec<HashMap<String, String>>>;
}
pub struct HttpEndpointExecutor {
client: reqwest::Client,
}
impl HttpEndpointExecutor {
pub fn new() -> GraphRAGResult<Self> {
let client = reqwest::Client::builder()
.build()
.map_err(|e| GraphRAGError::InternalError(format!("HTTP client init: {e}")))?;
Ok(Self { client })
}
fn apply_auth(
&self,
builder: reqwest::RequestBuilder,
auth: &EndpointAuth,
) -> reqwest::RequestBuilder {
match auth {
EndpointAuth::None => builder,
EndpointAuth::Bearer { token } => {
builder.header("Authorization", format!("Bearer {}", token))
}
EndpointAuth::Basic { username, password } => {
builder.basic_auth(username, Some(password))
}
EndpointAuth::ApiKey { header, key } => builder.header(header.as_str(), key.as_str()),
}
}
}
#[async_trait::async_trait]
impl EndpointExecutor for HttpEndpointExecutor {
async fn construct(
&self,
endpoint: &EndpointConfig,
sparql: &str,
timeout: Duration,
) -> GraphRAGResult<Vec<Triple>> {
let builder: reqwest::RequestBuilder = self
.client
.post(&endpoint.url)
.timeout(timeout)
.header("Content-Type", "application/sparql-query")
.header("Accept", "application/n-triples")
.body(sparql.to_string());
let builder = self.apply_auth(builder, &endpoint.auth);
let response = builder
.send()
.await
.map_err(|e| GraphRAGError::SparqlError(format!("HTTP error: {e}")))?;
let status = response.status();
if !status.is_success() {
return Err(GraphRAGError::SparqlError(format!(
"HTTP {} from {}",
status, endpoint.url
)));
}
let body = response
.text()
.await
.map_err(|e| GraphRAGError::SparqlError(format!("Response read error: {e}")))?;
parse_n_triples(&body)
}
async fn select(
&self,
endpoint: &EndpointConfig,
sparql: &str,
timeout: Duration,
) -> GraphRAGResult<Vec<HashMap<String, String>>> {
let builder: reqwest::RequestBuilder = self
.client
.post(&endpoint.url)
.timeout(timeout)
.header("Content-Type", "application/sparql-query")
.header("Accept", "application/sparql-results+json")
.body(sparql.to_string());
let builder = self.apply_auth(builder, &endpoint.auth);
let response = builder
.send()
.await
.map_err(|e| GraphRAGError::SparqlError(format!("HTTP error: {e}")))?;
let status = response.status();
if !status.is_success() {
return Err(GraphRAGError::SparqlError(format!(
"HTTP {} from {}",
status, endpoint.url
)));
}
let body = response
.text()
.await
.map_err(|e| GraphRAGError::SparqlError(format!("Response read error: {e}")))?;
parse_sparql_json_results(&body)
}
}
fn parse_n_triples(body: &str) -> GraphRAGResult<Vec<Triple>> {
let mut triples = Vec::new();
for line in body.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let tokens: Vec<&str> = line.splitn(4, ' ').collect();
if tokens.len() < 3 {
continue;
}
let s = strip_angle_brackets(tokens[0]);
let p = strip_angle_brackets(tokens[1]);
let o = if tokens[2].starts_with('<') {
strip_angle_brackets(tokens[2]).to_string()
} else {
tokens[2].to_string()
};
if !s.is_empty() && !p.is_empty() {
triples.push(Triple::new(s, p, o));
}
}
Ok(triples)
}
fn strip_angle_brackets(s: &str) -> &str {
s.trim_start_matches('<').trim_end_matches('>')
}
fn parse_sparql_json_results(body: &str) -> GraphRAGResult<Vec<HashMap<String, String>>> {
let json: serde_json::Value = serde_json::from_str(body)
.map_err(|e| GraphRAGError::InternalError(format!("JSON parse error: {e}")))?;
let vars: Vec<String> = json["head"]["vars"]
.as_array()
.unwrap_or(&vec![])
.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect();
let bindings = json["results"]["bindings"]
.as_array()
.unwrap_or(&vec![])
.clone();
let mut rows = Vec::new();
for binding in bindings {
let mut row = HashMap::new();
for var in &vars {
if let Some(val) = binding.get(var) {
let value = val["value"].as_str().unwrap_or("").to_string();
row.insert(var.clone(), value);
}
}
rows.push(row);
}
Ok(rows)
}
pub struct FederatedSubgraphExpander<E: EndpointExecutor> {
config: FederatedGraphRAGConfig,
executor: Arc<E>,
}
impl<E: EndpointExecutor + 'static> FederatedSubgraphExpander<E> {
pub fn new(config: FederatedGraphRAGConfig, executor: Arc<E>) -> Self {
Self { config, executor }
}
pub async fn expand_federated(
&self,
seeds: &[ScoredEntity],
endpoints: Option<&[String]>,
) -> GraphRAGResult<KnowledgeGraph> {
if seeds.is_empty() {
return Ok(KnowledgeGraph::new());
}
let seed_uris: Vec<&str> = seeds.iter().map(|s| s.uri.as_str()).collect();
let active: Vec<&EndpointConfig> = match endpoints {
Some(names) => self
.config
.active_endpoints()
.into_iter()
.filter(|ep| names.iter().any(|n| n == &ep.name))
.collect(),
None => self.config.active_endpoints(),
};
if active.is_empty() {
return Err(DistributedError::NoHealthyEndpoints.into());
}
info!(
"Federated expansion: {} seeds across {} endpoints",
seeds.len(),
active.len()
);
let semaphore = Arc::new(Semaphore::new(self.config.max_concurrency));
let results: Arc<Mutex<Vec<EndpointResult>>> = Arc::new(Mutex::new(Vec::new()));
let mut handles = Vec::new();
for ep in active {
let ep = ep.clone();
let executor = Arc::clone(&self.executor);
let sem = Arc::clone(&semaphore);
let results = Arc::clone(&results);
let seed_uris: Vec<String> = seed_uris.iter().map(|s| s.to_string()).collect();
let timeout_ms = ep.timeout_ms.unwrap_or(self.config.global_timeout_ms);
let timeout = Duration::from_millis(timeout_ms);
let retry_count = self.config.retry_count;
let retry_delay = Duration::from_millis(self.config.retry_delay_ms);
let partial_ok = self.config.partial_results_ok;
let handle = tokio::spawn(async move {
let _permit = match sem.acquire_owned().await {
Ok(p) => p,
Err(e) => {
warn!("Semaphore acquire failed: {e}");
return;
}
};
let sparql = build_seed_expansion_sparql(
&seed_uris.iter().map(|s| s.as_str()).collect::<Vec<_>>(),
ep.graph_uri.as_deref(),
ep.max_triples,
);
let start = Instant::now();
let mut last_err = None;
for attempt in 0..=retry_count {
if attempt > 0 {
tokio::time::sleep(retry_delay).await;
}
match executor.construct(&ep, &sparql, timeout).await {
Ok(triples) => {
let latency_ms = start.elapsed().as_millis() as u64;
debug!(
endpoint = %ep.name,
triples = triples.len(),
latency_ms,
"Endpoint query succeeded"
);
let mut guard = results.lock().await;
guard.push(EndpointResult {
endpoint_name: ep.name.clone(),
triples,
latency_ms,
});
return;
}
Err(e) => {
warn!(
endpoint = %ep.name,
attempt,
error = %e,
"Endpoint query failed"
);
last_err = Some(e);
}
}
}
if !partial_ok {
warn!(
endpoint = %ep.name,
error = ?last_err,
"Endpoint permanently failed and partial_results_ok=false"
);
}
});
handles.push(handle);
}
for h in handles {
if let Err(e) = h.await {
warn!("Task join error: {e}");
}
}
let endpoint_results = Arc::try_unwrap(results)
.map_err(|_| GraphRAGError::InternalError("Arc unwrap failed".into()))?
.into_inner();
if endpoint_results.is_empty() && !self.config.partial_results_ok {
return Err(DistributedError::NoHealthyEndpoints.into());
}
self.merge_results(endpoint_results)
}
fn merge_results(&self, results: Vec<EndpointResult>) -> GraphRAGResult<KnowledgeGraph> {
let mut kg = KnowledgeGraph::new();
let mut seen: HashSet<(String, String, String)> = HashSet::new();
let mut priority_map: HashMap<String, f64> = HashMap::new();
for ep in &self.config.endpoints {
priority_map.insert(ep.name.clone(), ep.priority);
}
let mut sorted_results = results;
sorted_results.sort_by(|a, b| {
let pa = priority_map.get(&a.endpoint_name).copied().unwrap_or(1.0);
let pb = priority_map.get(&b.endpoint_name).copied().unwrap_or(1.0);
pb.partial_cmp(&pa).unwrap_or(std::cmp::Ordering::Equal)
});
for result in sorted_results {
for triple in result.triples {
let key = (
triple.subject.clone(),
triple.predicate.clone(),
triple.object.clone(),
);
if seen.insert(key) {
kg.triples.push(triple);
kg.provenance.push(result.endpoint_name.clone());
}
}
}
Ok(kg)
}
}
pub struct DistributedEntityResolver<E: EndpointExecutor> {
config: FederatedGraphRAGConfig,
executor: Arc<E>,
}
impl<E: EndpointExecutor + 'static> DistributedEntityResolver<E> {
pub fn new(config: FederatedGraphRAGConfig, executor: Arc<E>) -> Self {
Self { config, executor }
}
pub async fn same_as_closure(
&self,
uris: &[String],
) -> GraphRAGResult<HashMap<String, String>> {
if uris.is_empty() {
return Ok(HashMap::new());
}
let parent: Arc<RwLock<HashMap<String, String>>> = Arc::new(RwLock::new(HashMap::new()));
{
let mut p = parent.write().await;
for uri in uris {
p.insert(uri.clone(), uri.clone());
}
}
let mut frontier: VecDeque<String> = uris.iter().cloned().collect();
let mut visited: HashSet<String> = HashSet::from_iter(uris.iter().cloned());
let mut depth = 0usize;
while !frontier.is_empty() && depth < self.config.same_as_max_depth {
let batch: Vec<String> = frontier.drain(..).collect();
let batch_refs: Vec<&str> = batch.iter().map(|s| s.as_str()).collect();
let links = self.fetch_same_as_links(&batch_refs).await?;
let mut p = parent.write().await;
for (a, b) in links {
p.entry(a.clone()).or_insert_with(|| a.clone());
p.entry(b.clone()).or_insert_with(|| b.clone());
let root_a = find_root_path(&p, &a);
let root_b = find_root_path(&p, &b);
if root_a != root_b {
let canonical = if root_a <= root_b {
root_a.clone()
} else {
root_b.clone()
};
p.insert(root_a, canonical.clone());
p.insert(root_b, canonical);
}
if !visited.contains(&b) {
visited.insert(b.clone());
frontier.push_back(b);
}
}
depth += 1;
}
let p = parent.read().await;
let mut result = HashMap::new();
for uri in p.keys() {
let canonical = find_root_path(&p, uri);
result.insert(uri.clone(), canonical);
}
Ok(result)
}
async fn fetch_same_as_links(&self, uris: &[&str]) -> GraphRAGResult<Vec<(String, String)>> {
let active = self.config.active_endpoints();
let semaphore = Arc::new(Semaphore::new(self.config.max_concurrency));
let pairs: Arc<Mutex<Vec<(String, String)>>> = Arc::new(Mutex::new(Vec::new()));
let mut handles = Vec::new();
for ep in active {
let ep = ep.clone();
let executor = Arc::clone(&self.executor);
let sem = Arc::clone(&semaphore);
let pairs = Arc::clone(&pairs);
let uris_owned: Vec<String> = uris.iter().map(|s| s.to_string()).collect();
let timeout_ms = ep.timeout_ms.unwrap_or(self.config.global_timeout_ms);
let timeout = Duration::from_millis(timeout_ms);
let handle = tokio::spawn(async move {
let _permit = match sem.acquire_owned().await {
Ok(p) => p,
Err(_) => return,
};
let sparql = build_same_as_sparql(
&uris_owned.iter().map(|s| s.as_str()).collect::<Vec<_>>(),
ep.graph_uri.as_deref(),
);
match executor.select(&ep, &sparql, timeout).await {
Ok(rows) => {
let mut guard = pairs.lock().await;
for row in rows {
if let (Some(a), Some(b)) = (row.get("a"), row.get("b")) {
guard.push((a.clone(), b.clone()));
}
}
}
Err(e) => {
debug!(endpoint = %ep.name, error = %e, "sameAs fetch failed");
}
}
});
handles.push(handle);
}
for h in handles {
let _ = h.await;
}
let guard = Arc::try_unwrap(pairs)
.map_err(|_| GraphRAGError::InternalError("Arc unwrap failed".into()))?
.into_inner();
Ok(guard)
}
pub fn apply_to_graph(&self, kg: &mut KnowledgeGraph, canonical_map: &HashMap<String, String>) {
let canonicalize = |s: &str| -> String {
canonical_map
.get(s)
.cloned()
.unwrap_or_else(|| s.to_string())
};
let mut seen: HashSet<(String, String, String)> = HashSet::new();
let mut new_triples = Vec::new();
let mut new_provenance = Vec::new();
for (triple, prov) in kg.triples.iter().zip(kg.provenance.iter()) {
let s = canonicalize(&triple.subject);
let p = triple.predicate.clone();
let o = canonicalize(&triple.object);
let key = (s.clone(), p.clone(), o.clone());
if seen.insert(key) {
new_triples.push(Triple::new(s, p, o));
new_provenance.push(prov.clone());
}
}
kg.triples = new_triples;
kg.provenance = new_provenance;
kg.canonical_uris = canonical_map.clone();
let mut classes: HashMap<String, HashSet<String>> = HashMap::new();
for (uri, canonical) in canonical_map {
classes
.entry(canonical.clone())
.or_default()
.insert(uri.clone());
}
kg.equivalence_classes = classes.into_values().collect();
}
}
fn find_root_path(parent: &HashMap<String, String>, uri: &str) -> String {
let mut current = uri.to_string();
let mut depth = 0usize;
loop {
let next = parent
.get(¤t)
.cloned()
.unwrap_or_else(|| current.clone());
if next == current || depth > 100 {
return current;
}
current = next;
depth += 1;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ContextOrderingStrategy {
ByEndpointPriority,
ByLatency,
Insertion,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FederatedContextConfig {
pub max_context_triples: usize,
pub max_context_chars: usize,
pub ordering: ContextOrderingStrategy,
pub include_provenance: bool,
pub min_endpoint_priority: f64,
pub include_equivalences: bool,
}
impl Default for FederatedContextConfig {
fn default() -> Self {
Self {
max_context_triples: 500,
max_context_chars: 50_000,
ordering: ContextOrderingStrategy::ByEndpointPriority,
include_provenance: false,
include_equivalences: false,
min_endpoint_priority: 0.0,
}
}
}
pub struct FederatedContextBuilder {
config: FederatedContextConfig,
endpoint_priorities: HashMap<String, f64>,
endpoint_latencies: Arc<RwLock<HashMap<String, u64>>>,
}
impl FederatedContextBuilder {
pub fn new(config: FederatedContextConfig, graphrag_config: &FederatedGraphRAGConfig) -> Self {
let endpoint_priorities: HashMap<String, f64> = graphrag_config
.endpoints
.iter()
.map(|ep| (ep.name.clone(), ep.priority))
.collect();
Self {
config,
endpoint_priorities,
endpoint_latencies: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn record_latency(&self, endpoint_name: &str, latency_ms: u64) {
let mut lats = self.endpoint_latencies.write().await;
lats.insert(endpoint_name.to_string(), latency_ms);
}
pub async fn build_context(&self, kg: &KnowledgeGraph, query: &str) -> GraphRAGResult<String> {
if kg.is_empty() {
return Ok(String::new());
}
let latencies = self.endpoint_latencies.read().await;
let mut indexed: Vec<(usize, f64)> = kg
.triples
.iter()
.enumerate()
.map(|(i, _)| {
let ep = kg.provenance.get(i).map(|s| s.as_str()).unwrap_or("");
let sort_key = match self.config.ordering {
ContextOrderingStrategy::ByEndpointPriority => {
self.endpoint_priorities.get(ep).copied().unwrap_or(1.0)
}
ContextOrderingStrategy::ByLatency => {
let lat = latencies.get(ep).copied().unwrap_or(u64::MAX);
1.0 / (lat as f64 + 1.0)
}
ContextOrderingStrategy::Insertion => i as f64,
};
(i, sort_key)
})
.filter(|(i, _)| {
let ep = kg.provenance.get(*i).map(|s| s.as_str()).unwrap_or("");
let prio = self.endpoint_priorities.get(ep).copied().unwrap_or(1.0);
prio >= self.config.min_endpoint_priority
})
.collect();
match self.config.ordering {
ContextOrderingStrategy::ByEndpointPriority | ContextOrderingStrategy::ByLatency => {
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
}
ContextOrderingStrategy::Insertion => {
indexed.sort_by_key(|(i, _)| *i);
}
}
let mut context = format!("## Knowledge Graph Context\n\nQuery: {}\n\n", query);
if self.config.include_equivalences && !kg.equivalence_classes.is_empty() {
context.push_str("### Entity Equivalences\n");
for class in &kg.equivalence_classes {
if class.len() > 1 {
let mut members: Vec<&str> = class.iter().map(|s| s.as_str()).collect();
members.sort();
context.push_str(&format!("- {}\n", members.join(" ≡ ")));
}
}
context.push('\n');
}
context.push_str("### Facts\n\n");
for (triple_count, (idx, _)) in indexed.iter().enumerate() {
if triple_count >= self.config.max_context_triples {
break;
}
if context.len() >= self.config.max_context_chars {
break;
}
let triple = &kg.triples[*idx];
let line = if self.config.include_provenance {
let ep = kg.provenance.get(*idx).map(|s| s.as_str()).unwrap_or("?");
format!(
"- {} → {} → {} [{}]\n",
triple.subject, triple.predicate, triple.object, ep
)
} else {
format!(
"- {} → {} → {}\n",
triple.subject, triple.predicate, triple.object
)
};
context.push_str(&line);
}
Ok(context)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EndpointMetrics {
pub name: String,
pub total_queries: u64,
pub successful_queries: u64,
pub failed_queries: u64,
pub total_triples: u64,
pub avg_latency_ms: f64,
pub min_latency_ms: u64,
pub max_latency_ms: u64,
pub hit_rate: f64,
}
impl EndpointMetrics {
fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
total_queries: 0,
successful_queries: 0,
failed_queries: 0,
total_triples: 0,
avg_latency_ms: 0.0,
min_latency_ms: u64::MAX,
max_latency_ms: 0,
hit_rate: 0.0,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AggregateMetrics {
pub total_federation_queries: u64,
pub total_triples_gathered: u64,
pub entity_resolution_ops: u64,
pub avg_federation_latency_ms: f64,
pub partial_failure_count: u64,
}
pub struct DistributedGraphRAGMetrics {
endpoint_metrics: Arc<RwLock<HashMap<String, EndpointMetrics>>>,
aggregate: Arc<RwLock<AggregateMetrics>>,
ema_alpha: f64,
}
impl DistributedGraphRAGMetrics {
pub fn new(endpoints: &[EndpointConfig]) -> Self {
let mut ep_map = HashMap::new();
for ep in endpoints {
ep_map.insert(ep.name.clone(), EndpointMetrics::new(&ep.name));
}
Self {
endpoint_metrics: Arc::new(RwLock::new(ep_map)),
aggregate: Arc::new(RwLock::new(AggregateMetrics::default())),
ema_alpha: 0.2,
}
}
pub async fn record_success(&self, endpoint_name: &str, latency_ms: u64, triple_count: usize) {
let mut guard = self.endpoint_metrics.write().await;
let m = guard
.entry(endpoint_name.to_string())
.or_insert_with(|| EndpointMetrics::new(endpoint_name));
m.total_queries += 1;
m.successful_queries += 1;
m.total_triples += triple_count as u64;
if m.total_queries == 1 {
m.avg_latency_ms = latency_ms as f64;
} else {
m.avg_latency_ms =
self.ema_alpha * latency_ms as f64 + (1.0 - self.ema_alpha) * m.avg_latency_ms;
}
if latency_ms < m.min_latency_ms {
m.min_latency_ms = latency_ms;
}
if latency_ms > m.max_latency_ms {
m.max_latency_ms = latency_ms;
}
let hits = m.successful_queries - if triple_count == 0 { 1 } else { 0 };
m.hit_rate = hits as f64 / m.total_queries as f64;
}
pub async fn record_failure(&self, endpoint_name: &str) {
let mut guard = self.endpoint_metrics.write().await;
let m = guard
.entry(endpoint_name.to_string())
.or_insert_with(|| EndpointMetrics::new(endpoint_name));
m.total_queries += 1;
m.failed_queries += 1;
m.hit_rate = if m.total_queries > 0 {
m.successful_queries as f64 / m.total_queries as f64
} else {
0.0
};
}
pub async fn record_federation_query(
&self,
wall_latency_ms: u64,
total_triples: usize,
had_partial_failure: bool,
) {
let mut agg = self.aggregate.write().await;
agg.total_federation_queries += 1;
agg.total_triples_gathered += total_triples as u64;
if had_partial_failure {
agg.partial_failure_count += 1;
}
if agg.total_federation_queries == 1 {
agg.avg_federation_latency_ms = wall_latency_ms as f64;
} else {
agg.avg_federation_latency_ms = self.ema_alpha * wall_latency_ms as f64
+ (1.0 - self.ema_alpha) * agg.avg_federation_latency_ms;
}
}
pub async fn record_entity_resolution(&self) {
let mut agg = self.aggregate.write().await;
agg.entity_resolution_ops += 1;
}
pub async fn endpoint_snapshot(&self, name: &str) -> Option<EndpointMetrics> {
self.endpoint_metrics.read().await.get(name).cloned()
}
pub async fn all_endpoint_snapshots(&self) -> Vec<EndpointMetrics> {
self.endpoint_metrics
.read()
.await
.values()
.cloned()
.collect()
}
pub async fn aggregate_snapshot(&self) -> AggregateMetrics {
self.aggregate.read().await.clone()
}
pub async fn fastest_endpoint(&self) -> Option<String> {
let guard = self.endpoint_metrics.read().await;
guard
.values()
.filter(|m| m.successful_queries > 0)
.min_by(|a, b| {
a.avg_latency_ms
.partial_cmp(&b.avg_latency_ms)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|m| m.name.clone())
}
pub async fn best_hit_rate_endpoint(&self) -> Option<String> {
let guard = self.endpoint_metrics.read().await;
guard
.values()
.filter(|m| m.total_queries > 0)
.max_by(|a, b| {
a.hit_rate
.partial_cmp(&b.hit_rate)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|m| m.name.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{GraphRAGResult, ScoreSource};
use async_trait::async_trait;
use std::collections::HashMap;
struct MockExecutor {
triples_by_endpoint: HashMap<String, Vec<Triple>>,
same_as_by_endpoint: HashMap<String, Vec<(String, String)>>,
}
impl MockExecutor {
fn new() -> Self {
Self {
triples_by_endpoint: HashMap::new(),
same_as_by_endpoint: HashMap::new(),
}
}
fn with_triples(mut self, endpoint: &str, triples: Vec<Triple>) -> Self {
self.triples_by_endpoint
.insert(endpoint.to_string(), triples);
self
}
fn with_same_as(mut self, endpoint: &str, pairs: Vec<(String, String)>) -> Self {
self.same_as_by_endpoint.insert(endpoint.to_string(), pairs);
self
}
}
#[async_trait]
impl EndpointExecutor for MockExecutor {
async fn construct(
&self,
endpoint: &EndpointConfig,
_sparql: &str,
_timeout: Duration,
) -> GraphRAGResult<Vec<Triple>> {
Ok(self
.triples_by_endpoint
.get(&endpoint.name)
.cloned()
.unwrap_or_default())
}
async fn select(
&self,
endpoint: &EndpointConfig,
_sparql: &str,
_timeout: Duration,
) -> GraphRAGResult<Vec<HashMap<String, String>>> {
let pairs = self
.same_as_by_endpoint
.get(&endpoint.name)
.cloned()
.unwrap_or_default();
Ok(pairs
.into_iter()
.map(|(a, b)| {
let mut m = HashMap::new();
m.insert("a".to_string(), a);
m.insert("b".to_string(), b);
m
})
.collect())
}
}
fn make_endpoint(name: &str, priority: f64) -> EndpointConfig {
EndpointConfig {
name: name.to_string(),
url: format!("http://example.org/{}/sparql", name),
auth: EndpointAuth::None,
timeout_ms: Some(5_000),
priority,
enabled: true,
graph_uri: None,
max_triples: 1_000,
}
}
fn make_seed(uri: &str, score: f64) -> ScoredEntity {
ScoredEntity {
uri: uri.to_string(),
score,
source: ScoreSource::Vector,
metadata: HashMap::new(),
}
}
fn make_triple(s: &str, p: &str, o: &str) -> Triple {
Triple::new(s, p, o)
}
#[test]
fn test_federated_config_validation_valid() {
let config = FederatedGraphRAGConfig {
endpoints: vec![make_endpoint("ep1", 1.0)],
global_timeout_ms: 10_000,
max_concurrency: 4,
same_as_max_depth: 3,
..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn test_federated_config_validation_zero_timeout() {
let config = FederatedGraphRAGConfig {
global_timeout_ms: 0,
..Default::default()
};
assert!(config.validate().is_err());
}
#[test]
fn test_federated_config_validation_zero_concurrency() {
let config = FederatedGraphRAGConfig {
max_concurrency: 0,
global_timeout_ms: 1_000,
..Default::default()
};
assert!(config.validate().is_err());
}
#[test]
fn test_federated_config_validation_empty_url() {
let mut ep = make_endpoint("ep1", 1.0);
ep.url = String::new();
let config = FederatedGraphRAGConfig {
endpoints: vec![ep],
global_timeout_ms: 5_000,
max_concurrency: 2,
same_as_max_depth: 3,
..Default::default()
};
assert!(config.validate().is_err());
}
#[test]
fn test_federated_config_active_endpoints_filters_disabled() {
let mut ep_disabled = make_endpoint("ep_off", 1.0);
ep_disabled.enabled = false;
let config = FederatedGraphRAGConfig {
endpoints: vec![make_endpoint("ep_on", 1.0), ep_disabled],
global_timeout_ms: 5_000,
max_concurrency: 2,
same_as_max_depth: 3,
..Default::default()
};
let active = config.active_endpoints();
assert_eq!(active.len(), 1);
assert_eq!(active[0].name, "ep_on");
}
#[tokio::test]
async fn test_federated_expansion_merges_two_endpoints() {
let triples_a = vec![
make_triple("http://a/s1", "http://p", "http://a/o1"),
make_triple("http://a/s2", "http://p", "http://a/o2"),
];
let triples_b = vec![
make_triple("http://b/s1", "http://p", "http://b/o1"),
make_triple("http://a/s1", "http://p", "http://a/o1"), ];
let executor = MockExecutor::new()
.with_triples("ep_a", triples_a)
.with_triples("ep_b", triples_b);
let config = FederatedGraphRAGConfig {
endpoints: vec![make_endpoint("ep_a", 2.0), make_endpoint("ep_b", 1.0)],
global_timeout_ms: 5_000,
max_concurrency: 4,
same_as_max_depth: 3,
partial_results_ok: true,
..Default::default()
};
let expander = FederatedSubgraphExpander::new(config, Arc::new(executor));
let seeds = vec![make_seed("http://a/s1", 0.9)];
let kg = expander
.expand_federated(&seeds, None)
.await
.expect("should succeed");
assert_eq!(kg.triple_count(), 3);
assert!(!kg.is_empty());
}
#[tokio::test]
async fn test_federated_expansion_empty_seeds() {
let executor = MockExecutor::new();
let config = FederatedGraphRAGConfig {
endpoints: vec![make_endpoint("ep_a", 1.0)],
global_timeout_ms: 5_000,
max_concurrency: 2,
same_as_max_depth: 3,
..Default::default()
};
let expander = FederatedSubgraphExpander::new(config, Arc::new(executor));
let kg = expander
.expand_federated(&[], None)
.await
.expect("should succeed");
assert!(kg.is_empty());
}
#[tokio::test]
async fn test_federated_expansion_no_active_endpoints() {
let mut ep = make_endpoint("ep1", 1.0);
ep.enabled = false;
let executor = MockExecutor::new();
let config = FederatedGraphRAGConfig {
endpoints: vec![ep],
global_timeout_ms: 5_000,
max_concurrency: 2,
same_as_max_depth: 3,
..Default::default()
};
let expander = FederatedSubgraphExpander::new(config, Arc::new(executor));
let seeds = vec![make_seed("http://s", 0.9)];
let result = expander.expand_federated(&seeds, None).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_distributed_entity_resolver_same_as_direct() {
let same_as_pairs = vec![("http://a/e1".to_string(), "http://b/e1".to_string())];
let executor = MockExecutor::new().with_same_as("ep_a", same_as_pairs);
let config = FederatedGraphRAGConfig {
endpoints: vec![make_endpoint("ep_a", 1.0)],
global_timeout_ms: 5_000,
max_concurrency: 2,
same_as_max_depth: 3,
..Default::default()
};
let resolver = DistributedEntityResolver::new(config, Arc::new(executor));
let uris = vec!["http://a/e1".to_string()];
let closure = resolver
.same_as_closure(&uris)
.await
.expect("should succeed");
let canon_a = closure.get("http://a/e1").expect("should succeed");
let canon_b = closure.get("http://b/e1").expect("should succeed");
assert_eq!(
canon_a, canon_b,
"Same-as entities should share canonical URI"
);
}
#[tokio::test]
async fn test_distributed_entity_resolver_no_links() {
let executor = MockExecutor::new();
let config = FederatedGraphRAGConfig {
endpoints: vec![make_endpoint("ep_a", 1.0)],
global_timeout_ms: 5_000,
max_concurrency: 2,
same_as_max_depth: 3,
..Default::default()
};
let resolver = DistributedEntityResolver::new(config, Arc::new(executor));
let uris = vec!["http://example.org/e1".to_string()];
let closure = resolver
.same_as_closure(&uris)
.await
.expect("should succeed");
let canon = closure
.get("http://example.org/e1")
.expect("should succeed");
assert_eq!(canon, "http://example.org/e1");
}
#[tokio::test]
async fn test_distributed_entity_resolver_transitive_chain() {
let same_as_pairs_ep1 = vec![("http://a/e1".to_string(), "http://b/e1".to_string())];
let same_as_pairs_ep2 = vec![("http://b/e1".to_string(), "http://c/e1".to_string())];
let executor = MockExecutor::new()
.with_same_as("ep1", same_as_pairs_ep1)
.with_same_as("ep2", same_as_pairs_ep2);
let config = FederatedGraphRAGConfig {
endpoints: vec![make_endpoint("ep1", 1.0), make_endpoint("ep2", 1.0)],
global_timeout_ms: 5_000,
max_concurrency: 2,
same_as_max_depth: 5,
..Default::default()
};
let resolver = DistributedEntityResolver::new(config, Arc::new(executor));
let uris = vec!["http://a/e1".to_string()];
let closure = resolver
.same_as_closure(&uris)
.await
.expect("should succeed");
if let Some(canon_a) = closure.get("http://a/e1") {
if let Some(canon_b) = closure.get("http://b/e1") {
assert_eq!(canon_a, canon_b);
}
}
}
#[test]
fn test_apply_to_graph_rewrites_uris() {
let executor = MockExecutor::new();
let config = FederatedGraphRAGConfig::default();
let resolver = DistributedEntityResolver::new(config, Arc::new(executor));
let mut kg = KnowledgeGraph::new();
kg.triples = vec![
make_triple("http://a/e1", "http://p", "http://b/e1"),
make_triple("http://a/e1", "http://p", "http://a/e1"), ];
kg.provenance = vec!["ep_a".to_string(), "ep_a".to_string()];
let mut canonical = HashMap::new();
canonical.insert("http://a/e1".to_string(), "http://canonical/e1".to_string());
canonical.insert("http://b/e1".to_string(), "http://canonical/e1".to_string());
resolver.apply_to_graph(&mut kg, &canonical);
assert_eq!(kg.triple_count(), 1);
assert_eq!(kg.triples[0].subject, "http://canonical/e1");
assert_eq!(kg.triples[0].object, "http://canonical/e1");
}
#[tokio::test]
async fn test_federated_context_builder_basic() {
let graphrag_config = FederatedGraphRAGConfig {
endpoints: vec![make_endpoint("ep_a", 2.0), make_endpoint("ep_b", 1.0)],
global_timeout_ms: 5_000,
max_concurrency: 2,
same_as_max_depth: 3,
..Default::default()
};
let ctx_config = FederatedContextConfig {
max_context_triples: 100,
max_context_chars: 10_000,
ordering: ContextOrderingStrategy::ByEndpointPriority,
include_provenance: true,
include_equivalences: false,
..Default::default()
};
let builder = FederatedContextBuilder::new(ctx_config, &graphrag_config);
let mut kg = KnowledgeGraph::new();
kg.triples = vec![
make_triple("http://s1", "http://p", "http://o1"),
make_triple("http://s2", "http://p", "http://o2"),
];
kg.provenance = vec!["ep_a".to_string(), "ep_b".to_string()];
let context = builder
.build_context(&kg, "test query")
.await
.expect("should succeed");
assert!(context.contains("test query"));
assert!(context.contains("http://s1"));
assert!(context.contains("http://s2"));
assert!(context.contains("[ep_a]") || context.contains("[ep_b]"));
}
#[tokio::test]
async fn test_federated_context_builder_empty_kg() {
let graphrag_config = FederatedGraphRAGConfig::default();
let ctx_config = FederatedContextConfig::default();
let builder = FederatedContextBuilder::new(ctx_config, &graphrag_config);
let kg = KnowledgeGraph::new();
let context = builder
.build_context(&kg, "test")
.await
.expect("should succeed");
assert!(context.is_empty());
}
#[tokio::test]
async fn test_federated_context_builder_respects_max_triples() {
let graphrag_config = FederatedGraphRAGConfig {
endpoints: vec![make_endpoint("ep_a", 1.0)],
global_timeout_ms: 5_000,
max_concurrency: 2,
same_as_max_depth: 3,
..Default::default()
};
let ctx_config = FederatedContextConfig {
max_context_triples: 2,
max_context_chars: 100_000,
ordering: ContextOrderingStrategy::Insertion,
include_provenance: false,
include_equivalences: false,
..Default::default()
};
let builder = FederatedContextBuilder::new(ctx_config, &graphrag_config);
let mut kg = KnowledgeGraph::new();
kg.triples = (0..10)
.map(|i| {
make_triple(
&format!("http://s{}", i),
"http://p",
&format!("http://o{}", i),
)
})
.collect();
kg.provenance = (0..10).map(|_| "ep_a".to_string()).collect();
let context = builder
.build_context(&kg, "test")
.await
.expect("should succeed");
let triple_lines = context.lines().filter(|l| l.starts_with("- ")).count();
assert!(
triple_lines <= 2,
"Expected at most 2 triples, got {}",
triple_lines
);
}
#[tokio::test]
async fn test_distributed_metrics_tracking_success() {
let endpoints = vec![make_endpoint("ep_a", 1.0), make_endpoint("ep_b", 1.0)];
let metrics = DistributedGraphRAGMetrics::new(&endpoints);
metrics.record_success("ep_a", 150, 42).await;
metrics.record_success("ep_a", 100, 30).await;
let snap = metrics
.endpoint_snapshot("ep_a")
.await
.expect("should succeed");
assert_eq!(snap.total_queries, 2);
assert_eq!(snap.successful_queries, 2);
assert_eq!(snap.failed_queries, 0);
assert_eq!(snap.total_triples, 72);
assert!(snap.avg_latency_ms > 0.0);
}
#[tokio::test]
async fn test_distributed_metrics_tracking_failure() {
let endpoints = vec![make_endpoint("ep_a", 1.0)];
let metrics = DistributedGraphRAGMetrics::new(&endpoints);
metrics.record_failure("ep_a").await;
metrics.record_failure("ep_a").await;
let snap = metrics
.endpoint_snapshot("ep_a")
.await
.expect("should succeed");
assert_eq!(snap.total_queries, 2);
assert_eq!(snap.failed_queries, 2);
assert_eq!(snap.successful_queries, 0);
assert_eq!(snap.hit_rate, 0.0);
}
#[tokio::test]
async fn test_distributed_metrics_aggregate() {
let endpoints = vec![make_endpoint("ep_a", 1.0)];
let metrics = DistributedGraphRAGMetrics::new(&endpoints);
metrics.record_federation_query(200, 100, false).await;
metrics.record_federation_query(300, 50, true).await;
metrics.record_entity_resolution().await;
let agg = metrics.aggregate_snapshot().await;
assert_eq!(agg.total_federation_queries, 2);
assert_eq!(agg.total_triples_gathered, 150);
assert_eq!(agg.entity_resolution_ops, 1);
assert_eq!(agg.partial_failure_count, 1);
assert!(agg.avg_federation_latency_ms > 0.0);
}
#[tokio::test]
async fn test_distributed_metrics_fastest_endpoint() {
let endpoints = vec![make_endpoint("ep_a", 1.0), make_endpoint("ep_b", 1.0)];
let metrics = DistributedGraphRAGMetrics::new(&endpoints);
metrics.record_success("ep_a", 500, 10).await;
metrics.record_success("ep_b", 50, 10).await;
let fastest = metrics.fastest_endpoint().await.expect("should succeed");
assert_eq!(fastest, "ep_b");
}
#[tokio::test]
async fn test_distributed_metrics_hit_rate() {
let endpoints = vec![make_endpoint("ep_a", 1.0)];
let metrics = DistributedGraphRAGMetrics::new(&endpoints);
metrics.record_success("ep_a", 100, 5).await; metrics.record_failure("ep_a").await;
let snap = metrics
.endpoint_snapshot("ep_a")
.await
.expect("should succeed");
assert_eq!(snap.total_queries, 2);
assert!(snap.hit_rate >= 0.0 && snap.hit_rate <= 1.0);
}
#[test]
fn test_parse_n_triples_basic() {
let body = "<http://s> <http://p> <http://o> .\n";
let triples = parse_n_triples(body).expect("should succeed");
assert_eq!(triples.len(), 1);
assert_eq!(triples[0].subject, "http://s");
assert_eq!(triples[0].predicate, "http://p");
assert_eq!(triples[0].object, "http://o");
}
#[test]
fn test_parse_n_triples_skips_comments() {
let body = "# comment\n<http://s> <http://p> <http://o> .\n";
let triples = parse_n_triples(body).expect("should succeed");
assert_eq!(triples.len(), 1);
}
#[test]
fn test_parse_n_triples_empty() {
let triples = parse_n_triples("").expect("should succeed");
assert!(triples.is_empty());
}
#[test]
fn test_build_seed_expansion_sparql_includes_seeds() {
let sparql = build_seed_expansion_sparql(
&["http://example.org/e1", "http://example.org/e2"],
None,
500,
);
assert!(sparql.contains("<http://example.org/e1>"));
assert!(sparql.contains("<http://example.org/e2>"));
assert!(sparql.contains("LIMIT 500"));
}
#[test]
fn test_build_seed_expansion_sparql_with_graph() {
let sparql = build_seed_expansion_sparql(
&["http://example.org/e1"],
Some("http://example.org/graph"),
100,
);
assert!(sparql.contains("FROM <http://example.org/graph>"));
}
#[test]
fn test_build_same_as_sparql() {
let sparql = build_same_as_sparql(&["http://a/e1", "http://b/e1"], None);
assert!(sparql.contains("owl#sameAs"));
assert!(sparql.contains("<http://a/e1>"));
}
#[test]
fn test_knowledge_graph_canonical_lookup() {
let mut kg = KnowledgeGraph::new();
kg.canonical_uris
.insert("http://b/e1".to_string(), "http://canonical/e1".to_string());
assert_eq!(kg.canonical("http://b/e1"), "http://canonical/e1");
assert_eq!(kg.canonical("http://unknown"), "http://unknown");
}
#[test]
fn test_endpoint_auth_variants() {
let bearer = EndpointAuth::Bearer {
token: "tok123".to_string(),
};
let basic = EndpointAuth::Basic {
username: "user".to_string(),
password: "pass".to_string(),
};
let api = EndpointAuth::ApiKey {
header: "X-API-Key".to_string(),
key: "key123".to_string(),
};
assert_ne!(bearer, EndpointAuth::None);
assert_ne!(basic, EndpointAuth::None);
assert_ne!(api, EndpointAuth::None);
}
}