autoagents_core/
session.rs

1use crate::agent::base::BaseAgent;
2use crate::agent::executor::AgentExecutor;
3use crate::agent::result::AgentRunResult;
4use crate::agent::runnable::{RunnableAgent, RunnableAgentBuilder};
5use crate::error::Error;
6use crate::protocol::{AgentID, Event, SessionId, SubmissionId};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::{mpsc, Mutex, RwLock};
12use tokio::task::JoinError;
13use tokio_stream::wrappers::ReceiverStream;
14use uuid::Uuid;
15
16/// Error types for Session operations
17#[derive(Debug, thiserror::Error)]
18pub enum SessionError {
19    #[error("Agent not found: {0}")]
20    AgentNotFound(Uuid),
21
22    #[error("No task set for agent: {0}")]
23    NoTaskSet(Uuid),
24
25    #[error("Task is None")]
26    EmptyTask,
27
28    #[error("Task join error: {0}")]
29    TaskJoinError(#[from] JoinError),
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct Task {
34    pub prompt: String,
35    pub submission_id: SubmissionId,
36    pub completed: bool,
37    pub result: Option<Value>,
38    agent_id: Option<AgentID>,
39}
40
41impl Task {
42    pub fn new<T: Into<String>>(task: T, agent_id: Option<AgentID>) -> Self {
43        Self {
44            prompt: task.into(),
45            submission_id: Uuid::new_v4(),
46            completed: false,
47            result: None,
48            agent_id,
49        }
50    }
51}
52
53#[derive(Debug, Default)]
54struct State {
55    current_task: Option<Task>,
56    task_queue: Vec<Task>,
57}
58
59pub struct Session {
60    pub id: SessionId,
61    tx_event: mpsc::Sender<Event>,
62    rx_event: Option<mpsc::Receiver<Event>>,
63    state: Arc<Mutex<State>>,
64    agents: Arc<RwLock<HashMap<AgentID, Arc<dyn RunnableAgent + Send + Sync + 'static>>>>,
65}
66
67impl Session {
68    pub fn new(channel_buffer: usize) -> (SessionId, Self) {
69        let id = Uuid::new_v4();
70        let (tx_event, rx_event) = mpsc::channel(channel_buffer);
71        (
72            id,
73            Self {
74                id,
75                tx_event,
76                rx_event: Some(rx_event),
77                state: Arc::new(Mutex::new(State::default())),
78                agents: Arc::new(RwLock::new(HashMap::new())),
79            },
80        )
81    }
82
83    pub async fn add_task(&self, task: Task) {
84        let task_clone = task.clone();
85        if self
86            .tx_event
87            .send(Event::NewTask {
88                sub_id: task_clone.submission_id,
89                agent_id: task_clone.agent_id,
90                prompt: task_clone.prompt,
91            })
92            .await
93            .is_err()
94        {
95            panic!("Failed to send NewTask event");
96        }
97        let mut state = self.state.lock().await;
98        state.task_queue.push(task);
99    }
100
101    pub async fn set_current_task(&self, task: Task) {
102        let mut state = self.state.lock().await;
103        state.current_task = Some(task);
104    }
105
106    pub async fn is_task_queue_empty(&self) -> bool {
107        let state = self.state.lock().await;
108        state.task_queue.is_empty()
109    }
110
111    pub async fn get_top_task(&self) -> Option<Task> {
112        let mut state = self.state.lock().await;
113        if state.task_queue.is_empty() {
114            None
115        } else {
116            let task = state.task_queue.remove(0);
117            state.current_task = Some(task.clone());
118            Some(task)
119        }
120    }
121
122    pub async fn get_current_task(&self) -> Option<Task> {
123        let state = self.state.lock().await;
124        state.current_task.clone()
125    }
126
127    pub async fn get_task(&self, sub_id: SubmissionId) -> Option<Task> {
128        let state = self.state.lock().await;
129        state
130            .task_queue
131            .iter()
132            .find(|t| t.submission_id == sub_id)
133            .cloned()
134    }
135
136    pub fn event_sender(&self) -> mpsc::Sender<Event> {
137        self.tx_event.clone()
138    }
139
140    pub fn take_event_receiver(&mut self) -> Option<ReceiverStream<Event>> {
141        self.rx_event.take().map(ReceiverStream::new)
142    }
143
144    pub async fn register_agent_with_id<E>(&self, agent_id: Uuid, agent: BaseAgent<E>)
145    where
146        E: AgentExecutor,
147        E::Output: Into<Value> + Send,
148        E::Error: std::error::Error + Send + Sync + 'static,
149    {
150        let builder = RunnableAgentBuilder::new();
151        let runnable_agent = builder.build(agent);
152        self.agents.write().await.insert(agent_id, runnable_agent);
153    }
154
155    pub async fn run_task(
156        &self,
157        task: Option<Task>,
158        agent_id: AgentID,
159    ) -> Result<AgentRunResult, Error> {
160        let task = task.ok_or_else(|| SessionError::EmptyTask)?;
161        let agent = self
162            .agents
163            .read()
164            .await
165            .get(&agent_id)
166            .ok_or(SessionError::AgentNotFound(agent_id))?
167            .clone();
168
169        let join_handle = agent.spawn_task(task, self.tx_event.clone());
170
171        // Await the task completion:
172        let result: Result<AgentRunResult, Error> = match join_handle.await {
173            Ok(run_result) => run_result,
174            Err(join_err) => Err(SessionError::from(join_err).into()),
175        };
176        result
177    }
178
179    pub async fn run(&self, agent_id: Uuid) -> Result<AgentRunResult, Error> {
180        let task = self.get_top_task().await;
181        self.run_task(task, agent_id).await
182    }
183
184    pub async fn run_all(&self, agent_id: Uuid) -> Result<Vec<AgentRunResult>, Error> {
185        let mut results = Vec::new();
186        while !self.is_task_queue_empty().await {
187            let result = self.run(agent_id).await?;
188            results.push(result);
189        }
190        Ok(results)
191    }
192}
193
194#[derive(Default)]
195pub struct SessionManager {
196    sessions: RwLock<HashMap<SessionId, Arc<Mutex<Session>>>>,
197}
198
199impl SessionManager {
200    pub fn new() -> Self {
201        Self::default()
202    }
203
204    pub async fn create_session(&self, channel_buffer: usize) -> (SessionId, Arc<Mutex<Session>>) {
205        let (id, session) = Session::new(channel_buffer);
206        let session_arc = Arc::new(Mutex::new(session));
207        let mut sessions = self.sessions.write().await;
208        sessions.insert(id, session_arc.clone());
209        (id, session_arc)
210    }
211
212    pub async fn get_session(&self, session_id: &SessionId) -> Option<Arc<Mutex<Session>>> {
213        let sessions = self.sessions.read().await;
214        sessions.get(session_id).cloned()
215    }
216
217    pub async fn get_session_mut(&self, session_id: &SessionId) -> Option<Arc<Mutex<Session>>> {
218        // This can be same as get_session, since the interior mutability is in Mutex
219        let sessions = self.sessions.read().await;
220        sessions.get(session_id).cloned()
221    }
222}