use crate::agents::{AgentError, AgentOutput, AgentStep, BaseAgent};
use crate::core::tools::BaseTool;
use crate::language_models::OpenAIChat;
use crate::schema::Message;
use crate::core::language_models::BaseChatModel;
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use super::parser::ReActOutputParser;
use super::prompt::{build_react_prompt, format_scratchpad};
pub struct ReActAgent {
llm: OpenAIChat,
tools: Vec<Arc<dyn BaseTool>>,
parser: ReActOutputParser,
system_prompt: Option<String>,
}
impl ReActAgent {
pub fn new(llm: OpenAIChat, tools: Vec<Arc<dyn BaseTool>>, system_prompt: Option<String>) -> Self {
Self {
llm,
tools,
parser: ReActOutputParser::new(),
system_prompt,
}
}
fn format_tools(&self) -> String {
self.tools
.iter()
.map(|tool| format!("{}: {}", tool.name(), tool.description()))
.collect::<Vec<_>>()
.join("\n")
}
fn get_tool_names(&self) -> Vec<&str> {
self.tools.iter().map(|t| t.name()).collect()
}
fn build_prompt(&self, input: &str, intermediate_steps: &[AgentStep], history: Option<&str>) -> String {
let tools_description = self.format_tools();
let tool_names = self.get_tool_names();
let scratchpad = format_scratchpad(intermediate_steps);
let mut prompt = build_react_prompt(&tools_description, &tool_names, input, &scratchpad);
if let Some(h) = history {
if !h.is_empty() {
prompt = format!("之前的对话历史:\n{}\n\n{}", h, prompt);
}
}
if let Some(sys) = &self.system_prompt {
prompt = format!("{}\n\n{}", sys, prompt);
}
prompt
}
}
#[async_trait]
impl BaseAgent for ReActAgent {
async fn plan(
&self,
intermediate_steps: &[AgentStep],
inputs: &HashMap<String, String>,
) -> Result<AgentOutput, AgentError> {
let input = inputs.get("input")
.ok_or_else(|| AgentError::Other("缺少输入参数 'input'".to_string()))?;
let history = inputs.get("history").map(|s| s.as_str());
let prompt_text = self.build_prompt(input, intermediate_steps, history);
let messages = vec![Message::human(prompt_text)];
let result = self.llm.chat(messages, None)
.await
.map_err(|e| AgentError::Other(format!("LLM 调用失败: {}", e)))?;
self.parser.parse(&result.content)
}
fn get_allowed_tools(&self) -> Option<Vec<&str>> {
Some(self.get_tool_names())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::Calculator;
use crate::language_models::OpenAIConfig;
use crate::agents::{AgentAction, ToolInput};
fn create_test_config() -> OpenAIConfig {
OpenAIConfig {
api_key: "sk-l0YYMX65mCYRlTJYH0ptf4BFpqJwm8Xo9Z5IMqSZD0yOafl6".to_string(),
base_url: "https://api.openai-proxy.org/v1".to_string(),
model: "gpt-3.5-turbo".to_string(),
temperature: Some(0.0),
max_tokens: Some(500),
top_p: None,
frequency_penalty: None,
presence_penalty: None,
streaming: false,
organization: None,
tools: None,
tool_choice: None,
}
}
#[test]
fn test_format_tools_description() {
let config = create_test_config();
let llm = OpenAIChat::new(config);
let tools: Vec<Arc<dyn BaseTool>> = vec![Arc::new(Calculator)];
let agent = ReActAgent::new(llm, tools, None);
let desc = agent.format_tools();
assert!(desc.contains("calculator"));
}
#[test]
fn test_get_tool_names() {
let config = create_test_config();
let llm = OpenAIChat::new(config);
let tools: Vec<Arc<dyn BaseTool>> = vec![Arc::new(Calculator)];
let agent = ReActAgent::new(llm, tools, None);
let names = agent.get_tool_names();
assert_eq!(names, vec!["calculator"]);
}
#[test]
fn test_build_prompt() {
let config = create_test_config();
let llm = OpenAIChat::new(config);
let tools: Vec<Arc<dyn BaseTool>> = vec![Arc::new(Calculator)];
let agent = ReActAgent::new(llm, tools, None);
let prompt = agent.build_prompt("计算 2 + 2", &[], None);
assert!(prompt.contains("计算 2 + 2"));
assert!(prompt.contains("calculator"));
assert!(prompt.contains("Question:"));
assert!(prompt.contains("Thought:"));
}
#[test]
fn test_build_prompt_with_history() {
let config = create_test_config();
let llm = OpenAIChat::new(config);
let tools: Vec<Arc<dyn BaseTool>> = vec![Arc::new(Calculator)];
let agent = ReActAgent::new(llm, tools, None);
let prompt = agent.build_prompt("计算 3 + 3", &[], Some("用户: 你好\n助手: 你好!"));
assert!(prompt.contains("之前的对话历史"));
assert!(prompt.contains("你好"));
}
#[test]
fn test_build_prompt_with_system_prompt() {
let config = create_test_config();
let llm = OpenAIChat::new(config);
let tools: Vec<Arc<dyn BaseTool>> = vec![Arc::new(Calculator)];
let agent = ReActAgent::new(llm, tools, Some("你是一个数学助手".to_string()));
let prompt = agent.build_prompt("计算 4 + 4", &[], None);
assert!(prompt.contains("你是一个数学助手"));
}
#[tokio::test]
#[ignore = "需要真实 API 调用"]
async fn test_real_api_simple() {
let config = create_test_config();
let llm = OpenAIChat::new(config);
let tools: Vec<Arc<dyn BaseTool>> = vec![];
let agent = ReActAgent::new(llm, tools, None);
let mut inputs = HashMap::new();
inputs.insert("input".to_string(), "什么是 Rust 语言?".to_string());
let result = agent.plan(&[], &inputs).await.unwrap();
match result {
AgentOutput::Finish(finish) => {
println!("答案: {:?}", finish.return_values);
assert!(finish.output().is_some());
}
AgentOutput::Action(_) => {
println!("LLM 尝试调用工具");
}
AgentOutput::Actions(_) => {
println!("ReActAgent 不支持并行工具调用");
}
}
}
#[tokio::test]
#[ignore = "需要真实 API 调用"]
async fn test_real_api_with_calculator() {
let config = create_test_config();
let llm = OpenAIChat::new(config);
let tools: Vec<Arc<dyn BaseTool>> = vec![Arc::new(Calculator)];
let agent = ReActAgent::new(llm, tools, None);
let mut inputs = HashMap::new();
inputs.insert("input".to_string(), "计算 37 加 48 等于多少?".to_string());
let result = agent.plan(&[], &inputs).await.unwrap();
match result {
AgentOutput::Action(action) => {
println!("动作: {}({})", action.tool, action.tool_input);
assert_eq!(action.tool, "calculator");
}
AgentOutput::Finish(finish) => {
println!("直接答案: {:?}", finish.return_values);
}
AgentOutput::Actions(_) => {
println!("ReActAgent 不支持并行工具调用");
}
}
}
#[tokio::test]
#[ignore = "需要真实 API 调用"]
async fn test_real_api_multi_step() {
let config = create_test_config();
let llm = OpenAIChat::new(config);
let tools: Vec<Arc<dyn BaseTool>> = vec![Arc::new(Calculator)];
let agent = ReActAgent::new(llm, tools, None);
let steps = vec![
AgentStep::new(
AgentAction {
tool: "calculator".to_string(),
tool_input: ToolInput::String("37 + 48".to_string()),
log: "我需要先计算 37 + 48".to_string(),
},
"85".to_string(),
),
];
let mut inputs = HashMap::new();
inputs.insert("input".to_string(), "计算 (37 + 48) * 2 等于多少?".to_string());
let result = agent.plan(&steps, &inputs).await.unwrap();
match result {
AgentOutput::Action(action) => {
println!("下一步动作: {}({})", action.tool, action.tool_input);
assert_eq!(action.tool, "calculator");
}
AgentOutput::Finish(finish) => {
println!("最终答案: {:?}", finish.return_values);
}
AgentOutput::Actions(_) => {
println!("ReActAgent 不支持并行工具调用");
}
}
}
#[tokio::test]
#[ignore = "需要真实 API 调用"]
async fn test_real_api_with_memory() {
let config = create_test_config();
let llm = OpenAIChat::new(config);
let tools: Vec<Arc<dyn BaseTool>> = vec![];
let agent = ReActAgent::new(llm, tools, None);
let history = "Human: 我叫张三\nAI: 好的,张三,我记住了。";
let mut inputs = HashMap::new();
inputs.insert("input".to_string(), "我叫什么名字?".to_string());
inputs.insert("history".to_string(), history.to_string());
let result = agent.plan(&[], &inputs).await.unwrap();
match result {
AgentOutput::Finish(finish) => {
println!("答案: {:?}", finish.return_values);
let output = finish.output().unwrap_or("");
assert!(output.contains("张三"), "应该记住用户名字");
}
AgentOutput::Action(_) => {
println!("LLM 尝试调用工具");
}
AgentOutput::Actions(_) => {
println!("ReActAgent 不支持并行工具调用");
}
}
}
}