multistream_select/
negotiated.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::protocol::{Protocol, MessageReader, Message, ProtocolError, HeaderLine};
22
23use futures::{prelude::*, io::{IoSlice, IoSliceMut}, ready};
24use pin_project::pin_project;
25use std::{error::Error, fmt, io, mem, pin::Pin, task::{Context, Poll}};
26
27/// An I/O stream that has settled on an (application-layer) protocol to use.
28///
29/// A `Negotiated` represents an I/O stream that has _settled_ on a protocol
30/// to use. In particular, it is not implied that all of the protocol negotiation
31/// frames have yet been sent and / or received, just that the selected protocol
32/// is fully determined. This is to allow the last protocol negotiation frames
33/// sent by a peer to be combined in a single write, possibly piggy-backing
34/// data from the negotiated protocol on top.
35///
36/// Reading from a `Negotiated` I/O stream that still has pending negotiation
37/// protocol data to send implicitly triggers flushing of all yet unsent data.
38#[pin_project]
39#[derive(Debug)]
40pub struct Negotiated<TInner> {
41    #[pin]
42    state: State<TInner>
43}
44
45/// A `Future` that waits on the completion of protocol negotiation.
46#[derive(Debug)]
47pub struct NegotiatedComplete<TInner> {
48    inner: Option<Negotiated<TInner>>,
49}
50
51impl<TInner> Future for NegotiatedComplete<TInner>
52where
53    // `Unpin` is required not because of implementation details but because we produce the
54    // `Negotiated` as the output of the future.
55    TInner: AsyncRead + AsyncWrite + Unpin,
56{
57    type Output = Result<Negotiated<TInner>, NegotiationError>;
58
59    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
60        let mut io = self.inner.take().expect("NegotiatedFuture called after completion.");
61        match Negotiated::poll(Pin::new(&mut io), cx) {
62            Poll::Pending => {
63                self.inner = Some(io);
64                Poll::Pending
65            },
66            Poll::Ready(Ok(())) => Poll::Ready(Ok(io)),
67            Poll::Ready(Err(err)) => {
68                self.inner = Some(io);
69                Poll::Ready(Err(err))
70            }
71        }
72    }
73}
74
75impl<TInner> Negotiated<TInner> {
76    /// Creates a `Negotiated` in state [`State::Completed`].
77    pub(crate) fn completed(io: TInner) -> Self {
78        Negotiated { state: State::Completed { io } }
79    }
80
81    /// Creates a `Negotiated` in state [`State::Expecting`] that is still
82    /// expecting confirmation of the given `protocol`.
83    pub(crate) fn expecting(
84        io: MessageReader<TInner>,
85        protocol: Protocol,
86        header: Option<HeaderLine>
87    ) -> Self {
88        Negotiated { state: State::Expecting { io, protocol, header } }
89    }
90
91    /// Polls the `Negotiated` for completion.
92    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), NegotiationError>>
93    where
94        TInner: AsyncRead + AsyncWrite + Unpin
95    {
96        // Flush any pending negotiation data.
97        match self.as_mut().poll_flush(cx) {
98            Poll::Ready(Ok(())) => {},
99            Poll::Pending => return Poll::Pending,
100            Poll::Ready(Err(e)) => {
101                // If the remote closed the stream, it is important to still
102                // continue reading the data that was sent, if any.
103                if e.kind() != io::ErrorKind::WriteZero {
104                    return Poll::Ready(Err(e.into()))
105                }
106            }
107        }
108
109        let mut this = self.project();
110
111        if let StateProj::Completed { .. } = this.state.as_mut().project() {
112             return Poll::Ready(Ok(()));
113        }
114
115        // Read outstanding protocol negotiation messages.
116        loop {
117            match mem::replace(&mut *this.state, State::Invalid) {
118                State::Expecting { mut io, header, protocol } => {
119                    let msg = match Pin::new(&mut io).poll_next(cx)? {
120                        Poll::Ready(Some(msg)) => msg,
121                        Poll::Pending => {
122                            *this.state = State::Expecting { io, header, protocol };
123                            return Poll::Pending
124                        },
125                        Poll::Ready(None) => {
126                            return Poll::Ready(Err(ProtocolError::IoError(
127                                io::ErrorKind::UnexpectedEof.into()).into()));
128                        }
129                    };
130
131                    if let Message::Header(h) = &msg {
132                        if Some(h) == header.as_ref() {
133                            *this.state = State::Expecting { io, protocol, header: None };
134                            continue
135                        }
136                    }
137
138                    if let Message::Protocol(p) = &msg {
139                        if p.as_ref() == protocol.as_ref() {
140                            log::debug!("Negotiated: Received confirmation for protocol: {}", p);
141                            *this.state = State::Completed { io: io.into_inner() };
142                            return Poll::Ready(Ok(()));
143                        }
144                    }
145
146                    return Poll::Ready(Err(NegotiationError::Failed));
147                }
148
149                _ => panic!("Negotiated: Invalid state")
150            }
151        }
152    }
153
154    /// Returns a [`NegotiatedComplete`] future that waits for protocol
155    /// negotiation to complete.
156    pub fn complete(self) -> NegotiatedComplete<TInner> {
157        NegotiatedComplete { inner: Some(self) }
158    }
159}
160
161/// The states of a `Negotiated` I/O stream.
162#[pin_project(project = StateProj)]
163#[derive(Debug)]
164enum State<R> {
165    /// In this state, a `Negotiated` is still expecting to
166    /// receive confirmation of the protocol it has optimistically
167    /// settled on.
168    Expecting {
169        /// The underlying I/O stream.
170        #[pin]
171        io: MessageReader<R>,
172        /// The expected negotiation header/preamble (i.e. multistream-select version),
173        /// if one is still expected to be received.
174        header: Option<HeaderLine>,
175        /// The expected application protocol (i.e. name and version).
176        protocol: Protocol,
177    },
178
179    /// In this state, a protocol has been agreed upon and I/O
180    /// on the underlying stream can commence.
181    Completed { #[pin] io: R },
182
183    /// Temporary state while moving the `io` resource from
184    /// `Expecting` to `Completed`.
185    Invalid,
186}
187
188impl<TInner> AsyncRead for Negotiated<TInner>
189where
190    TInner: AsyncRead + AsyncWrite + Unpin
191{
192    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8])
193        -> Poll<Result<usize, io::Error>>
194    {
195        loop {
196            if let StateProj::Completed { io } = self.as_mut().project().state.project() {
197                // If protocol negotiation is complete, commence with reading.
198                return io.poll_read(cx, buf);
199            }
200
201            // Poll the `Negotiated`, driving protocol negotiation to completion,
202            // including flushing of any remaining data.
203            match self.as_mut().poll(cx) {
204                Poll::Ready(Ok(())) => {},
205                Poll::Pending => return Poll::Pending,
206                Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
207            }
208        }
209    }
210
211    // TODO: implement once method is stabilized in the futures crate
212    /*unsafe fn initializer(&self) -> Initializer {
213        match &self.state {
214            State::Completed { io, .. } => io.initializer(),
215            State::Expecting { io, .. } => io.inner_ref().initializer(),
216            State::Invalid => panic!("Negotiated: Invalid state"),
217        }
218    }*/
219
220    fn poll_read_vectored(mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>])
221        -> Poll<Result<usize, io::Error>>
222    {
223        loop {
224            if let StateProj::Completed { io } = self.as_mut().project().state.project() {
225                // If protocol negotiation is complete, commence with reading.
226                return io.poll_read_vectored(cx, bufs)
227            }
228
229            // Poll the `Negotiated`, driving protocol negotiation to completion,
230            // including flushing of any remaining data.
231            match self.as_mut().poll(cx) {
232                Poll::Ready(Ok(())) => {},
233                Poll::Pending => return Poll::Pending,
234                Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
235            }
236        }
237    }
238}
239
240impl<TInner> AsyncWrite for Negotiated<TInner>
241where
242    TInner: AsyncWrite + AsyncRead + Unpin
243{
244    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
245        match self.project().state.project() {
246            StateProj::Completed { io } => io.poll_write(cx, buf),
247            StateProj::Expecting { io, .. } => io.poll_write(cx, buf),
248            StateProj::Invalid => panic!("Negotiated: Invalid state"),
249        }
250    }
251
252    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
253        match self.project().state.project() {
254            StateProj::Completed { io } => io.poll_flush(cx),
255            StateProj::Expecting { io, .. } => io.poll_flush(cx),
256            StateProj::Invalid => panic!("Negotiated: Invalid state"),
257        }
258    }
259
260    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
261        // Ensure all data has been flushed and expected negotiation messages
262        // have been received.
263        ready!(self.as_mut().poll(cx).map_err(Into::<io::Error>::into)?);
264        ready!(self.as_mut().poll_flush(cx).map_err(Into::<io::Error>::into)?);
265
266        // Continue with the shutdown of the underlying I/O stream.
267        match self.project().state.project() {
268            StateProj::Completed { io, .. } => io.poll_close(cx),
269            StateProj::Expecting { io, .. } => io.poll_close(cx),
270            StateProj::Invalid => panic!("Negotiated: Invalid state"),
271        }
272    }
273
274    fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>])
275        -> Poll<Result<usize, io::Error>>
276    {
277        match self.project().state.project() {
278            StateProj::Completed { io } => io.poll_write_vectored(cx, bufs),
279            StateProj::Expecting { io, .. } => io.poll_write_vectored(cx, bufs),
280            StateProj::Invalid => panic!("Negotiated: Invalid state"),
281        }
282    }
283}
284
285/// Error that can happen when negotiating a protocol with the remote.
286#[derive(Debug)]
287pub enum NegotiationError {
288    /// A protocol error occurred during the negotiation.
289    ProtocolError(ProtocolError),
290
291    /// Protocol negotiation failed because no protocol could be agreed upon.
292    Failed,
293}
294
295impl From<ProtocolError> for NegotiationError {
296    fn from(err: ProtocolError) -> NegotiationError {
297        NegotiationError::ProtocolError(err)
298    }
299}
300
301impl From<io::Error> for NegotiationError {
302    fn from(err: io::Error) -> NegotiationError {
303        ProtocolError::from(err).into()
304    }
305}
306
307impl From<NegotiationError> for io::Error {
308    fn from(err: NegotiationError) -> io::Error {
309        if let NegotiationError::ProtocolError(e) = err {
310            return e.into()
311        }
312        io::Error::new(io::ErrorKind::Other, err)
313    }
314}
315
316impl Error for NegotiationError {
317    fn source(&self) -> Option<&(dyn Error + 'static)> {
318        match self {
319            NegotiationError::ProtocolError(err) => Some(err),
320            _ => None,
321        }
322    }
323}
324
325impl fmt::Display for NegotiationError {
326    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
327        match self {
328            NegotiationError::ProtocolError(p) =>
329                fmt.write_fmt(format_args!("Protocol error: {}", p)),
330            NegotiationError::Failed =>
331                fmt.write_str("Protocol negotiation failed.")
332        }
333    }
334}