oxirs_graphrag/
lib.rs

1//! # OxiRS GraphRAG
2//!
3//! [![Version](https://img.shields.io/badge/version-0.1.0-blue)](https://github.com/cool-japan/oxirs/releases)
4//!
5//! **Status**: Production Release (v0.1.0)
6//!
7//! GraphRAG (Graph Retrieval-Augmented Generation) combines vector similarity search
8//! with graph topology traversal for enhanced knowledge retrieval.
9//!
10//! ## Architecture
11//!
12//! ```text
13//! Query → Embed → Vector KNN + Keyword Search → Fusion → Graph Expansion → LLM Answer
14//! ```
15//!
16//! ## Key Features
17//!
18//! - **Hybrid Retrieval**: Vector similarity + BM25 keyword search
19//! - **Graph Expansion**: SPARQL-based N-hop neighbor traversal
20//! - **Community Detection**: Louvain algorithm for hierarchical summarization
21//! - **Context Building**: Intelligent subgraph extraction for LLM context
22//!
23//! ## Example
24//!
25//! ```rust,ignore
26//! use oxirs_graphrag::{GraphRAGEngine, GraphRAGConfig};
27//!
28//! let engine = GraphRAGEngine::new(config).await?;
29//! let result = engine.query("What safety issues affect battery cells?").await?;
30//! println!("Answer: {}", result.answer);
31//! ```
32
33pub mod config;
34pub mod generation;
35pub mod graph;
36pub mod query;
37pub mod retrieval;
38pub mod sparql;
39
40use std::collections::HashMap;
41use std::sync::Arc;
42
43use async_trait::async_trait;
44use chrono::{DateTime, Utc};
45use serde::{Deserialize, Serialize};
46use thiserror::Error;
47use tokio::sync::RwLock;
48
49// Re-exports
50pub use config::GraphRAGConfig;
51pub use graph::community::CommunityDetector;
52pub use graph::traversal::GraphTraversal;
53pub use query::planner::QueryPlanner;
54pub use retrieval::fusion::FusionStrategy;
55
56/// GraphRAG error types
57#[derive(Error, Debug)]
58pub enum GraphRAGError {
59    #[error("Vector search failed: {0}")]
60    VectorSearchError(String),
61
62    #[error("Graph traversal failed: {0}")]
63    GraphTraversalError(String),
64
65    #[error("Community detection failed: {0}")]
66    CommunityDetectionError(String),
67
68    #[error("LLM generation failed: {0}")]
69    GenerationError(String),
70
71    #[error("Embedding failed: {0}")]
72    EmbeddingError(String),
73
74    #[error("SPARQL query failed: {0}")]
75    SparqlError(String),
76
77    #[error("Configuration error: {0}")]
78    ConfigError(String),
79
80    #[error("Internal error: {0}")]
81    InternalError(String),
82}
83
84pub type GraphRAGResult<T> = Result<T, GraphRAGError>;
85
86/// Triple representation for RDF data
87#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
88pub struct Triple {
89    pub subject: String,
90    pub predicate: String,
91    pub object: String,
92}
93
94impl Triple {
95    pub fn new(
96        subject: impl Into<String>,
97        predicate: impl Into<String>,
98        object: impl Into<String>,
99    ) -> Self {
100        Self {
101            subject: subject.into(),
102            predicate: predicate.into(),
103            object: object.into(),
104        }
105    }
106}
107
108/// Entity with relevance score
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct ScoredEntity {
111    /// Entity URI
112    pub uri: String,
113    /// Relevance score (0.0 - 1.0)
114    pub score: f64,
115    /// Source of the score (vector, keyword, or fused)
116    pub source: ScoreSource,
117    /// Additional metadata
118    pub metadata: HashMap<String, String>,
119}
120
121/// Source of entity score
122#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
123pub enum ScoreSource {
124    /// Score from vector similarity search
125    Vector,
126    /// Score from keyword/BM25 search
127    Keyword,
128    /// Fused score from multiple sources
129    Fused,
130    /// Score from graph traversal (path-based)
131    Graph,
132}
133
134/// Community summary for hierarchical retrieval
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct CommunitySummary {
137    /// Community identifier
138    pub id: String,
139    /// Human-readable summary of the community
140    pub summary: String,
141    /// Member entities in this community
142    pub entities: Vec<String>,
143    /// Representative triples from this community
144    pub representative_triples: Vec<Triple>,
145    /// Community level in hierarchy (0 = leaf, higher = more abstract)
146    pub level: u32,
147    /// Modularity score
148    pub modularity: f64,
149}
150
151/// Query provenance for attribution
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct QueryProvenance {
154    /// Query timestamp
155    pub timestamp: DateTime<Utc>,
156    /// Original query text
157    pub original_query: String,
158    /// Expanded query (if any)
159    pub expanded_query: Option<String>,
160    /// Seed entities used
161    pub seed_entities: Vec<String>,
162    /// Triples contributing to the answer
163    pub source_triples: Vec<Triple>,
164    /// Community summaries used (if hierarchical)
165    pub community_sources: Vec<String>,
166    /// Processing time in milliseconds
167    pub processing_time_ms: u64,
168}
169
170/// GraphRAG query result
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct GraphRAGResult2 {
173    /// Natural language answer
174    pub answer: String,
175    /// Source subgraph (RDF triples)
176    pub subgraph: Vec<Triple>,
177    /// Seed entities with scores
178    pub seeds: Vec<ScoredEntity>,
179    /// Community summaries (if enabled)
180    pub communities: Vec<CommunitySummary>,
181    /// Provenance information
182    pub provenance: QueryProvenance,
183    /// Confidence score (0.0 - 1.0)
184    pub confidence: f64,
185}
186
187/// Trait for vector index operations
188#[async_trait]
189pub trait VectorIndexTrait: Send + Sync {
190    /// Search for k nearest neighbors
191    async fn search_knn(
192        &self,
193        query_vector: &[f32],
194        k: usize,
195    ) -> GraphRAGResult<Vec<(String, f32)>>;
196
197    /// Search with similarity threshold
198    async fn search_threshold(
199        &self,
200        query_vector: &[f32],
201        threshold: f32,
202    ) -> GraphRAGResult<Vec<(String, f32)>>;
203}
204
205/// Trait for embedding model operations
206#[async_trait]
207pub trait EmbeddingModelTrait: Send + Sync {
208    /// Embed text into vector
209    async fn embed(&self, text: &str) -> GraphRAGResult<Vec<f32>>;
210
211    /// Embed multiple texts in batch
212    async fn embed_batch(&self, texts: &[&str]) -> GraphRAGResult<Vec<Vec<f32>>>;
213}
214
215/// Trait for SPARQL engine operations
216#[async_trait]
217pub trait SparqlEngineTrait: Send + Sync {
218    /// Execute SELECT query
219    async fn select(&self, query: &str) -> GraphRAGResult<Vec<HashMap<String, String>>>;
220
221    /// Execute ASK query
222    async fn ask(&self, query: &str) -> GraphRAGResult<bool>;
223
224    /// Execute CONSTRUCT query
225    async fn construct(&self, query: &str) -> GraphRAGResult<Vec<Triple>>;
226}
227
228/// Trait for LLM client operations
229#[async_trait]
230pub trait LlmClientTrait: Send + Sync {
231    /// Generate response from context and query
232    async fn generate(&self, context: &str, query: &str) -> GraphRAGResult<String>;
233
234    /// Generate with streaming response
235    async fn generate_stream(
236        &self,
237        context: &str,
238        query: &str,
239        callback: Box<dyn Fn(&str) + Send + Sync>,
240    ) -> GraphRAGResult<String>;
241}
242
243/// Main GraphRAG engine
244pub struct GraphRAGEngine<V, E, S, L>
245where
246    V: VectorIndexTrait,
247    E: EmbeddingModelTrait,
248    S: SparqlEngineTrait,
249    L: LlmClientTrait,
250{
251    /// Vector index for similarity search
252    vec_index: Arc<V>,
253    /// Embedding model for query vectorization
254    embedding_model: Arc<E>,
255    /// SPARQL engine for graph traversal
256    sparql_engine: Arc<S>,
257    /// LLM client for answer generation
258    llm_client: Arc<L>,
259    /// Configuration
260    config: GraphRAGConfig,
261    /// Query result cache
262    cache: Arc<RwLock<lru::LruCache<String, GraphRAGResult2>>>,
263    /// Community detector (lazy initialized)
264    community_detector: Option<Arc<CommunityDetector>>,
265}
266
267impl<V, E, S, L> GraphRAGEngine<V, E, S, L>
268where
269    V: VectorIndexTrait,
270    E: EmbeddingModelTrait,
271    S: SparqlEngineTrait,
272    L: LlmClientTrait,
273{
274    /// Create a new GraphRAG engine
275    pub fn new(
276        vec_index: Arc<V>,
277        embedding_model: Arc<E>,
278        sparql_engine: Arc<S>,
279        llm_client: Arc<L>,
280        config: GraphRAGConfig,
281    ) -> Self {
282        const DEFAULT_CACHE_SIZE: std::num::NonZeroUsize = match std::num::NonZeroUsize::new(1000) {
283            Some(size) => size,
284            None => panic!("constant is non-zero"),
285        };
286
287        let cache_size = config
288            .cache_size
289            .and_then(std::num::NonZeroUsize::new)
290            .unwrap_or(DEFAULT_CACHE_SIZE);
291
292        Self {
293            vec_index,
294            embedding_model,
295            sparql_engine,
296            llm_client,
297            config,
298            cache: Arc::new(RwLock::new(lru::LruCache::new(cache_size))),
299            community_detector: None,
300        }
301    }
302
303    /// Execute a GraphRAG query
304    pub async fn query(&self, query: &str) -> GraphRAGResult<GraphRAGResult2> {
305        let start_time = std::time::Instant::now();
306
307        // Check cache
308        if let Some(cached) = self.cache.read().await.peek(&query.to_string()) {
309            return Ok(cached.clone());
310        }
311
312        // 1. Embed query
313        let query_vec = self.embedding_model.embed(query).await?;
314
315        // 2. Vector retrieval (Top-K)
316        let vector_results = self
317            .vec_index
318            .search_knn(&query_vec, self.config.top_k)
319            .await?;
320
321        // 3. Keyword retrieval (BM25) - simplified for now
322        let keyword_results = self.keyword_search(query).await?;
323
324        // 4. Fusion (RRF)
325        let seeds = self.fuse_results(&vector_results, &keyword_results)?;
326
327        // 5. Graph expansion (SPARQL)
328        let subgraph = self.expand_graph(&seeds).await?;
329
330        // 6. Community detection (optional)
331        let communities = if self.config.enable_communities {
332            self.detect_communities(&subgraph)?
333        } else {
334            vec![]
335        };
336
337        // 7. Build context
338        let context = self.build_context(&subgraph, &communities, query)?;
339
340        // 8. Generate answer
341        let answer = self.llm_client.generate(&context, query).await?;
342
343        // Calculate confidence based on seed scores and graph coverage
344        let confidence = self.calculate_confidence(&seeds, &subgraph);
345
346        let result = GraphRAGResult2 {
347            answer,
348            subgraph: subgraph.clone(),
349            seeds: seeds.clone(),
350            communities,
351            provenance: QueryProvenance {
352                timestamp: Utc::now(),
353                original_query: query.to_string(),
354                expanded_query: None,
355                seed_entities: seeds.iter().map(|s| s.uri.clone()).collect(),
356                source_triples: subgraph,
357                community_sources: vec![],
358                processing_time_ms: start_time.elapsed().as_millis() as u64,
359            },
360            confidence,
361        };
362
363        // Update cache
364        self.cache
365            .write()
366            .await
367            .put(query.to_string(), result.clone());
368
369        Ok(result)
370    }
371
372    /// Keyword search using BM25 (simplified)
373    async fn keyword_search(&self, query: &str) -> GraphRAGResult<Vec<(String, f32)>> {
374        // Build SPARQL query with text matching
375        let terms: Vec<&str> = query.split_whitespace().collect();
376        if terms.is_empty() {
377            return Ok(vec![]);
378        }
379
380        // Create SPARQL FILTER with regex for each term
381        let filters: Vec<String> = terms
382            .iter()
383            .map(|term| format!("REGEX(STR(?label), \"{}\", \"i\")", term))
384            .collect();
385
386        let sparql = format!(
387            r#"
388            SELECT DISTINCT ?entity (COUNT(*) AS ?score) WHERE {{
389                ?entity rdfs:label|schema:name|dc:title ?label .
390                FILTER({})
391            }}
392            GROUP BY ?entity
393            ORDER BY DESC(?score)
394            LIMIT {}
395            "#,
396            filters.join(" || "),
397            self.config.top_k
398        );
399
400        let results = self.sparql_engine.select(&sparql).await?;
401
402        Ok(results
403            .into_iter()
404            .filter_map(|row| {
405                let entity = row.get("entity")?.clone();
406                let score = row.get("score")?.parse::<f32>().ok()?;
407                Some((entity, score))
408            })
409            .collect())
410    }
411
412    /// Fuse vector and keyword results using Reciprocal Rank Fusion
413    fn fuse_results(
414        &self,
415        vector_results: &[(String, f32)],
416        keyword_results: &[(String, f32)],
417    ) -> GraphRAGResult<Vec<ScoredEntity>> {
418        let k = 60.0; // RRF constant
419
420        let mut scores: HashMap<String, (f64, ScoreSource)> = HashMap::new();
421
422        // Add vector scores
423        for (rank, (uri, score)) in vector_results.iter().enumerate() {
424            let rrf_score = 1.0 / (k + rank as f64 + 1.0);
425            scores.insert(
426                uri.clone(),
427                (
428                    rrf_score * self.config.vector_weight as f64,
429                    ScoreSource::Vector,
430                ),
431            );
432        }
433
434        // Add keyword scores
435        for (rank, (uri, _score)) in keyword_results.iter().enumerate() {
436            let rrf_score = 1.0 / (k + rank as f64 + 1.0);
437            let keyword_contribution = rrf_score * self.config.keyword_weight as f64;
438
439            match scores.get(uri).cloned() {
440                Some((existing_score, _)) => {
441                    let new_score = existing_score + keyword_contribution;
442                    scores.insert(uri.clone(), (new_score, ScoreSource::Fused));
443                }
444                None => {
445                    scores.insert(uri.clone(), (keyword_contribution, ScoreSource::Keyword));
446                }
447            }
448        }
449
450        // Sort by score and take top results
451        let mut entities: Vec<ScoredEntity> = scores
452            .into_iter()
453            .map(|(uri, (score, source))| ScoredEntity {
454                uri,
455                score,
456                source,
457                metadata: HashMap::new(),
458            })
459            .collect();
460
461        entities.sort_by(|a, b| {
462            b.score
463                .partial_cmp(&a.score)
464                .unwrap_or(std::cmp::Ordering::Equal)
465        });
466        entities.truncate(self.config.max_seeds);
467
468        Ok(entities)
469    }
470
471    /// Expand graph from seed entities using SPARQL
472    async fn expand_graph(&self, seeds: &[ScoredEntity]) -> GraphRAGResult<Vec<Triple>> {
473        if seeds.is_empty() {
474            return Ok(vec![]);
475        }
476
477        let seed_uris: Vec<String> = seeds.iter().map(|s| format!("<{}>", s.uri)).collect();
478        let values = seed_uris.join(" ");
479
480        // N-hop neighbor expansion
481        let hops = self.config.expansion_hops;
482        let path_pattern = if hops == 1 {
483            "?seed ?p ?neighbor".to_string()
484        } else {
485            format!("?seed (:|!:){{1,{}}} ?neighbor", hops)
486        };
487
488        let sparql = format!(
489            r#"
490            CONSTRUCT {{
491                ?seed ?p ?o .
492                ?s ?p2 ?seed .
493                ?neighbor ?p3 ?o2 .
494            }}
495            WHERE {{
496                VALUES ?seed {{ {} }}
497                {{
498                    ?seed ?p ?o .
499                }} UNION {{
500                    ?s ?p2 ?seed .
501                }} UNION {{
502                    {}
503                    ?neighbor ?p3 ?o2 .
504                }}
505            }}
506            LIMIT {}
507            "#,
508            values, path_pattern, self.config.max_subgraph_size
509        );
510
511        self.sparql_engine.construct(&sparql).await
512    }
513
514    /// Detect communities in the subgraph using Louvain algorithm
515    fn detect_communities(&self, subgraph: &[Triple]) -> GraphRAGResult<Vec<CommunitySummary>> {
516        use petgraph::graph::UnGraph;
517
518        if subgraph.is_empty() {
519            return Ok(vec![]);
520        }
521
522        // Build undirected graph
523        let mut graph: UnGraph<String, ()> = UnGraph::new_undirected();
524        let mut node_indices: HashMap<String, petgraph::graph::NodeIndex> = HashMap::new();
525
526        for triple in subgraph {
527            let subj_idx = *node_indices
528                .entry(triple.subject.clone())
529                .or_insert_with(|| graph.add_node(triple.subject.clone()));
530            let obj_idx = *node_indices
531                .entry(triple.object.clone())
532                .or_insert_with(|| graph.add_node(triple.object.clone()));
533
534            if subj_idx != obj_idx {
535                graph.add_edge(subj_idx, obj_idx, ());
536            }
537        }
538
539        // Simple community detection based on connected components
540        // (Full Louvain implementation would be more complex)
541        let components = petgraph::algo::kosaraju_scc(&graph);
542
543        let communities: Vec<CommunitySummary> = components
544            .into_iter()
545            .enumerate()
546            .filter(|(_, component)| component.len() >= 2)
547            .map(|(idx, component)| {
548                let entities: Vec<String> = component
549                    .iter()
550                    .filter_map(|&node_idx| graph.node_weight(node_idx).cloned())
551                    .collect();
552
553                let representative_triples: Vec<Triple> = subgraph
554                    .iter()
555                    .filter(|t| entities.contains(&t.subject) || entities.contains(&t.object))
556                    .take(5)
557                    .cloned()
558                    .collect();
559
560                CommunitySummary {
561                    id: format!("community_{}", idx),
562                    summary: format!("Community with {} entities", entities.len()),
563                    entities,
564                    representative_triples,
565                    level: 0,
566                    modularity: 0.0,
567                }
568            })
569            .collect();
570
571        Ok(communities)
572    }
573
574    /// Build context string for LLM from subgraph and communities
575    fn build_context(
576        &self,
577        subgraph: &[Triple],
578        communities: &[CommunitySummary],
579        _query: &str,
580    ) -> GraphRAGResult<String> {
581        let mut context = String::new();
582
583        // Add community summaries if available
584        if !communities.is_empty() {
585            context.push_str("## Community Context\n\n");
586            for community in communities {
587                context.push_str(&format!("### {}\n", community.id));
588                context.push_str(&format!("{}\n", community.summary));
589                context.push_str(&format!("Entities: {}\n\n", community.entities.join(", ")));
590            }
591        }
592
593        // Add relevant triples
594        context.push_str("## Knowledge Graph Facts\n\n");
595        for triple in subgraph.iter().take(self.config.max_context_triples) {
596            context.push_str(&format!(
597                "- {} → {} → {}\n",
598                triple.subject, triple.predicate, triple.object
599            ));
600        }
601
602        Ok(context)
603    }
604
605    /// Calculate confidence score based on retrieval quality
606    fn calculate_confidence(&self, seeds: &[ScoredEntity], subgraph: &[Triple]) -> f64 {
607        if seeds.is_empty() {
608            return 0.0;
609        }
610
611        // Average seed score
612        let avg_seed_score: f64 = seeds.iter().map(|s| s.score).sum::<f64>() / seeds.len() as f64;
613
614        // Graph coverage (how many seeds appear in subgraph)
615        let seed_uris: std::collections::HashSet<_> = seeds.iter().map(|s| &s.uri).collect();
616        let covered: usize = subgraph
617            .iter()
618            .filter(|t| seed_uris.contains(&t.subject) || seed_uris.contains(&t.object))
619            .count();
620        let coverage = if subgraph.is_empty() {
621            0.0
622        } else {
623            (covered as f64 / subgraph.len() as f64).min(1.0)
624        };
625
626        // Combined confidence
627        (avg_seed_score * 0.6 + coverage * 0.4).min(1.0)
628    }
629}
630
631#[cfg(test)]
632mod tests {
633    use super::*;
634
635    #[test]
636    fn test_triple_creation() {
637        let triple = Triple::new(
638            "http://example.org/s",
639            "http://example.org/p",
640            "http://example.org/o",
641        );
642        assert_eq!(triple.subject, "http://example.org/s");
643        assert_eq!(triple.predicate, "http://example.org/p");
644        assert_eq!(triple.object, "http://example.org/o");
645    }
646
647    #[test]
648    fn test_scored_entity() {
649        let entity = ScoredEntity {
650            uri: "http://example.org/entity".to_string(),
651            score: 0.85,
652            source: ScoreSource::Fused,
653            metadata: HashMap::new(),
654        };
655        assert_eq!(entity.score, 0.85);
656        assert_eq!(entity.source, ScoreSource::Fused);
657    }
658}