1pub mod config;
34pub mod generation;
35pub mod graph;
36pub mod query;
37pub mod retrieval;
38pub mod sparql;
39
40use std::collections::HashMap;
41use std::sync::Arc;
42
43use async_trait::async_trait;
44use chrono::{DateTime, Utc};
45use serde::{Deserialize, Serialize};
46use thiserror::Error;
47use tokio::sync::RwLock;
48
49pub use config::GraphRAGConfig;
51pub use graph::community::CommunityDetector;
52pub use graph::traversal::GraphTraversal;
53pub use query::planner::QueryPlanner;
54pub use retrieval::fusion::FusionStrategy;
55
56#[derive(Error, Debug)]
58pub enum GraphRAGError {
59 #[error("Vector search failed: {0}")]
60 VectorSearchError(String),
61
62 #[error("Graph traversal failed: {0}")]
63 GraphTraversalError(String),
64
65 #[error("Community detection failed: {0}")]
66 CommunityDetectionError(String),
67
68 #[error("LLM generation failed: {0}")]
69 GenerationError(String),
70
71 #[error("Embedding failed: {0}")]
72 EmbeddingError(String),
73
74 #[error("SPARQL query failed: {0}")]
75 SparqlError(String),
76
77 #[error("Configuration error: {0}")]
78 ConfigError(String),
79
80 #[error("Internal error: {0}")]
81 InternalError(String),
82}
83
84pub type GraphRAGResult<T> = Result<T, GraphRAGError>;
85
86#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
88pub struct Triple {
89 pub subject: String,
90 pub predicate: String,
91 pub object: String,
92}
93
94impl Triple {
95 pub fn new(
96 subject: impl Into<String>,
97 predicate: impl Into<String>,
98 object: impl Into<String>,
99 ) -> Self {
100 Self {
101 subject: subject.into(),
102 predicate: predicate.into(),
103 object: object.into(),
104 }
105 }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct ScoredEntity {
111 pub uri: String,
113 pub score: f64,
115 pub source: ScoreSource,
117 pub metadata: HashMap<String, String>,
119}
120
121#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
123pub enum ScoreSource {
124 Vector,
126 Keyword,
128 Fused,
130 Graph,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct CommunitySummary {
137 pub id: String,
139 pub summary: String,
141 pub entities: Vec<String>,
143 pub representative_triples: Vec<Triple>,
145 pub level: u32,
147 pub modularity: f64,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct QueryProvenance {
154 pub timestamp: DateTime<Utc>,
156 pub original_query: String,
158 pub expanded_query: Option<String>,
160 pub seed_entities: Vec<String>,
162 pub source_triples: Vec<Triple>,
164 pub community_sources: Vec<String>,
166 pub processing_time_ms: u64,
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct GraphRAGResult2 {
173 pub answer: String,
175 pub subgraph: Vec<Triple>,
177 pub seeds: Vec<ScoredEntity>,
179 pub communities: Vec<CommunitySummary>,
181 pub provenance: QueryProvenance,
183 pub confidence: f64,
185}
186
187#[async_trait]
189pub trait VectorIndexTrait: Send + Sync {
190 async fn search_knn(
192 &self,
193 query_vector: &[f32],
194 k: usize,
195 ) -> GraphRAGResult<Vec<(String, f32)>>;
196
197 async fn search_threshold(
199 &self,
200 query_vector: &[f32],
201 threshold: f32,
202 ) -> GraphRAGResult<Vec<(String, f32)>>;
203}
204
205#[async_trait]
207pub trait EmbeddingModelTrait: Send + Sync {
208 async fn embed(&self, text: &str) -> GraphRAGResult<Vec<f32>>;
210
211 async fn embed_batch(&self, texts: &[&str]) -> GraphRAGResult<Vec<Vec<f32>>>;
213}
214
215#[async_trait]
217pub trait SparqlEngineTrait: Send + Sync {
218 async fn select(&self, query: &str) -> GraphRAGResult<Vec<HashMap<String, String>>>;
220
221 async fn ask(&self, query: &str) -> GraphRAGResult<bool>;
223
224 async fn construct(&self, query: &str) -> GraphRAGResult<Vec<Triple>>;
226}
227
228#[async_trait]
230pub trait LlmClientTrait: Send + Sync {
231 async fn generate(&self, context: &str, query: &str) -> GraphRAGResult<String>;
233
234 async fn generate_stream(
236 &self,
237 context: &str,
238 query: &str,
239 callback: Box<dyn Fn(&str) + Send + Sync>,
240 ) -> GraphRAGResult<String>;
241}
242
243pub struct GraphRAGEngine<V, E, S, L>
245where
246 V: VectorIndexTrait,
247 E: EmbeddingModelTrait,
248 S: SparqlEngineTrait,
249 L: LlmClientTrait,
250{
251 vec_index: Arc<V>,
253 embedding_model: Arc<E>,
255 sparql_engine: Arc<S>,
257 llm_client: Arc<L>,
259 config: GraphRAGConfig,
261 cache: Arc<RwLock<lru::LruCache<String, GraphRAGResult2>>>,
263 community_detector: Option<Arc<CommunityDetector>>,
265}
266
267impl<V, E, S, L> GraphRAGEngine<V, E, S, L>
268where
269 V: VectorIndexTrait,
270 E: EmbeddingModelTrait,
271 S: SparqlEngineTrait,
272 L: LlmClientTrait,
273{
274 pub fn new(
276 vec_index: Arc<V>,
277 embedding_model: Arc<E>,
278 sparql_engine: Arc<S>,
279 llm_client: Arc<L>,
280 config: GraphRAGConfig,
281 ) -> Self {
282 const DEFAULT_CACHE_SIZE: std::num::NonZeroUsize = match std::num::NonZeroUsize::new(1000) {
283 Some(size) => size,
284 None => panic!("constant is non-zero"),
285 };
286
287 let cache_size = config
288 .cache_size
289 .and_then(std::num::NonZeroUsize::new)
290 .unwrap_or(DEFAULT_CACHE_SIZE);
291
292 Self {
293 vec_index,
294 embedding_model,
295 sparql_engine,
296 llm_client,
297 config,
298 cache: Arc::new(RwLock::new(lru::LruCache::new(cache_size))),
299 community_detector: None,
300 }
301 }
302
303 pub async fn query(&self, query: &str) -> GraphRAGResult<GraphRAGResult2> {
305 let start_time = std::time::Instant::now();
306
307 if let Some(cached) = self.cache.read().await.peek(&query.to_string()) {
309 return Ok(cached.clone());
310 }
311
312 let query_vec = self.embedding_model.embed(query).await?;
314
315 let vector_results = self
317 .vec_index
318 .search_knn(&query_vec, self.config.top_k)
319 .await?;
320
321 let keyword_results = self.keyword_search(query).await?;
323
324 let seeds = self.fuse_results(&vector_results, &keyword_results)?;
326
327 let subgraph = self.expand_graph(&seeds).await?;
329
330 let communities = if self.config.enable_communities {
332 self.detect_communities(&subgraph)?
333 } else {
334 vec![]
335 };
336
337 let context = self.build_context(&subgraph, &communities, query)?;
339
340 let answer = self.llm_client.generate(&context, query).await?;
342
343 let confidence = self.calculate_confidence(&seeds, &subgraph);
345
346 let result = GraphRAGResult2 {
347 answer,
348 subgraph: subgraph.clone(),
349 seeds: seeds.clone(),
350 communities,
351 provenance: QueryProvenance {
352 timestamp: Utc::now(),
353 original_query: query.to_string(),
354 expanded_query: None,
355 seed_entities: seeds.iter().map(|s| s.uri.clone()).collect(),
356 source_triples: subgraph,
357 community_sources: vec![],
358 processing_time_ms: start_time.elapsed().as_millis() as u64,
359 },
360 confidence,
361 };
362
363 self.cache
365 .write()
366 .await
367 .put(query.to_string(), result.clone());
368
369 Ok(result)
370 }
371
372 async fn keyword_search(&self, query: &str) -> GraphRAGResult<Vec<(String, f32)>> {
374 let terms: Vec<&str> = query.split_whitespace().collect();
376 if terms.is_empty() {
377 return Ok(vec![]);
378 }
379
380 let filters: Vec<String> = terms
382 .iter()
383 .map(|term| format!("REGEX(STR(?label), \"{}\", \"i\")", term))
384 .collect();
385
386 let sparql = format!(
387 r#"
388 SELECT DISTINCT ?entity (COUNT(*) AS ?score) WHERE {{
389 ?entity rdfs:label|schema:name|dc:title ?label .
390 FILTER({})
391 }}
392 GROUP BY ?entity
393 ORDER BY DESC(?score)
394 LIMIT {}
395 "#,
396 filters.join(" || "),
397 self.config.top_k
398 );
399
400 let results = self.sparql_engine.select(&sparql).await?;
401
402 Ok(results
403 .into_iter()
404 .filter_map(|row| {
405 let entity = row.get("entity")?.clone();
406 let score = row.get("score")?.parse::<f32>().ok()?;
407 Some((entity, score))
408 })
409 .collect())
410 }
411
412 fn fuse_results(
414 &self,
415 vector_results: &[(String, f32)],
416 keyword_results: &[(String, f32)],
417 ) -> GraphRAGResult<Vec<ScoredEntity>> {
418 let k = 60.0; let mut scores: HashMap<String, (f64, ScoreSource)> = HashMap::new();
421
422 for (rank, (uri, score)) in vector_results.iter().enumerate() {
424 let rrf_score = 1.0 / (k + rank as f64 + 1.0);
425 scores.insert(
426 uri.clone(),
427 (
428 rrf_score * self.config.vector_weight as f64,
429 ScoreSource::Vector,
430 ),
431 );
432 }
433
434 for (rank, (uri, _score)) in keyword_results.iter().enumerate() {
436 let rrf_score = 1.0 / (k + rank as f64 + 1.0);
437 let keyword_contribution = rrf_score * self.config.keyword_weight as f64;
438
439 match scores.get(uri).cloned() {
440 Some((existing_score, _)) => {
441 let new_score = existing_score + keyword_contribution;
442 scores.insert(uri.clone(), (new_score, ScoreSource::Fused));
443 }
444 None => {
445 scores.insert(uri.clone(), (keyword_contribution, ScoreSource::Keyword));
446 }
447 }
448 }
449
450 let mut entities: Vec<ScoredEntity> = scores
452 .into_iter()
453 .map(|(uri, (score, source))| ScoredEntity {
454 uri,
455 score,
456 source,
457 metadata: HashMap::new(),
458 })
459 .collect();
460
461 entities.sort_by(|a, b| {
462 b.score
463 .partial_cmp(&a.score)
464 .unwrap_or(std::cmp::Ordering::Equal)
465 });
466 entities.truncate(self.config.max_seeds);
467
468 Ok(entities)
469 }
470
471 async fn expand_graph(&self, seeds: &[ScoredEntity]) -> GraphRAGResult<Vec<Triple>> {
473 if seeds.is_empty() {
474 return Ok(vec![]);
475 }
476
477 let seed_uris: Vec<String> = seeds.iter().map(|s| format!("<{}>", s.uri)).collect();
478 let values = seed_uris.join(" ");
479
480 let hops = self.config.expansion_hops;
482 let path_pattern = if hops == 1 {
483 "?seed ?p ?neighbor".to_string()
484 } else {
485 format!("?seed (:|!:){{1,{}}} ?neighbor", hops)
486 };
487
488 let sparql = format!(
489 r#"
490 CONSTRUCT {{
491 ?seed ?p ?o .
492 ?s ?p2 ?seed .
493 ?neighbor ?p3 ?o2 .
494 }}
495 WHERE {{
496 VALUES ?seed {{ {} }}
497 {{
498 ?seed ?p ?o .
499 }} UNION {{
500 ?s ?p2 ?seed .
501 }} UNION {{
502 {}
503 ?neighbor ?p3 ?o2 .
504 }}
505 }}
506 LIMIT {}
507 "#,
508 values, path_pattern, self.config.max_subgraph_size
509 );
510
511 self.sparql_engine.construct(&sparql).await
512 }
513
514 fn detect_communities(&self, subgraph: &[Triple]) -> GraphRAGResult<Vec<CommunitySummary>> {
516 use petgraph::graph::UnGraph;
517
518 if subgraph.is_empty() {
519 return Ok(vec![]);
520 }
521
522 let mut graph: UnGraph<String, ()> = UnGraph::new_undirected();
524 let mut node_indices: HashMap<String, petgraph::graph::NodeIndex> = HashMap::new();
525
526 for triple in subgraph {
527 let subj_idx = *node_indices
528 .entry(triple.subject.clone())
529 .or_insert_with(|| graph.add_node(triple.subject.clone()));
530 let obj_idx = *node_indices
531 .entry(triple.object.clone())
532 .or_insert_with(|| graph.add_node(triple.object.clone()));
533
534 if subj_idx != obj_idx {
535 graph.add_edge(subj_idx, obj_idx, ());
536 }
537 }
538
539 let components = petgraph::algo::kosaraju_scc(&graph);
542
543 let communities: Vec<CommunitySummary> = components
544 .into_iter()
545 .enumerate()
546 .filter(|(_, component)| component.len() >= 2)
547 .map(|(idx, component)| {
548 let entities: Vec<String> = component
549 .iter()
550 .filter_map(|&node_idx| graph.node_weight(node_idx).cloned())
551 .collect();
552
553 let representative_triples: Vec<Triple> = subgraph
554 .iter()
555 .filter(|t| entities.contains(&t.subject) || entities.contains(&t.object))
556 .take(5)
557 .cloned()
558 .collect();
559
560 CommunitySummary {
561 id: format!("community_{}", idx),
562 summary: format!("Community with {} entities", entities.len()),
563 entities,
564 representative_triples,
565 level: 0,
566 modularity: 0.0,
567 }
568 })
569 .collect();
570
571 Ok(communities)
572 }
573
574 fn build_context(
576 &self,
577 subgraph: &[Triple],
578 communities: &[CommunitySummary],
579 _query: &str,
580 ) -> GraphRAGResult<String> {
581 let mut context = String::new();
582
583 if !communities.is_empty() {
585 context.push_str("## Community Context\n\n");
586 for community in communities {
587 context.push_str(&format!("### {}\n", community.id));
588 context.push_str(&format!("{}\n", community.summary));
589 context.push_str(&format!("Entities: {}\n\n", community.entities.join(", ")));
590 }
591 }
592
593 context.push_str("## Knowledge Graph Facts\n\n");
595 for triple in subgraph.iter().take(self.config.max_context_triples) {
596 context.push_str(&format!(
597 "- {} → {} → {}\n",
598 triple.subject, triple.predicate, triple.object
599 ));
600 }
601
602 Ok(context)
603 }
604
605 fn calculate_confidence(&self, seeds: &[ScoredEntity], subgraph: &[Triple]) -> f64 {
607 if seeds.is_empty() {
608 return 0.0;
609 }
610
611 let avg_seed_score: f64 = seeds.iter().map(|s| s.score).sum::<f64>() / seeds.len() as f64;
613
614 let seed_uris: std::collections::HashSet<_> = seeds.iter().map(|s| &s.uri).collect();
616 let covered: usize = subgraph
617 .iter()
618 .filter(|t| seed_uris.contains(&t.subject) || seed_uris.contains(&t.object))
619 .count();
620 let coverage = if subgraph.is_empty() {
621 0.0
622 } else {
623 (covered as f64 / subgraph.len() as f64).min(1.0)
624 };
625
626 (avg_seed_score * 0.6 + coverage * 0.4).min(1.0)
628 }
629}
630
631#[cfg(test)]
632mod tests {
633 use super::*;
634
635 #[test]
636 fn test_triple_creation() {
637 let triple = Triple::new(
638 "http://example.org/s",
639 "http://example.org/p",
640 "http://example.org/o",
641 );
642 assert_eq!(triple.subject, "http://example.org/s");
643 assert_eq!(triple.predicate, "http://example.org/p");
644 assert_eq!(triple.object, "http://example.org/o");
645 }
646
647 #[test]
648 fn test_scored_entity() {
649 let entity = ScoredEntity {
650 uri: "http://example.org/entity".to_string(),
651 score: 0.85,
652 source: ScoreSource::Fused,
653 metadata: HashMap::new(),
654 };
655 assert_eq!(entity.score, 0.85);
656 assert_eq!(entity.source, ScoreSource::Fused);
657 }
658}