1use crate::next_or_pending;
2use bytes::{Buf, BufMut, BytesMut};
3use drain::Watch as Drain;
4use futures::{prelude::*, stream::FuturesUnordered};
5use std::collections::HashMap;
6use tokio::{
7 io,
8 sync::{mpsc, oneshot},
9};
10use tokio_util::codec::{Decoder, Encoder};
11use tracing::{debug, debug_span, error, info, trace, Instrument};
12
13#[derive(Default, Debug)]
14pub struct Muxer<E, D> {
15 buffer_capacity: usize,
16 encoder: FramedEncode<E>,
17 decoder: FramedDecode<D>,
18}
19
20#[derive(Debug)]
21pub struct Frame<T> {
22 pub id: u64,
23 pub value: T,
24}
25
26#[derive(Default, Debug)]
27pub struct FramedEncode<E> {
28 inner: E,
29}
30
31#[derive(Debug)]
32pub struct FramedDecode<D> {
33 inner: D,
34 state: DecodeState,
35}
36
37#[derive(Debug)]
38enum DecodeState {
39 Init,
40 Head { id: u64 },
41}
42
43pub fn spawn_client<Req, Rsp, W, R>(
44 mut write: W,
45 mut read: R,
46 buffer_capacity: usize,
47) -> mpsc::Sender<(Req, oneshot::Sender<Rsp>)>
48where
49 Req: Send + 'static,
50 Rsp: Send + 'static,
51 W: Sink<Frame<Req>, Error = io::Error> + Send + Unpin + 'static,
52 R: Stream<Item = io::Result<Frame<Rsp>>> + Send + Unpin + 'static,
53{
54 let (req_tx, mut req_rx) = mpsc::channel(buffer_capacity);
55
56 tokio::spawn(
57 async move {
58 let mut next_id = 1u64;
59 let mut in_flight = HashMap::<u64, oneshot::Sender<Rsp>>::new();
60
61 loop {
62 if next_id == std::u64::MAX {
63 info!("Client exhausted request IDs");
64 break;
65 }
66
67 tokio::select! {
68 req = req_rx.recv() => match req {
71 Some((value, rsp_tx)) => {
72 let id = next_id;
73 next_id += 1;
74 trace!(id, "Dispatching request");
75 let f = in_flight.entry(id);
76 if let std::collections::hash_map::Entry::Occupied(_) = f {
77 error!(id, "Request ID already in-flight");
78 return Err(io::Error::new(
79 io::ErrorKind::InvalidInput,
80 "Request ID is already in-flight",
81 ));
82 }
83 if let Err(error) = write.send(Frame { id, value }).await {
84 error!(id, %error, "Failed to write response");
85 return Err(error);
86 }
87 f.or_insert(rsp_tx);
88 }
89 None => {
90 debug!("Client dropped its send handle");
91 break;
92 }
93 },
94
95 rsp = read.try_next() => match rsp? {
98 Some(Frame { id, value }) => {
99 trace!(id, "Dispatching response");
100 match in_flight.remove(&id) {
101 Some(tx) => {
102 let _ = tx.send(value);
103 }
104 None => return Err(io::Error::new(
105 io::ErrorKind::InvalidInput,
106 "Response for unknown request",
107 )),
108 }
109 }
110 None => {
111 debug!(in_flight=in_flight.len(), "Server closed");
112 if in_flight.is_empty() {
113 return Ok(());
114 } else {
115 return Err(io::Error::new(
116 io::ErrorKind::ConnectionReset,
117 "Server closed",
118 ));
119 }
120 }
121 },
122 }
123 }
124
125 debug!("Allowing pending responses to complete");
126
127 drop((req_rx, write));
130
131 while let Some(Frame { id, value }) = read.try_next().await? {
133 match in_flight.remove(&id) {
134 Some(tx) => {
135 let _ = tx.send(value);
136 }
137 None => {
138 return Err(io::Error::new(
139 io::ErrorKind::InvalidInput,
140 "Response for unknown request",
141 ));
142 }
143 }
144 }
145 if !in_flight.is_empty() {
146 return Err(io::Error::new(
147 io::ErrorKind::ConnectionReset,
148 "Some requests did not receive a response",
149 ));
150 }
151
152 Ok(())
153 }
154 .in_current_span(),
155 );
156
157 req_tx
158}
159
160type Channel<Req, Rsp> = mpsc::Receiver<(Req, oneshot::Sender<Rsp>)>;
161type JoinHandle = tokio::task::JoinHandle<io::Result<()>>;
162
163pub fn spawn_server<Req, Rsp, R, W>(
164 mut read: R,
165 mut write: W,
166 drain: Drain,
167 buffer_capacity: usize,
168) -> (Channel<Req, Rsp>, JoinHandle)
169where
170 Req: Send + 'static,
171 Rsp: Send + 'static,
172 R: Stream<Item = io::Result<Frame<Req>>> + Send + Unpin + 'static,
173 W: Sink<Frame<Rsp>, Error = io::Error> + Send + Unpin + 'static,
174{
175 let (tx, rx) = mpsc::channel(buffer_capacity);
176
177 let handle = tokio::spawn(async move {
178 tokio::pin! {
179 let closed = drain.signaled();
180 }
181
182 let mut last_id = 0u64;
183 let mut in_flight = FuturesUnordered::new();
184 loop {
185 tokio::select! {
186 shutdown = (&mut closed) => {
187 debug!("Shutdown signaled; draining in-flight requests");
188 drop(read);
189 drop(tx);
190 while let Some(Frame { id, value }) = in_flight.try_next().await? {
191 trace!(id, "In-flight response completed");
192 write.send(Frame { id, value }).await?;
193 }
194 debug!("In-flight requests completed");
195 drop(shutdown);
196 return Ok(());
197 }
198
199 req = next_or_pending(&mut in_flight) => {
200 let Frame { id, value } = req?;
201 trace!(id, "In-flight response completed");
202 if let Err(error) = write.send(Frame { id, value }).await {
203 error!(%error, "Write failed");
204 return Err(error);
205 }
206 }
207
208 msg = read.try_next() => {
209 let Frame { id, value } = match msg? {
210 Some(f) => f,
211 None => {
212 trace!("Draining in-flight responses after client stream completed.");
213 while let Some(Frame { id, value }) = in_flight.try_next().await? {
214 trace!(id, "In-flight response completed");
215 write.send(Frame { id, value }).await?;
216 }
217 return Ok(());
218 }
219 };
220
221 if id <= last_id {
222 return Err(io::Error::new(
223 io::ErrorKind::InvalidInput,
224 "Request ID too low",
225 ));
226 }
227 last_id = id;
228
229 trace!(id, "Dispatching request");
230 let (rsp_tx, rsp_rx) = oneshot::channel();
231 if tx.send((value, rsp_tx)).await.is_err() {
232 return Err(io::Error::new(
233 io::ErrorKind::ConnectionAborted,
234 "Lost service",
235 ));
236 }
237 in_flight.push(rsp_rx.map(move |v| match v {
238 Ok(value) => Ok(Frame { id, value }),
239 Err(_) => Err(io::Error::new(
240 io::ErrorKind::ConnectionAborted,
241 "Server dropped response",
242 )),
243 }));
244 }
245 }
246 }
247 }.instrument(debug_span!("mux")));
248
249 (rx, handle)
250}
251
252impl<D> From<D> for FramedDecode<D> {
255 fn from(inner: D) -> Self {
256 Self {
257 inner,
258 state: DecodeState::Init,
259 }
260 }
261}
262
263impl<D: Default> Default for FramedDecode<D> {
264 fn default() -> Self {
265 Self::from(D::default())
266 }
267}
268
269impl<D: Decoder> Decoder for FramedDecode<D> {
270 type Item = Frame<D::Item>;
271 type Error = D::Error;
272
273 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Frame<D::Item>>, D::Error> {
274 let id = match self.state {
275 DecodeState::Init => {
276 if src.len() < 8 {
277 return Ok(None);
278 }
279 src.get_u64()
280 }
281 DecodeState::Head { id } => {
282 self.state = DecodeState::Init;
283 id
284 }
285 };
286
287 match self.inner.decode(src)? {
288 Some(value) => Ok(Some(Frame { id, value })),
289 None => {
290 self.state = DecodeState::Head { id };
291 Ok(None)
292 }
293 }
294 }
295}
296
297impl<E> From<E> for FramedEncode<E> {
300 fn from(inner: E) -> Self {
301 Self { inner }
302 }
303}
304
305impl<T, C: Encoder<T>> Encoder<Frame<T>> for FramedEncode<C> {
306 type Error = C::Error;
307
308 fn encode(
309 &mut self,
310 Frame { id, value }: Frame<T>,
311 dst: &mut BytesMut,
312 ) -> Result<(), C::Error> {
313 dst.reserve(8);
314 dst.put_u64(id);
315 self.inner.encode(value, dst)
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use bytes::Bytes;
323 use tokio_util::codec::LengthDelimitedCodec;
324
325 #[tokio::test]
326 async fn roundtrip() {
327 let b0 = Bytes::from_static(b"abcde");
328 let b1 = Bytes::from_static(b"fghij");
329
330 let mut buf = BytesMut::with_capacity(100);
331
332 let mut enc = FramedEncode::<LengthDelimitedCodec>::default();
333 enc.encode(
334 Frame {
335 id: 1,
336 value: b0.clone(),
337 },
338 &mut buf,
339 )
340 .expect("must encode");
341 enc.encode(
342 Frame {
343 id: 2,
344 value: b1.clone(),
345 },
346 &mut buf,
347 )
348 .expect("must encode");
349
350 let mut dec = FramedDecode::<LengthDelimitedCodec>::default();
351 let d0 = dec
352 .decode(&mut buf)
353 .expect("must decode")
354 .expect("must decode");
355 let d1 = dec
356 .decode(&mut buf)
357 .expect("must decode")
358 .expect("must decode");
359 assert_eq!(d0.id, 1);
360 assert_eq!(d0.value.freeze(), b0);
361 assert_eq!(d1.id, 2);
362 assert_eq!(d1.value.freeze(), b1);
363 }
364}