Skip to main content

oxirs_graphrag/
lib.rs

1//! # OxiRS GraphRAG
2//!
3//! [![Version](https://img.shields.io/badge/version-0.1.0-blue)](https://github.com/cool-japan/oxirs/releases)
4//!
5//! **Status**: Production Release (v0.1.0)
6//!
7//! GraphRAG (Graph Retrieval-Augmented Generation) combines vector similarity search
8//! with graph topology traversal for enhanced knowledge retrieval.
9//!
10//! ## Architecture
11//!
12//! ```text
13//! Query → Embed → Vector KNN + Keyword Search → Fusion → Graph Expansion → LLM Answer
14//! ```
15//!
16//! ## Key Features
17//!
18//! - **Hybrid Retrieval**: Vector similarity + BM25 keyword search
19//! - **Graph Expansion**: SPARQL-based N-hop neighbor traversal
20//! - **Community Detection**: Louvain algorithm for hierarchical summarization
21//! - **Context Building**: Intelligent subgraph extraction for LLM context
22//!
23//! ## Example
24//!
25//! ```rust,ignore
26//! use oxirs_graphrag::{GraphRAGEngine, GraphRAGConfig};
27//!
28//! let engine = GraphRAGEngine::new(config).await?;
29//! let result = engine.query("What safety issues affect battery cells?").await?;
30//! println!("Answer: {}", result.answer);
31//! ```
32
33pub mod cache;
34pub mod config;
35pub mod distributed;
36// v1.1.0: Graph summarization for RAG
37pub mod embeddings;
38pub mod federation;
39pub mod fusion;
40pub mod generation;
41pub mod graph;
42pub mod graph_summarization;
43pub mod query;
44pub mod reasoning;
45pub mod retrieval;
46pub mod sparql;
47pub mod streaming;
48pub mod temporal;
49
50// v1.1.0 TransE knowledge graph embedding model
51pub mod transe_model;
52
53// v1.1.0: Entity linking and disambiguation for knowledge graphs
54pub mod entity_linking;
55
56// v1.1.0 round 5: Community detection (Louvain-inspired greedy label propagation)
57pub mod community_detector;
58
59// v1.1.0 round 6: Knowledge graph path ranking (DFS + Dijkstra + scoring)
60pub mod path_ranker;
61
62// v1.1.0 round 7: String-to-RDF entity linking (mention detection + candidate ranking)
63pub mod entity_linker;
64
65// v1.1.0 round 11: Node2Vec-inspired graph embedding and structural node representations
66pub mod graph_embedder;
67
68// v1.1.0 round 12: Graph partitioning using greedy / label-propagation / bisection methods
69pub mod graph_partitioner;
70
71// v1.1.0 round 13: Rule-based knowledge triple extraction from natural language text
72pub mod triple_extractor;
73
74// v1.1.0 round 11: Multi-source knowledge fusion with provenance tracking
75pub mod knowledge_fusion;
76
77// v1.1.0 round 12: Context building for graph-based RAG (N-hop, ranking, truncation, formatting)
78pub mod context_builder;
79
80// v1.1.0 round 13: Graph path finding for RAG (BFS/DFS, shortest path, predicate filtering, scoring)
81pub mod path_finder;
82
83// v1.1.0 round 14: KG subgraph summarization via cluster-based abstraction
84pub mod summarizer;
85
86// v1.1.0 round 15: Entity type classification for knowledge graph nodes
87pub mod entity_classifier;
88
89use std::collections::HashMap;
90use std::sync::atomic::{AtomicU64, Ordering};
91use std::sync::Arc;
92use std::time::{Duration, SystemTime};
93
94use async_trait::async_trait;
95use chrono::{DateTime, Utc};
96use serde::{Deserialize, Serialize};
97use thiserror::Error;
98use tokio::sync::RwLock;
99
100// Re-exports
101pub use cache::query_cache::{CacheEntry, CacheStats, QueryCache, QueryCacheConfig};
102pub use config::{CacheConfiguration, GraphRAGConfig};
103pub use embeddings::node2vec::{
104    Node2VecConfig, Node2VecEmbedder, Node2VecEmbeddings, Node2VecWalkConfig,
105};
106pub use graph::community::{CommunityAlgorithm, CommunityConfig, CommunityDetector};
107pub use graph::embeddings::{CommunityAwareEmbeddings, CommunityStructure, EmbeddingConfig};
108pub use graph::traversal::GraphTraversal;
109pub use query::planner::QueryPlanner;
110pub use retrieval::fusion::FusionStrategy;
111
112/// GraphRAG error types
113#[derive(Error, Debug)]
114pub enum GraphRAGError {
115    #[error("Vector search failed: {0}")]
116    VectorSearchError(String),
117
118    #[error("Graph traversal failed: {0}")]
119    GraphTraversalError(String),
120
121    #[error("Community detection failed: {0}")]
122    CommunityDetectionError(String),
123
124    #[error("LLM generation failed: {0}")]
125    GenerationError(String),
126
127    #[error("Embedding failed: {0}")]
128    EmbeddingError(String),
129
130    #[error("SPARQL query failed: {0}")]
131    SparqlError(String),
132
133    #[error("Configuration error: {0}")]
134    ConfigError(String),
135
136    #[error("Internal error: {0}")]
137    InternalError(String),
138}
139
140pub type GraphRAGResult<T> = Result<T, GraphRAGError>;
141
142/// Triple representation for RDF data
143#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
144pub struct Triple {
145    pub subject: String,
146    pub predicate: String,
147    pub object: String,
148}
149
150impl Triple {
151    pub fn new(
152        subject: impl Into<String>,
153        predicate: impl Into<String>,
154        object: impl Into<String>,
155    ) -> Self {
156        Self {
157            subject: subject.into(),
158            predicate: predicate.into(),
159            object: object.into(),
160        }
161    }
162}
163
164/// Entity with relevance score
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct ScoredEntity {
167    /// Entity URI
168    pub uri: String,
169    /// Relevance score (0.0 - 1.0)
170    pub score: f64,
171    /// Source of the score (vector, keyword, or fused)
172    pub source: ScoreSource,
173    /// Additional metadata
174    pub metadata: HashMap<String, String>,
175}
176
177/// Source of entity score
178#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
179pub enum ScoreSource {
180    /// Score from vector similarity search
181    Vector,
182    /// Score from keyword/BM25 search
183    Keyword,
184    /// Fused score from multiple sources
185    Fused,
186    /// Score from graph traversal (path-based)
187    Graph,
188}
189
190/// Community summary for hierarchical retrieval
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct CommunitySummary {
193    /// Community identifier
194    pub id: String,
195    /// Human-readable summary of the community
196    pub summary: String,
197    /// Member entities in this community
198    pub entities: Vec<String>,
199    /// Representative triples from this community
200    pub representative_triples: Vec<Triple>,
201    /// Community level in hierarchy (0 = leaf, higher = more abstract)
202    pub level: u32,
203    /// Modularity score
204    pub modularity: f64,
205}
206
207/// Query provenance for attribution
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct QueryProvenance {
210    /// Query timestamp
211    pub timestamp: DateTime<Utc>,
212    /// Original query text
213    pub original_query: String,
214    /// Expanded query (if any)
215    pub expanded_query: Option<String>,
216    /// Seed entities used
217    pub seed_entities: Vec<String>,
218    /// Triples contributing to the answer
219    pub source_triples: Vec<Triple>,
220    /// Community summaries used (if hierarchical)
221    pub community_sources: Vec<String>,
222    /// Processing time in milliseconds
223    pub processing_time_ms: u64,
224}
225
226/// GraphRAG query result
227#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct GraphRAGResult2 {
229    /// Natural language answer
230    pub answer: String,
231    /// Source subgraph (RDF triples)
232    pub subgraph: Vec<Triple>,
233    /// Seed entities with scores
234    pub seeds: Vec<ScoredEntity>,
235    /// Community summaries (if enabled)
236    pub communities: Vec<CommunitySummary>,
237    /// Provenance information
238    pub provenance: QueryProvenance,
239    /// Confidence score (0.0 - 1.0)
240    pub confidence: f64,
241}
242
243/// Trait for vector index operations
244#[async_trait]
245pub trait VectorIndexTrait: Send + Sync {
246    /// Search for k nearest neighbors
247    async fn search_knn(
248        &self,
249        query_vector: &[f32],
250        k: usize,
251    ) -> GraphRAGResult<Vec<(String, f32)>>;
252
253    /// Search with similarity threshold
254    async fn search_threshold(
255        &self,
256        query_vector: &[f32],
257        threshold: f32,
258    ) -> GraphRAGResult<Vec<(String, f32)>>;
259}
260
261/// Trait for embedding model operations
262#[async_trait]
263pub trait EmbeddingModelTrait: Send + Sync {
264    /// Embed text into vector
265    async fn embed(&self, text: &str) -> GraphRAGResult<Vec<f32>>;
266
267    /// Embed multiple texts in batch
268    async fn embed_batch(&self, texts: &[&str]) -> GraphRAGResult<Vec<Vec<f32>>>;
269}
270
271/// Trait for SPARQL engine operations
272#[async_trait]
273pub trait SparqlEngineTrait: Send + Sync {
274    /// Execute SELECT query
275    async fn select(&self, query: &str) -> GraphRAGResult<Vec<HashMap<String, String>>>;
276
277    /// Execute ASK query
278    async fn ask(&self, query: &str) -> GraphRAGResult<bool>;
279
280    /// Execute CONSTRUCT query
281    async fn construct(&self, query: &str) -> GraphRAGResult<Vec<Triple>>;
282}
283
284/// Trait for LLM client operations
285#[async_trait]
286pub trait LlmClientTrait: Send + Sync {
287    /// Generate response from context and query
288    async fn generate(&self, context: &str, query: &str) -> GraphRAGResult<String>;
289
290    /// Generate with streaming response
291    async fn generate_stream(
292        &self,
293        context: &str,
294        query: &str,
295        callback: Box<dyn Fn(&str) + Send + Sync>,
296    ) -> GraphRAGResult<String>;
297}
298
299/// Cached result with metadata
300#[derive(Debug, Clone)]
301struct CachedResult {
302    result: GraphRAGResult2,
303    timestamp: SystemTime,
304    ttl: Duration,
305}
306
307impl CachedResult {
308    /// Check if the cached result is still fresh
309    fn is_fresh(&self) -> bool {
310        self.timestamp
311            .elapsed()
312            .map(|elapsed| elapsed < self.ttl)
313            .unwrap_or(false)
314    }
315}
316
317/// Cache configuration
318#[derive(Debug, Clone)]
319pub struct CacheConfig {
320    /// Base TTL in seconds (default: 3600 = 1 hour)
321    pub base_ttl_seconds: u64,
322    /// Minimum TTL in seconds (default: 300 = 5 minutes)
323    pub min_ttl_seconds: u64,
324    /// Maximum TTL in seconds (default: 86400 = 24 hours)
325    pub max_ttl_seconds: u64,
326    /// Enable adaptive TTL based on update frequency
327    pub adaptive: bool,
328}
329
330impl Default for CacheConfig {
331    fn default() -> Self {
332        Self {
333            base_ttl_seconds: 3600,
334            min_ttl_seconds: 300,
335            max_ttl_seconds: 86400,
336            adaptive: true,
337        }
338    }
339}
340
341/// Main GraphRAG engine
342pub struct GraphRAGEngine<V, E, S, L>
343where
344    V: VectorIndexTrait,
345    E: EmbeddingModelTrait,
346    S: SparqlEngineTrait,
347    L: LlmClientTrait,
348{
349    /// Vector index for similarity search
350    vec_index: Arc<V>,
351    /// Embedding model for query vectorization
352    embedding_model: Arc<E>,
353    /// SPARQL engine for graph traversal
354    sparql_engine: Arc<S>,
355    /// LLM client for answer generation
356    llm_client: Arc<L>,
357    /// Configuration
358    config: GraphRAGConfig,
359    /// Query result cache with adaptive TTL
360    cache: Arc<RwLock<lru::LruCache<String, CachedResult>>>,
361    /// Cache configuration
362    cache_config: CacheConfig,
363    /// Graph update counter for adaptive TTL
364    graph_update_count: Arc<AtomicU64>,
365    /// Community detector (lazy initialized)
366    community_detector: Option<Arc<CommunityDetector>>,
367}
368
369impl<V, E, S, L> GraphRAGEngine<V, E, S, L>
370where
371    V: VectorIndexTrait,
372    E: EmbeddingModelTrait,
373    S: SparqlEngineTrait,
374    L: LlmClientTrait,
375{
376    /// Create a new GraphRAG engine
377    pub fn new(
378        vec_index: Arc<V>,
379        embedding_model: Arc<E>,
380        sparql_engine: Arc<S>,
381        llm_client: Arc<L>,
382        config: GraphRAGConfig,
383    ) -> Self {
384        let cache_config = CacheConfig {
385            base_ttl_seconds: config.cache_config.base_ttl_seconds,
386            min_ttl_seconds: config.cache_config.min_ttl_seconds,
387            max_ttl_seconds: config.cache_config.max_ttl_seconds,
388            adaptive: config.cache_config.adaptive,
389        };
390
391        Self::with_cache_config(
392            vec_index,
393            embedding_model,
394            sparql_engine,
395            llm_client,
396            config,
397            cache_config,
398        )
399    }
400
401    /// Create a new GraphRAG engine with custom cache configuration
402    pub fn with_cache_config(
403        vec_index: Arc<V>,
404        embedding_model: Arc<E>,
405        sparql_engine: Arc<S>,
406        llm_client: Arc<L>,
407        config: GraphRAGConfig,
408        cache_config: CacheConfig,
409    ) -> Self {
410        const DEFAULT_CACHE_SIZE: std::num::NonZeroUsize = match std::num::NonZeroUsize::new(1000) {
411            Some(size) => size,
412            None => panic!("constant is non-zero"),
413        };
414
415        let cache_size = config
416            .cache_size
417            .and_then(std::num::NonZeroUsize::new)
418            .unwrap_or(DEFAULT_CACHE_SIZE);
419
420        Self {
421            vec_index,
422            embedding_model,
423            sparql_engine,
424            llm_client,
425            config,
426            cache: Arc::new(RwLock::new(lru::LruCache::new(cache_size))),
427            cache_config,
428            graph_update_count: Arc::new(AtomicU64::new(0)),
429            community_detector: None,
430        }
431    }
432
433    /// Calculate adaptive TTL based on graph update frequency
434    fn calculate_ttl(&self) -> Duration {
435        if !self.cache_config.adaptive {
436            return Duration::from_secs(self.cache_config.base_ttl_seconds);
437        }
438
439        let updates_per_hour = self.graph_update_count.load(Ordering::Relaxed) as f64;
440
441        // More updates = shorter TTL
442        let ttl_secs = if updates_per_hour > 100.0 {
443            self.cache_config.min_ttl_seconds // High update rate: 5 min TTL
444        } else if updates_per_hour > 10.0 {
445            self.cache_config.base_ttl_seconds / 2 // Medium: 30 min TTL
446        } else {
447            self.cache_config.max_ttl_seconds // Low update rate: 24 hour TTL
448        };
449
450        Duration::from_secs(ttl_secs)
451    }
452
453    /// Record graph update for adaptive TTL calculation
454    pub fn record_graph_update(&self) {
455        self.graph_update_count.fetch_add(1, Ordering::Relaxed);
456    }
457
458    /// Get current cache hit rate for monitoring
459    pub async fn get_cache_stats(&self) -> (usize, usize) {
460        let cache = self.cache.read().await;
461        (cache.len(), cache.cap().get())
462    }
463
464    /// Execute a GraphRAG query
465    pub async fn query(&self, query: &str) -> GraphRAGResult<GraphRAGResult2> {
466        let start_time = std::time::Instant::now();
467
468        // Check cache with freshness validation
469        {
470            let cache = self.cache.read().await;
471            if let Some(cached) = cache.peek(&query.to_string()) {
472                if cached.is_fresh() {
473                    return Ok(cached.result.clone());
474                }
475            }
476        }
477
478        // 1. Embed query
479        let query_vec = self.embedding_model.embed(query).await?;
480
481        // 2. Vector retrieval (Top-K)
482        let vector_results = self
483            .vec_index
484            .search_knn(&query_vec, self.config.top_k)
485            .await?;
486
487        // 3. Keyword retrieval (BM25) - simplified for now
488        let keyword_results = self.keyword_search(query).await?;
489
490        // 4. Fusion (RRF)
491        let seeds = self.fuse_results(&vector_results, &keyword_results)?;
492
493        // 5. Graph expansion (SPARQL)
494        let subgraph = self.expand_graph(&seeds).await?;
495
496        // 6. Community detection (optional)
497        let communities = if self.config.enable_communities {
498            self.detect_communities(&subgraph)?
499        } else {
500            vec![]
501        };
502
503        // 7. Build context
504        let context = self.build_context(&subgraph, &communities, query)?;
505
506        // 8. Generate answer
507        let answer = self.llm_client.generate(&context, query).await?;
508
509        // Calculate confidence based on seed scores and graph coverage
510        let confidence = self.calculate_confidence(&seeds, &subgraph);
511
512        let result = GraphRAGResult2 {
513            answer,
514            subgraph: subgraph.clone(),
515            seeds: seeds.clone(),
516            communities,
517            provenance: QueryProvenance {
518                timestamp: Utc::now(),
519                original_query: query.to_string(),
520                expanded_query: None,
521                seed_entities: seeds.iter().map(|s| s.uri.clone()).collect(),
522                source_triples: subgraph,
523                community_sources: vec![],
524                processing_time_ms: start_time.elapsed().as_millis() as u64,
525            },
526            confidence,
527        };
528
529        // Update cache with adaptive TTL
530        let ttl = self.calculate_ttl();
531        let cached = CachedResult {
532            result: result.clone(),
533            timestamp: SystemTime::now(),
534            ttl,
535        };
536        self.cache.write().await.put(query.to_string(), cached);
537
538        Ok(result)
539    }
540
541    /// Keyword search using BM25 (simplified)
542    async fn keyword_search(&self, query: &str) -> GraphRAGResult<Vec<(String, f32)>> {
543        // Build SPARQL query with text matching
544        let terms: Vec<&str> = query.split_whitespace().collect();
545        if terms.is_empty() {
546            return Ok(vec![]);
547        }
548
549        // Create SPARQL FILTER with regex for each term
550        let filters: Vec<String> = terms
551            .iter()
552            .map(|term| format!("REGEX(STR(?label), \"{}\", \"i\")", term))
553            .collect();
554
555        let sparql = format!(
556            r#"
557            SELECT DISTINCT ?entity (COUNT(*) AS ?score) WHERE {{
558                ?entity rdfs:label|schema:name|dc:title ?label .
559                FILTER({})
560            }}
561            GROUP BY ?entity
562            ORDER BY DESC(?score)
563            LIMIT {}
564            "#,
565            filters.join(" || "),
566            self.config.top_k
567        );
568
569        let results = self.sparql_engine.select(&sparql).await?;
570
571        Ok(results
572            .into_iter()
573            .filter_map(|row| {
574                let entity = row.get("entity")?.clone();
575                let score = row.get("score")?.parse::<f32>().ok()?;
576                Some((entity, score))
577            })
578            .collect())
579    }
580
581    /// Fuse vector and keyword results using Reciprocal Rank Fusion
582    fn fuse_results(
583        &self,
584        vector_results: &[(String, f32)],
585        keyword_results: &[(String, f32)],
586    ) -> GraphRAGResult<Vec<ScoredEntity>> {
587        let k = 60.0; // RRF constant
588
589        let mut scores: HashMap<String, (f64, ScoreSource)> = HashMap::new();
590
591        // Add vector scores
592        for (rank, (uri, score)) in vector_results.iter().enumerate() {
593            let rrf_score = 1.0 / (k + rank as f64 + 1.0);
594            scores.insert(
595                uri.clone(),
596                (
597                    rrf_score * self.config.vector_weight as f64,
598                    ScoreSource::Vector,
599                ),
600            );
601        }
602
603        // Add keyword scores
604        for (rank, (uri, _score)) in keyword_results.iter().enumerate() {
605            let rrf_score = 1.0 / (k + rank as f64 + 1.0);
606            let keyword_contribution = rrf_score * self.config.keyword_weight as f64;
607
608            match scores.get(uri).cloned() {
609                Some((existing_score, _)) => {
610                    let new_score = existing_score + keyword_contribution;
611                    scores.insert(uri.clone(), (new_score, ScoreSource::Fused));
612                }
613                None => {
614                    scores.insert(uri.clone(), (keyword_contribution, ScoreSource::Keyword));
615                }
616            }
617        }
618
619        // Sort by score and take top results
620        let mut entities: Vec<ScoredEntity> = scores
621            .into_iter()
622            .map(|(uri, (score, source))| ScoredEntity {
623                uri,
624                score,
625                source,
626                metadata: HashMap::new(),
627            })
628            .collect();
629
630        entities.sort_by(|a, b| {
631            b.score
632                .partial_cmp(&a.score)
633                .unwrap_or(std::cmp::Ordering::Equal)
634        });
635        entities.truncate(self.config.max_seeds);
636
637        Ok(entities)
638    }
639
640    /// Expand graph from seed entities using SPARQL
641    async fn expand_graph(&self, seeds: &[ScoredEntity]) -> GraphRAGResult<Vec<Triple>> {
642        if seeds.is_empty() {
643            return Ok(vec![]);
644        }
645
646        let seed_uris: Vec<String> = seeds.iter().map(|s| format!("<{}>", s.uri)).collect();
647        let values = seed_uris.join(" ");
648
649        // N-hop neighbor expansion
650        let hops = self.config.expansion_hops;
651        let path_pattern = if hops == 1 {
652            "?seed ?p ?neighbor".to_string()
653        } else {
654            format!("?seed (:|!:){{1,{}}} ?neighbor", hops)
655        };
656
657        let sparql = format!(
658            r#"
659            CONSTRUCT {{
660                ?seed ?p ?o .
661                ?s ?p2 ?seed .
662                ?neighbor ?p3 ?o2 .
663            }}
664            WHERE {{
665                VALUES ?seed {{ {} }}
666                {{
667                    ?seed ?p ?o .
668                }} UNION {{
669                    ?s ?p2 ?seed .
670                }} UNION {{
671                    {}
672                    ?neighbor ?p3 ?o2 .
673                }}
674            }}
675            LIMIT {}
676            "#,
677            values, path_pattern, self.config.max_subgraph_size
678        );
679
680        self.sparql_engine.construct(&sparql).await
681    }
682
683    /// Detect communities in the subgraph using Louvain algorithm
684    fn detect_communities(&self, subgraph: &[Triple]) -> GraphRAGResult<Vec<CommunitySummary>> {
685        use petgraph::graph::UnGraph;
686
687        if subgraph.is_empty() {
688            return Ok(vec![]);
689        }
690
691        // Build undirected graph
692        let mut graph: UnGraph<String, ()> = UnGraph::new_undirected();
693        let mut node_indices: HashMap<String, petgraph::graph::NodeIndex> = HashMap::new();
694
695        for triple in subgraph {
696            let subj_idx = *node_indices
697                .entry(triple.subject.clone())
698                .or_insert_with(|| graph.add_node(triple.subject.clone()));
699            let obj_idx = *node_indices
700                .entry(triple.object.clone())
701                .or_insert_with(|| graph.add_node(triple.object.clone()));
702
703            if subj_idx != obj_idx {
704                graph.add_edge(subj_idx, obj_idx, ());
705            }
706        }
707
708        // Simple community detection based on connected components
709        // (Full Louvain implementation would be more complex)
710        let components = petgraph::algo::kosaraju_scc(&graph);
711
712        let communities: Vec<CommunitySummary> = components
713            .into_iter()
714            .enumerate()
715            .filter(|(_, component)| component.len() >= 2)
716            .map(|(idx, component)| {
717                let entities: Vec<String> = component
718                    .iter()
719                    .filter_map(|&node_idx| graph.node_weight(node_idx).cloned())
720                    .collect();
721
722                let representative_triples: Vec<Triple> = subgraph
723                    .iter()
724                    .filter(|t| entities.contains(&t.subject) || entities.contains(&t.object))
725                    .take(5)
726                    .cloned()
727                    .collect();
728
729                CommunitySummary {
730                    id: format!("community_{}", idx),
731                    summary: format!("Community with {} entities", entities.len()),
732                    entities,
733                    representative_triples,
734                    level: 0,
735                    modularity: 0.0,
736                }
737            })
738            .collect();
739
740        Ok(communities)
741    }
742
743    /// Build context string for LLM from subgraph and communities
744    fn build_context(
745        &self,
746        subgraph: &[Triple],
747        communities: &[CommunitySummary],
748        _query: &str,
749    ) -> GraphRAGResult<String> {
750        let mut context = String::new();
751
752        // Add community summaries if available
753        if !communities.is_empty() {
754            context.push_str("## Community Context\n\n");
755            for community in communities {
756                context.push_str(&format!("### {}\n", community.id));
757                context.push_str(&format!("{}\n", community.summary));
758                context.push_str(&format!("Entities: {}\n\n", community.entities.join(", ")));
759            }
760        }
761
762        // Add relevant triples
763        context.push_str("## Knowledge Graph Facts\n\n");
764        for triple in subgraph.iter().take(self.config.max_context_triples) {
765            context.push_str(&format!(
766                "- {} → {} → {}\n",
767                triple.subject, triple.predicate, triple.object
768            ));
769        }
770
771        Ok(context)
772    }
773
774    /// Calculate confidence score based on retrieval quality
775    fn calculate_confidence(&self, seeds: &[ScoredEntity], subgraph: &[Triple]) -> f64 {
776        if seeds.is_empty() {
777            return 0.0;
778        }
779
780        // Average seed score
781        let avg_seed_score: f64 = seeds.iter().map(|s| s.score).sum::<f64>() / seeds.len() as f64;
782
783        // Graph coverage (how many seeds appear in subgraph)
784        let seed_uris: std::collections::HashSet<_> = seeds.iter().map(|s| &s.uri).collect();
785        let covered: usize = subgraph
786            .iter()
787            .filter(|t| seed_uris.contains(&t.subject) || seed_uris.contains(&t.object))
788            .count();
789        let coverage = if subgraph.is_empty() {
790            0.0
791        } else {
792            (covered as f64 / subgraph.len() as f64).min(1.0)
793        };
794
795        // Combined confidence
796        (avg_seed_score * 0.6 + coverage * 0.4).min(1.0)
797    }
798}
799
800#[cfg(test)]
801mod tests {
802    use super::*;
803
804    #[test]
805    fn test_triple_creation() {
806        let triple = Triple::new(
807            "http://example.org/s",
808            "http://example.org/p",
809            "http://example.org/o",
810        );
811        assert_eq!(triple.subject, "http://example.org/s");
812        assert_eq!(triple.predicate, "http://example.org/p");
813        assert_eq!(triple.object, "http://example.org/o");
814    }
815
816    #[test]
817    fn test_scored_entity() {
818        let entity = ScoredEntity {
819            uri: "http://example.org/entity".to_string(),
820            score: 0.85,
821            source: ScoreSource::Fused,
822            metadata: HashMap::new(),
823        };
824        assert_eq!(entity.score, 0.85);
825        assert_eq!(entity.source, ScoreSource::Fused);
826    }
827}