autoagents_core/runtime/
single_threaded.rs

1use super::{Runtime, RuntimeError, Task};
2use crate::{
3    agent::RunnableAgent,
4    error::Error,
5    protocol::{AgentID, Event, RuntimeID},
6};
7use async_trait::async_trait;
8use log::{debug, error, info, warn};
9use std::{
10    collections::HashMap,
11    sync::{
12        atomic::{AtomicBool, Ordering},
13        Arc,
14    },
15};
16use tokio::sync::{mpsc, Mutex, Notify, RwLock};
17use tokio_stream::wrappers::ReceiverStream;
18use uuid::Uuid;
19
20const DEFAULT_CHANNEL_BUFFER: usize = 100;
21const DEFAULT_INTERNAL_BUFFER: usize = 1000;
22
23/// Internal events that are processed within the runtime
24#[derive(Debug, Clone)]
25enum InternalEvent {
26    /// An event from an agent that needs processing
27    AgentEvent(Event),
28    /// Shutdown signal
29    Shutdown,
30}
31
32/// Single-threaded runtime implementation with internal event routing
33#[derive(Debug)]
34pub struct SingleThreadedRuntime {
35    pub id: RuntimeID,
36    // External event channel for application consumption
37    external_tx: mpsc::Sender<Event>,
38    external_rx: Mutex<Option<mpsc::Receiver<Event>>>,
39    // Internal event channel for runtime processing
40    internal_tx: mpsc::Sender<InternalEvent>,
41    internal_rx: Mutex<Option<mpsc::Receiver<InternalEvent>>>,
42    // Agent and subscription management
43    agents: Arc<RwLock<HashMap<AgentID, Arc<dyn RunnableAgent>>>>,
44    subscriptions: Arc<RwLock<HashMap<String, Vec<AgentID>>>>,
45    // Runtime state
46    shutdown_flag: Arc<AtomicBool>,
47    shutdown_notify: Arc<Notify>,
48}
49
50impl SingleThreadedRuntime {
51    pub fn new(channel_buffer: Option<usize>) -> Arc<Self> {
52        let id = Uuid::new_v4();
53        let buffer_size = channel_buffer.unwrap_or(DEFAULT_CHANNEL_BUFFER);
54
55        // Create channels
56        let (external_tx, external_rx) = mpsc::channel(buffer_size);
57        let (internal_tx, internal_rx) = mpsc::channel(DEFAULT_INTERNAL_BUFFER);
58
59        Arc::new(Self {
60            id,
61            external_tx,
62            external_rx: Mutex::new(Some(external_rx)),
63            internal_tx,
64            internal_rx: Mutex::new(Some(internal_rx)),
65            agents: Arc::new(RwLock::new(HashMap::new())),
66            subscriptions: Arc::new(RwLock::new(HashMap::new())),
67            shutdown_flag: Arc::new(AtomicBool::new(false)),
68            shutdown_notify: Arc::new(Notify::new()),
69        })
70    }
71
72    /// Creates an event sender that intercepts specific events for internal processing
73    fn create_intercepting_sender(&self) -> mpsc::Sender<Event> {
74        let internal_tx = self.internal_tx.clone();
75        let (interceptor_tx, mut interceptor_rx) = mpsc::channel(DEFAULT_CHANNEL_BUFFER);
76
77        tokio::spawn(async move {
78            while let Some(event) = interceptor_rx.recv().await {
79                if let Err(e) = internal_tx.send(InternalEvent::AgentEvent(event)).await {
80                    error!("Failed to forward event to internal channel: {e}");
81                    break;
82                }
83            }
84        });
85
86        interceptor_tx
87    }
88
89    async fn process_internal_event(&self, event: InternalEvent) -> Result<(), Error> {
90        match event {
91            InternalEvent::AgentEvent(event) => {
92                self.process_agent_event(event).await?;
93            }
94            InternalEvent::Shutdown => {
95                self.shutdown_flag.store(true, Ordering::SeqCst);
96                self.shutdown_notify.notify_waiters();
97            }
98        }
99        Ok(())
100    }
101
102    async fn process_agent_event(&self, event: Event) -> Result<(), Error> {
103        match event {
104            Event::PublishMessage { topic, message } => {
105                debug!("Processing publish message to topic: {topic}");
106                self.handle_publish_message(topic, message).await?;
107            }
108            Event::SendMessage { agent_id, message } => {
109                debug!("Processing send message to agent: {agent_id:?}");
110                self.handle_send_message(agent_id, message).await?;
111            }
112            _ => {
113                // All other events are forwarded to external channel
114                self.external_tx
115                    .send(event)
116                    .await
117                    .map_err(|_| RuntimeError::EmptyTask)?;
118            }
119        }
120        Ok(())
121    }
122
123    async fn handle_publish_message(&self, topic: String, message: String) -> Result<(), Error> {
124        let subscriptions = self.subscriptions.read().await;
125
126        if let Some(agents) = subscriptions.get(&topic) {
127            debug!(
128                "Publishing message to topic '{}' with {} subscribers",
129                topic,
130                agents.len()
131            );
132
133            for agent_id in agents {
134                let task = Task::new(message.clone(), Some(*agent_id));
135                self.execute_task_on_agent(*agent_id, task).await?;
136            }
137        } else {
138            debug!("No subscribers for topic: {topic}");
139        }
140
141        Ok(())
142    }
143
144    async fn handle_send_message(&self, agent_id: AgentID, message: String) -> Result<(), Error> {
145        let task = Task::new(message, Some(agent_id));
146        self.execute_task_on_agent(agent_id, task).await
147    }
148
149    async fn execute_task_on_agent(&self, agent_id: AgentID, task: Task) -> Result<(), Error> {
150        let agents = self.agents.read().await;
151
152        if let Some(agent) = agents.get(&agent_id) {
153            debug!("Executing task on agent: {agent_id:?}");
154
155            // Create a new task event and send it to external channel first
156            self.external_tx
157                .send(Event::NewTask {
158                    agent_id,
159                    task: task.clone(),
160                })
161                .await
162                .map_err(|_| RuntimeError::EmptyTask)?;
163
164            // Create intercepting sender for this agent
165            let tx = self.create_intercepting_sender();
166
167            // Use spawn_task for async execution
168            agent.clone().spawn_task(task, tx);
169        } else {
170            warn!("Agent not found: {agent_id:?}");
171            return Err(RuntimeError::AgentNotFound(agent_id).into());
172        }
173
174        Ok(())
175    }
176}
177
178#[async_trait]
179impl Runtime for SingleThreadedRuntime {
180    fn id(&self) -> RuntimeID {
181        self.id
182    }
183
184    async fn publish_message(&self, message: String, topic: String) -> Result<(), Error> {
185        debug!(
186            "Runtime received publish_message request for topic: {}",
187            topic
188        );
189
190        // Send the publish event through internal channel
191        self.internal_tx
192            .send(InternalEvent::AgentEvent(Event::PublishMessage {
193                topic,
194                message,
195            }))
196            .await
197            .map_err(|_| RuntimeError::EmptyTask)?;
198
199        Ok(())
200    }
201
202    async fn send_message(&self, message: String, agent_id: AgentID) -> Result<(), Error> {
203        debug!(
204            "Runtime received send_message request to agent: {:?}",
205            agent_id
206        );
207
208        // Send the event through internal channel
209        self.internal_tx
210            .send(InternalEvent::AgentEvent(Event::SendMessage {
211                agent_id,
212                message,
213            }))
214            .await
215            .map_err(|_| RuntimeError::EmptyTask)?;
216
217        Ok(())
218    }
219
220    async fn register_agent(&self, agent: Arc<dyn RunnableAgent>) -> Result<(), Error> {
221        let agent_id = agent.id();
222        info!("Registering agent: {:?}", agent_id);
223
224        self.agents.write().await.insert(agent_id, agent);
225        Ok(())
226    }
227
228    async fn subscribe(&self, agent_id: AgentID, topic: String) -> Result<(), Error> {
229        info!("Agent {:?} subscribing to topic: {}", agent_id, topic);
230
231        let mut subscriptions = self.subscriptions.write().await;
232        let agents = subscriptions.entry(topic).or_insert_with(Vec::new);
233
234        if !agents.contains(&agent_id) {
235            agents.push(agent_id);
236        }
237
238        Ok(())
239    }
240
241    async fn take_event_receiver(&self) -> Option<ReceiverStream<Event>> {
242        self.external_rx
243            .lock()
244            .await
245            .take()
246            .map(ReceiverStream::new)
247    }
248
249    async fn run(&self) -> Result<(), Error> {
250        info!("Runtime starting");
251
252        // Take the internal receiver
253        let mut internal_rx = self
254            .internal_rx
255            .lock()
256            .await
257            .take()
258            .ok_or(RuntimeError::EmptyTask)?;
259
260        // Process events until shutdown
261        loop {
262            tokio::select! {
263                // Process internal events
264                Some(event) = internal_rx.recv() => {
265                    if let Err(e) = self.process_internal_event(event).await {
266                        error!("Error processing internal event: {e}");
267                    }
268                }
269                // Check for shutdown
270                _ = self.shutdown_notify.notified() => {
271                    if self.shutdown_flag.load(Ordering::SeqCst) {
272                        info!("Runtime received shutdown signal");
273                        break;
274                    }
275                }
276            }
277        }
278
279        // Drain remaining events
280        info!("Draining remaining events before shutdown");
281        while let Ok(event) = internal_rx.try_recv() {
282            if let Err(e) = self.process_internal_event(event).await {
283                error!("Error processing event during shutdown: {e}");
284            }
285        }
286
287        info!("Runtime stopped");
288        Ok(())
289    }
290
291    async fn stop(&self) -> Result<(), Error> {
292        info!("Initiating runtime shutdown");
293
294        // Send shutdown signal
295        let _ = self.internal_tx.send(InternalEvent::Shutdown).await;
296
297        Ok(())
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use crate::memory::MemoryProvider;
305    use crate::protocol::TaskResult;
306    use tokio::time::{sleep, Duration};
307
308    #[derive(Debug, Clone)]
309    struct MockAgent {
310        id: AgentID,
311    }
312
313    #[async_trait]
314    impl RunnableAgent for MockAgent {
315        fn id(&self) -> AgentID {
316            self.id
317        }
318
319        fn name(&self) -> &'static str {
320            "test"
321        }
322
323        fn description(&self) -> &'static str {
324            "test"
325        }
326
327        fn memory(&self) -> Option<Arc<RwLock<Box<dyn MemoryProvider>>>> {
328            None
329        }
330
331        async fn run(self: Arc<Self>, task: Task, tx: mpsc::Sender<Event>) -> Result<(), Error> {
332            // Send task started event
333            tx.send(Event::TaskStarted {
334                sub_id: task.submission_id,
335                agent_id: self.id,
336                task_description: task.prompt.clone(),
337            })
338            .await
339            .unwrap();
340
341            // Simulate some work
342            sleep(Duration::from_millis(10)).await;
343
344            // Send task complete event
345            tx.send(Event::TaskComplete {
346                sub_id: task.submission_id,
347                result: TaskResult::Value(serde_json::json!({
348                    "message": "Task completed successfully"
349                })),
350            })
351            .await
352            .unwrap();
353
354            Ok(())
355        }
356    }
357
358    #[tokio::test]
359    async fn test_runtime_creation() {
360        let runtime = SingleThreadedRuntime::new(None);
361        assert_ne!(runtime.id(), Uuid::nil());
362    }
363
364    #[tokio::test]
365    async fn test_agent_registration() {
366        let runtime = SingleThreadedRuntime::new(None);
367        let agent = Arc::new(MockAgent { id: Uuid::new_v4() });
368
369        runtime.register_agent(agent.clone()).await.unwrap();
370
371        let agents = runtime.agents.read().await;
372        assert!(agents.contains_key(&agent.id()));
373    }
374
375    #[tokio::test]
376    async fn test_subscription() {
377        let runtime = SingleThreadedRuntime::new(None);
378        let agent_id = Uuid::new_v4();
379        let topic = "test_topic".to_string();
380
381        runtime.subscribe(agent_id, topic.clone()).await.unwrap();
382
383        let subscriptions = runtime.subscriptions.read().await;
384        assert!(subscriptions.contains_key(&topic));
385        assert!(subscriptions.get(&topic).unwrap().contains(&agent_id));
386    }
387}