async_session_types/multiplexing/
outgoing.rs

1use futures::{stream::FuturesUnordered, StreamExt};
2use tokio::select;
3
4use super::{AddChan, AddMsg, Demultiplexer, ErrorChan, MultiMessage, Multiplexer};
5use crate::{Receiver, Sender, SessionError};
6
7/// Control signals for the running outgoing channel, allowing the rest of
8/// the program to register new sessions and to abort all ongoing ones.
9pub struct Control<P, R> {
10    errors: Sender<SessionError>,
11    adds: Sender<AddMsg<P, R>>,
12}
13
14impl<P, R> Control<P, R> {
15    /// Register the client side of a session protocol, so that `run` starts listening
16    /// to messages coming from the protocol and multiplexing them to the outgoing
17    /// connection towards the server, and also starts dispaching server messages
18    /// to this protocol.
19    ///
20    /// The session should be automatically removed when it closes the other sides of
21    /// the channels.
22    ///
23    /// Returns `false` if the channel is already closed and the session could not be registered.
24    pub fn add(&self, protocol_id: P, tx: Sender<R>, rx: Receiver<R>) -> bool {
25        self.adds.send((protocol_id, tx, rx)).is_ok()
26    }
27
28    /// Remove all sessions due to protocol violation. Use an internal channel to notify `run`.
29    pub fn error(&self, _protocol_id: P, error: SessionError) {
30        // Ignoring send errors, it would mean the channel is already closed.
31        let _ = self.errors.send(error);
32    }
33}
34
35/// An `OutgoingMultiChannel` is established for each outgoing TCP connection.
36/// The local party will be the client of the remote server.
37///
38/// The `mux` part listens to multiple sessions initiated by this side and
39/// relay their messages to the common outgoing channel.
40///
41/// The `demux` part listens to the incoming replies from the server on the
42/// connection and dispatches the messages to their corresponding protocol
43/// handlers.
44pub struct OutgoingMultiChannel<P, R> {
45    mux: Multiplexer<P, R>,
46    demux: Demultiplexer<P, R>,
47    errors: ErrorChan,
48    adds: AddChan<P, R>,
49}
50
51impl<P: Ord + Copy, R: 'static + Send + Sync> OutgoingMultiChannel<P, R> {
52    /// Create a new `MultiChannel` by passing it the channels it can use to
53    /// send/receive messages to/from the underlying network connection.
54    pub fn new(tx: Sender<MultiMessage<P, R>>, rx: Receiver<MultiMessage<P, R>>) -> Self {
55        Self {
56            mux: Multiplexer::new(tx),
57            demux: Demultiplexer::new(rx),
58            errors: ErrorChan::new(),
59            adds: AddChan::new(),
60        }
61    }
62
63    /// Get a handler that can be used to send signals to the channel.
64    ///
65    /// Call this before the channel is run.
66    pub fn control(&self) -> Control<P, R> {
67        Control {
68            errors: self.errors.tx.clone(),
69            adds: self.adds.tx.clone(),
70        }
71    }
72
73    /// Start consuming messages from the remote server, relaying them to
74    /// the local client side of the protocols.
75    ///
76    /// We must also listen to an internal channel that signals the abortion
77    /// of the whole connection due to some protocol violation.
78    ///
79    /// Yet another internal channel must be used to receive registration requests.
80    pub async fn run(mut self) {
81        // Futures aggregator for the session clients on our side.
82        let mut mux = FuturesUnordered::new();
83
84        loop {
85            select! {
86                // Abandon all channels if there's an error. This should never return `None`.
87                Some(err) = self.errors.rx.recv() => {
88                    // We can ignore disconnect errors, they can come from replaced sessions.
89                    // We'll recognise a real disconnection by not being able to receive from
90                    // the remote side here in the incoming arm.
91                    match err {
92                        SessionError::Disconnected => (),
93                        _ => break,
94                    }
95                },
96
97                // We are initiating a new protocol. This should never return `None`.
98                Some(add) = self.adds.rx.recv() => {
99                    let (pid, rx) = self.add(add);
100                    mux.push(Self::recv_outgoing(pid, rx));
101                }
102
103                // We have an outgoing request from one of the sessions.
104                // Initially `mux` is empty and this branch will be disabled by the `select!`.
105                // Only after the first `add` message will it get another chaince.
106                Some((pid, outgoing, rx)) = mux.next() => {
107                    match outgoing {
108                        Some(msg) => {
109                            self.handle_outgoing_request(pid, msg);
110                            // Re-queue the receiver.
111                            mux.push(Self::recv_outgoing(pid, rx));
112                        }
113                        // The outgoing channel got closed on our side.
114                        // The other sessions can keep going.
115                        None =>self.remove(pid),
116                    }
117                }
118
119                // There is an incoming reply from the server.
120                incoming = self.demux.recv() => {
121                    match incoming {
122                        Some((pid, msg)) => self.handle_incoming_reply(pid, msg),
123                        // The channel to server is closed.
124                        None => break,
125                    }
126                }
127            }
128        }
129    }
130
131    // We need exactly one function to produce the future that gets pushed into the
132    // `FuturesUnordered` otherwise it would have multiple conflicting anonymous types.
133    async fn recv_outgoing(pid: P, mut rx: Receiver<R>) -> (P, Option<R>, Receiver<R>) {
134        let o = rx.recv().await;
135        (pid, o, rx)
136    }
137
138    /// Add a new session, which is the client side of a protocol this side initiated.
139    ///
140    /// If a session with the same ID already exists it is overwritten.
141    /// This should cause a disconnection error to be raised in the session,
142    /// which is why we have to ignore those, and not abort all other sessions
143    /// when such an error is reported.
144    ///
145    /// Returns the receiver we need to register with the multiplexer.
146    fn add(&mut self, add: AddMsg<P, R>) -> (P, Receiver<R>) {
147        let (pid, tx, rx) = add;
148        self.demux.txs.insert(pid, tx);
149        (pid, rx)
150    }
151
152    /// Dispatch an incoming reply to the corresponding session that sent the orginal request as a client.
153    ///
154    /// Abort if the protocol doesn't exist. This is an outgoing connection, the local party initiates.
155    fn handle_incoming_reply(&mut self, pid: P, msg: R) {
156        match self.demux.txs.get(&pid) {
157            Some(tx) => {
158                // Ignoring send errors here; it means the session has ended on our side,
159                // but we'll realise this in the loop when trying to receive from it.
160                let _ = tx.send(msg);
161            }
162            None => {
163                // A message to a protocol we did not initiate.
164                let _ = self
165                    .errors
166                    .tx
167                    .send(SessionError::UnexpectedMessage(Box::new(msg)));
168            }
169        }
170    }
171
172    /// Wrap an outgoing request and relay to the connection channel.
173    fn handle_outgoing_request(&self, pid: P, msg: R) {
174        // Ignoring send errors here; it means the connection is closed,
175        // but we'll realize that in the loop when trying to receive.
176        let _ = self.mux.send(pid, msg);
177    }
178
179    /// Remove a session protocol that got closed on our side.
180    fn remove(&mut self, pid: P) {
181        self.demux.txs.remove(&pid);
182    }
183}
184
185#[cfg(test)]
186mod test {
187    use crate::multiplexing::MultiMessage;
188    use crate::session_channel_dyn;
189    use crate::test::protocols::greetings::*;
190    use crate::test::protocols::ping_pong::*;
191    use crate::test::protocols::*;
192    use crate::unbounded_channel;
193    use crate::Chan;
194    use crate::DynMessage;
195    use crate::SessionResult;
196    use std::time::Duration;
197    use tokio::time::timeout as timeout_after;
198
199    use super::OutgoingMultiChannel;
200
201    type PID = u8;
202    mod protos {
203        pub const PING_PONG: u8 = 1;
204        pub const GREETINGS: u8 = 2;
205    }
206
207    #[tokio::test]
208    async fn basics() {
209        let timeout = Duration::from_millis(100);
210
211        // Create an IncomingMultiChannel. It needs a pair of channels,
212        // one for requests going out and one for replies coming in. .
213        let (tx_in, rx_in) = unbounded_channel();
214        let (tx_out, mut rx_out) = unbounded_channel();
215        let channel = OutgoingMultiChannel::<PID, DynMessage>::new(tx_out, rx_in);
216
217        // Grab the control handle before moving the channel into a thread.
218        let control = channel.control();
219
220        // Start multiplexing in the background.
221        tokio::spawn(channel.run());
222
223        // Act as session server in the test. Start the clients in the background.
224        async fn cli(
225            c: Chan<greetings::Client, (), DynMessage>,
226            timeout: Duration,
227        ) -> SessionResult<()> {
228            let c = c.send(Hail("Punter".into()))?;
229            let (c, Greetings(_)) = c.recv(timeout).await?;
230            let c = c.enter();
231            let (c, AddResponse(_)) = c
232                .sel2()
233                .sel1()
234                .send(AddRequest(1))?
235                .send(AddRequest(2))?
236                .recv(timeout)
237                .await?;
238
239            c.zero()?.sel2().sel2().send(Quit)?.close()
240        }
241
242        let start_greeting = || {
243            let (c, (tx, rx)) = session_channel_dyn::<greetings::Client, DynMessage>();
244            // Start the session interacting with the greeting server in a thread.
245            tokio::spawn(cli(c, timeout));
246            // Register the channels with the multiplexer.
247            control.add(protos::GREETINGS, tx, rx);
248        };
249
250        start_greeting();
251
252        // See what the outgoing multiplexer wants to send.
253        let res = timeout_after(timeout, rx_out.recv())
254            .await
255            .unwrap()
256            .unwrap();
257
258        assert_eq!(res.protocol_id, protos::GREETINGS);
259        assert!(res.payload.downcast::<Hail>().is_ok());
260
261        // Now send an unexpected reply to a never requested ping. It should cause the whole thing to be closed.
262        tx_in
263            .send(MultiMessage::new(protos::PING_PONG, Pong))
264            .unwrap();
265
266        // Wait a bit for the message to take effect.
267        tokio::time::sleep(timeout / 2).await;
268
269        let res = tx_in.send(MultiMessage::new(
270            protos::GREETINGS,
271            Hail("Still there?".into()),
272        ));
273        assert!(res.is_err());
274    }
275}