langchain_rust/agent/
executor.rs

1use std::{collections::HashMap, sync::Arc};
2
3use async_trait::async_trait;
4use serde_json::json;
5use tokio::sync::Mutex;
6
7use crate::{
8    chain::{chain_trait::Chain, ChainError},
9    language_models::GenerateResult,
10    memory::SimpleMemory,
11    prompt::PromptArgs,
12    schemas::{
13        agent::{AgentAction, AgentEvent},
14        memory::BaseMemory,
15    },
16    tools::Tool,
17};
18
19use super::{agent::Agent, AgentError};
20
21pub struct AgentExecutor<A>
22where
23    A: Agent,
24{
25    agent: A,
26    max_iterations: Option<i32>,
27    break_if_error: bool,
28    pub memory: Option<Arc<Mutex<dyn BaseMemory>>>,
29}
30
31impl<A> AgentExecutor<A>
32where
33    A: Agent,
34{
35    pub fn from_agent(agent: A) -> Self {
36        Self {
37            agent,
38            max_iterations: Some(10),
39            break_if_error: false,
40            memory: None,
41        }
42    }
43
44    pub fn with_max_iterations(mut self, max_iterations: i32) -> Self {
45        self.max_iterations = Some(max_iterations);
46        self
47    }
48
49    pub fn with_memory(mut self, memory: Arc<Mutex<dyn BaseMemory>>) -> Self {
50        self.memory = Some(memory);
51        self
52    }
53
54    pub fn with_break_if_error(mut self, break_if_error: bool) -> Self {
55        self.break_if_error = break_if_error;
56        self
57    }
58
59    fn get_name_to_tools(&self) -> HashMap<String, Arc<dyn Tool>> {
60        let mut name_to_tool = HashMap::new();
61        for tool in self.agent.get_tools().iter() {
62            log::debug!("Loading Tool:{}", tool.name());
63            name_to_tool.insert(tool.name().trim().replace(" ", "_"), tool.clone());
64        }
65        name_to_tool
66    }
67}
68
69#[async_trait]
70impl<A> Chain for AgentExecutor<A>
71where
72    A: Agent + Send + Sync,
73{
74    async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError> {
75        let mut input_variables = input_variables.clone();
76        let name_to_tools = self.get_name_to_tools();
77        let mut steps: Vec<(AgentAction, String)> = Vec::new();
78        log::debug!("steps: {:?}", steps);
79        if let Some(memory) = &self.memory {
80            let memory = memory.lock().await;
81            input_variables.insert("chat_history".to_string(), json!(memory.messages()));
82        } else {
83            input_variables.insert(
84                "chat_history".to_string(),
85                json!(SimpleMemory::new().messages()),
86            );
87        }
88
89        loop {
90            let agent_event = self
91                .agent
92                .plan(&steps, input_variables.clone())
93                .await
94                .map_err(|e| ChainError::AgentError(format!("Error in agent planning: {}", e)))?;
95            match agent_event {
96                AgentEvent::Action(actions) => {
97                    for action in actions {
98                        log::debug!("Action: {:?}", action.tool_input);
99                        let tool = name_to_tools
100                            .get(&action.tool)
101                            .ok_or_else(|| {
102                                AgentError::ToolError(format!("Tool {} not found", action.tool))
103                            })
104                            .map_err(|e| ChainError::AgentError(e.to_string()))?;
105
106                        let observation_result = tool.call(&action.tool_input).await;
107
108                        let observation = match observation_result {
109                            Ok(result) => result,
110                            Err(err) => {
111                                log::info!(
112                                    "The tool return the following error: {}",
113                                    err.to_string()
114                                );
115                                if self.break_if_error {
116                                    return Err(ChainError::AgentError(
117                                        AgentError::ToolError(err.to_string()).to_string(),
118                                    ));
119                                } else {
120                                    format!("The tool return the following error: {}", err)
121                                }
122                            }
123                        };
124
125                        steps.push((action, observation));
126                    }
127                }
128                AgentEvent::Finish(finish) => {
129                    if let Some(memory) = &self.memory {
130                        let mut memory = memory.lock().await;
131                        memory.add_user_message(&input_variables["input"]);
132                        memory.add_ai_message(&finish.output);
133                    }
134                    return Ok(GenerateResult {
135                        generation: finish.output,
136                        ..Default::default()
137                    });
138                }
139            }
140
141            if let Some(max_iterations) = self.max_iterations {
142                if steps.len() >= max_iterations as usize {
143                    return Ok(GenerateResult {
144                        generation: "Max iterations reached".to_string(),
145                        ..Default::default()
146                    });
147                }
148            }
149        }
150    }
151
152    async fn invoke(&self, input_variables: PromptArgs) -> Result<String, ChainError> {
153        let result = self.call(input_variables).await?;
154        Ok(result.generation)
155    }
156}