mod tool_callbacks_integration;
use langchainrust::{AgentAction, AgentFinish, AgentStep, AgentOutput, BaseAgent, AgentExecutor, AgentError, Calculator, ToolRegistry, Tool};
use langchainrust::tools::CalculatorInput;
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
struct MockAgent {
actions: Vec<AgentOutput>,
current: std::sync::atomic::AtomicUsize,
}
impl MockAgent {
fn new(actions: Vec<AgentOutput>) -> Self {
Self {
actions,
current: std::sync::atomic::AtomicUsize::new(0),
}
}
}
#[async_trait]
impl BaseAgent for MockAgent {
async fn plan(
&self,
_intermediate_steps: &[AgentStep],
_inputs: &HashMap<String, String>,
) -> Result<AgentOutput, AgentError> {
let index = self.current.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if index < self.actions.len() {
Ok(self.actions[index].clone())
} else {
Ok(AgentOutput::Finish(AgentFinish::new(
"默认答案".to_string(),
String::new(),
)))
}
}
}
#[test]
fn test_agent_action_serialization() {
let action = AgentAction {
tool: "calculator".to_string(),
tool_input: langchainrust::ToolInput::String("2 + 2".to_string()),
log: "Thought: 计算加法\nAction: calculator\nAction Input: 2 + 2".to_string(),
};
let json = serde_json::to_string(&action).unwrap();
assert!(json.contains("calculator"));
assert!(json.contains("2 + 2"));
let deserialized: AgentAction = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.tool, "calculator");
}
#[test]
fn test_agent_finish() {
let finish = AgentFinish::new("答案是 42".to_string(), "Final Answer: 答案是 42".to_string());
assert_eq!(finish.output(), Some("答案是 42"));
assert!(finish.return_values.contains_key("output"));
}
#[test]
fn test_agent_step() {
let action = AgentAction {
tool: "test".to_string(),
tool_input: langchainrust::ToolInput::String("input".to_string()),
log: "log".to_string(),
};
let step = AgentStep::new(action.clone(), "observation".to_string());
assert_eq!(step.action.tool, "test");
assert_eq!(step.observation, "observation");
}
#[tokio::test]
async fn test_agent_executor_direct_answer() {
let agent = Arc::new(MockAgent::new(vec![
AgentOutput::Finish(AgentFinish::new(
"直接答案".to_string(),
String::new(),
)),
]));
let executor = AgentExecutor::new(agent, vec![]);
let result = executor.invoke("测试问题".to_string()).await.unwrap();
assert_eq!(result, "直接答案");
}
#[tokio::test]
async fn test_agent_executor_with_tool() {
let agent = Arc::new(MockAgent::new(vec![
AgentOutput::Action(AgentAction {
tool: "calculator".to_string(),
tool_input: langchainrust::ToolInput::String("{\"expression\": \"2 + 3\"}".to_string()),
log: String::new(),
}),
AgentOutput::Finish(AgentFinish::new(
"计算完成".to_string(),
String::new(),
)),
]));
let calculator = Arc::new(Calculator::new());
let executor = AgentExecutor::new(agent, vec![calculator]);
let result = executor.invoke("计算 2 + 3".to_string()).await.unwrap();
assert_eq!(result, "计算完成");
}
#[tokio::test]
async fn test_agent_executor_tool_not_found() {
let agent = Arc::new(MockAgent::new(vec![
AgentOutput::Action(AgentAction {
tool: "nonexistent".to_string(),
tool_input: langchainrust::ToolInput::String("input".to_string()),
log: String::new(),
}),
]));
let executor = AgentExecutor::new(agent, vec![]);
let result = executor.invoke("测试".to_string()).await;
assert!(result.is_err());
match result.unwrap_err() {
AgentError::ToolNotFound(name) => assert_eq!(name, "nonexistent"),
_ => panic!("期望 ToolNotFound 错误"),
}
}
#[tokio::test]
async fn test_agent_executor_max_iterations() {
let agent = Arc::new(MockAgent::new(vec![
AgentOutput::Action(AgentAction {
tool: "calculator".to_string(),
tool_input: langchainrust::ToolInput::String("{\"expression\": \"1\"}".to_string()),
log: String::new(),
});
20 ]));
let calculator = Arc::new(Calculator::new());
let executor = AgentExecutor::new(agent, vec![calculator])
.with_max_iterations(3);
let result = executor.invoke("测试".to_string()).await.unwrap();
assert!(result.contains("iteration limit"));
}
#[tokio::test]
async fn test_agent_executor_verbose() {
let agent = Arc::new(MockAgent::new(vec![
AgentOutput::Finish(AgentFinish::new(
"答案".to_string(),
String::new(),
)),
]));
let executor = AgentExecutor::new(agent, vec![])
.with_verbose(true);
let result = executor.invoke("测试".to_string()).await.unwrap();
assert_eq!(result, "答案");
}
#[test]
fn test_tool_input_display() {
let input = langchainrust::ToolInput::String("test".to_string());
assert_eq!(format!("{}", input), "test");
let json_input = langchainrust::ToolInput::Object(serde_json::json!({"key": "value"}));
assert!(format!("{}", json_input).contains("key"));
}
#[test]
fn test_agent_error_display() {
let error = AgentError::ToolNotFound("test".to_string());
assert!(error.to_string().contains("工具未找到"));
let error = AgentError::MaxIterationsReached;
assert!(error.to_string().contains("最大迭代次数"));
let error = AgentError::OutputParsingError("parse error".to_string());
assert!(error.to_string().contains("解析错误"));
}