1use std::marker::PhantomData;
2
3use futures::channel::{mpsc, oneshot};
4use futures::StreamExt;
5use std::collections::HashMap;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::Arc;
8
9enum CoreRequest<X, Y> {
10 AppRequest {
11 inner: X,
12 tx: oneshot::Sender<Y>,
13 stream_id: u64,
14 },
15 Cancel {
16 stream_id: u64,
17 },
18}
19
20pub struct ServerBuilder<X, Svc> {
21 svc: Svc,
22 phantom_x: PhantomData<X>,
23}
24impl<X, Svc: crate::Service<X> + 'static + Send> ServerBuilder<X, Svc>
25where
26 X: 'static + Send,
27 Svc::Future: Send,
28 Svc::Response: Send,
29{
30 pub fn new(svc: Svc) -> Self {
31 Self {
32 svc: svc,
33 phantom_x: PhantomData,
34 }
35 }
36 pub fn build(self) -> (Channel<X, Svc::Response>, Server<X, Svc>) {
37 let (tx, rx) = mpsc::unbounded();
38 let server = Server::new(rx, self.svc);
39 let chan = Channel::new(tx);
40 (chan, server)
41 }
42}
43
44pub struct Channel<X, Y> {
45 next_id: Arc<AtomicU64>,
46 stream_id: u64,
47 tx: mpsc::UnboundedSender<CoreRequest<X, Y>>,
48}
49impl<X, Y> Channel<X, Y> {
50 fn new(tx: mpsc::UnboundedSender<CoreRequest<X, Y>>) -> Self {
51 Self {
52 stream_id: 0,
53 next_id: Arc::new(AtomicU64::new(1)),
54 tx: tx,
55 }
56 }
57}
58impl<X, Y> Clone for Channel<X, Y> {
59 fn clone(&self) -> Self {
60 let next_id = self.next_id.clone();
61 let stream_id = next_id.fetch_add(1, Ordering::SeqCst);
62 Self {
63 stream_id,
64 next_id: next_id,
65 tx: self.tx.clone(),
66 }
67 }
68}
69impl<X, Y> Drop for Channel<X, Y> {
70 fn drop(&mut self) {
71 let cancel_req = CoreRequest::Cancel {
72 stream_id: self.stream_id,
73 };
74 self.tx.unbounded_send(cancel_req).ok();
75 }
76}
77impl<X: 'static + Send, Y: 'static + Send> crate::Service<X> for Channel<X, Y> {
78 type Response = Y;
79 type Error = anyhow::Error;
80 type Future =
81 std::pin::Pin<Box<dyn std::future::Future<Output = Result<Y, Self::Error>> + Send>>;
82
83 fn poll_ready(
84 &mut self,
85 _: &mut std::task::Context<'_>,
86 ) -> std::task::Poll<Result<(), Self::Error>> {
87 Ok(()).into()
88 }
89
90 fn call(&mut self, req: X) -> Self::Future {
91 let tx = self.tx.clone();
92 let stream_id = self.stream_id;
93 Box::pin(async move {
94 let (tx1, rx1) = oneshot::channel::<Y>();
95 let req = CoreRequest::AppRequest {
96 inner: req,
97 tx: tx1,
98 stream_id,
99 };
100 if tx.unbounded_send(req).is_err() {
101 anyhow::bail!("failed to send a request");
102 }
103 let rep = rx1.await?;
104 Ok(rep)
105 })
106 }
107}
108
109pub struct Server<X, Svc: crate::Service<X>> {
110 service: Svc,
111 rx: mpsc::UnboundedReceiver<CoreRequest<X, Svc::Response>>,
112}
113impl<X, Svc: crate::Service<X> + 'static + Send> Server<X, Svc>
114where
115 X: 'static + Send,
116 Svc::Future: Send,
117 Svc::Response: Send,
118{
119 fn new(rx: mpsc::UnboundedReceiver<CoreRequest<X, Svc::Response>>, service: Svc) -> Self {
120 Self { service, rx: rx }
121 }
122 pub async fn serve(mut self, executor: impl futures::task::Spawn) {
123 use futures::future::AbortHandle;
124 use futures::task::SpawnExt;
125 let mut processings: HashMap<u64, AbortHandle> = HashMap::new();
126 while let Some(req) = self.rx.next().await {
127 match req {
128 CoreRequest::AppRequest {
129 inner,
130 tx,
131 stream_id,
132 } => {
133 if let Some(handle) = processings.get(&stream_id) {
134 handle.abort();
135 }
136 processings.remove(&stream_id);
137
138 crate::poll_fn(|ctx| self.service.poll_ready(ctx))
140 .await
141 .ok();
142 let fut = self.service.call(inner);
143 let (fut, abort_handle) = futures::future::abortable(async move {
144 if let Ok(rep) = fut.await {
145 tx.send(rep).ok();
146 }
147 });
148 let fut = async move {
149 fut.await.ok();
150 };
151 if let Err(e) = executor.spawn(fut) {
152 abort_handle.abort();
153 }
154 processings.insert(stream_id, abort_handle);
155 }
156 CoreRequest::Cancel { stream_id } => {
157 if let Some(handle) = processings.get(&stream_id) {
158 handle.abort();
159 }
160 processings.remove(&stream_id);
161 }
162 }
163 }
164 }
165}
166
167#[cfg(feature = "tokio-executor")]
168#[cfg_attr(docsrs, doc(cfg(feature = "tokio-executor")))]
169pub struct TokioExecutor;
171
172#[cfg(feature = "tokio-executor")]
173impl futures::task::Spawn for TokioExecutor {
174 fn spawn_obj(
175 &self,
176 future: futures::task::FutureObj<'static, ()>,
177 ) -> Result<(), futures::task::SpawnError> {
178 tokio::spawn(future);
179 Ok(())
180 }
181}
182
183#[cfg(feature = "async-std-executor")]
184#[cfg_attr(docsrs, doc(cfg(feature = "async-std-executor")))]
185pub struct AsyncStdExecutor;
187
188#[cfg(feature = "async-std-executor")]
189impl futures::task::Spawn for AsyncStdExecutor {
190 fn spawn_obj(
191 &self,
192 future: futures::task::FutureObj<'static, ()>,
193 ) -> Result<(), futures::task::SpawnError> {
194 async_std::task::spawn(future);
195 Ok(())
196 }
197}