autoagents_core/agent/prebuilt/executor/
basic.rs1use crate::agent::hooks::HookOutcome;
2use crate::agent::task::Task;
3use crate::agent::{AgentDeriveT, AgentExecutor, AgentHooks, Context, EventHelper, ExecutorConfig};
4use crate::tool::{ToolCallResult, ToolT};
5use async_trait::async_trait;
6use autoagents_llm::ToolCall;
7use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType};
8use futures::Stream;
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::ops::Deref;
12use std::pin::Pin;
13use std::sync::Arc;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct BasicAgentOutput {
18 pub response: String,
19 pub done: bool,
20}
21
22impl From<BasicAgentOutput> for Value {
23 fn from(output: BasicAgentOutput) -> Self {
24 serde_json::to_value(output).unwrap_or(Value::Null)
25 }
26}
27impl From<BasicAgentOutput> for String {
28 fn from(output: BasicAgentOutput) -> Self {
29 output.response
30 }
31}
32
33impl BasicAgentOutput {
34 pub fn try_parse<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
37 serde_json::from_str::<T>(&self.response)
38 }
39
40 pub fn parse_or_map<T, F>(&self, fallback: F) -> T
43 where
44 T: for<'de> serde::Deserialize<'de>,
45 F: FnOnce(&str) -> T,
46 {
47 self.try_parse::<T>()
48 .unwrap_or_else(|_| fallback(&self.response))
49 }
50}
51
52#[derive(Debug, thiserror::Error)]
54pub enum BasicExecutorError {
55 #[error("LLM error: {0}")]
56 LLMError(String),
57
58 #[error("Other error: {0}")]
59 Other(String),
60}
61
62#[derive(Debug)]
67pub struct BasicAgent<T: AgentDeriveT> {
68 inner: Arc<T>,
69}
70
71impl<T: AgentDeriveT> Clone for BasicAgent<T> {
72 fn clone(&self) -> Self {
73 Self {
74 inner: Arc::clone(&self.inner),
75 }
76 }
77}
78
79impl<T: AgentDeriveT> BasicAgent<T> {
80 pub fn new(inner: T) -> Self {
81 Self {
82 inner: Arc::new(inner),
83 }
84 }
85}
86
87impl<T: AgentDeriveT> Deref for BasicAgent<T> {
88 type Target = T;
89
90 fn deref(&self) -> &Self::Target {
91 &self.inner
92 }
93}
94
95#[async_trait]
97impl<T: AgentDeriveT> AgentDeriveT for BasicAgent<T> {
98 type Output = <T as AgentDeriveT>::Output;
99
100 fn description(&self) -> &'static str {
101 self.inner.description()
102 }
103
104 fn output_schema(&self) -> Option<Value> {
105 self.inner.output_schema()
106 }
107
108 fn name(&self) -> &'static str {
109 self.inner.name()
110 }
111
112 fn tools(&self) -> Vec<Box<dyn ToolT>> {
113 self.inner.tools()
114 }
115}
116
117#[async_trait]
118impl<T> AgentHooks for BasicAgent<T>
119where
120 T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
121{
122 async fn on_agent_create(&self) {
123 self.inner.on_agent_create().await
124 }
125
126 async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
127 self.inner.on_run_start(task, ctx).await
128 }
129
130 async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
131 self.inner.on_run_complete(task, result, ctx).await
132 }
133
134 async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
135 self.inner.on_turn_start(turn_index, ctx).await
136 }
137
138 async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
139 self.inner.on_turn_complete(turn_index, ctx).await
140 }
141
142 async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
143 self.inner.on_tool_call(tool_call, ctx).await
144 }
145
146 async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
147 self.inner.on_tool_start(tool_call, ctx).await
148 }
149
150 async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
151 self.inner.on_tool_result(tool_call, result, ctx).await
152 }
153
154 async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
155 self.inner.on_tool_error(tool_call, err, ctx).await
156 }
157 async fn on_agent_shutdown(&self) {
158 self.inner.on_agent_shutdown().await
159 }
160}
161
162#[async_trait]
164impl<T: AgentDeriveT> AgentExecutor for BasicAgent<T> {
165 type Output = BasicAgentOutput;
166 type Error = BasicExecutorError;
167
168 fn config(&self) -> ExecutorConfig {
169 ExecutorConfig { max_turns: 1 }
170 }
171
172 async fn execute(
173 &self,
174 task: &Task,
175 context: Arc<Context>,
176 ) -> Result<Self::Output, Self::Error> {
177 let tx_event = context.tx().ok();
178 EventHelper::send_task_started(
179 &tx_event,
180 task.submission_id,
181 context.config().id,
182 task.prompt.clone(),
183 context.config().name.clone(),
184 )
185 .await;
186
187 let mut messages = vec![ChatMessage {
188 role: ChatRole::System,
189 message_type: MessageType::Text,
190 content: context.config().description.clone(),
191 }];
192
193 let chat_msg = if let Some((mime, image_data)) = &task.image {
194 ChatMessage {
196 role: ChatRole::User,
197 message_type: MessageType::Image((*mime, image_data.clone())),
198 content: task.prompt.clone(),
199 }
200 } else {
201 ChatMessage {
203 role: ChatRole::User,
204 message_type: MessageType::Text,
205 content: task.prompt.clone(),
206 }
207 };
208 messages.push(chat_msg);
209 let response = context
210 .llm()
211 .chat(&messages, context.config().output_schema.clone())
212 .await
213 .map_err(|e| BasicExecutorError::LLMError(e.to_string()))?;
214 let response_text = response.text().unwrap_or_default();
215 Ok(BasicAgentOutput {
216 response: response_text,
217 done: true,
218 })
219 }
220
221 async fn execute_stream(
222 &self,
223 task: &Task,
224 context: Arc<Context>,
225 ) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>, Self::Error>
226 {
227 use futures::StreamExt;
228
229 let tx_event = context.tx().ok();
230 EventHelper::send_task_started(
231 &tx_event,
232 task.submission_id,
233 context.config().id,
234 task.prompt.clone(),
235 context.config().name.clone(),
236 )
237 .await;
238
239 let mut messages = vec![ChatMessage {
240 role: ChatRole::System,
241 message_type: MessageType::Text,
242 content: context.config().description.clone(),
243 }];
244
245 let chat_msg = if let Some((mime, image_data)) = &task.image {
246 ChatMessage {
248 role: ChatRole::User,
249 message_type: MessageType::Image((*mime, image_data.clone())),
250 content: task.prompt.clone(),
251 }
252 } else {
253 ChatMessage {
255 role: ChatRole::User,
256 message_type: MessageType::Text,
257 content: task.prompt.clone(),
258 }
259 };
260 messages.push(chat_msg);
261
262 let stream = context
263 .llm()
264 .chat_stream_struct(&messages, None, context.config().output_schema.clone())
265 .await
266 .map_err(|e| BasicExecutorError::LLMError(e.to_string()))?;
267
268 let mapped_stream = stream.map(|chunk_result| match chunk_result {
269 Ok(chunk) => {
270 let content = chunk
271 .choices
272 .first()
273 .and_then(|choice| choice.delta.content.as_ref())
274 .map_or("", |v| v)
275 .to_string();
276
277 Ok(BasicAgentOutput {
278 response: content,
279 done: false,
280 })
281 }
282 Err(e) => Err(BasicExecutorError::LLMError(e.to_string())),
283 });
284
285 Ok(Box::pin(mapped_stream))
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use crate::agent::AgentDeriveT;
293 use crate::tests::agent::MockAgentImpl;
294 use autoagents_test_utils::llm::MockLLMProvider;
295 use std::sync::Arc;
296
297 #[test]
298 fn test_basic_agent_creation() {
299 let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
300 let basic_agent = BasicAgent::new(mock_agent);
301
302 assert_eq!(basic_agent.name(), "test_agent");
303 assert_eq!(basic_agent.description(), "Test agent description");
304 }
305
306 #[test]
307 fn test_basic_agent_clone() {
308 let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
309 let basic_agent = BasicAgent::new(mock_agent);
310 let cloned_agent = basic_agent.clone();
311
312 assert_eq!(cloned_agent.name(), "test_agent");
313 assert_eq!(cloned_agent.description(), "Test agent description");
314 }
315
316 #[test]
317 fn test_basic_agent_output_conversions() {
318 let output = BasicAgentOutput {
319 response: "Test response".to_string(),
320 done: true,
321 };
322
323 let value: Value = output.clone().into();
325 assert!(value.is_object());
326
327 let string: String = output.into();
329 assert_eq!(string, "Test response");
330 }
331
332 #[tokio::test]
333 async fn test_basic_agent_execute() {
334 use crate::agent::task::Task;
335 use crate::agent::{AgentConfig, Context};
336 use crate::protocol::ActorID;
337
338 let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
339 let basic_agent = BasicAgent::new(mock_agent);
340
341 let llm = Arc::new(MockLLMProvider {});
342 let config = AgentConfig {
343 id: ActorID::new_v4(),
344 name: "test_agent".to_string(),
345 description: "Test agent description".to_string(),
346 output_schema: None,
347 };
348
349 let context = Context::new(llm, None).with_config(config);
350
351 let context_arc = Arc::new(context);
352 let task = Task::new("Test task");
353 let result = basic_agent.execute(&task, context_arc).await;
354
355 assert!(result.is_ok());
356 let output = result.unwrap();
357 assert_eq!(output.response, "Mock response");
358 assert!(output.done);
359 }
360
361 #[test]
362 fn test_executor_config() {
363 let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
364 let basic_agent = BasicAgent::new(mock_agent);
365
366 let config = basic_agent.config();
367 assert_eq!(config.max_turns, 1);
368 }
369}