1#![recursion_limit = "512"]
40use future::FusedFuture;
41use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
42use futures::channel::oneshot::{channel, Sender};
43use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
44use futures::lock::Mutex;
45use futures::select;
46use futures::stream::{FuturesUnordered, StreamExt};
47use futures::{self, Future};
48use futures::{
49 future,
50 future::{BoxFuture, FutureExt},
51};
52use serde::{de::DeserializeOwned, Deserialize, Serialize};
53use std::fmt::Debug;
54use std::{collections::HashMap, pin::Pin, sync::Arc};
55use thiserror::Error;
56use tracing::{instrument, trace, debug_span, Instrument};
57
58#[cfg(test)]
59mod test;
60
61pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {}
63impl<T> AsyncReadWrite for T where T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {}
64
65pub trait Msg: Serialize + DeserializeOwned + Send + Sync + 'static {}
68impl<T> Msg for T where T: Serialize + DeserializeOwned + Send + Sync + 'static {}
69
70#[derive(Error, Debug)]
71pub enum Error {
72 #[error("queue management error")]
73 Queue,
74
75 #[error("unknown response {0}")]
76 UnknownResponse(u64),
77
78 #[error("response channel {0} dropped")]
79 ResponseChannelDropped(u64),
80
81 #[error(transparent)]
82 Io(#[from] futures::io::Error),
83
84 #[error(transparent)]
85 Postcard(#[from] postcard::Error),
86
87 #[error(transparent)]
88 Canceled(#[from] futures::channel::oneshot::Canceled),
89}
90
91struct EnvelopeWriter {
92 wr: Box<dyn AsyncWrite + Unpin + Send + Sync + 'static>,
93}
94
95impl EnvelopeWriter {
96 fn new<W: AsyncWrite + Unpin + Send + Sync + 'static>(wtr: W) -> Self {
97 EnvelopeWriter { wr: Box::new(wtr) }
98 }
99
100 #[instrument(skip(self, e))]
102 async fn send<OneWay: Msg, Request: Msg, Response: Msg>(
103 &mut self,
104 e: Envelope<OneWay, Request, Response>,
105 ) -> Result<(), Error> {
106 use byteorder_async::{LittleEndian, WriterToByteOrder};
107 let bytes = postcard::to_allocvec(&e)?;
108 let wsz: u64 = bytes.len() as u64;
109 self.wr.byte_order().write_u64::<LittleEndian>(wsz).await?;
110 self.wr.write_all(bytes.as_slice()).await?;
111 Ok(())
112 }
113}
114
115struct EnvelopeReader {
116 rd: Box<dyn AsyncRead + Unpin + Send + Sync + 'static>,
117 rdbuf: Vec<u8>,
118}
119
120impl EnvelopeReader {
121 fn new<R: AsyncRead + Unpin + Send + Sync + 'static>(rdr: R) -> Self {
122 EnvelopeReader {
123 rd: Box::new(rdr),
124 rdbuf: Vec::new(),
125 }
126 }
127
128 #[instrument(skip(self))]
130 async fn recv<OneWay: Msg, Request: Msg, Response: Msg>(
131 &mut self,
132 ) -> Result<Envelope<OneWay, Request, Response>, Error> {
133 use byteorder_async::{LittleEndian, ReaderToByteOrder};
134 let rsz: u64 = self.rd.byte_order().read_u64::<LittleEndian>().await?;
135 self.rdbuf.resize(rsz as usize, 0);
136 self.rd.read_exact(self.rdbuf.as_mut_slice()).await?;
137 Ok(postcard::from_bytes(self.rdbuf.as_slice())?)
138 }
139}
140
141struct Reception<OneWay: Msg, Request: Msg, Response: Msg> {
149 next_request: u64,
151
152 requests: HashMap<u64, Sender<Response>>,
156
157 enqueue: UnboundedSender<Envelope<OneWay, Request, Response>>,
159}
160
161pub struct Queue<OneWay: Msg, Request: Msg, Response: Msg> {
165 reception: Arc<Mutex<Reception<OneWay, Request, Response>>>,
166}
167
168impl<OneWay: Msg, Request: Msg, Response: Msg> Clone for Queue<OneWay, Request, Response> {
169 fn clone(&self) -> Self {
170 Self {
171 reception: self.reception.clone(),
172 }
173 }
174}
175
176impl<OneWay: Msg, Request: Msg, Response: Msg> Queue<OneWay, Request, Response> {
177 fn new(reception: Reception<OneWay, Request, Response>) -> Self {
178 Self {
179 reception: Arc::new(Mutex::new(reception)),
180 }
181 }
182
183 pub fn enqueue_oneway(
185 &self,
186 oneway: OneWay,
187 ) -> impl Future<Output = Result<(), Error>> + 'static {
188 let reception = self.reception.clone();
189 async move {
190 let env = Envelope::<OneWay, Request, Response>::OneWay(oneway);
191 let guard = reception.lock().await;
192 guard.enqueue.unbounded_send(env).map_err(|_| Error::Queue)
193 }
194 }
195
196 pub fn enqueue_request(
199 &self,
200 req: Request,
201 ) -> impl Future<Output = Result<Response, Error>> + 'static {
202 let reception = self.reception.clone();
203 async move {
204 let (send_err, recv) = {
205 let mut guard = reception.lock().await;
206 let curr = guard.next_request;
207 let env = Envelope::<OneWay, Request, Response>::Request(curr, req);
208 let send_err = guard.enqueue.unbounded_send(env);
209 let (send, recv) = channel();
210 if send_err.is_ok() {
211 tracing::trace!(?curr, "enqueued envelope for request");
212 guard.next_request += 1;
213 guard.requests.insert(curr, send);
214 }
215 (send_err, recv)
217 };
218 if send_err.is_ok() {
219 Ok(recv.await?)
220 } else {
221 Err(Error::Queue)
222 }
223 }
224 }
225}
226
227type PendingWrite = Pin<Box<dyn FusedFuture<Output = Result<(), Error>> + Send + Sync + 'static>>;
228type PendingRead<OneWay, Request, Response> = Pin<
229 Box<
230 dyn FusedFuture<Output = Result<Envelope<OneWay, Request, Response>, Error>>
231 + Send
232 + Sync
233 + 'static,
234 >,
235>;
236pub struct Connection<OneWay: Msg, Request: Msg, Response: Msg> {
237 reader: Arc<Mutex<EnvelopeReader>>,
241
242 writer: Arc<Mutex<EnvelopeWriter>>,
246
247 reads_in_progress: FuturesUnordered<PendingRead<OneWay, Request, Response>>,
250
251 writes_in_progress: FuturesUnordered<PendingWrite>,
254
255 pub queue: Queue<OneWay, Request, Response>,
258
259 dequeue: UnboundedReceiver<Envelope<OneWay, Request, Response>>,
261
262 responses: FuturesUnordered<BoxFuture<'static, (u64, Response)>>,
264
265 envelope_count: usize,
267}
268
269#[serde(bound = "")]
270#[derive(Debug, Serialize, Deserialize)]
271enum Envelope<OneWay: Msg, Request: Msg, Response: Msg> {
272 OneWay(OneWay),
273 Request(u64, Request),
274 Response(u64, Response),
275}
276
277impl<OneWay: Msg, Request: Msg, Response: Msg> Connection<OneWay, Request, Response> {
285 pub fn new_split<R, W>(rdr: R, wtr: W) -> Self
289 where
290 R: AsyncRead + Unpin + Send + Sync + 'static,
291 W: AsyncWrite + Unpin + Send + Sync + 'static,
292 {
293 let reader = Arc::new(Mutex::new(EnvelopeReader::new(rdr)));
294 let writer = Arc::new(Mutex::new(EnvelopeWriter::new(wtr)));
295 let next_request = 0;
296 let requests = HashMap::new();
297 let responses = FuturesUnordered::new();
298 let (enqueue, dequeue) = unbounded();
299 let queue = Queue::new(Reception {
300 next_request,
301 requests,
302 enqueue,
303 });
304
305 let reads_in_progress = FuturesUnordered::new();
306 let writes_in_progress = FuturesUnordered::new();
307 Connection {
308 reader,
309 writer,
310 queue,
311 reads_in_progress,
312 writes_in_progress,
313 responses,
314 dequeue,
315 envelope_count: 0,
316 }
317 }
318
319 pub fn new<RW: AsyncReadWrite>(rw: RW) -> Self {
322 let (rdr, wtr) = rw.split();
323 Self::new_split(rdr, wtr)
324 }
325
326 pub fn enqueue_oneway(
328 &self,
329 oneway: OneWay,
330 ) -> impl Future<Output = Result<(), Error>> + 'static {
331 self.queue.enqueue_oneway(oneway)
332 }
333
334 pub fn enqueue_request(
336 &self,
337 req: Request,
338 ) -> impl Future<Output = Result<Response, Error>> + 'static {
339 self.queue.enqueue_request(req)
340 }
341
342 fn issue_read(&mut self) {
343 let rdr = self.reader.clone();
344 let fut = Box::pin(
345 async move { rdr.lock().await.recv::<OneWay, Request, Response>().await }.fuse(),
346 );
347 self.reads_in_progress.push(fut);
348 }
349
350 fn issue_write(&mut self, env: Envelope<OneWay, Request, Response>) {
351 let wtr = self.writer.clone();
352 let fut = Box::pin(async move { wtr.lock().await.send(env).await }.fuse());
353 self.writes_in_progress.push(fut);
354 }
355
356 pub async fn advance<ServeRequest, FutureResponse, ServeOneWay>(
368 &mut self,
369 srv_req: ServeRequest,
370 srv_ow: ServeOneWay,
371 ) -> Result<(), Error>
372 where
373 ServeRequest: FnOnce(Request) -> FutureResponse,
374 FutureResponse: Future<Output = Response> + Send + 'static,
375 ServeOneWay: FnOnce(OneWay) -> (),
376 {
377 if self.reads_in_progress.len() == 0 {
378 self.issue_read();
379 }
380 select! {
381 next_written = self.writes_in_progress.next() => match next_written {
382 None => (Ok(())),
383 Some(res) => res
384 },
385 next_enqueued = self.dequeue.next() => match next_enqueued {
386 None => Ok(()),
387 Some(env) => {
388 trace!("dequeued envelope, sending");
389 self.issue_write(env);
390 Ok(())
391 }
392 },
393 next_response = self.responses.next() => {
394 match next_response {
395 None => Ok(()),
396 Some((n, response)) => {
397 trace!(n, "finished serving request, enqueueing response");
398 let env = Envelope::Response(n, response);
399 let guard = self.queue.reception.lock().await;
400 guard.enqueue.unbounded_send(env).map_err(|_| Error::Queue)
401 }
402 }
403 },
404 next_read = self.reads_in_progress.next() => match next_read {
405 None => Ok(()),
406 Some(read_result) => {
407 self.envelope_count += 1;
408 self.issue_read();
409 let env = read_result?;
410 match env {
411 Envelope::OneWay(ow) => {
412 trace!("received one-way, calling service function");
413 let span = debug_span!("RPC", e=self.envelope_count);
414 Ok(span.in_scope(|| srv_ow(ow)))
415 },
416 Envelope::Request(n, req) => {
417 trace!(n, "received request, calling service function");
418 let span = debug_span!("RPC", e=self.envelope_count);
419 let res_fut = srv_req(req);
420 let boxed : BoxFuture<'static,_> = Box::pin(res_fut.instrument(span).map(move |r| (n, r)));
421 Ok(self.responses.push(boxed))
422 },
423 Envelope::Response(n, res) => {
424 trace!(n, "received response, fulfilling future");
425 let mut guard = self.queue.reception.lock().await;
426 match guard.requests.remove(&n.clone()) {
427 None => Err(Error::UnknownResponse(n)),
428 Some(send) => {
429 match send.send(res) {
430 Ok(_) => Ok(()),
431 Err(_) => Err(Error::ResponseChannelDropped(n))
432 }
433 }
434 }
435 }
436 }
437 }
438 }
439 }
440 }
441}