autoagents_core/agent/
runnable.rs1use super::base::BaseAgent;
2use super::executor::AgentExecutor;
3use super::result::AgentRunResult;
4use crate::error::Error;
5use crate::protocol::Event;
6use crate::session::Task;
7use crate::tool::ToolCallResult;
8use async_trait::async_trait;
9use autoagents_llm::chat::ChatMessage;
10use serde_json::Value;
11use std::sync::Arc;
12use tokio::sync::{mpsc, Mutex};
13use tokio::task::JoinHandle;
14use uuid::Uuid;
15
16#[derive(Debug, Default, Clone)]
17pub struct History {
18 pub messages: Vec<ChatMessage>,
19 pub tool_calls: Vec<ToolCallResult>,
20 pub tasks: Vec<Task>,
21}
22
23#[derive(Default)]
24pub struct AgentState {
25 history: History,
26}
27
28impl AgentState {
29 pub(crate) fn get_history(&self) -> History {
30 self.history.clone()
31 }
32
33 pub(crate) fn record_conversation(&mut self, message: ChatMessage) {
34 self.history.messages.push(message);
35 }
36
37 pub(crate) fn record_tool_call(&mut self, tool_call: ToolCallResult) {
38 self.history.tool_calls.push(tool_call);
39 }
40}
41
42#[async_trait]
44pub trait RunnableAgent: Send + Sync + 'static {
45 fn name(&self) -> &str;
46 fn description(&self) -> &str;
47 fn id(&self) -> Uuid;
48
49 async fn run(
50 self: Arc<Self>,
51 task: Task,
52 tx_event: mpsc::Sender<Event>,
53 ) -> Result<AgentRunResult, Error>;
54
55 fn spawn_task(
56 self: Arc<Self>,
57 task: Task,
58 tx_event: mpsc::Sender<Event>,
59 ) -> JoinHandle<Result<AgentRunResult, Error>> {
60 tokio::spawn(async move { self.run(task, tx_event).await })
61 }
62}
63
64pub struct RunnableAgentImpl<E>
66where
67 E: AgentExecutor,
68{
69 agent: BaseAgent<E>,
70 state: Arc<Mutex<AgentState>>,
71 id: Uuid,
72}
73
74impl<E> RunnableAgentImpl<E>
75where
76 E: AgentExecutor,
77{
78 pub fn new(agent: BaseAgent<E>) -> Self {
79 Self {
80 agent,
81 state: Arc::new(Mutex::new(AgentState::default())),
82 id: Uuid::new_v4(),
83 }
84 }
85}
86
87#[async_trait]
88impl<E> RunnableAgent for RunnableAgentImpl<E>
89where
90 E: AgentExecutor,
91 E::Output: Into<Value> + Send + Sync,
92 E::Error: std::error::Error + Send + Sync + 'static,
93{
94 fn name(&self) -> &str {
95 &self.agent.name
96 }
97
98 fn description(&self) -> &str {
99 &self.agent.description
100 }
101
102 fn id(&self) -> Uuid {
103 self.id
104 }
105
106 async fn run(
107 self: Arc<Self>,
108 task: Task,
109 tx_event: mpsc::Sender<Event>,
110 ) -> Result<AgentRunResult, Error> {
111 let llm = self.agent.llm.clone();
112 let result = self
113 .agent
114 .executor
115 .execute(llm, task.clone(), self.state.clone())
116 .await;
117
118 let task_result = match &result {
120 Ok(val) => crate::protocol::TaskResult::Value(
121 serde_json::to_value(val).unwrap_or(serde_json::Value::Null),
122 ),
123 Err(err) => crate::protocol::TaskResult::Failure(err.to_string()),
124 };
125
126 let _ = tx_event
127 .send(Event::TaskComplete {
128 sub_id: task.submission_id,
129 result: task_result,
130 })
131 .await;
132
133 match result {
134 Ok(val) => Ok(AgentRunResult::success(val.into())),
135 Err(e) => Ok(AgentRunResult::failure(e.to_string())),
136 }
137 }
138}
139
140pub struct RunnableAgentBuilder {}
142
143impl Default for RunnableAgentBuilder {
144 fn default() -> Self {
145 Self::new()
146 }
147}
148
149impl RunnableAgentBuilder {
150 pub fn new() -> Self {
151 Self {}
152 }
153
154 pub fn build<E>(self, agent: BaseAgent<E>) -> Arc<dyn RunnableAgent>
155 where
156 E: AgentExecutor,
157 E::Output: Into<Value> + Send + Sync,
158 E::Error: std::error::Error + Send + Sync + 'static,
159 {
160 Arc::new(RunnableAgentImpl::new(agent))
161 }
162}