acril_rt/
lib.rs

1#![doc = include_str!("../README.md")]
2#![forbid(unsafe_code)]
3#![allow(incomplete_features)]
4#![feature(return_type_notation)]
5
6use std::{
7    any::{Any, TypeId},
8    cell::RefCell,
9    collections::HashMap,
10    marker::PhantomData,
11    pin::Pin,
12    sync::Arc,
13    thread,
14};
15
16use acril::{Future, Handler, Service};
17use tokio::{
18    sync::{
19        mpsc::{self, unbounded_channel, Sender, UnboundedReceiver, UnboundedSender},
20        oneshot, RwLock,
21    },
22    task::{JoinHandle, LocalSet},
23};
24
25/// A convenience alias to a [`Box`]`<dyn `[`Any`]` + Send + Sync + 'static>`.
26type Unknown = Box<dyn Any + Send + Sync + 'static>;
27/// A type-erased service.
28type AnyProc = Unknown;
29/// A type-erased message.
30type AnyMsg = Unknown;
31
32/// A type-erased function, which handles a message sent to a service.
33type ProcessRunner = Arc<
34    dyn Fn(
35            AnyProc,
36            AddrErased,
37            AnyMsg,
38            UnboundedSender<RuntimeCommand>,
39        ) -> Pin<Box<dyn Future<Output = AnyProc>>>
40        + Send
41        + Sync,
42>;
43
44/// The sender half of a channel for sending messages to running services.
45type MessageSender = UnboundedSender<(u32, ProcessRunner, AnyMsg)>;
46
47/// A handle to a running [`Service`], allowing to send messages to it.
48pub struct Addr<S: Service> {
49    erased: AddrErased,
50    phantom: PhantomData<S>,
51}
52
53impl<S: Service> Clone for Addr<S> {
54    fn clone(&self) -> Self {
55        Self {
56            erased: self.erased.clone(),
57            phantom: PhantomData,
58        }
59    }
60}
61
62/// An erased service address, containing an ID and message sender.
63#[derive(Clone)]
64struct AddrErased {
65    id: u32,
66    runner_tx: MessageSender,
67}
68
69impl<S: Service<Context = Context<S>>> Addr<S> {
70    /// Send a message to the service pointed to by this address.
71    pub async fn send<M>(&self, msg: M) -> Result<S::Response, S::Error>
72    where
73        S: Handler<M> + Send + Sync + 'static,
74        M: 'static + Send + Sync,
75        S::Error: Send + Sync,
76        S::Response: Send + Sync,
77    {
78        let (tx, mut rx) = mpsc::channel(1);
79
80        self.erased
81            .runner_tx
82            .send((
83                self.erased.id,
84                message_handler::<S, M>(Some(tx)),
85                Box::new(msg),
86            ))
87            .unwrap();
88
89        rx.recv().await.unwrap()
90    }
91
92    /// Just send a message, without waiting for a response.
93    ///
94    /// Be aware that by using this function you allow desynchonization to happen, as this function
95    /// doesn't wait for the response.
96    ///
97    /// If you don't want desyncs to happen, use [`send`](Self::send).
98    pub fn do_send<M>(&self, msg: M)
99    where
100        S: Handler<M> + Send + Sync + 'static,
101        M: 'static + Send + Sync,
102        S::Error: Send + Sync,
103        S::Response: Send + Sync,
104    {
105        self.erased
106            .runner_tx
107            .send((self.erased.id, message_handler::<S, M>(None), Box::new(msg)))
108            .unwrap()
109    }
110}
111
112/// A service context.
113///
114/// Any service that is spawned onto a [`Runtime`] needs to have this type as context.
115pub struct Context<S: Service<Context = Self>> {
116    addr: Addr<S>,
117    commands: UnboundedSender<RuntimeCommand>,
118}
119
120impl<S: Service<Context = Self>> Context<S> {
121    /// Get the address of this service.
122    ///
123    /// Beware that if you use this to send a message to yourself (with [`Addr::send`]), you will
124    /// encounter a deadlock because to handle a message you need send a message, which just makes
125    /// the message handler never finish.
126    pub fn this(&self) -> Addr<S> {
127        self.addr.clone()
128    }
129
130    /// Spawn a future onto the current arbiter.
131    pub fn spawn<F: Future + 'static>(&self, fut: F) -> JoinHandle<F::Output> {
132        tokio::task::spawn_local(fut)
133    }
134
135    /// Try to retrieve the address of a singleton service `T`.
136    /// If there are more than one or no instances of `T` running,
137    /// this method will return a [`SingletonError`].
138    pub async fn try_singleton<T: Service<Context = Context<T>> + Any>(
139        &self,
140    ) -> Result<Addr<T>, SingletonError> {
141        let (addr_send, addr_recv) = oneshot::channel();
142
143        let _ = self.commands.send(RuntimeCommand::GetAddrOf {
144            ty: TypeId::of::<T>(),
145            addr: addr_send,
146        });
147
148        addr_recv
149            .await
150            .expect("runtime has been shut down")
151            .map(|erased| Addr {
152                erased,
153                phantom: PhantomData,
154            })
155    }
156
157    /// Retrieve an address to the singleton service `T`.
158    /// This method panics if the service wasn't running
159    /// or there were more than one instance of it.
160    pub async fn singleton<T: Service<Context = Context<T>> + Any>(&self) -> Addr<T> {
161        self.try_singleton().await.unwrap_or_else(|error| {
162            panic!(
163                "A singleton {T} was not available: {error:?}",
164                T = std::any::type_name::<T>()
165            )
166        })
167    }
168}
169
170/// Make a new message handler. This takes a responder, which is an optional channel for sending
171/// the response of the service. A context is created inline to avoid storing another thing inside
172/// of the event loop.
173fn message_handler<S, M>(responder: Option<Sender<Result<S::Response, S::Error>>>) -> ProcessRunner
174where
175    S: 'static + Handler<M, Context = Context<S>> + Send + Sync,
176    M: 'static + Send + Sync,
177    S::Error: Send + Sync,
178    S::Response: Send + Sync,
179{
180    Arc::new(move |actor, erased, msg, commands| {
181        let mut proc = actor.downcast::<S>().unwrap();
182        let msg = msg.downcast::<M>().unwrap();
183        let responder = responder.clone();
184
185        Box::pin(async move {
186            let res = proc
187                .call(
188                    *msg,
189                    &mut Context {
190                        commands,
191                        addr: Addr {
192                            erased,
193                            phantom: PhantomData::<S>,
194                        },
195                    },
196                )
197                .await;
198
199            if let Some(responder) = &responder {
200                responder.send(res).await.ok();
201            }
202
203            proc as Unknown
204        })
205    })
206}
207
208thread_local! {
209    /// A handle to the arbiter that the current task is running on.
210    static HANDLE: RefCell<Option<ArbiterHandle>> = RefCell::new(None);
211}
212
213/// An arbiter is a single-threaded event loop, allowing users to spawn tasks onto it.
214pub struct Arbiter {
215    thread_handle: thread::JoinHandle<()>,
216    // store a handle instead of the tx itself so we can make Arbiter deref to the handle
217    // to not copy-paste spawn and stop methods
218    arb: ArbiterHandle,
219}
220
221/// A command to an arbiter.
222enum ArbiterCommand {
223    /// Stop the arbiter
224    Stop,
225    /// Execute a future.
226    Execute(Pin<Box<dyn Future<Output = ()> + Send>>),
227}
228
229impl Arbiter {
230    async fn runner(mut rx: UnboundedReceiver<ArbiterCommand>) {
231        // clever trick: the loop ends if there is a `None` or `Some(Stop)`
232        while let Some(ArbiterCommand::Execute(fu)) = rx.recv().await {
233            tokio::task::spawn_local(fu);
234        }
235    }
236
237    /// Get a handle to the arbiter that the current task is running in.
238    /// If the arbiter is not available, this function panics. If you don't want a panic, use
239    /// [`Self::try_current`], which returns an [`Option`]`<`[`ArbiterHandle`]`>`.
240    pub fn current() -> ArbiterHandle {
241        Self::try_current().expect("no arbiter was available")
242    }
243
244    /// Get a handle to the arbiter that the current task is running in.
245    /// If the arbiter is not available, this returns [`None`].
246    pub fn try_current() -> Option<ArbiterHandle> {
247        HANDLE.with_borrow(|x| x.as_ref().map(|a| ArbiterHandle { tx: a.tx.clone() }))
248    }
249
250    /// Get a handle to this arbiter.
251    pub fn handle(&self) -> &ArbiterHandle {
252        &self.arb
253    }
254
255    pub fn new() -> Self {
256        Self::with_tokio_rt(|| {
257            tokio::runtime::Builder::new_current_thread()
258                .build()
259                .unwrap()
260        })
261    }
262
263    pub fn with_tokio_rt(factory: impl Fn() -> tokio::runtime::Runtime + Send + 'static) -> Self {
264        let (tx, rx) = unbounded_channel();
265        let (ready_tx, ready_rx) = std::sync::mpsc::channel::<()>();
266
267        Self {
268            thread_handle: thread::Builder::new()
269                .name("acril-rt-arbiter".to_string())
270                .spawn({
271                    let tx = tx.clone();
272                    move || {
273                        let tokio = factory();
274                        let local_set = LocalSet::new();
275                        let _guard = local_set.enter();
276
277                        // "register" the arbiter
278                        HANDLE.set(Some(ArbiterHandle { tx }));
279
280                        ready_tx.send(()).unwrap();
281
282                        tokio.block_on(local_set.run_until(Self::runner(rx)));
283
284                        // de-"register" the arbiter
285                        HANDLE.set(None);
286                    }
287                })
288                .unwrap(),
289            arb: ArbiterHandle {
290                tx: {
291                    ready_rx.recv().unwrap();
292                    tx
293                },
294            },
295        }
296    }
297
298    pub fn join(self) -> std::thread::Result<()> {
299        self.thread_handle.join()
300    }
301}
302
303/// A handle to an arbiter, allowing to spawn futures onto the arbiter or stop it.
304#[derive(Clone)]
305pub struct ArbiterHandle {
306    tx: UnboundedSender<ArbiterCommand>,
307}
308
309impl std::ops::Deref for Arbiter {
310    type Target = ArbiterHandle;
311    fn deref(&self) -> &Self::Target {
312        &self.arb
313    }
314}
315
316impl ArbiterHandle {
317    pub fn spawn(&self, future: impl Future<Output = ()> + Send + 'static) -> bool {
318        self.tx
319            .send(ArbiterCommand::Execute(Box::pin(future)))
320            .is_ok()
321    }
322
323    pub fn stop(&self) -> bool {
324        self.tx.send(ArbiterCommand::Stop).is_ok()
325    }
326}
327
328/// A command sent to a runtime.
329enum RuntimeCommand {
330    /// Spawn a service.
331    Process {
332        ty: TypeId,
333        proc: AnyProc,
334        addr: oneshot::Sender<AddrErased>,
335    },
336    /// Get the address of a singleton service with said [`TypeId`].
337    GetAddrOf {
338        ty: TypeId,
339        addr: oneshot::Sender<Result<AddrErased, SingletonError>>,
340    },
341}
342
343/// The main component of `acril_rt` - the runtime. It stores the running services
344/// and handles their lifecycle and messages sent to them.
345pub struct Runtime {
346    command_sender: UnboundedSender<RuntimeCommand>,
347}
348
349impl Runtime {
350    pub fn new() -> Self {
351        let (command_sender, command_recv) = unbounded_channel();
352
353        let _ = tokio::task::spawn_local(Self::event_loop(command_recv, command_sender.clone()));
354
355        Self { command_sender }
356    }
357
358    pub fn new_in(arbiter: &ArbiterHandle) -> Self {
359        let (command_sender, command_recv) = unbounded_channel();
360
361        arbiter.spawn(Self::event_loop(command_recv, command_sender.clone()));
362
363        Self { command_sender }
364    }
365
366    pub async fn spawn<S: Service<Context = Context<S>> + Send + Sync + 'static>(
367        &self,
368        service: S,
369    ) -> Addr<S> {
370        let (addr_send, addr_recv) = oneshot::channel();
371
372        self.command_sender
373            .send(RuntimeCommand::Process {
374                ty: service.type_id(),
375                proc: Box::new(service),
376                addr: addr_send,
377            })
378            .unwrap();
379
380        Addr {
381            erased: addr_recv.await.unwrap(),
382            phantom: PhantomData,
383        }
384    }
385
386    async fn event_loop(
387        mut commands: UnboundedReceiver<RuntimeCommand>,
388        commands_sender: UnboundedSender<RuntimeCommand>,
389    ) {
390        let processes: Arc<RwLock<HashMap<u32, (TypeId, AnyProc)>>> =
391            Arc::new(RwLock::new(HashMap::new()));
392        let mut count: u32 = 0;
393        let (message_sender, mut messages): (MessageSender, _) = unbounded_channel();
394
395        loop {
396            tokio::select! {
397                Some((id, runner, msg)) = messages.recv() => {
398                            let (ty, proc) = processes.write().await.remove(&id).unwrap();
399
400                            tokio::task::spawn_local({ let processes = processes.clone(); let commands_sender = commands_sender.clone(); let runner_tx = message_sender.clone(); async move {
401                                let proc = runner(proc, AddrErased { id, runner_tx }, msg, commands_sender).await;
402                                processes.write().await.insert(id, (ty, proc));
403                            }});
404                }
405                Some(command) = commands.recv() => {
406                    match command {
407                        RuntimeCommand::Process { ty, proc, addr } => {
408                            let id = count;
409                            processes.write().await.insert(id, (ty, proc));
410                            count += 1;
411                            let _ = addr.send(AddrErased { id, runner_tx: message_sender.clone() });
412                        }
413                        RuntimeCommand::GetAddrOf { ty, addr } => {
414                            let _ = addr.send(
415                                only_one(processes.read().await.iter()
416                                    .filter(|(_id, (ty_, _))| *ty_ == ty))
417                                    .map(|(id, _)| AddrErased { id: *id, runner_tx: message_sender.clone() }).map_err(|e| if e.is_some() { SingletonError::MoreThanOne } else { SingletonError::NotPresent})
418                            );
419                        }
420                    }
421                }
422                else => break
423            }
424        }
425    }
426}
427
428/// An error returned while retrieving a singleton service.
429#[derive(Debug, Clone, PartialEq, Eq)]
430pub enum SingletonError {
431    /// There were more than one instances of the requested service.
432    MoreThanOne,
433    /// The requested singleton service was not present.
434    NotPresent,
435}
436
437/// Require the iterator to have only one item, erroring if it has less than one or more than one.
438/// If there were more than one items, the next item is present as Err(Some(next_item))
439pub fn only_one<I: Iterator>(mut iter: I) -> Result<I::Item, Option<I::Item>> {
440    let item = iter.next().ok_or(None)?;
441
442    if let Some(next) = iter.next() {
443        Err(Some(next))
444    } else {
445        Ok(item)
446    }
447}
448
449/// `use acril_rt::prelude::*;` to import the commonly used types.
450pub mod prelude {
451    #[doc(no_inline)]
452    pub use acril::{self, Handler, Service};
453    pub use crate::{Addr, Arbiter, Context, Runtime};
454}