1pub mod cache;
34pub mod config;
35pub mod distributed;
36pub mod embeddings;
38pub mod federation;
39pub mod fusion;
40pub mod generation;
41pub mod graph;
42pub mod graph_summarization;
43pub mod query;
44pub mod reasoning;
45pub mod retrieval;
46pub mod sparql;
47pub mod streaming;
48pub mod temporal;
49
50pub mod transe_model;
52
53pub mod entity_linking;
55
56pub mod community_detector;
58
59pub mod path_ranker;
61
62pub mod entity_linker;
64
65pub mod graph_embedder;
67
68pub mod graph_partitioner;
70
71pub mod triple_extractor;
73
74pub mod knowledge_fusion;
76
77pub mod context_builder;
79
80pub mod path_finder;
82
83pub mod summarizer;
85
86pub mod entity_classifier;
88
89use std::collections::HashMap;
90use std::sync::atomic::{AtomicU64, Ordering};
91use std::sync::Arc;
92use std::time::{Duration, SystemTime};
93
94use async_trait::async_trait;
95use chrono::{DateTime, Utc};
96use serde::{Deserialize, Serialize};
97use thiserror::Error;
98use tokio::sync::RwLock;
99
100pub use cache::query_cache::{CacheEntry, CacheStats, QueryCache, QueryCacheConfig};
102pub use config::{CacheConfiguration, GraphRAGConfig};
103pub use embeddings::node2vec::{
104 Node2VecConfig, Node2VecEmbedder, Node2VecEmbeddings, Node2VecWalkConfig,
105};
106pub use graph::community::{CommunityAlgorithm, CommunityConfig, CommunityDetector};
107pub use graph::embeddings::{CommunityAwareEmbeddings, CommunityStructure, EmbeddingConfig};
108pub use graph::traversal::GraphTraversal;
109pub use query::planner::QueryPlanner;
110pub use retrieval::fusion::FusionStrategy;
111
112#[derive(Error, Debug)]
114pub enum GraphRAGError {
115 #[error("Vector search failed: {0}")]
116 VectorSearchError(String),
117
118 #[error("Graph traversal failed: {0}")]
119 GraphTraversalError(String),
120
121 #[error("Community detection failed: {0}")]
122 CommunityDetectionError(String),
123
124 #[error("LLM generation failed: {0}")]
125 GenerationError(String),
126
127 #[error("Embedding failed: {0}")]
128 EmbeddingError(String),
129
130 #[error("SPARQL query failed: {0}")]
131 SparqlError(String),
132
133 #[error("Configuration error: {0}")]
134 ConfigError(String),
135
136 #[error("Internal error: {0}")]
137 InternalError(String),
138}
139
140pub type GraphRAGResult<T> = Result<T, GraphRAGError>;
141
142#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
144pub struct Triple {
145 pub subject: String,
146 pub predicate: String,
147 pub object: String,
148}
149
150impl Triple {
151 pub fn new(
152 subject: impl Into<String>,
153 predicate: impl Into<String>,
154 object: impl Into<String>,
155 ) -> Self {
156 Self {
157 subject: subject.into(),
158 predicate: predicate.into(),
159 object: object.into(),
160 }
161 }
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct ScoredEntity {
167 pub uri: String,
169 pub score: f64,
171 pub source: ScoreSource,
173 pub metadata: HashMap<String, String>,
175}
176
177#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
179pub enum ScoreSource {
180 Vector,
182 Keyword,
184 Fused,
186 Graph,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct CommunitySummary {
193 pub id: String,
195 pub summary: String,
197 pub entities: Vec<String>,
199 pub representative_triples: Vec<Triple>,
201 pub level: u32,
203 pub modularity: f64,
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct QueryProvenance {
210 pub timestamp: DateTime<Utc>,
212 pub original_query: String,
214 pub expanded_query: Option<String>,
216 pub seed_entities: Vec<String>,
218 pub source_triples: Vec<Triple>,
220 pub community_sources: Vec<String>,
222 pub processing_time_ms: u64,
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct GraphRAGResult2 {
229 pub answer: String,
231 pub subgraph: Vec<Triple>,
233 pub seeds: Vec<ScoredEntity>,
235 pub communities: Vec<CommunitySummary>,
237 pub provenance: QueryProvenance,
239 pub confidence: f64,
241}
242
243#[async_trait]
245pub trait VectorIndexTrait: Send + Sync {
246 async fn search_knn(
248 &self,
249 query_vector: &[f32],
250 k: usize,
251 ) -> GraphRAGResult<Vec<(String, f32)>>;
252
253 async fn search_threshold(
255 &self,
256 query_vector: &[f32],
257 threshold: f32,
258 ) -> GraphRAGResult<Vec<(String, f32)>>;
259}
260
261#[async_trait]
263pub trait EmbeddingModelTrait: Send + Sync {
264 async fn embed(&self, text: &str) -> GraphRAGResult<Vec<f32>>;
266
267 async fn embed_batch(&self, texts: &[&str]) -> GraphRAGResult<Vec<Vec<f32>>>;
269}
270
271#[async_trait]
273pub trait SparqlEngineTrait: Send + Sync {
274 async fn select(&self, query: &str) -> GraphRAGResult<Vec<HashMap<String, String>>>;
276
277 async fn ask(&self, query: &str) -> GraphRAGResult<bool>;
279
280 async fn construct(&self, query: &str) -> GraphRAGResult<Vec<Triple>>;
282}
283
284#[async_trait]
286pub trait LlmClientTrait: Send + Sync {
287 async fn generate(&self, context: &str, query: &str) -> GraphRAGResult<String>;
289
290 async fn generate_stream(
292 &self,
293 context: &str,
294 query: &str,
295 callback: Box<dyn Fn(&str) + Send + Sync>,
296 ) -> GraphRAGResult<String>;
297}
298
299#[derive(Debug, Clone)]
301struct CachedResult {
302 result: GraphRAGResult2,
303 timestamp: SystemTime,
304 ttl: Duration,
305}
306
307impl CachedResult {
308 fn is_fresh(&self) -> bool {
310 self.timestamp
311 .elapsed()
312 .map(|elapsed| elapsed < self.ttl)
313 .unwrap_or(false)
314 }
315}
316
317#[derive(Debug, Clone)]
319pub struct CacheConfig {
320 pub base_ttl_seconds: u64,
322 pub min_ttl_seconds: u64,
324 pub max_ttl_seconds: u64,
326 pub adaptive: bool,
328}
329
330impl Default for CacheConfig {
331 fn default() -> Self {
332 Self {
333 base_ttl_seconds: 3600,
334 min_ttl_seconds: 300,
335 max_ttl_seconds: 86400,
336 adaptive: true,
337 }
338 }
339}
340
341pub struct GraphRAGEngine<V, E, S, L>
343where
344 V: VectorIndexTrait,
345 E: EmbeddingModelTrait,
346 S: SparqlEngineTrait,
347 L: LlmClientTrait,
348{
349 vec_index: Arc<V>,
351 embedding_model: Arc<E>,
353 sparql_engine: Arc<S>,
355 llm_client: Arc<L>,
357 config: GraphRAGConfig,
359 cache: Arc<RwLock<lru::LruCache<String, CachedResult>>>,
361 cache_config: CacheConfig,
363 graph_update_count: Arc<AtomicU64>,
365 community_detector: Option<Arc<CommunityDetector>>,
367}
368
369impl<V, E, S, L> GraphRAGEngine<V, E, S, L>
370where
371 V: VectorIndexTrait,
372 E: EmbeddingModelTrait,
373 S: SparqlEngineTrait,
374 L: LlmClientTrait,
375{
376 pub fn new(
378 vec_index: Arc<V>,
379 embedding_model: Arc<E>,
380 sparql_engine: Arc<S>,
381 llm_client: Arc<L>,
382 config: GraphRAGConfig,
383 ) -> Self {
384 let cache_config = CacheConfig {
385 base_ttl_seconds: config.cache_config.base_ttl_seconds,
386 min_ttl_seconds: config.cache_config.min_ttl_seconds,
387 max_ttl_seconds: config.cache_config.max_ttl_seconds,
388 adaptive: config.cache_config.adaptive,
389 };
390
391 Self::with_cache_config(
392 vec_index,
393 embedding_model,
394 sparql_engine,
395 llm_client,
396 config,
397 cache_config,
398 )
399 }
400
401 pub fn with_cache_config(
403 vec_index: Arc<V>,
404 embedding_model: Arc<E>,
405 sparql_engine: Arc<S>,
406 llm_client: Arc<L>,
407 config: GraphRAGConfig,
408 cache_config: CacheConfig,
409 ) -> Self {
410 const DEFAULT_CACHE_SIZE: std::num::NonZeroUsize = match std::num::NonZeroUsize::new(1000) {
411 Some(size) => size,
412 None => panic!("constant is non-zero"),
413 };
414
415 let cache_size = config
416 .cache_size
417 .and_then(std::num::NonZeroUsize::new)
418 .unwrap_or(DEFAULT_CACHE_SIZE);
419
420 Self {
421 vec_index,
422 embedding_model,
423 sparql_engine,
424 llm_client,
425 config,
426 cache: Arc::new(RwLock::new(lru::LruCache::new(cache_size))),
427 cache_config,
428 graph_update_count: Arc::new(AtomicU64::new(0)),
429 community_detector: None,
430 }
431 }
432
433 fn calculate_ttl(&self) -> Duration {
435 if !self.cache_config.adaptive {
436 return Duration::from_secs(self.cache_config.base_ttl_seconds);
437 }
438
439 let updates_per_hour = self.graph_update_count.load(Ordering::Relaxed) as f64;
440
441 let ttl_secs = if updates_per_hour > 100.0 {
443 self.cache_config.min_ttl_seconds } else if updates_per_hour > 10.0 {
445 self.cache_config.base_ttl_seconds / 2 } else {
447 self.cache_config.max_ttl_seconds };
449
450 Duration::from_secs(ttl_secs)
451 }
452
453 pub fn record_graph_update(&self) {
455 self.graph_update_count.fetch_add(1, Ordering::Relaxed);
456 }
457
458 pub async fn get_cache_stats(&self) -> (usize, usize) {
460 let cache = self.cache.read().await;
461 (cache.len(), cache.cap().get())
462 }
463
464 pub async fn query(&self, query: &str) -> GraphRAGResult<GraphRAGResult2> {
466 let start_time = std::time::Instant::now();
467
468 {
470 let cache = self.cache.read().await;
471 if let Some(cached) = cache.peek(&query.to_string()) {
472 if cached.is_fresh() {
473 return Ok(cached.result.clone());
474 }
475 }
476 }
477
478 let query_vec = self.embedding_model.embed(query).await?;
480
481 let vector_results = self
483 .vec_index
484 .search_knn(&query_vec, self.config.top_k)
485 .await?;
486
487 let keyword_results = self.keyword_search(query).await?;
489
490 let seeds = self.fuse_results(&vector_results, &keyword_results)?;
492
493 let subgraph = self.expand_graph(&seeds).await?;
495
496 let communities = if self.config.enable_communities {
498 self.detect_communities(&subgraph)?
499 } else {
500 vec![]
501 };
502
503 let context = self.build_context(&subgraph, &communities, query)?;
505
506 let answer = self.llm_client.generate(&context, query).await?;
508
509 let confidence = self.calculate_confidence(&seeds, &subgraph);
511
512 let result = GraphRAGResult2 {
513 answer,
514 subgraph: subgraph.clone(),
515 seeds: seeds.clone(),
516 communities,
517 provenance: QueryProvenance {
518 timestamp: Utc::now(),
519 original_query: query.to_string(),
520 expanded_query: None,
521 seed_entities: seeds.iter().map(|s| s.uri.clone()).collect(),
522 source_triples: subgraph,
523 community_sources: vec![],
524 processing_time_ms: start_time.elapsed().as_millis() as u64,
525 },
526 confidence,
527 };
528
529 let ttl = self.calculate_ttl();
531 let cached = CachedResult {
532 result: result.clone(),
533 timestamp: SystemTime::now(),
534 ttl,
535 };
536 self.cache.write().await.put(query.to_string(), cached);
537
538 Ok(result)
539 }
540
541 async fn keyword_search(&self, query: &str) -> GraphRAGResult<Vec<(String, f32)>> {
543 let terms: Vec<&str> = query.split_whitespace().collect();
545 if terms.is_empty() {
546 return Ok(vec![]);
547 }
548
549 let filters: Vec<String> = terms
551 .iter()
552 .map(|term| format!("REGEX(STR(?label), \"{}\", \"i\")", term))
553 .collect();
554
555 let sparql = format!(
556 r#"
557 SELECT DISTINCT ?entity (COUNT(*) AS ?score) WHERE {{
558 ?entity rdfs:label|schema:name|dc:title ?label .
559 FILTER({})
560 }}
561 GROUP BY ?entity
562 ORDER BY DESC(?score)
563 LIMIT {}
564 "#,
565 filters.join(" || "),
566 self.config.top_k
567 );
568
569 let results = self.sparql_engine.select(&sparql).await?;
570
571 Ok(results
572 .into_iter()
573 .filter_map(|row| {
574 let entity = row.get("entity")?.clone();
575 let score = row.get("score")?.parse::<f32>().ok()?;
576 Some((entity, score))
577 })
578 .collect())
579 }
580
581 fn fuse_results(
583 &self,
584 vector_results: &[(String, f32)],
585 keyword_results: &[(String, f32)],
586 ) -> GraphRAGResult<Vec<ScoredEntity>> {
587 let k = 60.0; let mut scores: HashMap<String, (f64, ScoreSource)> = HashMap::new();
590
591 for (rank, (uri, score)) in vector_results.iter().enumerate() {
593 let rrf_score = 1.0 / (k + rank as f64 + 1.0);
594 scores.insert(
595 uri.clone(),
596 (
597 rrf_score * self.config.vector_weight as f64,
598 ScoreSource::Vector,
599 ),
600 );
601 }
602
603 for (rank, (uri, _score)) in keyword_results.iter().enumerate() {
605 let rrf_score = 1.0 / (k + rank as f64 + 1.0);
606 let keyword_contribution = rrf_score * self.config.keyword_weight as f64;
607
608 match scores.get(uri).cloned() {
609 Some((existing_score, _)) => {
610 let new_score = existing_score + keyword_contribution;
611 scores.insert(uri.clone(), (new_score, ScoreSource::Fused));
612 }
613 None => {
614 scores.insert(uri.clone(), (keyword_contribution, ScoreSource::Keyword));
615 }
616 }
617 }
618
619 let mut entities: Vec<ScoredEntity> = scores
621 .into_iter()
622 .map(|(uri, (score, source))| ScoredEntity {
623 uri,
624 score,
625 source,
626 metadata: HashMap::new(),
627 })
628 .collect();
629
630 entities.sort_by(|a, b| {
631 b.score
632 .partial_cmp(&a.score)
633 .unwrap_or(std::cmp::Ordering::Equal)
634 });
635 entities.truncate(self.config.max_seeds);
636
637 Ok(entities)
638 }
639
640 async fn expand_graph(&self, seeds: &[ScoredEntity]) -> GraphRAGResult<Vec<Triple>> {
642 if seeds.is_empty() {
643 return Ok(vec![]);
644 }
645
646 let seed_uris: Vec<String> = seeds.iter().map(|s| format!("<{}>", s.uri)).collect();
647 let values = seed_uris.join(" ");
648
649 let hops = self.config.expansion_hops;
651 let path_pattern = if hops == 1 {
652 "?seed ?p ?neighbor".to_string()
653 } else {
654 format!("?seed (:|!:){{1,{}}} ?neighbor", hops)
655 };
656
657 let sparql = format!(
658 r#"
659 CONSTRUCT {{
660 ?seed ?p ?o .
661 ?s ?p2 ?seed .
662 ?neighbor ?p3 ?o2 .
663 }}
664 WHERE {{
665 VALUES ?seed {{ {} }}
666 {{
667 ?seed ?p ?o .
668 }} UNION {{
669 ?s ?p2 ?seed .
670 }} UNION {{
671 {}
672 ?neighbor ?p3 ?o2 .
673 }}
674 }}
675 LIMIT {}
676 "#,
677 values, path_pattern, self.config.max_subgraph_size
678 );
679
680 self.sparql_engine.construct(&sparql).await
681 }
682
683 fn detect_communities(&self, subgraph: &[Triple]) -> GraphRAGResult<Vec<CommunitySummary>> {
685 use petgraph::graph::UnGraph;
686
687 if subgraph.is_empty() {
688 return Ok(vec![]);
689 }
690
691 let mut graph: UnGraph<String, ()> = UnGraph::new_undirected();
693 let mut node_indices: HashMap<String, petgraph::graph::NodeIndex> = HashMap::new();
694
695 for triple in subgraph {
696 let subj_idx = *node_indices
697 .entry(triple.subject.clone())
698 .or_insert_with(|| graph.add_node(triple.subject.clone()));
699 let obj_idx = *node_indices
700 .entry(triple.object.clone())
701 .or_insert_with(|| graph.add_node(triple.object.clone()));
702
703 if subj_idx != obj_idx {
704 graph.add_edge(subj_idx, obj_idx, ());
705 }
706 }
707
708 let components = petgraph::algo::kosaraju_scc(&graph);
711
712 let communities: Vec<CommunitySummary> = components
713 .into_iter()
714 .enumerate()
715 .filter(|(_, component)| component.len() >= 2)
716 .map(|(idx, component)| {
717 let entities: Vec<String> = component
718 .iter()
719 .filter_map(|&node_idx| graph.node_weight(node_idx).cloned())
720 .collect();
721
722 let representative_triples: Vec<Triple> = subgraph
723 .iter()
724 .filter(|t| entities.contains(&t.subject) || entities.contains(&t.object))
725 .take(5)
726 .cloned()
727 .collect();
728
729 CommunitySummary {
730 id: format!("community_{}", idx),
731 summary: format!("Community with {} entities", entities.len()),
732 entities,
733 representative_triples,
734 level: 0,
735 modularity: 0.0,
736 }
737 })
738 .collect();
739
740 Ok(communities)
741 }
742
743 fn build_context(
745 &self,
746 subgraph: &[Triple],
747 communities: &[CommunitySummary],
748 _query: &str,
749 ) -> GraphRAGResult<String> {
750 let mut context = String::new();
751
752 if !communities.is_empty() {
754 context.push_str("## Community Context\n\n");
755 for community in communities {
756 context.push_str(&format!("### {}\n", community.id));
757 context.push_str(&format!("{}\n", community.summary));
758 context.push_str(&format!("Entities: {}\n\n", community.entities.join(", ")));
759 }
760 }
761
762 context.push_str("## Knowledge Graph Facts\n\n");
764 for triple in subgraph.iter().take(self.config.max_context_triples) {
765 context.push_str(&format!(
766 "- {} → {} → {}\n",
767 triple.subject, triple.predicate, triple.object
768 ));
769 }
770
771 Ok(context)
772 }
773
774 fn calculate_confidence(&self, seeds: &[ScoredEntity], subgraph: &[Triple]) -> f64 {
776 if seeds.is_empty() {
777 return 0.0;
778 }
779
780 let avg_seed_score: f64 = seeds.iter().map(|s| s.score).sum::<f64>() / seeds.len() as f64;
782
783 let seed_uris: std::collections::HashSet<_> = seeds.iter().map(|s| &s.uri).collect();
785 let covered: usize = subgraph
786 .iter()
787 .filter(|t| seed_uris.contains(&t.subject) || seed_uris.contains(&t.object))
788 .count();
789 let coverage = if subgraph.is_empty() {
790 0.0
791 } else {
792 (covered as f64 / subgraph.len() as f64).min(1.0)
793 };
794
795 (avg_seed_score * 0.6 + coverage * 0.4).min(1.0)
797 }
798}
799
800#[cfg(test)]
801mod tests {
802 use super::*;
803
804 #[test]
805 fn test_triple_creation() {
806 let triple = Triple::new(
807 "http://example.org/s",
808 "http://example.org/p",
809 "http://example.org/o",
810 );
811 assert_eq!(triple.subject, "http://example.org/s");
812 assert_eq!(triple.predicate, "http://example.org/p");
813 assert_eq!(triple.object, "http://example.org/o");
814 }
815
816 #[test]
817 fn test_scored_entity() {
818 let entity = ScoredEntity {
819 uri: "http://example.org/entity".to_string(),
820 score: 0.85,
821 source: ScoreSource::Fused,
822 metadata: HashMap::new(),
823 };
824 assert_eq!(entity.score, 0.85);
825 assert_eq!(entity.source, ScoreSource::Fused);
826 }
827}