use super::functions::{
EntityExpandFunction, GetEntityContextFunction, GraphSearchFunction,
InferRelationshipsFunction, RelationshipTraverseFunction,
};
use super::FunctionCaller;
use crate::Result;
pub struct ToolRegistry;
impl ToolRegistry {
pub fn register_default_functions(function_caller: &mut FunctionCaller) -> Result<()> {
function_caller.register_function(Box::new(GraphSearchFunction));
function_caller.register_function(Box::new(EntityExpandFunction));
function_caller.register_function(Box::new(RelationshipTraverseFunction));
function_caller.register_function(Box::new(GetEntityContextFunction));
function_caller.register_function(Box::new(InferRelationshipsFunction::new()));
Ok(())
}
pub fn get_openai_function_definitions(function_caller: &FunctionCaller) -> json::JsonValue {
let definitions = function_caller.get_function_definitions();
let function_objects: Vec<_> = definitions
.into_iter()
.map(|def| {
json::object! {
"type": "function",
"function": {
"name": def.name,
"description": def.description,
"parameters": def.parameters
}
}
})
.collect();
json::JsonValue::Array(function_objects)
}
pub fn parse_openai_function_call(
response: &json::JsonValue,
) -> Result<Vec<super::FunctionCall>> {
let mut function_calls = Vec::new();
if response["function_call"].is_object() {
let function_call = &response["function_call"];
if let (Some(name), Some(arguments_str)) = (
function_call["name"].as_str(),
function_call["arguments"].as_str(),
) {
let arguments = json::parse(arguments_str).map_err(crate::GraphRAGError::Json)?;
function_calls.push(super::FunctionCall {
name: name.to_string(),
arguments,
});
}
}
if response["tool_calls"].is_array() {
for tool_call in response["tool_calls"].members() {
if tool_call["function"].is_object() {
let function = &tool_call["function"];
if let (Some(name), Some(arguments_str)) =
(function["name"].as_str(), function["arguments"].as_str())
{
let arguments =
json::parse(arguments_str).map_err(crate::GraphRAGError::Json)?;
function_calls.push(super::FunctionCall {
name: name.to_string(),
arguments,
});
}
}
}
}
Ok(function_calls)
}
pub fn format_function_results_for_llm(results: &[super::FunctionResult]) -> String {
if results.is_empty() {
return "No function calls were made.".to_string();
}
let mut formatted = String::from("Function call results:\n\n");
for (i, result) in results.iter().enumerate() {
formatted.push_str(&format!(
"{index}. Function: {name}\n",
index = i + 1,
name = result.function_name
));
if result.success {
formatted.push_str(" Status: Success\n");
formatted.push_str(&format!(
" Result: {result}\n",
result = result.result.pretty(2)
));
} else {
formatted.push_str(" Status: Failed\n");
if let Some(error) = &result.error {
formatted.push_str(&format!(" Error: {error}\n"));
}
}
formatted.push_str(&format!(
" Execution time: {time}ms\n\n",
time = result.execution_time_ms
));
}
formatted
}
pub fn create_system_prompt() -> String {
r#"You are a GraphRAG assistant that can interact with a knowledge graph database through function calls.
Available functions:
1. graph_search: Search for entities by name or partial name match
2. entity_expand: Get all relationships and connected entities for a specific entity
3. relationship_traverse: Find paths between two entities in the graph
4. get_entity_context: Get text chunks where an entity appears for detailed context
5. infer_relationships: Infer implicit relationships based on context patterns and co-occurrence
Guidelines:
- Use function calls to gather information from the knowledge graph before answering
- Start with graph_search to find relevant entities
- Use entity_expand to understand explicit relationships around key entities
- Use infer_relationships to find implicit relationships (friends, enemies, etc.)
- Use relationship_traverse to find connections between entities
- Use get_entity_context to get detailed textual information
- Synthesize information from multiple function calls for comprehensive answers
- Distinguish between explicit (directly stated) and implicit (inferred) relationships
- If no relevant information is found, state that clearly
Remember to make strategic function calls based on the user's question type:
- For entity information: graph_search → entity_expand → get_entity_context
- For explicit relationships: graph_search → relationship_traverse
- For implicit relationships: graph_search → infer_relationships → get_entity_context
- For friendship queries: graph_search → infer_relationships (relation_type: "FRIEND")
- For detailed analysis: graph_search → entity_expand → infer_relationships → get_entity_context
Always explain your reasoning and cite the function call results in your answers."#.to_string()
}
}
pub struct PromptBuilder;
impl PromptBuilder {
pub fn build_function_calling_prompt(
user_query: &str,
function_caller: &FunctionCaller,
previous_results: &[super::FunctionResult],
) -> String {
let mut prompt = String::new();
prompt.push_str(&ToolRegistry::create_system_prompt());
prompt.push_str("\n\n");
let function_defs = function_caller.get_function_definitions();
if !function_defs.is_empty() {
prompt.push_str("Available functions:\n");
for def in function_defs {
prompt.push_str(&format!("- {}: {}\n", def.name, def.description));
}
prompt.push('\n');
}
if !previous_results.is_empty() {
prompt.push_str("Previous function call results:\n");
prompt.push_str(&ToolRegistry::format_function_results_for_llm(
previous_results,
));
prompt.push('\n');
}
prompt.push_str(&format!("User query: {user_query}\n\n"));
prompt.push_str("Please analyze the query and make appropriate function calls to gather information from the knowledge graph, then provide a comprehensive answer based on the results.");
prompt
}
pub fn build_synthesis_prompt(
user_query: &str,
function_results: &[super::FunctionResult],
) -> String {
let mut prompt = String::new();
prompt.push_str("Based on the following function call results, provide a comprehensive answer to the user's query.\n\n");
prompt.push_str(&format!("User query: {user_query}\n\n"));
prompt.push_str(&ToolRegistry::format_function_results_for_llm(
function_results,
));
prompt.push_str("\nPlease synthesize this information into a clear, comprehensive answer that directly addresses the user's question. ");
prompt.push_str("Include specific details from the function results and explain any relationships or connections found. ");
prompt.push_str("If insufficient information was found, state that clearly and suggest what additional information might be helpful.");
prompt
}
}
pub struct QueryAnalyzer;
impl QueryAnalyzer {
pub fn analyze_query(query: &str) -> QueryAnalysis {
let query_lower = query.to_lowercase();
let is_entity_focused = query_lower.contains("what is")
|| query_lower.contains("who is")
|| query_lower.contains("tell me about");
let is_relationship_focused = query_lower.contains("relationship")
|| query_lower.contains("connect")
|| query_lower.contains("relation")
|| query_lower.contains("between")
|| query_lower.contains("how are")
|| query_lower.contains("associated");
let is_context_focused = query_lower.contains("context")
|| query_lower.contains("detail")
|| query_lower.contains("information")
|| query_lower.contains("describe")
|| query_lower.contains("explain");
let requires_search =
!query_lower.contains("list all") && !query_lower.contains("show all");
QueryAnalysis {
is_entity_focused,
is_relationship_focused,
is_context_focused,
requires_search,
complexity: if is_relationship_focused {
QueryComplexity::High
} else if is_context_focused {
QueryComplexity::Medium
} else {
QueryComplexity::Low
},
}
}
pub fn extract_quoted_entities(query: &str) -> Vec<String> {
let mut entities = Vec::new();
let mut in_quotes = false;
let mut current_entity = String::new();
for ch in query.chars() {
match ch {
'"' | '\'' => {
if in_quotes && !current_entity.is_empty() {
entities.push(current_entity.trim().to_string());
current_entity.clear();
}
in_quotes = !in_quotes;
},
_ if in_quotes => {
current_entity.push(ch);
},
_ => {},
}
}
entities
}
pub fn extract_capitalized_terms(query: &str) -> Vec<String> {
query
.split_whitespace()
.filter(|word| {
word.len() > 2
&& word
.chars()
.next()
.expect("non-empty string")
.is_uppercase()
&& !word.chars().all(|c| c.is_uppercase()) })
.map(|word| {
word.trim_matches(|c: char| !c.is_alphanumeric())
.to_string()
})
.filter(|word| !word.is_empty())
.collect()
}
}
#[derive(Debug, Clone)]
pub struct QueryAnalysis {
pub is_entity_focused: bool,
pub is_relationship_focused: bool,
pub is_context_focused: bool,
pub requires_search: bool,
pub complexity: QueryComplexity,
}
#[derive(Debug, Clone, PartialEq)]
pub enum QueryComplexity {
Low, Medium, High, }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_analysis() {
let analysis =
QueryAnalyzer::analyze_query("What is the relationship between John and Mary?");
assert!(analysis.is_relationship_focused);
assert_eq!(analysis.complexity, QueryComplexity::High);
}
#[test]
fn test_extract_quoted_entities() {
let entities =
QueryAnalyzer::extract_quoted_entities("Tell me about \"John Smith\" and 'Mary Jones'");
assert_eq!(entities.len(), 2);
assert!(entities.contains(&"John Smith".to_string()));
assert!(entities.contains(&"Mary Jones".to_string()));
}
#[test]
fn test_extract_capitalized_terms() {
let terms =
QueryAnalyzer::extract_capitalized_terms("John Smith works at Acme Corp in New York");
assert!(terms.contains(&"John".to_string()));
assert!(terms.contains(&"Smith".to_string()));
assert!(terms.contains(&"Acme".to_string()));
}
#[test]
fn test_system_prompt_creation() {
let prompt = ToolRegistry::create_system_prompt();
assert!(prompt.contains("GraphRAG assistant"));
assert!(prompt.contains("function calls"));
}
}