autoagents_core/agent/prebuilt/executor/
basic.rs1use crate::agent::hooks::HookOutcome;
2use crate::agent::task::Task;
3use crate::agent::{AgentDeriveT, AgentExecutor, AgentHooks, Context, ExecutorConfig};
4use crate::tool::{ToolCallResult, ToolT};
5use async_trait::async_trait;
6use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType};
7use autoagents_llm::ToolCall;
8use futures::Stream;
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::ops::Deref;
12use std::pin::Pin;
13use std::sync::Arc;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct BasicAgentOutput {
18 pub response: String,
19 pub done: bool,
20}
21
22impl From<BasicAgentOutput> for Value {
23 fn from(output: BasicAgentOutput) -> Self {
24 serde_json::to_value(output).unwrap_or(Value::Null)
25 }
26}
27impl From<BasicAgentOutput> for String {
28 fn from(output: BasicAgentOutput) -> Self {
29 output.response
30 }
31}
32
33#[derive(Debug, thiserror::Error)]
35pub enum BasicExecutorError {
36 #[error("LLM error: {0}")]
37 LLMError(String),
38
39 #[error("Other error: {0}")]
40 Other(String),
41}
42
43#[derive(Debug)]
45pub struct BasicAgent<T: AgentDeriveT> {
46 inner: Arc<T>,
47}
48
49impl<T: AgentDeriveT> Clone for BasicAgent<T> {
50 fn clone(&self) -> Self {
51 Self {
52 inner: Arc::clone(&self.inner),
53 }
54 }
55}
56
57impl<T: AgentDeriveT> BasicAgent<T> {
58 pub fn new(inner: T) -> Self {
59 Self {
60 inner: Arc::new(inner),
61 }
62 }
63}
64
65impl<T: AgentDeriveT> Deref for BasicAgent<T> {
66 type Target = T;
67
68 fn deref(&self) -> &Self::Target {
69 &self.inner
70 }
71}
72
73#[async_trait]
75impl<T: AgentDeriveT> AgentDeriveT for BasicAgent<T> {
76 type Output = <T as AgentDeriveT>::Output;
77
78 fn description(&self) -> &'static str {
79 self.inner.description()
80 }
81
82 fn output_schema(&self) -> Option<Value> {
83 self.inner.output_schema()
84 }
85
86 fn name(&self) -> &'static str {
87 self.inner.name()
88 }
89
90 fn tools(&self) -> Vec<Box<dyn ToolT>> {
91 self.inner.tools()
92 }
93}
94
95#[async_trait]
96impl<T> AgentHooks for BasicAgent<T>
97where
98 T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
99{
100 async fn on_agent_create(&self) {
101 self.inner.on_agent_create().await
102 }
103
104 async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
105 self.inner.on_run_start(task, ctx).await
106 }
107
108 async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
109 self.inner.on_run_complete(task, result, ctx).await
110 }
111
112 async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
113 self.inner.on_turn_start(turn_index, ctx).await
114 }
115
116 async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
117 self.inner.on_turn_complete(turn_index, ctx).await
118 }
119
120 async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
121 self.inner.on_tool_call(tool_call, ctx).await
122 }
123
124 async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
125 self.inner.on_tool_start(tool_call, ctx).await
126 }
127
128 async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
129 self.inner.on_tool_result(tool_call, result, ctx).await
130 }
131
132 async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
133 self.inner.on_tool_error(tool_call, err, ctx).await
134 }
135 async fn on_agent_shutdown(&self) {
136 self.inner.on_agent_shutdown().await
137 }
138}
139
140#[async_trait]
142impl<T: AgentDeriveT> AgentExecutor for BasicAgent<T> {
143 type Output = BasicAgentOutput;
144 type Error = BasicExecutorError;
145
146 fn config(&self) -> ExecutorConfig {
147 ExecutorConfig { max_turns: 1 }
148 }
149
150 async fn execute(
151 &self,
152 task: &Task,
153 context: Arc<Context>,
154 ) -> Result<Self::Output, Self::Error> {
155 let mut messages = vec![ChatMessage {
156 role: ChatRole::System,
157 message_type: MessageType::Text,
158 content: context.config().description.clone(),
159 }];
160
161 let chat_msg = ChatMessage {
162 role: ChatRole::User,
163 message_type: MessageType::Text,
164 content: task.prompt.clone(),
165 };
166 messages.push(chat_msg);
167 let response = context
168 .llm()
169 .chat(&messages, None, context.config().output_schema.clone())
170 .await
171 .map_err(|e| BasicExecutorError::LLMError(e.to_string()))?;
172 let response_text = response.text().unwrap_or_default();
173 Ok(BasicAgentOutput {
174 response: response_text,
175 done: true,
176 })
177 }
178
179 async fn execute_stream(
180 &self,
181 task: &Task,
182 context: Arc<Context>,
183 ) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>, Self::Error>
184 {
185 use futures::StreamExt;
186
187 let mut messages = vec![ChatMessage {
188 role: ChatRole::System,
189 message_type: MessageType::Text,
190 content: context.config().description.clone(),
191 }];
192
193 let chat_msg = ChatMessage {
194 role: ChatRole::User,
195 message_type: MessageType::Text,
196 content: task.prompt.clone(),
197 };
198 messages.push(chat_msg);
199
200 let stream = context
201 .llm()
202 .chat_stream_struct(&messages, None, context.config().output_schema.clone())
203 .await
204 .map_err(|e| BasicExecutorError::LLMError(e.to_string()))?;
205
206 let mapped_stream = stream.map(|chunk_result| match chunk_result {
207 Ok(chunk) => {
208 let content = chunk
209 .choices
210 .first()
211 .and_then(|choice| choice.delta.content.as_ref())
212 .map_or("", |v| v)
213 .to_string();
214
215 Ok(BasicAgentOutput {
216 response: content,
217 done: false,
218 })
219 }
220 Err(e) => Err(BasicExecutorError::LLMError(e.to_string())),
221 });
222
223 Ok(Box::pin(mapped_stream))
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use crate::agent::AgentDeriveT;
231 use crate::tests::agent::MockAgentImpl;
232 use autoagents_test_utils::llm::MockLLMProvider;
233 use std::sync::Arc;
234
235 #[test]
236 fn test_basic_agent_creation() {
237 let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
238 let basic_agent = BasicAgent::new(mock_agent);
239
240 assert_eq!(basic_agent.name(), "test_agent");
241 assert_eq!(basic_agent.description(), "Test agent description");
242 }
243
244 #[test]
245 fn test_basic_agent_clone() {
246 let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
247 let basic_agent = BasicAgent::new(mock_agent);
248 let cloned_agent = basic_agent.clone();
249
250 assert_eq!(cloned_agent.name(), "test_agent");
251 assert_eq!(cloned_agent.description(), "Test agent description");
252 }
253
254 #[test]
255 fn test_basic_agent_output_conversions() {
256 let output = BasicAgentOutput {
257 response: "Test response".to_string(),
258 done: true,
259 };
260
261 let value: Value = output.clone().into();
263 assert!(value.is_object());
264
265 let string: String = output.into();
267 assert_eq!(string, "Test response");
268 }
269
270 #[tokio::test]
271 async fn test_basic_agent_execute() {
272 use crate::agent::task::Task;
273 use crate::agent::{AgentConfig, Context};
274 use crate::protocol::ActorID;
275
276 let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
277 let basic_agent = BasicAgent::new(mock_agent);
278
279 let llm = Arc::new(MockLLMProvider {});
280 let config = AgentConfig {
281 id: ActorID::new_v4(),
282 name: "test_agent".to_string(),
283 description: "Test agent description".to_string(),
284 output_schema: None,
285 };
286
287 let context = Context::new(llm, None).with_config(config);
288
289 let context_arc = Arc::new(context);
290 let task = Task::new("Test task");
291 let result = basic_agent.execute(&task, context_arc).await;
292
293 assert!(result.is_ok());
294 let output = result.unwrap();
295 assert_eq!(output.response, "Mock response");
296 assert!(output.done);
297 }
298
299 #[test]
300 fn test_executor_config() {
301 let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
302 let basic_agent = BasicAgent::new(mock_agent);
303
304 let config = basic_agent.config();
305 assert_eq!(config.max_turns, 1);
306 }
307}