multistream_select/
listener_select.rs

1// Copyright 2017 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Protocol negotiation strategies for the peer acting as the listener
22//! in a multistream-select protocol negotiation.
23
24use crate::{Negotiated, NegotiationError};
25use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, HeaderLine};
26
27use futures::prelude::*;
28use smallvec::SmallVec;
29use std::{convert::TryFrom as _, iter::FromIterator, mem, pin::Pin, task::{Context, Poll}};
30
31/// Returns a `Future` that negotiates a protocol on the given I/O stream
32/// for a peer acting as the _listener_ (or _responder_).
33///
34/// This function is given an I/O stream and a list of protocols and returns a
35/// computation that performs the protocol negotiation with the remote. The
36/// returned `Future` resolves with the name of the negotiated protocol and
37/// a [`Negotiated`] I/O stream.
38pub fn listener_select_proto<R, I>(
39    inner: R,
40    protocols: I,
41) -> ListenerSelectFuture<R, I::Item>
42where
43    R: AsyncRead + AsyncWrite,
44    I: IntoIterator,
45    I::Item: AsRef<[u8]>
46{
47    let protocols = protocols.into_iter().filter_map(|n|
48        match Protocol::try_from(n.as_ref()) {
49            Ok(p) => Some((n, p)),
50            Err(e) => {
51                log::warn!("Listener: Ignoring invalid protocol: {} due to {}",
52                      String::from_utf8_lossy(n.as_ref()), e);
53                None
54            }
55        });
56    ListenerSelectFuture {
57        protocols: SmallVec::from_iter(protocols),
58        state: State::RecvHeader {
59            io: MessageIO::new(inner)
60        },
61        last_sent_na: false,
62    }
63}
64
65/// The `Future` returned by [`listener_select_proto`] that performs a
66/// multistream-select protocol negotiation on an underlying I/O stream.
67#[pin_project::pin_project]
68pub struct ListenerSelectFuture<R, N>
69where
70    R: AsyncRead + AsyncWrite,
71    N: AsRef<[u8]>
72{
73    // TODO: It would be nice if eventually N = Protocol, which has a
74    // few more implications on the API.
75    protocols: SmallVec<[(N, Protocol); 8]>,
76    state: State<R, N>,
77    /// Whether the last message sent was a protocol rejection (i.e. `na\n`).
78    ///
79    /// If the listener reads garbage or EOF after such a rejection,
80    /// the dialer is likely using `V1Lazy` and negotiation must be
81    /// considered failed, but not with a protocol violation or I/O
82    /// error.
83    last_sent_na: bool,
84}
85
86enum State<R, N>
87where
88    R: AsyncRead + AsyncWrite,
89    N: AsRef<[u8]>
90{
91    RecvHeader { io: MessageIO<R> },
92    SendHeader { io: MessageIO<R> },
93    RecvMessage { io: MessageIO<R> },
94    SendMessage {
95        io: MessageIO<R>,
96        message: Message,
97        protocol: Option<N>
98    },
99    Flush {
100        io: MessageIO<R>,
101        protocol: Option<N>
102    },
103    Done
104}
105
106impl<R, N> Future for ListenerSelectFuture<R, N>
107where
108    // The Unpin bound here is required because we produce a `Negotiated<R>` as the output.
109    // It also makes the implementation considerably easier to write.
110    R: AsyncRead + AsyncWrite + Unpin,
111    N: AsRef<[u8]> + Clone
112{
113    type Output = Result<(N, Negotiated<R>), NegotiationError>;
114
115    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
116        let this = self.project();
117
118        loop {
119            match mem::replace(this.state, State::Done) {
120                State::RecvHeader { mut io } => {
121                    match io.poll_next_unpin(cx) {
122                        Poll::Ready(Some(Ok(Message::Header(h)))) => {
123                            match h {
124                                HeaderLine::V1 => *this.state = State::SendHeader { io }
125                            }
126                        }
127                        Poll::Ready(Some(Ok(_))) => {
128                            return Poll::Ready(Err(ProtocolError::InvalidMessage.into()))
129                        },
130                        Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(From::from(err))),
131                        // Treat EOF error as [`NegotiationError::Failed`], not as
132                        // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O
133                        // stream as a permissible way to "gracefully" fail a negotiation.
134                        Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
135                        Poll::Pending => {
136                            *this.state = State::RecvHeader { io };
137                            return Poll::Pending
138                        }
139                    }
140                }
141
142                State::SendHeader { mut io } => {
143                    match Pin::new(&mut io).poll_ready(cx) {
144                        Poll::Pending => {
145                            *this.state = State::SendHeader { io };
146                            return Poll::Pending
147                        },
148                        Poll::Ready(Ok(())) => {},
149                        Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
150                    }
151
152                    let msg = Message::Header(HeaderLine::V1);
153                    if let Err(err) = Pin::new(&mut io).start_send(msg) {
154                        return Poll::Ready(Err(From::from(err)));
155                    }
156
157                    *this.state = State::Flush { io, protocol: None };
158                }
159
160                State::RecvMessage { mut io } => {
161                    let msg = match Pin::new(&mut io).poll_next(cx) {
162                        Poll::Ready(Some(Ok(msg))) => msg,
163                        // Treat EOF error as [`NegotiationError::Failed`], not as
164                        // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O
165                        // stream as a permissible way to "gracefully" fail a negotiation.
166                        //
167                        // This is e.g. important when a listener rejects a protocol with
168                        // [`Message::NotAvailable`] and the dialer does not have alternative
169                        // protocols to propose. Then the dialer will stop the negotiation and drop
170                        // the corresponding stream. As a listener this EOF should be interpreted as
171                        // a failed negotiation.
172                        Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
173                        Poll::Pending => {
174                            *this.state = State::RecvMessage { io };
175                            return Poll::Pending;
176                        }
177                        Poll::Ready(Some(Err(err))) => {
178                            if *this.last_sent_na {
179                                // When we read garbage or EOF after having already rejected a
180                                // protocol, the dialer is most likely using `V1Lazy` and has
181                                // optimistically settled on this protocol, so this is really a
182                                // failed negotiation, not a protocol violation. In this case
183                                // the dialer also raises `NegotiationError::Failed` when finally
184                                // reading the `N/A` response.
185                                if let ProtocolError::InvalidMessage = &err {
186                                    log::trace!("Listener: Negotiation failed with invalid \
187                                        message after protocol rejection.");
188                                    return Poll::Ready(Err(NegotiationError::Failed))
189                                }
190                                if let ProtocolError::IoError(e) = &err {
191                                    if e.kind() == std::io::ErrorKind::UnexpectedEof {
192                                        log::trace!("Listener: Negotiation failed with EOF \
193                                            after protocol rejection.");
194                                        return Poll::Ready(Err(NegotiationError::Failed))
195                                    }
196                                }
197                            }
198
199                            return Poll::Ready(Err(From::from(err)))
200                        }
201                    };
202
203                    match msg {
204                        Message::ListProtocols => {
205                            let supported = this.protocols.iter().map(|(_,p)| p).cloned().collect();
206                            let message = Message::Protocols(supported);
207                            *this.state = State::SendMessage { io, message, protocol: None }
208                        }
209                        Message::Protocol(p) => {
210                            let protocol = this.protocols.iter().find_map(|(name, proto)| {
211                                if &p == proto {
212                                    Some(name.clone())
213                                } else {
214                                    None
215                                }
216                            });
217
218                            let message = if protocol.is_some() {
219                                log::debug!("Listener: confirming protocol: {}", p);
220                                Message::Protocol(p.clone())
221                            } else {
222                                log::debug!("Listener: rejecting protocol: {}",
223                                    String::from_utf8_lossy(p.as_ref()));
224                                Message::NotAvailable
225                            };
226
227                            *this.state = State::SendMessage { io, message, protocol };
228                        }
229                        _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into()))
230                    }
231                }
232
233                State::SendMessage { mut io, message, protocol } => {
234                    match Pin::new(&mut io).poll_ready(cx) {
235                        Poll::Pending => {
236                            *this.state = State::SendMessage { io, message, protocol };
237                            return Poll::Pending
238                        },
239                        Poll::Ready(Ok(())) => {},
240                        Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
241                    }
242
243                    if let Message::NotAvailable = &message  {
244                        *this.last_sent_na = true;
245                    } else {
246                        *this.last_sent_na = false;
247                    }
248
249                    if let Err(err) = Pin::new(&mut io).start_send(message) {
250                        return Poll::Ready(Err(From::from(err)));
251                    }
252
253                    *this.state = State::Flush { io, protocol };
254                }
255
256                State::Flush { mut io, protocol } => {
257                    match Pin::new(&mut io).poll_flush(cx) {
258                        Poll::Pending => {
259                            *this.state = State::Flush { io, protocol };
260                            return Poll::Pending
261                        },
262                        Poll::Ready(Ok(())) => {
263                            // If a protocol has been selected, finish negotiation.
264                            // Otherwise expect to receive another message.
265                            match protocol {
266                                Some(protocol) => {
267                                    log::debug!("Listener: sent confirmed protocol: {}",
268                                        String::from_utf8_lossy(protocol.as_ref()));
269                                    let io = Negotiated::completed(io.into_inner());
270                                    return Poll::Ready(Ok((protocol, io)))
271                                }
272                                None => *this.state = State::RecvMessage { io }
273                            }
274                        }
275                        Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
276                    }
277                }
278
279                State::Done => panic!("State::poll called after completion")
280            }
281        }
282    }
283}