autoagents_core/runtime/
single_threaded.rs

1use super::{Runtime, RuntimeError, Task};
2use crate::{
3    agent::RunnableAgent,
4    error::Error,
5    protocol::{AgentID, Event, RuntimeID},
6};
7use async_trait::async_trait;
8use log::{debug, error, info, warn};
9use std::{
10    collections::HashMap,
11    sync::{
12        atomic::{AtomicBool, Ordering},
13        Arc,
14    },
15};
16use tokio::sync::{mpsc, Mutex, Notify, RwLock};
17use tokio_stream::wrappers::ReceiverStream;
18use uuid::Uuid;
19
20const DEFAULT_CHANNEL_BUFFER: usize = 100;
21const DEFAULT_INTERNAL_BUFFER: usize = 1000;
22
23/// Internal events that are processed within the runtime
24#[derive(Debug, Clone)]
25pub enum InternalEvent {
26    /// An event from an agent that needs processing
27    AgentEvent(Event),
28    /// Shutdown signal
29    Shutdown,
30}
31
32/// Single-threaded runtime implementation with internal event routing
33#[derive(Debug)]
34pub struct SingleThreadedRuntime {
35    pub id: RuntimeID,
36    // External event channel for application consumption
37    external_tx: mpsc::Sender<Event>,
38    external_rx: Mutex<Option<mpsc::Receiver<Event>>>,
39    // Internal event channel for runtime processing
40    internal_tx: mpsc::Sender<InternalEvent>,
41    internal_rx: Mutex<Option<mpsc::Receiver<InternalEvent>>>,
42    // Agent and subscription management
43    agents: Arc<RwLock<HashMap<AgentID, Arc<dyn RunnableAgent>>>>,
44    subscriptions: Arc<RwLock<HashMap<String, Vec<AgentID>>>>,
45    // Runtime state
46    shutdown_flag: Arc<AtomicBool>,
47    shutdown_notify: Arc<Notify>,
48}
49
50impl SingleThreadedRuntime {
51    pub fn new(channel_buffer: Option<usize>) -> Arc<Self> {
52        let id = Uuid::new_v4();
53        let buffer_size = channel_buffer.unwrap_or(DEFAULT_CHANNEL_BUFFER);
54
55        // Create channels
56        let (external_tx, external_rx) = mpsc::channel(buffer_size);
57        let (internal_tx, internal_rx) = mpsc::channel(DEFAULT_INTERNAL_BUFFER);
58
59        Arc::new(Self {
60            id,
61            external_tx,
62            external_rx: Mutex::new(Some(external_rx)),
63            internal_tx,
64            internal_rx: Mutex::new(Some(internal_rx)),
65            agents: Arc::new(RwLock::new(HashMap::new())),
66            subscriptions: Arc::new(RwLock::new(HashMap::new())),
67            shutdown_flag: Arc::new(AtomicBool::new(false)),
68            shutdown_notify: Arc::new(Notify::new()),
69        })
70    }
71
72    /// Creates an event sender that intercepts specific events for internal processing
73    fn create_intercepting_sender(&self) -> mpsc::Sender<Event> {
74        let internal_tx = self.internal_tx.clone();
75        let (interceptor_tx, mut interceptor_rx) = mpsc::channel(DEFAULT_CHANNEL_BUFFER);
76
77        tokio::spawn(async move {
78            while let Some(event) = interceptor_rx.recv().await {
79                if let Err(e) = internal_tx.send(InternalEvent::AgentEvent(event)).await {
80                    error!("Failed to forward event to internal channel: {e}");
81                    break;
82                }
83            }
84        });
85
86        interceptor_tx
87    }
88
89    async fn process_internal_event(&self, event: InternalEvent) -> Result<(), Error> {
90        match event {
91            InternalEvent::AgentEvent(event) => {
92                self.process_agent_event(event).await?;
93            }
94            InternalEvent::Shutdown => {
95                self.shutdown_flag.store(true, Ordering::SeqCst);
96                self.shutdown_notify.notify_waiters();
97            }
98        }
99        Ok(())
100    }
101
102    async fn process_agent_event(&self, event: Event) -> Result<(), Error> {
103        match event {
104            Event::PublishMessage { topic, message } => {
105                debug!("Processing publish message to topic: {topic}");
106                self.handle_publish_message(topic, message).await?;
107            }
108            Event::SendMessage { agent_id, message } => {
109                debug!("Processing send message to agent: {agent_id:?}");
110                self.handle_send_message(agent_id, message).await?;
111            }
112            _ => {
113                // All other events are forwarded to external channel
114                self.external_tx
115                    .send(event)
116                    .await
117                    .map_err(RuntimeError::EventError)?;
118            }
119        }
120        Ok(())
121    }
122
123    async fn handle_publish_message(&self, topic: String, message: String) -> Result<(), Error> {
124        let subscriptions = self.subscriptions.read().await;
125
126        if let Some(agents) = subscriptions.get(&topic) {
127            debug!(
128                "Publishing message to topic '{}' with {} subscribers",
129                topic,
130                agents.len()
131            );
132
133            for agent_id in agents {
134                let task = Task::new(message.clone(), Some(*agent_id));
135                self.execute_task_on_agent(*agent_id, task).await?;
136            }
137        } else {
138            debug!("No subscribers for topic: {topic}");
139        }
140
141        Ok(())
142    }
143
144    async fn handle_send_message(&self, agent_id: AgentID, message: String) -> Result<(), Error> {
145        let task = Task::new(message, Some(agent_id));
146        self.execute_task_on_agent(agent_id, task).await
147    }
148
149    async fn execute_task_on_agent(&self, agent_id: AgentID, task: Task) -> Result<(), Error> {
150        let agents = self.agents.read().await;
151
152        if let Some(agent) = agents.get(&agent_id) {
153            debug!("Executing task on agent: {agent_id:?}");
154
155            // Create a new task event and send it to external channel first
156            self.external_tx
157                .send(Event::NewTask {
158                    agent_id,
159                    task: task.clone(),
160                })
161                .await
162                .map_err(RuntimeError::EventError)?;
163
164            // Create intercepting sender for this agent
165            let tx = self.create_intercepting_sender();
166
167            // Use spawn_task for async execution
168            agent.clone().spawn_task(task, tx);
169        } else {
170            warn!("Agent not found: {agent_id:?}");
171            return Err(RuntimeError::AgentNotFound(agent_id).into());
172        }
173
174        Ok(())
175    }
176}
177
178#[async_trait]
179impl Runtime for SingleThreadedRuntime {
180    fn id(&self) -> RuntimeID {
181        self.id
182    }
183
184    async fn publish_message(&self, message: String, topic: String) -> Result<(), Error> {
185        debug!(
186            "Runtime received publish_message request for topic: {}",
187            topic
188        );
189
190        // Send the publish event through internal channel
191        self.internal_tx
192            .send(InternalEvent::AgentEvent(Event::PublishMessage {
193                topic,
194                message,
195            }))
196            .await
197            .map_err(RuntimeError::InternalEventError)?;
198
199        Ok(())
200    }
201
202    async fn send_message(&self, message: String, agent_id: AgentID) -> Result<(), Error> {
203        debug!(
204            "Runtime received send_message request to agent: {:?}",
205            agent_id
206        );
207
208        // Send the event through internal channel
209        self.internal_tx
210            .send(InternalEvent::AgentEvent(Event::SendMessage {
211                agent_id,
212                message,
213            }))
214            .await
215            .map_err(RuntimeError::InternalEventError)?;
216
217        Ok(())
218    }
219
220    async fn register_agent(&self, agent: Arc<dyn RunnableAgent>) -> Result<(), Error> {
221        let agent_id = agent.id();
222        info!("Registering agent: {:?}", agent_id);
223
224        self.agents.write().await.insert(agent_id, agent);
225        Ok(())
226    }
227
228    async fn subscribe(&self, agent_id: AgentID, topic: String) -> Result<(), Error> {
229        info!("Agent {:?} subscribing to topic: {}", agent_id, topic);
230
231        let mut subscriptions = self.subscriptions.write().await;
232        let agents = subscriptions.entry(topic).or_insert_with(Vec::new);
233
234        if !agents.contains(&agent_id) {
235            agents.push(agent_id);
236        }
237
238        Ok(())
239    }
240
241    async fn take_event_receiver(&self) -> Option<ReceiverStream<Event>> {
242        self.external_rx
243            .lock()
244            .await
245            .take()
246            .map(ReceiverStream::new)
247    }
248
249    async fn run(&self) -> Result<(), Error> {
250        info!("Runtime starting");
251
252        // Take the internal receiver
253        let mut internal_rx = self
254            .internal_rx
255            .lock()
256            .await
257            .take()
258            .ok_or(RuntimeError::EmptyTask)?;
259
260        // Process events until shutdown
261        loop {
262            tokio::select! {
263                // Process internal events
264                Some(event) = internal_rx.recv() => {
265                    debug!("Processing internal event: {event:?}");
266                    if let Err(e) = self.process_internal_event(event).await {
267                        error!("Error processing internal event: {e}");
268                    }
269                }
270                // Check for shutdown
271                _ = self.shutdown_notify.notified() => {
272                    if self.shutdown_flag.load(Ordering::SeqCst) {
273                        info!("Runtime received shutdown signal");
274                        break;
275                    }
276                }
277            }
278        }
279
280        // Drain remaining events
281        info!("Draining remaining events before shutdown");
282        while let Ok(event) = internal_rx.try_recv() {
283            if let Err(e) = self.process_internal_event(event).await {
284                error!("Error processing event during shutdown: {e}");
285            }
286        }
287
288        info!("Runtime stopped");
289        Ok(())
290    }
291
292    async fn stop(&self) -> Result<(), Error> {
293        info!("Initiating runtime shutdown");
294
295        // Send shutdown signal
296        let _ = self.internal_tx.send(InternalEvent::Shutdown).await;
297
298        Ok(())
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use crate::memory::MemoryProvider;
306    use crate::protocol::TaskResult;
307    use tokio::time::{sleep, Duration};
308
309    #[derive(Debug, Clone)]
310    struct MockAgent {
311        id: AgentID,
312    }
313
314    #[async_trait]
315    impl RunnableAgent for MockAgent {
316        fn id(&self) -> AgentID {
317            self.id
318        }
319
320        fn name(&self) -> &'static str {
321            "test"
322        }
323
324        fn description(&self) -> &'static str {
325            "test"
326        }
327
328        fn memory(&self) -> Option<Arc<RwLock<Box<dyn MemoryProvider>>>> {
329            None
330        }
331
332        async fn run(self: Arc<Self>, task: Task, tx: mpsc::Sender<Event>) -> Result<(), Error> {
333            // Send task started event
334            tx.send(Event::TaskStarted {
335                sub_id: task.submission_id,
336                agent_id: self.id,
337                task_description: task.prompt.clone(),
338            })
339            .await
340            .unwrap();
341
342            // Simulate some work
343            sleep(Duration::from_millis(10)).await;
344
345            // Send task complete event
346            tx.send(Event::TaskComplete {
347                sub_id: task.submission_id,
348                result: TaskResult::Value(serde_json::json!({
349                    "message": "Task completed successfully"
350                })),
351            })
352            .await
353            .unwrap();
354
355            Ok(())
356        }
357    }
358
359    #[tokio::test]
360    async fn test_runtime_creation() {
361        let runtime = SingleThreadedRuntime::new(None);
362        assert_ne!(runtime.id(), Uuid::nil());
363    }
364
365    #[tokio::test]
366    async fn test_agent_registration() {
367        let runtime = SingleThreadedRuntime::new(None);
368        let agent = Arc::new(MockAgent { id: Uuid::new_v4() });
369
370        runtime.register_agent(agent.clone()).await.unwrap();
371
372        let agents = runtime.agents.read().await;
373        assert!(agents.contains_key(&agent.id()));
374    }
375
376    #[tokio::test]
377    async fn test_subscription() {
378        let runtime = SingleThreadedRuntime::new(None);
379        let agent_id = Uuid::new_v4();
380        let topic = "test_topic".to_string();
381
382        runtime.subscribe(agent_id, topic.clone()).await.unwrap();
383
384        let subscriptions = runtime.subscriptions.read().await;
385        assert!(subscriptions.contains_key(&topic));
386        assert!(subscriptions.get(&topic).unwrap().contains(&agent_id));
387    }
388}