1use std::any::Any;
2use std::collections::HashMap;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use anyhow::{Context, Result, anyhow};
8use futures::future::BoxFuture;
9#[cfg(any(feature = "client", feature = "server"))]
10use futures_util::{Sink, SinkExt, Stream, StreamExt};
11use serde::{Deserialize, Serialize};
12use serde_binary::binary_stream::Endian;
13#[cfg(feature = "server")]
14use tokio::net::{TcpListener, TcpStream};
15use tokio::spawn;
16use tokio::sync::{Mutex, RwLock, mpsc};
17use tokio::task::JoinHandle;
18#[cfg(feature = "client")]
19use tokio_tungstenite::connect_async;
20#[cfg(any(feature = "client", feature = "server"))]
21use tokio_tungstenite::tungstenite;
22use tracing::{error, info, trace};
23use ulid::Ulid;
24
25#[cfg(feature = "client")]
26use crate::message::HandshakeMessage;
27use crate::message::{Message, MessageBody, RpcMessage};
28use crate::{
29 CONTEXT, Identity, InstanceId, InventoryItem, MessageHandler, MessageId, ObjectRef,
30 RuntimeContext, RuntimeId, TypeId, TypeInfo,
31};
32
33#[cfg(feature = "server")]
34#[allow(unused_imports)]
35use crate::api;
36
37type MessageFuture = Pin<Box<dyn Future<Output = Message> + Send>>;
39
40type Executor = Box<dyn Fn(Message) -> MessageFuture + Send + Sync>;
42
43struct MessageData {
46 message: Message,
47 return_tx: mpsc::Sender<Message>,
48}
49
50#[derive(Clone, Deserialize, Serialize)]
52pub struct RuntimeInfo {
53 pub id: RuntimeId,
54}
55
56#[crate::async_trait]
57pub trait RuntimeTrait {
58 async fn create(id: RuntimeId) -> Result<Self>
59 where
60 Self: Sized;
61 async fn execute<F>(&self, target_id: RuntimeId, f: F) -> Result<()>
62 where
63 F: Future<Output = Result<()>> + Send + 'static;
64 async fn execute_local<F>(&self, f: F) -> Result<()>
65 where
66 F: Future<Output = Result<()>> + Send + 'static;
67 async fn register_handler<H: MessageHandler>(&self) -> Result<()>;
68 async fn register_instance<T>(&self, instance: T) -> ObjectRef<T>
69 where
70 T: Identity + Any + Send + Sync + 'static;
71 async fn get_instance<T>(&self, id: InstanceId) -> Result<Arc<RwLock<T>>, String>
72 where
73 T: Identity + Any + Send + Sync + 'static;
74 async fn take_instance<T>(&self, id: InstanceId) -> Option<T>
75 where
76 T: Identity + Any + Send + Sync + 'static;
77 async fn call<T: MessageHandler>(
78 &self,
79 target_id: RuntimeId,
80 args: T::Input,
81 ) -> Result<T::Output>;
82 #[cfg(any(feature = "client", feature = "server"))]
83 async fn connected_runtimes(&self) -> Vec<RuntimeInfo>;
84 async fn inventory(&self) -> Vec<InventoryItem>;
85 #[cfg(feature = "client")]
86 async fn connect(&self, addr: String) -> Result<JoinHandle<()>>;
87 #[cfg(feature = "server")]
88 async fn start_server(&self, addr: String) -> JoinHandle<Result<()>>;
89}
90
91#[crate::async_trait]
92pub(crate) trait RuntimeInternalTrait {
93 async fn register_handlers(&self) -> Result<()>;
94 async fn runtime_worker(&self);
95 #[cfg(any(feature = "client", feature = "server"))]
96 async fn message_encoder<S: Sink<tungstenite::Message> + Unpin + Send>(
97 tx: S,
98 rx: mpsc::Receiver<Message>,
99 ) where
100 <S as Sink<tungstenite::Message>>::Error: std::fmt::Display;
101 #[cfg(any(feature = "client", feature = "server"))]
102 async fn message_decoder<
103 S: Stream<Item = std::result::Result<tungstenite::Message, tungstenite::Error>> + Unpin + Send,
104 >(
105 &self,
106 tx: mpsc::Sender<Message>,
107 rx: S,
108 );
109 #[cfg(feature = "server")]
110 async fn accept_connection(&self, stream: TcpStream) -> Result<()>;
111}
112
113pub struct HandlerRegistration {
115 pub register: fn(Runtime) -> BoxFuture<'static, Result<()>>,
117}
118
119inventory::collect!(HandlerRegistration);
121
122pub type Runtime = Arc<RuntimeImpl>;
123
124pub struct RuntimeImpl {
129 id: RuntimeId,
130 executors: Arc<RwLock<HashMap<TypeId, Arc<Executor>>>>,
131 message_handlers: Arc<Mutex<HashMap<Ulid, mpsc::Sender<Message>>>>,
133 inventory: Arc<Mutex<HashMap<TypeId, InventoryItem>>>,
134 runtime_rx: Arc<Mutex<mpsc::Receiver<MessageData>>>,
135 runtime_tx: mpsc::Sender<MessageData>,
136 runtime_worker_handle: Mutex<Option<JoinHandle<()>>>,
137
138 instance_registry: Arc<RwLock<HashMap<Ulid, Arc<dyn Any + Send + Sync>>>>,
140
141 #[cfg(any(feature = "client", feature = "server"))]
143 runtime_registry: Arc<RwLock<HashMap<RuntimeId, mpsc::Sender<Message>>>>,
144}
145
146#[crate::async_trait]
147impl RuntimeTrait for Runtime {
148 async fn create(id: RuntimeId) -> Result<Self> {
153 info!(runtime_id = %id, "Creating new RPC runtime");
154 let (runtime_tx, runtime_rx) = mpsc::channel(1024);
155 let rt = Self::new(RuntimeImpl {
156 id,
157 executors: Arc::new(RwLock::new(HashMap::new())),
158 message_handlers: Arc::new(Mutex::new(HashMap::new())),
159 inventory: Arc::new(Mutex::new(HashMap::new())),
160 runtime_rx: Arc::new(Mutex::new(runtime_rx)),
161 runtime_tx,
162 runtime_worker_handle: Mutex::new(None),
163 instance_registry: Arc::new(RwLock::new(HashMap::new())),
164
165 #[cfg(any(feature = "client", feature = "server"))]
166 runtime_registry: Arc::new(RwLock::new(HashMap::new())),
167 });
168
169 rt.register_handlers().await?;
170 rt.runtime_worker().await;
171
172 Ok(rt)
173 }
174
175 async fn execute<F>(&self, target_id: RuntimeId, f: F) -> Result<()>
181 where
182 F: Future<Output = Result<()>> + Send + 'static,
183 {
184 let (error_tx, mut error_rx) = mpsc::channel(1);
185
186 let ctx = RuntimeContext {
188 target_id: target_id.clone(),
189 runtime: self.clone(),
190 error_tx,
191 };
192
193 let task = spawn(CONTEXT.scope(ctx, f));
194 if let Some(error) = error_rx.recv().await {
195 task.abort();
196 Err(error)
197 } else {
198 task.await?
199 }
200 }
201
202 async fn execute_local<F>(&self, f: F) -> Result<()>
204 where
205 F: Future<Output = Result<()>> + Send + 'static,
206 {
207 self.execute(self.id.clone(), f).await
208 }
209
210 async fn register_handler<H: MessageHandler>(&self) -> Result<()> {
216 let type_id = <<H as MessageHandler>::Input as TypeInfo>::type_id();
218 self.inventory.lock().await.insert(type_id, H::type_info());
219
220 let executor = {
222 let runtime = self.clone();
223 let exec_fn = move |msg: Message| -> MessageFuture {
224 let runtime = runtime.clone();
225 Box::pin(async move {
226 let (error_tx, _error_rx) = mpsc::channel(1);
228 let ctx = RuntimeContext {
229 target_id: runtime.id.clone(),
230 runtime: runtime.clone(),
231 error_tx,
232 };
233
234 CONTEXT
235 .scope(ctx, async move {
236 let reply_data: Result<_, String> = match msg.data {
238 MessageBody::Handshake(_msg) => Ok(vec![]),
239 MessageBody::Rpc(message) => match message.data {
240 Ok(vec) => match serde_binary::from_vec::<H::Input>(
241 vec,
242 Endian::Little,
243 ) {
244 Ok(input) => match H::handle(input).await {
245 Ok(output) => {
246 serde_binary::to_vec(&output, Endian::Little)
247 .map_err(|e| e.to_string())
248 }
249 Err(err_str) => Err(err_str),
250 },
251 Err(err) => Err(err.to_string()),
252 },
253 Err(err_str) => Err(err_str),
254 },
255 };
256
257 Message {
259 target_id: msg.source_id.clone(),
260 source_id: runtime.id.clone(),
261 message_id: msg.message_id,
262 is_answer: true,
263 is_closed: false,
264 data: MessageBody::Rpc(RpcMessage {
265 r#type: <<H as MessageHandler>::Output as TypeInfo>::type_id(),
266 data: reply_data,
267 }),
268 }
269 })
270 .await
271 })
272 };
273 Arc::new(Box::new(exec_fn) as Executor)
274 };
275
276 info!(
278 input_type = %H::Input::type_name(),
279 "Registering RPC handler"
280 );
281 self.executors.write().await.insert(type_id, executor);
282 Ok(())
283 }
284
285 async fn register_instance<T>(&self, instance: T) -> ObjectRef<T>
288 where
289 T: Identity + Any + Send + Sync + 'static,
290 {
291 let id = *instance.id();
292 self.instance_registry
293 .write()
294 .await
295 .insert(id, Arc::new(RwLock::new(instance)));
296 ObjectRef::create(self.id.clone(), id)
297 }
298
299 async fn get_instance<T>(&self, id: InstanceId) -> Result<Arc<RwLock<T>>, String>
305 where
306 T: Identity + Any + Send + Sync + 'static,
307 {
308 let maybe_instance = { self.instance_registry.read().await.get(&id).cloned() };
310
311 match maybe_instance {
312 Some(instance) => match instance.downcast::<RwLock<T>>() {
313 Ok(typed_instance) => Ok(typed_instance),
314 Err(_) => Err("type mismatch".into()),
315 },
316 None => Err("instance not found".into()),
317 }
318 }
319
320 async fn take_instance<T>(&self, id: InstanceId) -> Option<T>
324 where
325 T: Identity + Any + Send + Sync + 'static,
326 {
327 let maybe_instance = {
328 self.instance_registry
329 .write()
330 .await
331 .remove(&id)
332 .and_then(|arc_any| arc_any.downcast::<RwLock<T>>().ok())
333 };
334
335 maybe_instance.and_then(|instance| match Arc::try_unwrap(instance) {
336 Ok(instance) => Some(instance.into_inner()),
337 Err(_) => None,
338 })
339 }
340
341 async fn call<T: MessageHandler>(
346 &self,
347 target_id: RuntimeId,
348 args: T::Input,
349 ) -> Result<T::Output> {
350 let message_id = MessageId::new();
351 trace!(
352 message_id = %message_id,
353 input_type = %T::Input::type_name(),
354 "Sending RPC request"
355 );
356
357 let result = async move {
358 let data = Message {
360 target_id,
361 source_id: self.id.clone(),
362 message_id,
363 is_answer: false,
364 is_closed: false,
365 data: MessageBody::Rpc(RpcMessage {
366 r#type: <<T as MessageHandler>::Input as TypeInfo>::type_id(),
367 data: Ok(serde_binary::to_vec(&args, Endian::Little)?),
368 }),
369 };
370
371 let (return_tx, mut return_rx) = mpsc::channel(1);
373
374 {
376 self.message_handlers
377 .lock()
378 .await
379 .insert(message_id, return_tx.clone());
380 }
381
382 self.runtime_tx
384 .send(MessageData {
385 return_tx,
386 message: data,
387 })
388 .await
389 .map_err(|e| anyhow!(e))?;
390
391 let ret_msg = return_rx.recv().await.context("no answer")?;
393
394 match ret_msg.data {
396 MessageBody::Rpc(message) => serde_binary::from_vec::<T::Output>(
397 message.data.map_err(|e| anyhow!(e))?,
398 Endian::Little,
399 )
400 .context("failed to deserialize response"),
401 _ => Err(anyhow!("unsupported answer type")),
402 }
403 }
404 .await;
405
406 self.message_handlers.lock().await.remove(&message_id);
408
409 result
410 }
411
412 #[cfg(any(feature = "client", feature = "server"))]
413 async fn connected_runtimes(&self) -> Vec<RuntimeInfo> {
415 self.runtime_registry
416 .read()
417 .await
418 .keys()
419 .map(|id| RuntimeInfo { id: id.clone() })
420 .collect()
421 }
422
423 async fn inventory(&self) -> Vec<InventoryItem> {
425 self.inventory.lock().await.values().cloned().collect()
426 }
427
428 #[cfg(feature = "client")]
429 async fn connect(&self, addr: String) -> Result<JoinHandle<()>> {
430 let (conn, _) = connect_async(&addr).await?;
431 let (mut conn_tx, conn_rx) = conn.split();
432 let (return_tx, return_rx) = mpsc::channel(16);
433
434 self.runtime_registry
435 .write()
436 .await
437 .insert("server".into(), return_tx.clone());
438
439 conn_tx
441 .send(tungstenite::Message::Binary(
442 serde_binary::to_vec(
443 &Message {
444 target_id: "server".into(),
445 source_id: self.id.clone(),
446 message_id: Ulid::new(),
447 is_answer: false,
448 is_closed: true,
449 data: MessageBody::Handshake(HandshakeMessage {
450 runtime_id: self.id.clone(),
451 }),
452 },
453 Endian::Little,
454 )?
455 .into(),
456 ))
457 .await?;
458
459 let runtime = self.clone();
460
461 Ok(spawn(async move {
462 let sender_handle = spawn(Self::message_encoder(conn_tx, return_rx));
464
465 runtime.message_decoder(return_tx, conn_rx).await;
467
468 sender_handle.abort();
469 }))
470 }
471
472 #[cfg(feature = "server")]
473 async fn start_server(&self, addr: String) -> JoinHandle<Result<()>> {
474 let runtime = self.clone();
475 spawn(async move {
476 let server = TcpListener::bind(addr).await?;
477
478 loop {
479 let (socket, _) = server.accept().await?;
480
481 let runtime = runtime.clone();
482 spawn(async move {
483 if let Err(e) = runtime.accept_connection(socket).await {
484 error!("connection error: {}", e);
485 }
486 });
487 }
488 })
489 }
490}
491
492#[crate::async_trait]
493impl RuntimeInternalTrait for Runtime {
494 async fn register_handlers(&self) -> Result<()> {
496 for reg in inventory::iter::<HandlerRegistration> {
497 (reg.register)(self.clone()).await?;
498 }
499 Ok(())
500 }
501
502 async fn runtime_worker(&self) {
507 let runtime = self.clone();
508 info!(runtime_id = %runtime.id, "Starting RPC runtime event loop");
509
510 let handle = spawn(async move {
511 let mut rx = runtime.runtime_rx.lock().await;
513 while let Some(MessageData { return_tx, message }) = rx.recv().await {
514 let msg_id = message.message_id;
515 trace!(message_id = %msg_id, is_answer = %message.is_answer, "Received message");
516
517 if message.target_id == runtime.id {
519 if message.is_answer {
521 if let Some(handler) = {
522 let mut handlers = runtime.message_handlers.lock().await;
523 handlers.remove(&msg_id)
524 } {
525 trace!(message_id = %msg_id, "Received answer");
526 let _ = handler.send(message).await;
527 }
528 } else {
529 match message.data {
530 #[cfg(any(feature = "client", feature = "server"))]
532 MessageBody::Handshake(ref msg) => {
533 info!(runtime = %msg.runtime_id, "Connected");
534 runtime
535 .runtime_registry
536 .write()
537 .await
538 .insert(msg.runtime_id.clone(), return_tx);
539 }
540 #[cfg(not(any(feature = "client", feature = "server")))]
542 MessageBody::Handshake(_) => {}
543 MessageBody::Rpc(ref msg) => {
544 if let Some(exec) =
545 { runtime.executors.read().await.get(&msg.r#type).cloned() }
546 {
547 trace!(message_id = %msg_id, "Dispatching to handler");
548 let handler_exec = exec.clone();
549 let handler_msg = message;
550 let handler_tx = return_tx.clone();
551 spawn(async move {
552 let response = handler_exec(handler_msg).await;
553 trace!(message_id = %msg_id, "Handler execution complete");
554 if let Err(err) = handler_tx.send(response).await {
555 error!(
556 message_id = %msg_id,
557 error = ?err,
558 "Failed to send RPC response"
559 );
560 }
561 });
562 } else {
563 error!(
564 message_id = %msg_id,
565 message_type = ?msg.r#type,
566 "No executor registered for message type"
567 );
568 let _ = return_tx
569 .send(Message {
570 target_id: message.source_id,
571 source_id: runtime.id.clone(),
572 message_id: Ulid::new(),
573 is_answer: true,
574 is_closed: true,
575 data: MessageBody::Rpc(RpcMessage {
576 r#type: msg.r#type,
577 data: Err("no handler registered".into()),
578 }),
579 })
580 .await;
581 }
582 }
583 }
584 }
585 } else {
586 #[cfg(any(feature = "client", feature = "server"))]
588 {
589 let target = message.target_id.clone();
590 let tx_opt =
591 { runtime.runtime_registry.read().await.get(&target).cloned() };
592
593 if let Some(tx) = tx_opt {
594 if let Err(err) = tx.send(message).await {
595 error!(
596 message_id = %msg_id,
597 target_id = %target,
598 error = ?err,
599 "Failed to forward message to runtime"
600 );
601 }
602 } else {
603 let server_tx_opt =
604 { runtime.runtime_registry.read().await.get("server").cloned() };
605 if let Some(tx) = server_tx_opt {
606 if let Err(err) = tx.send(message).await {
607 error!(
608 message_id = %msg_id,
609 target_id = %target,
610 error = ?err,
611 "Failed to forward message to runtime"
612 );
613 }
614 }
615 }
616 }
617 #[cfg(not(any(feature = "client", feature = "server")))]
619 unreachable!("Received message for unknown target {}", message.target_id);
620 }
621 }
622 });
623
624 self.runtime_worker_handle.lock().await.replace(handle);
625 }
626
627 #[cfg(any(feature = "client", feature = "server"))]
628 async fn message_encoder<S: Sink<tungstenite::Message> + Unpin + Send>(
629 mut tx: S,
630 mut rx: mpsc::Receiver<Message>,
631 ) where
632 <S as Sink<tungstenite::Message>>::Error: std::fmt::Display,
633 {
634 while let Some(message) = rx.recv().await {
635 match serde_binary::to_vec(&message, Endian::Little) {
636 Ok(message) => {
637 let message = tungstenite::Message::Binary(message.into());
638 if let Err(e) = tx.send(message).await {
639 error!("error sending message: {}", e);
640 }
641 }
642 Err(e) => error!("error sending message: {}", e),
643 }
644 }
645 }
646
647 #[cfg(any(feature = "client", feature = "server"))]
648 async fn message_decoder<
649 S: Stream<Item = std::result::Result<tungstenite::Message, tungstenite::Error>> + Unpin + Send,
650 >(
651 &self,
652 tx: mpsc::Sender<Message>,
653 mut rx: S,
654 ) {
655 while let Some(message) = rx.next().await {
656 match message {
657 Ok(tungstenite::Message::Binary(message)) => {
658 match serde_binary::from_slice(&message, Endian::Little) {
659 Ok(message) => {
660 let _ = self
661 .runtime_tx
662 .send(MessageData {
663 message,
664 return_tx: tx.clone(),
665 })
666 .await;
667 }
668 Err(e) => error!("error decoding message: {}", e),
669 }
670 }
671 Ok(_) => {
672 error!("unsupported message type");
673 break;
674 }
675 Err(e) => {
676 error!("connection error: {}", e);
677 break;
678 }
679 }
680 }
681 }
682
683 #[cfg(feature = "server")]
684 async fn accept_connection(&self, stream: TcpStream) -> Result<()> {
685 let ws = tokio_tungstenite::accept_async(stream).await?;
686 let (stream_tx, stream_rx) = ws.split();
687 let (return_tx, return_rx) = mpsc::channel(16);
688
689 let transfer_handle = spawn(Self::message_encoder(stream_tx, return_rx));
690 self.message_decoder(return_tx, stream_rx).await;
691 transfer_handle.abort();
692
693 Ok(())
694 }
695}