pub mod cache;
pub mod config;
pub mod distributed;
pub mod embeddings;
pub mod federation;
pub mod fusion;
pub mod generation;
pub mod graph;
pub mod graph_summarization;
pub mod query;
pub mod reasoning;
pub mod retrieval;
pub mod sparql;
pub mod streaming;
pub mod temporal;
pub mod transe_model;
pub mod entity_linking;
pub mod community_detector;
pub mod path_ranker;
pub mod entity_linker;
pub mod graph_embedder;
pub mod graph_partitioner;
pub mod triple_extractor;
pub mod knowledge_fusion;
pub mod context_builder;
pub mod path_finder;
pub mod summarizer;
pub mod entity_classifier;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::sync::RwLock;
pub use cache::query_cache::{CacheEntry, CacheStats, QueryCache, QueryCacheConfig};
pub use config::{CacheConfiguration, GraphRAGConfig};
pub use embeddings::node2vec::{
Node2VecConfig, Node2VecEmbedder, Node2VecEmbeddings, Node2VecWalkConfig,
};
pub use graph::community::{CommunityAlgorithm, CommunityConfig, CommunityDetector};
pub use graph::embeddings::{CommunityAwareEmbeddings, CommunityStructure, EmbeddingConfig};
pub use graph::traversal::GraphTraversal;
pub use query::planner::QueryPlanner;
pub use retrieval::fusion::FusionStrategy;
#[derive(Error, Debug)]
pub enum GraphRAGError {
#[error("Vector search failed: {0}")]
VectorSearchError(String),
#[error("Graph traversal failed: {0}")]
GraphTraversalError(String),
#[error("Community detection failed: {0}")]
CommunityDetectionError(String),
#[error("LLM generation failed: {0}")]
GenerationError(String),
#[error("Embedding failed: {0}")]
EmbeddingError(String),
#[error("SPARQL query failed: {0}")]
SparqlError(String),
#[error("Configuration error: {0}")]
ConfigError(String),
#[error("Internal error: {0}")]
InternalError(String),
}
pub type GraphRAGResult<T> = Result<T, GraphRAGError>;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct Triple {
pub subject: String,
pub predicate: String,
pub object: String,
}
impl Triple {
pub fn new(
subject: impl Into<String>,
predicate: impl Into<String>,
object: impl Into<String>,
) -> Self {
Self {
subject: subject.into(),
predicate: predicate.into(),
object: object.into(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoredEntity {
pub uri: String,
pub score: f64,
pub source: ScoreSource,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum ScoreSource {
Vector,
Keyword,
Fused,
Graph,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommunitySummary {
pub id: String,
pub summary: String,
pub entities: Vec<String>,
pub representative_triples: Vec<Triple>,
pub level: u32,
pub modularity: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryProvenance {
pub timestamp: DateTime<Utc>,
pub original_query: String,
pub expanded_query: Option<String>,
pub seed_entities: Vec<String>,
pub source_triples: Vec<Triple>,
pub community_sources: Vec<String>,
pub processing_time_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphRAGResult2 {
pub answer: String,
pub subgraph: Vec<Triple>,
pub seeds: Vec<ScoredEntity>,
pub communities: Vec<CommunitySummary>,
pub provenance: QueryProvenance,
pub confidence: f64,
}
#[async_trait]
pub trait VectorIndexTrait: Send + Sync {
async fn search_knn(
&self,
query_vector: &[f32],
k: usize,
) -> GraphRAGResult<Vec<(String, f32)>>;
async fn search_threshold(
&self,
query_vector: &[f32],
threshold: f32,
) -> GraphRAGResult<Vec<(String, f32)>>;
}
#[async_trait]
pub trait EmbeddingModelTrait: Send + Sync {
async fn embed(&self, text: &str) -> GraphRAGResult<Vec<f32>>;
async fn embed_batch(&self, texts: &[&str]) -> GraphRAGResult<Vec<Vec<f32>>>;
}
#[async_trait]
pub trait SparqlEngineTrait: Send + Sync {
async fn select(&self, query: &str) -> GraphRAGResult<Vec<HashMap<String, String>>>;
async fn ask(&self, query: &str) -> GraphRAGResult<bool>;
async fn construct(&self, query: &str) -> GraphRAGResult<Vec<Triple>>;
}
#[async_trait]
pub trait LlmClientTrait: Send + Sync {
async fn generate(&self, context: &str, query: &str) -> GraphRAGResult<String>;
async fn generate_stream(
&self,
context: &str,
query: &str,
callback: Box<dyn Fn(&str) + Send + Sync>,
) -> GraphRAGResult<String>;
}
#[derive(Debug, Clone)]
struct CachedResult {
result: GraphRAGResult2,
timestamp: SystemTime,
ttl: Duration,
}
impl CachedResult {
fn is_fresh(&self) -> bool {
self.timestamp
.elapsed()
.map(|elapsed| elapsed < self.ttl)
.unwrap_or(false)
}
}
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub base_ttl_seconds: u64,
pub min_ttl_seconds: u64,
pub max_ttl_seconds: u64,
pub adaptive: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
base_ttl_seconds: 3600,
min_ttl_seconds: 300,
max_ttl_seconds: 86400,
adaptive: true,
}
}
}
pub struct GraphRAGEngine<V, E, S, L>
where
V: VectorIndexTrait,
E: EmbeddingModelTrait,
S: SparqlEngineTrait,
L: LlmClientTrait,
{
vec_index: Arc<V>,
embedding_model: Arc<E>,
sparql_engine: Arc<S>,
llm_client: Arc<L>,
config: GraphRAGConfig,
cache: Arc<RwLock<lru::LruCache<String, CachedResult>>>,
cache_config: CacheConfig,
graph_update_count: Arc<AtomicU64>,
community_detector: Option<Arc<CommunityDetector>>,
}
impl<V, E, S, L> GraphRAGEngine<V, E, S, L>
where
V: VectorIndexTrait,
E: EmbeddingModelTrait,
S: SparqlEngineTrait,
L: LlmClientTrait,
{
pub fn new(
vec_index: Arc<V>,
embedding_model: Arc<E>,
sparql_engine: Arc<S>,
llm_client: Arc<L>,
config: GraphRAGConfig,
) -> Self {
let cache_config = CacheConfig {
base_ttl_seconds: config.cache_config.base_ttl_seconds,
min_ttl_seconds: config.cache_config.min_ttl_seconds,
max_ttl_seconds: config.cache_config.max_ttl_seconds,
adaptive: config.cache_config.adaptive,
};
Self::with_cache_config(
vec_index,
embedding_model,
sparql_engine,
llm_client,
config,
cache_config,
)
}
pub fn with_cache_config(
vec_index: Arc<V>,
embedding_model: Arc<E>,
sparql_engine: Arc<S>,
llm_client: Arc<L>,
config: GraphRAGConfig,
cache_config: CacheConfig,
) -> Self {
const DEFAULT_CACHE_SIZE: std::num::NonZeroUsize = match std::num::NonZeroUsize::new(1000) {
Some(size) => size,
None => panic!("constant is non-zero"),
};
let cache_size = config
.cache_size
.and_then(std::num::NonZeroUsize::new)
.unwrap_or(DEFAULT_CACHE_SIZE);
Self {
vec_index,
embedding_model,
sparql_engine,
llm_client,
config,
cache: Arc::new(RwLock::new(lru::LruCache::new(cache_size))),
cache_config,
graph_update_count: Arc::new(AtomicU64::new(0)),
community_detector: None,
}
}
fn calculate_ttl(&self) -> Duration {
if !self.cache_config.adaptive {
return Duration::from_secs(self.cache_config.base_ttl_seconds);
}
let updates_per_hour = self.graph_update_count.load(Ordering::Relaxed) as f64;
let ttl_secs = if updates_per_hour > 100.0 {
self.cache_config.min_ttl_seconds } else if updates_per_hour > 10.0 {
self.cache_config.base_ttl_seconds / 2 } else {
self.cache_config.max_ttl_seconds };
Duration::from_secs(ttl_secs)
}
pub fn record_graph_update(&self) {
self.graph_update_count.fetch_add(1, Ordering::Relaxed);
}
pub async fn get_cache_stats(&self) -> (usize, usize) {
let cache = self.cache.read().await;
(cache.len(), cache.cap().get())
}
pub async fn query(&self, query: &str) -> GraphRAGResult<GraphRAGResult2> {
let start_time = std::time::Instant::now();
{
let cache = self.cache.read().await;
if let Some(cached) = cache.peek(&query.to_string()) {
if cached.is_fresh() {
return Ok(cached.result.clone());
}
}
}
let query_vec = self.embedding_model.embed(query).await?;
let vector_results = self
.vec_index
.search_knn(&query_vec, self.config.top_k)
.await?;
let keyword_results = self.keyword_search(query).await?;
let seeds = self.fuse_results(&vector_results, &keyword_results)?;
let subgraph = self.expand_graph(&seeds).await?;
let communities = if self.config.enable_communities {
self.detect_communities(&subgraph)?
} else {
vec![]
};
let context = self.build_context(&subgraph, &communities, query)?;
let answer = self.llm_client.generate(&context, query).await?;
let confidence = self.calculate_confidence(&seeds, &subgraph);
let result = GraphRAGResult2 {
answer,
subgraph: subgraph.clone(),
seeds: seeds.clone(),
communities,
provenance: QueryProvenance {
timestamp: Utc::now(),
original_query: query.to_string(),
expanded_query: None,
seed_entities: seeds.iter().map(|s| s.uri.clone()).collect(),
source_triples: subgraph,
community_sources: vec![],
processing_time_ms: start_time.elapsed().as_millis() as u64,
},
confidence,
};
let ttl = self.calculate_ttl();
let cached = CachedResult {
result: result.clone(),
timestamp: SystemTime::now(),
ttl,
};
self.cache.write().await.put(query.to_string(), cached);
Ok(result)
}
async fn keyword_search(&self, query: &str) -> GraphRAGResult<Vec<(String, f32)>> {
let terms: Vec<&str> = query.split_whitespace().collect();
if terms.is_empty() {
return Ok(vec![]);
}
let filters: Vec<String> = terms
.iter()
.map(|term| format!("REGEX(STR(?label), \"{}\", \"i\")", term))
.collect();
let sparql = format!(
r#"
SELECT DISTINCT ?entity (COUNT(*) AS ?score) WHERE {{
?entity rdfs:label|schema:name|dc:title ?label .
FILTER({})
}}
GROUP BY ?entity
ORDER BY DESC(?score)
LIMIT {}
"#,
filters.join(" || "),
self.config.top_k
);
let results = self.sparql_engine.select(&sparql).await?;
Ok(results
.into_iter()
.filter_map(|row| {
let entity = row.get("entity")?.clone();
let score = row.get("score")?.parse::<f32>().ok()?;
Some((entity, score))
})
.collect())
}
fn fuse_results(
&self,
vector_results: &[(String, f32)],
keyword_results: &[(String, f32)],
) -> GraphRAGResult<Vec<ScoredEntity>> {
let k = 60.0;
let mut scores: HashMap<String, (f64, ScoreSource)> = HashMap::new();
for (rank, (uri, score)) in vector_results.iter().enumerate() {
let rrf_score = 1.0 / (k + rank as f64 + 1.0);
scores.insert(
uri.clone(),
(
rrf_score * self.config.vector_weight as f64,
ScoreSource::Vector,
),
);
}
for (rank, (uri, _score)) in keyword_results.iter().enumerate() {
let rrf_score = 1.0 / (k + rank as f64 + 1.0);
let keyword_contribution = rrf_score * self.config.keyword_weight as f64;
match scores.get(uri).cloned() {
Some((existing_score, _)) => {
let new_score = existing_score + keyword_contribution;
scores.insert(uri.clone(), (new_score, ScoreSource::Fused));
}
None => {
scores.insert(uri.clone(), (keyword_contribution, ScoreSource::Keyword));
}
}
}
let mut entities: Vec<ScoredEntity> = scores
.into_iter()
.map(|(uri, (score, source))| ScoredEntity {
uri,
score,
source,
metadata: HashMap::new(),
})
.collect();
entities.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
entities.truncate(self.config.max_seeds);
Ok(entities)
}
async fn expand_graph(&self, seeds: &[ScoredEntity]) -> GraphRAGResult<Vec<Triple>> {
if seeds.is_empty() {
return Ok(vec![]);
}
let seed_uris: Vec<String> = seeds.iter().map(|s| format!("<{}>", s.uri)).collect();
let values = seed_uris.join(" ");
let hops = self.config.expansion_hops;
let path_pattern = if hops == 1 {
"?seed ?p ?neighbor".to_string()
} else {
format!("?seed (:|!:){{1,{}}} ?neighbor", hops)
};
let sparql = format!(
r#"
CONSTRUCT {{
?seed ?p ?o .
?s ?p2 ?seed .
?neighbor ?p3 ?o2 .
}}
WHERE {{
VALUES ?seed {{ {} }}
{{
?seed ?p ?o .
}} UNION {{
?s ?p2 ?seed .
}} UNION {{
{}
?neighbor ?p3 ?o2 .
}}
}}
LIMIT {}
"#,
values, path_pattern, self.config.max_subgraph_size
);
self.sparql_engine.construct(&sparql).await
}
fn detect_communities(&self, subgraph: &[Triple]) -> GraphRAGResult<Vec<CommunitySummary>> {
use petgraph::graph::UnGraph;
if subgraph.is_empty() {
return Ok(vec![]);
}
let mut graph: UnGraph<String, ()> = UnGraph::new_undirected();
let mut node_indices: HashMap<String, petgraph::graph::NodeIndex> = HashMap::new();
for triple in subgraph {
let subj_idx = *node_indices
.entry(triple.subject.clone())
.or_insert_with(|| graph.add_node(triple.subject.clone()));
let obj_idx = *node_indices
.entry(triple.object.clone())
.or_insert_with(|| graph.add_node(triple.object.clone()));
if subj_idx != obj_idx {
graph.add_edge(subj_idx, obj_idx, ());
}
}
let components = petgraph::algo::kosaraju_scc(&graph);
let communities: Vec<CommunitySummary> = components
.into_iter()
.enumerate()
.filter(|(_, component)| component.len() >= 2)
.map(|(idx, component)| {
let entities: Vec<String> = component
.iter()
.filter_map(|&node_idx| graph.node_weight(node_idx).cloned())
.collect();
let representative_triples: Vec<Triple> = subgraph
.iter()
.filter(|t| entities.contains(&t.subject) || entities.contains(&t.object))
.take(5)
.cloned()
.collect();
CommunitySummary {
id: format!("community_{}", idx),
summary: format!("Community with {} entities", entities.len()),
entities,
representative_triples,
level: 0,
modularity: 0.0,
}
})
.collect();
Ok(communities)
}
fn build_context(
&self,
subgraph: &[Triple],
communities: &[CommunitySummary],
_query: &str,
) -> GraphRAGResult<String> {
let mut context = String::new();
if !communities.is_empty() {
context.push_str("## Community Context\n\n");
for community in communities {
context.push_str(&format!("### {}\n", community.id));
context.push_str(&format!("{}\n", community.summary));
context.push_str(&format!("Entities: {}\n\n", community.entities.join(", ")));
}
}
context.push_str("## Knowledge Graph Facts\n\n");
for triple in subgraph.iter().take(self.config.max_context_triples) {
context.push_str(&format!(
"- {} → {} → {}\n",
triple.subject, triple.predicate, triple.object
));
}
Ok(context)
}
fn calculate_confidence(&self, seeds: &[ScoredEntity], subgraph: &[Triple]) -> f64 {
if seeds.is_empty() {
return 0.0;
}
let avg_seed_score: f64 = seeds.iter().map(|s| s.score).sum::<f64>() / seeds.len() as f64;
let seed_uris: std::collections::HashSet<_> = seeds.iter().map(|s| &s.uri).collect();
let covered: usize = subgraph
.iter()
.filter(|t| seed_uris.contains(&t.subject) || seed_uris.contains(&t.object))
.count();
let coverage = if subgraph.is_empty() {
0.0
} else {
(covered as f64 / subgraph.len() as f64).min(1.0)
};
(avg_seed_score * 0.6 + coverage * 0.4).min(1.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_triple_creation() {
let triple = Triple::new(
"http://example.org/s",
"http://example.org/p",
"http://example.org/o",
);
assert_eq!(triple.subject, "http://example.org/s");
assert_eq!(triple.predicate, "http://example.org/p");
assert_eq!(triple.object, "http://example.org/o");
}
#[test]
fn test_scored_entity() {
let entity = ScoredEntity {
uri: "http://example.org/entity".to_string(),
score: 0.85,
source: ScoreSource::Fused,
metadata: HashMap::new(),
};
assert_eq!(entity.score, 0.85);
assert_eq!(entity.source, ScoreSource::Fused);
}
}