cs_mwc_libp2p_swarm/protocols_handler/
node_handler.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::upgrade::SendWrapper;
22use crate::protocols_handler::{
23    KeepAlive,
24    ProtocolsHandler,
25    IntoProtocolsHandler,
26    ProtocolsHandlerEvent,
27    ProtocolsHandlerUpgrErr
28};
29
30use futures::prelude::*;
31use futures::stream::FuturesUnordered;
32use mwc_libp2p_core::{
33    Multiaddr,
34    Connected,
35    connection::{
36        ConnectionHandler,
37        ConnectionHandlerEvent,
38        IntoConnectionHandler,
39        Substream,
40        SubstreamEndpoint,
41    },
42    muxing::StreamMuxerBox,
43    upgrade::{self, InboundUpgradeApply, OutboundUpgradeApply, UpgradeError}
44};
45use std::{error, fmt, pin::Pin, task::Context, task::Poll, time::Duration};
46use wasm_timer::{Delay, Instant};
47
48/// Prototype for a `NodeHandlerWrapper`.
49pub struct NodeHandlerWrapperBuilder<TIntoProtoHandler> {
50    /// The underlying handler.
51    handler: TIntoProtoHandler,
52    /// The substream upgrade protocol override, if any.
53    substream_upgrade_protocol_override: Option<upgrade::Version>,
54}
55
56impl<TIntoProtoHandler> NodeHandlerWrapperBuilder<TIntoProtoHandler>
57where
58    TIntoProtoHandler: IntoProtocolsHandler
59{
60    /// Builds a `NodeHandlerWrapperBuilder`.
61    pub(crate) fn new(handler: TIntoProtoHandler) -> Self {
62        NodeHandlerWrapperBuilder {
63            handler,
64            substream_upgrade_protocol_override: None,
65        }
66    }
67
68    pub(crate) fn with_substream_upgrade_protocol_override(
69        mut self,
70        version: Option<upgrade::Version>
71    ) -> Self {
72        self.substream_upgrade_protocol_override = version;
73        self
74    }
75}
76
77impl<TIntoProtoHandler, TProtoHandler> IntoConnectionHandler
78    for NodeHandlerWrapperBuilder<TIntoProtoHandler>
79where
80    TIntoProtoHandler: IntoProtocolsHandler<Handler = TProtoHandler>,
81    TProtoHandler: ProtocolsHandler,
82{
83    type Handler = NodeHandlerWrapper<TIntoProtoHandler::Handler>;
84
85    fn into_handler(self, connected: &Connected) -> Self::Handler {
86        NodeHandlerWrapper {
87            handler: self.handler.into_handler(&connected.peer_id, &connected.endpoint),
88            negotiating_in: Default::default(),
89            negotiating_out: Default::default(),
90            queued_dial_upgrades: Vec::new(),
91            unique_dial_upgrade_id: 0,
92            shutdown: Shutdown::None,
93            substream_upgrade_protocol_override: self.substream_upgrade_protocol_override,
94        }
95    }
96}
97
98// A `ConnectionHandler` for an underlying `ProtocolsHandler`.
99/// Wraps around an implementation of `ProtocolsHandler`, and implements `NodeHandler`.
100// TODO: add a caching system for protocols that are supported or not
101pub struct NodeHandlerWrapper<TProtoHandler>
102where
103    TProtoHandler: ProtocolsHandler,
104{
105    /// The underlying handler.
106    handler: TProtoHandler,
107    /// Futures that upgrade incoming substreams.
108    negotiating_in: FuturesUnordered<SubstreamUpgrade<
109        TProtoHandler::InboundOpenInfo,
110        InboundUpgradeApply<Substream<StreamMuxerBox>, SendWrapper<TProtoHandler::InboundProtocol>>,
111    >>,
112    /// Futures that upgrade outgoing substreams.
113    negotiating_out: FuturesUnordered<SubstreamUpgrade<
114        TProtoHandler::OutboundOpenInfo,
115        OutboundUpgradeApply<Substream<StreamMuxerBox>, SendWrapper<TProtoHandler::OutboundProtocol>>,
116    >>,
117    /// For each outbound substream request, how to upgrade it. The first element of the tuple
118    /// is the unique identifier (see `unique_dial_upgrade_id`).
119    queued_dial_upgrades: Vec<(u64, (upgrade::Version, SendWrapper<TProtoHandler::OutboundProtocol>))>,
120    /// Unique identifier assigned to each queued dial upgrade.
121    unique_dial_upgrade_id: u64,
122    /// The currently planned connection & handler shutdown.
123    shutdown: Shutdown,
124    /// The substream upgrade protocol override, if any.
125    substream_upgrade_protocol_override: Option<upgrade::Version>,
126}
127
128struct SubstreamUpgrade<UserData, Upgrade> {
129    user_data: Option<UserData>,
130    timeout: Delay,
131    upgrade: Upgrade,
132}
133
134impl<UserData, Upgrade> Unpin for SubstreamUpgrade<UserData, Upgrade> {}
135
136impl<UserData, Upgrade, UpgradeOutput, TUpgradeError> Future for SubstreamUpgrade<UserData, Upgrade>
137where
138    Upgrade: Future<Output = Result<UpgradeOutput, UpgradeError<TUpgradeError>>> + Unpin,
139{
140    type Output = (UserData, Result<UpgradeOutput, ProtocolsHandlerUpgrErr<TUpgradeError>>);
141
142    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
143        match self.timeout.poll_unpin(cx) {
144            Poll::Ready(Ok(_)) => return Poll::Ready((
145                self.user_data.take().expect("Future not to be polled again once ready."),
146                Err(ProtocolsHandlerUpgrErr::Timeout)),
147            ),
148            Poll::Ready(Err(_)) => return Poll::Ready((
149                self.user_data.take().expect("Future not to be polled again once ready."),
150                Err(ProtocolsHandlerUpgrErr::Timer)),
151            ),
152            Poll::Pending => {},
153        }
154
155        match self.upgrade.poll_unpin(cx) {
156            Poll::Ready(Ok(upgrade)) => Poll::Ready((
157                self.user_data.take().expect("Future not to be polled again once ready."),
158                Ok(upgrade),
159            )),
160            Poll::Ready(Err(err)) => Poll::Ready((
161                self.user_data.take().expect("Future not to be polled again once ready."),
162                Err(ProtocolsHandlerUpgrErr::Upgrade(err)),
163            )),
164            Poll::Pending => Poll::Pending,
165        }
166    }
167}
168
169
170/// The options for a planned connection & handler shutdown.
171///
172/// A shutdown is planned anew based on the the return value of
173/// [`ProtocolsHandler::connection_keep_alive`] of the underlying handler
174/// after every invocation of [`ProtocolsHandler::poll`].
175///
176/// A planned shutdown is always postponed for as long as there are ingoing
177/// or outgoing substreams being negotiated, i.e. it is a graceful, "idle"
178/// shutdown.
179enum Shutdown {
180    /// No shutdown is planned.
181    None,
182    /// A shut down is planned as soon as possible.
183    Asap,
184    /// A shut down is planned for when a `Delay` has elapsed.
185    Later(Delay, Instant)
186}
187
188/// Error generated by the `NodeHandlerWrapper`.
189#[derive(Debug)]
190pub enum NodeHandlerWrapperError<TErr> {
191    /// The connection handler encountered an error.
192    Handler(TErr),
193    /// The connection keep-alive timeout expired.
194    KeepAliveTimeout,
195}
196
197impl<TErr> From<TErr> for NodeHandlerWrapperError<TErr> {
198    fn from(err: TErr) -> NodeHandlerWrapperError<TErr> {
199        NodeHandlerWrapperError::Handler(err)
200    }
201}
202
203impl<TErr> fmt::Display for NodeHandlerWrapperError<TErr>
204where
205    TErr: fmt::Display
206{
207    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208        match self {
209            NodeHandlerWrapperError::Handler(err) => write!(f, "{}", err),
210            NodeHandlerWrapperError::KeepAliveTimeout =>
211                write!(f, "Connection closed due to expired keep-alive timeout."),
212        }
213    }
214}
215
216impl<TErr> error::Error for NodeHandlerWrapperError<TErr>
217where
218    TErr: error::Error + 'static
219{
220    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
221        match self {
222            NodeHandlerWrapperError::Handler(err) => Some(err),
223            NodeHandlerWrapperError::KeepAliveTimeout => None,
224        }
225    }
226}
227
228impl<TProtoHandler> ConnectionHandler for NodeHandlerWrapper<TProtoHandler>
229where
230    TProtoHandler: ProtocolsHandler,
231{
232    type InEvent = TProtoHandler::InEvent;
233    type OutEvent = TProtoHandler::OutEvent;
234    type Error = NodeHandlerWrapperError<TProtoHandler::Error>;
235    type Substream = Substream<StreamMuxerBox>;
236    // The first element of the tuple is the unique upgrade identifier
237    // (see `unique_dial_upgrade_id`).
238    type OutboundOpenInfo = (u64, TProtoHandler::OutboundOpenInfo, Duration);
239
240    fn inject_substream(
241        &mut self,
242        substream: Self::Substream,
243        endpoint: SubstreamEndpoint<Self::OutboundOpenInfo>,
244    ) {
245        match endpoint {
246            SubstreamEndpoint::Listener => {
247                let protocol = self.handler.listen_protocol();
248                let timeout = *protocol.timeout();
249                let (_, upgrade, user_data) = protocol.into_upgrade();
250                let upgrade = upgrade::apply_inbound(substream, SendWrapper(upgrade));
251                let timeout = Delay::new(timeout);
252                self.negotiating_in.push(SubstreamUpgrade {
253                    user_data: Some(user_data),
254                    timeout,
255                    upgrade,
256                });
257            }
258            SubstreamEndpoint::Dialer((upgrade_id, user_data, timeout)) => {
259                let pos = match self
260                    .queued_dial_upgrades
261                    .iter()
262                    .position(|(id, _)| id == &upgrade_id)
263                {
264                    Some(p) => p,
265                    None => {
266                        debug_assert!(false, "Received an upgrade with an invalid upgrade ID");
267                        return;
268                    }
269                };
270
271                let (_, (mut version, upgrade)) = self.queued_dial_upgrades.remove(pos);
272                if let Some(v) = self.substream_upgrade_protocol_override {
273                    if v != version {
274                        log::debug!("Substream upgrade protocol override: {:?} -> {:?}", version, v);
275                        version = v;
276                    }
277                }
278                let upgrade = upgrade::apply_outbound(substream, upgrade, version);
279                let timeout = Delay::new(timeout);
280                self.negotiating_out.push(SubstreamUpgrade {
281                    user_data: Some(user_data),
282                    timeout,
283                    upgrade,
284                });
285            }
286        }
287    }
288
289    fn inject_event(&mut self, event: Self::InEvent) {
290        self.handler.inject_event(event);
291    }
292
293    fn inject_address_change(&mut self, new_address: &Multiaddr) {
294        self.handler.inject_address_change(new_address);
295    }
296
297    fn poll(&mut self, cx: &mut Context<'_>) -> Poll<
298        Result<ConnectionHandlerEvent<Self::OutboundOpenInfo, Self::OutEvent>, Self::Error>
299    > {
300        while let Poll::Ready(Some((user_data, res))) = self.negotiating_in.poll_next_unpin(cx) {
301            match res {
302                Ok(upgrade) => self.handler.inject_fully_negotiated_inbound(upgrade, user_data),
303                Err(err) => self.handler.inject_listen_upgrade_error(user_data, err),
304            }
305        }
306
307        while let Poll::Ready(Some((user_data, res))) = self.negotiating_out.poll_next_unpin(cx) {
308            match res {
309                Ok(upgrade) => self.handler.inject_fully_negotiated_outbound(upgrade, user_data),
310                Err(err) => self.handler.inject_dial_upgrade_error(user_data, err),
311            }
312        }
313
314        // Poll the handler at the end so that we see the consequences of the method
315        // calls on `self.handler`.
316        let poll_result = self.handler.poll(cx);
317
318        // Ask the handler whether it wants the connection (and the handler itself)
319        // to be kept alive, which determines the planned shutdown, if any.
320        match (&mut self.shutdown, self.handler.connection_keep_alive()) {
321            (Shutdown::Later(timer, deadline), KeepAlive::Until(t)) =>
322                if *deadline != t {
323                    *deadline = t;
324                    timer.reset_at(t)
325                },
326            (_, KeepAlive::Until(t)) => self.shutdown = Shutdown::Later(Delay::new_at(t), t),
327            (_, KeepAlive::No) => self.shutdown = Shutdown::Asap,
328            (_, KeepAlive::Yes) => self.shutdown = Shutdown::None
329        };
330
331        match poll_result {
332            Poll::Ready(ProtocolsHandlerEvent::Custom(event)) => {
333                return Poll::Ready(Ok(ConnectionHandlerEvent::Custom(event)));
334            }
335            Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol }) => {
336                let id = self.unique_dial_upgrade_id;
337                let timeout = *protocol.timeout();
338                self.unique_dial_upgrade_id += 1;
339                let (version, upgrade, info) = protocol.into_upgrade();
340                self.queued_dial_upgrades.push((id, (version, SendWrapper(upgrade))));
341                return Poll::Ready(Ok(
342                    ConnectionHandlerEvent::OutboundSubstreamRequest((id, info, timeout)),
343                ));
344            }
345            Poll::Ready(ProtocolsHandlerEvent::Close(err)) => return Poll::Ready(Err(err.into())),
346            Poll::Pending => (),
347        };
348
349        // Check if the connection (and handler) should be shut down.
350        // As long as we're still negotiating substreams, shutdown is always postponed.
351        if self.negotiating_in.is_empty() && self.negotiating_out.is_empty() {
352            match self.shutdown {
353                Shutdown::None => {},
354                Shutdown::Asap => return Poll::Ready(Err(NodeHandlerWrapperError::KeepAliveTimeout)),
355                Shutdown::Later(ref mut delay, _) => match Future::poll(Pin::new(delay), cx) {
356                    Poll::Ready(_) => return Poll::Ready(Err(NodeHandlerWrapperError::KeepAliveTimeout)),
357                    Poll::Pending => {}
358                }
359            }
360        }
361
362        Poll::Pending
363    }
364}