1use super::{Runtime, RuntimeError};
2use crate::agent::constants::DEFAULT_CHANNEL_BUFFER;
3use crate::utils::{BoxEventStream, receiver_into_stream};
4use crate::{
5 actor::{AnyActor, Transport},
6 error::Error,
7};
8use async_trait::async_trait;
9use autoagents_protocol::{Event, InternalEvent, RuntimeID};
10use futures_util::StreamExt;
11use log::{debug, error, info, warn};
12use std::{
13 any::{Any, TypeId},
14 collections::HashMap,
15 sync::{
16 Arc,
17 atomic::{AtomicBool, Ordering},
18 },
19};
20use tokio::sync::{Mutex, Notify, RwLock, broadcast, mpsc};
21use tokio_stream::wrappers::{BroadcastStream, errors::BroadcastStreamRecvError};
22use uuid::Uuid;
23
24const DEFAULT_INTERNAL_BUFFER: usize = 1000;
25
26#[derive(Debug)]
28struct Subscription {
29 topic_type: TypeId,
30 actors: Vec<Arc<dyn AnyActor>>,
31}
32
33#[derive(Debug)]
34pub struct SingleThreadedRuntime {
36 pub id: RuntimeID,
37 external_tx: mpsc::Sender<Event>,
39 external_rx: Mutex<Option<mpsc::Receiver<Event>>>,
40 broadcast_tx: broadcast::Sender<Event>,
42 internal_tx: mpsc::Sender<InternalEvent>,
44 internal_rx: Mutex<Option<mpsc::Receiver<InternalEvent>>>,
45 subscriptions: Arc<RwLock<HashMap<String, Subscription>>>,
47 transport: Arc<dyn Transport>,
49 shutdown_flag: Arc<AtomicBool>,
51 shutdown_notify: Arc<Notify>,
52}
53
54impl SingleThreadedRuntime {
55 pub fn new(channel_buffer: Option<usize>) -> Arc<Self> {
56 Self::with_transport(channel_buffer, Arc::new(crate::actor::LocalTransport))
57 }
58
59 pub fn with_transport(
60 channel_buffer: Option<usize>,
61 transport: Arc<dyn Transport>,
62 ) -> Arc<Self> {
63 let id = Uuid::new_v4();
64 let buffer_size = channel_buffer.unwrap_or(DEFAULT_CHANNEL_BUFFER);
65
66 let (external_tx, external_rx) = mpsc::channel(buffer_size);
68 let (internal_tx, internal_rx) = mpsc::channel(DEFAULT_INTERNAL_BUFFER);
69 let (broadcast_tx, _) = broadcast::channel(buffer_size);
70
71 Arc::new(Self {
72 id,
73 external_tx,
74 external_rx: Mutex::new(Some(external_rx)),
75 broadcast_tx,
76 internal_tx,
77 internal_rx: Mutex::new(Some(internal_rx)),
78 subscriptions: Arc::new(RwLock::new(HashMap::new())),
79 transport,
80 shutdown_flag: Arc::new(AtomicBool::new(false)),
81 shutdown_notify: Arc::new(Notify::new()),
82 })
83 }
84
85 async fn process_internal_event(&self, event: InternalEvent) -> Result<(), Error> {
87 debug!("Received internal event: {event:?}");
88 match event {
89 InternalEvent::ProtocolEvent(event) => {
90 self.process_protocol_event(event).await?;
91 }
92 InternalEvent::Shutdown => {
93 self.shutdown_flag.store(true, Ordering::SeqCst);
94 self.shutdown_notify.notify_waiters();
95 }
96 }
97 Ok(())
98 }
99
100 async fn process_protocol_event(&self, event: Event) -> Result<(), Error> {
102 if let Event::PublishMessage {
103 topic_type,
104 topic_name,
105 message,
106 } = event
107 {
108 self.handle_publish_message(&topic_name, topic_type, message)
109 .await?;
110 } else {
111 let _ = self.broadcast_tx.send(event.clone());
113 self.external_tx
114 .send(event)
115 .await
116 .map_err(RuntimeError::EventError)?;
117 }
118 Ok(())
119 }
120
121 async fn handle_publish_message(
123 &self,
124 topic_name: &str,
125 topic_type: TypeId,
126 message: Arc<dyn Any + Send + Sync>,
127 ) -> Result<(), RuntimeError> {
128 debug!("Handling publish event: {topic_name}");
129
130 let subscriptions = self.subscriptions.read().await;
131
132 if let Some(subscription) = subscriptions.get(topic_name) {
133 if subscription.topic_type != topic_type {
135 error!(
136 "Type mismatch for topic '{}': expected {:?}, got {:?}",
137 topic_name, subscription.topic_type, topic_type
138 );
139 return Err(RuntimeError::TopicTypeMismatch(
140 topic_name.to_owned(),
141 topic_type,
142 ));
143 }
144
145 for actor in &subscription.actors {
147 if let Err(e) = self
148 .transport
149 .send(actor.as_ref(), Arc::clone(&message))
150 .await
151 {
152 error!("Failed to send message to subscriber: {e}");
153 }
154 }
155 } else {
156 debug!("No subscribers for topic: {topic_name}");
157 }
158
159 Ok(())
160 }
161
162 async fn handle_subscribe(
164 &self,
165 topic_name: &str,
166 topic_type: TypeId,
167 actor: Arc<dyn AnyActor>,
168 ) -> Result<(), RuntimeError> {
169 info!("Actor subscribing to topic: {topic_name}");
170
171 let mut subscriptions = self.subscriptions.write().await;
172
173 match subscriptions.get_mut(topic_name) {
174 Some(subscription) => {
175 if subscription.topic_type != topic_type {
177 return Err(RuntimeError::TopicTypeMismatch(
178 topic_name.to_string(),
179 subscription.topic_type,
180 ));
181 }
182 subscription.actors.push(actor);
183 }
184 None => {
185 subscriptions.insert(
187 topic_name.to_string(),
188 Subscription {
189 topic_type,
190 actors: vec![actor],
191 },
192 );
193 }
194 }
195
196 Ok(())
197 }
198
199 async fn event_loop(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
201 let mut internal_rx = self
202 .internal_rx
203 .lock()
204 .await
205 .take()
206 .ok_or("Internal receiver already taken")?;
207
208 info!("Runtime event loop starting");
209
210 loop {
211 tokio::select! {
212 Some(event) = internal_rx.recv() => {
214 debug!("Processing internal event");
215
216 if matches!(event, InternalEvent::Shutdown) {
218 info!("Received shutdown event");
219 self.process_internal_event(event).await?;
220 break;
221 }
222
223 if let Err(e) = self.process_internal_event(event).await {
224 error!("Error processing internal event: {e}");
225 break;
226 }
227 }
228 _ = self.shutdown_notify.notified() => {
230 if self.shutdown_flag.load(Ordering::SeqCst) {
231 info!("Runtime received shutdown notification");
232 break;
233 }
234 }
235 else => {
237 warn!("Internal event channel closed");
238 break;
239 }
240 }
241 }
242
243 info!("Draining remaining events before shutdown");
245 while let Ok(event) = internal_rx.try_recv() {
246 if let Err(e) = self.process_internal_event(event).await {
247 error!("Error processing event during shutdown: {e}");
248 }
249 }
250
251 info!("Runtime event loop stopped");
252 Ok(())
253 }
254}
255
256#[async_trait]
257impl Runtime for SingleThreadedRuntime {
258 fn id(&self) -> RuntimeID {
259 self.id
260 }
261
262 async fn subscribe_any(
263 &self,
264 topic_name: &str,
265 topic_type: TypeId,
266 actor: Arc<dyn AnyActor>,
267 ) -> Result<(), RuntimeError> {
268 self.handle_subscribe(topic_name, topic_type, actor).await
269 }
270
271 async fn publish_any(
272 &self,
273 topic_name: &str,
274 topic_type: TypeId,
275 message: Arc<dyn Any + Send + Sync>,
276 ) -> Result<(), RuntimeError> {
277 self.handle_publish_message(topic_name, topic_type, message)
278 .await
279 }
280
281 fn tx(&self) -> mpsc::Sender<Event> {
282 let internal_tx = self.internal_tx.clone();
284 let (interceptor_tx, mut interceptor_rx) = mpsc::channel::<Event>(DEFAULT_CHANNEL_BUFFER);
285
286 tokio::spawn(async move {
287 while let Some(event) = interceptor_rx.recv().await {
288 if let Err(e) = internal_tx.send(InternalEvent::ProtocolEvent(event)).await {
289 error!("Failed to forward event to internal channel: {e}");
290 break;
291 }
292 }
293 });
294
295 interceptor_tx
296 }
297
298 async fn transport(&self) -> Arc<dyn Transport> {
299 Arc::clone(&self.transport)
300 }
301
302 async fn take_event_receiver(&self) -> Option<BoxEventStream<Event>> {
303 let mut guard = self.external_rx.lock().await;
304 guard.take().map(receiver_into_stream)
305 }
306
307 async fn subscribe_events(&self) -> BoxEventStream<Event> {
308 let rx = self.broadcast_tx.subscribe();
309 let stream = BroadcastStream::new(rx)
310 .filter_map(|item: Result<Event, BroadcastStreamRecvError>| async move { item.ok() });
311 Box::pin(stream)
312 }
313
314 async fn run(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
315 info!("Starting SingleThreadedRuntime {}", self.id);
316 self.event_loop().await
317 }
318
319 async fn stop(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
320 info!("Initiating runtime shutdown for {}", self.id);
321
322 self.internal_tx
324 .send(InternalEvent::Shutdown)
325 .await
326 .map_err(|e| format!("Failed to send shutdown signal: {e}"))?;
327
328 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
330
331 Ok(())
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use crate::actor::{CloneableMessage, Topic};
339 use crate::runtime::{RuntimeConfig, TypedRuntime};
340 use ractor::{Actor, ActorProcessingErr, ActorRef};
341 use tokio::time::{Duration, sleep};
342
343 #[derive(Clone, Debug)]
345 struct TestMessage {
346 content: String,
347 }
348
349 impl crate::actor::ActorMessage for TestMessage {}
350 impl CloneableMessage for TestMessage {}
351
352 struct TestActor {
354 received: Arc<Mutex<Vec<String>>>,
355 }
356
357 #[async_trait]
358 impl Actor for TestActor {
359 type Msg = TestMessage;
360 type State = ();
361 type Arguments = Arc<Mutex<Vec<String>>>;
362
363 async fn pre_start(
364 &self,
365 _myself: ActorRef<Self::Msg>,
366 _args: Self::Arguments,
367 ) -> Result<Self::State, ActorProcessingErr> {
368 Ok(())
369 }
370
371 async fn handle(
372 &self,
373 _myself: ActorRef<Self::Msg>,
374 message: Self::Msg,
375 _state: &mut Self::State,
376 ) -> Result<(), ActorProcessingErr> {
377 let mut received = self.received.lock().await;
378 received.push(message.content);
379 Ok(())
380 }
381 }
382
383 #[tokio::test]
384 async fn test_runtime_creation() {
385 let runtime = SingleThreadedRuntime::new(None);
386 assert_ne!(runtime.id(), Uuid::nil());
387 }
388
389 #[tokio::test]
390 async fn test_publish_subscribe_cloneable() {
391 let runtime = SingleThreadedRuntime::new(Some(10));
392 let runtime_handle = runtime.clone();
393
394 let runtime_task = tokio::spawn(async move { runtime_handle.run().await });
396
397 let received = Arc::new(Mutex::new(Vec::new()));
399 let (actor_ref, _actor_handle) = Actor::spawn(
400 None,
401 TestActor {
402 received: received.clone(),
403 },
404 received.clone(),
405 )
406 .await
407 .unwrap();
408
409 let topic = Topic::<TestMessage>::new("test_topic");
411 runtime.subscribe(&topic, actor_ref).await.unwrap();
412
413 runtime
415 .publish(
416 &topic,
417 TestMessage {
418 content: "Hello".to_string(),
419 },
420 )
421 .await
422 .unwrap();
423
424 runtime
425 .publish(
426 &topic,
427 TestMessage {
428 content: "World".to_string(),
429 },
430 )
431 .await
432 .unwrap();
433
434 sleep(Duration::from_millis(100)).await;
436
437 let received_msgs = received.lock().await;
439 assert_eq!(received_msgs.len(), 2);
440 assert_eq!(received_msgs[0], "Hello");
441 assert_eq!(received_msgs[1], "World");
442
443 runtime.stop().await.unwrap();
445 runtime_task.abort();
446 }
447
448 #[tokio::test]
449 async fn test_type_safety() {
450 let runtime = SingleThreadedRuntime::new(None);
451 let runtime_handle = runtime.clone();
452
453 let runtime_task = tokio::spawn(async move { runtime_handle.run().await });
455
456 let topic_name = "typed_topic";
458 let topic1 = Topic::<TestMessage>::new(topic_name);
459
460 let received = Arc::new(Mutex::new(Vec::new()));
461 let (actor_ref, _) = Actor::spawn(
462 None,
463 TestActor {
464 received: received.clone(),
465 },
466 received.clone(),
467 )
468 .await
469 .unwrap();
470
471 runtime.subscribe(&topic1, actor_ref).await.unwrap();
472
473 sleep(Duration::from_millis(50)).await;
475
476 #[derive(Clone)]
478 struct OtherMessage;
479 impl crate::actor::ActorMessage for OtherMessage {}
480 impl CloneableMessage for OtherMessage {}
481
482 let topic2 = Topic::<OtherMessage>::new(topic_name);
483
484 struct OtherActor;
485 #[async_trait]
486 impl Actor for OtherActor {
487 type Msg = OtherMessage;
488 type State = ();
489 type Arguments = ();
490
491 async fn pre_start(
492 &self,
493 _myself: ActorRef<Self::Msg>,
494 _args: Self::Arguments,
495 ) -> Result<Self::State, ActorProcessingErr> {
496 Ok(())
497 }
498
499 async fn handle(
500 &self,
501 _myself: ActorRef<Self::Msg>,
502 _message: Self::Msg,
503 _state: &mut Self::State,
504 ) -> Result<(), ActorProcessingErr> {
505 Ok(())
506 }
507 }
508
509 let (other_ref, _) = Actor::spawn(None, OtherActor, ()).await.unwrap();
510
511 let result = runtime.subscribe(&topic2, other_ref).await;
513
514 assert!(result.is_err());
516
517 if let Err(RuntimeError::TopicTypeMismatch(topic, _)) = result {
519 assert_eq!(topic, topic_name);
520 } else {
521 panic!("Expected TopicTypeMismatch error");
522 }
523
524 runtime.stop().await.unwrap();
526 runtime_task.abort();
527 }
528
529 #[tokio::test]
530 async fn test_message_ordering() {
531 let runtime = SingleThreadedRuntime::new(Some(10));
532 let runtime_handle = runtime.clone();
533
534 let runtime_task = tokio::spawn(async move { runtime_handle.run().await });
536
537 let received = Arc::new(Mutex::new(Vec::new()));
539 let (actor_ref, _actor_handle) = Actor::spawn(
540 None,
541 TestActor {
542 received: received.clone(),
543 },
544 received.clone(),
545 )
546 .await
547 .unwrap();
548
549 let topic = Topic::<TestMessage>::new("order_test");
551 runtime.subscribe(&topic, actor_ref).await.unwrap();
552
553 for i in 0..10 {
555 runtime
556 .publish(
557 &topic,
558 TestMessage {
559 content: format!("Message {i}"),
560 },
561 )
562 .await
563 .unwrap();
564 }
565
566 sleep(Duration::from_millis(200)).await;
568
569 let received_msgs = received.lock().await;
571 assert_eq!(received_msgs.len(), 10);
572
573 for (i, msg) in received_msgs.iter().enumerate() {
574 assert_eq!(msg, &format!("Message {i}"));
575 }
576
577 runtime.stop().await.unwrap();
579 runtime_task.abort();
580 }
581
582 #[tokio::test]
583 async fn test_runtime_multiple_topics() {
584 let runtime = SingleThreadedRuntime::new(Some(10));
585 let runtime_handle = runtime.clone();
586
587 let runtime_task = tokio::spawn(async move { runtime_handle.run().await });
589
590 let topic1 = Topic::<TestMessage>::new("topic1");
592 let topic2 = Topic::<TestMessage>::new("topic2");
593
594 let received1 = Arc::new(Mutex::new(Vec::new()));
595 let received2 = Arc::new(Mutex::new(Vec::new()));
596
597 let (actor_ref1, _) = Actor::spawn(
598 None,
599 TestActor {
600 received: received1.clone(),
601 },
602 received1.clone(),
603 )
604 .await
605 .unwrap();
606
607 let (actor_ref2, _) = Actor::spawn(
608 None,
609 TestActor {
610 received: received2.clone(),
611 },
612 received2.clone(),
613 )
614 .await
615 .unwrap();
616
617 runtime.subscribe(&topic1, actor_ref1).await.unwrap();
619 runtime.subscribe(&topic2, actor_ref2).await.unwrap();
620 sleep(Duration::from_millis(50)).await;
621
622 let message1 = TestMessage {
624 content: "topic1_message".to_string(),
625 };
626 runtime.publish(&topic1, message1).await.unwrap();
627 sleep(Duration::from_millis(50)).await;
628
629 let message2 = TestMessage {
631 content: "topic2_message".to_string(),
632 };
633 runtime.publish(&topic2, message2).await.unwrap();
634 sleep(Duration::from_millis(50)).await;
635
636 let received_msgs1 = received1.lock().await;
638 let received_msgs2 = received2.lock().await;
639
640 assert_eq!(received_msgs1.len(), 1);
641 assert_eq!(received_msgs1[0], "topic1_message");
642
643 assert_eq!(received_msgs2.len(), 1);
644 assert_eq!(received_msgs2[0], "topic2_message");
645
646 runtime.stop().await.unwrap();
648 runtime_task.abort();
649 }
650
651 #[tokio::test]
652 async fn test_runtime_subscribe_multiple_actors_same_topic() {
653 let runtime = SingleThreadedRuntime::new(Some(10));
654 let runtime_handle = runtime.clone();
655
656 let runtime_task = tokio::spawn(async move { runtime_handle.run().await });
658
659 let topic = Topic::<TestMessage>::new("shared_topic");
660
661 let received1 = Arc::new(Mutex::new(Vec::new()));
662 let received2 = Arc::new(Mutex::new(Vec::new()));
663
664 let (actor_ref1, _) = Actor::spawn(
665 None,
666 TestActor {
667 received: received1.clone(),
668 },
669 received1.clone(),
670 )
671 .await
672 .unwrap();
673
674 let (actor_ref2, _) = Actor::spawn(
675 None,
676 TestActor {
677 received: received2.clone(),
678 },
679 received2.clone(),
680 )
681 .await
682 .unwrap();
683
684 runtime.subscribe(&topic, actor_ref1).await.unwrap();
686 runtime.subscribe(&topic, actor_ref2).await.unwrap();
687 sleep(Duration::from_millis(50)).await;
688
689 let message = TestMessage {
691 content: "broadcast_message".to_string(),
692 };
693 runtime.publish(&topic, message).await.unwrap();
694 sleep(Duration::from_millis(100)).await;
695
696 let received_msgs1 = received1.lock().await;
698 let received_msgs2 = received2.lock().await;
699
700 assert_eq!(received_msgs1.len(), 1);
701 assert_eq!(received_msgs1[0], "broadcast_message");
702
703 assert_eq!(received_msgs2.len(), 1);
704 assert_eq!(received_msgs2[0], "broadcast_message");
705
706 runtime.stop().await.unwrap();
708 runtime_task.abort();
709 }
710
711 #[test]
712 fn test_runtime_config_creation() {
713 let config = RuntimeConfig {
714 queue_size: Some(100),
715 };
716 assert_eq!(config.queue_size, Some(100));
717 }
718
719 #[test]
720 fn test_runtime_id_generation() {
721 let runtime1 = SingleThreadedRuntime::new(None);
722 let runtime2 = SingleThreadedRuntime::new(None);
723
724 assert_ne!(runtime1.id(), runtime2.id());
725 }
726}