1use futures::{future::BoxFuture, prelude::*, ready};
2use log::*;
3use std::{
4    collections::hash_map::DefaultHasher,
5    hash::{Hash, Hasher},
6    ops::DerefMut,
7    pin::Pin,
8    sync::{Arc, RwLock},
9    task::{Context, Poll},
10};
11use tokio::{
12    runtime::Builder,
13    sync::{mpsc, oneshot},
14    task::JoinHandle,
15    task_local,
16};
17
18type Sender = mpsc::UnboundedSender<BoxFuture<'static, ()>>;
19type Receiver = mpsc::UnboundedReceiver<BoxFuture<'static, ()>>;
20type CactchResult<T> = std::result::Result<T, Box<dyn std::any::Any + 'static + Send>>;
21
22task_local! {
23    pub static ID: u64;
24}
25
26pub fn current() -> u64 {
28    ID.get()
29}
30
31lazy_static::lazy_static! {
32    static ref GLOBAL_ROUTER: RwLock<Option<Router>> = RwLock::new(None);
33}
34
35#[derive(Clone)]
36pub struct Router {
37    tx: Arc<Vec<Sender>>,
38}
39
40pub struct Via<T>(oneshot::Receiver<CactchResult<T>>);
41
42impl<T> Via<T> {
43    fn new<F, R>(tx: &Sender, f: F) -> Self
44    where
45        T: Send + 'static,
46        F: FnOnce() -> R,
47        R: Future<Output = T> + Send + 'static,
48    {
49        let (otx, orx) = oneshot::channel();
50
51        let fut = std::panic::AssertUnwindSafe(f())
52            .catch_unwind()
53            .then(move |r| async move {
54                let _ = otx.send(r);
55            })
56            .boxed();
57
58        if tx.send(fut).is_err() {
59            panic!("Couldn't send future to router; the future will never be resolved");
60        }
61
62        Self(orx)
63    }
64}
65
66impl<T> Future for Via<T> {
67    type Output = T;
68
69    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
70        match ready!(self.0.poll_unpin(cx)) {
71            Ok(output) => Poll::Ready(output.expect("panic in the future in the ert router")),
72            Err(_) => panic!("the future in ert was cancelled"),
73        }
74    }
75}
76
77fn open(workers: usize) -> (Vec<Sender>, Vec<Receiver>) {
78    (0..workers).map(|_| mpsc::unbounded_channel()).unzip()
79}
80
81fn run(rxs: Vec<Receiver>) -> Vec<JoinHandle<()>> {
82    rxs.into_iter()
83        .enumerate()
84        .map(|(i, rx)| {
85            tokio::spawn(rx.for_each(move |t| ID.scope(i as u64, async move { t.await })))
86        })
87        .collect()
88}
89
90impl Router {
91    pub fn new(workers: usize) -> Self {
92        if workers == 0 {
93            panic!("Invalid number of workers: {}", workers);
94        }
95
96        let (txs, rxs) = open(workers);
97
98        run(rxs);
99
100        Self { tx: Arc::new(txs) }
101    }
102
103    pub fn run_on_thread(workers: usize) -> Self {
104        let (txs, rxs) = open(workers);
105
106        std::thread::spawn(move || {
107            let mut rt = Builder::new()
108                .threaded_scheduler()
109                .enable_all()
110                .build()
111                .unwrap();
112            rt.block_on(async move {
113                if let Err(e) = futures::future::try_join_all(run(rxs)).await {
114                    error!("Couldn't join router worker thread successfully: {}", e);
115                }
116            });
117        });
118
119        Self { tx: Arc::new(txs) }
120    }
121
122    pub fn set_as_global(self) {
123        *GLOBAL_ROUTER.write().unwrap() = Some(self);
124    }
125
126    pub fn with_global<F, R>(f: F) -> R
127    where
128        F: FnOnce(Option<&Router>) -> R,
129    {
130        f(GLOBAL_ROUTER.read().unwrap().as_ref())
131    }
132
133    pub fn with_global_mut<F, R>(f: F) -> R
134    where
135        F: FnOnce(&mut Option<Router>) -> R,
136    {
137        f(GLOBAL_ROUTER.write().unwrap().deref_mut())
138    }
139
140    pub fn via<K, F, T, R>(&self, key: K, f: F) -> Via<T>
141    where
142        K: Hash,
143        T: Send + 'static,
144        F: FnOnce() -> R,
145        R: Future<Output = T> + Send + 'static,
146    {
147        let h = {
148            let mut hasher = DefaultHasher::new();
149            key.hash(&mut hasher);
150            hasher.finish() as usize
151        };
152
153        Via::new(&self.tx[h % self.tx.len()], f)
154    }
155}