1use super::{Runtime, RuntimeError};
2use crate::agent::constants::DEFAULT_CHANNEL_BUFFER;
3use crate::utils::{BoxEventStream, receiver_into_stream};
4use crate::{
5 actor::{AnyActor, Transport},
6 error::Error,
7};
8use async_trait::async_trait;
9use autoagents_protocol::{Event, InternalEvent, RuntimeID};
10use futures_util::StreamExt;
11use log::{debug, error, info, warn};
12use std::{
13 any::{Any, TypeId},
14 collections::HashMap,
15 sync::{
16 Arc,
17 atomic::{AtomicBool, Ordering},
18 },
19};
20use tokio::sync::{Mutex, Notify, RwLock, broadcast, mpsc};
21use tokio_stream::wrappers::{BroadcastStream, errors::BroadcastStreamRecvError};
22use uuid::Uuid;
23
24const DEFAULT_INTERNAL_BUFFER: usize = 1000;
25
26#[derive(Debug)]
28struct Subscription {
29 topic_type: TypeId,
30 actors: Vec<Arc<dyn AnyActor>>,
31}
32
33#[derive(Debug)]
34pub struct SingleThreadedRuntime {
36 pub id: RuntimeID,
37 external_tx: mpsc::Sender<Event>,
39 external_rx: Mutex<Option<mpsc::Receiver<Event>>>,
40 broadcast_tx: broadcast::Sender<Event>,
42 internal_tx: mpsc::Sender<InternalEvent>,
44 internal_rx: Mutex<Option<mpsc::Receiver<InternalEvent>>>,
45 subscriptions: Arc<RwLock<HashMap<String, Subscription>>>,
47 transport: Arc<dyn Transport>,
49 shutdown_flag: Arc<AtomicBool>,
51 shutdown_notify: Arc<Notify>,
52}
53
54impl SingleThreadedRuntime {
55 pub fn new(channel_buffer: Option<usize>) -> Arc<Self> {
56 Self::with_transport(channel_buffer, Arc::new(crate::actor::LocalTransport))
57 }
58
59 pub fn with_transport(
60 channel_buffer: Option<usize>,
61 transport: Arc<dyn Transport>,
62 ) -> Arc<Self> {
63 let id = Uuid::new_v4();
64 let buffer_size = channel_buffer.unwrap_or(DEFAULT_CHANNEL_BUFFER);
65
66 let (external_tx, external_rx) = mpsc::channel(buffer_size);
68 let (internal_tx, internal_rx) = mpsc::channel(DEFAULT_INTERNAL_BUFFER);
69 let (broadcast_tx, _) = broadcast::channel(buffer_size);
70
71 Arc::new(Self {
72 id,
73 external_tx,
74 external_rx: Mutex::new(Some(external_rx)),
75 broadcast_tx,
76 internal_tx,
77 internal_rx: Mutex::new(Some(internal_rx)),
78 subscriptions: Arc::new(RwLock::new(HashMap::new())),
79 transport,
80 shutdown_flag: Arc::new(AtomicBool::new(false)),
81 shutdown_notify: Arc::new(Notify::new()),
82 })
83 }
84
85 async fn process_internal_event(&self, event: InternalEvent) -> Result<(), Error> {
87 debug!("Received internal event: {event:?}");
88 match event {
89 InternalEvent::ProtocolEvent(event) => {
90 self.process_protocol_event(event).await?;
91 }
92 InternalEvent::Shutdown => {
93 self.shutdown_flag.store(true, Ordering::SeqCst);
94 self.shutdown_notify.notify_waiters();
95 }
96 }
97 Ok(())
98 }
99
100 async fn process_protocol_event(&self, event: Event) -> Result<(), Error> {
102 match event {
103 Event::PublishMessage {
104 topic_type,
105 topic_name,
106 message,
107 } => {
108 self.handle_publish_message(&topic_name, topic_type, message)
109 .await?;
110 }
111 _ => {
112 let _ = self.broadcast_tx.send(event.clone());
114 self.external_tx
115 .send(event)
116 .await
117 .map_err(RuntimeError::EventError)?;
118 }
119 }
120 Ok(())
121 }
122
123 async fn handle_publish_message(
125 &self,
126 topic_name: &str,
127 topic_type: TypeId,
128 message: Arc<dyn Any + Send + Sync>,
129 ) -> Result<(), RuntimeError> {
130 debug!("Handling publish event: {topic_name}");
131
132 let subscriptions = self.subscriptions.read().await;
133
134 if let Some(subscription) = subscriptions.get(topic_name) {
135 if subscription.topic_type != topic_type {
137 error!(
138 "Type mismatch for topic '{}': expected {:?}, got {:?}",
139 topic_name, subscription.topic_type, topic_type
140 );
141 return Err(RuntimeError::TopicTypeMismatch(
142 topic_name.to_owned(),
143 topic_type,
144 ));
145 }
146
147 for actor in &subscription.actors {
149 if let Err(e) = self
150 .transport
151 .send(actor.as_ref(), Arc::clone(&message))
152 .await
153 {
154 error!("Failed to send message to subscriber: {e}");
155 }
156 }
157 } else {
158 debug!("No subscribers for topic: {topic_name}");
159 }
160
161 Ok(())
162 }
163
164 async fn handle_subscribe(
166 &self,
167 topic_name: &str,
168 topic_type: TypeId,
169 actor: Arc<dyn AnyActor>,
170 ) -> Result<(), RuntimeError> {
171 info!("Actor subscribing to topic: {topic_name}");
172
173 let mut subscriptions = self.subscriptions.write().await;
174
175 match subscriptions.get_mut(topic_name) {
176 Some(subscription) => {
177 if subscription.topic_type != topic_type {
179 return Err(RuntimeError::TopicTypeMismatch(
180 topic_name.to_string(),
181 subscription.topic_type,
182 ));
183 }
184 subscription.actors.push(actor);
185 }
186 None => {
187 subscriptions.insert(
189 topic_name.to_string(),
190 Subscription {
191 topic_type,
192 actors: vec![actor],
193 },
194 );
195 }
196 }
197
198 Ok(())
199 }
200
201 async fn event_loop(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
203 let mut internal_rx = self
204 .internal_rx
205 .lock()
206 .await
207 .take()
208 .ok_or("Internal receiver already taken")?;
209
210 info!("Runtime event loop starting");
211
212 loop {
213 tokio::select! {
214 Some(event) = internal_rx.recv() => {
216 debug!("Processing internal event");
217
218 if matches!(event, InternalEvent::Shutdown) {
220 info!("Received shutdown event");
221 self.process_internal_event(event).await?;
222 break;
223 }
224
225 if let Err(e) = self.process_internal_event(event).await {
226 error!("Error processing internal event: {e}");
227 break;
228 }
229 }
230 _ = self.shutdown_notify.notified() => {
232 if self.shutdown_flag.load(Ordering::SeqCst) {
233 info!("Runtime received shutdown notification");
234 break;
235 }
236 }
237 else => {
239 warn!("Internal event channel closed");
240 break;
241 }
242 }
243 }
244
245 info!("Draining remaining events before shutdown");
247 while let Ok(event) = internal_rx.try_recv() {
248 if let Err(e) = self.process_internal_event(event).await {
249 error!("Error processing event during shutdown: {e}");
250 }
251 }
252
253 info!("Runtime event loop stopped");
254 Ok(())
255 }
256}
257
258#[async_trait]
259impl Runtime for SingleThreadedRuntime {
260 fn id(&self) -> RuntimeID {
261 self.id
262 }
263
264 async fn subscribe_any(
265 &self,
266 topic_name: &str,
267 topic_type: TypeId,
268 actor: Arc<dyn AnyActor>,
269 ) -> Result<(), RuntimeError> {
270 self.handle_subscribe(topic_name, topic_type, actor).await
271 }
272
273 async fn publish_any(
274 &self,
275 topic_name: &str,
276 topic_type: TypeId,
277 message: Arc<dyn Any + Send + Sync>,
278 ) -> Result<(), RuntimeError> {
279 self.handle_publish_message(topic_name, topic_type, message)
280 .await
281 }
282
283 fn tx(&self) -> mpsc::Sender<Event> {
284 let internal_tx = self.internal_tx.clone();
286 let (interceptor_tx, mut interceptor_rx) = mpsc::channel::<Event>(DEFAULT_CHANNEL_BUFFER);
287
288 tokio::spawn(async move {
289 while let Some(event) = interceptor_rx.recv().await {
290 if let Err(e) = internal_tx.send(InternalEvent::ProtocolEvent(event)).await {
291 error!("Failed to forward event to internal channel: {e}");
292 break;
293 }
294 }
295 });
296
297 interceptor_tx
298 }
299
300 async fn transport(&self) -> Arc<dyn Transport> {
301 Arc::clone(&self.transport)
302 }
303
304 async fn take_event_receiver(&self) -> Option<BoxEventStream<Event>> {
305 let mut guard = self.external_rx.lock().await;
306 guard.take().map(receiver_into_stream)
307 }
308
309 async fn subscribe_events(&self) -> BoxEventStream<Event> {
310 let rx = self.broadcast_tx.subscribe();
311 let stream = BroadcastStream::new(rx)
312 .filter_map(|item: Result<Event, BroadcastStreamRecvError>| async move { item.ok() });
313 Box::pin(stream)
314 }
315
316 async fn run(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
317 info!("Starting SingleThreadedRuntime {}", self.id);
318 self.event_loop().await
319 }
320
321 async fn stop(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
322 info!("Initiating runtime shutdown for {}", self.id);
323
324 self.internal_tx
326 .send(InternalEvent::Shutdown)
327 .await
328 .map_err(|e| format!("Failed to send shutdown signal: {e}"))?;
329
330 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
332
333 Ok(())
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 use crate::actor::{CloneableMessage, Topic};
341 use crate::runtime::{RuntimeConfig, TypedRuntime};
342 use ractor::{Actor, ActorProcessingErr, ActorRef};
343 use tokio::time::{Duration, sleep};
344
345 #[derive(Clone, Debug)]
347 struct TestMessage {
348 content: String,
349 }
350
351 impl crate::actor::ActorMessage for TestMessage {}
352 impl CloneableMessage for TestMessage {}
353
354 struct TestActor {
356 received: Arc<Mutex<Vec<String>>>,
357 }
358
359 #[async_trait]
360 impl Actor for TestActor {
361 type Msg = TestMessage;
362 type State = ();
363 type Arguments = Arc<Mutex<Vec<String>>>;
364
365 async fn pre_start(
366 &self,
367 _myself: ActorRef<Self::Msg>,
368 _args: Self::Arguments,
369 ) -> Result<Self::State, ActorProcessingErr> {
370 Ok(())
371 }
372
373 async fn handle(
374 &self,
375 _myself: ActorRef<Self::Msg>,
376 message: Self::Msg,
377 _state: &mut Self::State,
378 ) -> Result<(), ActorProcessingErr> {
379 let mut received = self.received.lock().await;
380 received.push(message.content);
381 Ok(())
382 }
383 }
384
385 #[tokio::test]
386 async fn test_runtime_creation() {
387 let runtime = SingleThreadedRuntime::new(None);
388 assert_ne!(runtime.id(), Uuid::nil());
389 }
390
391 #[tokio::test]
392 async fn test_publish_subscribe_cloneable() {
393 let runtime = SingleThreadedRuntime::new(Some(10));
394 let runtime_handle = runtime.clone();
395
396 let runtime_task = tokio::spawn(async move { runtime_handle.run().await });
398
399 let received = Arc::new(Mutex::new(Vec::new()));
401 let (actor_ref, _actor_handle) = Actor::spawn(
402 None,
403 TestActor {
404 received: received.clone(),
405 },
406 received.clone(),
407 )
408 .await
409 .unwrap();
410
411 let topic = Topic::<TestMessage>::new("test_topic");
413 runtime.subscribe(&topic, actor_ref).await.unwrap();
414
415 runtime
417 .publish(
418 &topic,
419 TestMessage {
420 content: "Hello".to_string(),
421 },
422 )
423 .await
424 .unwrap();
425
426 runtime
427 .publish(
428 &topic,
429 TestMessage {
430 content: "World".to_string(),
431 },
432 )
433 .await
434 .unwrap();
435
436 sleep(Duration::from_millis(100)).await;
438
439 let received_msgs = received.lock().await;
441 assert_eq!(received_msgs.len(), 2);
442 assert_eq!(received_msgs[0], "Hello");
443 assert_eq!(received_msgs[1], "World");
444
445 runtime.stop().await.unwrap();
447 runtime_task.abort();
448 }
449
450 #[tokio::test]
451 async fn test_type_safety() {
452 let runtime = SingleThreadedRuntime::new(None);
453 let runtime_handle = runtime.clone();
454
455 let runtime_task = tokio::spawn(async move { runtime_handle.run().await });
457
458 let topic_name = "typed_topic";
460 let topic1 = Topic::<TestMessage>::new(topic_name);
461
462 let received = Arc::new(Mutex::new(Vec::new()));
463 let (actor_ref, _) = Actor::spawn(
464 None,
465 TestActor {
466 received: received.clone(),
467 },
468 received.clone(),
469 )
470 .await
471 .unwrap();
472
473 runtime.subscribe(&topic1, actor_ref).await.unwrap();
474
475 sleep(Duration::from_millis(50)).await;
477
478 #[derive(Clone)]
480 struct OtherMessage;
481 impl crate::actor::ActorMessage for OtherMessage {}
482 impl CloneableMessage for OtherMessage {}
483
484 let topic2 = Topic::<OtherMessage>::new(topic_name);
485
486 struct OtherActor;
487 #[async_trait]
488 impl Actor for OtherActor {
489 type Msg = OtherMessage;
490 type State = ();
491 type Arguments = ();
492
493 async fn pre_start(
494 &self,
495 _myself: ActorRef<Self::Msg>,
496 _args: Self::Arguments,
497 ) -> Result<Self::State, ActorProcessingErr> {
498 Ok(())
499 }
500
501 async fn handle(
502 &self,
503 _myself: ActorRef<Self::Msg>,
504 _message: Self::Msg,
505 _state: &mut Self::State,
506 ) -> Result<(), ActorProcessingErr> {
507 Ok(())
508 }
509 }
510
511 let (other_ref, _) = Actor::spawn(None, OtherActor, ()).await.unwrap();
512
513 let result = runtime.subscribe(&topic2, other_ref).await;
515
516 assert!(result.is_err());
518
519 if let Err(RuntimeError::TopicTypeMismatch(topic, _)) = result {
521 assert_eq!(topic, topic_name);
522 } else {
523 panic!("Expected TopicTypeMismatch error");
524 }
525
526 runtime.stop().await.unwrap();
528 runtime_task.abort();
529 }
530
531 #[tokio::test]
532 async fn test_message_ordering() {
533 let runtime = SingleThreadedRuntime::new(Some(10));
534 let runtime_handle = runtime.clone();
535
536 let runtime_task = tokio::spawn(async move { runtime_handle.run().await });
538
539 let received = Arc::new(Mutex::new(Vec::new()));
541 let (actor_ref, _actor_handle) = Actor::spawn(
542 None,
543 TestActor {
544 received: received.clone(),
545 },
546 received.clone(),
547 )
548 .await
549 .unwrap();
550
551 let topic = Topic::<TestMessage>::new("order_test");
553 runtime.subscribe(&topic, actor_ref).await.unwrap();
554
555 for i in 0..10 {
557 runtime
558 .publish(
559 &topic,
560 TestMessage {
561 content: format!("Message {i}"),
562 },
563 )
564 .await
565 .unwrap();
566 }
567
568 sleep(Duration::from_millis(200)).await;
570
571 let received_msgs = received.lock().await;
573 assert_eq!(received_msgs.len(), 10);
574
575 for (i, msg) in received_msgs.iter().enumerate() {
576 assert_eq!(msg, &format!("Message {i}"));
577 }
578
579 runtime.stop().await.unwrap();
581 runtime_task.abort();
582 }
583
584 #[tokio::test]
585 async fn test_runtime_multiple_topics() {
586 let runtime = SingleThreadedRuntime::new(Some(10));
587 let runtime_handle = runtime.clone();
588
589 let runtime_task = tokio::spawn(async move { runtime_handle.run().await });
591
592 let topic1 = Topic::<TestMessage>::new("topic1");
594 let topic2 = Topic::<TestMessage>::new("topic2");
595
596 let received1 = Arc::new(Mutex::new(Vec::new()));
597 let received2 = Arc::new(Mutex::new(Vec::new()));
598
599 let (actor_ref1, _) = Actor::spawn(
600 None,
601 TestActor {
602 received: received1.clone(),
603 },
604 received1.clone(),
605 )
606 .await
607 .unwrap();
608
609 let (actor_ref2, _) = Actor::spawn(
610 None,
611 TestActor {
612 received: received2.clone(),
613 },
614 received2.clone(),
615 )
616 .await
617 .unwrap();
618
619 runtime.subscribe(&topic1, actor_ref1).await.unwrap();
621 runtime.subscribe(&topic2, actor_ref2).await.unwrap();
622 sleep(Duration::from_millis(50)).await;
623
624 let message1 = TestMessage {
626 content: "topic1_message".to_string(),
627 };
628 runtime.publish(&topic1, message1).await.unwrap();
629 sleep(Duration::from_millis(50)).await;
630
631 let message2 = TestMessage {
633 content: "topic2_message".to_string(),
634 };
635 runtime.publish(&topic2, message2).await.unwrap();
636 sleep(Duration::from_millis(50)).await;
637
638 let received_msgs1 = received1.lock().await;
640 let received_msgs2 = received2.lock().await;
641
642 assert_eq!(received_msgs1.len(), 1);
643 assert_eq!(received_msgs1[0], "topic1_message");
644
645 assert_eq!(received_msgs2.len(), 1);
646 assert_eq!(received_msgs2[0], "topic2_message");
647
648 runtime.stop().await.unwrap();
650 runtime_task.abort();
651 }
652
653 #[tokio::test]
654 async fn test_runtime_subscribe_multiple_actors_same_topic() {
655 let runtime = SingleThreadedRuntime::new(Some(10));
656 let runtime_handle = runtime.clone();
657
658 let runtime_task = tokio::spawn(async move { runtime_handle.run().await });
660
661 let topic = Topic::<TestMessage>::new("shared_topic");
662
663 let received1 = Arc::new(Mutex::new(Vec::new()));
664 let received2 = Arc::new(Mutex::new(Vec::new()));
665
666 let (actor_ref1, _) = Actor::spawn(
667 None,
668 TestActor {
669 received: received1.clone(),
670 },
671 received1.clone(),
672 )
673 .await
674 .unwrap();
675
676 let (actor_ref2, _) = Actor::spawn(
677 None,
678 TestActor {
679 received: received2.clone(),
680 },
681 received2.clone(),
682 )
683 .await
684 .unwrap();
685
686 runtime.subscribe(&topic, actor_ref1).await.unwrap();
688 runtime.subscribe(&topic, actor_ref2).await.unwrap();
689 sleep(Duration::from_millis(50)).await;
690
691 let message = TestMessage {
693 content: "broadcast_message".to_string(),
694 };
695 runtime.publish(&topic, message).await.unwrap();
696 sleep(Duration::from_millis(100)).await;
697
698 let received_msgs1 = received1.lock().await;
700 let received_msgs2 = received2.lock().await;
701
702 assert_eq!(received_msgs1.len(), 1);
703 assert_eq!(received_msgs1[0], "broadcast_message");
704
705 assert_eq!(received_msgs2.len(), 1);
706 assert_eq!(received_msgs2[0], "broadcast_message");
707
708 runtime.stop().await.unwrap();
710 runtime_task.abort();
711 }
712
713 #[test]
714 fn test_runtime_config_creation() {
715 let config = RuntimeConfig {
716 queue_size: Some(100),
717 };
718 assert_eq!(config.queue_size, Some(100));
719 }
720
721 #[test]
722 fn test_runtime_id_generation() {
723 let runtime1 = SingleThreadedRuntime::new(None);
724 let runtime2 = SingleThreadedRuntime::new(None);
725
726 assert_ne!(runtime1.id(), runtime2.id());
727 }
728}