use langchainrust::{
AgentExecutor, BaseAgent, AgentError, AgentOutput, AgentStep, AgentFinish, AgentAction,
CallbackManager, CallbackHandler, RunTree,
tools::{Calculator, SimpleMathTool},
BaseTool,
};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
pub struct ToolTrackingHandler {
calls: Arc<Mutex<Vec<String>>>,
tool_start_count: Arc<Mutex<usize>>,
tool_end_count: Arc<Mutex<usize>>,
tool_error_count: Arc<Mutex<usize>>,
tool_inputs: Arc<Mutex<Vec<(String, String)>>>,
tool_outputs: Arc<Mutex<Vec<(String, String)>>>,
tool_durations: Arc<Mutex<Vec<(String, i64)>>>,
}
impl ToolTrackingHandler {
pub fn new() -> Self {
Self {
calls: Arc::new(Mutex::new(Vec::new())),
tool_start_count: Arc::new(Mutex::new(0)),
tool_end_count: Arc::new(Mutex::new(0)),
tool_error_count: Arc::new(Mutex::new(0)),
tool_inputs: Arc::new(Mutex::new(Vec::new())),
tool_outputs: Arc::new(Mutex::new(Vec::new())),
tool_durations: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn get_calls(&self) -> Vec<String> {
self.calls.lock().unwrap().clone()
}
pub fn get_tool_start_count(&self) -> usize {
*self.tool_start_count.lock().unwrap()
}
pub fn get_tool_end_count(&self) -> usize {
*self.tool_end_count.lock().unwrap()
}
pub fn get_tool_error_count(&self) -> usize {
*self.tool_error_count.lock().unwrap()
}
pub fn get_tool_inputs(&self) -> Vec<(String, String)> {
self.tool_inputs.lock().unwrap().clone()
}
pub fn get_tool_outputs(&self) -> Vec<(String, String)> {
self.tool_outputs.lock().unwrap().clone()
}
pub fn verify_call_order(&self) -> bool {
let calls = self.get_calls();
let mut last_was_start = false;
for call in calls {
if call.starts_with("tool_start:") {
if last_was_start {
return false;
}
last_was_start = true;
} else if call.starts_with("tool_end:") || call.starts_with("tool_error:") {
if !last_was_start {
return false;
}
last_was_start = false;
}
}
!last_was_start
}
}
#[async_trait]
impl CallbackHandler for ToolTrackingHandler {
async fn on_run_start(&self, run: &RunTree) {
self.calls.lock().unwrap().push(format!("run_start:{}", run.name));
}
async fn on_run_end(&self, run: &RunTree) {
self.calls.lock().unwrap().push(format!("run_end:{}", run.name));
}
async fn on_run_error(&self, run: &RunTree, error: &str) {
self.calls.lock().unwrap().push(format!("run_error:{}:{}", run.name, error));
}
async fn on_chain_start(&self, run: &RunTree, _inputs: &serde_json::Value) {
self.calls.lock().unwrap().push(format!("chain_start:{}", run.name));
}
async fn on_chain_end(&self, run: &RunTree, _outputs: &serde_json::Value) {
self.calls.lock().unwrap().push(format!("chain_end:{}", run.name));
}
async fn on_chain_error(&self, run: &RunTree, error: &str) {
self.calls.lock().unwrap().push(format!("chain_error:{}:{}", run.name, error));
}
async fn on_tool_start(&self, run: &RunTree, tool_name: &str, input: &str) {
self.calls.lock().unwrap().push(format!("tool_start:{}:{}", run.name, tool_name));
let mut count = self.tool_start_count.lock().unwrap();
*count += 1;
self.tool_inputs.lock().unwrap().push((tool_name.to_string(), input.to_string()));
let start_time = std::time::Instant::now();
self.tool_durations.lock().unwrap().push((format!("{}_start", tool_name), start_time.elapsed().as_millis() as i64));
}
async fn on_tool_end(&self, run: &RunTree, output: &str) {
let tool_name = run.name.clone();
self.calls.lock().unwrap().push(format!("tool_end:{}:{}", run.name, output));
let mut count = self.tool_end_count.lock().unwrap();
*count += 1;
self.tool_outputs.lock().unwrap().push((tool_name.clone(), output.to_string()));
let end_time = std::time::Instant::now();
self.tool_durations.lock().unwrap().push((format!("{}_end", tool_name), end_time.elapsed().as_millis() as i64));
}
async fn on_tool_error(&self, run: &RunTree, error: &str) {
self.calls.lock().unwrap().push(format!("tool_error:{}:{}", run.name, error));
let mut count = self.tool_error_count.lock().unwrap();
*count += 1;
}
async fn on_llm_start(&self, run: &RunTree, _messages: &[langchainrust::schema::Message]) {
self.calls.lock().unwrap().push(format!("llm_start:{}", run.name));
}
async fn on_llm_end(&self, run: &RunTree, _response: &str) {
self.calls.lock().unwrap().push(format!("llm_end:{}", run.name));
}
}
struct ToolCallingAgent {
tool_to_call: String,
tool_input: String,
final_answer: String,
}
impl ToolCallingAgent {
fn new(tool_name: &str, tool_input: &str, final_answer: &str) -> Self {
Self {
tool_to_call: tool_name.to_string(),
tool_input: tool_input.to_string(),
final_answer: final_answer.to_string(),
}
}
}
#[async_trait]
impl BaseAgent for ToolCallingAgent {
async fn plan(
&self,
intermediate_steps: &[AgentStep],
inputs: &HashMap<String, String>,
) -> Result<AgentOutput, AgentError> {
if intermediate_steps.is_empty() {
return Ok(AgentOutput::Action(AgentAction {
tool: self.tool_to_call.clone(),
tool_input: langchainrust::agents::ToolInput::String(self.tool_input.clone()),
log: format!("我需要使用 {} 工具", self.tool_to_call),
}));
}
let observation = &intermediate_steps[0].observation;
Ok(AgentOutput::Finish(AgentFinish::new(
format!("{} (工具结果: {})", self.final_answer, observation),
format!("使用了 {} 工具", self.tool_to_call),
)))
}
}
struct MultiToolAgent {
tool_calls: Vec<(String, String)>,
current_index: Arc<Mutex<usize>>,
}
impl MultiToolAgent {
fn new(tool_calls: Vec<(String, String)>) -> Self {
Self {
tool_calls,
current_index: Arc::new(Mutex::new(0)),
}
}
}
#[async_trait]
impl BaseAgent for MultiToolAgent {
async fn plan(
&self,
intermediate_steps: &[AgentStep],
inputs: &HashMap<String, String>,
) -> Result<AgentOutput, AgentError> {
let mut index = self.current_index.lock().unwrap();
if *index < self.tool_calls.len() {
let (tool_name, tool_input) = &self.tool_calls[*index];
*index += 1;
return Ok(AgentOutput::Action(AgentAction {
tool: tool_name.clone(),
tool_input: langchainrust::agents::ToolInput::String(tool_input.clone()),
log: format!("调用工具 {}", tool_name),
}));
}
let observations = intermediate_steps.iter()
.map(|s| format!("{}: {}", s.action.tool, s.observation))
.collect::<Vec<_>>()
.join("; ");
Ok(AgentOutput::Finish(AgentFinish::new(
format!("完成所有工具调用。结果: {}", observations),
String::new(),
)))
}
}
struct ErrorToolAgent;
#[async_trait]
impl BaseAgent for ErrorToolAgent {
async fn plan(
&self,
intermediate_steps: &[AgentStep],
_inputs: &HashMap<String, String>,
) -> Result<AgentOutput, AgentError> {
if intermediate_steps.is_empty() {
return Ok(AgentOutput::Action(AgentAction {
tool: "nonexistent_tool".to_string(),
tool_input: langchainrust::agents::ToolInput::String("test".to_string()),
log: "尝试调用不存在的工具".to_string(),
}));
}
let error_observation = &intermediate_steps[0].observation;
Ok(AgentOutput::Finish(AgentFinish::new(
format!("工具调用失败: {}", error_observation),
String::new(),
)))
}
}
#[tokio::test]
async fn test_single_tool_callback_chain() {
let handler = Arc::new(ToolTrackingHandler::new());
let callbacks = Arc::new(CallbackManager::new().add_handler(handler.clone()));
let tools: Vec<Arc<dyn BaseTool>> = vec![
Arc::new(Calculator::new()),
];
let agent = Arc::new(ToolCallingAgent::new(
"calculator",
"{\"expression\": \"10 + 20\"}",
"计算完成"
));
let executor = AgentExecutor::new(agent, tools)
.with_callbacks(callbacks);
let result = executor.invoke("计算 10 + 20".to_string()).await;
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.contains("30"));
assert_eq!(handler.get_tool_start_count(), 1, "on_tool_start 应被调用 1 次");
assert_eq!(handler.get_tool_end_count(), 1, "on_tool_end 应被调用 1 次");
assert_eq!(handler.get_tool_error_count(), 0, "不应有错误回调");
assert!(handler.verify_call_order(), "回调顺序应正确: start → end");
let inputs = handler.get_tool_inputs();
assert_eq!(inputs.len(), 1);
assert_eq!(inputs[0].0, "calculator");
assert!(inputs[0].1.contains("10 + 20"));
let outputs = handler.get_tool_outputs();
assert_eq!(outputs.len(), 1);
assert!(outputs[0].1.contains("30"));
}
#[tokio::test]
async fn test_multiple_tools_callback_chain() {
let handler = Arc::new(ToolTrackingHandler::new());
let callbacks = Arc::new(CallbackManager::new().add_handler(handler.clone()));
let tools: Vec<Arc<dyn BaseTool>> = vec![
Arc::new(Calculator::new()),
Arc::new(SimpleMathTool::new()),
];
let agent = Arc::new(MultiToolAgent::new(vec![
("calculator".to_string(), "{\"expression\": \"5 + 5\"}".to_string()),
("math".to_string(), "{\"operation\": \"sqrt\", \"value\": 100}".to_string()),
]));
let executor = AgentExecutor::new(agent, tools)
.with_callbacks(callbacks);
let result = executor.invoke("计算多个数学问题".to_string()).await;
assert!(result.is_ok());
assert_eq!(handler.get_tool_start_count(), 2, "应调用 2 个工具");
assert_eq!(handler.get_tool_end_count(), 2);
assert_eq!(handler.get_tool_error_count(), 0);
assert!(handler.verify_call_order(), "多工具回调顺序应正确");
let calls = handler.get_calls();
assert!(calls.iter().any(|c| c.contains("tool_start") && c.contains("calculator")));
assert!(calls.iter().any(|c| c.contains("tool_start") && c.contains("math")));
}
#[tokio::test]
async fn test_tool_error_callback() {
let handler = Arc::new(ToolTrackingHandler::new());
let callbacks = Arc::new(CallbackManager::new().add_handler(handler.clone()));
let tools: Vec<Arc<dyn BaseTool>> = vec![
Arc::new(Calculator::new()),
];
let agent = Arc::new(ErrorToolAgent);
let executor = AgentExecutor::new(agent, tools)
.with_max_iterations(2)
.with_callbacks(callbacks);
let result = executor.invoke("测试错误".to_string()).await;
assert!(result.is_err(), "调用不存在的工具应返回错误");
let calls = handler.get_calls();
assert!(calls.iter().any(|c| c.starts_with("chain_start")), "应有 chain_start");
assert!(calls.iter().any(|c| c.starts_with("chain_error")), "应有 chain_error");
let has_error_handling = calls.iter().any(|c| c.contains("error"));
assert!(has_error_handling, "应有错误处理回调");
}
#[tokio::test]
async fn test_tool_invalid_input_callback() {
let handler = Arc::new(ToolTrackingHandler::new());
let callbacks = Arc::new(CallbackManager::new().add_handler(handler.clone()));
let tools: Vec<Arc<dyn BaseTool>> = vec![
Arc::new(Calculator::new()),
];
let agent = Arc::new(ToolCallingAgent::new(
"calculator",
"invalid json input", "计算完成"
));
let executor = AgentExecutor::new(agent, tools)
.with_callbacks(callbacks);
let result = executor.invoke("测试无效输入".to_string()).await;
assert!(handler.get_tool_start_count() >= 1);
let total_callbacks = handler.get_tool_start_count() + handler.get_tool_end_count() + handler.get_tool_error_count();
assert!(total_callbacks >= 2, "至少应有 start + end/error");
}
#[tokio::test]
async fn test_full_execution_trace() {
let handler = Arc::new(ToolTrackingHandler::new());
let callbacks = Arc::new(CallbackManager::new().add_handler(handler.clone()));
let tools: Vec<Arc<dyn BaseTool>> = vec![
Arc::new(Calculator::new()),
];
let agent = Arc::new(ToolCallingAgent::new(
"calculator",
"{\"expression\": \"2 * 3\"}",
"乘法计算完成"
));
let executor = AgentExecutor::new(agent, tools)
.with_callbacks(callbacks);
executor.invoke("计算 2 * 3".to_string()).await.unwrap();
let calls = handler.get_calls();
assert!(calls.iter().any(|c| c.starts_with("chain_start")), "应有 chain_start");
assert!(calls.iter().any(|c| c.starts_with("chain_end")), "应有 chain_end");
assert!(calls.iter().any(|c| c.starts_with("tool_start")), "应有 tool_start");
assert!(calls.iter().any(|c| c.starts_with("tool_end")), "应有 tool_end");
let chain_start_idx = calls.iter().position(|c| c.starts_with("chain_start")).unwrap();
let tool_start_idx = calls.iter().position(|c| c.starts_with("tool_start")).unwrap();
let tool_end_idx = calls.iter().position(|c| c.starts_with("tool_end")).unwrap();
let chain_end_idx = calls.iter().position(|c| c.starts_with("chain_end")).unwrap();
assert!(chain_start_idx < tool_start_idx, "chain_start 应在 tool_start 之前");
assert!(tool_start_idx < tool_end_idx, "tool_start 应在 tool_end 之前");
assert!(tool_end_idx < chain_end_idx, "tool_end 应在 chain_end 之前");
}
#[tokio::test]
async fn test_multiple_handlers_receive_tool_events() {
let handler1 = Arc::new(ToolTrackingHandler::new());
let handler2 = Arc::new(ToolTrackingHandler::new());
let callbacks = Arc::new(CallbackManager::new()
.add_handler(handler1.clone())
.add_handler(handler2.clone()));
let tools: Vec<Arc<dyn BaseTool>> = vec![
Arc::new(Calculator::new()),
];
let agent = Arc::new(ToolCallingAgent::new(
"calculator",
"{\"expression\": \"1 + 1\"}",
"完成"
));
let executor = AgentExecutor::new(agent, tools)
.with_callbacks(callbacks);
executor.invoke("测试".to_string()).await.unwrap();
assert_eq!(handler1.get_tool_start_count(), 1, "handler1 应收到 tool_start");
assert_eq!(handler2.get_tool_start_count(), 1, "handler2 应收到 tool_start");
assert_eq!(handler1.get_tool_end_count(), 1, "handler1 应收到 tool_end");
assert_eq!(handler2.get_tool_end_count(), 1, "handler2 应收到 tool_end");
}
#[tokio::test]
async fn test_tool_input_output_integrity() {
let handler = Arc::new(ToolTrackingHandler::new());
let callbacks = Arc::new(CallbackManager::new().add_handler(handler.clone()));
let tools: Vec<Arc<dyn BaseTool>> = vec![
Arc::new(SimpleMathTool::new()),
];
let agent = Arc::new(ToolCallingAgent::new(
"math",
"{\"operation\": \"factorial\", \"value\": 5}",
"阶乘计算完成"
));
let executor = AgentExecutor::new(agent, tools)
.with_callbacks(callbacks);
executor.invoke("计算 5 的阶乘".to_string()).await.unwrap();
let inputs = handler.get_tool_inputs();
assert_eq!(inputs.len(), 1);
assert_eq!(inputs[0].0, "math");
assert!(inputs[0].1.contains("factorial"));
assert!(inputs[0].1.contains("5"));
let outputs = handler.get_tool_outputs();
assert_eq!(outputs.len(), 1);
assert_eq!(outputs[0].0, "math");
assert!(outputs[0].1.contains("120")); }
#[tokio::test]
async fn test_direct_tool_run_no_callback() {
let handler = Arc::new(ToolTrackingHandler::new());
let callbacks = Arc::new(CallbackManager::new().add_handler(handler.clone()));
let calc = Calculator::new();
let result = calc.run("{\"expression\": \"3 + 4\"}".to_string()).await.unwrap();
assert!(result.contains("7"));
assert_eq!(handler.get_tool_start_count(), 0, "直接调用不应触发回调");
assert_eq!(handler.get_tool_end_count(), 0);
}
#[tokio::test]
async fn test_tool_run_tree_hierarchy() {
let handler = Arc::new(ToolTrackingHandler::new());
let callbacks = Arc::new(CallbackManager::new().add_handler(handler.clone()));
let tools: Vec<Arc<dyn BaseTool>> = vec![
Arc::new(Calculator::new()),
];
let agent = Arc::new(ToolCallingAgent::new(
"calculator",
"{\"expression\": \"100 / 25\"}",
"除法计算完成"
));
let executor = AgentExecutor::new(agent, tools)
.with_callbacks(callbacks);
executor.invoke("计算 100 / 25".to_string()).await.unwrap();
let calls = handler.get_calls();
let chain_calls = calls.iter().filter(|c| c.contains("AgentExecutor")).count();
let tool_calls = calls.iter().filter(|c| c.contains("calculator")).count();
assert!(chain_calls >= 2, "应有 chain start/end");
assert!(tool_calls >= 2, "应有 tool start/end");
}