use langchainrust::{
AgentAction, AgentOutput, AgentStep, ToolInput, AgentFinish,
Calculator, DateTimeTool, SimpleMathTool, BaseTool,
AgentExecutor, BaseAgent, AgentError,
};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
struct MultiActionAgent;
#[async_trait]
impl BaseAgent for MultiActionAgent {
async fn plan(
&self,
intermediate_steps: &[AgentStep],
_inputs: &HashMap<String, String>,
) -> Result<AgentOutput, AgentError> {
if !intermediate_steps.is_empty() {
let results: Vec<&str> = intermediate_steps.iter()
.map(|s| s.observation.as_str())
.collect();
return Ok(AgentOutput::Finish(AgentFinish::new(
format!("结果: {}", results.join(", ")),
String::new(),
)));
}
let actions = vec![
AgentAction {
tool: "datetime".to_string(),
tool_input: ToolInput::String("now".to_string()),
log: "call_datetime_1".to_string(),
},
AgentAction {
tool: "calculator".to_string(),
tool_input: ToolInput::Object(serde_json::json!({"expression": "100 + 200"})),
log: "call_calc_1".to_string(),
},
];
Ok(AgentOutput::Actions(actions))
}
fn get_allowed_tools(&self) -> Option<Vec<&str>> {
Some(vec!["datetime", "calculator"])
}
}
struct SingleActionAgent;
#[async_trait]
impl BaseAgent for SingleActionAgent {
async fn plan(
&self,
intermediate_steps: &[AgentStep],
_inputs: &HashMap<String, String>,
) -> Result<AgentOutput, AgentError> {
if !intermediate_steps.is_empty() {
return Ok(AgentOutput::Finish(AgentFinish::new(
intermediate_steps[0].observation.clone(),
String::new(),
)));
}
Ok(AgentOutput::Action(AgentAction {
tool: "calculator".to_string(),
tool_input: ToolInput::Object(serde_json::json!({"expression": "25 * 4"})),
log: "call_calc_1".to_string(),
}))
}
fn get_allowed_tools(&self) -> Option<Vec<&str>> {
Some(vec!["calculator"])
}
}
struct DirectFinishAgent;
#[async_trait]
impl BaseAgent for DirectFinishAgent {
async fn plan(
&self,
_intermediate_steps: &[AgentStep],
_inputs: &HashMap<String, String>,
) -> Result<AgentOutput, AgentError> {
Ok(AgentOutput::Finish(AgentFinish::new(
"这是直接回答,不需要工具".to_string(),
String::new(),
)))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_agent_output_actions_variant() {
let actions = vec![
AgentAction {
tool: "calculator".to_string(),
tool_input: ToolInput::String("1 + 1".to_string()),
log: "call_1".to_string(),
},
AgentAction {
tool: "datetime".to_string(),
tool_input: ToolInput::String("now".to_string()),
log: "call_2".to_string(),
},
];
let output = AgentOutput::Actions(actions);
assert!(output.is_action());
assert!(!output.is_finish());
assert_eq!(output.actions().len(), 2);
assert!(output.action().is_none());
let action_refs = output.actions();
assert_eq!(action_refs[0].tool, "calculator");
assert_eq!(action_refs[1].tool, "datetime");
}
#[test]
fn test_agent_output_single_vs_multiple_actions() {
let single = AgentOutput::Action(AgentAction {
tool: "calculator".to_string(),
tool_input: ToolInput::String("test".to_string()),
log: "log".to_string(),
});
let multiple = AgentOutput::Actions(vec![
AgentAction {
tool: "calculator".to_string(),
tool_input: ToolInput::String("test".to_string()),
log: "log".to_string(),
},
]);
assert!(single.is_action());
assert!(multiple.is_action());
assert!(single.action().is_some());
assert!(multiple.action().is_none());
assert_eq!(single.actions().len(), 1);
assert_eq!(multiple.actions().len(), 1);
}
#[tokio::test]
async fn test_executor_single_action_with_real_calculator() {
let tools: Vec<Arc<dyn BaseTool>> = vec![
Arc::new(Calculator::new()),
];
let executor = AgentExecutor::new(
Arc::new(SingleActionAgent),
tools,
).with_max_iterations(2);
let result = executor.invoke("计算 25 * 4".to_string()).await.unwrap();
assert!(result.contains("100"), "Calculator 应返回计算结果 100");
}
#[tokio::test]
async fn test_executor_parallel_actions_with_real_tools() {
let tools: Vec<Arc<dyn BaseTool>> = vec![
Arc::new(DateTimeTool::new()),
Arc::new(SimpleMathTool::new()),
];
let executor = AgentExecutor::new(
Arc::new(MultiActionAgent),
tools,
).with_max_iterations(2);
let result = executor.invoke("现在几点?顺便算一下 100 + 200".to_string()).await.unwrap();
assert!(result.contains("300"), "应包含 SimpleMathTool 计算结果 300");
assert!(result.contains("202") || result.len() > 10, "应包含 DateTimeTool 时间输出");
}
#[tokio::test]
async fn test_executor_direct_finish_without_tools() {
let tools: Vec<Arc<dyn BaseTool>> = vec![];
let executor = AgentExecutor::new(
Arc::new(DirectFinishAgent),
tools,
).with_max_iterations(1);
let result = executor.invoke("你好".to_string()).await.unwrap();
assert_eq!(result, "这是直接回答,不需要工具");
}
#[test]
fn test_agent_step_records_action_and_observation() {
let action = AgentAction {
tool: "calculator".to_string(),
tool_input: ToolInput::String("10 + 20".to_string()),
log: "需要计算 10 + 20".to_string(),
};
let step = AgentStep::new(action, "{\"result\": 30}".to_string());
assert_eq!(step.action.tool, "calculator");
assert_eq!(step.action.tool_input.to_string(), "10 + 20");
assert_eq!(step.observation, "{\"result\": 30}");
}
#[test]
fn test_agent_finish_constructs_final_answer() {
let finish = AgentFinish::new(
"最终答案是 42".to_string(),
"经过计算得出".to_string(),
);
assert_eq!(finish.output(), Some("最终答案是 42"));
assert!(finish.return_values.contains_key("output"));
assert_eq!(
finish.return_values.get("output").and_then(|v| v.as_str()),
Some("最终答案是 42")
);
}
#[test]
fn test_tool_input_string_and_object_formats() {
let string_input = ToolInput::String("1 + 2".to_string());
assert_eq!(string_input.to_string(), "1 + 2");
let json_input = ToolInput::Object(serde_json::json!({
"expression": "10 * 5"
}));
let json_str = json_input.to_string();
assert!(json_str.contains("expression"));
assert!(json_str.contains("10 * 5"));
}
}