autoagents_core/agent/prebuilt/
simple.rs

1use crate::agent::base::{AgentDeriveT, BaseAgent};
2use crate::agent::executor::{AgentExecutor, TurnResult};
3use crate::agent::runnable::AgentState;
4use crate::session::Task;
5use crate::tool::ToolCallResult;
6use async_trait::async_trait;
7use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType, Tool};
8use autoagents_llm::{LLMProvider, ToolCall, ToolT};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::sync::Arc;
12use tokio::sync::Mutex;
13
14/// A simple executor that processes user prompts and handles tool calls
15#[derive(Clone)]
16pub struct SimpleExecutor {
17    /// System prompt for the executor
18    system_prompt: String,
19    /// Maximum number of turns
20    max_turns: usize,
21    /// Tools available to this executor
22    tools: Vec<Arc<Box<dyn ToolT>>>,
23}
24
25/// Output type for the simple executor
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct SimpleOutput {
28    /// The final response from the agent
29    pub response: String,
30    /// Tool calls that were made
31    pub tool_calls: Vec<ToolCall>,
32}
33
34/// Error type for the simple executor
35#[derive(Debug, thiserror::Error)]
36pub enum SimpleError {
37    #[error("LLM error: {0}")]
38    LLMError(String),
39
40    #[error("Session error: {0}")]
41    SessionError(String),
42
43    #[error("Tool execution error: {0}")]
44    ToolError(String),
45
46    #[error("Maximum turns exceeded")]
47    MaxTurnsExceeded,
48}
49
50impl SimpleExecutor {
51    /// Create a new simple executor
52    pub fn new(system_prompt: String, tools: Vec<Arc<Box<dyn ToolT>>>) -> Self {
53        Self {
54            system_prompt,
55            max_turns: 10,
56            tools,
57        }
58    }
59
60    /// Set the maximum number of turns
61    pub fn with_max_turns(mut self, max_turns: usize) -> Self {
62        self.max_turns = max_turns;
63        self
64    }
65
66    /// Process tool calls from the LLM response
67    async fn process_tool_calls(
68        &self,
69        tool_calls: Vec<ToolCall>,
70    ) -> Result<Option<ToolCallResult>, SimpleError> {
71        // // Process each tool call
72        if let Some(tool) = tool_calls.first() {
73            let arguments = tool.function.arguments.clone();
74            for tool_self in self.tools.clone() {
75                if tool.function.name == tool_self.name() {
76                    let result = tool_self.run(serde_json::from_str(&arguments.clone()).unwrap());
77                    // let result = llm.call_tool(tool_name, arguments.clone());
78                    return Ok(Some(ToolCallResult {
79                        tool_name: tool.function.name.clone(),
80                        arguments: arguments.clone().into(),
81                        result,
82                    }));
83                }
84            }
85        }
86
87        Ok(None)
88    }
89}
90
91#[async_trait]
92impl AgentExecutor for SimpleExecutor {
93    type Output = Value;
94    type Error = SimpleError;
95
96    async fn execute(
97        &self,
98        llm: Arc<dyn LLMProvider>,
99        task: Task,
100        state: Arc<Mutex<AgentState>>,
101    ) -> Result<Self::Output, Self::Error> {
102        // Initialize conversation with system prompt and task
103        let mut previous_chat_history = state.lock().await.get_history().messages;
104        let messages = vec![
105            ChatMessage {
106                role: ChatRole::Assistant,
107                message_type: MessageType::Text,
108                content: self.system_prompt.clone(),
109            },
110            ChatMessage {
111                role: ChatRole::User,
112                message_type: MessageType::Text,
113                content: task.prompt,
114            },
115        ];
116        previous_chat_history.extend(messages);
117
118        let mut turn_count = 0;
119
120        // Process turns until we get a final response or hit the limit
121        loop {
122            turn_count += 1;
123            if turn_count > self.max_turns {
124                return Err(SimpleError::MaxTurnsExceeded);
125            }
126
127            match self
128                .process_turn(llm.clone(), &mut previous_chat_history, state.clone())
129                .await?
130            {
131                TurnResult::Complete(output) => return Ok(output),
132                TurnResult::Continue => continue,
133                TurnResult::Error(e) => {
134                    // Add error to conversation and continue
135                    previous_chat_history.push(ChatMessage {
136                        role: ChatRole::Assistant,
137                        message_type: MessageType::Text,
138                        content: format!("Error occurred: {}", e),
139                    });
140                }
141            }
142        }
143    }
144
145    async fn process_turn(
146        &self,
147        llm: Arc<dyn LLMProvider>,
148        messages: &mut Vec<ChatMessage>,
149        state: Arc<Mutex<AgentState>>,
150    ) -> Result<TurnResult<Self::Output>, Self::Error> {
151        let has_tools = !self.tools.is_empty();
152        let response = if has_tools {
153            let tools = self.tools.iter().map(Tool::from).collect::<Vec<_>>();
154            llm.chat_with_tools(messages.as_slice(), Some(tools.as_slice()))
155                .await
156        } else {
157            // Call the LLM
158            llm.chat(messages.as_slice()).await
159        };
160
161        let response_mapped = response.map_err(|e| SimpleError::LLMError(e.to_string()))?;
162
163        // Check if the response contains tool calls
164        let final_response = if response_mapped.tool_calls().is_none() {
165            // No tool calls, this is the final response
166            response_mapped.text().clone().unwrap()
167        } else {
168            let tool_calls = response_mapped.tool_calls().unwrap();
169            let result = self.process_tool_calls(tool_calls).await.unwrap();
170            let restult_string = result.clone().unwrap().result.to_string();
171            messages.push(ChatMessage {
172                role: ChatRole::Assistant,
173                message_type: MessageType::Text,
174                content: restult_string,
175            });
176            state.lock().await.record_tool_call(result.unwrap());
177            // Continue the conversation after tool calls
178            return Ok(TurnResult::Continue);
179        };
180        // Record the final message
181        state.lock().await.record_conversation(ChatMessage {
182            role: ChatRole::Assistant,
183            message_type: MessageType::Text,
184            content: final_response.clone(),
185        });
186
187        Ok(TurnResult::Complete(final_response.into()))
188    }
189
190    fn max_turns(&self) -> usize {
191        self.max_turns
192    }
193}
194
195/// Builder for creating Simple agents
196pub struct SimpleAgentBuilder {
197    name: String,
198    executor: SimpleExecutor,
199    description: String,
200    tools: Vec<Arc<Box<dyn ToolT>>>,
201    llm: Option<Arc<dyn LLMProvider>>,
202}
203
204impl SimpleAgentBuilder {
205    pub fn new(name: impl Into<String>, system_prompt: String) -> Self {
206        Self {
207            name: name.into(),
208            description: String::new(),
209            executor: SimpleExecutor::new(system_prompt, vec![]),
210            tools: Vec::new(),
211            llm: None,
212        }
213    }
214
215    pub fn from_agent<T: AgentDeriveT>(agent: T) -> Self {
216        let tools: Vec<Arc<Box<dyn ToolT>>> = agent
217            .tools()
218            .into_iter()
219            .map(|tool| {
220                let boxed: Box<dyn ToolT> = tool;
221                let arc: Arc<Box<dyn ToolT>> = Arc::new(boxed);
222                arc
223            })
224            .collect();
225        Self {
226            name: agent.name().into(),
227            description: agent.description().into(),
228            executor: SimpleExecutor::new(agent.description().into(), tools.clone()),
229            tools,
230            llm: None,
231        }
232    }
233
234    pub fn with_llm(mut self, llm: Arc<dyn LLMProvider>) -> Self {
235        self.llm = Some(llm);
236        self
237    }
238
239    pub fn build(self) -> Result<BaseAgent<SimpleExecutor>, SimpleError> {
240        let llm = self
241            .llm
242            .ok_or_else(|| SimpleError::LLMError("LLm is not set".into()))?;
243        Ok(BaseAgent::new(
244            self.name,
245            self.description,
246            self.executor,
247            self.tools,
248            llm,
249        ))
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[test]
258    fn test_simple_executor_creation() {
259        let executor = SimpleExecutor::new("You are a helpful assistant".to_string(), vec![])
260            .with_max_turns(5);
261
262        assert_eq!(executor.max_turns(), 5);
263    }
264}