async_session_types/multiplexing/incoming.rs
1use futures::stream::FuturesUnordered;
2use futures::StreamExt;
3use std::{collections::btree_map::Entry, fmt::Debug};
4use tokio::select;
5
6use super::{Demultiplexer, ErrorChan, MultiMessage, Multiplexer};
7use crate::{Receiver, Sender, SessionError};
8
9/// An `IncomingMultiChannel` is established for each incoming TCP connection.
10/// The local party will act as the server for the remote client.
11///
12/// The `demux` part takes messages from the incoming connection and dispatches
13/// to the implementing protocols. If the protocol doesn't exist, a server
14/// session is instantiated to handle the messages.
15///
16/// The `mux` part listens for outgoing messages from all instantiated server
17/// sessions, relaying responses to the client.
18pub struct IncomingMultiChannel<P, R> {
19 demux: Demultiplexer<P, R>,
20 mux: Multiplexer<P, R>,
21 errors: ErrorChan,
22}
23
24impl<P: Ord + Copy + Debug, R> IncomingMultiChannel<P, R> {
25 /// Create a new `MultiChannel` by passing it the channels it can use to
26 /// send/receive messages to/from the underlying network connection.
27 pub fn new(tx: Sender<MultiMessage<P, R>>, rx: Receiver<MultiMessage<P, R>>) -> Self {
28 Self {
29 errors: ErrorChan::new(),
30 demux: Demultiplexer::new(rx),
31 mux: Multiplexer::new(tx),
32 }
33 }
34
35 /// Start consuming messages from the client, creating new protocol handlers
36 /// where one doesn't exist already.
37 ///
38 /// We must also listen to an internal channel that signals the abortion
39 /// of the whole connection due to some protocol violation.
40 ///
41 /// `init_server` should initiate the server side of the protocol, listening
42 /// to incoming messages from the client; it should return the channels we
43 /// can send the client messages to, and receive the replies over. When
44 /// sending or receiving from these channels fail, the protocol is over
45 /// and should be removed.
46 pub async fn run<F>(mut self, mut init_server: F)
47 where
48 F: FnMut(P, Sender<SessionError>) -> (Sender<R>, Receiver<R>),
49 {
50 // NOTE: The following comment illustrates the original synchronous version, but the gist of it is the same.
51 // In a loop, add the `rx` of `demux`, and all the `rxs` of `mux` to a `Select`, see which one is ready:
52 // * If `error.rx` returns `Ok` then a protocol violation occurred and we can exit the loop.
53 // * If `demux.rx` returns `Ok` then get or create the protocol in both `mux` and `demux`, and dispatch into the corresponding channel in `demux.txs`.
54 // * If `demux.rx` returns `Err` then the connection is closed and we can exit the loop.
55 // * If any of `mux.rxs` return `Ok` then wrap the message and send to `mux.tx`.
56 // * If any of `mux.rxs` return `Err` then that protocol is finished and can be removed from both `mux` and `demux`.
57
58 // We have to declare a `FuturesUnordered` here, can't store it in a field becuase it has an opaque type, `impl Future`.
59 let mut mux = FuturesUnordered::new();
60
61 loop {
62 select! {
63 // One of the sessions detection a violation of some rule and now all of them need to be closed.
64 // This will never return `None` because we never close the error channel.
65 Some(_) = self.errors.rx.recv() => break,
66
67 // One of the session handlers has a message to send to the client.
68 // This might return `None` when there are no sessions, but the pattern will discard this branch
69 // and wait on the others. The first incoming message will then establish a handler and the next
70 // loop should see it matching.
71 Some((pid, outgoing, rx)) = mux.next() => {
72 match outgoing {
73 Some(msg) => {
74 self.handle_outgoing_reply(pid, msg);
75 // Put the receiver back in the mux, so further replies from the session are handled.
76 mux.push(Self::recv_outgoing(pid, rx));
77 }
78 // The outgoing channel was closed on our side.
79 None => self.remove(pid),
80 }
81 },
82
83 // There is an incoming request from the client.
84 incoming = self.demux.recv() => {
85 match incoming {
86 Some((pid, msg)) => {
87 if let Some(rx) = self.handle_incoming_request(pid, msg, &mut init_server) {
88 // Register the receiver with the reply multiplexer.
89 mux.push(Self::recv_outgoing(pid, rx));
90 }
91 },
92 // The incoming channel is closed.
93 None => break,
94 }
95 },
96 }
97 }
98 }
99
100 // We need exactly one function to produce the future that gets pushed into the `FuturesUnordered` otherwise it would have multiple types.
101 async fn recv_outgoing(pid: P, mut rx: Receiver<R>) -> (P, Option<R>, Receiver<R>) {
102 let o = rx.recv().await;
103 (pid, o, rx)
104 }
105
106 /// Dispatch an incoming request to the corresponding session.
107 ///
108 /// Initialize the session handler if it doesn't exist yet.
109 /// If we initiated a new handler, return the channel where
110 /// we can receive messages from it.
111 fn handle_incoming_request<F>(
112 &mut self,
113 pid: P,
114 msg: R,
115 init_server: &mut F,
116 ) -> Option<Receiver<R>>
117 where
118 F: FnMut(P, Sender<SessionError>) -> (Sender<R>, Receiver<R>),
119 {
120 match self.demux.txs.entry(pid) {
121 Entry::Vacant(e) => {
122 let (tx, rx) = init_server(pid, self.errors.tx.clone());
123 let _ = tx.send(msg);
124 e.insert(tx);
125 Some(rx)
126 }
127 Entry::Occupied(e) => {
128 // Ignoring send errors here; it means the channel is closed, but we'll realize that in the loop.
129 let _ = e.get().send(msg);
130 None
131 }
132 }
133 }
134
135 /// Wrap an outgoing reply and relay to the connection channel.
136 fn handle_outgoing_reply(&self, pid: P, msg: R) {
137 // Ignoring send errors here; it means the connection is closed, but we'll realize that in the loop.
138 let _ = self.mux.send(pid, msg);
139 }
140
141 /// Remove a session protocol that got closed on our side.
142 fn remove(&mut self, pid: P) {
143 self.demux.txs.remove(&pid);
144 }
145}
146
147#[cfg(test)]
148mod test {
149 use crate::multiplexing::MultiMessage;
150 use crate::session_channel_dyn;
151 use crate::test::protocols::greetings::*;
152 use crate::test::protocols::ping_pong::*;
153 use crate::test::protocols::*;
154 use crate::unbounded_channel;
155 use crate::DynMessage;
156 use crate::Receiver;
157 use crate::Sender;
158 use crate::SessionResult;
159 use crate::{offer, ok, Chan};
160 use std::fmt::Debug;
161 use std::time::Duration;
162 use std::time::Instant;
163 use tokio::time::timeout as timeout_after;
164
165 use super::IncomingMultiChannel;
166
167 #[derive(Ord, Clone, PartialEq, PartialOrd, Eq, Debug, Copy)]
168 enum Protos {
169 PingPong,
170 Greetings,
171 }
172
173 async fn ping_pong_srv(
174 c: Chan<ping_pong::Server, (), DynMessage>,
175 t: Duration,
176 ) -> SessionResult<()> {
177 let (c, _ping) = c.recv(t).await?;
178 c.send(Pong)?.close()
179 }
180
181 async fn greetings_srv(
182 c: Chan<greetings::Server, (), DynMessage>,
183 t: Duration,
184 ) -> SessionResult<()> {
185 let (c, Hail(cid)) = c.recv(t).await?;
186 let c = c.send(Greetings(format!("Hello {}!", cid)))?;
187 let mut c = c.enter();
188 loop {
189 c = offer! { c, t,
190 Time => {
191 let (c, TimeRequest) = c.recv(t).await?;
192 let c = c.send(TimeResponse(Instant::now()))?;
193 c.zero()?
194 },
195 Add => {
196 let (c, AddRequest(a)) = c.recv(t).await?;
197 let (c, AddRequest(b)) = c.recv(t).await?;
198 let c = c.send(AddResponse(a + b))?;
199 c.zero()?
200 },
201 Quit => {
202 let (c, Quit) = c.recv(t).await?;
203 c.close()?;
204 break;
205 }
206 };
207 }
208
209 ok(())
210 }
211
212 #[tokio::test]
213 async fn basics() {
214 let timeout = Duration::from_millis(100);
215
216 // Create an IncomingMultiChannel. It needs a pair of channels, an incoming and outgoing one.
217 // Whichever side we are not passing to the constructor is what we're going to use in the test.
218 let (tx_in, rx_in) = unbounded_channel();
219 let (tx_out, mut rx_out) = unbounded_channel();
220 let channel = IncomingMultiChannel::<Protos, DynMessage>::new(tx_out, rx_in);
221
222 type TxIn = Sender<MultiMessage<Protos, DynMessage>>;
223 type RxOut = Receiver<MultiMessage<Protos, DynMessage>>;
224
225 // Start the channel by passing it a closure that tells it how to instantiate servers.
226 tokio::spawn(channel.run(move |p, errors| match p {
227 Protos::PingPong => {
228 let (c, (tx, rx)) = session_channel_dyn();
229 tokio::spawn(async move {
230 if let Err(e) = ping_pong_srv(c, timeout).await {
231 let _ = errors.send(e);
232 }
233 });
234 (tx, rx)
235 }
236
237 Protos::Greetings => {
238 let (c, (tx, rx)) = session_channel_dyn();
239 tokio::spawn(async move {
240 if let Err(e) = greetings_srv(c, timeout).await {
241 let _ = errors.send(e);
242 }
243 });
244 (tx, rx)
245 }
246 }));
247
248 // Act as session clients, send some messages, verify that the responses arrive.
249
250 async fn test_ping(tx_in: &TxIn, rx_out: &mut RxOut, timeout: Duration) {
251 tx_in
252 .send(MultiMessage::new(Protos::PingPong, Ping))
253 .unwrap();
254
255 let res = timeout_after(timeout, rx_out.recv())
256 .await
257 .unwrap()
258 .unwrap();
259
260 assert_eq!(res.protocol_id, Protos::PingPong);
261 assert!(res.payload.downcast::<Pong>().is_ok());
262 }
263
264 async fn test_greetings(tx_in: &TxIn, rx_out: &mut RxOut, timeout: Duration) {
265 let pid = Protos::Greetings;
266 tx_in
267 .send(MultiMessage::new(pid, Hail("Multi".into())))
268 .unwrap();
269
270 let res = timeout_after(timeout, rx_out.recv())
271 .await
272 .unwrap()
273 .unwrap();
274
275 assert!(res.payload.downcast::<Greetings>().is_ok());
276
277 tx_in.send(MultiMessage::new(pid, AddRequest(1))).unwrap();
278 tx_in.send(MultiMessage::new(pid, AddRequest(2))).unwrap();
279
280 let res = timeout_after(timeout, rx_out.recv())
281 .await
282 .unwrap()
283 .unwrap();
284
285 assert!(res.payload.downcast::<AddResponse>().is_ok());
286 tx_in.send(MultiMessage::new(pid, Quit)).unwrap();
287 }
288
289 async fn test_abort(tx_in: &TxIn, rx_out: &mut RxOut, timeout: Duration) {
290 tx_in
291 .send(MultiMessage::new(Protos::Greetings, Hail("Abort".into())))
292 .unwrap();
293
294 let res = timeout_after(timeout, rx_out.recv())
295 .await
296 .unwrap()
297 .unwrap();
298
299 assert!(res.payload.downcast::<Greetings>().is_ok());
300
301 // Send an invalid message.
302 tx_in
303 .send(MultiMessage::new(Protos::PingPong, "Boom!!!"))
304 .unwrap();
305
306 // It should cause the other protocol to close as well.
307 tokio::time::sleep(timeout / 2).await;
308 let res = tx_in.send(MultiMessage::new(Protos::Greetings, AddRequest(10)));
309 assert!(res.is_err());
310 }
311
312 test_ping(&tx_in, &mut rx_out, timeout).await;
313 test_greetings(&tx_in, &mut rx_out, timeout).await;
314 test_ping(&tx_in, &mut rx_out, timeout).await;
315 test_abort(&tx_in, &mut rx_out, timeout).await;
316 }
317}