async_session_types/multiplexing/
incoming.rs

1use futures::stream::FuturesUnordered;
2use futures::StreamExt;
3use std::{collections::btree_map::Entry, fmt::Debug};
4use tokio::select;
5
6use super::{Demultiplexer, ErrorChan, MultiMessage, Multiplexer};
7use crate::{Receiver, Sender, SessionError};
8
9/// An `IncomingMultiChannel` is established for each incoming TCP connection.
10/// The local party will act as the server for the remote client.
11///
12/// The `demux` part takes messages from the incoming connection and dispatches
13/// to the implementing protocols. If the protocol doesn't exist, a server
14/// session is instantiated to handle the messages.
15///
16/// The `mux` part listens for outgoing messages from all instantiated server
17/// sessions, relaying responses to the client.
18pub struct IncomingMultiChannel<P, R> {
19    demux: Demultiplexer<P, R>,
20    mux: Multiplexer<P, R>,
21    errors: ErrorChan,
22}
23
24impl<P: Ord + Copy + Debug, R> IncomingMultiChannel<P, R> {
25    /// Create a new `MultiChannel` by passing it the channels it can use to
26    /// send/receive messages to/from the underlying network connection.
27    pub fn new(tx: Sender<MultiMessage<P, R>>, rx: Receiver<MultiMessage<P, R>>) -> Self {
28        Self {
29            errors: ErrorChan::new(),
30            demux: Demultiplexer::new(rx),
31            mux: Multiplexer::new(tx),
32        }
33    }
34
35    /// Start consuming messages from the client, creating new protocol handlers
36    /// where one doesn't exist already.
37    ///
38    /// We must also listen to an internal channel that signals the abortion
39    /// of the whole connection due to some protocol violation.
40    ///
41    /// `init_server` should initiate the server side of the protocol, listening
42    /// to incoming messages from the client; it should return the channels we
43    /// can send the client messages to, and receive the replies over. When
44    /// sending or receiving from these channels fail, the protocol is over
45    /// and should be removed.
46    pub async fn run<F>(mut self, mut init_server: F)
47    where
48        F: FnMut(P, Sender<SessionError>) -> (Sender<R>, Receiver<R>),
49    {
50        // NOTE: The following comment illustrates the original synchronous version, but the gist of it is the same.
51        // In a loop, add the `rx` of `demux`, and all the `rxs` of `mux` to a `Select`, see which one is ready:
52        // * If `error.rx` returns `Ok` then a protocol violation occurred and we can exit the loop.
53        // * If `demux.rx` returns `Ok` then get or create the protocol in both `mux` and `demux`, and dispatch into the corresponding channel in `demux.txs`.
54        // * If `demux.rx` returns `Err` then the connection is closed and we can exit the loop.
55        // * If any of `mux.rxs` return `Ok` then wrap the message and send to `mux.tx`.
56        // * If any of `mux.rxs` return `Err` then that protocol is finished and can be removed from both `mux` and `demux`.
57
58        // We have to declare a `FuturesUnordered` here, can't store it in a field becuase it has an opaque type, `impl Future`.
59        let mut mux = FuturesUnordered::new();
60
61        loop {
62            select! {
63                // One of the sessions detection a violation of some rule and now all of them need to be closed.
64                // This will never return `None` because we never close the error channel.
65                Some(_) = self.errors.rx.recv() => break,
66
67                // One of the session handlers has a message to send to the client.
68                // This might return `None` when there are no sessions, but the pattern will discard this branch
69                // and wait on the others. The first incoming message will then establish a handler and the next
70                // loop should see it matching.
71                Some((pid, outgoing, rx)) = mux.next() => {
72                    match outgoing {
73                        Some(msg) => {
74                            self.handle_outgoing_reply(pid, msg);
75                            // Put the receiver back in the mux, so further replies from the session are handled.
76                            mux.push(Self::recv_outgoing(pid, rx));
77                        }
78                        // The outgoing channel was closed on our side.
79                        None => self.remove(pid),
80                    }
81                },
82
83                // There is an incoming request from the client.
84                incoming = self.demux.recv() => {
85                    match incoming {
86                        Some((pid, msg)) => {
87                            if let Some(rx) = self.handle_incoming_request(pid, msg, &mut init_server) {
88                                // Register the receiver with the reply multiplexer.
89                                mux.push(Self::recv_outgoing(pid, rx));
90                            }
91                        },
92                        // The incoming channel is closed.
93                        None => break,
94                    }
95                },
96            }
97        }
98    }
99
100    // We need exactly one function to produce the future that gets pushed into the `FuturesUnordered` otherwise it would have multiple types.
101    async fn recv_outgoing(pid: P, mut rx: Receiver<R>) -> (P, Option<R>, Receiver<R>) {
102        let o = rx.recv().await;
103        (pid, o, rx)
104    }
105
106    /// Dispatch an incoming request to the corresponding session.
107    ///
108    /// Initialize the session handler if it doesn't exist yet.
109    /// If we initiated a new handler, return the channel where
110    /// we can receive messages from it.
111    fn handle_incoming_request<F>(
112        &mut self,
113        pid: P,
114        msg: R,
115        init_server: &mut F,
116    ) -> Option<Receiver<R>>
117    where
118        F: FnMut(P, Sender<SessionError>) -> (Sender<R>, Receiver<R>),
119    {
120        match self.demux.txs.entry(pid) {
121            Entry::Vacant(e) => {
122                let (tx, rx) = init_server(pid, self.errors.tx.clone());
123                let _ = tx.send(msg);
124                e.insert(tx);
125                Some(rx)
126            }
127            Entry::Occupied(e) => {
128                // Ignoring send errors here; it means the channel is closed, but we'll realize that in the loop.
129                let _ = e.get().send(msg);
130                None
131            }
132        }
133    }
134
135    /// Wrap an outgoing reply and relay to the connection channel.
136    fn handle_outgoing_reply(&self, pid: P, msg: R) {
137        // Ignoring send errors here; it means the connection is closed, but we'll realize that in the loop.
138        let _ = self.mux.send(pid, msg);
139    }
140
141    /// Remove a session protocol that got closed on our side.
142    fn remove(&mut self, pid: P) {
143        self.demux.txs.remove(&pid);
144    }
145}
146
147#[cfg(test)]
148mod test {
149    use crate::multiplexing::MultiMessage;
150    use crate::session_channel_dyn;
151    use crate::test::protocols::greetings::*;
152    use crate::test::protocols::ping_pong::*;
153    use crate::test::protocols::*;
154    use crate::unbounded_channel;
155    use crate::DynMessage;
156    use crate::Receiver;
157    use crate::Sender;
158    use crate::SessionResult;
159    use crate::{offer, ok, Chan};
160    use std::fmt::Debug;
161    use std::time::Duration;
162    use std::time::Instant;
163    use tokio::time::timeout as timeout_after;
164
165    use super::IncomingMultiChannel;
166
167    #[derive(Ord, Clone, PartialEq, PartialOrd, Eq, Debug, Copy)]
168    enum Protos {
169        PingPong,
170        Greetings,
171    }
172
173    async fn ping_pong_srv(
174        c: Chan<ping_pong::Server, (), DynMessage>,
175        t: Duration,
176    ) -> SessionResult<()> {
177        let (c, _ping) = c.recv(t).await?;
178        c.send(Pong)?.close()
179    }
180
181    async fn greetings_srv(
182        c: Chan<greetings::Server, (), DynMessage>,
183        t: Duration,
184    ) -> SessionResult<()> {
185        let (c, Hail(cid)) = c.recv(t).await?;
186        let c = c.send(Greetings(format!("Hello {}!", cid)))?;
187        let mut c = c.enter();
188        loop {
189            c = offer! { c, t,
190                Time => {
191                    let (c, TimeRequest) = c.recv(t).await?;
192                    let c = c.send(TimeResponse(Instant::now()))?;
193                    c.zero()?
194                },
195                Add => {
196                    let (c, AddRequest(a)) = c.recv(t).await?;
197                    let (c, AddRequest(b)) = c.recv(t).await?;
198                    let c = c.send(AddResponse(a + b))?;
199                    c.zero()?
200                },
201                Quit => {
202                    let (c, Quit) = c.recv(t).await?;
203                    c.close()?;
204                    break;
205                }
206            };
207        }
208
209        ok(())
210    }
211
212    #[tokio::test]
213    async fn basics() {
214        let timeout = Duration::from_millis(100);
215
216        // Create an IncomingMultiChannel. It needs a pair of channels, an incoming and outgoing one.
217        // Whichever side we are not passing to the constructor is what we're going to use in the test.
218        let (tx_in, rx_in) = unbounded_channel();
219        let (tx_out, mut rx_out) = unbounded_channel();
220        let channel = IncomingMultiChannel::<Protos, DynMessage>::new(tx_out, rx_in);
221
222        type TxIn = Sender<MultiMessage<Protos, DynMessage>>;
223        type RxOut = Receiver<MultiMessage<Protos, DynMessage>>;
224
225        // Start the channel by passing it a closure that tells it how to instantiate servers.
226        tokio::spawn(channel.run(move |p, errors| match p {
227            Protos::PingPong => {
228                let (c, (tx, rx)) = session_channel_dyn();
229                tokio::spawn(async move {
230                    if let Err(e) = ping_pong_srv(c, timeout).await {
231                        let _ = errors.send(e);
232                    }
233                });
234                (tx, rx)
235            }
236
237            Protos::Greetings => {
238                let (c, (tx, rx)) = session_channel_dyn();
239                tokio::spawn(async move {
240                    if let Err(e) = greetings_srv(c, timeout).await {
241                        let _ = errors.send(e);
242                    }
243                });
244                (tx, rx)
245            }
246        }));
247
248        // Act as session clients, send some messages, verify that the responses arrive.
249
250        async fn test_ping(tx_in: &TxIn, rx_out: &mut RxOut, timeout: Duration) {
251            tx_in
252                .send(MultiMessage::new(Protos::PingPong, Ping))
253                .unwrap();
254
255            let res = timeout_after(timeout, rx_out.recv())
256                .await
257                .unwrap()
258                .unwrap();
259
260            assert_eq!(res.protocol_id, Protos::PingPong);
261            assert!(res.payload.downcast::<Pong>().is_ok());
262        }
263
264        async fn test_greetings(tx_in: &TxIn, rx_out: &mut RxOut, timeout: Duration) {
265            let pid = Protos::Greetings;
266            tx_in
267                .send(MultiMessage::new(pid, Hail("Multi".into())))
268                .unwrap();
269
270            let res = timeout_after(timeout, rx_out.recv())
271                .await
272                .unwrap()
273                .unwrap();
274
275            assert!(res.payload.downcast::<Greetings>().is_ok());
276
277            tx_in.send(MultiMessage::new(pid, AddRequest(1))).unwrap();
278            tx_in.send(MultiMessage::new(pid, AddRequest(2))).unwrap();
279
280            let res = timeout_after(timeout, rx_out.recv())
281                .await
282                .unwrap()
283                .unwrap();
284
285            assert!(res.payload.downcast::<AddResponse>().is_ok());
286            tx_in.send(MultiMessage::new(pid, Quit)).unwrap();
287        }
288
289        async fn test_abort(tx_in: &TxIn, rx_out: &mut RxOut, timeout: Duration) {
290            tx_in
291                .send(MultiMessage::new(Protos::Greetings, Hail("Abort".into())))
292                .unwrap();
293
294            let res = timeout_after(timeout, rx_out.recv())
295                .await
296                .unwrap()
297                .unwrap();
298
299            assert!(res.payload.downcast::<Greetings>().is_ok());
300
301            // Send an invalid message.
302            tx_in
303                .send(MultiMessage::new(Protos::PingPong, "Boom!!!"))
304                .unwrap();
305
306            // It should cause the other protocol to close as well.
307            tokio::time::sleep(timeout / 2).await;
308            let res = tx_in.send(MultiMessage::new(Protos::Greetings, AddRequest(10)));
309            assert!(res.is_err());
310        }
311
312        test_ping(&tx_in, &mut rx_out, timeout).await;
313        test_greetings(&tx_in, &mut rx_out, timeout).await;
314        test_ping(&tx_in, &mut rx_out, timeout).await;
315        test_abort(&tx_in, &mut rx_out, timeout).await;
316    }
317}