autoagents_core/agent/prebuilt/executor/
basic.rs1use crate::agent::hooks::HookOutcome;
2use crate::agent::task::Task;
3use crate::agent::{AgentDeriveT, AgentExecutor, AgentHooks, Context, EventHelper, ExecutorConfig};
4use crate::tool::{ToolCallResult, ToolT};
5use async_trait::async_trait;
6use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType};
7use autoagents_llm::ToolCall;
8use futures::Stream;
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::ops::Deref;
12use std::pin::Pin;
13use std::sync::Arc;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct BasicAgentOutput {
18 pub response: String,
19 pub done: bool,
20}
21
22impl From<BasicAgentOutput> for Value {
23 fn from(output: BasicAgentOutput) -> Self {
24 serde_json::to_value(output).unwrap_or(Value::Null)
25 }
26}
27impl From<BasicAgentOutput> for String {
28 fn from(output: BasicAgentOutput) -> Self {
29 output.response
30 }
31}
32
33#[derive(Debug, thiserror::Error)]
35pub enum BasicExecutorError {
36 #[error("LLM error: {0}")]
37 LLMError(String),
38
39 #[error("Other error: {0}")]
40 Other(String),
41}
42
43#[derive(Debug)]
45pub struct BasicAgent<T: AgentDeriveT> {
46 inner: Arc<T>,
47}
48
49impl<T: AgentDeriveT> Clone for BasicAgent<T> {
50 fn clone(&self) -> Self {
51 Self {
52 inner: Arc::clone(&self.inner),
53 }
54 }
55}
56
57impl<T: AgentDeriveT> BasicAgent<T> {
58 pub fn new(inner: T) -> Self {
59 Self {
60 inner: Arc::new(inner),
61 }
62 }
63}
64
65impl<T: AgentDeriveT> Deref for BasicAgent<T> {
66 type Target = T;
67
68 fn deref(&self) -> &Self::Target {
69 &self.inner
70 }
71}
72
73#[async_trait]
75impl<T: AgentDeriveT> AgentDeriveT for BasicAgent<T> {
76 type Output = <T as AgentDeriveT>::Output;
77
78 fn description(&self) -> &'static str {
79 self.inner.description()
80 }
81
82 fn output_schema(&self) -> Option<Value> {
83 self.inner.output_schema()
84 }
85
86 fn name(&self) -> &'static str {
87 self.inner.name()
88 }
89
90 fn tools(&self) -> Vec<Box<dyn ToolT>> {
91 self.inner.tools()
92 }
93}
94
95#[async_trait]
96impl<T> AgentHooks for BasicAgent<T>
97where
98 T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
99{
100 async fn on_agent_create(&self) {
101 self.inner.on_agent_create().await
102 }
103
104 async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
105 self.inner.on_run_start(task, ctx).await
106 }
107
108 async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
109 self.inner.on_run_complete(task, result, ctx).await
110 }
111
112 async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
113 self.inner.on_turn_start(turn_index, ctx).await
114 }
115
116 async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
117 self.inner.on_turn_complete(turn_index, ctx).await
118 }
119
120 async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
121 self.inner.on_tool_call(tool_call, ctx).await
122 }
123
124 async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
125 self.inner.on_tool_start(tool_call, ctx).await
126 }
127
128 async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
129 self.inner.on_tool_result(tool_call, result, ctx).await
130 }
131
132 async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
133 self.inner.on_tool_error(tool_call, err, ctx).await
134 }
135 async fn on_agent_shutdown(&self) {
136 self.inner.on_agent_shutdown().await
137 }
138}
139
140#[async_trait]
142impl<T: AgentDeriveT> AgentExecutor for BasicAgent<T> {
143 type Output = BasicAgentOutput;
144 type Error = BasicExecutorError;
145
146 fn config(&self) -> ExecutorConfig {
147 ExecutorConfig { max_turns: 1 }
148 }
149
150 async fn execute(
151 &self,
152 task: &Task,
153 context: Arc<Context>,
154 ) -> Result<Self::Output, Self::Error> {
155 let tx_event = context.tx().ok();
156 EventHelper::send_task_started(
157 &tx_event,
158 task.submission_id,
159 context.config().id,
160 task.prompt.clone(),
161 context.config().name.clone(),
162 )
163 .await;
164
165 let mut messages = vec![ChatMessage {
166 role: ChatRole::System,
167 message_type: MessageType::Text,
168 content: context.config().description.clone(),
169 }];
170
171 let chat_msg = ChatMessage {
172 role: ChatRole::User,
173 message_type: MessageType::Text,
174 content: task.prompt.clone(),
175 };
176 messages.push(chat_msg);
177 let response = context
178 .llm()
179 .chat(&messages, None, context.config().output_schema.clone())
180 .await
181 .map_err(|e| BasicExecutorError::LLMError(e.to_string()))?;
182 let response_text = response.text().unwrap_or_default();
183 Ok(BasicAgentOutput {
184 response: response_text,
185 done: true,
186 })
187 }
188
189 async fn execute_stream(
190 &self,
191 task: &Task,
192 context: Arc<Context>,
193 ) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>, Self::Error>
194 {
195 use futures::StreamExt;
196
197 let tx_event = context.tx().ok();
198 EventHelper::send_task_started(
199 &tx_event,
200 task.submission_id,
201 context.config().id,
202 task.prompt.clone(),
203 context.config().name.clone(),
204 )
205 .await;
206
207 let mut messages = vec![ChatMessage {
208 role: ChatRole::System,
209 message_type: MessageType::Text,
210 content: context.config().description.clone(),
211 }];
212
213 let chat_msg = ChatMessage {
214 role: ChatRole::User,
215 message_type: MessageType::Text,
216 content: task.prompt.clone(),
217 };
218 messages.push(chat_msg);
219
220 let stream = context
221 .llm()
222 .chat_stream_struct(&messages, None, context.config().output_schema.clone())
223 .await
224 .map_err(|e| BasicExecutorError::LLMError(e.to_string()))?;
225
226 let mapped_stream = stream.map(|chunk_result| match chunk_result {
227 Ok(chunk) => {
228 let content = chunk
229 .choices
230 .first()
231 .and_then(|choice| choice.delta.content.as_ref())
232 .map_or("", |v| v)
233 .to_string();
234
235 Ok(BasicAgentOutput {
236 response: content,
237 done: false,
238 })
239 }
240 Err(e) => Err(BasicExecutorError::LLMError(e.to_string())),
241 });
242
243 Ok(Box::pin(mapped_stream))
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250 use crate::agent::AgentDeriveT;
251 use crate::tests::agent::MockAgentImpl;
252 use autoagents_test_utils::llm::MockLLMProvider;
253 use std::sync::Arc;
254
255 #[test]
256 fn test_basic_agent_creation() {
257 let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
258 let basic_agent = BasicAgent::new(mock_agent);
259
260 assert_eq!(basic_agent.name(), "test_agent");
261 assert_eq!(basic_agent.description(), "Test agent description");
262 }
263
264 #[test]
265 fn test_basic_agent_clone() {
266 let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
267 let basic_agent = BasicAgent::new(mock_agent);
268 let cloned_agent = basic_agent.clone();
269
270 assert_eq!(cloned_agent.name(), "test_agent");
271 assert_eq!(cloned_agent.description(), "Test agent description");
272 }
273
274 #[test]
275 fn test_basic_agent_output_conversions() {
276 let output = BasicAgentOutput {
277 response: "Test response".to_string(),
278 done: true,
279 };
280
281 let value: Value = output.clone().into();
283 assert!(value.is_object());
284
285 let string: String = output.into();
287 assert_eq!(string, "Test response");
288 }
289
290 #[tokio::test]
291 async fn test_basic_agent_execute() {
292 use crate::agent::task::Task;
293 use crate::agent::{AgentConfig, Context};
294 use crate::protocol::ActorID;
295
296 let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
297 let basic_agent = BasicAgent::new(mock_agent);
298
299 let llm = Arc::new(MockLLMProvider {});
300 let config = AgentConfig {
301 id: ActorID::new_v4(),
302 name: "test_agent".to_string(),
303 description: "Test agent description".to_string(),
304 output_schema: None,
305 };
306
307 let context = Context::new(llm, None).with_config(config);
308
309 let context_arc = Arc::new(context);
310 let task = Task::new("Test task");
311 let result = basic_agent.execute(&task, context_arc).await;
312
313 assert!(result.is_ok());
314 let output = result.unwrap();
315 assert_eq!(output.response, "Mock response");
316 assert!(output.done);
317 }
318
319 #[test]
320 fn test_executor_config() {
321 let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
322 let basic_agent = BasicAgent::new(mock_agent);
323
324 let config = basic_agent.config();
325 assert_eq!(config.max_turns, 1);
326 }
327}