norpc/runtime/
mod.rs

1use std::marker::PhantomData;
2
3use futures::channel::{mpsc, oneshot};
4use futures::StreamExt;
5use std::collections::HashMap;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::Arc;
8
9enum CoreRequest<X, Y> {
10    AppRequest {
11        inner: X,
12        tx: oneshot::Sender<Y>,
13        stream_id: u64,
14    },
15    Cancel {
16        stream_id: u64,
17    },
18}
19
20pub struct ServerBuilder<X, Svc> {
21    svc: Svc,
22    phantom_x: PhantomData<X>,
23}
24impl<X, Svc: crate::Service<X> + 'static + Send> ServerBuilder<X, Svc>
25where
26    X: 'static + Send,
27    Svc::Future: Send,
28    Svc::Response: Send,
29{
30    pub fn new(svc: Svc) -> Self {
31        Self {
32            svc: svc,
33            phantom_x: PhantomData,
34        }
35    }
36    pub fn build(self) -> (Channel<X, Svc::Response>, Server<X, Svc>) {
37        let (tx, rx) = mpsc::unbounded();
38        let server = Server::new(rx, self.svc);
39        let chan = Channel::new(tx);
40        (chan, server)
41    }
42}
43
44pub struct Channel<X, Y> {
45    next_id: Arc<AtomicU64>,
46    stream_id: u64,
47    tx: mpsc::UnboundedSender<CoreRequest<X, Y>>,
48}
49impl<X, Y> Channel<X, Y> {
50    fn new(tx: mpsc::UnboundedSender<CoreRequest<X, Y>>) -> Self {
51        Self {
52            stream_id: 0,
53            next_id: Arc::new(AtomicU64::new(1)),
54            tx: tx,
55        }
56    }
57}
58impl<X, Y> Clone for Channel<X, Y> {
59    fn clone(&self) -> Self {
60        let next_id = self.next_id.clone();
61        let stream_id = next_id.fetch_add(1, Ordering::SeqCst);
62        Self {
63            stream_id,
64            next_id: next_id,
65            tx: self.tx.clone(),
66        }
67    }
68}
69impl<X, Y> Drop for Channel<X, Y> {
70    fn drop(&mut self) {
71        let cancel_req = CoreRequest::Cancel {
72            stream_id: self.stream_id,
73        };
74        self.tx.unbounded_send(cancel_req).ok();
75    }
76}
77impl<X: 'static + Send, Y: 'static + Send> crate::Service<X> for Channel<X, Y> {
78    type Response = Y;
79    type Error = anyhow::Error;
80    type Future =
81        std::pin::Pin<Box<dyn std::future::Future<Output = Result<Y, Self::Error>> + Send>>;
82
83    fn poll_ready(
84        &mut self,
85        _: &mut std::task::Context<'_>,
86    ) -> std::task::Poll<Result<(), Self::Error>> {
87        Ok(()).into()
88    }
89
90    fn call(&mut self, req: X) -> Self::Future {
91        let tx = self.tx.clone();
92        let stream_id = self.stream_id;
93        Box::pin(async move {
94            let (tx1, rx1) = oneshot::channel::<Y>();
95            let req = CoreRequest::AppRequest {
96                inner: req,
97                tx: tx1,
98                stream_id,
99            };
100            if tx.unbounded_send(req).is_err() {
101                anyhow::bail!("failed to send a request");
102            }
103            let rep = rx1.await?;
104            Ok(rep)
105        })
106    }
107}
108
109pub struct Server<X, Svc: crate::Service<X>> {
110    service: Svc,
111    rx: mpsc::UnboundedReceiver<CoreRequest<X, Svc::Response>>,
112}
113impl<X, Svc: crate::Service<X> + 'static + Send> Server<X, Svc>
114where
115    X: 'static + Send,
116    Svc::Future: Send,
117    Svc::Response: Send,
118{
119    fn new(rx: mpsc::UnboundedReceiver<CoreRequest<X, Svc::Response>>, service: Svc) -> Self {
120        Self { service, rx: rx }
121    }
122    pub async fn serve(mut self, executor: impl futures::task::Spawn) {
123        use futures::future::AbortHandle;
124        use futures::task::SpawnExt;
125        let mut processings: HashMap<u64, AbortHandle> = HashMap::new();
126        while let Some(req) = self.rx.next().await {
127            match req {
128                CoreRequest::AppRequest {
129                    inner,
130                    tx,
131                    stream_id,
132                } => {
133                    if let Some(handle) = processings.get(&stream_id) {
134                        handle.abort();
135                    }
136                    processings.remove(&stream_id);
137
138                    // backpressure
139                    crate::poll_fn(|ctx| self.service.poll_ready(ctx))
140                        .await
141                        .ok();
142                    let fut = self.service.call(inner);
143                    let (fut, abort_handle) = futures::future::abortable(async move {
144                        if let Ok(rep) = fut.await {
145                            tx.send(rep).ok();
146                        }
147                    });
148                    let fut = async move {
149                        fut.await.ok();
150                    };
151                    if let Err(e) = executor.spawn(fut) {
152                        abort_handle.abort();
153                    }
154                    processings.insert(stream_id, abort_handle);
155                }
156                CoreRequest::Cancel { stream_id } => {
157                    if let Some(handle) = processings.get(&stream_id) {
158                        handle.abort();
159                    }
160                    processings.remove(&stream_id);
161                }
162            }
163        }
164    }
165}
166
167#[cfg(feature = "tokio-executor")]
168#[cfg_attr(docsrs, doc(cfg(feature = "tokio-executor")))]
169/// Tokio support.
170pub struct TokioExecutor;
171
172#[cfg(feature = "tokio-executor")]
173impl futures::task::Spawn for TokioExecutor {
174    fn spawn_obj(
175        &self,
176        future: futures::task::FutureObj<'static, ()>,
177    ) -> Result<(), futures::task::SpawnError> {
178        tokio::spawn(future);
179        Ok(())
180    }
181}
182
183#[cfg(feature = "async-std-executor")]
184#[cfg_attr(docsrs, doc(cfg(feature = "async-std-executor")))]
185/// async-std support.
186pub struct AsyncStdExecutor;
187
188#[cfg(feature = "async-std-executor")]
189impl futures::task::Spawn for AsyncStdExecutor {
190    fn spawn_obj(
191        &self,
192        future: futures::task::FutureObj<'static, ()>,
193    ) -> Result<(), futures::task::SpawnError> {
194        async_std::task::spawn(future);
195        Ok(())
196    }
197}