1use super::{Runtime, RuntimeError};
2use crate::agent::constants::DEFAULT_CHANNEL_BUFFER;
3use crate::protocol::InternalEvent;
4use crate::utils::{receiver_into_stream, BoxEventStream};
5use crate::{
6 actor::{AnyActor, Transport},
7 error::Error,
8 protocol::{Event, RuntimeID},
9};
10use async_trait::async_trait;
11use log::{debug, error, info, warn};
12use std::{
13 any::{Any, TypeId},
14 collections::HashMap,
15 sync::{
16 atomic::{AtomicBool, Ordering},
17 Arc,
18 },
19};
20use tokio::sync::{mpsc, Mutex, Notify, RwLock};
21use uuid::Uuid;
22
23const DEFAULT_INTERNAL_BUFFER: usize = 1000;
24
25#[derive(Debug)]
27struct Subscription {
28 topic_type: TypeId,
29 actors: Vec<Arc<dyn AnyActor>>,
30}
31
32#[derive(Debug)]
33pub struct SingleThreadedRuntime {
35 pub id: RuntimeID,
36 external_tx: mpsc::Sender<Event>,
38 external_rx: Mutex<Option<mpsc::Receiver<Event>>>,
39 internal_tx: mpsc::Sender<InternalEvent>,
41 internal_rx: Mutex<Option<mpsc::Receiver<InternalEvent>>>,
42 subscriptions: Arc<RwLock<HashMap<String, Subscription>>>,
44 transport: Arc<dyn Transport>,
46 shutdown_flag: Arc<AtomicBool>,
48 shutdown_notify: Arc<Notify>,
49}
50
51impl SingleThreadedRuntime {
52 pub fn new(channel_buffer: Option<usize>) -> Arc<Self> {
53 Self::with_transport(channel_buffer, Arc::new(crate::actor::LocalTransport))
54 }
55
56 pub fn with_transport(
57 channel_buffer: Option<usize>,
58 transport: Arc<dyn Transport>,
59 ) -> Arc<Self> {
60 let id = Uuid::new_v4();
61 let buffer_size = channel_buffer.unwrap_or(DEFAULT_CHANNEL_BUFFER);
62
63 let (external_tx, external_rx) = mpsc::channel(buffer_size);
65 let (internal_tx, internal_rx) = mpsc::channel(DEFAULT_INTERNAL_BUFFER);
66
67 Arc::new(Self {
68 id,
69 external_tx,
70 external_rx: Mutex::new(Some(external_rx)),
71 internal_tx,
72 internal_rx: Mutex::new(Some(internal_rx)),
73 subscriptions: Arc::new(RwLock::new(HashMap::new())),
74 transport,
75 shutdown_flag: Arc::new(AtomicBool::new(false)),
76 shutdown_notify: Arc::new(Notify::new()),
77 })
78 }
79
80 async fn process_internal_event(&self, event: InternalEvent) -> Result<(), Error> {
82 debug!("Received internal event: {event:?}");
83 match event {
84 InternalEvent::ProtocolEvent(event) => {
85 self.process_protocol_event(event).await?;
86 }
87 InternalEvent::Shutdown => {
88 self.shutdown_flag.store(true, Ordering::SeqCst);
89 self.shutdown_notify.notify_waiters();
90 }
91 }
92 Ok(())
93 }
94
95 async fn process_protocol_event(&self, event: Event) -> Result<(), Error> {
97 match event {
98 Event::PublishMessage {
99 topic_type,
100 topic_name,
101 message,
102 } => {
103 self.handle_publish_message(&topic_name, topic_type, message)
104 .await?;
105 }
106 _ => {
107 self.external_tx
109 .send(event)
110 .await
111 .map_err(RuntimeError::EventError)?;
112 }
113 }
114 Ok(())
115 }
116
117 async fn handle_publish_message(
119 &self,
120 topic_name: &str,
121 topic_type: TypeId,
122 message: Arc<dyn Any + Send + Sync>,
123 ) -> Result<(), RuntimeError> {
124 debug!("Handling publish event: {topic_name}");
125
126 let subscriptions = self.subscriptions.read().await;
127
128 if let Some(subscription) = subscriptions.get(topic_name) {
129 if subscription.topic_type != topic_type {
131 error!(
132 "Type mismatch for topic '{}': expected {:?}, got {:?}",
133 topic_name, subscription.topic_type, topic_type
134 );
135 return Err(RuntimeError::TopicTypeMismatch(
136 topic_name.to_owned(),
137 topic_type,
138 ));
139 }
140
141 for actor in &subscription.actors {
143 if let Err(e) = self
144 .transport
145 .send(actor.as_ref(), Arc::clone(&message))
146 .await
147 {
148 error!("Failed to send message to subscriber: {e}");
149 }
150 }
151 } else {
152 debug!("No subscribers for topic: {topic_name}");
153 }
154
155 Ok(())
156 }
157
158 async fn handle_subscribe(
160 &self,
161 topic_name: &str,
162 topic_type: TypeId,
163 actor: Arc<dyn AnyActor>,
164 ) -> Result<(), RuntimeError> {
165 info!("Actor subscribing to topic: {topic_name}");
166
167 let mut subscriptions = self.subscriptions.write().await;
168
169 match subscriptions.get_mut(topic_name) {
170 Some(subscription) => {
171 if subscription.topic_type != topic_type {
173 return Err(RuntimeError::TopicTypeMismatch(
174 topic_name.to_string(),
175 subscription.topic_type,
176 ));
177 }
178 subscription.actors.push(actor);
179 }
180 None => {
181 subscriptions.insert(
183 topic_name.to_string(),
184 Subscription {
185 topic_type,
186 actors: vec![actor],
187 },
188 );
189 }
190 }
191
192 Ok(())
193 }
194
195 async fn event_loop(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
197 let mut internal_rx = self
198 .internal_rx
199 .lock()
200 .await
201 .take()
202 .ok_or("Internal receiver already taken")?;
203
204 info!("Runtime event loop starting");
205
206 loop {
207 tokio::select! {
208 Some(event) = internal_rx.recv() => {
210 debug!("Processing internal event");
211
212 if matches!(event, InternalEvent::Shutdown) {
214 info!("Received shutdown event");
215 self.process_internal_event(event).await?;
216 break;
217 }
218
219 if let Err(e) = self.process_internal_event(event).await {
220 error!("Error processing internal event: {e}");
221 break;
222 }
223 }
224 _ = self.shutdown_notify.notified() => {
226 if self.shutdown_flag.load(Ordering::SeqCst) {
227 info!("Runtime received shutdown notification");
228 break;
229 }
230 }
231 else => {
233 warn!("Internal event channel closed");
234 break;
235 }
236 }
237 }
238
239 info!("Draining remaining events before shutdown");
241 while let Ok(event) = internal_rx.try_recv() {
242 if let Err(e) = self.process_internal_event(event).await {
243 error!("Error processing event during shutdown: {e}");
244 }
245 }
246
247 info!("Runtime event loop stopped");
248 Ok(())
249 }
250}
251
252#[async_trait]
253impl Runtime for SingleThreadedRuntime {
254 fn id(&self) -> RuntimeID {
255 self.id
256 }
257
258 async fn subscribe_any(
259 &self,
260 topic_name: &str,
261 topic_type: TypeId,
262 actor: Arc<dyn AnyActor>,
263 ) -> Result<(), RuntimeError> {
264 self.handle_subscribe(topic_name, topic_type, actor).await
265 }
266
267 async fn publish_any(
268 &self,
269 topic_name: &str,
270 topic_type: TypeId,
271 message: Arc<dyn Any + Send + Sync>,
272 ) -> Result<(), RuntimeError> {
273 self.handle_publish_message(topic_name, topic_type, message)
274 .await
275 }
276
277 fn tx(&self) -> mpsc::Sender<Event> {
278 let internal_tx = self.internal_tx.clone();
280 let (interceptor_tx, mut interceptor_rx) = mpsc::channel::<Event>(DEFAULT_CHANNEL_BUFFER);
281
282 tokio::spawn(async move {
283 while let Some(event) = interceptor_rx.recv().await {
284 if let Err(e) = internal_tx.send(InternalEvent::ProtocolEvent(event)).await {
285 error!("Failed to forward event to internal channel: {e}");
286 break;
287 }
288 }
289 });
290
291 interceptor_tx
292 }
293
294 async fn transport(&self) -> Arc<dyn Transport> {
295 Arc::clone(&self.transport)
296 }
297
298 async fn take_event_receiver(&self) -> Option<BoxEventStream<Event>> {
299 let mut guard = self.external_rx.lock().await;
300 guard.take().map(receiver_into_stream)
301 }
302
303 async fn run(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
304 info!("Starting SingleThreadedRuntime {}", self.id);
305 self.event_loop().await
306 }
307
308 async fn stop(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
309 info!("Initiating runtime shutdown for {}", self.id);
310
311 self.internal_tx
313 .send(InternalEvent::Shutdown)
314 .await
315 .map_err(|e| format!("Failed to send shutdown signal: {e}"))?;
316
317 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
319
320 Ok(())
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use crate::actor::{CloneableMessage, Topic};
328 use crate::runtime::{RuntimeConfig, TypedRuntime};
329 use ractor::{Actor, ActorProcessingErr, ActorRef};
330 use tokio::time::{sleep, Duration};
331
332 #[derive(Clone, Debug)]
334 struct TestMessage {
335 content: String,
336 }
337
338 impl crate::actor::ActorMessage for TestMessage {}
339 impl CloneableMessage for TestMessage {}
340
341 struct TestActor {
343 received: Arc<Mutex<Vec<String>>>,
344 }
345
346 #[async_trait]
347 impl Actor for TestActor {
348 type Msg = TestMessage;
349 type State = ();
350 type Arguments = Arc<Mutex<Vec<String>>>;
351
352 async fn pre_start(
353 &self,
354 _myself: ActorRef<Self::Msg>,
355 _args: Self::Arguments,
356 ) -> Result<Self::State, ActorProcessingErr> {
357 Ok(())
358 }
359
360 async fn handle(
361 &self,
362 _myself: ActorRef<Self::Msg>,
363 message: Self::Msg,
364 _state: &mut Self::State,
365 ) -> Result<(), ActorProcessingErr> {
366 let mut received = self.received.lock().await;
367 received.push(message.content);
368 Ok(())
369 }
370 }
371
372 #[tokio::test]
373 async fn test_runtime_creation() {
374 let runtime = SingleThreadedRuntime::new(None);
375 assert_ne!(runtime.id(), Uuid::nil());
376 }
377
378 #[tokio::test]
379 async fn test_publish_subscribe_cloneable() {
380 let runtime = SingleThreadedRuntime::new(Some(10));
381 let runtime_handle = runtime.clone();
382
383 let runtime_task = tokio::spawn(async move { runtime_handle.run().await });
385
386 let received = Arc::new(Mutex::new(Vec::new()));
388 let (actor_ref, _actor_handle) = Actor::spawn(
389 None,
390 TestActor {
391 received: received.clone(),
392 },
393 received.clone(),
394 )
395 .await
396 .unwrap();
397
398 let topic = Topic::<TestMessage>::new("test_topic");
400 runtime.subscribe(&topic, actor_ref).await.unwrap();
401
402 runtime
404 .publish(
405 &topic,
406 TestMessage {
407 content: "Hello".to_string(),
408 },
409 )
410 .await
411 .unwrap();
412
413 runtime
414 .publish(
415 &topic,
416 TestMessage {
417 content: "World".to_string(),
418 },
419 )
420 .await
421 .unwrap();
422
423 sleep(Duration::from_millis(100)).await;
425
426 let received_msgs = received.lock().await;
428 assert_eq!(received_msgs.len(), 2);
429 assert_eq!(received_msgs[0], "Hello");
430 assert_eq!(received_msgs[1], "World");
431
432 runtime.stop().await.unwrap();
434 runtime_task.abort();
435 }
436
437 #[tokio::test]
438 async fn test_type_safety() {
439 let runtime = SingleThreadedRuntime::new(None);
440 let runtime_handle = runtime.clone();
441
442 let runtime_task = tokio::spawn(async move { runtime_handle.run().await });
444
445 let topic_name = "typed_topic";
447 let topic1 = Topic::<TestMessage>::new(topic_name);
448
449 let received = Arc::new(Mutex::new(Vec::new()));
450 let (actor_ref, _) = Actor::spawn(
451 None,
452 TestActor {
453 received: received.clone(),
454 },
455 received.clone(),
456 )
457 .await
458 .unwrap();
459
460 runtime.subscribe(&topic1, actor_ref).await.unwrap();
461
462 sleep(Duration::from_millis(50)).await;
464
465 #[derive(Clone)]
467 struct OtherMessage;
468 impl crate::actor::ActorMessage for OtherMessage {}
469 impl CloneableMessage for OtherMessage {}
470
471 let topic2 = Topic::<OtherMessage>::new(topic_name);
472
473 struct OtherActor;
474 #[async_trait]
475 impl Actor for OtherActor {
476 type Msg = OtherMessage;
477 type State = ();
478 type Arguments = ();
479
480 async fn pre_start(
481 &self,
482 _myself: ActorRef<Self::Msg>,
483 _args: Self::Arguments,
484 ) -> Result<Self::State, ActorProcessingErr> {
485 Ok(())
486 }
487
488 async fn handle(
489 &self,
490 _myself: ActorRef<Self::Msg>,
491 _message: Self::Msg,
492 _state: &mut Self::State,
493 ) -> Result<(), ActorProcessingErr> {
494 Ok(())
495 }
496 }
497
498 let (other_ref, _) = Actor::spawn(None, OtherActor, ()).await.unwrap();
499
500 let result = runtime.subscribe(&topic2, other_ref).await;
502
503 assert!(result.is_err());
505
506 if let Err(RuntimeError::TopicTypeMismatch(topic, _)) = result {
508 assert_eq!(topic, topic_name);
509 } else {
510 panic!("Expected TopicTypeMismatch error");
511 }
512
513 runtime.stop().await.unwrap();
515 runtime_task.abort();
516 }
517
518 #[tokio::test]
519 async fn test_message_ordering() {
520 let runtime = SingleThreadedRuntime::new(Some(10));
521 let runtime_handle = runtime.clone();
522
523 let runtime_task = tokio::spawn(async move { runtime_handle.run().await });
525
526 let received = Arc::new(Mutex::new(Vec::new()));
528 let (actor_ref, _actor_handle) = Actor::spawn(
529 None,
530 TestActor {
531 received: received.clone(),
532 },
533 received.clone(),
534 )
535 .await
536 .unwrap();
537
538 let topic = Topic::<TestMessage>::new("order_test");
540 runtime.subscribe(&topic, actor_ref).await.unwrap();
541
542 for i in 0..10 {
544 runtime
545 .publish(
546 &topic,
547 TestMessage {
548 content: format!("Message {i}"),
549 },
550 )
551 .await
552 .unwrap();
553 }
554
555 sleep(Duration::from_millis(200)).await;
557
558 let received_msgs = received.lock().await;
560 assert_eq!(received_msgs.len(), 10);
561
562 for (i, msg) in received_msgs.iter().enumerate() {
563 assert_eq!(msg, &format!("Message {i}"));
564 }
565
566 runtime.stop().await.unwrap();
568 runtime_task.abort();
569 }
570
571 #[tokio::test]
572 async fn test_runtime_multiple_topics() {
573 let runtime = SingleThreadedRuntime::new(Some(10));
574 let runtime_handle = runtime.clone();
575
576 let runtime_task = tokio::spawn(async move { runtime_handle.run().await });
578
579 let topic1 = Topic::<TestMessage>::new("topic1");
581 let topic2 = Topic::<TestMessage>::new("topic2");
582
583 let received1 = Arc::new(Mutex::new(Vec::new()));
584 let received2 = Arc::new(Mutex::new(Vec::new()));
585
586 let (actor_ref1, _) = Actor::spawn(
587 None,
588 TestActor {
589 received: received1.clone(),
590 },
591 received1.clone(),
592 )
593 .await
594 .unwrap();
595
596 let (actor_ref2, _) = Actor::spawn(
597 None,
598 TestActor {
599 received: received2.clone(),
600 },
601 received2.clone(),
602 )
603 .await
604 .unwrap();
605
606 runtime.subscribe(&topic1, actor_ref1).await.unwrap();
608 runtime.subscribe(&topic2, actor_ref2).await.unwrap();
609 sleep(Duration::from_millis(50)).await;
610
611 let message1 = TestMessage {
613 content: "topic1_message".to_string(),
614 };
615 runtime.publish(&topic1, message1).await.unwrap();
616 sleep(Duration::from_millis(50)).await;
617
618 let message2 = TestMessage {
620 content: "topic2_message".to_string(),
621 };
622 runtime.publish(&topic2, message2).await.unwrap();
623 sleep(Duration::from_millis(50)).await;
624
625 let received_msgs1 = received1.lock().await;
627 let received_msgs2 = received2.lock().await;
628
629 assert_eq!(received_msgs1.len(), 1);
630 assert_eq!(received_msgs1[0], "topic1_message");
631
632 assert_eq!(received_msgs2.len(), 1);
633 assert_eq!(received_msgs2[0], "topic2_message");
634
635 runtime.stop().await.unwrap();
637 runtime_task.abort();
638 }
639
640 #[tokio::test]
641 async fn test_runtime_subscribe_multiple_actors_same_topic() {
642 let runtime = SingleThreadedRuntime::new(Some(10));
643 let runtime_handle = runtime.clone();
644
645 let runtime_task = tokio::spawn(async move { runtime_handle.run().await });
647
648 let topic = Topic::<TestMessage>::new("shared_topic");
649
650 let received1 = Arc::new(Mutex::new(Vec::new()));
651 let received2 = Arc::new(Mutex::new(Vec::new()));
652
653 let (actor_ref1, _) = Actor::spawn(
654 None,
655 TestActor {
656 received: received1.clone(),
657 },
658 received1.clone(),
659 )
660 .await
661 .unwrap();
662
663 let (actor_ref2, _) = Actor::spawn(
664 None,
665 TestActor {
666 received: received2.clone(),
667 },
668 received2.clone(),
669 )
670 .await
671 .unwrap();
672
673 runtime.subscribe(&topic, actor_ref1).await.unwrap();
675 runtime.subscribe(&topic, actor_ref2).await.unwrap();
676 sleep(Duration::from_millis(50)).await;
677
678 let message = TestMessage {
680 content: "broadcast_message".to_string(),
681 };
682 runtime.publish(&topic, message).await.unwrap();
683 sleep(Duration::from_millis(100)).await;
684
685 let received_msgs1 = received1.lock().await;
687 let received_msgs2 = received2.lock().await;
688
689 assert_eq!(received_msgs1.len(), 1);
690 assert_eq!(received_msgs1[0], "broadcast_message");
691
692 assert_eq!(received_msgs2.len(), 1);
693 assert_eq!(received_msgs2[0], "broadcast_message");
694
695 runtime.stop().await.unwrap();
697 runtime_task.abort();
698 }
699
700 #[test]
701 fn test_runtime_config_creation() {
702 let config = RuntimeConfig {
703 queue_size: Some(100),
704 };
705 assert_eq!(config.queue_size, Some(100));
706 }
707
708 #[test]
709 fn test_runtime_id_generation() {
710 let runtime1 = SingleThreadedRuntime::new(None);
711 let runtime2 = SingleThreadedRuntime::new(None);
712
713 assert_ne!(runtime1.id(), runtime2.id());
714 }
715}