cs_mwc_libp2p_swarm/protocols_handler/
multi.rs

1// Copyright 2020 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! A [`ProtocolsHandler`] implementation that combines multiple other `ProtocolsHandler`s
22//! indexed by some key.
23
24use crate::NegotiatedSubstream;
25use crate::protocols_handler::{
26    KeepAlive,
27    IntoProtocolsHandler,
28    ProtocolsHandler,
29    ProtocolsHandlerEvent,
30    ProtocolsHandlerUpgrErr,
31    SubstreamProtocol
32};
33use crate::upgrade::{
34    InboundUpgradeSend,
35    OutboundUpgradeSend,
36    UpgradeInfoSend
37};
38use futures::{future::BoxFuture, prelude::*};
39use mwc_libp2p_core::{ConnectedPoint, Multiaddr, PeerId};
40use mwc_libp2p_core::upgrade::{self, ProtocolName, UpgradeError, NegotiationError, ProtocolError};
41use rand::Rng;
42use std::{
43    cmp,
44    collections::{HashMap, HashSet},
45    error,
46    fmt,
47    hash::Hash,
48    iter::{self, FromIterator},
49    task::{Context, Poll},
50    time::Duration
51};
52
53/// A [`ProtocolsHandler`] for multiple `ProtocolsHandler`s of the same type.
54#[derive(Clone)]
55pub struct MultiHandler<K, H> {
56    handlers: HashMap<K, H>
57}
58
59impl<K, H> fmt::Debug for MultiHandler<K, H>
60where
61    K: fmt::Debug + Eq + Hash,
62    H: fmt::Debug
63{
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        f.debug_struct("MultiHandler")
66            .field("handlers", &self.handlers)
67            .finish()
68    }
69}
70
71impl<K, H> MultiHandler<K, H>
72where
73    K: Hash + Eq,
74    H: ProtocolsHandler
75{
76    /// Create and populate a `MultiHandler` from the given handler iterator.
77    ///
78    /// It is an error for any two protocols handlers to share the same protocol name.
79    ///
80    /// > **Note**: All handlers should use the same [`upgrade::Version`] for
81    /// > the inbound and outbound [`SubstreamProtocol`]s.
82    pub fn try_from_iter<I>(iter: I) -> Result<Self, DuplicateProtonameError>
83    where
84        I: IntoIterator<Item = (K, H)>
85    {
86        let m = MultiHandler { handlers: HashMap::from_iter(iter) };
87        uniq_proto_names(m.handlers.values().map(|h| h.listen_protocol().into_upgrade().1))?;
88        Ok(m)
89    }
90}
91
92impl<K, H> ProtocolsHandler for MultiHandler<K, H>
93where
94    K: Clone + Hash + Eq + Send + 'static,
95    H: ProtocolsHandler,
96    H::InboundProtocol: InboundUpgradeSend,
97    H::OutboundProtocol: OutboundUpgradeSend
98{
99    type InEvent = (K, <H as ProtocolsHandler>::InEvent);
100    type OutEvent = (K, <H as ProtocolsHandler>::OutEvent);
101    type Error = <H as ProtocolsHandler>::Error;
102    type InboundProtocol = Upgrade<K, <H as ProtocolsHandler>::InboundProtocol>;
103    type OutboundProtocol = <H as ProtocolsHandler>::OutboundProtocol;
104    type InboundOpenInfo = Info<K, <H as ProtocolsHandler>::InboundOpenInfo>;
105    type OutboundOpenInfo = (K, <H as ProtocolsHandler>::OutboundOpenInfo);
106
107    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
108        let (upgrade, info, timeout, version) = self.handlers.iter()
109            .map(|(key, handler)| {
110                let proto = handler.listen_protocol();
111                let timeout = *proto.timeout();
112                let (version, upgrade, info) = proto.into_upgrade();
113                (key.clone(), (version, upgrade, info, timeout))
114            })
115            .fold((Upgrade::new(), Info::new(), Duration::from_secs(0), None),
116                |(mut upg, mut inf, mut timeout, mut version), (k, (v, u, i, t))| {
117                    upg.upgrades.push((k.clone(), u));
118                    inf.infos.push((k, i));
119                    timeout = cmp::max(timeout, t);
120                    version = version.map_or(Some(v), |vv|
121                        if v != vv {
122                            // Different upgrade (i.e. protocol negotiation) protocol
123                            // versions are usually incompatible and not negotiated
124                            // themselves, so a protocol upgrade may fail.
125                            log::warn!("Differing upgrade versions. Defaulting to V1.");
126                            Some(upgrade::Version::V1)
127                        } else {
128                            Some(v)
129                        });
130                    (upg, inf, timeout, version)
131                }
132            );
133        SubstreamProtocol::new(upgrade, info)
134            .with_timeout(timeout)
135            .with_upgrade_protocol(version.unwrap_or(upgrade::Version::V1))
136    }
137
138    fn inject_fully_negotiated_outbound (
139        &mut self,
140        protocol: <Self::OutboundProtocol as OutboundUpgradeSend>::Output,
141        (key, arg): Self::OutboundOpenInfo
142    ) {
143        if let Some(h) = self.handlers.get_mut(&key) {
144            h.inject_fully_negotiated_outbound(protocol, arg)
145        } else {
146            log::error!("inject_fully_negotiated_outbound: no handler for key")
147        }
148    }
149
150    fn inject_fully_negotiated_inbound (
151        &mut self,
152        (key, arg): <Self::InboundProtocol as InboundUpgradeSend>::Output,
153        mut info: Self::InboundOpenInfo
154    ) {
155        if let Some(h) = self.handlers.get_mut(&key) {
156            if let Some(i) = info.take(&key) {
157                h.inject_fully_negotiated_inbound(arg, i)
158            }
159        } else {
160            log::error!("inject_fully_negotiated_inbound: no handler for key")
161        }
162    }
163
164    fn inject_event(&mut self, (key, event): Self::InEvent) {
165        if let Some(h) = self.handlers.get_mut(&key) {
166            h.inject_event(event)
167        } else {
168            log::error!("inject_event: no handler for key")
169        }
170    }
171
172    fn inject_address_change(&mut self, addr: &Multiaddr) {
173        for h in self.handlers.values_mut() {
174            h.inject_address_change(addr)
175        }
176    }
177
178    fn inject_dial_upgrade_error (
179        &mut self,
180        (key, arg): Self::OutboundOpenInfo,
181        error: ProtocolsHandlerUpgrErr<<Self::OutboundProtocol as OutboundUpgradeSend>::Error>
182    ) {
183        if let Some(h) = self.handlers.get_mut(&key) {
184            h.inject_dial_upgrade_error(arg, error)
185        } else {
186            log::error!("inject_dial_upgrade_error: no handler for protocol")
187        }
188    }
189
190    fn inject_listen_upgrade_error(
191        &mut self,
192        mut info: Self::InboundOpenInfo,
193        error: ProtocolsHandlerUpgrErr<<Self::InboundProtocol as InboundUpgradeSend>::Error>
194    ) {
195        match error {
196            ProtocolsHandlerUpgrErr::Timer =>
197                for (k, h) in &mut self.handlers {
198                    if let Some(i) = info.take(k) {
199                        h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Timer)
200                    }
201                }
202            ProtocolsHandlerUpgrErr::Timeout =>
203                for (k, h) in &mut self.handlers {
204                    if let Some(i) = info.take(k) {
205                        h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Timeout)
206                    }
207                }
208            ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) =>
209                for (k, h) in &mut self.handlers {
210                    if let Some(i) = info.take(k) {
211                        h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)))
212                    }
213                }
214            ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::ProtocolError(e))) =>
215                match e {
216                    ProtocolError::IoError(e) =>
217                        for (k, h) in &mut self.handlers {
218                            if let Some(i) = info.take(k) {
219                                let e = NegotiationError::ProtocolError(ProtocolError::IoError(e.kind().into()));
220                                h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)))
221                            }
222                        }
223                    ProtocolError::InvalidMessage =>
224                        for (k, h) in &mut self.handlers {
225                            if let Some(i) = info.take(k) {
226                                let e = NegotiationError::ProtocolError(ProtocolError::InvalidMessage);
227                                h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)))
228                            }
229                        }
230                    ProtocolError::InvalidProtocol =>
231                        for (k, h) in &mut self.handlers {
232                            if let Some(i) = info.take(k) {
233                                let e = NegotiationError::ProtocolError(ProtocolError::InvalidProtocol);
234                                h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)))
235                            }
236                        }
237                    ProtocolError::TooManyProtocols =>
238                        for (k, h) in &mut self.handlers {
239                            if let Some(i) = info.take(k) {
240                                let e = NegotiationError::ProtocolError(ProtocolError::TooManyProtocols);
241                                h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)))
242                            }
243                        }
244                }
245            ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply((k, e))) =>
246                if let Some(h) = self.handlers.get_mut(&k) {
247                    if let Some(i) = info.take(&k) {
248                        h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e)))
249                    }
250                }
251        }
252    }
253
254    fn connection_keep_alive(&self) -> KeepAlive {
255        self.handlers.values()
256            .map(|h| h.connection_keep_alive())
257            .max()
258            .unwrap_or(KeepAlive::No)
259    }
260
261    fn poll(&mut self, cx: &mut Context<'_>)
262        -> Poll<ProtocolsHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::OutEvent, Self::Error>>
263    {
264        // Calling `gen_range(0, 0)` (see below) would panic, so we have return early to avoid
265        // that situation.
266        if self.handlers.is_empty() {
267            return Poll::Pending;
268        }
269
270        // Not always polling handlers in the same order should give anyone the chance to make progress.
271        let pos = rand::thread_rng().gen_range(0, self.handlers.len());
272
273        for (k, h) in self.handlers.iter_mut().skip(pos) {
274            if let Poll::Ready(e) = h.poll(cx) {
275                let e = e.map_outbound_open_info(|i| (k.clone(), i)).map_custom(|p| (k.clone(), p));
276                return Poll::Ready(e)
277            }
278        }
279
280        for (k, h) in self.handlers.iter_mut().take(pos) {
281            if let Poll::Ready(e) = h.poll(cx) {
282                let e = e.map_outbound_open_info(|i| (k.clone(), i)).map_custom(|p| (k.clone(), p));
283                return Poll::Ready(e)
284            }
285        }
286
287        Poll::Pending
288    }
289}
290
291/// A [`IntoProtocolsHandler`] for multiple other `IntoProtocolsHandler`s.
292#[derive(Clone)]
293pub struct IntoMultiHandler<K, H> {
294    handlers: HashMap<K, H>
295}
296
297impl<K, H> fmt::Debug for IntoMultiHandler<K, H>
298where
299    K: fmt::Debug + Eq + Hash,
300    H: fmt::Debug
301{
302    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303        f.debug_struct("IntoMultiHandler")
304            .field("handlers", &self.handlers)
305            .finish()
306    }
307}
308
309
310impl<K, H> IntoMultiHandler<K, H>
311where
312    K: Hash + Eq,
313    H: IntoProtocolsHandler
314{
315    /// Create and populate an `IntoMultiHandler` from the given iterator.
316    ///
317    /// It is an error for any two protocols handlers to share the same protocol name.
318    ///
319    /// > **Note**: All handlers should use the same [`upgrade::Version`] for
320    /// > the inbound and outbound [`SubstreamProtocol`]s.
321    pub fn try_from_iter<I>(iter: I) -> Result<Self, DuplicateProtonameError>
322    where
323        I: IntoIterator<Item = (K, H)>
324    {
325        let m = IntoMultiHandler { handlers: HashMap::from_iter(iter) };
326        uniq_proto_names(m.handlers.values().map(|h| h.inbound_protocol()))?;
327        Ok(m)
328    }
329}
330
331impl<K, H> IntoProtocolsHandler for IntoMultiHandler<K, H>
332where
333    K: Clone + Eq + Hash + Send + 'static,
334    H: IntoProtocolsHandler
335{
336    type Handler = MultiHandler<K, H::Handler>;
337
338    fn into_handler(self, p: &PeerId, c: &ConnectedPoint) -> Self::Handler {
339        MultiHandler {
340            handlers: self.handlers.into_iter()
341                .map(|(k, h)| (k, h.into_handler(p, c)))
342                .collect()
343        }
344    }
345
346    fn inbound_protocol(&self) -> <Self::Handler as ProtocolsHandler>::InboundProtocol {
347        Upgrade {
348            upgrades: self.handlers.iter()
349                .map(|(k, h)| (k.clone(), h.inbound_protocol()))
350                .collect()
351        }
352    }
353}
354
355/// Index and protocol name pair used as `UpgradeInfo::Info`.
356#[derive(Debug, Clone)]
357pub struct IndexedProtoName<H>(usize, H);
358
359impl<H: ProtocolName> ProtocolName for IndexedProtoName<H> {
360    fn protocol_name(&self) -> &[u8] {
361        self.1.protocol_name()
362    }
363}
364
365/// The aggregated `InboundOpenInfo`s of supported inbound substream protocols.
366#[derive(Clone)]
367pub struct Info<K, I> {
368    infos: Vec<(K, I)>
369}
370
371impl<K: Eq, I> Info<K, I> {
372    fn new() -> Self {
373        Info { infos: Vec::new() }
374    }
375
376    pub fn take(&mut self, k: &K) -> Option<I> {
377        if let Some(p) = self.infos.iter().position(|(key, _)| key == k) {
378            return Some(self.infos.remove(p).1)
379        }
380        None
381    }
382}
383
384/// Inbound and outbound upgrade for all `ProtocolsHandler`s.
385#[derive(Clone)]
386pub struct Upgrade<K, H> {
387    upgrades: Vec<(K, H)>
388}
389
390impl<K, H> Upgrade<K, H> {
391    fn new() -> Self {
392        Upgrade { upgrades: Vec::new() }
393    }
394}
395
396impl<K, H> fmt::Debug for Upgrade<K, H>
397where
398    K: fmt::Debug + Eq + Hash,
399    H: fmt::Debug
400{
401    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
402        f.debug_struct("Upgrade")
403            .field("upgrades", &self.upgrades)
404            .finish()
405    }
406}
407
408impl<K, H> UpgradeInfoSend for Upgrade<K, H>
409where
410    H: UpgradeInfoSend,
411    K: Send + 'static
412{
413    type Info = IndexedProtoName<H::Info>;
414    type InfoIter = std::vec::IntoIter<Self::Info>;
415
416    fn protocol_info(&self) -> Self::InfoIter {
417        self.upgrades.iter().enumerate()
418            .map(|(i, (_, h))| iter::repeat(i).zip(h.protocol_info()))
419            .flatten()
420            .map(|(i, h)| IndexedProtoName(i, h))
421            .collect::<Vec<_>>()
422            .into_iter()
423    }
424}
425
426impl<K, H> InboundUpgradeSend for Upgrade<K, H>
427where
428    H: InboundUpgradeSend,
429    K: Send + 'static
430{
431    type Output = (K, <H as InboundUpgradeSend>::Output);
432    type Error  = (K, <H as InboundUpgradeSend>::Error);
433    type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
434
435    fn upgrade_inbound(mut self, resource: NegotiatedSubstream, info: Self::Info) -> Self::Future {
436        let IndexedProtoName(index, info) = info;
437        let (key, upgrade) = self.upgrades.remove(index);
438        upgrade.upgrade_inbound(resource, info)
439            .map(move |out| {
440                match out {
441                    Ok(o) => Ok((key, o)),
442                    Err(e) => Err((key, e))
443                }
444            })
445            .boxed()
446    }
447}
448
449impl<K, H> OutboundUpgradeSend for Upgrade<K, H>
450where
451    H: OutboundUpgradeSend,
452    K: Send + 'static
453{
454    type Output = (K, <H as OutboundUpgradeSend>::Output);
455    type Error  = (K, <H as OutboundUpgradeSend>::Error);
456    type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
457
458    fn upgrade_outbound(mut self, resource: NegotiatedSubstream, info: Self::Info) -> Self::Future {
459        let IndexedProtoName(index, info) = info;
460        let (key, upgrade) = self.upgrades.remove(index);
461        upgrade.upgrade_outbound(resource, info)
462            .map(move |out| {
463                match out {
464                    Ok(o) => Ok((key, o)),
465                    Err(e) => Err((key, e))
466                }
467            })
468            .boxed()
469    }
470}
471
472/// Check that no two protocol names are equal.
473fn uniq_proto_names<I, T>(iter: I) -> Result<(), DuplicateProtonameError>
474where
475    I: Iterator<Item = T>,
476    T: UpgradeInfoSend
477{
478    let mut set = HashSet::new();
479    for infos in iter {
480        for i in infos.protocol_info() {
481            let v = Vec::from(i.protocol_name());
482            if set.contains(&v) {
483                return Err(DuplicateProtonameError(v))
484            } else {
485                set.insert(v);
486            }
487        }
488    }
489    Ok(())
490}
491
492/// It is an error if two handlers share the same protocol name.
493#[derive(Debug, Clone)]
494pub struct DuplicateProtonameError(Vec<u8>);
495
496impl DuplicateProtonameError {
497    /// The protocol name bytes that occured in more than one handler.
498    pub fn protocol_name(&self) -> &[u8] {
499        &self.0
500    }
501}
502
503impl fmt::Display for DuplicateProtonameError {
504    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
505        if let Ok(s) = std::str::from_utf8(&self.0) {
506            write!(f, "duplicate protocol name: {}", s)
507        } else {
508            write!(f, "duplicate protocol name: {:?}", self.0)
509        }
510    }
511}
512
513impl error::Error for DuplicateProtonameError {}