oxirs_graphrag/
lib.rs

1//! # OxiRS GraphRAG
2//!
3//! GraphRAG (Graph Retrieval-Augmented Generation) combines vector similarity search
4//! with graph topology traversal for enhanced knowledge retrieval.
5//!
6//! ## Architecture
7//!
8//! ```text
9//! Query → Embed → Vector KNN + Keyword Search → Fusion → Graph Expansion → LLM Answer
10//! ```
11//!
12//! ## Key Features
13//!
14//! - **Hybrid Retrieval**: Vector similarity + BM25 keyword search
15//! - **Graph Expansion**: SPARQL-based N-hop neighbor traversal
16//! - **Community Detection**: Louvain algorithm for hierarchical summarization
17//! - **Context Building**: Intelligent subgraph extraction for LLM context
18//!
19//! ## Example
20//!
21//! ```rust,ignore
22//! use oxirs_graphrag::{GraphRAGEngine, GraphRAGConfig};
23//!
24//! let engine = GraphRAGEngine::new(config).await?;
25//! let result = engine.query("What safety issues affect battery cells?").await?;
26//! println!("Answer: {}", result.answer);
27//! ```
28
29pub mod config;
30pub mod generation;
31pub mod graph;
32pub mod query;
33pub mod retrieval;
34pub mod sparql;
35
36use std::collections::HashMap;
37use std::sync::Arc;
38
39use async_trait::async_trait;
40use chrono::{DateTime, Utc};
41use serde::{Deserialize, Serialize};
42use thiserror::Error;
43use tokio::sync::RwLock;
44
45// Re-exports
46pub use config::GraphRAGConfig;
47pub use graph::community::CommunityDetector;
48pub use graph::traversal::GraphTraversal;
49pub use query::planner::QueryPlanner;
50pub use retrieval::fusion::FusionStrategy;
51
52/// GraphRAG error types
53#[derive(Error, Debug)]
54pub enum GraphRAGError {
55    #[error("Vector search failed: {0}")]
56    VectorSearchError(String),
57
58    #[error("Graph traversal failed: {0}")]
59    GraphTraversalError(String),
60
61    #[error("Community detection failed: {0}")]
62    CommunityDetectionError(String),
63
64    #[error("LLM generation failed: {0}")]
65    GenerationError(String),
66
67    #[error("Embedding failed: {0}")]
68    EmbeddingError(String),
69
70    #[error("SPARQL query failed: {0}")]
71    SparqlError(String),
72
73    #[error("Configuration error: {0}")]
74    ConfigError(String),
75
76    #[error("Internal error: {0}")]
77    InternalError(String),
78}
79
80pub type GraphRAGResult<T> = Result<T, GraphRAGError>;
81
82/// Triple representation for RDF data
83#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
84pub struct Triple {
85    pub subject: String,
86    pub predicate: String,
87    pub object: String,
88}
89
90impl Triple {
91    pub fn new(
92        subject: impl Into<String>,
93        predicate: impl Into<String>,
94        object: impl Into<String>,
95    ) -> Self {
96        Self {
97            subject: subject.into(),
98            predicate: predicate.into(),
99            object: object.into(),
100        }
101    }
102}
103
104/// Entity with relevance score
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct ScoredEntity {
107    /// Entity URI
108    pub uri: String,
109    /// Relevance score (0.0 - 1.0)
110    pub score: f64,
111    /// Source of the score (vector, keyword, or fused)
112    pub source: ScoreSource,
113    /// Additional metadata
114    pub metadata: HashMap<String, String>,
115}
116
117/// Source of entity score
118#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
119pub enum ScoreSource {
120    /// Score from vector similarity search
121    Vector,
122    /// Score from keyword/BM25 search
123    Keyword,
124    /// Fused score from multiple sources
125    Fused,
126    /// Score from graph traversal (path-based)
127    Graph,
128}
129
130/// Community summary for hierarchical retrieval
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct CommunitySummary {
133    /// Community identifier
134    pub id: String,
135    /// Human-readable summary of the community
136    pub summary: String,
137    /// Member entities in this community
138    pub entities: Vec<String>,
139    /// Representative triples from this community
140    pub representative_triples: Vec<Triple>,
141    /// Community level in hierarchy (0 = leaf, higher = more abstract)
142    pub level: u32,
143    /// Modularity score
144    pub modularity: f64,
145}
146
147/// Query provenance for attribution
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct QueryProvenance {
150    /// Query timestamp
151    pub timestamp: DateTime<Utc>,
152    /// Original query text
153    pub original_query: String,
154    /// Expanded query (if any)
155    pub expanded_query: Option<String>,
156    /// Seed entities used
157    pub seed_entities: Vec<String>,
158    /// Triples contributing to the answer
159    pub source_triples: Vec<Triple>,
160    /// Community summaries used (if hierarchical)
161    pub community_sources: Vec<String>,
162    /// Processing time in milliseconds
163    pub processing_time_ms: u64,
164}
165
166/// GraphRAG query result
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct GraphRAGResult2 {
169    /// Natural language answer
170    pub answer: String,
171    /// Source subgraph (RDF triples)
172    pub subgraph: Vec<Triple>,
173    /// Seed entities with scores
174    pub seeds: Vec<ScoredEntity>,
175    /// Community summaries (if enabled)
176    pub communities: Vec<CommunitySummary>,
177    /// Provenance information
178    pub provenance: QueryProvenance,
179    /// Confidence score (0.0 - 1.0)
180    pub confidence: f64,
181}
182
183/// Trait for vector index operations
184#[async_trait]
185pub trait VectorIndexTrait: Send + Sync {
186    /// Search for k nearest neighbors
187    async fn search_knn(
188        &self,
189        query_vector: &[f32],
190        k: usize,
191    ) -> GraphRAGResult<Vec<(String, f32)>>;
192
193    /// Search with similarity threshold
194    async fn search_threshold(
195        &self,
196        query_vector: &[f32],
197        threshold: f32,
198    ) -> GraphRAGResult<Vec<(String, f32)>>;
199}
200
201/// Trait for embedding model operations
202#[async_trait]
203pub trait EmbeddingModelTrait: Send + Sync {
204    /// Embed text into vector
205    async fn embed(&self, text: &str) -> GraphRAGResult<Vec<f32>>;
206
207    /// Embed multiple texts in batch
208    async fn embed_batch(&self, texts: &[&str]) -> GraphRAGResult<Vec<Vec<f32>>>;
209}
210
211/// Trait for SPARQL engine operations
212#[async_trait]
213pub trait SparqlEngineTrait: Send + Sync {
214    /// Execute SELECT query
215    async fn select(&self, query: &str) -> GraphRAGResult<Vec<HashMap<String, String>>>;
216
217    /// Execute ASK query
218    async fn ask(&self, query: &str) -> GraphRAGResult<bool>;
219
220    /// Execute CONSTRUCT query
221    async fn construct(&self, query: &str) -> GraphRAGResult<Vec<Triple>>;
222}
223
224/// Trait for LLM client operations
225#[async_trait]
226pub trait LlmClientTrait: Send + Sync {
227    /// Generate response from context and query
228    async fn generate(&self, context: &str, query: &str) -> GraphRAGResult<String>;
229
230    /// Generate with streaming response
231    async fn generate_stream(
232        &self,
233        context: &str,
234        query: &str,
235        callback: Box<dyn Fn(&str) + Send + Sync>,
236    ) -> GraphRAGResult<String>;
237}
238
239/// Main GraphRAG engine
240pub struct GraphRAGEngine<V, E, S, L>
241where
242    V: VectorIndexTrait,
243    E: EmbeddingModelTrait,
244    S: SparqlEngineTrait,
245    L: LlmClientTrait,
246{
247    /// Vector index for similarity search
248    vec_index: Arc<V>,
249    /// Embedding model for query vectorization
250    embedding_model: Arc<E>,
251    /// SPARQL engine for graph traversal
252    sparql_engine: Arc<S>,
253    /// LLM client for answer generation
254    llm_client: Arc<L>,
255    /// Configuration
256    config: GraphRAGConfig,
257    /// Query result cache
258    cache: Arc<RwLock<lru::LruCache<String, GraphRAGResult2>>>,
259    /// Community detector (lazy initialized)
260    community_detector: Option<Arc<CommunityDetector>>,
261}
262
263impl<V, E, S, L> GraphRAGEngine<V, E, S, L>
264where
265    V: VectorIndexTrait,
266    E: EmbeddingModelTrait,
267    S: SparqlEngineTrait,
268    L: LlmClientTrait,
269{
270    /// Create a new GraphRAG engine
271    pub fn new(
272        vec_index: Arc<V>,
273        embedding_model: Arc<E>,
274        sparql_engine: Arc<S>,
275        llm_client: Arc<L>,
276        config: GraphRAGConfig,
277    ) -> Self {
278        const DEFAULT_CACHE_SIZE: std::num::NonZeroUsize = match std::num::NonZeroUsize::new(1000) {
279            Some(size) => size,
280            None => panic!("constant is non-zero"),
281        };
282
283        let cache_size = config
284            .cache_size
285            .and_then(std::num::NonZeroUsize::new)
286            .unwrap_or(DEFAULT_CACHE_SIZE);
287
288        Self {
289            vec_index,
290            embedding_model,
291            sparql_engine,
292            llm_client,
293            config,
294            cache: Arc::new(RwLock::new(lru::LruCache::new(cache_size))),
295            community_detector: None,
296        }
297    }
298
299    /// Execute a GraphRAG query
300    pub async fn query(&self, query: &str) -> GraphRAGResult<GraphRAGResult2> {
301        let start_time = std::time::Instant::now();
302
303        // Check cache
304        if let Some(cached) = self.cache.read().await.peek(&query.to_string()) {
305            return Ok(cached.clone());
306        }
307
308        // 1. Embed query
309        let query_vec = self.embedding_model.embed(query).await?;
310
311        // 2. Vector retrieval (Top-K)
312        let vector_results = self
313            .vec_index
314            .search_knn(&query_vec, self.config.top_k)
315            .await?;
316
317        // 3. Keyword retrieval (BM25) - simplified for now
318        let keyword_results = self.keyword_search(query).await?;
319
320        // 4. Fusion (RRF)
321        let seeds = self.fuse_results(&vector_results, &keyword_results)?;
322
323        // 5. Graph expansion (SPARQL)
324        let subgraph = self.expand_graph(&seeds).await?;
325
326        // 6. Community detection (optional)
327        let communities = if self.config.enable_communities {
328            self.detect_communities(&subgraph)?
329        } else {
330            vec![]
331        };
332
333        // 7. Build context
334        let context = self.build_context(&subgraph, &communities, query)?;
335
336        // 8. Generate answer
337        let answer = self.llm_client.generate(&context, query).await?;
338
339        // Calculate confidence based on seed scores and graph coverage
340        let confidence = self.calculate_confidence(&seeds, &subgraph);
341
342        let result = GraphRAGResult2 {
343            answer,
344            subgraph: subgraph.clone(),
345            seeds: seeds.clone(),
346            communities,
347            provenance: QueryProvenance {
348                timestamp: Utc::now(),
349                original_query: query.to_string(),
350                expanded_query: None,
351                seed_entities: seeds.iter().map(|s| s.uri.clone()).collect(),
352                source_triples: subgraph,
353                community_sources: vec![],
354                processing_time_ms: start_time.elapsed().as_millis() as u64,
355            },
356            confidence,
357        };
358
359        // Update cache
360        self.cache
361            .write()
362            .await
363            .put(query.to_string(), result.clone());
364
365        Ok(result)
366    }
367
368    /// Keyword search using BM25 (simplified)
369    async fn keyword_search(&self, query: &str) -> GraphRAGResult<Vec<(String, f32)>> {
370        // Build SPARQL query with text matching
371        let terms: Vec<&str> = query.split_whitespace().collect();
372        if terms.is_empty() {
373            return Ok(vec![]);
374        }
375
376        // Create SPARQL FILTER with regex for each term
377        let filters: Vec<String> = terms
378            .iter()
379            .map(|term| format!("REGEX(STR(?label), \"{}\", \"i\")", term))
380            .collect();
381
382        let sparql = format!(
383            r#"
384            SELECT DISTINCT ?entity (COUNT(*) AS ?score) WHERE {{
385                ?entity rdfs:label|schema:name|dc:title ?label .
386                FILTER({})
387            }}
388            GROUP BY ?entity
389            ORDER BY DESC(?score)
390            LIMIT {}
391            "#,
392            filters.join(" || "),
393            self.config.top_k
394        );
395
396        let results = self.sparql_engine.select(&sparql).await?;
397
398        Ok(results
399            .into_iter()
400            .filter_map(|row| {
401                let entity = row.get("entity")?.clone();
402                let score = row.get("score")?.parse::<f32>().ok()?;
403                Some((entity, score))
404            })
405            .collect())
406    }
407
408    /// Fuse vector and keyword results using Reciprocal Rank Fusion
409    fn fuse_results(
410        &self,
411        vector_results: &[(String, f32)],
412        keyword_results: &[(String, f32)],
413    ) -> GraphRAGResult<Vec<ScoredEntity>> {
414        let k = 60.0; // RRF constant
415
416        let mut scores: HashMap<String, (f64, ScoreSource)> = HashMap::new();
417
418        // Add vector scores
419        for (rank, (uri, score)) in vector_results.iter().enumerate() {
420            let rrf_score = 1.0 / (k + rank as f64 + 1.0);
421            scores.insert(
422                uri.clone(),
423                (
424                    rrf_score * self.config.vector_weight as f64,
425                    ScoreSource::Vector,
426                ),
427            );
428        }
429
430        // Add keyword scores
431        for (rank, (uri, _score)) in keyword_results.iter().enumerate() {
432            let rrf_score = 1.0 / (k + rank as f64 + 1.0);
433            let keyword_contribution = rrf_score * self.config.keyword_weight as f64;
434
435            match scores.get(uri).cloned() {
436                Some((existing_score, _)) => {
437                    let new_score = existing_score + keyword_contribution;
438                    scores.insert(uri.clone(), (new_score, ScoreSource::Fused));
439                }
440                None => {
441                    scores.insert(uri.clone(), (keyword_contribution, ScoreSource::Keyword));
442                }
443            }
444        }
445
446        // Sort by score and take top results
447        let mut entities: Vec<ScoredEntity> = scores
448            .into_iter()
449            .map(|(uri, (score, source))| ScoredEntity {
450                uri,
451                score,
452                source,
453                metadata: HashMap::new(),
454            })
455            .collect();
456
457        entities.sort_by(|a, b| {
458            b.score
459                .partial_cmp(&a.score)
460                .unwrap_or(std::cmp::Ordering::Equal)
461        });
462        entities.truncate(self.config.max_seeds);
463
464        Ok(entities)
465    }
466
467    /// Expand graph from seed entities using SPARQL
468    async fn expand_graph(&self, seeds: &[ScoredEntity]) -> GraphRAGResult<Vec<Triple>> {
469        if seeds.is_empty() {
470            return Ok(vec![]);
471        }
472
473        let seed_uris: Vec<String> = seeds.iter().map(|s| format!("<{}>", s.uri)).collect();
474        let values = seed_uris.join(" ");
475
476        // N-hop neighbor expansion
477        let hops = self.config.expansion_hops;
478        let path_pattern = if hops == 1 {
479            "?seed ?p ?neighbor".to_string()
480        } else {
481            format!("?seed (:|!:){{1,{}}} ?neighbor", hops)
482        };
483
484        let sparql = format!(
485            r#"
486            CONSTRUCT {{
487                ?seed ?p ?o .
488                ?s ?p2 ?seed .
489                ?neighbor ?p3 ?o2 .
490            }}
491            WHERE {{
492                VALUES ?seed {{ {} }}
493                {{
494                    ?seed ?p ?o .
495                }} UNION {{
496                    ?s ?p2 ?seed .
497                }} UNION {{
498                    {}
499                    ?neighbor ?p3 ?o2 .
500                }}
501            }}
502            LIMIT {}
503            "#,
504            values, path_pattern, self.config.max_subgraph_size
505        );
506
507        self.sparql_engine.construct(&sparql).await
508    }
509
510    /// Detect communities in the subgraph using Louvain algorithm
511    fn detect_communities(&self, subgraph: &[Triple]) -> GraphRAGResult<Vec<CommunitySummary>> {
512        use petgraph::graph::UnGraph;
513
514        if subgraph.is_empty() {
515            return Ok(vec![]);
516        }
517
518        // Build undirected graph
519        let mut graph: UnGraph<String, ()> = UnGraph::new_undirected();
520        let mut node_indices: HashMap<String, petgraph::graph::NodeIndex> = HashMap::new();
521
522        for triple in subgraph {
523            let subj_idx = *node_indices
524                .entry(triple.subject.clone())
525                .or_insert_with(|| graph.add_node(triple.subject.clone()));
526            let obj_idx = *node_indices
527                .entry(triple.object.clone())
528                .or_insert_with(|| graph.add_node(triple.object.clone()));
529
530            if subj_idx != obj_idx {
531                graph.add_edge(subj_idx, obj_idx, ());
532            }
533        }
534
535        // Simple community detection based on connected components
536        // (Full Louvain implementation would be more complex)
537        let components = petgraph::algo::kosaraju_scc(&graph);
538
539        let communities: Vec<CommunitySummary> = components
540            .into_iter()
541            .enumerate()
542            .filter(|(_, component)| component.len() >= 2)
543            .map(|(idx, component)| {
544                let entities: Vec<String> = component
545                    .iter()
546                    .filter_map(|&node_idx| graph.node_weight(node_idx).cloned())
547                    .collect();
548
549                let representative_triples: Vec<Triple> = subgraph
550                    .iter()
551                    .filter(|t| entities.contains(&t.subject) || entities.contains(&t.object))
552                    .take(5)
553                    .cloned()
554                    .collect();
555
556                CommunitySummary {
557                    id: format!("community_{}", idx),
558                    summary: format!("Community with {} entities", entities.len()),
559                    entities,
560                    representative_triples,
561                    level: 0,
562                    modularity: 0.0,
563                }
564            })
565            .collect();
566
567        Ok(communities)
568    }
569
570    /// Build context string for LLM from subgraph and communities
571    fn build_context(
572        &self,
573        subgraph: &[Triple],
574        communities: &[CommunitySummary],
575        _query: &str,
576    ) -> GraphRAGResult<String> {
577        let mut context = String::new();
578
579        // Add community summaries if available
580        if !communities.is_empty() {
581            context.push_str("## Community Context\n\n");
582            for community in communities {
583                context.push_str(&format!("### {}\n", community.id));
584                context.push_str(&format!("{}\n", community.summary));
585                context.push_str(&format!("Entities: {}\n\n", community.entities.join(", ")));
586            }
587        }
588
589        // Add relevant triples
590        context.push_str("## Knowledge Graph Facts\n\n");
591        for triple in subgraph.iter().take(self.config.max_context_triples) {
592            context.push_str(&format!(
593                "- {} → {} → {}\n",
594                triple.subject, triple.predicate, triple.object
595            ));
596        }
597
598        Ok(context)
599    }
600
601    /// Calculate confidence score based on retrieval quality
602    fn calculate_confidence(&self, seeds: &[ScoredEntity], subgraph: &[Triple]) -> f64 {
603        if seeds.is_empty() {
604            return 0.0;
605        }
606
607        // Average seed score
608        let avg_seed_score: f64 = seeds.iter().map(|s| s.score).sum::<f64>() / seeds.len() as f64;
609
610        // Graph coverage (how many seeds appear in subgraph)
611        let seed_uris: std::collections::HashSet<_> = seeds.iter().map(|s| &s.uri).collect();
612        let covered: usize = subgraph
613            .iter()
614            .filter(|t| seed_uris.contains(&t.subject) || seed_uris.contains(&t.object))
615            .count();
616        let coverage = if subgraph.is_empty() {
617            0.0
618        } else {
619            (covered as f64 / subgraph.len() as f64).min(1.0)
620        };
621
622        // Combined confidence
623        (avg_seed_score * 0.6 + coverage * 0.4).min(1.0)
624    }
625}
626
627#[cfg(test)]
628mod tests {
629    use super::*;
630
631    #[test]
632    fn test_triple_creation() {
633        let triple = Triple::new(
634            "http://example.org/s",
635            "http://example.org/p",
636            "http://example.org/o",
637        );
638        assert_eq!(triple.subject, "http://example.org/s");
639        assert_eq!(triple.predicate, "http://example.org/p");
640        assert_eq!(triple.object, "http://example.org/o");
641    }
642
643    #[test]
644    fn test_scored_entity() {
645        let entity = ScoredEntity {
646            uri: "http://example.org/entity".to_string(),
647            score: 0.85,
648            source: ScoreSource::Fused,
649            metadata: HashMap::new(),
650        };
651        assert_eq!(entity.score, 0.85);
652        assert_eq!(entity.source, ScoreSource::Fused);
653    }
654}