use crate::core::context_update::RelationType;
use crate::graph::entity_graph::SimpleEntityGraph;
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphRagConfig {
pub enabled: bool,
pub max_depth: usize,
pub max_relations_per_entity: usize,
pub min_graph_size: usize,
pub timeout_ms: u64,
pub cache_ttl_secs: u64,
pub relation_priority: Vec<RelationType>,
}
impl Default for GraphRagConfig {
fn default() -> Self {
Self {
enabled: true,
max_depth: 2,
max_relations_per_entity: 10,
min_graph_size: 5,
timeout_ms: 100,
cache_ttl_secs: 300,
relation_priority: vec![
RelationType::CausedBy, RelationType::DependsOn, RelationType::Implements, RelationType::LeadsTo, RelationType::Solves, RelationType::RelatedTo, ],
}
}
}
#[derive(Debug, Clone)]
struct CachedRelations {
relations: Vec<RelationInfo>,
cached_at: Instant,
}
#[derive(Debug, Clone)]
pub struct RelationInfo {
pub target_entity: String,
pub relation_type: RelationType,
pub context: String,
pub depth: usize,
}
#[derive(Debug, Clone)]
pub struct GraphEnrichment {
pub related_entities: Vec<RelationInfo>,
pub paths_to_query: Vec<Vec<String>>,
pub context_string: String,
}
#[derive(Debug, Clone, Default)]
pub struct GlobalGraphInsights {
pub query_entity_map: Vec<(String, Vec<String>)>,
pub structural_insights: Vec<String>,
pub formatted: String,
}
pub struct GraphRagEnricher {
config: GraphRagConfig,
relation_cache: Arc<DashMap<String, CachedRelations>>,
}
impl GraphRagEnricher {
pub fn new(config: GraphRagConfig) -> Self {
Self {
config,
relation_cache: Arc::new(DashMap::new()),
}
}
pub fn with_defaults() -> Self {
Self::new(GraphRagConfig::default())
}
pub fn should_enrich(&self, graph: &SimpleEntityGraph) -> bool {
if !self.config.enabled {
return false;
}
graph.entity_count() >= self.config.min_graph_size
}
pub fn extract_graph_entities(&self, graph: &SimpleEntityGraph, text: &str) -> Vec<String> {
let mut found = Vec::new();
let words: Vec<&str> = text.split_whitespace().collect();
for word in words {
let clean = word.trim_matches(|c: char| !c.is_alphanumeric() && c != '_' && c != '-');
if clean.len() < 2 {
continue;
}
let lower = clean.to_lowercase();
if graph.has_entity(&lower) {
found.push(lower);
continue;
}
if graph.has_entity(clean) {
found.push(clean.to_string());
}
}
let mut seen = HashSet::new();
found.retain(|e| seen.insert(e.clone()));
found
}
pub fn find_related_with_depth(
&self,
graph: &SimpleEntityGraph,
start_entity: &str,
max_depth: usize,
) -> Vec<RelationInfo> {
if let Some(cached) = self.get_cached_relations(start_entity) {
return cached;
}
let mut results = Vec::new();
let mut visited = HashSet::new();
visited.insert(start_entity.to_string());
self.traverse_relations(
graph,
start_entity,
1,
max_depth,
&mut visited,
&mut results,
);
results.sort_by(|a, b| {
let priority_a = self.get_relation_priority(&a.relation_type);
let priority_b = self.get_relation_priority(&b.relation_type);
priority_a.cmp(&priority_b).then(a.depth.cmp(&b.depth))
});
results.truncate(self.config.max_relations_per_entity);
self.cache_relations(start_entity, results.clone());
results
}
fn traverse_relations(
&self,
graph: &SimpleEntityGraph,
entity: &str,
current_depth: usize,
max_depth: usize,
visited: &mut HashSet<String>,
results: &mut Vec<RelationInfo>,
) {
if current_depth > max_depth {
return;
}
let relations = graph.get_entity_relationships(entity);
for (target, rel_type, context) in relations {
if visited.contains(&target) {
continue; }
results.push(RelationInfo {
target_entity: target.clone(),
relation_type: rel_type.clone(),
context: context.clone(),
depth: current_depth,
});
visited.insert(target.clone());
self.traverse_relations(
graph,
&target,
current_depth + 1,
max_depth,
visited,
results,
);
}
}
fn get_cached_relations(&self, entity: &str) -> Option<Vec<RelationInfo>> {
let cached = self.relation_cache.get(entity)?;
let ttl = Duration::from_secs(self.config.cache_ttl_secs);
if cached.cached_at.elapsed() < ttl {
Some(cached.relations.clone())
} else {
drop(cached);
self.relation_cache.remove(entity);
None
}
}
fn cache_relations(&self, entity: &str, relations: Vec<RelationInfo>) {
self.relation_cache.insert(
entity.to_string(),
CachedRelations {
relations,
cached_at: Instant::now(),
},
);
}
fn get_relation_priority(&self, rel_type: &RelationType) -> usize {
self.config
.relation_priority
.iter()
.position(|r| r == rel_type)
.unwrap_or(usize::MAX)
}
pub fn filter_relevant_relations(
&self,
relations: &[RelationInfo],
query_entities: &[String],
graph: &SimpleEntityGraph,
) -> Vec<RelationInfo> {
if query_entities.is_empty() {
return relations.to_vec();
}
let query_set: HashSet<_> = query_entities.iter().map(|s| s.to_lowercase()).collect();
let mut query_neighbors: HashSet<String> = HashSet::new();
for qe in query_entities {
let neighbors = graph.find_related_entities(qe);
query_neighbors.extend(neighbors.into_iter().map(|n| n.to_lowercase()));
}
let mut scored_relations: Vec<(RelationInfo, u32)> = relations
.iter()
.map(|rel| {
let mut score = 0u32;
let target_lower = rel.target_entity.to_lowercase();
if query_set.contains(&target_lower) {
score += 100;
}
if query_neighbors.contains(&target_lower) {
score += 50;
}
let priority = self.get_relation_priority(&rel.relation_type);
if priority < 3 {
score += (30 - priority * 10) as u32;
}
if rel.depth == 1 {
score += 20;
} else if rel.depth == 2 {
score += 10;
}
(rel.clone(), score)
})
.collect();
scored_relations.sort_by(|a, b| b.1.cmp(&a.1));
scored_relations
.into_iter()
.filter(|(_, score)| *score > 0)
.map(|(rel, _)| rel)
.collect()
}
pub fn summarize_path(&self, path: &[String], graph: &SimpleEntityGraph) -> String {
if path.len() < 2 {
return path.join(" → ");
}
let mut parts = Vec::new();
parts.push(path[0].clone());
for i in 0..path.len() - 1 {
let from = &path[i];
let to = &path[i + 1];
let relations = graph.get_entity_relationships(from);
let rel_type = relations
.iter()
.find(|(target, _, _)| target == to)
.map(|(_, rt, _)| self.format_relation_type(rt))
.unwrap_or_else(|| "→".to_string());
parts.push(format!("─{}→", rel_type));
parts.push(to.clone());
}
parts.join(" ")
}
fn format_relation_type(&self, rel_type: &RelationType) -> String {
match rel_type {
RelationType::CausedBy => "CAUSED_BY".to_string(),
RelationType::DependsOn => "DEPENDS_ON".to_string(),
RelationType::Implements => "IMPLEMENTS".to_string(),
RelationType::LeadsTo => "LEADS_TO".to_string(),
RelationType::RelatedTo => "RELATED_TO".to_string(),
RelationType::RequiredBy => "REQUIRED_BY".to_string(),
RelationType::ConflictsWith => "CONFLICTS".to_string(),
RelationType::Solves => "SOLVES".to_string(),
}
}
pub fn relation_relevance_score(
&self,
relation: &RelationInfo,
query_entities: &[String],
query_neighbors: &HashSet<String>,
) -> u32 {
let mut score = 0u32;
let target_lower = relation.target_entity.to_lowercase();
let query_set: HashSet<_> = query_entities.iter().map(|s| s.to_lowercase()).collect();
if query_set.contains(&target_lower) {
score += 100;
}
if query_neighbors.contains(&target_lower) {
score += 50;
}
let priority = self.get_relation_priority(&relation.relation_type);
if priority < 3 {
score += (30 - priority * 10) as u32;
}
if relation.depth == 1 {
score += 20;
} else if relation.depth == 2 {
score += 10;
}
score
}
pub fn analyze_query(
&self,
graph: &SimpleEntityGraph,
query_entities: &[String],
) -> GlobalGraphInsights {
let mut insights = GlobalGraphInsights::default();
for entity in query_entities {
let relations = self.find_related_with_depth(graph, entity, self.config.max_depth);
if !relations.is_empty() {
let related: Vec<String> = relations
.iter()
.take(5)
.map(|r| r.target_entity.clone())
.collect();
insights.query_entity_map.push((entity.clone(), related));
}
}
if !insights.query_entity_map.is_empty() {
let mut formatted = String::from("\n[System Knowledge Map]:\n");
for (entity, related) in &insights.query_entity_map {
formatted.push_str(&format!("• {} → {}\n", entity, related.join(", ")));
}
insights.formatted = formatted;
}
insights
}
pub fn enrich_result(
&self,
graph: &SimpleEntityGraph,
content: &str,
query_entities: &[String],
) -> GraphEnrichment {
let result_entities = self.extract_graph_entities(graph, content);
let mut related = Vec::new();
let mut paths = Vec::new();
for entity in result_entities.iter().take(3) {
let entity_relations = self.find_related_with_depth(graph, entity, 1);
related.extend(entity_relations);
}
for result_entity in result_entities.iter().take(2) {
for query_entity in query_entities.iter().take(2) {
if result_entity != query_entity
&& let Some(path) = graph.find_shortest_path(result_entity, query_entity)
&& path.len() > 2
&& path.len() <= 5
{
paths.push(path);
}
}
}
let mut seen = HashSet::new();
related.retain(|r| seen.insert(r.target_entity.clone()));
let filtered_related = self.filter_relevant_relations(&related, query_entities, graph);
let final_related: Vec<RelationInfo> = if filtered_related.is_empty() {
related
.into_iter()
.take(self.config.max_relations_per_entity)
.collect()
} else {
filtered_related
.into_iter()
.take(self.config.max_relations_per_entity)
.collect()
};
let context_string = self.format_enrichment_with_summary(graph, &final_related, &paths);
GraphEnrichment {
related_entities: final_related,
paths_to_query: paths,
context_string,
}
}
pub fn find_cross_result_insights(
&self,
graph: &SimpleEntityGraph,
result_entities: &[Vec<String>],
) -> Vec<String> {
let mut insights = Vec::new();
if result_entities.len() < 2 {
return insights;
}
for i in 0..result_entities.len().min(3) {
for j in (i + 1)..result_entities.len().min(3) {
if let (Some(e1), Some(e2)) =
(result_entities[i].first(), result_entities[j].first())
&& e1 != e2
&& let Some(path) = graph.find_shortest_path(e1, e2)
&& path.len() > 2
&& path.len() <= 5
{
let summarized = self.summarize_path(&path, graph);
insights.push(format!("Connection: {}", summarized));
}
}
}
insights
}
fn format_enrichment_with_summary(
&self,
graph: &SimpleEntityGraph,
relations: &[RelationInfo],
paths: &[Vec<String>],
) -> String {
if relations.is_empty() && paths.is_empty() {
return String::new();
}
let mut output = String::from("\n───────────────────────────\n📊 Graph Context:\n");
if !relations.is_empty() {
let mut by_type: std::collections::HashMap<&RelationType, Vec<&str>> =
std::collections::HashMap::new();
for rel in relations.iter().take(8) {
by_type
.entry(&rel.relation_type)
.or_default()
.push(&rel.target_entity);
}
for (rel_type, entities) in by_type {
output.push_str(&format!(
"• {}: {}\n",
self.format_relation_type(rel_type),
entities.join(", ")
));
}
}
for path in paths.iter().take(2) {
let summarized = self.summarize_path(path, graph);
output.push_str(&format!("• Path: {}\n", summarized));
}
output.push_str("───────────────────────────");
output
}
#[allow(dead_code)]
fn format_enrichment(&self, relations: &[RelationInfo], paths: &[Vec<String>]) -> String {
if relations.is_empty() && paths.is_empty() {
return String::new();
}
let mut output = String::from("\n───────────────────────────\n📊 Graph Context:\n");
if !relations.is_empty() {
let mut by_type: std::collections::HashMap<&RelationType, Vec<&str>> =
std::collections::HashMap::new();
for rel in relations.iter().take(8) {
by_type
.entry(&rel.relation_type)
.or_default()
.push(&rel.target_entity);
}
for (rel_type, entities) in by_type {
output.push_str(&format!("• {:?}: {}\n", rel_type, entities.join(", ")));
}
}
for path in paths.iter().take(2) {
output.push_str(&format!("• Path: {}\n", path.join(" → ")));
}
output.push_str("───────────────────────────");
output
}
pub fn cleanup_cache(&self) {
let ttl = Duration::from_secs(self.config.cache_ttl_secs);
self.relation_cache
.retain(|_, cached| cached.cached_at.elapsed() < ttl);
}
pub fn cache_stats(&self) -> (usize, usize) {
let total = self.relation_cache.len();
let ttl = Duration::from_secs(self.config.cache_ttl_secs);
let valid = self
.relation_cache
.iter()
.filter(|e| e.cached_at.elapsed() < ttl)
.count();
(valid, total)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::context_update::EntityType;
fn create_test_graph() -> SimpleEntityGraph {
let mut graph = SimpleEntityGraph::new();
let now = chrono::Utc::now();
graph.add_or_update_entity(
"payment".to_string(),
EntityType::Concept,
now,
"Payment processing",
);
graph.add_or_update_entity(
"database".to_string(),
EntityType::Technology,
now,
"Database system",
);
graph.add_or_update_entity(
"timeout".to_string(),
EntityType::Concept,
now,
"Connection timeout",
);
graph.add_or_update_entity(
"auth".to_string(),
EntityType::Concept,
now,
"Authentication",
);
graph.add_or_update_entity("user".to_string(), EntityType::Concept, now, "User entity");
use crate::core::context_update::EntityRelationship;
graph.add_relationship(EntityRelationship {
from_entity: "payment".to_string(),
to_entity: "database".to_string(),
relation_type: RelationType::DependsOn,
context: "Payment depends on database".to_string(),
});
graph.add_relationship(EntityRelationship {
from_entity: "database".to_string(),
to_entity: "timeout".to_string(),
relation_type: RelationType::CausedBy,
context: "Database can cause timeout".to_string(),
});
graph.add_relationship(EntityRelationship {
from_entity: "auth".to_string(),
to_entity: "user".to_string(),
relation_type: RelationType::LeadsTo,
context: "Auth leads to user".to_string(),
});
graph.add_relationship(EntityRelationship {
from_entity: "payment".to_string(),
to_entity: "auth".to_string(),
relation_type: RelationType::DependsOn,
context: "Payment requires auth".to_string(),
});
graph
}
#[test]
fn test_should_enrich() {
let enricher = GraphRagEnricher::with_defaults();
let graph = create_test_graph();
assert!(enricher.should_enrich(&graph));
let empty_graph = SimpleEntityGraph::new();
assert!(!enricher.should_enrich(&empty_graph));
}
#[test]
fn test_extract_graph_entities() {
let enricher = GraphRagEnricher::with_defaults();
let graph = create_test_graph();
let text = "The payment system had a database timeout";
let entities = enricher.extract_graph_entities(&graph, text);
assert!(entities.contains(&"payment".to_string()));
assert!(entities.contains(&"database".to_string()));
assert!(entities.contains(&"timeout".to_string()));
}
#[test]
fn test_find_related_with_depth() {
let enricher = GraphRagEnricher::with_defaults();
let graph = create_test_graph();
let related = enricher.find_related_with_depth(&graph, "payment", 2);
let targets: Vec<_> = related.iter().map(|r| &r.target_entity).collect();
assert!(targets.contains(&&"database".to_string()));
assert!(targets.contains(&&"auth".to_string()));
}
#[test]
fn test_cycle_detection() {
let mut graph = SimpleEntityGraph::new();
let now = chrono::Utc::now();
for name in ["a", "b", "c"] {
graph.add_or_update_entity(name.to_string(), EntityType::Concept, now, "Test entity");
}
use crate::core::context_update::EntityRelationship;
graph.add_relationship(EntityRelationship {
from_entity: "a".to_string(),
to_entity: "b".to_string(),
relation_type: RelationType::RelatedTo,
context: String::new(),
});
graph.add_relationship(EntityRelationship {
from_entity: "b".to_string(),
to_entity: "c".to_string(),
relation_type: RelationType::RelatedTo,
context: String::new(),
});
graph.add_relationship(EntityRelationship {
from_entity: "c".to_string(),
to_entity: "a".to_string(),
relation_type: RelationType::RelatedTo,
context: String::new(),
});
let enricher = GraphRagEnricher::with_defaults();
let related = enricher.find_related_with_depth(&graph, "a", 5);
assert_eq!(related.len(), 2);
}
#[test]
fn test_cache_works() {
let enricher = GraphRagEnricher::with_defaults();
let graph = create_test_graph();
let related1 = enricher.find_related_with_depth(&graph, "payment", 2);
let related2 = enricher.find_related_with_depth(&graph, "payment", 2);
assert_eq!(related1.len(), related2.len());
let (valid, total) = enricher.cache_stats();
assert!(valid > 0);
assert_eq!(valid, total);
}
#[test]
fn test_enrich_result() {
let enricher = GraphRagEnricher::with_defaults();
let graph = create_test_graph();
let content = "Payment processing failed due to database issues";
let query_entities = vec!["timeout".to_string()];
let enrichment = enricher.enrich_result(&graph, content, &query_entities);
assert!(!enrichment.related_entities.is_empty());
assert!(!enrichment.context_string.is_empty());
assert!(enrichment.context_string.contains("Graph Context"));
}
#[test]
fn test_path_summarization() {
let enricher = GraphRagEnricher::with_defaults();
let graph = create_test_graph();
let path = vec![
"payment".to_string(),
"database".to_string(),
"timeout".to_string(),
];
let summarized = enricher.summarize_path(&path, &graph);
assert!(summarized.contains("payment"));
assert!(summarized.contains("database"));
assert!(summarized.contains("timeout"));
assert!(summarized.contains("DEPENDS_ON") || summarized.contains("→"));
}
#[test]
fn test_relevance_filtering() {
let enricher = GraphRagEnricher::with_defaults();
let graph = create_test_graph();
let relations = vec![
RelationInfo {
target_entity: "database".to_string(),
relation_type: RelationType::DependsOn,
context: String::new(),
depth: 1,
},
RelationInfo {
target_entity: "timeout".to_string(),
relation_type: RelationType::CausedBy,
context: String::new(),
depth: 2,
},
RelationInfo {
target_entity: "unrelated".to_string(),
relation_type: RelationType::RelatedTo,
context: String::new(),
depth: 3,
},
];
let query_entities = vec!["timeout".to_string()];
let filtered = enricher.filter_relevant_relations(&relations, &query_entities, &graph);
assert!(!filtered.is_empty());
if !filtered.is_empty() {
let first = &filtered[0];
assert!(
first.target_entity == "timeout" || first.target_entity == "database",
"Expected timeout or database to be prioritized"
);
}
}
#[test]
fn test_relevance_scoring() {
let enricher = GraphRagEnricher::with_defaults();
let relation = RelationInfo {
target_entity: "database".to_string(),
relation_type: RelationType::DependsOn,
context: String::new(),
depth: 1,
};
let query_entities = vec!["database".to_string()];
let query_neighbors: HashSet<String> = HashSet::new();
let score = enricher.relation_relevance_score(&relation, &query_entities, &query_neighbors);
assert!(
score >= 100,
"Direct match should give high score, got {}",
score
);
}
#[test]
fn test_format_relation_type() {
let enricher = GraphRagEnricher::with_defaults();
assert_eq!(
enricher.format_relation_type(&RelationType::CausedBy),
"CAUSED_BY"
);
assert_eq!(
enricher.format_relation_type(&RelationType::DependsOn),
"DEPENDS_ON"
);
assert_eq!(
enricher.format_relation_type(&RelationType::Implements),
"IMPLEMENTS"
);
assert_eq!(
enricher.format_relation_type(&RelationType::LeadsTo),
"LEADS_TO"
);
assert_eq!(
enricher.format_relation_type(&RelationType::RelatedTo),
"RELATED_TO"
);
}
#[test]
fn test_cross_result_insights_with_summary() {
let enricher = GraphRagEnricher::with_defaults();
let graph = create_test_graph();
let result_entities = vec![vec!["payment".to_string()], vec!["timeout".to_string()]];
let insights = enricher.find_cross_result_insights(&graph, &result_entities);
if !insights.is_empty() {
let insight = &insights[0];
assert!(
insight.contains("Connection:"),
"Should contain 'Connection:'"
);
assert!(
insight.contains("payment") && insight.contains("timeout"),
"Should contain both entities"
);
}
}
}