autoagents_core/
session.rs1use crate::agent::base::BaseAgent;
2use crate::agent::executor::AgentExecutor;
3use crate::agent::result::AgentRunResult;
4use crate::agent::runnable::{RunnableAgent, RunnableAgentBuilder};
5use crate::error::Error;
6use crate::protocol::{AgentID, Event, SessionId, SubmissionId};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::{mpsc, Mutex, RwLock};
12use tokio::task::JoinError;
13use tokio_stream::wrappers::ReceiverStream;
14use uuid::Uuid;
15
16#[derive(Debug, thiserror::Error)]
18pub enum SessionError {
19 #[error("Agent not found: {0}")]
20 AgentNotFound(Uuid),
21
22 #[error("No task set for agent: {0}")]
23 NoTaskSet(Uuid),
24
25 #[error("Task is None")]
26 EmptyTask,
27
28 #[error("Task join error: {0}")]
29 TaskJoinError(#[from] JoinError),
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct Task {
34 pub prompt: String,
35 pub submission_id: SubmissionId,
36 pub completed: bool,
37 pub result: Option<Value>,
38 agent_id: Option<AgentID>,
39}
40
41impl Task {
42 pub fn new<T: Into<String>>(task: T, agent_id: Option<AgentID>) -> Self {
43 Self {
44 prompt: task.into(),
45 submission_id: Uuid::new_v4(),
46 completed: false,
47 result: None,
48 agent_id,
49 }
50 }
51}
52
53#[derive(Debug, Default)]
54struct State {
55 current_task: Option<Task>,
56 task_queue: Vec<Task>,
57}
58
59pub struct Session {
60 pub id: SessionId,
61 tx_event: mpsc::Sender<Event>,
62 rx_event: Option<mpsc::Receiver<Event>>,
63 state: Arc<Mutex<State>>,
64 agents: Arc<RwLock<HashMap<AgentID, Arc<dyn RunnableAgent + Send + Sync + 'static>>>>,
65}
66
67impl Session {
68 pub fn new(channel_buffer: usize) -> (SessionId, Self) {
69 let id = Uuid::new_v4();
70 let (tx_event, rx_event) = mpsc::channel(channel_buffer);
71 (
72 id,
73 Self {
74 id,
75 tx_event,
76 rx_event: Some(rx_event),
77 state: Arc::new(Mutex::new(State::default())),
78 agents: Arc::new(RwLock::new(HashMap::new())),
79 },
80 )
81 }
82
83 pub async fn add_task(&self, task: Task) {
84 let task_clone = task.clone();
85 if self
86 .tx_event
87 .send(Event::NewTask {
88 sub_id: task_clone.submission_id,
89 agent_id: task_clone.agent_id,
90 prompt: task_clone.prompt,
91 })
92 .await
93 .is_err()
94 {
95 panic!("Failed to send NewTask event");
96 }
97 let mut state = self.state.lock().await;
98 state.task_queue.push(task);
99 }
100
101 pub async fn set_current_task(&self, task: Task) {
102 let mut state = self.state.lock().await;
103 state.current_task = Some(task);
104 }
105
106 pub async fn is_task_queue_empty(&self) -> bool {
107 let state = self.state.lock().await;
108 state.task_queue.is_empty()
109 }
110
111 pub async fn get_top_task(&self) -> Option<Task> {
112 let mut state = self.state.lock().await;
113 if state.task_queue.is_empty() {
114 None
115 } else {
116 let task = state.task_queue.remove(0);
117 state.current_task = Some(task.clone());
118 Some(task)
119 }
120 }
121
122 pub async fn get_current_task(&self) -> Option<Task> {
123 let state = self.state.lock().await;
124 state.current_task.clone()
125 }
126
127 pub async fn get_task(&self, sub_id: SubmissionId) -> Option<Task> {
128 let state = self.state.lock().await;
129 state
130 .task_queue
131 .iter()
132 .find(|t| t.submission_id == sub_id)
133 .cloned()
134 }
135
136 pub fn event_sender(&self) -> mpsc::Sender<Event> {
137 self.tx_event.clone()
138 }
139
140 pub fn take_event_receiver(&mut self) -> Option<ReceiverStream<Event>> {
141 self.rx_event.take().map(ReceiverStream::new)
142 }
143
144 pub async fn register_agent_with_id<E>(&self, agent_id: Uuid, agent: BaseAgent<E>)
145 where
146 E: AgentExecutor,
147 E::Output: Into<Value> + Send,
148 E::Error: std::error::Error + Send + Sync + 'static,
149 {
150 let builder = RunnableAgentBuilder::new();
151 let runnable_agent = builder.build(agent);
152 self.agents.write().await.insert(agent_id, runnable_agent);
153 }
154
155 pub async fn run_task(
156 &self,
157 task: Option<Task>,
158 agent_id: AgentID,
159 ) -> Result<AgentRunResult, Error> {
160 let task = task.ok_or_else(|| SessionError::EmptyTask)?;
161 let agent = self
162 .agents
163 .read()
164 .await
165 .get(&agent_id)
166 .ok_or(SessionError::AgentNotFound(agent_id))?
167 .clone();
168
169 let join_handle = agent.spawn_task(task, self.tx_event.clone());
170
171 let result: Result<AgentRunResult, Error> = match join_handle.await {
173 Ok(run_result) => run_result,
174 Err(join_err) => Err(SessionError::from(join_err).into()),
175 };
176 result
177 }
178
179 pub async fn run(&self, agent_id: Uuid) -> Result<AgentRunResult, Error> {
180 let task = self.get_top_task().await;
181 self.run_task(task, agent_id).await
182 }
183
184 pub async fn run_all(&self, agent_id: Uuid) -> Result<Vec<AgentRunResult>, Error> {
185 let mut results = Vec::new();
186 while !self.is_task_queue_empty().await {
187 let result = self.run(agent_id).await?;
188 results.push(result);
189 }
190 Ok(results)
191 }
192}
193
194#[derive(Default)]
195pub struct SessionManager {
196 sessions: RwLock<HashMap<SessionId, Arc<Mutex<Session>>>>,
197}
198
199impl SessionManager {
200 pub fn new() -> Self {
201 Self::default()
202 }
203
204 pub async fn create_session(&self, channel_buffer: usize) -> (SessionId, Arc<Mutex<Session>>) {
205 let (id, session) = Session::new(channel_buffer);
206 let session_arc = Arc::new(Mutex::new(session));
207 let mut sessions = self.sessions.write().await;
208 sessions.insert(id, session_arc.clone());
209 (id, session_arc)
210 }
211
212 pub async fn get_session(&self, session_id: &SessionId) -> Option<Arc<Mutex<Session>>> {
213 let sessions = self.sessions.read().await;
214 sessions.get(session_id).cloned()
215 }
216
217 pub async fn get_session_mut(&self, session_id: &SessionId) -> Option<Arc<Mutex<Session>>> {
218 let sessions = self.sessions.read().await;
220 sessions.get(session_id).cloned()
221 }
222}