langchain_rust/agent/
executor.rs1use std::{collections::HashMap, sync::Arc};
2
3use async_trait::async_trait;
4use serde_json::json;
5use tokio::sync::Mutex;
6
7use crate::{
8 chain::{chain_trait::Chain, ChainError},
9 language_models::GenerateResult,
10 memory::SimpleMemory,
11 prompt::PromptArgs,
12 schemas::{
13 agent::{AgentAction, AgentEvent},
14 memory::BaseMemory,
15 },
16 tools::Tool,
17};
18
19use super::{agent::Agent, AgentError};
20
21pub struct AgentExecutor<A>
22where
23 A: Agent,
24{
25 agent: A,
26 max_iterations: Option<i32>,
27 break_if_error: bool,
28 pub memory: Option<Arc<Mutex<dyn BaseMemory>>>,
29}
30
31impl<A> AgentExecutor<A>
32where
33 A: Agent,
34{
35 pub fn from_agent(agent: A) -> Self {
36 Self {
37 agent,
38 max_iterations: Some(10),
39 break_if_error: false,
40 memory: None,
41 }
42 }
43
44 pub fn with_max_iterations(mut self, max_iterations: i32) -> Self {
45 self.max_iterations = Some(max_iterations);
46 self
47 }
48
49 pub fn with_memory(mut self, memory: Arc<Mutex<dyn BaseMemory>>) -> Self {
50 self.memory = Some(memory);
51 self
52 }
53
54 pub fn with_break_if_error(mut self, break_if_error: bool) -> Self {
55 self.break_if_error = break_if_error;
56 self
57 }
58
59 fn get_name_to_tools(&self) -> HashMap<String, Arc<dyn Tool>> {
60 let mut name_to_tool = HashMap::new();
61 for tool in self.agent.get_tools().iter() {
62 log::debug!("Loading Tool:{}", tool.name());
63 name_to_tool.insert(tool.name().trim().replace(" ", "_"), tool.clone());
64 }
65 name_to_tool
66 }
67}
68
69#[async_trait]
70impl<A> Chain for AgentExecutor<A>
71where
72 A: Agent + Send + Sync,
73{
74 async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError> {
75 let mut input_variables = input_variables.clone();
76 let name_to_tools = self.get_name_to_tools();
77 let mut steps: Vec<(AgentAction, String)> = Vec::new();
78 log::debug!("steps: {:?}", steps);
79 if let Some(memory) = &self.memory {
80 let memory = memory.lock().await;
81 input_variables.insert("chat_history".to_string(), json!(memory.messages()));
82 } else {
83 input_variables.insert(
84 "chat_history".to_string(),
85 json!(SimpleMemory::new().messages()),
86 );
87 }
88
89 loop {
90 let agent_event = self
91 .agent
92 .plan(&steps, input_variables.clone())
93 .await
94 .map_err(|e| ChainError::AgentError(format!("Error in agent planning: {}", e)))?;
95 match agent_event {
96 AgentEvent::Action(actions) => {
97 for action in actions {
98 log::debug!("Action: {:?}", action.tool_input);
99 let tool = name_to_tools
100 .get(&action.tool)
101 .ok_or_else(|| {
102 AgentError::ToolError(format!("Tool {} not found", action.tool))
103 })
104 .map_err(|e| ChainError::AgentError(e.to_string()))?;
105
106 let observation_result = tool.call(&action.tool_input).await;
107
108 let observation = match observation_result {
109 Ok(result) => result,
110 Err(err) => {
111 log::info!(
112 "The tool return the following error: {}",
113 err.to_string()
114 );
115 if self.break_if_error {
116 return Err(ChainError::AgentError(
117 AgentError::ToolError(err.to_string()).to_string(),
118 ));
119 } else {
120 format!("The tool return the following error: {}", err)
121 }
122 }
123 };
124
125 steps.push((action, observation));
126 }
127 }
128 AgentEvent::Finish(finish) => {
129 if let Some(memory) = &self.memory {
130 let mut memory = memory.lock().await;
131 memory.add_user_message(&input_variables["input"]);
132 memory.add_ai_message(&finish.output);
133 }
134 return Ok(GenerateResult {
135 generation: finish.output,
136 ..Default::default()
137 });
138 }
139 }
140
141 if let Some(max_iterations) = self.max_iterations {
142 if steps.len() >= max_iterations as usize {
143 return Ok(GenerateResult {
144 generation: "Max iterations reached".to_string(),
145 ..Default::default()
146 });
147 }
148 }
149 }
150 }
151
152 async fn invoke(&self, input_variables: PromptArgs) -> Result<String, ChainError> {
153 let result = self.call(input_variables).await?;
154 Ok(result.generation)
155 }
156}