use crate::{core::KnowledgeGraph, GraphRAGError, Result};
use json::JsonValue;
use std::collections::HashMap;
pub mod agent;
pub mod enhanced_registry;
pub mod functions;
pub mod tools;
#[derive(Debug, Clone)]
pub struct FunctionDefinition {
pub name: String,
pub description: String,
pub parameters: JsonValue,
pub required: bool,
}
#[derive(Debug, Clone)]
pub struct FunctionResult {
pub function_name: String,
pub arguments: JsonValue,
pub result: JsonValue,
pub success: bool,
pub error: Option<String>,
pub execution_time_ms: u64,
}
#[derive(Debug, Clone)]
pub struct FunctionCall {
pub name: String,
pub arguments: JsonValue,
}
pub trait CallableFunction: Send + Sync {
fn call(&self, arguments: JsonValue, context: &FunctionContext) -> Result<JsonValue>;
fn definition(&self) -> FunctionDefinition;
fn validate_arguments(&self, arguments: &JsonValue) -> Result<()>;
}
#[derive(Debug)]
pub struct FunctionContext<'a> {
pub knowledge_graph: &'a KnowledgeGraph,
pub query: &'a str,
pub previous_results: &'a [FunctionResult],
pub metadata: HashMap<String, JsonValue>,
}
pub struct FunctionCaller {
functions: HashMap<String, Box<dyn CallableFunction>>,
max_calls_per_query: usize,
call_history: Vec<FunctionResult>,
}
impl FunctionCaller {
pub fn new() -> Self {
Self {
functions: HashMap::new(),
max_calls_per_query: 10,
call_history: Vec::new(),
}
}
pub fn register_function(&mut self, function: Box<dyn CallableFunction>) {
let name = function.definition().name.clone();
self.functions.insert(name, function);
}
pub fn get_function_definitions(&self) -> Vec<FunctionDefinition> {
self.functions.values().map(|f| f.definition()).collect()
}
pub fn call_function(
&mut self,
function_call: FunctionCall,
context: &FunctionContext,
) -> Result<FunctionResult> {
let start_time = std::time::Instant::now();
let function =
self.functions
.get(&function_call.name)
.ok_or_else(|| GraphRAGError::Generation {
message: format!("Function '{}' not found", function_call.name),
})?;
if let Err(e) = function.validate_arguments(&function_call.arguments) {
return Ok(FunctionResult {
function_name: function_call.name,
arguments: function_call.arguments,
result: JsonValue::Null,
success: false,
error: Some(e.to_string()),
execution_time_ms: start_time.elapsed().as_millis() as u64,
});
}
let result = match function.call(function_call.arguments.clone(), context) {
Ok(result) => FunctionResult {
function_name: function_call.name.clone(),
arguments: function_call.arguments,
result,
success: true,
error: None,
execution_time_ms: start_time.elapsed().as_millis() as u64,
},
Err(e) => FunctionResult {
function_name: function_call.name,
arguments: function_call.arguments,
result: JsonValue::Null,
success: false,
error: Some(e.to_string()),
execution_time_ms: start_time.elapsed().as_millis() as u64,
},
};
self.call_history.push(result.clone());
Ok(result)
}
pub fn call_functions(
&mut self,
function_calls: Vec<FunctionCall>,
context: &FunctionContext,
) -> Result<Vec<FunctionResult>> {
if function_calls.len() > self.max_calls_per_query {
return Err(GraphRAGError::Generation {
message: format!(
"Too many function calls requested: {} (max: {})",
function_calls.len(),
self.max_calls_per_query
),
});
}
let mut results = Vec::new();
for call in function_calls {
let result = self.call_function(call, context)?;
results.push(result);
}
Ok(results)
}
pub fn get_call_history(&self) -> &[FunctionResult] {
&self.call_history
}
pub fn clear_history(&mut self) {
self.call_history.clear();
}
pub fn get_statistics(&self) -> FunctionCallStatistics {
let total_calls = self.call_history.len();
let successful_calls = self.call_history.iter().filter(|r| r.success).count();
let failed_calls = total_calls - successful_calls;
let total_execution_time: u64 = self.call_history.iter().map(|r| r.execution_time_ms).sum();
let average_execution_time = if total_calls > 0 {
total_execution_time / total_calls as u64
} else {
0
};
let mut function_usage = HashMap::new();
for result in &self.call_history {
*function_usage
.entry(result.function_name.clone())
.or_insert(0) += 1;
}
FunctionCallStatistics {
total_calls,
successful_calls,
failed_calls,
total_execution_time_ms: total_execution_time,
average_execution_time_ms: average_execution_time,
function_usage,
}
}
pub fn set_max_calls_per_query(&mut self, max_calls: usize) {
self.max_calls_per_query = max_calls;
}
}
impl Default for FunctionCaller {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct FunctionCallStatistics {
pub total_calls: usize,
pub successful_calls: usize,
pub failed_calls: usize,
pub total_execution_time_ms: u64,
pub average_execution_time_ms: u64,
pub function_usage: HashMap<String, usize>,
}
impl FunctionCallStatistics {
pub fn print(&self) {
let success_rate = if self.total_calls > 0 {
(self.successful_calls as f64 / self.total_calls as f64) * 100.0
} else {
0.0
};
tracing::info!(
total_calls = self.total_calls,
successful_calls = self.successful_calls,
failed_calls = self.failed_calls,
success_rate = format!("{:.1}%", success_rate),
total_execution_time_ms = self.total_execution_time_ms,
avg_execution_time_ms = self.average_execution_time_ms,
"Function call statistics"
);
if !self.function_usage.is_empty() {
let mut usage_vec: Vec<_> = self.function_usage.iter().collect();
usage_vec.sort_by(|a, b| b.1.cmp(a.1));
for (function, count) in usage_vec {
tracing::debug!(function = %function, call_count = count, "Function usage");
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::KnowledgeGraph;
struct MockFunction {
name: String,
}
impl CallableFunction for MockFunction {
fn call(&self, arguments: JsonValue, _context: &FunctionContext) -> Result<JsonValue> {
Ok(json::object! {
"function": self.name.clone(),
"arguments": arguments,
"result": "mock_result"
})
}
fn definition(&self) -> FunctionDefinition {
FunctionDefinition {
name: self.name.clone(),
description: "Mock function for testing".to_string(),
parameters: json::object! {
"type": "object",
"properties": {
"test_param": {
"type": "string",
"description": "Test parameter"
}
},
"required": ["test_param"]
},
required: false,
}
}
fn validate_arguments(&self, arguments: &JsonValue) -> Result<()> {
if arguments["test_param"].is_null() {
return Err(GraphRAGError::Generation {
message: "test_param is required".to_string(),
});
}
Ok(())
}
}
#[test]
fn test_function_registration() {
let mut caller = FunctionCaller::new();
let mock_function = Box::new(MockFunction {
name: "test_function".to_string(),
});
caller.register_function(mock_function);
assert_eq!(caller.get_function_definitions().len(), 1);
}
#[test]
fn test_function_call() {
let mut caller = FunctionCaller::new();
let mock_function = Box::new(MockFunction {
name: "test_function".to_string(),
});
caller.register_function(mock_function);
let graph = KnowledgeGraph::new();
let context = FunctionContext {
knowledge_graph: &graph,
query: "test query",
previous_results: &[],
metadata: HashMap::new(),
};
let function_call = FunctionCall {
name: "test_function".to_string(),
arguments: json::object! {
"test_param": "test_value"
},
};
let result = caller.call_function(function_call, &context).unwrap();
assert!(result.success);
assert_eq!(result.function_name, "test_function");
}
}