1pub mod config;
30pub mod generation;
31pub mod graph;
32pub mod query;
33pub mod retrieval;
34pub mod sparql;
35
36use std::collections::HashMap;
37use std::sync::Arc;
38
39use async_trait::async_trait;
40use chrono::{DateTime, Utc};
41use serde::{Deserialize, Serialize};
42use thiserror::Error;
43use tokio::sync::RwLock;
44
45pub use config::GraphRAGConfig;
47pub use graph::community::CommunityDetector;
48pub use graph::traversal::GraphTraversal;
49pub use query::planner::QueryPlanner;
50pub use retrieval::fusion::FusionStrategy;
51
52#[derive(Error, Debug)]
54pub enum GraphRAGError {
55 #[error("Vector search failed: {0}")]
56 VectorSearchError(String),
57
58 #[error("Graph traversal failed: {0}")]
59 GraphTraversalError(String),
60
61 #[error("Community detection failed: {0}")]
62 CommunityDetectionError(String),
63
64 #[error("LLM generation failed: {0}")]
65 GenerationError(String),
66
67 #[error("Embedding failed: {0}")]
68 EmbeddingError(String),
69
70 #[error("SPARQL query failed: {0}")]
71 SparqlError(String),
72
73 #[error("Configuration error: {0}")]
74 ConfigError(String),
75
76 #[error("Internal error: {0}")]
77 InternalError(String),
78}
79
80pub type GraphRAGResult<T> = Result<T, GraphRAGError>;
81
82#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
84pub struct Triple {
85 pub subject: String,
86 pub predicate: String,
87 pub object: String,
88}
89
90impl Triple {
91 pub fn new(
92 subject: impl Into<String>,
93 predicate: impl Into<String>,
94 object: impl Into<String>,
95 ) -> Self {
96 Self {
97 subject: subject.into(),
98 predicate: predicate.into(),
99 object: object.into(),
100 }
101 }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct ScoredEntity {
107 pub uri: String,
109 pub score: f64,
111 pub source: ScoreSource,
113 pub metadata: HashMap<String, String>,
115}
116
117#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
119pub enum ScoreSource {
120 Vector,
122 Keyword,
124 Fused,
126 Graph,
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct CommunitySummary {
133 pub id: String,
135 pub summary: String,
137 pub entities: Vec<String>,
139 pub representative_triples: Vec<Triple>,
141 pub level: u32,
143 pub modularity: f64,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct QueryProvenance {
150 pub timestamp: DateTime<Utc>,
152 pub original_query: String,
154 pub expanded_query: Option<String>,
156 pub seed_entities: Vec<String>,
158 pub source_triples: Vec<Triple>,
160 pub community_sources: Vec<String>,
162 pub processing_time_ms: u64,
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct GraphRAGResult2 {
169 pub answer: String,
171 pub subgraph: Vec<Triple>,
173 pub seeds: Vec<ScoredEntity>,
175 pub communities: Vec<CommunitySummary>,
177 pub provenance: QueryProvenance,
179 pub confidence: f64,
181}
182
183#[async_trait]
185pub trait VectorIndexTrait: Send + Sync {
186 async fn search_knn(
188 &self,
189 query_vector: &[f32],
190 k: usize,
191 ) -> GraphRAGResult<Vec<(String, f32)>>;
192
193 async fn search_threshold(
195 &self,
196 query_vector: &[f32],
197 threshold: f32,
198 ) -> GraphRAGResult<Vec<(String, f32)>>;
199}
200
201#[async_trait]
203pub trait EmbeddingModelTrait: Send + Sync {
204 async fn embed(&self, text: &str) -> GraphRAGResult<Vec<f32>>;
206
207 async fn embed_batch(&self, texts: &[&str]) -> GraphRAGResult<Vec<Vec<f32>>>;
209}
210
211#[async_trait]
213pub trait SparqlEngineTrait: Send + Sync {
214 async fn select(&self, query: &str) -> GraphRAGResult<Vec<HashMap<String, String>>>;
216
217 async fn ask(&self, query: &str) -> GraphRAGResult<bool>;
219
220 async fn construct(&self, query: &str) -> GraphRAGResult<Vec<Triple>>;
222}
223
224#[async_trait]
226pub trait LlmClientTrait: Send + Sync {
227 async fn generate(&self, context: &str, query: &str) -> GraphRAGResult<String>;
229
230 async fn generate_stream(
232 &self,
233 context: &str,
234 query: &str,
235 callback: Box<dyn Fn(&str) + Send + Sync>,
236 ) -> GraphRAGResult<String>;
237}
238
239pub struct GraphRAGEngine<V, E, S, L>
241where
242 V: VectorIndexTrait,
243 E: EmbeddingModelTrait,
244 S: SparqlEngineTrait,
245 L: LlmClientTrait,
246{
247 vec_index: Arc<V>,
249 embedding_model: Arc<E>,
251 sparql_engine: Arc<S>,
253 llm_client: Arc<L>,
255 config: GraphRAGConfig,
257 cache: Arc<RwLock<lru::LruCache<String, GraphRAGResult2>>>,
259 community_detector: Option<Arc<CommunityDetector>>,
261}
262
263impl<V, E, S, L> GraphRAGEngine<V, E, S, L>
264where
265 V: VectorIndexTrait,
266 E: EmbeddingModelTrait,
267 S: SparqlEngineTrait,
268 L: LlmClientTrait,
269{
270 pub fn new(
272 vec_index: Arc<V>,
273 embedding_model: Arc<E>,
274 sparql_engine: Arc<S>,
275 llm_client: Arc<L>,
276 config: GraphRAGConfig,
277 ) -> Self {
278 const DEFAULT_CACHE_SIZE: std::num::NonZeroUsize = match std::num::NonZeroUsize::new(1000) {
279 Some(size) => size,
280 None => panic!("constant is non-zero"),
281 };
282
283 let cache_size = config
284 .cache_size
285 .and_then(std::num::NonZeroUsize::new)
286 .unwrap_or(DEFAULT_CACHE_SIZE);
287
288 Self {
289 vec_index,
290 embedding_model,
291 sparql_engine,
292 llm_client,
293 config,
294 cache: Arc::new(RwLock::new(lru::LruCache::new(cache_size))),
295 community_detector: None,
296 }
297 }
298
299 pub async fn query(&self, query: &str) -> GraphRAGResult<GraphRAGResult2> {
301 let start_time = std::time::Instant::now();
302
303 if let Some(cached) = self.cache.read().await.peek(&query.to_string()) {
305 return Ok(cached.clone());
306 }
307
308 let query_vec = self.embedding_model.embed(query).await?;
310
311 let vector_results = self
313 .vec_index
314 .search_knn(&query_vec, self.config.top_k)
315 .await?;
316
317 let keyword_results = self.keyword_search(query).await?;
319
320 let seeds = self.fuse_results(&vector_results, &keyword_results)?;
322
323 let subgraph = self.expand_graph(&seeds).await?;
325
326 let communities = if self.config.enable_communities {
328 self.detect_communities(&subgraph)?
329 } else {
330 vec![]
331 };
332
333 let context = self.build_context(&subgraph, &communities, query)?;
335
336 let answer = self.llm_client.generate(&context, query).await?;
338
339 let confidence = self.calculate_confidence(&seeds, &subgraph);
341
342 let result = GraphRAGResult2 {
343 answer,
344 subgraph: subgraph.clone(),
345 seeds: seeds.clone(),
346 communities,
347 provenance: QueryProvenance {
348 timestamp: Utc::now(),
349 original_query: query.to_string(),
350 expanded_query: None,
351 seed_entities: seeds.iter().map(|s| s.uri.clone()).collect(),
352 source_triples: subgraph,
353 community_sources: vec![],
354 processing_time_ms: start_time.elapsed().as_millis() as u64,
355 },
356 confidence,
357 };
358
359 self.cache
361 .write()
362 .await
363 .put(query.to_string(), result.clone());
364
365 Ok(result)
366 }
367
368 async fn keyword_search(&self, query: &str) -> GraphRAGResult<Vec<(String, f32)>> {
370 let terms: Vec<&str> = query.split_whitespace().collect();
372 if terms.is_empty() {
373 return Ok(vec![]);
374 }
375
376 let filters: Vec<String> = terms
378 .iter()
379 .map(|term| format!("REGEX(STR(?label), \"{}\", \"i\")", term))
380 .collect();
381
382 let sparql = format!(
383 r#"
384 SELECT DISTINCT ?entity (COUNT(*) AS ?score) WHERE {{
385 ?entity rdfs:label|schema:name|dc:title ?label .
386 FILTER({})
387 }}
388 GROUP BY ?entity
389 ORDER BY DESC(?score)
390 LIMIT {}
391 "#,
392 filters.join(" || "),
393 self.config.top_k
394 );
395
396 let results = self.sparql_engine.select(&sparql).await?;
397
398 Ok(results
399 .into_iter()
400 .filter_map(|row| {
401 let entity = row.get("entity")?.clone();
402 let score = row.get("score")?.parse::<f32>().ok()?;
403 Some((entity, score))
404 })
405 .collect())
406 }
407
408 fn fuse_results(
410 &self,
411 vector_results: &[(String, f32)],
412 keyword_results: &[(String, f32)],
413 ) -> GraphRAGResult<Vec<ScoredEntity>> {
414 let k = 60.0; let mut scores: HashMap<String, (f64, ScoreSource)> = HashMap::new();
417
418 for (rank, (uri, score)) in vector_results.iter().enumerate() {
420 let rrf_score = 1.0 / (k + rank as f64 + 1.0);
421 scores.insert(
422 uri.clone(),
423 (
424 rrf_score * self.config.vector_weight as f64,
425 ScoreSource::Vector,
426 ),
427 );
428 }
429
430 for (rank, (uri, _score)) in keyword_results.iter().enumerate() {
432 let rrf_score = 1.0 / (k + rank as f64 + 1.0);
433 let keyword_contribution = rrf_score * self.config.keyword_weight as f64;
434
435 match scores.get(uri).cloned() {
436 Some((existing_score, _)) => {
437 let new_score = existing_score + keyword_contribution;
438 scores.insert(uri.clone(), (new_score, ScoreSource::Fused));
439 }
440 None => {
441 scores.insert(uri.clone(), (keyword_contribution, ScoreSource::Keyword));
442 }
443 }
444 }
445
446 let mut entities: Vec<ScoredEntity> = scores
448 .into_iter()
449 .map(|(uri, (score, source))| ScoredEntity {
450 uri,
451 score,
452 source,
453 metadata: HashMap::new(),
454 })
455 .collect();
456
457 entities.sort_by(|a, b| {
458 b.score
459 .partial_cmp(&a.score)
460 .unwrap_or(std::cmp::Ordering::Equal)
461 });
462 entities.truncate(self.config.max_seeds);
463
464 Ok(entities)
465 }
466
467 async fn expand_graph(&self, seeds: &[ScoredEntity]) -> GraphRAGResult<Vec<Triple>> {
469 if seeds.is_empty() {
470 return Ok(vec![]);
471 }
472
473 let seed_uris: Vec<String> = seeds.iter().map(|s| format!("<{}>", s.uri)).collect();
474 let values = seed_uris.join(" ");
475
476 let hops = self.config.expansion_hops;
478 let path_pattern = if hops == 1 {
479 "?seed ?p ?neighbor".to_string()
480 } else {
481 format!("?seed (:|!:){{1,{}}} ?neighbor", hops)
482 };
483
484 let sparql = format!(
485 r#"
486 CONSTRUCT {{
487 ?seed ?p ?o .
488 ?s ?p2 ?seed .
489 ?neighbor ?p3 ?o2 .
490 }}
491 WHERE {{
492 VALUES ?seed {{ {} }}
493 {{
494 ?seed ?p ?o .
495 }} UNION {{
496 ?s ?p2 ?seed .
497 }} UNION {{
498 {}
499 ?neighbor ?p3 ?o2 .
500 }}
501 }}
502 LIMIT {}
503 "#,
504 values, path_pattern, self.config.max_subgraph_size
505 );
506
507 self.sparql_engine.construct(&sparql).await
508 }
509
510 fn detect_communities(&self, subgraph: &[Triple]) -> GraphRAGResult<Vec<CommunitySummary>> {
512 use petgraph::graph::UnGraph;
513
514 if subgraph.is_empty() {
515 return Ok(vec![]);
516 }
517
518 let mut graph: UnGraph<String, ()> = UnGraph::new_undirected();
520 let mut node_indices: HashMap<String, petgraph::graph::NodeIndex> = HashMap::new();
521
522 for triple in subgraph {
523 let subj_idx = *node_indices
524 .entry(triple.subject.clone())
525 .or_insert_with(|| graph.add_node(triple.subject.clone()));
526 let obj_idx = *node_indices
527 .entry(triple.object.clone())
528 .or_insert_with(|| graph.add_node(triple.object.clone()));
529
530 if subj_idx != obj_idx {
531 graph.add_edge(subj_idx, obj_idx, ());
532 }
533 }
534
535 let components = petgraph::algo::kosaraju_scc(&graph);
538
539 let communities: Vec<CommunitySummary> = components
540 .into_iter()
541 .enumerate()
542 .filter(|(_, component)| component.len() >= 2)
543 .map(|(idx, component)| {
544 let entities: Vec<String> = component
545 .iter()
546 .filter_map(|&node_idx| graph.node_weight(node_idx).cloned())
547 .collect();
548
549 let representative_triples: Vec<Triple> = subgraph
550 .iter()
551 .filter(|t| entities.contains(&t.subject) || entities.contains(&t.object))
552 .take(5)
553 .cloned()
554 .collect();
555
556 CommunitySummary {
557 id: format!("community_{}", idx),
558 summary: format!("Community with {} entities", entities.len()),
559 entities,
560 representative_triples,
561 level: 0,
562 modularity: 0.0,
563 }
564 })
565 .collect();
566
567 Ok(communities)
568 }
569
570 fn build_context(
572 &self,
573 subgraph: &[Triple],
574 communities: &[CommunitySummary],
575 _query: &str,
576 ) -> GraphRAGResult<String> {
577 let mut context = String::new();
578
579 if !communities.is_empty() {
581 context.push_str("## Community Context\n\n");
582 for community in communities {
583 context.push_str(&format!("### {}\n", community.id));
584 context.push_str(&format!("{}\n", community.summary));
585 context.push_str(&format!("Entities: {}\n\n", community.entities.join(", ")));
586 }
587 }
588
589 context.push_str("## Knowledge Graph Facts\n\n");
591 for triple in subgraph.iter().take(self.config.max_context_triples) {
592 context.push_str(&format!(
593 "- {} → {} → {}\n",
594 triple.subject, triple.predicate, triple.object
595 ));
596 }
597
598 Ok(context)
599 }
600
601 fn calculate_confidence(&self, seeds: &[ScoredEntity], subgraph: &[Triple]) -> f64 {
603 if seeds.is_empty() {
604 return 0.0;
605 }
606
607 let avg_seed_score: f64 = seeds.iter().map(|s| s.score).sum::<f64>() / seeds.len() as f64;
609
610 let seed_uris: std::collections::HashSet<_> = seeds.iter().map(|s| &s.uri).collect();
612 let covered: usize = subgraph
613 .iter()
614 .filter(|t| seed_uris.contains(&t.subject) || seed_uris.contains(&t.object))
615 .count();
616 let coverage = if subgraph.is_empty() {
617 0.0
618 } else {
619 (covered as f64 / subgraph.len() as f64).min(1.0)
620 };
621
622 (avg_seed_score * 0.6 + coverage * 0.4).min(1.0)
624 }
625}
626
627#[cfg(test)]
628mod tests {
629 use super::*;
630
631 #[test]
632 fn test_triple_creation() {
633 let triple = Triple::new(
634 "http://example.org/s",
635 "http://example.org/p",
636 "http://example.org/o",
637 );
638 assert_eq!(triple.subject, "http://example.org/s");
639 assert_eq!(triple.predicate, "http://example.org/p");
640 assert_eq!(triple.object, "http://example.org/o");
641 }
642
643 #[test]
644 fn test_scored_entity() {
645 let entity = ScoredEntity {
646 uri: "http://example.org/entity".to_string(),
647 score: 0.85,
648 source: ScoreSource::Fused,
649 metadata: HashMap::new(),
650 };
651 assert_eq!(entity.score, 0.85);
652 assert_eq!(entity.source, ScoreSource::Fused);
653 }
654}