cs_mwc_libp2p_swarm/protocols_handler/
select.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::upgrade::{SendWrapper, InboundUpgradeSend, OutboundUpgradeSend};
22use crate::protocols_handler::{
23    KeepAlive,
24    SubstreamProtocol,
25    IntoProtocolsHandler,
26    ProtocolsHandler,
27    ProtocolsHandlerEvent,
28    ProtocolsHandlerUpgrErr,
29};
30
31use mwc_libp2p_core::{
32    ConnectedPoint,
33    Multiaddr,
34    PeerId,
35    either::{EitherError, EitherOutput},
36    upgrade::{EitherUpgrade, SelectUpgrade, UpgradeError, NegotiationError, ProtocolError}
37};
38use std::{cmp, task::Context, task::Poll};
39
40/// Implementation of `IntoProtocolsHandler` that combines two protocols into one.
41#[derive(Debug, Clone)]
42pub struct IntoProtocolsHandlerSelect<TProto1, TProto2> {
43    /// The first protocol.
44    proto1: TProto1,
45    /// The second protocol.
46    proto2: TProto2,
47}
48
49impl<TProto1, TProto2> IntoProtocolsHandlerSelect<TProto1, TProto2> {
50    /// Builds a `IntoProtocolsHandlerSelect`.
51    pub(crate) fn new(proto1: TProto1, proto2: TProto2) -> Self {
52        IntoProtocolsHandlerSelect {
53            proto1,
54            proto2,
55        }
56    }
57}
58
59impl<TProto1, TProto2> IntoProtocolsHandler for IntoProtocolsHandlerSelect<TProto1, TProto2>
60where
61    TProto1: IntoProtocolsHandler,
62    TProto2: IntoProtocolsHandler,
63{
64    type Handler = ProtocolsHandlerSelect<TProto1::Handler, TProto2::Handler>;
65
66    fn into_handler(self, remote_peer_id: &PeerId, connected_point: &ConnectedPoint) -> Self::Handler {
67        ProtocolsHandlerSelect {
68            proto1: self.proto1.into_handler(remote_peer_id, connected_point),
69            proto2: self.proto2.into_handler(remote_peer_id, connected_point),
70        }
71    }
72
73    fn inbound_protocol(&self) -> <Self::Handler as ProtocolsHandler>::InboundProtocol {
74        SelectUpgrade::new(SendWrapper(self.proto1.inbound_protocol()), SendWrapper(self.proto2.inbound_protocol()))
75    }
76}
77
78/// Implementation of `ProtocolsHandler` that combines two protocols into one.
79#[derive(Debug, Clone)]
80pub struct ProtocolsHandlerSelect<TProto1, TProto2> {
81    /// The first protocol.
82    proto1: TProto1,
83    /// The second protocol.
84    proto2: TProto2,
85}
86
87impl<TProto1, TProto2> ProtocolsHandlerSelect<TProto1, TProto2> {
88    /// Builds a `ProtocolsHandlerSelect`.
89    pub(crate) fn new(proto1: TProto1, proto2: TProto2) -> Self {
90        ProtocolsHandlerSelect {
91            proto1,
92            proto2,
93        }
94    }
95}
96
97impl<TProto1, TProto2> ProtocolsHandler for ProtocolsHandlerSelect<TProto1, TProto2>
98where
99    TProto1: ProtocolsHandler,
100    TProto2: ProtocolsHandler,
101{
102    type InEvent = EitherOutput<TProto1::InEvent, TProto2::InEvent>;
103    type OutEvent = EitherOutput<TProto1::OutEvent, TProto2::OutEvent>;
104    type Error = EitherError<TProto1::Error, TProto2::Error>;
105    type InboundProtocol = SelectUpgrade<SendWrapper<<TProto1 as ProtocolsHandler>::InboundProtocol>, SendWrapper<<TProto2 as ProtocolsHandler>::InboundProtocol>>;
106    type OutboundProtocol = EitherUpgrade<SendWrapper<TProto1::OutboundProtocol>, SendWrapper<TProto2::OutboundProtocol>>;
107    type OutboundOpenInfo = EitherOutput<TProto1::OutboundOpenInfo, TProto2::OutboundOpenInfo>;
108    type InboundOpenInfo = (TProto1::InboundOpenInfo, TProto2::InboundOpenInfo);
109
110    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
111        let proto1 = self.proto1.listen_protocol();
112        let proto2 = self.proto2.listen_protocol();
113        let timeout = *std::cmp::max(proto1.timeout(), proto2.timeout());
114        let (_, u1, i1) = proto1.into_upgrade();
115        let (_, u2, i2) = proto2.into_upgrade();
116        let choice = SelectUpgrade::new(SendWrapper(u1), SendWrapper(u2));
117        SubstreamProtocol::new(choice, (i1, i2)).with_timeout(timeout)
118    }
119
120    fn inject_fully_negotiated_outbound(&mut self, protocol: <Self::OutboundProtocol as OutboundUpgradeSend>::Output, endpoint: Self::OutboundOpenInfo) {
121        match (protocol, endpoint) {
122            (EitherOutput::First(protocol), EitherOutput::First(info)) =>
123                self.proto1.inject_fully_negotiated_outbound(protocol, info),
124            (EitherOutput::Second(protocol), EitherOutput::Second(info)) =>
125                self.proto2.inject_fully_negotiated_outbound(protocol, info),
126            (EitherOutput::First(_), EitherOutput::Second(_)) =>
127                panic!("wrong API usage: the protocol doesn't match the upgrade info"),
128            (EitherOutput::Second(_), EitherOutput::First(_)) =>
129                panic!("wrong API usage: the protocol doesn't match the upgrade info")
130        }
131    }
132
133    fn inject_fully_negotiated_inbound(&mut self, protocol: <Self::InboundProtocol as InboundUpgradeSend>::Output, (i1, i2): Self::InboundOpenInfo) {
134        match protocol {
135            EitherOutput::First(protocol) =>
136                self.proto1.inject_fully_negotiated_inbound(protocol, i1),
137            EitherOutput::Second(protocol) =>
138                self.proto2.inject_fully_negotiated_inbound(protocol, i2)
139        }
140    }
141
142    fn inject_event(&mut self, event: Self::InEvent) {
143        match event {
144            EitherOutput::First(event) => self.proto1.inject_event(event),
145            EitherOutput::Second(event) => self.proto2.inject_event(event),
146        }
147    }
148
149    fn inject_address_change(&mut self, new_address: &Multiaddr) {
150        self.proto1.inject_address_change(new_address);
151        self.proto2.inject_address_change(new_address)
152    }
153
154    fn inject_dial_upgrade_error(&mut self, info: Self::OutboundOpenInfo, error: ProtocolsHandlerUpgrErr<<Self::OutboundProtocol as OutboundUpgradeSend>::Error>) {
155        match (info, error) {
156            (EitherOutput::First(info), ProtocolsHandlerUpgrErr::Timer) => {
157                self.proto1.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Timer)
158            },
159            (EitherOutput::First(info), ProtocolsHandlerUpgrErr::Timeout) => {
160                self.proto1.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Timeout)
161            },
162            (EitherOutput::First(info), ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(err))) => {
163                self.proto1.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(err)))
164            },
165            (EitherOutput::First(info), ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::A(err)))) => {
166                self.proto1.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(err)))
167            },
168            (EitherOutput::First(_), ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::B(_)))) => {
169                panic!("Wrong API usage; the upgrade error doesn't match the outbound open info");
170            },
171            (EitherOutput::Second(info), ProtocolsHandlerUpgrErr::Timeout) => {
172                self.proto2.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Timeout)
173            },
174            (EitherOutput::Second(info), ProtocolsHandlerUpgrErr::Timer) => {
175                self.proto2.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Timer)
176            },
177            (EitherOutput::Second(info), ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(err))) => {
178                self.proto2.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(err)))
179            },
180            (EitherOutput::Second(info), ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::B(err)))) => {
181                self.proto2.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(err)))
182            },
183            (EitherOutput::Second(_), ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::A(_)))) => {
184                panic!("Wrong API usage; the upgrade error doesn't match the outbound open info");
185            },
186        }
187    }
188
189    fn inject_listen_upgrade_error(&mut self, (i1, i2): Self::InboundOpenInfo, error: ProtocolsHandlerUpgrErr<<Self::InboundProtocol as InboundUpgradeSend>::Error>) {
190        match error {
191            ProtocolsHandlerUpgrErr::Timer => {
192                self.proto1.inject_listen_upgrade_error(i1, ProtocolsHandlerUpgrErr::Timer);
193                self.proto2.inject_listen_upgrade_error(i2, ProtocolsHandlerUpgrErr::Timer)
194            }
195            ProtocolsHandlerUpgrErr::Timeout => {
196                self.proto1.inject_listen_upgrade_error(i1, ProtocolsHandlerUpgrErr::Timeout);
197                self.proto2.inject_listen_upgrade_error(i2, ProtocolsHandlerUpgrErr::Timeout)
198            }
199            ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => {
200                self.proto1.inject_listen_upgrade_error(i1, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)));
201                self.proto2.inject_listen_upgrade_error(i2, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)));
202            }
203            ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::ProtocolError(e))) => {
204                let (e1, e2);
205                match e {
206                    ProtocolError::IoError(e) => {
207                        e1 = NegotiationError::ProtocolError(ProtocolError::IoError(e.kind().into()));
208                        e2 = NegotiationError::ProtocolError(ProtocolError::IoError(e))
209                    }
210                    ProtocolError::InvalidMessage => {
211                        e1 = NegotiationError::ProtocolError(ProtocolError::InvalidMessage);
212                        e2 = NegotiationError::ProtocolError(ProtocolError::InvalidMessage)
213                    }
214                    ProtocolError::InvalidProtocol => {
215                        e1 = NegotiationError::ProtocolError(ProtocolError::InvalidProtocol);
216                        e2 = NegotiationError::ProtocolError(ProtocolError::InvalidProtocol)
217                    }
218                    ProtocolError::TooManyProtocols => {
219                        e1 = NegotiationError::ProtocolError(ProtocolError::TooManyProtocols);
220                        e2 = NegotiationError::ProtocolError(ProtocolError::TooManyProtocols)
221                    }
222                }
223                self.proto1.inject_listen_upgrade_error(i1, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e1)));
224                self.proto2.inject_listen_upgrade_error(i2, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e2)))
225            }
226            ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::A(e))) => {
227                self.proto1.inject_listen_upgrade_error(i1, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e)))
228            }
229            ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::B(e))) => {
230                self.proto2.inject_listen_upgrade_error(i2, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e)))
231            }
232        }
233    }
234
235    fn connection_keep_alive(&self) -> KeepAlive {
236        cmp::max(self.proto1.connection_keep_alive(), self.proto2.connection_keep_alive())
237    }
238
239    fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ProtocolsHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::OutEvent, Self::Error>> {
240        match self.proto1.poll(cx) {
241            Poll::Ready(ProtocolsHandlerEvent::Custom(event)) => {
242                return Poll::Ready(ProtocolsHandlerEvent::Custom(EitherOutput::First(event)));
243            },
244            Poll::Ready(ProtocolsHandlerEvent::Close(event)) => {
245                return Poll::Ready(ProtocolsHandlerEvent::Close(EitherError::A(event)));
246            },
247            Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol }) => {
248                return Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest {
249                    protocol: protocol
250                        .map_upgrade(|u| EitherUpgrade::A(SendWrapper(u)))
251                        .map_info(EitherOutput::First)
252                });
253            },
254            Poll::Pending => ()
255        };
256
257        match self.proto2.poll(cx) {
258            Poll::Ready(ProtocolsHandlerEvent::Custom(event)) => {
259                return Poll::Ready(ProtocolsHandlerEvent::Custom(EitherOutput::Second(event)));
260            },
261            Poll::Ready(ProtocolsHandlerEvent::Close(event)) => {
262                return Poll::Ready(ProtocolsHandlerEvent::Close(EitherError::B(event)));
263            },
264            Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol }) => {
265                return Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest {
266                    protocol: protocol
267                        .map_upgrade(|u| EitherUpgrade::B(SendWrapper(u)))
268                        .map_info(EitherOutput::Second)
269                });
270            },
271            Poll::Pending => ()
272        };
273
274        Poll::Pending
275    }
276}