async_session_types/multiplexing/outgoing.rs
1use futures::{stream::FuturesUnordered, StreamExt};
2use tokio::select;
3
4use super::{AddChan, AddMsg, Demultiplexer, ErrorChan, MultiMessage, Multiplexer};
5use crate::{Receiver, Sender, SessionError};
6
7/// Control signals for the running outgoing channel, allowing the rest of
8/// the program to register new sessions and to abort all ongoing ones.
9pub struct Control<P, R> {
10 errors: Sender<SessionError>,
11 adds: Sender<AddMsg<P, R>>,
12}
13
14impl<P, R> Control<P, R> {
15 /// Register the client side of a session protocol, so that `run` starts listening
16 /// to messages coming from the protocol and multiplexing them to the outgoing
17 /// connection towards the server, and also starts dispaching server messages
18 /// to this protocol.
19 ///
20 /// The session should be automatically removed when it closes the other sides of
21 /// the channels.
22 ///
23 /// Returns `false` if the channel is already closed and the session could not be registered.
24 pub fn add(&self, protocol_id: P, tx: Sender<R>, rx: Receiver<R>) -> bool {
25 self.adds.send((protocol_id, tx, rx)).is_ok()
26 }
27
28 /// Remove all sessions due to protocol violation. Use an internal channel to notify `run`.
29 pub fn error(&self, _protocol_id: P, error: SessionError) {
30 // Ignoring send errors, it would mean the channel is already closed.
31 let _ = self.errors.send(error);
32 }
33}
34
35/// An `OutgoingMultiChannel` is established for each outgoing TCP connection.
36/// The local party will be the client of the remote server.
37///
38/// The `mux` part listens to multiple sessions initiated by this side and
39/// relay their messages to the common outgoing channel.
40///
41/// The `demux` part listens to the incoming replies from the server on the
42/// connection and dispatches the messages to their corresponding protocol
43/// handlers.
44pub struct OutgoingMultiChannel<P, R> {
45 mux: Multiplexer<P, R>,
46 demux: Demultiplexer<P, R>,
47 errors: ErrorChan,
48 adds: AddChan<P, R>,
49}
50
51impl<P: Ord + Copy, R: 'static + Send + Sync> OutgoingMultiChannel<P, R> {
52 /// Create a new `MultiChannel` by passing it the channels it can use to
53 /// send/receive messages to/from the underlying network connection.
54 pub fn new(tx: Sender<MultiMessage<P, R>>, rx: Receiver<MultiMessage<P, R>>) -> Self {
55 Self {
56 mux: Multiplexer::new(tx),
57 demux: Demultiplexer::new(rx),
58 errors: ErrorChan::new(),
59 adds: AddChan::new(),
60 }
61 }
62
63 /// Get a handler that can be used to send signals to the channel.
64 ///
65 /// Call this before the channel is run.
66 pub fn control(&self) -> Control<P, R> {
67 Control {
68 errors: self.errors.tx.clone(),
69 adds: self.adds.tx.clone(),
70 }
71 }
72
73 /// Start consuming messages from the remote server, relaying them to
74 /// the local client side of the protocols.
75 ///
76 /// We must also listen to an internal channel that signals the abortion
77 /// of the whole connection due to some protocol violation.
78 ///
79 /// Yet another internal channel must be used to receive registration requests.
80 pub async fn run(mut self) {
81 // Futures aggregator for the session clients on our side.
82 let mut mux = FuturesUnordered::new();
83
84 loop {
85 select! {
86 // Abandon all channels if there's an error. This should never return `None`.
87 Some(err) = self.errors.rx.recv() => {
88 // We can ignore disconnect errors, they can come from replaced sessions.
89 // We'll recognise a real disconnection by not being able to receive from
90 // the remote side here in the incoming arm.
91 match err {
92 SessionError::Disconnected => (),
93 _ => break,
94 }
95 },
96
97 // We are initiating a new protocol. This should never return `None`.
98 Some(add) = self.adds.rx.recv() => {
99 let (pid, rx) = self.add(add);
100 mux.push(Self::recv_outgoing(pid, rx));
101 }
102
103 // We have an outgoing request from one of the sessions.
104 // Initially `mux` is empty and this branch will be disabled by the `select!`.
105 // Only after the first `add` message will it get another chaince.
106 Some((pid, outgoing, rx)) = mux.next() => {
107 match outgoing {
108 Some(msg) => {
109 self.handle_outgoing_request(pid, msg);
110 // Re-queue the receiver.
111 mux.push(Self::recv_outgoing(pid, rx));
112 }
113 // The outgoing channel got closed on our side.
114 // The other sessions can keep going.
115 None =>self.remove(pid),
116 }
117 }
118
119 // There is an incoming reply from the server.
120 incoming = self.demux.recv() => {
121 match incoming {
122 Some((pid, msg)) => self.handle_incoming_reply(pid, msg),
123 // The channel to server is closed.
124 None => break,
125 }
126 }
127 }
128 }
129 }
130
131 // We need exactly one function to produce the future that gets pushed into the
132 // `FuturesUnordered` otherwise it would have multiple conflicting anonymous types.
133 async fn recv_outgoing(pid: P, mut rx: Receiver<R>) -> (P, Option<R>, Receiver<R>) {
134 let o = rx.recv().await;
135 (pid, o, rx)
136 }
137
138 /// Add a new session, which is the client side of a protocol this side initiated.
139 ///
140 /// If a session with the same ID already exists it is overwritten.
141 /// This should cause a disconnection error to be raised in the session,
142 /// which is why we have to ignore those, and not abort all other sessions
143 /// when such an error is reported.
144 ///
145 /// Returns the receiver we need to register with the multiplexer.
146 fn add(&mut self, add: AddMsg<P, R>) -> (P, Receiver<R>) {
147 let (pid, tx, rx) = add;
148 self.demux.txs.insert(pid, tx);
149 (pid, rx)
150 }
151
152 /// Dispatch an incoming reply to the corresponding session that sent the orginal request as a client.
153 ///
154 /// Abort if the protocol doesn't exist. This is an outgoing connection, the local party initiates.
155 fn handle_incoming_reply(&mut self, pid: P, msg: R) {
156 match self.demux.txs.get(&pid) {
157 Some(tx) => {
158 // Ignoring send errors here; it means the session has ended on our side,
159 // but we'll realise this in the loop when trying to receive from it.
160 let _ = tx.send(msg);
161 }
162 None => {
163 // A message to a protocol we did not initiate.
164 let _ = self
165 .errors
166 .tx
167 .send(SessionError::UnexpectedMessage(Box::new(msg)));
168 }
169 }
170 }
171
172 /// Wrap an outgoing request and relay to the connection channel.
173 fn handle_outgoing_request(&self, pid: P, msg: R) {
174 // Ignoring send errors here; it means the connection is closed,
175 // but we'll realize that in the loop when trying to receive.
176 let _ = self.mux.send(pid, msg);
177 }
178
179 /// Remove a session protocol that got closed on our side.
180 fn remove(&mut self, pid: P) {
181 self.demux.txs.remove(&pid);
182 }
183}
184
185#[cfg(test)]
186mod test {
187 use crate::multiplexing::MultiMessage;
188 use crate::session_channel_dyn;
189 use crate::test::protocols::greetings::*;
190 use crate::test::protocols::ping_pong::*;
191 use crate::test::protocols::*;
192 use crate::unbounded_channel;
193 use crate::Chan;
194 use crate::DynMessage;
195 use crate::SessionResult;
196 use std::time::Duration;
197 use tokio::time::timeout as timeout_after;
198
199 use super::OutgoingMultiChannel;
200
201 type PID = u8;
202 mod protos {
203 pub const PING_PONG: u8 = 1;
204 pub const GREETINGS: u8 = 2;
205 }
206
207 #[tokio::test]
208 async fn basics() {
209 let timeout = Duration::from_millis(100);
210
211 // Create an IncomingMultiChannel. It needs a pair of channels,
212 // one for requests going out and one for replies coming in. .
213 let (tx_in, rx_in) = unbounded_channel();
214 let (tx_out, mut rx_out) = unbounded_channel();
215 let channel = OutgoingMultiChannel::<PID, DynMessage>::new(tx_out, rx_in);
216
217 // Grab the control handle before moving the channel into a thread.
218 let control = channel.control();
219
220 // Start multiplexing in the background.
221 tokio::spawn(channel.run());
222
223 // Act as session server in the test. Start the clients in the background.
224 async fn cli(
225 c: Chan<greetings::Client, (), DynMessage>,
226 timeout: Duration,
227 ) -> SessionResult<()> {
228 let c = c.send(Hail("Punter".into()))?;
229 let (c, Greetings(_)) = c.recv(timeout).await?;
230 let c = c.enter();
231 let (c, AddResponse(_)) = c
232 .sel2()
233 .sel1()
234 .send(AddRequest(1))?
235 .send(AddRequest(2))?
236 .recv(timeout)
237 .await?;
238
239 c.zero()?.sel2().sel2().send(Quit)?.close()
240 }
241
242 let start_greeting = || {
243 let (c, (tx, rx)) = session_channel_dyn::<greetings::Client, DynMessage>();
244 // Start the session interacting with the greeting server in a thread.
245 tokio::spawn(cli(c, timeout));
246 // Register the channels with the multiplexer.
247 control.add(protos::GREETINGS, tx, rx);
248 };
249
250 start_greeting();
251
252 // See what the outgoing multiplexer wants to send.
253 let res = timeout_after(timeout, rx_out.recv())
254 .await
255 .unwrap()
256 .unwrap();
257
258 assert_eq!(res.protocol_id, protos::GREETINGS);
259 assert!(res.payload.downcast::<Hail>().is_ok());
260
261 // Now send an unexpected reply to a never requested ping. It should cause the whole thing to be closed.
262 tx_in
263 .send(MultiMessage::new(protos::PING_PONG, Pong))
264 .unwrap();
265
266 // Wait a bit for the message to take effect.
267 tokio::time::sleep(timeout / 2).await;
268
269 let res = tx_in.send(MultiMessage::new(
270 protos::GREETINGS,
271 Hail("Still there?".into()),
272 ));
273 assert!(res.is_err());
274 }
275}