multistream_select/
dialer_select.rs

1// Copyright 2017 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Protocol negotiation strategies for the peer acting as the dialer.
22
23use crate::{Negotiated, NegotiationError, Version};
24use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, HeaderLine};
25
26use futures::{future::Either, prelude::*};
27use std::{convert::TryFrom as _, iter, mem, pin::Pin, task::{Context, Poll}};
28
29/// Returns a `Future` that negotiates a protocol on the given I/O stream
30/// for a peer acting as the _dialer_ (or _initiator_).
31///
32/// This function is given an I/O stream and a list of protocols and returns a
33/// computation that performs the protocol negotiation with the remote. The
34/// returned `Future` resolves with the name of the negotiated protocol and
35/// a [`Negotiated`] I/O stream.
36///
37/// The chosen message flow for protocol negotiation depends on the numbers of
38/// supported protocols given. That is, this function delegates to serial or
39/// parallel variant based on the number of protocols given. The number of
40/// protocols is determined through the `size_hint` of the given iterator and
41/// thus an inaccurate size estimate may result in a suboptimal choice.
42///
43/// Within the scope of this library, a dialer always commits to a specific
44/// multistream-select [`Version`], whereas a listener always supports
45/// all versions supported by this library. Frictionless multistream-select
46/// protocol upgrades may thus proceed by deployments with updated listeners,
47/// eventually followed by deployments of dialers choosing the newer protocol.
48pub fn dialer_select_proto<R, I>(
49    inner: R,
50    protocols: I,
51    version: Version
52) -> DialerSelectFuture<R, I::IntoIter>
53where
54    R: AsyncRead + AsyncWrite,
55    I: IntoIterator,
56    I::Item: AsRef<[u8]>
57{
58    let iter = protocols.into_iter();
59    // We choose between the "serial" and "parallel" strategies based on the number of protocols.
60    if iter.size_hint().1.map(|n| n <= 3).unwrap_or(false) {
61       Either::Left(dialer_select_proto_serial(inner, iter, version))
62    } else {
63        Either::Right(dialer_select_proto_parallel(inner, iter, version))
64    }
65}
66
67/// Future, returned by `dialer_select_proto`, which selects a protocol and dialer
68/// either trying protocols in-order, or by requesting all protocols supported
69/// by the remote upfront, from which the first protocol found in the dialer's
70/// list of protocols is selected.
71pub type DialerSelectFuture<R, I> = Either<DialerSelectSeq<R, I>, DialerSelectPar<R, I>>;
72
73/// Returns a `Future` that negotiates a protocol on the given I/O stream.
74///
75/// Just like [`dialer_select_proto`] but always using an iterative message flow,
76/// trying the given list of supported protocols one-by-one.
77///
78/// This strategy is preferable if the dialer only supports a few protocols.
79pub(crate) fn dialer_select_proto_serial<R, I>(
80    inner: R,
81    protocols: I,
82    version: Version
83) -> DialerSelectSeq<R, I::IntoIter>
84where
85    R: AsyncRead + AsyncWrite,
86    I: IntoIterator,
87    I::Item: AsRef<[u8]>
88{
89    let protocols = protocols.into_iter().peekable();
90    DialerSelectSeq {
91        version,
92        protocols,
93        state: SeqState::SendHeader {
94            io: MessageIO::new(inner),
95        }
96    }
97}
98
99/// Returns a `Future` that negotiates a protocol on the given I/O stream.
100///
101/// Just like [`dialer_select_proto`] but always using a message flow that first
102/// requests all supported protocols from the remote, selecting the first
103/// protocol from the given list of supported protocols that is supported
104/// by the remote.
105///
106/// This strategy may be beneficial if the dialer supports many protocols
107/// and it is unclear whether the remote supports one of the first few.
108pub(crate) fn dialer_select_proto_parallel<R, I>(
109    inner: R,
110    protocols: I,
111    version: Version
112) -> DialerSelectPar<R, I::IntoIter>
113where
114    R: AsyncRead + AsyncWrite,
115    I: IntoIterator,
116    I::Item: AsRef<[u8]>
117{
118    let protocols = protocols.into_iter();
119    DialerSelectPar {
120        version,
121        protocols,
122        state: ParState::SendHeader {
123            io: MessageIO::new(inner)
124        }
125    }
126}
127
128/// A `Future` returned by [`dialer_select_proto_serial`] which negotiates
129/// a protocol iteratively by considering one protocol after the other.
130#[pin_project::pin_project]
131pub struct DialerSelectSeq<R, I>
132where
133    R: AsyncRead + AsyncWrite,
134    I: Iterator,
135    I::Item: AsRef<[u8]>
136{
137    // TODO: It would be nice if eventually N = I::Item = Protocol.
138    protocols: iter::Peekable<I>,
139    state: SeqState<R, I::Item>,
140    version: Version,
141}
142
143enum SeqState<R, N>
144where
145    R: AsyncRead + AsyncWrite,
146    N: AsRef<[u8]>
147{
148    SendHeader { io: MessageIO<R>, },
149    SendProtocol { io: MessageIO<R>, protocol: N },
150    FlushProtocol { io: MessageIO<R>, protocol: N },
151    AwaitProtocol { io: MessageIO<R>, protocol: N },
152    Done
153}
154
155impl<R, I> Future for DialerSelectSeq<R, I>
156where
157    // The Unpin bound here is required because we produce a `Negotiated<R>` as the output.
158    // It also makes the implementation considerably easier to write.
159    R: AsyncRead + AsyncWrite + Unpin,
160    I: Iterator,
161    I::Item: AsRef<[u8]>
162{
163    type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
164
165    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
166        let this = self.project();
167
168        loop {
169            match mem::replace(this.state, SeqState::Done) {
170                SeqState::SendHeader { mut io } => {
171                    match Pin::new(&mut io).poll_ready(cx)? {
172                        Poll::Ready(()) => {},
173                        Poll::Pending => {
174                            *this.state = SeqState::SendHeader { io };
175                            return Poll::Pending
176                        },
177                    }
178
179                    let h = HeaderLine::from(*this.version);
180                    if let Err(err) = Pin::new(&mut io).start_send(Message::Header(h)) {
181                        return Poll::Ready(Err(From::from(err)));
182                    }
183
184                    let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
185
186                    // The dialer always sends the header and the first protocol
187                    // proposal in one go for efficiency.
188                    *this.state = SeqState::SendProtocol { io, protocol };
189                }
190
191                SeqState::SendProtocol { mut io, protocol } => {
192                    match Pin::new(&mut io).poll_ready(cx)? {
193                        Poll::Ready(()) => {},
194                        Poll::Pending => {
195                            *this.state = SeqState::SendProtocol { io, protocol };
196                            return Poll::Pending
197                        },
198                    }
199
200                    let p = Protocol::try_from(protocol.as_ref())?;
201                    if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) {
202                        return Poll::Ready(Err(From::from(err)));
203                    }
204                    log::debug!("Dialer: Proposed protocol: {}", p);
205
206                    if this.protocols.peek().is_some() {
207                        *this.state = SeqState::FlushProtocol { io, protocol }
208                    } else {
209                        match this.version {
210                            Version::V1 => *this.state = SeqState::FlushProtocol { io, protocol },
211                            // This is the only effect that `V1Lazy` has compared to `V1`:
212                            // Optimistically settling on the only protocol that
213                            // the dialer supports for this negotiation. Notably,
214                            // the dialer expects a regular `V1` response.
215                            Version::V1Lazy => {
216                                log::debug!("Dialer: Expecting proposed protocol: {}", p);
217                                let hl = HeaderLine::from(Version::V1Lazy);
218                                let io = Negotiated::expecting(io.into_reader(), p, Some(hl));
219                                return Poll::Ready(Ok((protocol, io)))
220                            }
221                        }
222                    }
223                }
224
225                SeqState::FlushProtocol { mut io, protocol } => {
226                    match Pin::new(&mut io).poll_flush(cx)? {
227                        Poll::Ready(()) => *this.state = SeqState::AwaitProtocol { io, protocol },
228                        Poll::Pending => {
229                            *this.state = SeqState::FlushProtocol { io, protocol };
230                            return Poll::Pending
231                        },
232                    }
233                }
234
235                SeqState::AwaitProtocol { mut io, protocol } => {
236                    let msg = match Pin::new(&mut io).poll_next(cx)? {
237                        Poll::Ready(Some(msg)) => msg,
238                        Poll::Pending => {
239                            *this.state = SeqState::AwaitProtocol { io, protocol };
240                            return Poll::Pending
241                        }
242                        // Treat EOF error as [`NegotiationError::Failed`], not as
243                        // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O
244                        // stream as a permissible way to "gracefully" fail a negotiation.
245                        Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
246                    };
247
248                    match msg {
249                        Message::Header(v) if v == HeaderLine::from(*this.version) => {
250                            *this.state = SeqState::AwaitProtocol { io, protocol };
251                        }
252                        Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => {
253                            log::debug!("Dialer: Received confirmation for protocol: {}", p);
254                            let io = Negotiated::completed(io.into_inner());
255                            return Poll::Ready(Ok((protocol, io)));
256                        }
257                        Message::NotAvailable => {
258                            log::debug!("Dialer: Received rejection of protocol: {}",
259                                String::from_utf8_lossy(protocol.as_ref()));
260                            let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
261                            *this.state = SeqState::SendProtocol { io, protocol }
262                        }
263                        _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
264                    }
265                }
266
267                SeqState::Done => panic!("SeqState::poll called after completion")
268            }
269        }
270    }
271}
272
273/// A `Future` returned by [`dialer_select_proto_parallel`] which negotiates
274/// a protocol selectively by considering all supported protocols of the remote
275/// "in parallel".
276#[pin_project::pin_project]
277pub struct DialerSelectPar<R, I>
278where
279    R: AsyncRead + AsyncWrite,
280    I: Iterator,
281    I::Item: AsRef<[u8]>
282{
283    protocols: I,
284    state: ParState<R, I::Item>,
285    version: Version,
286}
287
288enum ParState<R, N>
289where
290    R: AsyncRead + AsyncWrite,
291    N: AsRef<[u8]>
292{
293    SendHeader { io: MessageIO<R> },
294    SendProtocolsRequest { io: MessageIO<R> },
295    Flush { io: MessageIO<R> },
296    RecvProtocols { io: MessageIO<R> },
297    SendProtocol { io: MessageIO<R>, protocol: N },
298    Done
299}
300
301impl<R, I> Future for DialerSelectPar<R, I>
302where
303    // The Unpin bound here is required because we produce a `Negotiated<R>` as the output.
304    // It also makes the implementation considerably easier to write.
305    R: AsyncRead + AsyncWrite + Unpin,
306    I: Iterator,
307    I::Item: AsRef<[u8]>
308{
309    type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
310
311    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
312        let this = self.project();
313
314        loop {
315            match mem::replace(this.state, ParState::Done) {
316                ParState::SendHeader { mut io } => {
317                    match Pin::new(&mut io).poll_ready(cx)? {
318                        Poll::Ready(()) => {},
319                        Poll::Pending => {
320                            *this.state = ParState::SendHeader { io };
321                            return Poll::Pending
322                        },
323                    }
324
325                    let msg = Message::Header(HeaderLine::from(*this.version));
326                    if let Err(err) = Pin::new(&mut io).start_send(msg) {
327                        return Poll::Ready(Err(From::from(err)));
328                    }
329
330                    *this.state = ParState::SendProtocolsRequest { io };
331                }
332
333                ParState::SendProtocolsRequest { mut io } => {
334                    match Pin::new(&mut io).poll_ready(cx)? {
335                        Poll::Ready(()) => {},
336                        Poll::Pending => {
337                            *this.state = ParState::SendProtocolsRequest { io };
338                            return Poll::Pending
339                        },
340                    }
341
342                    if let Err(err) = Pin::new(&mut io).start_send(Message::ListProtocols) {
343                        return Poll::Ready(Err(From::from(err)));
344                    }
345
346                    log::debug!("Dialer: Requested supported protocols.");
347                    *this.state = ParState::Flush { io }
348                }
349
350                ParState::Flush { mut io } => {
351                    match Pin::new(&mut io).poll_flush(cx)? {
352                        Poll::Ready(()) => *this.state = ParState::RecvProtocols { io },
353                        Poll::Pending => {
354                            *this.state = ParState::Flush { io };
355                            return Poll::Pending
356                        },
357                    }
358                }
359
360                ParState::RecvProtocols { mut io } => {
361                    let msg = match Pin::new(&mut io).poll_next(cx)? {
362                        Poll::Ready(Some(msg)) => msg,
363                        Poll::Pending => {
364                            *this.state = ParState::RecvProtocols { io };
365                            return Poll::Pending
366                        }
367                        // Treat EOF error as [`NegotiationError::Failed`], not as
368                        // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O
369                        // stream as a permissible way to "gracefully" fail a negotiation.
370                        Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
371                    };
372
373                    match &msg {
374                        Message::Header(h) if h == &HeaderLine::from(*this.version) => {
375                            *this.state = ParState::RecvProtocols { io }
376                        }
377                        Message::Protocols(supported) => {
378                            let protocol = this.protocols.by_ref()
379                                .find(|p| supported.iter().any(|s|
380                                    s.as_ref() == p.as_ref()))
381                                .ok_or(NegotiationError::Failed)?;
382                            log::debug!("Dialer: Found supported protocol: {}",
383                                String::from_utf8_lossy(protocol.as_ref()));
384                            *this.state = ParState::SendProtocol { io, protocol };
385                        }
386                        _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
387                    }
388                }
389
390                ParState::SendProtocol { mut io, protocol } => {
391                    match Pin::new(&mut io).poll_ready(cx)? {
392                        Poll::Ready(()) => {},
393                        Poll::Pending => {
394                            *this.state = ParState::SendProtocol { io, protocol };
395                            return Poll::Pending
396                        },
397                    }
398
399                    let p = Protocol::try_from(protocol.as_ref())?;
400                    if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) {
401                        return Poll::Ready(Err(From::from(err)));
402                    }
403
404                    log::debug!("Dialer: Expecting proposed protocol: {}", p);
405                    let io = Negotiated::expecting(io.into_reader(), p, None);
406
407                    return Poll::Ready(Ok((protocol, io)))
408                }
409
410                ParState::Done => panic!("ParState::poll called after completion")
411            }
412        }
413    }
414}