1use super::{Runtime, RuntimeError, Task};
2use crate::{
3 agent::RunnableAgent,
4 error::Error,
5 protocol::{AgentID, Event, RuntimeID},
6};
7use async_trait::async_trait;
8use log::{debug, error, info, warn};
9use std::{
10 collections::HashMap,
11 sync::{
12 atomic::{AtomicBool, Ordering},
13 Arc,
14 },
15};
16use tokio::sync::{mpsc, Mutex, Notify, RwLock};
17use tokio_stream::wrappers::ReceiverStream;
18use uuid::Uuid;
19
20const DEFAULT_CHANNEL_BUFFER: usize = 100;
21const DEFAULT_INTERNAL_BUFFER: usize = 1000;
22
23#[derive(Debug, Clone)]
25pub enum InternalEvent {
26 AgentEvent(Event),
28 Shutdown,
30}
31
32#[derive(Debug)]
34pub struct SingleThreadedRuntime {
35 pub id: RuntimeID,
36 external_tx: mpsc::Sender<Event>,
38 external_rx: Mutex<Option<mpsc::Receiver<Event>>>,
39 internal_tx: mpsc::Sender<InternalEvent>,
41 internal_rx: Mutex<Option<mpsc::Receiver<InternalEvent>>>,
42 agents: Arc<RwLock<HashMap<AgentID, Arc<dyn RunnableAgent>>>>,
44 subscriptions: Arc<RwLock<HashMap<String, Vec<AgentID>>>>,
45 shutdown_flag: Arc<AtomicBool>,
47 shutdown_notify: Arc<Notify>,
48}
49
50impl SingleThreadedRuntime {
51 pub fn new(channel_buffer: Option<usize>) -> Arc<Self> {
52 let id = Uuid::new_v4();
53 let buffer_size = channel_buffer.unwrap_or(DEFAULT_CHANNEL_BUFFER);
54
55 let (external_tx, external_rx) = mpsc::channel(buffer_size);
57 let (internal_tx, internal_rx) = mpsc::channel(DEFAULT_INTERNAL_BUFFER);
58
59 Arc::new(Self {
60 id,
61 external_tx,
62 external_rx: Mutex::new(Some(external_rx)),
63 internal_tx,
64 internal_rx: Mutex::new(Some(internal_rx)),
65 agents: Arc::new(RwLock::new(HashMap::new())),
66 subscriptions: Arc::new(RwLock::new(HashMap::new())),
67 shutdown_flag: Arc::new(AtomicBool::new(false)),
68 shutdown_notify: Arc::new(Notify::new()),
69 })
70 }
71
72 fn create_intercepting_sender(&self) -> mpsc::Sender<Event> {
74 let internal_tx = self.internal_tx.clone();
75 let (interceptor_tx, mut interceptor_rx) = mpsc::channel(DEFAULT_CHANNEL_BUFFER);
76
77 tokio::spawn(async move {
78 while let Some(event) = interceptor_rx.recv().await {
79 if let Err(e) = internal_tx.send(InternalEvent::AgentEvent(event)).await {
80 error!("Failed to forward event to internal channel: {e}");
81 break;
82 }
83 }
84 });
85
86 interceptor_tx
87 }
88
89 async fn process_internal_event(&self, event: InternalEvent) -> Result<(), Error> {
90 match event {
91 InternalEvent::AgentEvent(event) => {
92 self.process_agent_event(event).await?;
93 }
94 InternalEvent::Shutdown => {
95 self.shutdown_flag.store(true, Ordering::SeqCst);
96 self.shutdown_notify.notify_waiters();
97 }
98 }
99 Ok(())
100 }
101
102 async fn process_agent_event(&self, event: Event) -> Result<(), Error> {
103 match event {
104 Event::PublishMessage { topic, message } => {
105 debug!("Processing publish message to topic: {topic}");
106 self.handle_publish_message(topic, message).await?;
107 }
108 Event::SendMessage { agent_id, message } => {
109 debug!("Processing send message to agent: {agent_id:?}");
110 self.handle_send_message(agent_id, message).await?;
111 }
112 _ => {
113 self.external_tx
115 .send(event)
116 .await
117 .map_err(RuntimeError::EventError)?;
118 }
119 }
120 Ok(())
121 }
122
123 async fn handle_publish_message(&self, topic: String, message: String) -> Result<(), Error> {
124 let subscriptions = self.subscriptions.read().await;
125
126 if let Some(agents) = subscriptions.get(&topic) {
127 debug!(
128 "Publishing message to topic '{}' with {} subscribers",
129 topic,
130 agents.len()
131 );
132
133 for agent_id in agents {
134 let task = Task::new(message.clone(), Some(*agent_id));
135 self.execute_task_on_agent(*agent_id, task).await?;
136 }
137 } else {
138 debug!("No subscribers for topic: {topic}");
139 }
140
141 Ok(())
142 }
143
144 async fn handle_send_message(&self, agent_id: AgentID, message: String) -> Result<(), Error> {
145 let task = Task::new(message, Some(agent_id));
146 self.execute_task_on_agent(agent_id, task).await
147 }
148
149 async fn execute_task_on_agent(&self, agent_id: AgentID, task: Task) -> Result<(), Error> {
150 let agents = self.agents.read().await;
151
152 if let Some(agent) = agents.get(&agent_id) {
153 debug!("Executing task on agent: {agent_id:?}");
154
155 self.external_tx
157 .send(Event::NewTask {
158 agent_id,
159 task: task.clone(),
160 })
161 .await
162 .map_err(RuntimeError::EventError)?;
163
164 let tx = self.create_intercepting_sender();
166
167 agent.clone().spawn_task(task, tx);
169 } else {
170 warn!("Agent not found: {agent_id:?}");
171 return Err(RuntimeError::AgentNotFound(agent_id).into());
172 }
173
174 Ok(())
175 }
176}
177
178#[async_trait]
179impl Runtime for SingleThreadedRuntime {
180 fn id(&self) -> RuntimeID {
181 self.id
182 }
183
184 async fn publish_message(&self, message: String, topic: String) -> Result<(), Error> {
185 debug!(
186 "Runtime received publish_message request for topic: {}",
187 topic
188 );
189
190 self.internal_tx
192 .send(InternalEvent::AgentEvent(Event::PublishMessage {
193 topic,
194 message,
195 }))
196 .await
197 .map_err(RuntimeError::InternalEventError)?;
198
199 Ok(())
200 }
201
202 async fn send_message(&self, message: String, agent_id: AgentID) -> Result<(), Error> {
203 debug!(
204 "Runtime received send_message request to agent: {:?}",
205 agent_id
206 );
207
208 self.internal_tx
210 .send(InternalEvent::AgentEvent(Event::SendMessage {
211 agent_id,
212 message,
213 }))
214 .await
215 .map_err(RuntimeError::InternalEventError)?;
216
217 Ok(())
218 }
219
220 async fn register_agent(&self, agent: Arc<dyn RunnableAgent>) -> Result<(), Error> {
221 let agent_id = agent.id();
222 info!("Registering agent: {:?}", agent_id);
223
224 self.agents.write().await.insert(agent_id, agent);
225 Ok(())
226 }
227
228 async fn subscribe(&self, agent_id: AgentID, topic: String) -> Result<(), Error> {
229 info!("Agent {:?} subscribing to topic: {}", agent_id, topic);
230
231 let mut subscriptions = self.subscriptions.write().await;
232 let agents = subscriptions.entry(topic).or_insert_with(Vec::new);
233
234 if !agents.contains(&agent_id) {
235 agents.push(agent_id);
236 }
237
238 Ok(())
239 }
240
241 async fn take_event_receiver(&self) -> Option<ReceiverStream<Event>> {
242 self.external_rx
243 .lock()
244 .await
245 .take()
246 .map(ReceiverStream::new)
247 }
248
249 async fn run(&self) -> Result<(), Error> {
250 info!("Runtime starting");
251
252 let mut internal_rx = self
254 .internal_rx
255 .lock()
256 .await
257 .take()
258 .ok_or(RuntimeError::EmptyTask)?;
259
260 loop {
262 tokio::select! {
263 Some(event) = internal_rx.recv() => {
265 debug!("Processing internal event: {event:?}");
266 if let Err(e) = self.process_internal_event(event).await {
267 error!("Error processing internal event: {e}");
268 }
269 }
270 _ = self.shutdown_notify.notified() => {
272 if self.shutdown_flag.load(Ordering::SeqCst) {
273 info!("Runtime received shutdown signal");
274 break;
275 }
276 }
277 }
278 }
279
280 info!("Draining remaining events before shutdown");
282 while let Ok(event) = internal_rx.try_recv() {
283 if let Err(e) = self.process_internal_event(event).await {
284 error!("Error processing event during shutdown: {e}");
285 }
286 }
287
288 info!("Runtime stopped");
289 Ok(())
290 }
291
292 async fn stop(&self) -> Result<(), Error> {
293 info!("Initiating runtime shutdown");
294
295 let _ = self.internal_tx.send(InternalEvent::Shutdown).await;
297
298 Ok(())
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use crate::memory::MemoryProvider;
306 use crate::protocol::TaskResult;
307 use tokio::time::{sleep, Duration};
308
309 #[derive(Debug, Clone)]
310 struct MockAgent {
311 id: AgentID,
312 }
313
314 #[async_trait]
315 impl RunnableAgent for MockAgent {
316 fn id(&self) -> AgentID {
317 self.id
318 }
319
320 fn name(&self) -> &'static str {
321 "test"
322 }
323
324 fn description(&self) -> &'static str {
325 "test"
326 }
327
328 fn memory(&self) -> Option<Arc<RwLock<Box<dyn MemoryProvider>>>> {
329 None
330 }
331
332 async fn run(self: Arc<Self>, task: Task, tx: mpsc::Sender<Event>) -> Result<(), Error> {
333 tx.send(Event::TaskStarted {
335 sub_id: task.submission_id,
336 agent_id: self.id,
337 task_description: task.prompt.clone(),
338 })
339 .await
340 .unwrap();
341
342 sleep(Duration::from_millis(10)).await;
344
345 tx.send(Event::TaskComplete {
347 sub_id: task.submission_id,
348 result: TaskResult::Value(serde_json::json!({
349 "message": "Task completed successfully"
350 })),
351 })
352 .await
353 .unwrap();
354
355 Ok(())
356 }
357 }
358
359 #[tokio::test]
360 async fn test_runtime_creation() {
361 let runtime = SingleThreadedRuntime::new(None);
362 assert_ne!(runtime.id(), Uuid::nil());
363 }
364
365 #[tokio::test]
366 async fn test_agent_registration() {
367 let runtime = SingleThreadedRuntime::new(None);
368 let agent = Arc::new(MockAgent { id: Uuid::new_v4() });
369
370 runtime.register_agent(agent.clone()).await.unwrap();
371
372 let agents = runtime.agents.read().await;
373 assert!(agents.contains_key(&agent.id()));
374 }
375
376 #[tokio::test]
377 async fn test_subscription() {
378 let runtime = SingleThreadedRuntime::new(None);
379 let agent_id = Uuid::new_v4();
380 let topic = "test_topic".to_string();
381
382 runtime.subscribe(agent_id, topic.clone()).await.unwrap();
383
384 let subscriptions = runtime.subscriptions.read().await;
385 assert!(subscriptions.contains_key(&topic));
386 assert!(subscriptions.get(&topic).unwrap().contains(&agent_id));
387 }
388}