Skip to main content

oxirs_graphrag/distributed/
mod.rs

1//! Distributed GraphRAG: federated subgraph expansion across multiple SPARQL endpoints.
2//!
3//! This module provides the building blocks for querying heterogeneous, geographically
4//! distributed knowledge graphs and merging the results into a single coherent subgraph
5//! suitable for retrieval-augmented generation.
6//!
7//! ## Architecture
8//!
9//! ```text
10//! Query Seeds
11//!     │
12//!     ▼
13//! FederatedSubgraphExpander ──► [Endpoint A] ──► subgraph_A
14//!     │                    ──► [Endpoint B] ──► subgraph_B   ──► merge + resolve ──► KnowledgeGraph
15//!     │                    ──► [Endpoint C] ──► subgraph_C
16//!     │
17//!     ▼
18//! DistributedEntityResolver  (sameAs closure)
19//!     │
20//!     ▼
21//! FederatedContextBuilder    (priority + confidence ranking)
22//!     │
23//!     ▼
24//! RAG context string
25//! ```
26
27use std::collections::{HashMap, HashSet, VecDeque};
28use std::sync::Arc;
29use std::time::{Duration, Instant};
30
31use serde::{Deserialize, Serialize};
32use thiserror::Error;
33use tokio::sync::{Mutex, RwLock, Semaphore};
34use tracing::{debug, info, warn};
35
36use crate::{GraphRAGError, GraphRAGResult, ScoredEntity, Triple};
37
38// ─────────────────────────────────────────────────────────────────────────────
39// Error types
40// ─────────────────────────────────────────────────────────────────────────────
41
42/// Distributed GraphRAG–specific error variants
43#[derive(Error, Debug)]
44pub enum DistributedError {
45    #[error("Endpoint {endpoint} is unreachable: {reason}")]
46    EndpointUnreachable { endpoint: String, reason: String },
47
48    #[error("Authentication failed for endpoint {endpoint}")]
49    AuthFailed { endpoint: String },
50
51    #[error("SPARQL query timeout after {timeout_ms}ms on endpoint {endpoint}")]
52    QueryTimeout { endpoint: String, timeout_ms: u64 },
53
54    #[error("Entity resolution cycle detected for URI {uri}")]
55    SameAsCycle { uri: String },
56
57    #[error("No healthy endpoints available for query")]
58    NoHealthyEndpoints,
59
60    #[error("Merge conflict: cannot reconcile {uri} across endpoints")]
61    MergeConflict { uri: String },
62
63    #[error("Configuration invalid: {0}")]
64    InvalidConfig(String),
65}
66
67impl From<DistributedError> for GraphRAGError {
68    fn from(e: DistributedError) -> Self {
69        GraphRAGError::InternalError(e.to_string())
70    }
71}
72
73// ─────────────────────────────────────────────────────────────────────────────
74// Configuration
75// ─────────────────────────────────────────────────────────────────────────────
76
77/// Authentication method for SPARQL endpoints
78#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
79#[serde(tag = "type", rename_all = "snake_case")]
80pub enum EndpointAuth {
81    /// No authentication
82    None,
83    /// HTTP Bearer token
84    Bearer { token: String },
85    /// HTTP Basic auth
86    Basic { username: String, password: String },
87    /// API key in header
88    ApiKey { header: String, key: String },
89}
90
91/// Configuration for a single remote SPARQL endpoint
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct EndpointConfig {
94    /// Human-readable name for the endpoint
95    pub name: String,
96    /// Base URL of the SPARQL endpoint
97    pub url: String,
98    /// Authentication method
99    pub auth: EndpointAuth,
100    /// Per-endpoint query timeout in milliseconds (overrides global setting)
101    pub timeout_ms: Option<u64>,
102    /// Priority weight (higher = preferred; used when deduplicating conflicting triples)
103    pub priority: f64,
104    /// Whether this endpoint is enabled
105    pub enabled: bool,
106    /// Graph URI to restrict queries to (SPARQL FROM clause)
107    pub graph_uri: Option<String>,
108    /// Maximum triples to fetch from this endpoint per query
109    pub max_triples: usize,
110}
111
112impl Default for EndpointConfig {
113    fn default() -> Self {
114        Self {
115            name: String::new(),
116            url: String::new(),
117            auth: EndpointAuth::None,
118            timeout_ms: None,
119            priority: 1.0,
120            enabled: true,
121            graph_uri: None,
122            max_triples: 10_000,
123        }
124    }
125}
126
127/// Top-level configuration for federated GraphRAG
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct FederatedGraphRAGConfig {
130    /// List of remote endpoints to query
131    pub endpoints: Vec<EndpointConfig>,
132    /// Global query timeout in milliseconds
133    pub global_timeout_ms: u64,
134    /// Maximum concurrent endpoint requests
135    pub max_concurrency: usize,
136    /// Maximum transitive sameAs hops to follow
137    pub same_as_max_depth: usize,
138    /// Minimum endpoint priority to include in a query (0.0 = include all)
139    pub min_endpoint_priority: f64,
140    /// Whether to continue when some endpoints fail
141    pub partial_results_ok: bool,
142    /// Retry count for failed endpoint requests
143    pub retry_count: usize,
144    /// Delay between retries in milliseconds
145    pub retry_delay_ms: u64,
146}
147
148impl Default for FederatedGraphRAGConfig {
149    fn default() -> Self {
150        Self {
151            endpoints: vec![],
152            global_timeout_ms: 30_000,
153            max_concurrency: 8,
154            same_as_max_depth: 5,
155            min_endpoint_priority: 0.0,
156            partial_results_ok: true,
157            retry_count: 2,
158            retry_delay_ms: 500,
159        }
160    }
161}
162
163impl FederatedGraphRAGConfig {
164    /// Validate configuration and return an error description if invalid.
165    pub fn validate(&self) -> Result<(), DistributedError> {
166        if self.global_timeout_ms == 0 {
167            return Err(DistributedError::InvalidConfig(
168                "global_timeout_ms must be > 0".into(),
169            ));
170        }
171        if self.max_concurrency == 0 {
172            return Err(DistributedError::InvalidConfig(
173                "max_concurrency must be > 0".into(),
174            ));
175        }
176        if self.same_as_max_depth == 0 {
177            return Err(DistributedError::InvalidConfig(
178                "same_as_max_depth must be > 0".into(),
179            ));
180        }
181        for ep in &self.endpoints {
182            if ep.url.is_empty() {
183                return Err(DistributedError::InvalidConfig(format!(
184                    "Endpoint '{}' has an empty URL",
185                    ep.name
186                )));
187            }
188            if ep.max_triples == 0 {
189                return Err(DistributedError::InvalidConfig(format!(
190                    "Endpoint '{}' max_triples must be > 0",
191                    ep.name
192                )));
193            }
194        }
195        Ok(())
196    }
197
198    /// Return only enabled endpoints with priority >= `min_endpoint_priority`.
199    pub fn active_endpoints(&self) -> Vec<&EndpointConfig> {
200        self.endpoints
201            .iter()
202            .filter(|ep| ep.enabled && ep.priority >= self.min_endpoint_priority)
203            .collect()
204    }
205}
206
207// ─────────────────────────────────────────────────────────────────────────────
208// Knowledge graph result type
209// ─────────────────────────────────────────────────────────────────────────────
210
211/// A merged knowledge graph assembled from multiple endpoints
212#[derive(Debug, Clone, Default)]
213pub struct KnowledgeGraph {
214    /// All triples gathered from the federation
215    pub triples: Vec<Triple>,
216    /// Provenance: which endpoint contributed each triple index
217    pub provenance: Vec<String>,
218    /// Entity equivalence classes after sameAs resolution
219    pub equivalence_classes: Vec<HashSet<String>>,
220    /// Canonical URIs chosen for each equivalence class (representative URI)
221    pub canonical_uris: HashMap<String, String>,
222}
223
224impl KnowledgeGraph {
225    /// Create an empty knowledge graph
226    pub fn new() -> Self {
227        Self::default()
228    }
229
230    /// Return the number of distinct triples
231    pub fn triple_count(&self) -> usize {
232        self.triples.len()
233    }
234
235    /// Return true if the knowledge graph has no triples
236    pub fn is_empty(&self) -> bool {
237        self.triples.is_empty()
238    }
239
240    /// Resolve a URI to its canonical form (or return the URI unchanged)
241    pub fn canonical<'a>(&'a self, uri: &'a str) -> &'a str {
242        self.canonical_uris
243            .get(uri)
244            .map(|s| s.as_str())
245            .unwrap_or(uri)
246    }
247}
248
249// ─────────────────────────────────────────────────────────────────────────────
250// HTTP SPARQL client abstraction
251// ─────────────────────────────────────────────────────────────────────────────
252
253/// Result of a single endpoint query
254#[derive(Debug)]
255struct EndpointResult {
256    endpoint_name: String,
257    triples: Vec<Triple>,
258    latency_ms: u64,
259}
260
261/// Build a SPARQL CONSTRUCT query for seed expansion
262fn build_seed_expansion_sparql(seeds: &[&str], graph_uri: Option<&str>, limit: usize) -> String {
263    let values: Vec<String> = seeds.iter().map(|s| format!("<{}>", s)).collect();
264    let values_block = values.join(" ");
265
266    let from_clause = match graph_uri {
267        Some(g) => format!("FROM <{}>", g),
268        None => String::new(),
269    };
270
271    format!(
272        r#"CONSTRUCT {{
273    ?s ?p ?o .
274}}
275{from}
276WHERE {{
277    VALUES ?seed {{ {seeds} }}
278    {{
279        BIND(?seed AS ?s)
280        ?s ?p ?o .
281    }} UNION {{
282        ?s ?p ?seed .
283        BIND(?seed AS ?o)
284    }}
285}}
286LIMIT {limit}
287"#,
288        from = from_clause,
289        seeds = values_block,
290        limit = limit,
291    )
292}
293
294/// Build a SPARQL SELECT query for sameAs links
295fn build_same_as_sparql(uris: &[&str], graph_uri: Option<&str>) -> String {
296    let values: Vec<String> = uris.iter().map(|s| format!("<{}>", s)).collect();
297    let values_block = values.join(" ");
298
299    let from_clause = match graph_uri {
300        Some(g) => format!("FROM <{}>", g),
301        None => String::new(),
302    };
303
304    format!(
305        r#"SELECT DISTINCT ?a ?b
306{from}
307WHERE {{
308    VALUES ?a {{ {uris} }}
309    {{
310        ?a <http://www.w3.org/2002/07/owl#sameAs> ?b .
311    }} UNION {{
312        ?b <http://www.w3.org/2002/07/owl#sameAs> ?a .
313    }}
314}}
315"#,
316        from = from_clause,
317        uris = values_block,
318    )
319}
320
321// ─────────────────────────────────────────────────────────────────────────────
322// HTTP executor (mockable in tests)
323// ─────────────────────────────────────────────────────────────────────────────
324
325/// Trait for executing SPARQL CONSTRUCT queries against a remote endpoint
326#[async_trait::async_trait]
327pub trait EndpointExecutor: Send + Sync {
328    /// Execute a SPARQL CONSTRUCT query and return RDF triples
329    async fn construct(
330        &self,
331        endpoint: &EndpointConfig,
332        sparql: &str,
333        timeout: Duration,
334    ) -> GraphRAGResult<Vec<Triple>>;
335
336    /// Execute a SPARQL SELECT query and return rows (variable → value maps)
337    async fn select(
338        &self,
339        endpoint: &EndpointConfig,
340        sparql: &str,
341        timeout: Duration,
342    ) -> GraphRAGResult<Vec<HashMap<String, String>>>;
343}
344
345/// Default HTTP-based endpoint executor using reqwest
346pub struct HttpEndpointExecutor {
347    client: reqwest::Client,
348}
349
350impl HttpEndpointExecutor {
351    /// Create a new HTTP executor
352    pub fn new() -> GraphRAGResult<Self> {
353        let client = reqwest::Client::builder()
354            .build()
355            .map_err(|e| GraphRAGError::InternalError(format!("HTTP client init: {e}")))?;
356        Ok(Self { client })
357    }
358
359    /// Apply authentication headers to a request builder
360    fn apply_auth(
361        &self,
362        builder: reqwest::RequestBuilder,
363        auth: &EndpointAuth,
364    ) -> reqwest::RequestBuilder {
365        match auth {
366            EndpointAuth::None => builder,
367            EndpointAuth::Bearer { token } => {
368                builder.header("Authorization", format!("Bearer {}", token))
369            }
370            EndpointAuth::Basic { username, password } => {
371                builder.basic_auth(username, Some(password))
372            }
373            EndpointAuth::ApiKey { header, key } => builder.header(header.as_str(), key.as_str()),
374        }
375    }
376}
377
378#[async_trait::async_trait]
379impl EndpointExecutor for HttpEndpointExecutor {
380    async fn construct(
381        &self,
382        endpoint: &EndpointConfig,
383        sparql: &str,
384        timeout: Duration,
385    ) -> GraphRAGResult<Vec<Triple>> {
386        let builder: reqwest::RequestBuilder = self
387            .client
388            .post(&endpoint.url)
389            .timeout(timeout)
390            .header("Content-Type", "application/sparql-query")
391            .header("Accept", "application/n-triples")
392            .body(sparql.to_string());
393        let builder = self.apply_auth(builder, &endpoint.auth);
394
395        let response = builder
396            .send()
397            .await
398            .map_err(|e| GraphRAGError::SparqlError(format!("HTTP error: {e}")))?;
399
400        let status = response.status();
401        if !status.is_success() {
402            return Err(GraphRAGError::SparqlError(format!(
403                "HTTP {} from {}",
404                status, endpoint.url
405            )));
406        }
407
408        let body = response
409            .text()
410            .await
411            .map_err(|e| GraphRAGError::SparqlError(format!("Response read error: {e}")))?;
412
413        parse_n_triples(&body)
414    }
415
416    async fn select(
417        &self,
418        endpoint: &EndpointConfig,
419        sparql: &str,
420        timeout: Duration,
421    ) -> GraphRAGResult<Vec<HashMap<String, String>>> {
422        let builder: reqwest::RequestBuilder = self
423            .client
424            .post(&endpoint.url)
425            .timeout(timeout)
426            .header("Content-Type", "application/sparql-query")
427            .header("Accept", "application/sparql-results+json")
428            .body(sparql.to_string());
429        let builder = self.apply_auth(builder, &endpoint.auth);
430
431        let response = builder
432            .send()
433            .await
434            .map_err(|e| GraphRAGError::SparqlError(format!("HTTP error: {e}")))?;
435
436        let status = response.status();
437        if !status.is_success() {
438            return Err(GraphRAGError::SparqlError(format!(
439                "HTTP {} from {}",
440                status, endpoint.url
441            )));
442        }
443
444        let body = response
445            .text()
446            .await
447            .map_err(|e| GraphRAGError::SparqlError(format!("Response read error: {e}")))?;
448
449        parse_sparql_json_results(&body)
450    }
451}
452
453/// Minimal N-Triples parser (handles `<s> <p> <o> .` and string literals)
454fn parse_n_triples(body: &str) -> GraphRAGResult<Vec<Triple>> {
455    let mut triples = Vec::new();
456    for line in body.lines() {
457        let line = line.trim();
458        if line.is_empty() || line.starts_with('#') {
459            continue;
460        }
461        // Very lightweight parser: split on whitespace-delimited tokens
462        let tokens: Vec<&str> = line.splitn(4, ' ').collect();
463        if tokens.len() < 3 {
464            continue;
465        }
466        let s = strip_angle_brackets(tokens[0]);
467        let p = strip_angle_brackets(tokens[1]);
468        let o = if tokens[2].starts_with('<') {
469            strip_angle_brackets(tokens[2]).to_string()
470        } else {
471            tokens[2].to_string()
472        };
473        if !s.is_empty() && !p.is_empty() {
474            triples.push(Triple::new(s, p, o));
475        }
476    }
477    Ok(triples)
478}
479
480fn strip_angle_brackets(s: &str) -> &str {
481    s.trim_start_matches('<').trim_end_matches('>')
482}
483
484/// Minimal SPARQL JSON results parser for SELECT queries
485fn parse_sparql_json_results(body: &str) -> GraphRAGResult<Vec<HashMap<String, String>>> {
486    // Use serde_json for reliability
487    let json: serde_json::Value = serde_json::from_str(body)
488        .map_err(|e| GraphRAGError::InternalError(format!("JSON parse error: {e}")))?;
489
490    let vars: Vec<String> = json["head"]["vars"]
491        .as_array()
492        .unwrap_or(&vec![])
493        .iter()
494        .filter_map(|v| v.as_str().map(|s| s.to_string()))
495        .collect();
496
497    let bindings = json["results"]["bindings"]
498        .as_array()
499        .unwrap_or(&vec![])
500        .clone();
501
502    let mut rows = Vec::new();
503    for binding in bindings {
504        let mut row = HashMap::new();
505        for var in &vars {
506            if let Some(val) = binding.get(var) {
507                let value = val["value"].as_str().unwrap_or("").to_string();
508                row.insert(var.clone(), value);
509            }
510        }
511        rows.push(row);
512    }
513    Ok(rows)
514}
515
516// ─────────────────────────────────────────────────────────────────────────────
517// FederatedSubgraphExpander
518// ─────────────────────────────────────────────────────────────────────────────
519
520/// Expands subgraphs across multiple SPARQL endpoints concurrently and merges
521/// the results into a single [`KnowledgeGraph`].
522pub struct FederatedSubgraphExpander<E: EndpointExecutor> {
523    config: FederatedGraphRAGConfig,
524    executor: Arc<E>,
525}
526
527impl<E: EndpointExecutor + 'static> FederatedSubgraphExpander<E> {
528    /// Create a new expander with the given config and executor
529    pub fn new(config: FederatedGraphRAGConfig, executor: Arc<E>) -> Self {
530        Self { config, executor }
531    }
532
533    /// Expand subgraphs for the given seed entities across all active endpoints.
534    ///
535    /// Queries are issued concurrently (bounded by `config.max_concurrency`).
536    /// If `config.partial_results_ok` is true, endpoint failures are logged but
537    /// do not abort the overall operation.
538    pub async fn expand_federated(
539        &self,
540        seeds: &[ScoredEntity],
541        endpoints: Option<&[String]>,
542    ) -> GraphRAGResult<KnowledgeGraph> {
543        if seeds.is_empty() {
544            return Ok(KnowledgeGraph::new());
545        }
546
547        let seed_uris: Vec<&str> = seeds.iter().map(|s| s.uri.as_str()).collect();
548
549        // Determine which endpoints to query
550        let active: Vec<&EndpointConfig> = match endpoints {
551            Some(names) => self
552                .config
553                .active_endpoints()
554                .into_iter()
555                .filter(|ep| names.iter().any(|n| n == &ep.name))
556                .collect(),
557            None => self.config.active_endpoints(),
558        };
559
560        if active.is_empty() {
561            return Err(DistributedError::NoHealthyEndpoints.into());
562        }
563
564        info!(
565            "Federated expansion: {} seeds across {} endpoints",
566            seeds.len(),
567            active.len()
568        );
569
570        let semaphore = Arc::new(Semaphore::new(self.config.max_concurrency));
571        let results: Arc<Mutex<Vec<EndpointResult>>> = Arc::new(Mutex::new(Vec::new()));
572        let mut handles = Vec::new();
573
574        for ep in active {
575            let ep = ep.clone();
576            let executor = Arc::clone(&self.executor);
577            let sem = Arc::clone(&semaphore);
578            let results = Arc::clone(&results);
579            let seed_uris: Vec<String> = seed_uris.iter().map(|s| s.to_string()).collect();
580            let timeout_ms = ep.timeout_ms.unwrap_or(self.config.global_timeout_ms);
581            let timeout = Duration::from_millis(timeout_ms);
582            let retry_count = self.config.retry_count;
583            let retry_delay = Duration::from_millis(self.config.retry_delay_ms);
584            let partial_ok = self.config.partial_results_ok;
585
586            let handle = tokio::spawn(async move {
587                let _permit = match sem.acquire_owned().await {
588                    Ok(p) => p,
589                    Err(e) => {
590                        warn!("Semaphore acquire failed: {e}");
591                        return;
592                    }
593                };
594
595                let sparql = build_seed_expansion_sparql(
596                    &seed_uris.iter().map(|s| s.as_str()).collect::<Vec<_>>(),
597                    ep.graph_uri.as_deref(),
598                    ep.max_triples,
599                );
600
601                let start = Instant::now();
602                let mut last_err = None;
603
604                for attempt in 0..=retry_count {
605                    if attempt > 0 {
606                        tokio::time::sleep(retry_delay).await;
607                    }
608
609                    match executor.construct(&ep, &sparql, timeout).await {
610                        Ok(triples) => {
611                            let latency_ms = start.elapsed().as_millis() as u64;
612                            debug!(
613                                endpoint = %ep.name,
614                                triples = triples.len(),
615                                latency_ms,
616                                "Endpoint query succeeded"
617                            );
618                            let mut guard = results.lock().await;
619                            guard.push(EndpointResult {
620                                endpoint_name: ep.name.clone(),
621                                triples,
622                                latency_ms,
623                            });
624                            return;
625                        }
626                        Err(e) => {
627                            warn!(
628                                endpoint = %ep.name,
629                                attempt,
630                                error = %e,
631                                "Endpoint query failed"
632                            );
633                            last_err = Some(e);
634                        }
635                    }
636                }
637
638                if !partial_ok {
639                    warn!(
640                        endpoint = %ep.name,
641                        error = ?last_err,
642                        "Endpoint permanently failed and partial_results_ok=false"
643                    );
644                }
645            });
646
647            handles.push(handle);
648        }
649
650        // Wait for all tasks
651        for h in handles {
652            if let Err(e) = h.await {
653                warn!("Task join error: {e}");
654            }
655        }
656
657        let endpoint_results = Arc::try_unwrap(results)
658            .map_err(|_| GraphRAGError::InternalError("Arc unwrap failed".into()))?
659            .into_inner();
660
661        if endpoint_results.is_empty() && !self.config.partial_results_ok {
662            return Err(DistributedError::NoHealthyEndpoints.into());
663        }
664
665        self.merge_results(endpoint_results)
666    }
667
668    /// Merge endpoint results into a unified [`KnowledgeGraph`], deduplicating
669    /// triples and recording provenance.
670    fn merge_results(&self, results: Vec<EndpointResult>) -> GraphRAGResult<KnowledgeGraph> {
671        let mut kg = KnowledgeGraph::new();
672        // Use a set to deduplicate (subject, predicate, object)
673        let mut seen: HashSet<(String, String, String)> = HashSet::new();
674
675        // Sort by endpoint priority descending (higher priority wins dedup)
676        let mut priority_map: HashMap<String, f64> = HashMap::new();
677        for ep in &self.config.endpoints {
678            priority_map.insert(ep.name.clone(), ep.priority);
679        }
680
681        let mut sorted_results = results;
682        sorted_results.sort_by(|a, b| {
683            let pa = priority_map.get(&a.endpoint_name).copied().unwrap_or(1.0);
684            let pb = priority_map.get(&b.endpoint_name).copied().unwrap_or(1.0);
685            pb.partial_cmp(&pa).unwrap_or(std::cmp::Ordering::Equal)
686        });
687
688        for result in sorted_results {
689            for triple in result.triples {
690                let key = (
691                    triple.subject.clone(),
692                    triple.predicate.clone(),
693                    triple.object.clone(),
694                );
695                if seen.insert(key) {
696                    kg.triples.push(triple);
697                    kg.provenance.push(result.endpoint_name.clone());
698                }
699            }
700        }
701
702        Ok(kg)
703    }
704}
705
706// ─────────────────────────────────────────────────────────────────────────────
707// DistributedEntityResolver
708// ─────────────────────────────────────────────────────────────────────────────
709
710/// Resolves entity identity across endpoints using owl:sameAs links.
711///
712/// The resolver computes the transitive sameAs closure: if A sameAs B and
713/// B sameAs C, all three are placed in the same equivalence class.
714pub struct DistributedEntityResolver<E: EndpointExecutor> {
715    config: FederatedGraphRAGConfig,
716    executor: Arc<E>,
717}
718
719impl<E: EndpointExecutor + 'static> DistributedEntityResolver<E> {
720    /// Create a new resolver
721    pub fn new(config: FederatedGraphRAGConfig, executor: Arc<E>) -> Self {
722        Self { config, executor }
723    }
724
725    /// Compute the transitive owl:sameAs closure for the given URIs across all
726    /// active endpoints.
727    ///
728    /// Returns a map from each input URI (and discovered aliases) to the
729    /// canonical representative URI for its equivalence class.
730    pub async fn same_as_closure(
731        &self,
732        uris: &[String],
733    ) -> GraphRAGResult<HashMap<String, String>> {
734        if uris.is_empty() {
735            return Ok(HashMap::new());
736        }
737
738        // Union-Find structure for equivalence classes
739        let parent: Arc<RwLock<HashMap<String, String>>> = Arc::new(RwLock::new(HashMap::new()));
740
741        // Initialize each URI as its own parent
742        {
743            let mut p = parent.write().await;
744            for uri in uris {
745                p.insert(uri.clone(), uri.clone());
746            }
747        }
748
749        // BFS frontier: expand sameAs links up to max depth
750        let mut frontier: VecDeque<String> = uris.iter().cloned().collect();
751        let mut visited: HashSet<String> = HashSet::from_iter(uris.iter().cloned());
752        let mut depth = 0usize;
753
754        while !frontier.is_empty() && depth < self.config.same_as_max_depth {
755            let batch: Vec<String> = frontier.drain(..).collect();
756            let batch_refs: Vec<&str> = batch.iter().map(|s| s.as_str()).collect();
757
758            // Query all endpoints for sameAs links
759            let links = self.fetch_same_as_links(&batch_refs).await?;
760
761            let mut p = parent.write().await;
762            for (a, b) in links {
763                // Ensure both exist in the union-find
764                p.entry(a.clone()).or_insert_with(|| a.clone());
765                p.entry(b.clone()).or_insert_with(|| b.clone());
766
767                // Union a and b
768                let root_a = find_root_path(&p, &a);
769                let root_b = find_root_path(&p, &b);
770                if root_a != root_b {
771                    // Prefer lexicographically smaller URI as canonical
772                    let canonical = if root_a <= root_b {
773                        root_a.clone()
774                    } else {
775                        root_b.clone()
776                    };
777                    p.insert(root_a, canonical.clone());
778                    p.insert(root_b, canonical);
779                }
780
781                // Add newly discovered URIs to the frontier
782                if !visited.contains(&b) {
783                    visited.insert(b.clone());
784                    frontier.push_back(b);
785                }
786            }
787
788            depth += 1;
789        }
790
791        // Flatten all paths to canonical roots
792        let p = parent.read().await;
793        let mut result = HashMap::new();
794        for uri in p.keys() {
795            let canonical = find_root_path(&p, uri);
796            result.insert(uri.clone(), canonical);
797        }
798        Ok(result)
799    }
800
801    /// Fetch raw owl:sameAs pairs from all active endpoints for the given URIs
802    async fn fetch_same_as_links(&self, uris: &[&str]) -> GraphRAGResult<Vec<(String, String)>> {
803        let active = self.config.active_endpoints();
804        let semaphore = Arc::new(Semaphore::new(self.config.max_concurrency));
805        let pairs: Arc<Mutex<Vec<(String, String)>>> = Arc::new(Mutex::new(Vec::new()));
806        let mut handles = Vec::new();
807
808        for ep in active {
809            let ep = ep.clone();
810            let executor = Arc::clone(&self.executor);
811            let sem = Arc::clone(&semaphore);
812            let pairs = Arc::clone(&pairs);
813            let uris_owned: Vec<String> = uris.iter().map(|s| s.to_string()).collect();
814            let timeout_ms = ep.timeout_ms.unwrap_or(self.config.global_timeout_ms);
815            let timeout = Duration::from_millis(timeout_ms);
816
817            let handle = tokio::spawn(async move {
818                let _permit = match sem.acquire_owned().await {
819                    Ok(p) => p,
820                    Err(_) => return,
821                };
822
823                let sparql = build_same_as_sparql(
824                    &uris_owned.iter().map(|s| s.as_str()).collect::<Vec<_>>(),
825                    ep.graph_uri.as_deref(),
826                );
827
828                match executor.select(&ep, &sparql, timeout).await {
829                    Ok(rows) => {
830                        let mut guard = pairs.lock().await;
831                        for row in rows {
832                            if let (Some(a), Some(b)) = (row.get("a"), row.get("b")) {
833                                guard.push((a.clone(), b.clone()));
834                            }
835                        }
836                    }
837                    Err(e) => {
838                        debug!(endpoint = %ep.name, error = %e, "sameAs fetch failed");
839                    }
840                }
841            });
842
843            handles.push(handle);
844        }
845
846        for h in handles {
847            let _ = h.await;
848        }
849
850        let guard = Arc::try_unwrap(pairs)
851            .map_err(|_| GraphRAGError::InternalError("Arc unwrap failed".into()))?
852            .into_inner();
853
854        Ok(guard)
855    }
856
857    /// Apply sameAs closure to a knowledge graph, rewriting URIs to canonical forms
858    /// and deduplicating triples that become identical after rewriting.
859    pub fn apply_to_graph(&self, kg: &mut KnowledgeGraph, canonical_map: &HashMap<String, String>) {
860        let canonicalize = |s: &str| -> String {
861            canonical_map
862                .get(s)
863                .cloned()
864                .unwrap_or_else(|| s.to_string())
865        };
866
867        let mut seen: HashSet<(String, String, String)> = HashSet::new();
868        let mut new_triples = Vec::new();
869        let mut new_provenance = Vec::new();
870
871        for (triple, prov) in kg.triples.iter().zip(kg.provenance.iter()) {
872            let s = canonicalize(&triple.subject);
873            let p = triple.predicate.clone();
874            let o = canonicalize(&triple.object);
875            let key = (s.clone(), p.clone(), o.clone());
876            if seen.insert(key) {
877                new_triples.push(Triple::new(s, p, o));
878                new_provenance.push(prov.clone());
879            }
880        }
881
882        kg.triples = new_triples;
883        kg.provenance = new_provenance;
884        kg.canonical_uris = canonical_map.clone();
885
886        // Rebuild equivalence classes
887        let mut classes: HashMap<String, HashSet<String>> = HashMap::new();
888        for (uri, canonical) in canonical_map {
889            classes
890                .entry(canonical.clone())
891                .or_default()
892                .insert(uri.clone());
893        }
894        kg.equivalence_classes = classes.into_values().collect();
895    }
896}
897
898/// Path-compression find for the union-find structure
899fn find_root_path(parent: &HashMap<String, String>, uri: &str) -> String {
900    let mut current = uri.to_string();
901    let mut depth = 0usize;
902    loop {
903        let next = parent
904            .get(&current)
905            .cloned()
906            .unwrap_or_else(|| current.clone());
907        if next == current || depth > 100 {
908            return current;
909        }
910        current = next;
911        depth += 1;
912    }
913}
914
915// ─────────────────────────────────────────────────────────────────────────────
916// FederatedContextBuilder
917// ─────────────────────────────────────────────────────────────────────────────
918
919/// Strategy for ordering triples in the generated context
920#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
921pub enum ContextOrderingStrategy {
922    /// Order by endpoint priority (highest first)
923    ByEndpointPriority,
924    /// Order by query latency (fastest endpoints first)
925    ByLatency,
926    /// No specific ordering (insertion order)
927    Insertion,
928}
929
930/// Configuration for the federated context builder
931#[derive(Debug, Clone, Serialize, Deserialize)]
932pub struct FederatedContextConfig {
933    /// Maximum number of triples to include in the context
934    pub max_context_triples: usize,
935    /// Maximum length (characters) of the context string
936    pub max_context_chars: usize,
937    /// Triple ordering strategy
938    pub ordering: ContextOrderingStrategy,
939    /// Whether to include provenance annotations in the context
940    pub include_provenance: bool,
941    /// Minimum endpoint priority to include triples from
942    pub min_endpoint_priority: f64,
943    /// Whether to include equivalence class annotations
944    pub include_equivalences: bool,
945}
946
947impl Default for FederatedContextConfig {
948    fn default() -> Self {
949        Self {
950            max_context_triples: 500,
951            max_context_chars: 50_000,
952            ordering: ContextOrderingStrategy::ByEndpointPriority,
953            include_provenance: false,
954            include_equivalences: false,
955            min_endpoint_priority: 0.0,
956        }
957    }
958}
959
960/// Builds RAG context strings from distributed knowledge graphs
961pub struct FederatedContextBuilder {
962    config: FederatedContextConfig,
963    /// Per-endpoint priority registry
964    endpoint_priorities: HashMap<String, f64>,
965    /// Per-endpoint latency registry (milliseconds, populated from expansion runs)
966    endpoint_latencies: Arc<RwLock<HashMap<String, u64>>>,
967}
968
969impl FederatedContextBuilder {
970    /// Create a new context builder
971    pub fn new(config: FederatedContextConfig, graphrag_config: &FederatedGraphRAGConfig) -> Self {
972        let endpoint_priorities: HashMap<String, f64> = graphrag_config
973            .endpoints
974            .iter()
975            .map(|ep| (ep.name.clone(), ep.priority))
976            .collect();
977
978        Self {
979            config,
980            endpoint_priorities,
981            endpoint_latencies: Arc::new(RwLock::new(HashMap::new())),
982        }
983    }
984
985    /// Record observed latency for an endpoint (used in ByLatency ordering)
986    pub async fn record_latency(&self, endpoint_name: &str, latency_ms: u64) {
987        let mut lats = self.endpoint_latencies.write().await;
988        lats.insert(endpoint_name.to_string(), latency_ms);
989    }
990
991    /// Build a context string from a [`KnowledgeGraph`].
992    ///
993    /// Triples are ordered according to the configured strategy, truncated to
994    /// respect both `max_context_triples` and `max_context_chars`.
995    pub async fn build_context(&self, kg: &KnowledgeGraph, query: &str) -> GraphRAGResult<String> {
996        if kg.is_empty() {
997            return Ok(String::new());
998        }
999
1000        // Create (triple_index, priority_key) pairs for sorting
1001        let latencies = self.endpoint_latencies.read().await;
1002        let mut indexed: Vec<(usize, f64)> = kg
1003            .triples
1004            .iter()
1005            .enumerate()
1006            .map(|(i, _)| {
1007                let ep = kg.provenance.get(i).map(|s| s.as_str()).unwrap_or("");
1008                let sort_key = match self.config.ordering {
1009                    ContextOrderingStrategy::ByEndpointPriority => {
1010                        // Higher priority → lower sort key (we sort ascending, then reverse)
1011                        self.endpoint_priorities.get(ep).copied().unwrap_or(1.0)
1012                    }
1013                    ContextOrderingStrategy::ByLatency => {
1014                        // Lower latency → higher priority
1015                        let lat = latencies.get(ep).copied().unwrap_or(u64::MAX);
1016                        // Invert: smaller latency → larger sort key
1017                        1.0 / (lat as f64 + 1.0)
1018                    }
1019                    ContextOrderingStrategy::Insertion => i as f64,
1020                };
1021                (i, sort_key)
1022            })
1023            .filter(|(i, _)| {
1024                let ep = kg.provenance.get(*i).map(|s| s.as_str()).unwrap_or("");
1025                let prio = self.endpoint_priorities.get(ep).copied().unwrap_or(1.0);
1026                prio >= self.config.min_endpoint_priority
1027            })
1028            .collect();
1029
1030        // Sort: ByEndpointPriority and ByLatency both want descending key
1031        match self.config.ordering {
1032            ContextOrderingStrategy::ByEndpointPriority | ContextOrderingStrategy::ByLatency => {
1033                indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1034            }
1035            ContextOrderingStrategy::Insertion => {
1036                indexed.sort_by_key(|(i, _)| *i);
1037            }
1038        }
1039
1040        let mut context = format!("## Knowledge Graph Context\n\nQuery: {}\n\n", query);
1041
1042        // Add equivalence class info if requested
1043        if self.config.include_equivalences && !kg.equivalence_classes.is_empty() {
1044            context.push_str("### Entity Equivalences\n");
1045            for class in &kg.equivalence_classes {
1046                if class.len() > 1 {
1047                    let mut members: Vec<&str> = class.iter().map(|s| s.as_str()).collect();
1048                    members.sort();
1049                    context.push_str(&format!("- {}\n", members.join(" ≡ ")));
1050                }
1051            }
1052            context.push('\n');
1053        }
1054
1055        context.push_str("### Facts\n\n");
1056
1057        for (triple_count, (idx, _)) in indexed.iter().enumerate() {
1058            if triple_count >= self.config.max_context_triples {
1059                break;
1060            }
1061            if context.len() >= self.config.max_context_chars {
1062                break;
1063            }
1064
1065            let triple = &kg.triples[*idx];
1066            let line = if self.config.include_provenance {
1067                let ep = kg.provenance.get(*idx).map(|s| s.as_str()).unwrap_or("?");
1068                format!(
1069                    "- {} → {} → {} [{}]\n",
1070                    triple.subject, triple.predicate, triple.object, ep
1071                )
1072            } else {
1073                format!(
1074                    "- {} → {} → {}\n",
1075                    triple.subject, triple.predicate, triple.object
1076                )
1077            };
1078
1079            context.push_str(&line);
1080        }
1081
1082        Ok(context)
1083    }
1084}
1085
1086// ─────────────────────────────────────────────────────────────────────────────
1087// DistributedGraphRAGMetrics
1088// ─────────────────────────────────────────────────────────────────────────────
1089
1090/// Per-endpoint performance snapshot
1091#[derive(Debug, Clone, Serialize, Deserialize)]
1092pub struct EndpointMetrics {
1093    /// Endpoint name
1094    pub name: String,
1095    /// Total number of queries sent to this endpoint
1096    pub total_queries: u64,
1097    /// Number of successful queries
1098    pub successful_queries: u64,
1099    /// Number of failed queries
1100    pub failed_queries: u64,
1101    /// Total triples retrieved from this endpoint
1102    pub total_triples: u64,
1103    /// Exponential moving average of latency in milliseconds
1104    pub avg_latency_ms: f64,
1105    /// Minimum observed latency
1106    pub min_latency_ms: u64,
1107    /// Maximum observed latency
1108    pub max_latency_ms: u64,
1109    /// Hit rate: fraction of queries that returned ≥1 triple
1110    pub hit_rate: f64,
1111}
1112
1113impl EndpointMetrics {
1114    fn new(name: impl Into<String>) -> Self {
1115        Self {
1116            name: name.into(),
1117            total_queries: 0,
1118            successful_queries: 0,
1119            failed_queries: 0,
1120            total_triples: 0,
1121            avg_latency_ms: 0.0,
1122            min_latency_ms: u64::MAX,
1123            max_latency_ms: 0,
1124            hit_rate: 0.0,
1125        }
1126    }
1127}
1128
1129/// Aggregate metrics across all endpoints
1130#[derive(Debug, Clone, Default, Serialize, Deserialize)]
1131pub struct AggregateMetrics {
1132    /// Total federation queries
1133    pub total_federation_queries: u64,
1134    /// Total triples gathered across all queries
1135    pub total_triples_gathered: u64,
1136    /// Number of entity resolution operations
1137    pub entity_resolution_ops: u64,
1138    /// Average federation latency (wall-clock)
1139    pub avg_federation_latency_ms: f64,
1140    /// Number of partial result failures (some endpoints failed)
1141    pub partial_failure_count: u64,
1142}
1143
1144/// Thread-safe metrics tracker for distributed GraphRAG operations
1145pub struct DistributedGraphRAGMetrics {
1146    /// Per-endpoint counters
1147    endpoint_metrics: Arc<RwLock<HashMap<String, EndpointMetrics>>>,
1148    /// Aggregate counters
1149    aggregate: Arc<RwLock<AggregateMetrics>>,
1150    /// EMA smoothing factor (0 < alpha ≤ 1)
1151    ema_alpha: f64,
1152}
1153
1154impl DistributedGraphRAGMetrics {
1155    /// Create a new metrics tracker
1156    pub fn new(endpoints: &[EndpointConfig]) -> Self {
1157        let mut ep_map = HashMap::new();
1158        for ep in endpoints {
1159            ep_map.insert(ep.name.clone(), EndpointMetrics::new(&ep.name));
1160        }
1161
1162        Self {
1163            endpoint_metrics: Arc::new(RwLock::new(ep_map)),
1164            aggregate: Arc::new(RwLock::new(AggregateMetrics::default())),
1165            ema_alpha: 0.2,
1166        }
1167    }
1168
1169    /// Record a successful query result for an endpoint
1170    pub async fn record_success(&self, endpoint_name: &str, latency_ms: u64, triple_count: usize) {
1171        let mut guard = self.endpoint_metrics.write().await;
1172        let m = guard
1173            .entry(endpoint_name.to_string())
1174            .or_insert_with(|| EndpointMetrics::new(endpoint_name));
1175
1176        m.total_queries += 1;
1177        m.successful_queries += 1;
1178        m.total_triples += triple_count as u64;
1179
1180        // Update EMA latency
1181        if m.total_queries == 1 {
1182            m.avg_latency_ms = latency_ms as f64;
1183        } else {
1184            m.avg_latency_ms =
1185                self.ema_alpha * latency_ms as f64 + (1.0 - self.ema_alpha) * m.avg_latency_ms;
1186        }
1187
1188        // Update min/max
1189        if latency_ms < m.min_latency_ms {
1190            m.min_latency_ms = latency_ms;
1191        }
1192        if latency_ms > m.max_latency_ms {
1193            m.max_latency_ms = latency_ms;
1194        }
1195
1196        // Recompute hit rate
1197        let hits = m.successful_queries - if triple_count == 0 { 1 } else { 0 };
1198        m.hit_rate = hits as f64 / m.total_queries as f64;
1199    }
1200
1201    /// Record a failed query for an endpoint
1202    pub async fn record_failure(&self, endpoint_name: &str) {
1203        let mut guard = self.endpoint_metrics.write().await;
1204        let m = guard
1205            .entry(endpoint_name.to_string())
1206            .or_insert_with(|| EndpointMetrics::new(endpoint_name));
1207
1208        m.total_queries += 1;
1209        m.failed_queries += 1;
1210        // Recompute hit rate (failure counts as miss)
1211        m.hit_rate = if m.total_queries > 0 {
1212            m.successful_queries as f64 / m.total_queries as f64
1213        } else {
1214            0.0
1215        };
1216    }
1217
1218    /// Record a completed federation query
1219    pub async fn record_federation_query(
1220        &self,
1221        wall_latency_ms: u64,
1222        total_triples: usize,
1223        had_partial_failure: bool,
1224    ) {
1225        let mut agg = self.aggregate.write().await;
1226        agg.total_federation_queries += 1;
1227        agg.total_triples_gathered += total_triples as u64;
1228        if had_partial_failure {
1229            agg.partial_failure_count += 1;
1230        }
1231        if agg.total_federation_queries == 1 {
1232            agg.avg_federation_latency_ms = wall_latency_ms as f64;
1233        } else {
1234            agg.avg_federation_latency_ms = self.ema_alpha * wall_latency_ms as f64
1235                + (1.0 - self.ema_alpha) * agg.avg_federation_latency_ms;
1236        }
1237    }
1238
1239    /// Record an entity resolution operation
1240    pub async fn record_entity_resolution(&self) {
1241        let mut agg = self.aggregate.write().await;
1242        agg.entity_resolution_ops += 1;
1243    }
1244
1245    /// Retrieve a snapshot of metrics for a specific endpoint
1246    pub async fn endpoint_snapshot(&self, name: &str) -> Option<EndpointMetrics> {
1247        self.endpoint_metrics.read().await.get(name).cloned()
1248    }
1249
1250    /// Retrieve a snapshot of all endpoint metrics
1251    pub async fn all_endpoint_snapshots(&self) -> Vec<EndpointMetrics> {
1252        self.endpoint_metrics
1253            .read()
1254            .await
1255            .values()
1256            .cloned()
1257            .collect()
1258    }
1259
1260    /// Retrieve aggregate metrics
1261    pub async fn aggregate_snapshot(&self) -> AggregateMetrics {
1262        self.aggregate.read().await.clone()
1263    }
1264
1265    /// Return the endpoint name with the lowest average latency
1266    pub async fn fastest_endpoint(&self) -> Option<String> {
1267        let guard = self.endpoint_metrics.read().await;
1268        guard
1269            .values()
1270            .filter(|m| m.successful_queries > 0)
1271            .min_by(|a, b| {
1272                a.avg_latency_ms
1273                    .partial_cmp(&b.avg_latency_ms)
1274                    .unwrap_or(std::cmp::Ordering::Equal)
1275            })
1276            .map(|m| m.name.clone())
1277    }
1278
1279    /// Return the endpoint with the highest hit rate
1280    pub async fn best_hit_rate_endpoint(&self) -> Option<String> {
1281        let guard = self.endpoint_metrics.read().await;
1282        guard
1283            .values()
1284            .filter(|m| m.total_queries > 0)
1285            .max_by(|a, b| {
1286                a.hit_rate
1287                    .partial_cmp(&b.hit_rate)
1288                    .unwrap_or(std::cmp::Ordering::Equal)
1289            })
1290            .map(|m| m.name.clone())
1291    }
1292}
1293
1294// ─────────────────────────────────────────────────────────────────────────────
1295// Tests
1296// ─────────────────────────────────────────────────────────────────────────────
1297
1298#[cfg(test)]
1299mod tests {
1300    use super::*;
1301    use crate::{GraphRAGResult, ScoreSource};
1302    use async_trait::async_trait;
1303    use std::collections::HashMap;
1304
1305    // ── Mock executor ────────────────────────────────────────────────────────
1306
1307    struct MockExecutor {
1308        /// Triples returned by the `construct` call (keyed by endpoint name)
1309        triples_by_endpoint: HashMap<String, Vec<Triple>>,
1310        /// sameAs pairs returned by the `select` call (keyed by endpoint name)
1311        same_as_by_endpoint: HashMap<String, Vec<(String, String)>>,
1312    }
1313
1314    impl MockExecutor {
1315        fn new() -> Self {
1316            Self {
1317                triples_by_endpoint: HashMap::new(),
1318                same_as_by_endpoint: HashMap::new(),
1319            }
1320        }
1321
1322        fn with_triples(mut self, endpoint: &str, triples: Vec<Triple>) -> Self {
1323            self.triples_by_endpoint
1324                .insert(endpoint.to_string(), triples);
1325            self
1326        }
1327
1328        fn with_same_as(mut self, endpoint: &str, pairs: Vec<(String, String)>) -> Self {
1329            self.same_as_by_endpoint.insert(endpoint.to_string(), pairs);
1330            self
1331        }
1332    }
1333
1334    #[async_trait]
1335    impl EndpointExecutor for MockExecutor {
1336        async fn construct(
1337            &self,
1338            endpoint: &EndpointConfig,
1339            _sparql: &str,
1340            _timeout: Duration,
1341        ) -> GraphRAGResult<Vec<Triple>> {
1342            Ok(self
1343                .triples_by_endpoint
1344                .get(&endpoint.name)
1345                .cloned()
1346                .unwrap_or_default())
1347        }
1348
1349        async fn select(
1350            &self,
1351            endpoint: &EndpointConfig,
1352            _sparql: &str,
1353            _timeout: Duration,
1354        ) -> GraphRAGResult<Vec<HashMap<String, String>>> {
1355            let pairs = self
1356                .same_as_by_endpoint
1357                .get(&endpoint.name)
1358                .cloned()
1359                .unwrap_or_default();
1360            Ok(pairs
1361                .into_iter()
1362                .map(|(a, b)| {
1363                    let mut m = HashMap::new();
1364                    m.insert("a".to_string(), a);
1365                    m.insert("b".to_string(), b);
1366                    m
1367                })
1368                .collect())
1369        }
1370    }
1371
1372    // ── Helper constructors ──────────────────────────────────────────────────
1373
1374    fn make_endpoint(name: &str, priority: f64) -> EndpointConfig {
1375        EndpointConfig {
1376            name: name.to_string(),
1377            url: format!("http://example.org/{}/sparql", name),
1378            auth: EndpointAuth::None,
1379            timeout_ms: Some(5_000),
1380            priority,
1381            enabled: true,
1382            graph_uri: None,
1383            max_triples: 1_000,
1384        }
1385    }
1386
1387    fn make_seed(uri: &str, score: f64) -> ScoredEntity {
1388        ScoredEntity {
1389            uri: uri.to_string(),
1390            score,
1391            source: ScoreSource::Vector,
1392            metadata: HashMap::new(),
1393        }
1394    }
1395
1396    fn make_triple(s: &str, p: &str, o: &str) -> Triple {
1397        Triple::new(s, p, o)
1398    }
1399
1400    // ── test_federated_config_validation ─────────────────────────────────────
1401
1402    #[test]
1403    fn test_federated_config_validation_valid() {
1404        let config = FederatedGraphRAGConfig {
1405            endpoints: vec![make_endpoint("ep1", 1.0)],
1406            global_timeout_ms: 10_000,
1407            max_concurrency: 4,
1408            same_as_max_depth: 3,
1409            ..Default::default()
1410        };
1411        assert!(config.validate().is_ok());
1412    }
1413
1414    #[test]
1415    fn test_federated_config_validation_zero_timeout() {
1416        let config = FederatedGraphRAGConfig {
1417            global_timeout_ms: 0,
1418            ..Default::default()
1419        };
1420        assert!(config.validate().is_err());
1421    }
1422
1423    #[test]
1424    fn test_federated_config_validation_zero_concurrency() {
1425        let config = FederatedGraphRAGConfig {
1426            max_concurrency: 0,
1427            global_timeout_ms: 1_000,
1428            ..Default::default()
1429        };
1430        assert!(config.validate().is_err());
1431    }
1432
1433    #[test]
1434    fn test_federated_config_validation_empty_url() {
1435        let mut ep = make_endpoint("ep1", 1.0);
1436        ep.url = String::new();
1437        let config = FederatedGraphRAGConfig {
1438            endpoints: vec![ep],
1439            global_timeout_ms: 5_000,
1440            max_concurrency: 2,
1441            same_as_max_depth: 3,
1442            ..Default::default()
1443        };
1444        assert!(config.validate().is_err());
1445    }
1446
1447    #[test]
1448    fn test_federated_config_active_endpoints_filters_disabled() {
1449        let mut ep_disabled = make_endpoint("ep_off", 1.0);
1450        ep_disabled.enabled = false;
1451        let config = FederatedGraphRAGConfig {
1452            endpoints: vec![make_endpoint("ep_on", 1.0), ep_disabled],
1453            global_timeout_ms: 5_000,
1454            max_concurrency: 2,
1455            same_as_max_depth: 3,
1456            ..Default::default()
1457        };
1458        let active = config.active_endpoints();
1459        assert_eq!(active.len(), 1);
1460        assert_eq!(active[0].name, "ep_on");
1461    }
1462
1463    // ── test_federated_subgraph_expander ─────────────────────────────────────
1464
1465    #[tokio::test]
1466    async fn test_federated_expansion_merges_two_endpoints() {
1467        let triples_a = vec![
1468            make_triple("http://a/s1", "http://p", "http://a/o1"),
1469            make_triple("http://a/s2", "http://p", "http://a/o2"),
1470        ];
1471        let triples_b = vec![
1472            make_triple("http://b/s1", "http://p", "http://b/o1"),
1473            make_triple("http://a/s1", "http://p", "http://a/o1"), // duplicate
1474        ];
1475        let executor = MockExecutor::new()
1476            .with_triples("ep_a", triples_a)
1477            .with_triples("ep_b", triples_b);
1478
1479        let config = FederatedGraphRAGConfig {
1480            endpoints: vec![make_endpoint("ep_a", 2.0), make_endpoint("ep_b", 1.0)],
1481            global_timeout_ms: 5_000,
1482            max_concurrency: 4,
1483            same_as_max_depth: 3,
1484            partial_results_ok: true,
1485            ..Default::default()
1486        };
1487
1488        let expander = FederatedSubgraphExpander::new(config, Arc::new(executor));
1489        let seeds = vec![make_seed("http://a/s1", 0.9)];
1490        let kg = expander
1491            .expand_federated(&seeds, None)
1492            .await
1493            .expect("should succeed");
1494
1495        // 3 unique triples: 2 from ep_a + 1 new from ep_b (duplicate filtered)
1496        assert_eq!(kg.triple_count(), 3);
1497        assert!(!kg.is_empty());
1498    }
1499
1500    #[tokio::test]
1501    async fn test_federated_expansion_empty_seeds() {
1502        let executor = MockExecutor::new();
1503        let config = FederatedGraphRAGConfig {
1504            endpoints: vec![make_endpoint("ep_a", 1.0)],
1505            global_timeout_ms: 5_000,
1506            max_concurrency: 2,
1507            same_as_max_depth: 3,
1508            ..Default::default()
1509        };
1510        let expander = FederatedSubgraphExpander::new(config, Arc::new(executor));
1511        let kg = expander
1512            .expand_federated(&[], None)
1513            .await
1514            .expect("should succeed");
1515        assert!(kg.is_empty());
1516    }
1517
1518    #[tokio::test]
1519    async fn test_federated_expansion_no_active_endpoints() {
1520        let mut ep = make_endpoint("ep1", 1.0);
1521        ep.enabled = false;
1522        let executor = MockExecutor::new();
1523        let config = FederatedGraphRAGConfig {
1524            endpoints: vec![ep],
1525            global_timeout_ms: 5_000,
1526            max_concurrency: 2,
1527            same_as_max_depth: 3,
1528            ..Default::default()
1529        };
1530        let expander = FederatedSubgraphExpander::new(config, Arc::new(executor));
1531        let seeds = vec![make_seed("http://s", 0.9)];
1532        let result = expander.expand_federated(&seeds, None).await;
1533        assert!(result.is_err());
1534    }
1535
1536    // ── test_distributed_entity_resolver_sameAs ──────────────────────────────
1537
1538    #[tokio::test]
1539    async fn test_distributed_entity_resolver_same_as_direct() {
1540        let same_as_pairs = vec![("http://a/e1".to_string(), "http://b/e1".to_string())];
1541        let executor = MockExecutor::new().with_same_as("ep_a", same_as_pairs);
1542
1543        let config = FederatedGraphRAGConfig {
1544            endpoints: vec![make_endpoint("ep_a", 1.0)],
1545            global_timeout_ms: 5_000,
1546            max_concurrency: 2,
1547            same_as_max_depth: 3,
1548            ..Default::default()
1549        };
1550
1551        let resolver = DistributedEntityResolver::new(config, Arc::new(executor));
1552        let uris = vec!["http://a/e1".to_string()];
1553        let closure = resolver
1554            .same_as_closure(&uris)
1555            .await
1556            .expect("should succeed");
1557
1558        // Both http://a/e1 and http://b/e1 should map to the same canonical URI
1559        let canon_a = closure.get("http://a/e1").expect("should succeed");
1560        let canon_b = closure.get("http://b/e1").expect("should succeed");
1561        assert_eq!(
1562            canon_a, canon_b,
1563            "Same-as entities should share canonical URI"
1564        );
1565    }
1566
1567    #[tokio::test]
1568    async fn test_distributed_entity_resolver_no_links() {
1569        let executor = MockExecutor::new();
1570        let config = FederatedGraphRAGConfig {
1571            endpoints: vec![make_endpoint("ep_a", 1.0)],
1572            global_timeout_ms: 5_000,
1573            max_concurrency: 2,
1574            same_as_max_depth: 3,
1575            ..Default::default()
1576        };
1577
1578        let resolver = DistributedEntityResolver::new(config, Arc::new(executor));
1579        let uris = vec!["http://example.org/e1".to_string()];
1580        let closure = resolver
1581            .same_as_closure(&uris)
1582            .await
1583            .expect("should succeed");
1584
1585        // Without any sameAs links, e1 maps to itself
1586        let canon = closure
1587            .get("http://example.org/e1")
1588            .expect("should succeed");
1589        assert_eq!(canon, "http://example.org/e1");
1590    }
1591
1592    #[tokio::test]
1593    async fn test_distributed_entity_resolver_transitive_chain() {
1594        // A sameAs B, B sameAs C — all three should end up in the same class
1595        let same_as_pairs_ep1 = vec![("http://a/e1".to_string(), "http://b/e1".to_string())];
1596        let same_as_pairs_ep2 = vec![("http://b/e1".to_string(), "http://c/e1".to_string())];
1597        let executor = MockExecutor::new()
1598            .with_same_as("ep1", same_as_pairs_ep1)
1599            .with_same_as("ep2", same_as_pairs_ep2);
1600
1601        let config = FederatedGraphRAGConfig {
1602            endpoints: vec![make_endpoint("ep1", 1.0), make_endpoint("ep2", 1.0)],
1603            global_timeout_ms: 5_000,
1604            max_concurrency: 2,
1605            same_as_max_depth: 5,
1606            ..Default::default()
1607        };
1608
1609        let resolver = DistributedEntityResolver::new(config, Arc::new(executor));
1610        let uris = vec!["http://a/e1".to_string()];
1611        let closure = resolver
1612            .same_as_closure(&uris)
1613            .await
1614            .expect("should succeed");
1615
1616        // Check that the discovered URIs (at least a/e1 and b/e1) share a canonical form
1617        if let Some(canon_a) = closure.get("http://a/e1") {
1618            if let Some(canon_b) = closure.get("http://b/e1") {
1619                assert_eq!(canon_a, canon_b);
1620            }
1621        }
1622    }
1623
1624    #[test]
1625    fn test_apply_to_graph_rewrites_uris() {
1626        let executor = MockExecutor::new();
1627        let config = FederatedGraphRAGConfig::default();
1628        let resolver = DistributedEntityResolver::new(config, Arc::new(executor));
1629
1630        let mut kg = KnowledgeGraph::new();
1631        kg.triples = vec![
1632            make_triple("http://a/e1", "http://p", "http://b/e1"),
1633            make_triple("http://a/e1", "http://p", "http://a/e1"), // self-loop
1634        ];
1635        kg.provenance = vec!["ep_a".to_string(), "ep_a".to_string()];
1636
1637        let mut canonical = HashMap::new();
1638        canonical.insert("http://a/e1".to_string(), "http://canonical/e1".to_string());
1639        canonical.insert("http://b/e1".to_string(), "http://canonical/e1".to_string());
1640
1641        resolver.apply_to_graph(&mut kg, &canonical);
1642
1643        // After rewriting: both triples become <canonical/e1> <p> <canonical/e1>
1644        // which is the same — deduplication keeps only 1
1645        assert_eq!(kg.triple_count(), 1);
1646        assert_eq!(kg.triples[0].subject, "http://canonical/e1");
1647        assert_eq!(kg.triples[0].object, "http://canonical/e1");
1648    }
1649
1650    // ── test_federated_context_builder ───────────────────────────────────────
1651
1652    #[tokio::test]
1653    async fn test_federated_context_builder_basic() {
1654        let graphrag_config = FederatedGraphRAGConfig {
1655            endpoints: vec![make_endpoint("ep_a", 2.0), make_endpoint("ep_b", 1.0)],
1656            global_timeout_ms: 5_000,
1657            max_concurrency: 2,
1658            same_as_max_depth: 3,
1659            ..Default::default()
1660        };
1661
1662        let ctx_config = FederatedContextConfig {
1663            max_context_triples: 100,
1664            max_context_chars: 10_000,
1665            ordering: ContextOrderingStrategy::ByEndpointPriority,
1666            include_provenance: true,
1667            include_equivalences: false,
1668            ..Default::default()
1669        };
1670
1671        let builder = FederatedContextBuilder::new(ctx_config, &graphrag_config);
1672
1673        let mut kg = KnowledgeGraph::new();
1674        kg.triples = vec![
1675            make_triple("http://s1", "http://p", "http://o1"),
1676            make_triple("http://s2", "http://p", "http://o2"),
1677        ];
1678        kg.provenance = vec!["ep_a".to_string(), "ep_b".to_string()];
1679
1680        let context = builder
1681            .build_context(&kg, "test query")
1682            .await
1683            .expect("should succeed");
1684
1685        assert!(context.contains("test query"));
1686        assert!(context.contains("http://s1"));
1687        assert!(context.contains("http://s2"));
1688        // Provenance included
1689        assert!(context.contains("[ep_a]") || context.contains("[ep_b]"));
1690    }
1691
1692    #[tokio::test]
1693    async fn test_federated_context_builder_empty_kg() {
1694        let graphrag_config = FederatedGraphRAGConfig::default();
1695        let ctx_config = FederatedContextConfig::default();
1696        let builder = FederatedContextBuilder::new(ctx_config, &graphrag_config);
1697        let kg = KnowledgeGraph::new();
1698        let context = builder
1699            .build_context(&kg, "test")
1700            .await
1701            .expect("should succeed");
1702        assert!(context.is_empty());
1703    }
1704
1705    #[tokio::test]
1706    async fn test_federated_context_builder_respects_max_triples() {
1707        let graphrag_config = FederatedGraphRAGConfig {
1708            endpoints: vec![make_endpoint("ep_a", 1.0)],
1709            global_timeout_ms: 5_000,
1710            max_concurrency: 2,
1711            same_as_max_depth: 3,
1712            ..Default::default()
1713        };
1714
1715        let ctx_config = FederatedContextConfig {
1716            max_context_triples: 2,
1717            max_context_chars: 100_000,
1718            ordering: ContextOrderingStrategy::Insertion,
1719            include_provenance: false,
1720            include_equivalences: false,
1721            ..Default::default()
1722        };
1723
1724        let builder = FederatedContextBuilder::new(ctx_config, &graphrag_config);
1725
1726        let mut kg = KnowledgeGraph::new();
1727        kg.triples = (0..10)
1728            .map(|i| {
1729                make_triple(
1730                    &format!("http://s{}", i),
1731                    "http://p",
1732                    &format!("http://o{}", i),
1733                )
1734            })
1735            .collect();
1736        kg.provenance = (0..10).map(|_| "ep_a".to_string()).collect();
1737
1738        let context = builder
1739            .build_context(&kg, "test")
1740            .await
1741            .expect("should succeed");
1742
1743        // Count lines starting with "- " to determine triple count
1744        let triple_lines = context.lines().filter(|l| l.starts_with("- ")).count();
1745        assert!(
1746            triple_lines <= 2,
1747            "Expected at most 2 triples, got {}",
1748            triple_lines
1749        );
1750    }
1751
1752    // ── test_distributed_metrics_tracking ────────────────────────────────────
1753
1754    #[tokio::test]
1755    async fn test_distributed_metrics_tracking_success() {
1756        let endpoints = vec![make_endpoint("ep_a", 1.0), make_endpoint("ep_b", 1.0)];
1757        let metrics = DistributedGraphRAGMetrics::new(&endpoints);
1758
1759        metrics.record_success("ep_a", 150, 42).await;
1760        metrics.record_success("ep_a", 100, 30).await;
1761
1762        let snap = metrics
1763            .endpoint_snapshot("ep_a")
1764            .await
1765            .expect("should succeed");
1766        assert_eq!(snap.total_queries, 2);
1767        assert_eq!(snap.successful_queries, 2);
1768        assert_eq!(snap.failed_queries, 0);
1769        assert_eq!(snap.total_triples, 72);
1770        assert!(snap.avg_latency_ms > 0.0);
1771    }
1772
1773    #[tokio::test]
1774    async fn test_distributed_metrics_tracking_failure() {
1775        let endpoints = vec![make_endpoint("ep_a", 1.0)];
1776        let metrics = DistributedGraphRAGMetrics::new(&endpoints);
1777
1778        metrics.record_failure("ep_a").await;
1779        metrics.record_failure("ep_a").await;
1780
1781        let snap = metrics
1782            .endpoint_snapshot("ep_a")
1783            .await
1784            .expect("should succeed");
1785        assert_eq!(snap.total_queries, 2);
1786        assert_eq!(snap.failed_queries, 2);
1787        assert_eq!(snap.successful_queries, 0);
1788        assert_eq!(snap.hit_rate, 0.0);
1789    }
1790
1791    #[tokio::test]
1792    async fn test_distributed_metrics_aggregate() {
1793        let endpoints = vec![make_endpoint("ep_a", 1.0)];
1794        let metrics = DistributedGraphRAGMetrics::new(&endpoints);
1795
1796        metrics.record_federation_query(200, 100, false).await;
1797        metrics.record_federation_query(300, 50, true).await;
1798        metrics.record_entity_resolution().await;
1799
1800        let agg = metrics.aggregate_snapshot().await;
1801        assert_eq!(agg.total_federation_queries, 2);
1802        assert_eq!(agg.total_triples_gathered, 150);
1803        assert_eq!(agg.entity_resolution_ops, 1);
1804        assert_eq!(agg.partial_failure_count, 1);
1805        assert!(agg.avg_federation_latency_ms > 0.0);
1806    }
1807
1808    #[tokio::test]
1809    async fn test_distributed_metrics_fastest_endpoint() {
1810        let endpoints = vec![make_endpoint("ep_a", 1.0), make_endpoint("ep_b", 1.0)];
1811        let metrics = DistributedGraphRAGMetrics::new(&endpoints);
1812
1813        // ep_a is slow, ep_b is fast
1814        metrics.record_success("ep_a", 500, 10).await;
1815        metrics.record_success("ep_b", 50, 10).await;
1816
1817        let fastest = metrics.fastest_endpoint().await.expect("should succeed");
1818        assert_eq!(fastest, "ep_b");
1819    }
1820
1821    #[tokio::test]
1822    async fn test_distributed_metrics_hit_rate() {
1823        let endpoints = vec![make_endpoint("ep_a", 1.0)];
1824        let metrics = DistributedGraphRAGMetrics::new(&endpoints);
1825
1826        metrics.record_success("ep_a", 100, 5).await; // hit (triple_count > 0)
1827        metrics.record_failure("ep_a").await; // miss
1828
1829        let snap = metrics
1830            .endpoint_snapshot("ep_a")
1831            .await
1832            .expect("should succeed");
1833        assert_eq!(snap.total_queries, 2);
1834        // 1 success + 1 failure
1835        assert!(snap.hit_rate >= 0.0 && snap.hit_rate <= 1.0);
1836    }
1837
1838    // ── Parse helpers ────────────────────────────────────────────────────────
1839
1840    #[test]
1841    fn test_parse_n_triples_basic() {
1842        let body = "<http://s> <http://p> <http://o> .\n";
1843        let triples = parse_n_triples(body).expect("should succeed");
1844        assert_eq!(triples.len(), 1);
1845        assert_eq!(triples[0].subject, "http://s");
1846        assert_eq!(triples[0].predicate, "http://p");
1847        assert_eq!(triples[0].object, "http://o");
1848    }
1849
1850    #[test]
1851    fn test_parse_n_triples_skips_comments() {
1852        let body = "# comment\n<http://s> <http://p> <http://o> .\n";
1853        let triples = parse_n_triples(body).expect("should succeed");
1854        assert_eq!(triples.len(), 1);
1855    }
1856
1857    #[test]
1858    fn test_parse_n_triples_empty() {
1859        let triples = parse_n_triples("").expect("should succeed");
1860        assert!(triples.is_empty());
1861    }
1862
1863    #[test]
1864    fn test_build_seed_expansion_sparql_includes_seeds() {
1865        let sparql = build_seed_expansion_sparql(
1866            &["http://example.org/e1", "http://example.org/e2"],
1867            None,
1868            500,
1869        );
1870        assert!(sparql.contains("<http://example.org/e1>"));
1871        assert!(sparql.contains("<http://example.org/e2>"));
1872        assert!(sparql.contains("LIMIT 500"));
1873    }
1874
1875    #[test]
1876    fn test_build_seed_expansion_sparql_with_graph() {
1877        let sparql = build_seed_expansion_sparql(
1878            &["http://example.org/e1"],
1879            Some("http://example.org/graph"),
1880            100,
1881        );
1882        assert!(sparql.contains("FROM <http://example.org/graph>"));
1883    }
1884
1885    #[test]
1886    fn test_build_same_as_sparql() {
1887        let sparql = build_same_as_sparql(&["http://a/e1", "http://b/e1"], None);
1888        assert!(sparql.contains("owl#sameAs"));
1889        assert!(sparql.contains("<http://a/e1>"));
1890    }
1891
1892    #[test]
1893    fn test_knowledge_graph_canonical_lookup() {
1894        let mut kg = KnowledgeGraph::new();
1895        kg.canonical_uris
1896            .insert("http://b/e1".to_string(), "http://canonical/e1".to_string());
1897        assert_eq!(kg.canonical("http://b/e1"), "http://canonical/e1");
1898        assert_eq!(kg.canonical("http://unknown"), "http://unknown");
1899    }
1900
1901    #[test]
1902    fn test_endpoint_auth_variants() {
1903        let bearer = EndpointAuth::Bearer {
1904            token: "tok123".to_string(),
1905        };
1906        let basic = EndpointAuth::Basic {
1907            username: "user".to_string(),
1908            password: "pass".to_string(),
1909        };
1910        let api = EndpointAuth::ApiKey {
1911            header: "X-API-Key".to_string(),
1912            key: "key123".to_string(),
1913        };
1914        assert_ne!(bearer, EndpointAuth::None);
1915        assert_ne!(basic, EndpointAuth::None);
1916        assert_ne!(api, EndpointAuth::None);
1917    }
1918}