use super::{FunctionCall, FunctionCaller, FunctionContext, FunctionResult};
use crate::{core::KnowledgeGraph, Result};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub enum QueryStrategy {
EntitySearch,
RelationshipExploration,
ContextualAnalysis,
Adaptive,
}
#[derive(Debug, Clone)]
pub struct QueryPlan {
pub strategy: QueryStrategy,
pub function_calls: Vec<FunctionCall>,
pub expected_outcomes: Vec<String>,
pub confidence: f32,
}
pub struct GraphRAGAgent {
function_caller: FunctionCaller,
query_history: Vec<QuerySession>,
max_iterations: usize,
}
#[derive(Debug, Clone)]
pub struct QuerySession {
pub query: String,
pub plan: QueryPlan,
pub function_results: Vec<FunctionResult>,
pub answer: Option<String>,
pub execution_time_ms: u64,
pub success: bool,
}
impl GraphRAGAgent {
pub fn new() -> Self {
Self {
function_caller: FunctionCaller::new(),
query_history: Vec::new(),
max_iterations: 5,
}
}
pub fn with_function_caller(function_caller: FunctionCaller) -> Self {
Self {
function_caller,
query_history: Vec::new(),
max_iterations: 5,
}
}
pub fn process_query(
&mut self,
query: &str,
knowledge_graph: &KnowledgeGraph,
) -> Result<QuerySession> {
let start_time = std::time::Instant::now();
let plan = self.generate_query_plan(query, knowledge_graph)?;
let context = FunctionContext {
knowledge_graph,
query,
previous_results: &[],
metadata: HashMap::new(),
};
let function_results = self
.function_caller
.call_functions(plan.function_calls.clone(), &context)?;
let answer = self.synthesize_answer(query, &function_results, knowledge_graph)?;
let session = QuerySession {
query: query.to_string(),
plan,
function_results,
answer: Some(answer),
execution_time_ms: start_time.elapsed().as_millis() as u64,
success: true,
};
self.query_history.push(session.clone());
Ok(session)
}
fn generate_query_plan(
&self,
query: &str,
knowledge_graph: &KnowledgeGraph,
) -> Result<QueryPlan> {
let query_lower = query.to_lowercase();
let potential_entities = self.extract_entity_names_from_query(query, knowledge_graph);
let strategy = if query_lower.contains("relationship")
|| query_lower.contains("connect")
|| query_lower.contains("relation")
|| query_lower.contains("between")
{
QueryStrategy::RelationshipExploration
} else if query_lower.contains("context")
|| query_lower.contains("detail")
|| query_lower.contains("about")
|| query_lower.contains("information")
{
QueryStrategy::ContextualAnalysis
} else if !potential_entities.is_empty() {
QueryStrategy::EntitySearch
} else {
QueryStrategy::Adaptive
};
let function_calls = match strategy {
QueryStrategy::EntitySearch => self.plan_entity_search(&potential_entities),
QueryStrategy::RelationshipExploration => {
self.plan_relationship_exploration(&potential_entities)
},
QueryStrategy::ContextualAnalysis => self.plan_contextual_analysis(&potential_entities),
QueryStrategy::Adaptive => self.plan_adaptive_search(query, &potential_entities),
};
Ok(QueryPlan {
strategy,
function_calls,
expected_outcomes: vec!["entities".to_string(), "relationships".to_string()],
confidence: 0.8,
})
}
fn extract_entity_names_from_query(
&self,
query: &str,
knowledge_graph: &KnowledgeGraph,
) -> Vec<String> {
let words: Vec<&str> = query.split_whitespace().collect();
let mut entities = Vec::new();
for window in words.windows(1).chain(words.windows(2)) {
let potential_name = window.join(" ");
for entity in knowledge_graph.entities() {
if entity
.name
.to_lowercase()
.contains(&potential_name.to_lowercase())
{
entities.push(entity.name.clone());
break;
}
}
}
if let Some(start) = query.find('"') {
if let Some(end) = query[start + 1..].find('"') {
let quoted = &query[start + 1..start + 1 + end];
entities.push(quoted.to_string());
}
}
entities.sort();
entities.dedup();
entities
}
fn plan_entity_search(&self, entities: &[String]) -> Vec<FunctionCall> {
let mut calls = Vec::new();
for entity in entities {
calls.push(FunctionCall {
name: "graph_search".to_string(),
arguments: json::object! {
"entity_name": entity.clone(),
"limit": 5
},
});
}
calls
}
fn plan_relationship_exploration(&self, entities: &[String]) -> Vec<FunctionCall> {
let mut calls = Vec::new();
for entity in entities {
calls.push(FunctionCall {
name: "graph_search".to_string(),
arguments: json::object! {
"entity_name": entity.clone(),
"limit": 3
},
});
}
if entities.len() >= 2 {
calls.push(FunctionCall {
name: "relationship_traverse".to_string(),
arguments: json::object! {
"source_entity": entities[0].clone(),
"target_entity": entities[1].clone(),
"max_hops": 3
},
});
}
calls
}
fn plan_contextual_analysis(&self, entities: &[String]) -> Vec<FunctionCall> {
let mut calls = Vec::new();
for entity in entities {
calls.push(FunctionCall {
name: "graph_search".to_string(),
arguments: json::object! {
"entity_name": entity.clone(),
"limit": 3
},
});
}
calls
}
fn plan_adaptive_search(&self, query: &str, entities: &[String]) -> Vec<FunctionCall> {
let mut calls = Vec::new();
if entities.is_empty() {
let key_terms: Vec<&str> = query
.split_whitespace()
.filter(|word| {
word.len() > 3
&& word
.chars()
.next()
.expect("non-empty string")
.is_uppercase()
})
.collect();
for term in key_terms.iter().take(3) {
calls.push(FunctionCall {
name: "graph_search".to_string(),
arguments: json::object! {
"entity_name": term.to_string(),
"limit": 5
},
});
}
} else {
calls.extend(self.plan_entity_search(entities));
}
calls
}
fn synthesize_answer(
&self,
query: &str,
function_results: &[FunctionResult],
_knowledge_graph: &KnowledgeGraph,
) -> Result<String> {
if function_results.is_empty() {
return Ok("No relevant information found in the knowledge graph.".to_string());
}
let mut answer_parts = Vec::new();
for result in function_results {
if !result.success {
continue;
}
match result.function_name.as_str() {
"graph_search" => {
if result.result["entities"].is_array() {
let entities: Vec<_> = result.result["entities"].members().collect();
if !entities.is_empty() {
answer_parts.push(format!(
"Found {} relevant entities: {}",
entities.len(),
entities
.iter()
.map(|e| e["name"].as_str().unwrap_or("Unknown"))
.collect::<Vec<_>>()
.join(", ")
));
}
}
},
"entity_expand" => {
if result.result["relationships"].is_array() {
let relationships: Vec<_> =
result.result["relationships"].members().collect();
if !relationships.is_empty() {
answer_parts.push(format!(
"Found {} relationships for the entity",
relationships.len()
));
}
}
},
"relationship_traverse" => {
if result.result["paths"].is_array() {
let paths: Vec<_> = result.result["paths"].members().collect();
if !paths.is_empty() {
answer_parts.push(format!(
"Found {} connection paths between the entities",
paths.len()
));
} else {
answer_parts.push(
"No direct connection found between the entities".to_string(),
);
}
}
},
"get_entity_context" => {
if result.result["context_chunks"].is_array() {
let chunks: Vec<_> = result.result["context_chunks"].members().collect();
if !chunks.is_empty() {
answer_parts.push(format!(
"Found {} text contexts mentioning the entity",
chunks.len()
));
}
}
},
_ => {},
}
}
if answer_parts.is_empty() {
Ok("The query was processed but no specific information was found.".to_string())
} else {
Ok(format!(
"Query: \"{}\"\n\nResults:\n{}",
query,
answer_parts.join("\n")
))
}
}
pub fn get_query_history(&self) -> &[QuerySession] {
&self.query_history
}
pub fn get_statistics(&self) -> super::FunctionCallStatistics {
self.function_caller.get_statistics()
}
pub fn clear_history(&mut self) {
self.query_history.clear();
self.function_caller.clear_history();
}
pub fn set_max_iterations(&mut self, max_iterations: usize) {
self.max_iterations = max_iterations;
}
pub fn get_function_caller_mut(&mut self) -> &mut FunctionCaller {
&mut self.function_caller
}
}
impl Default for GraphRAGAgent {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{Entity, EntityId, KnowledgeGraph};
#[test]
fn test_entity_extraction_from_query() {
let agent = GraphRAGAgent::new();
let mut graph = KnowledgeGraph::new();
let entity = Entity::new(
EntityId::new("test_entity".to_string()),
"John Smith".to_string(),
"PERSON".to_string(),
0.9,
);
graph.add_entity(entity).unwrap();
let entities = agent.extract_entity_names_from_query("Tell me about John Smith", &graph);
assert!(!entities.is_empty());
}
#[test]
fn test_query_plan_generation() {
let agent = GraphRAGAgent::new();
let graph = KnowledgeGraph::new();
let plan = agent
.generate_query_plan("What is the relationship between A and B?", &graph)
.unwrap();
matches!(plan.strategy, QueryStrategy::RelationshipExploration);
}
}