ort_tcp/
muxer.rs

1use crate::next_or_pending;
2use bytes::{Buf, BufMut, BytesMut};
3use drain::Watch as Drain;
4use futures::{prelude::*, stream::FuturesUnordered};
5use std::collections::HashMap;
6use tokio::{
7    io,
8    sync::{mpsc, oneshot},
9};
10use tokio_util::codec::{Decoder, Encoder};
11use tracing::{debug, debug_span, error, info, trace, Instrument};
12
13#[derive(Default, Debug)]
14pub struct Muxer<E, D> {
15    buffer_capacity: usize,
16    encoder: FramedEncode<E>,
17    decoder: FramedDecode<D>,
18}
19
20#[derive(Debug)]
21pub struct Frame<T> {
22    pub id: u64,
23    pub value: T,
24}
25
26#[derive(Default, Debug)]
27pub struct FramedEncode<E> {
28    inner: E,
29}
30
31#[derive(Debug)]
32pub struct FramedDecode<D> {
33    inner: D,
34    state: DecodeState,
35}
36
37#[derive(Debug)]
38enum DecodeState {
39    Init,
40    Head { id: u64 },
41}
42
43pub fn spawn_client<Req, Rsp, W, R>(
44    mut write: W,
45    mut read: R,
46    buffer_capacity: usize,
47) -> mpsc::Sender<(Req, oneshot::Sender<Rsp>)>
48where
49    Req: Send + 'static,
50    Rsp: Send + 'static,
51    W: Sink<Frame<Req>, Error = io::Error> + Send + Unpin + 'static,
52    R: Stream<Item = io::Result<Frame<Rsp>>> + Send + Unpin + 'static,
53{
54    let (req_tx, mut req_rx) = mpsc::channel(buffer_capacity);
55
56    tokio::spawn(
57        async move {
58            let mut next_id = 1u64;
59            let mut in_flight = HashMap::<u64, oneshot::Sender<Rsp>>::new();
60
61            loop {
62                if next_id == std::u64::MAX {
63                    info!("Client exhausted request IDs");
64                    break;
65                }
66
67                tokio::select! {
68                    // Read requests from the stream and write them on the socket.
69                    // Stash the response oneshot for when the response is read.
70                    req = req_rx.recv() => match req {
71                        Some((value, rsp_tx)) => {
72                            let id = next_id;
73                            next_id += 1;
74                            trace!(id, "Dispatching request");
75                            let f = in_flight.entry(id);
76                            if let std::collections::hash_map::Entry::Occupied(_) = f {
77                                error!(id, "Request ID already in-flight");
78                                return Err(io::Error::new(
79                                    io::ErrorKind::InvalidInput,
80                                    "Request ID is already in-flight",
81                                ));
82                            }
83                            if let Err(error) = write.send(Frame { id, value }).await {
84                                error!(id, %error, "Failed to write response");
85                                return Err(error);
86                            }
87                            f.or_insert(rsp_tx);
88                        }
89                        None => {
90                            debug!("Client dropped its send handle");
91                            break;
92                        }
93                    },
94
95                    // Read responses from the socket and send them back to the
96                    // client.
97                    rsp = read.try_next() => match rsp? {
98                        Some(Frame { id, value }) => {
99                            trace!(id, "Dispatching response");
100                            match in_flight.remove(&id) {
101                                Some(tx) => {
102                                    let _ = tx.send(value);
103                                }
104                                None => return Err(io::Error::new(
105                                    io::ErrorKind::InvalidInput,
106                                    "Response for unknown request",
107                                )),
108                            }
109                        }
110                        None => {
111                            debug!(in_flight=in_flight.len(), "Server closed");
112                            if in_flight.is_empty() {
113                                return Ok(());
114                            } else {
115                                return Err(io::Error::new(
116                                    io::ErrorKind::ConnectionReset,
117                                    "Server closed",
118                                ));
119                            }
120                        }
121                    },
122                }
123            }
124
125            debug!("Allowing pending responses to complete");
126
127            // We shan't be sending any more requests. Keep reading
128            // responses, though.
129            drop((req_rx, write));
130
131            // Satisfy remaining responses.
132            while let Some(Frame { id, value }) = read.try_next().await? {
133                match in_flight.remove(&id) {
134                    Some(tx) => {
135                        let _ = tx.send(value);
136                    }
137                    None => {
138                        return Err(io::Error::new(
139                            io::ErrorKind::InvalidInput,
140                            "Response for unknown request",
141                        ));
142                    }
143                }
144            }
145            if !in_flight.is_empty() {
146                return Err(io::Error::new(
147                    io::ErrorKind::ConnectionReset,
148                    "Some requests did not receive a response",
149                ));
150            }
151
152            Ok(())
153        }
154        .in_current_span(),
155    );
156
157    req_tx
158}
159
160type Channel<Req, Rsp> = mpsc::Receiver<(Req, oneshot::Sender<Rsp>)>;
161type JoinHandle = tokio::task::JoinHandle<io::Result<()>>;
162
163pub fn spawn_server<Req, Rsp, R, W>(
164    mut read: R,
165    mut write: W,
166    drain: Drain,
167    buffer_capacity: usize,
168) -> (Channel<Req, Rsp>, JoinHandle)
169where
170    Req: Send + 'static,
171    Rsp: Send + 'static,
172    R: Stream<Item = io::Result<Frame<Req>>> + Send + Unpin + 'static,
173    W: Sink<Frame<Rsp>, Error = io::Error> + Send + Unpin + 'static,
174{
175    let (tx, rx) = mpsc::channel(buffer_capacity);
176
177    let handle = tokio::spawn(async move {
178        tokio::pin! {
179            let closed = drain.signaled();
180        }
181
182        let mut last_id = 0u64;
183        let mut in_flight = FuturesUnordered::new();
184        loop {
185            tokio::select! {
186                shutdown = (&mut closed) => {
187                    debug!("Shutdown signaled; draining in-flight requests");
188                    drop(read);
189                    drop(tx);
190                    while let Some(Frame { id, value }) = in_flight.try_next().await? {
191                        trace!(id, "In-flight response completed");
192                        write.send(Frame { id, value }).await?;
193                    }
194                    debug!("In-flight requests completed");
195                    drop(shutdown);
196                    return Ok(());
197                }
198
199                req = next_or_pending(&mut in_flight) => {
200                    let Frame { id, value } = req?;
201                    trace!(id, "In-flight response completed");
202                    if let Err(error) = write.send(Frame { id, value }).await {
203                        error!(%error, "Write failed");
204                        return Err(error);
205                    }
206                }
207
208                msg = read.try_next() => {
209                    let Frame { id, value } = match msg? {
210                        Some(f) => f,
211                        None => {
212                            trace!("Draining in-flight responses after client stream completed.");
213                            while let Some(Frame { id, value }) = in_flight.try_next().await? {
214                                trace!(id, "In-flight response completed");
215                                write.send(Frame { id, value }).await?;
216                            }
217                            return Ok(());
218                        }
219                    };
220
221                    if id <= last_id {
222                        return Err(io::Error::new(
223                            io::ErrorKind::InvalidInput,
224                            "Request ID too low",
225                        ));
226                    }
227                    last_id = id;
228
229                    trace!(id, "Dispatching request");
230                    let (rsp_tx, rsp_rx) = oneshot::channel();
231                    if tx.send((value, rsp_tx)).await.is_err() {
232                        return Err(io::Error::new(
233                            io::ErrorKind::ConnectionAborted,
234                            "Lost service",
235                        ));
236                    }
237                    in_flight.push(rsp_rx.map(move |v| match v {
238                        Ok(value) => Ok(Frame { id, value }),
239                        Err(_) => Err(io::Error::new(
240                            io::ErrorKind::ConnectionAborted,
241                            "Server dropped response",
242                        )),
243                    }));
244                }
245            }
246        }
247    }.instrument(debug_span!("mux")));
248
249    (rx, handle)
250}
251
252// === impl FramedDecode ===
253
254impl<D> From<D> for FramedDecode<D> {
255    fn from(inner: D) -> Self {
256        Self {
257            inner,
258            state: DecodeState::Init,
259        }
260    }
261}
262
263impl<D: Default> Default for FramedDecode<D> {
264    fn default() -> Self {
265        Self::from(D::default())
266    }
267}
268
269impl<D: Decoder> Decoder for FramedDecode<D> {
270    type Item = Frame<D::Item>;
271    type Error = D::Error;
272
273    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Frame<D::Item>>, D::Error> {
274        let id = match self.state {
275            DecodeState::Init => {
276                if src.len() < 8 {
277                    return Ok(None);
278                }
279                src.get_u64()
280            }
281            DecodeState::Head { id } => {
282                self.state = DecodeState::Init;
283                id
284            }
285        };
286
287        match self.inner.decode(src)? {
288            Some(value) => Ok(Some(Frame { id, value })),
289            None => {
290                self.state = DecodeState::Head { id };
291                Ok(None)
292            }
293        }
294    }
295}
296
297// === impl FramedEncode ===
298
299impl<E> From<E> for FramedEncode<E> {
300    fn from(inner: E) -> Self {
301        Self { inner }
302    }
303}
304
305impl<T, C: Encoder<T>> Encoder<Frame<T>> for FramedEncode<C> {
306    type Error = C::Error;
307
308    fn encode(
309        &mut self,
310        Frame { id, value }: Frame<T>,
311        dst: &mut BytesMut,
312    ) -> Result<(), C::Error> {
313        dst.reserve(8);
314        dst.put_u64(id);
315        self.inner.encode(value, dst)
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use bytes::Bytes;
323    use tokio_util::codec::LengthDelimitedCodec;
324
325    #[tokio::test]
326    async fn roundtrip() {
327        let b0 = Bytes::from_static(b"abcde");
328        let b1 = Bytes::from_static(b"fghij");
329
330        let mut buf = BytesMut::with_capacity(100);
331
332        let mut enc = FramedEncode::<LengthDelimitedCodec>::default();
333        enc.encode(
334            Frame {
335                id: 1,
336                value: b0.clone(),
337            },
338            &mut buf,
339        )
340        .expect("must encode");
341        enc.encode(
342            Frame {
343                id: 2,
344                value: b1.clone(),
345            },
346            &mut buf,
347        )
348        .expect("must encode");
349
350        let mut dec = FramedDecode::<LengthDelimitedCodec>::default();
351        let d0 = dec
352            .decode(&mut buf)
353            .expect("must decode")
354            .expect("must decode");
355        let d1 = dec
356            .decode(&mut buf)
357            .expect("must decode")
358            .expect("must decode");
359        assert_eq!(d0.id, 1);
360        assert_eq!(d0.value.freeze(), b0);
361        assert_eq!(d1.id, 2);
362        assert_eq!(d1.value.freeze(), b1);
363    }
364}