libp2p_swarm/protocols_handler/
multi.rs

1// Copyright 2020 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! A [`ProtocolsHandler`] implementation that combines multiple other `ProtocolsHandler`s
22//! indexed by some key.
23
24use crate::NegotiatedSubstream;
25use crate::protocols_handler::{
26    KeepAlive,
27    IntoProtocolsHandler,
28    ProtocolsHandler,
29    ProtocolsHandlerEvent,
30    ProtocolsHandlerUpgrErr,
31    SubstreamProtocol
32};
33use crate::upgrade::{
34    InboundUpgradeSend,
35    OutboundUpgradeSend,
36    UpgradeInfoSend
37};
38use futures::{future::BoxFuture, prelude::*};
39use libp2p_core::{ConnectedPoint, Multiaddr, PeerId};
40use libp2p_core::upgrade::{ProtocolName, UpgradeError, NegotiationError, ProtocolError};
41use rand::Rng;
42use std::{
43    cmp,
44    collections::{HashMap, HashSet},
45    error,
46    fmt,
47    hash::Hash,
48    iter::{self, FromIterator},
49    task::{Context, Poll},
50    time::Duration
51};
52
53/// A [`ProtocolsHandler`] for multiple `ProtocolsHandler`s of the same type.
54#[derive(Clone)]
55pub struct MultiHandler<K, H> {
56    handlers: HashMap<K, H>
57}
58
59impl<K, H> fmt::Debug for MultiHandler<K, H>
60where
61    K: fmt::Debug + Eq + Hash,
62    H: fmt::Debug
63{
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        f.debug_struct("MultiHandler")
66            .field("handlers", &self.handlers)
67            .finish()
68    }
69}
70
71impl<K, H> MultiHandler<K, H>
72where
73    K: Hash + Eq,
74    H: ProtocolsHandler
75{
76    /// Create and populate a `MultiHandler` from the given handler iterator.
77    ///
78    /// It is an error for any two protocols handlers to share the same protocol name.
79    pub fn try_from_iter<I>(iter: I) -> Result<Self, DuplicateProtonameError>
80    where
81        I: IntoIterator<Item = (K, H)>
82    {
83        let m = MultiHandler { handlers: HashMap::from_iter(iter) };
84        uniq_proto_names(m.handlers.values().map(|h| h.listen_protocol().into_upgrade().0))?;
85        Ok(m)
86    }
87}
88
89impl<K, H> ProtocolsHandler for MultiHandler<K, H>
90where
91    K: Clone + Hash + Eq + Send + 'static,
92    H: ProtocolsHandler,
93    H::InboundProtocol: InboundUpgradeSend,
94    H::OutboundProtocol: OutboundUpgradeSend
95{
96    type InEvent = (K, <H as ProtocolsHandler>::InEvent);
97    type OutEvent = (K, <H as ProtocolsHandler>::OutEvent);
98    type Error = <H as ProtocolsHandler>::Error;
99    type InboundProtocol = Upgrade<K, <H as ProtocolsHandler>::InboundProtocol>;
100    type OutboundProtocol = <H as ProtocolsHandler>::OutboundProtocol;
101    type InboundOpenInfo = Info<K, <H as ProtocolsHandler>::InboundOpenInfo>;
102    type OutboundOpenInfo = (K, <H as ProtocolsHandler>::OutboundOpenInfo);
103
104    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
105        let (upgrade, info, timeout) = self.handlers.iter()
106            .map(|(key, handler)| {
107                let proto = handler.listen_protocol();
108                let timeout = *proto.timeout();
109                let (upgrade, info) = proto.into_upgrade();
110                (key.clone(), (upgrade, info, timeout))
111            })
112            .fold((Upgrade::new(), Info::new(), Duration::from_secs(0)),
113                |(mut upg, mut inf, mut timeout), (k, (u, i, t))| {
114                    upg.upgrades.push((k.clone(), u));
115                    inf.infos.push((k, i));
116                    timeout = cmp::max(timeout, t);
117                    (upg, inf, timeout)
118                }
119            );
120        SubstreamProtocol::new(upgrade, info).with_timeout(timeout)
121    }
122
123    fn inject_fully_negotiated_outbound (
124        &mut self,
125        protocol: <Self::OutboundProtocol as OutboundUpgradeSend>::Output,
126        (key, arg): Self::OutboundOpenInfo
127    ) {
128        if let Some(h) = self.handlers.get_mut(&key) {
129            h.inject_fully_negotiated_outbound(protocol, arg)
130        } else {
131            log::error!("inject_fully_negotiated_outbound: no handler for key")
132        }
133    }
134
135    fn inject_fully_negotiated_inbound (
136        &mut self,
137        (key, arg): <Self::InboundProtocol as InboundUpgradeSend>::Output,
138        mut info: Self::InboundOpenInfo
139    ) {
140        if let Some(h) = self.handlers.get_mut(&key) {
141            if let Some(i) = info.take(&key) {
142                h.inject_fully_negotiated_inbound(arg, i)
143            }
144        } else {
145            log::error!("inject_fully_negotiated_inbound: no handler for key")
146        }
147    }
148
149    fn inject_event(&mut self, (key, event): Self::InEvent) {
150        if let Some(h) = self.handlers.get_mut(&key) {
151            h.inject_event(event)
152        } else {
153            log::error!("inject_event: no handler for key")
154        }
155    }
156
157    fn inject_address_change(&mut self, addr: &Multiaddr) {
158        for h in self.handlers.values_mut() {
159            h.inject_address_change(addr)
160        }
161    }
162
163    fn inject_dial_upgrade_error (
164        &mut self,
165        (key, arg): Self::OutboundOpenInfo,
166        error: ProtocolsHandlerUpgrErr<<Self::OutboundProtocol as OutboundUpgradeSend>::Error>
167    ) {
168        if let Some(h) = self.handlers.get_mut(&key) {
169            h.inject_dial_upgrade_error(arg, error)
170        } else {
171            log::error!("inject_dial_upgrade_error: no handler for protocol")
172        }
173    }
174
175    fn inject_listen_upgrade_error(
176        &mut self,
177        mut info: Self::InboundOpenInfo,
178        error: ProtocolsHandlerUpgrErr<<Self::InboundProtocol as InboundUpgradeSend>::Error>
179    ) {
180        match error {
181            ProtocolsHandlerUpgrErr::Timer =>
182                for (k, h) in &mut self.handlers {
183                    if let Some(i) = info.take(k) {
184                        h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Timer)
185                    }
186                }
187            ProtocolsHandlerUpgrErr::Timeout =>
188                for (k, h) in &mut self.handlers {
189                    if let Some(i) = info.take(k) {
190                        h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Timeout)
191                    }
192                }
193            ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) =>
194                for (k, h) in &mut self.handlers {
195                    if let Some(i) = info.take(k) {
196                        h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)))
197                    }
198                }
199            ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::ProtocolError(e))) =>
200                match e {
201                    ProtocolError::IoError(e) =>
202                        for (k, h) in &mut self.handlers {
203                            if let Some(i) = info.take(k) {
204                                let e = NegotiationError::ProtocolError(ProtocolError::IoError(e.kind().into()));
205                                h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)))
206                            }
207                        }
208                    ProtocolError::InvalidMessage =>
209                        for (k, h) in &mut self.handlers {
210                            if let Some(i) = info.take(k) {
211                                let e = NegotiationError::ProtocolError(ProtocolError::InvalidMessage);
212                                h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)))
213                            }
214                        }
215                    ProtocolError::InvalidProtocol =>
216                        for (k, h) in &mut self.handlers {
217                            if let Some(i) = info.take(k) {
218                                let e = NegotiationError::ProtocolError(ProtocolError::InvalidProtocol);
219                                h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)))
220                            }
221                        }
222                    ProtocolError::TooManyProtocols =>
223                        for (k, h) in &mut self.handlers {
224                            if let Some(i) = info.take(k) {
225                                let e = NegotiationError::ProtocolError(ProtocolError::TooManyProtocols);
226                                h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)))
227                            }
228                        }
229                }
230            ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply((k, e))) =>
231                if let Some(h) = self.handlers.get_mut(&k) {
232                    if let Some(i) = info.take(&k) {
233                        h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e)))
234                    }
235                }
236        }
237    }
238
239    fn connection_keep_alive(&self) -> KeepAlive {
240        self.handlers.values()
241            .map(|h| h.connection_keep_alive())
242            .max()
243            .unwrap_or(KeepAlive::No)
244    }
245
246    fn poll(&mut self, cx: &mut Context<'_>)
247        -> Poll<ProtocolsHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::OutEvent, Self::Error>>
248    {
249        // Calling `gen_range(0, 0)` (see below) would panic, so we have return early to avoid
250        // that situation.
251        if self.handlers.is_empty() {
252            return Poll::Pending;
253        }
254
255        // Not always polling handlers in the same order should give anyone the chance to make progress.
256        let pos = rand::thread_rng().gen_range(0, self.handlers.len());
257
258        for (k, h) in self.handlers.iter_mut().skip(pos) {
259            if let Poll::Ready(e) = h.poll(cx) {
260                let e = e.map_outbound_open_info(|i| (k.clone(), i)).map_custom(|p| (k.clone(), p));
261                return Poll::Ready(e)
262            }
263        }
264
265        for (k, h) in self.handlers.iter_mut().take(pos) {
266            if let Poll::Ready(e) = h.poll(cx) {
267                let e = e.map_outbound_open_info(|i| (k.clone(), i)).map_custom(|p| (k.clone(), p));
268                return Poll::Ready(e)
269            }
270        }
271
272        Poll::Pending
273    }
274}
275
276/// A [`IntoProtocolsHandler`] for multiple other `IntoProtocolsHandler`s.
277#[derive(Clone)]
278pub struct IntoMultiHandler<K, H> {
279    handlers: HashMap<K, H>
280}
281
282impl<K, H> fmt::Debug for IntoMultiHandler<K, H>
283where
284    K: fmt::Debug + Eq + Hash,
285    H: fmt::Debug
286{
287    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288        f.debug_struct("IntoMultiHandler")
289            .field("handlers", &self.handlers)
290            .finish()
291    }
292}
293
294
295impl<K, H> IntoMultiHandler<K, H>
296where
297    K: Hash + Eq,
298    H: IntoProtocolsHandler
299{
300    /// Create and populate an `IntoMultiHandler` from the given iterator.
301    ///
302    /// It is an error for any two protocols handlers to share the same protocol name.
303    pub fn try_from_iter<I>(iter: I) -> Result<Self, DuplicateProtonameError>
304    where
305        I: IntoIterator<Item = (K, H)>
306    {
307        let m = IntoMultiHandler { handlers: HashMap::from_iter(iter) };
308        uniq_proto_names(m.handlers.values().map(|h| h.inbound_protocol()))?;
309        Ok(m)
310    }
311}
312
313impl<K, H> IntoProtocolsHandler for IntoMultiHandler<K, H>
314where
315    K: Clone + Eq + Hash + Send + 'static,
316    H: IntoProtocolsHandler
317{
318    type Handler = MultiHandler<K, H::Handler>;
319
320    fn into_handler(self, p: &PeerId, c: &ConnectedPoint) -> Self::Handler {
321        MultiHandler {
322            handlers: self.handlers.into_iter()
323                .map(|(k, h)| (k, h.into_handler(p, c)))
324                .collect()
325        }
326    }
327
328    fn inbound_protocol(&self) -> <Self::Handler as ProtocolsHandler>::InboundProtocol {
329        Upgrade {
330            upgrades: self.handlers.iter()
331                .map(|(k, h)| (k.clone(), h.inbound_protocol()))
332                .collect()
333        }
334    }
335}
336
337/// Index and protocol name pair used as `UpgradeInfo::Info`.
338#[derive(Debug, Clone)]
339pub struct IndexedProtoName<H>(usize, H);
340
341impl<H: ProtocolName> ProtocolName for IndexedProtoName<H> {
342    fn protocol_name(&self) -> &[u8] {
343        self.1.protocol_name()
344    }
345}
346
347/// The aggregated `InboundOpenInfo`s of supported inbound substream protocols.
348#[derive(Clone)]
349pub struct Info<K, I> {
350    infos: Vec<(K, I)>
351}
352
353impl<K: Eq, I> Info<K, I> {
354    fn new() -> Self {
355        Info { infos: Vec::new() }
356    }
357
358    pub fn take(&mut self, k: &K) -> Option<I> {
359        if let Some(p) = self.infos.iter().position(|(key, _)| key == k) {
360            return Some(self.infos.remove(p).1)
361        }
362        None
363    }
364}
365
366/// Inbound and outbound upgrade for all `ProtocolsHandler`s.
367#[derive(Clone)]
368pub struct Upgrade<K, H> {
369    upgrades: Vec<(K, H)>
370}
371
372impl<K, H> Upgrade<K, H> {
373    fn new() -> Self {
374        Upgrade { upgrades: Vec::new() }
375    }
376}
377
378impl<K, H> fmt::Debug for Upgrade<K, H>
379where
380    K: fmt::Debug + Eq + Hash,
381    H: fmt::Debug
382{
383    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
384        f.debug_struct("Upgrade")
385            .field("upgrades", &self.upgrades)
386            .finish()
387    }
388}
389
390impl<K, H> UpgradeInfoSend for Upgrade<K, H>
391where
392    H: UpgradeInfoSend,
393    K: Send + 'static
394{
395    type Info = IndexedProtoName<H::Info>;
396    type InfoIter = std::vec::IntoIter<Self::Info>;
397
398    fn protocol_info(&self) -> Self::InfoIter {
399        self.upgrades.iter().enumerate()
400            .map(|(i, (_, h))| iter::repeat(i).zip(h.protocol_info()))
401            .flatten()
402            .map(|(i, h)| IndexedProtoName(i, h))
403            .collect::<Vec<_>>()
404            .into_iter()
405    }
406}
407
408impl<K, H> InboundUpgradeSend for Upgrade<K, H>
409where
410    H: InboundUpgradeSend,
411    K: Send + 'static
412{
413    type Output = (K, <H as InboundUpgradeSend>::Output);
414    type Error  = (K, <H as InboundUpgradeSend>::Error);
415    type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
416
417    fn upgrade_inbound(mut self, resource: NegotiatedSubstream, info: Self::Info) -> Self::Future {
418        let IndexedProtoName(index, info) = info;
419        let (key, upgrade) = self.upgrades.remove(index);
420        upgrade.upgrade_inbound(resource, info)
421            .map(move |out| {
422                match out {
423                    Ok(o) => Ok((key, o)),
424                    Err(e) => Err((key, e))
425                }
426            })
427            .boxed()
428    }
429}
430
431impl<K, H> OutboundUpgradeSend for Upgrade<K, H>
432where
433    H: OutboundUpgradeSend,
434    K: Send + 'static
435{
436    type Output = (K, <H as OutboundUpgradeSend>::Output);
437    type Error  = (K, <H as OutboundUpgradeSend>::Error);
438    type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
439
440    fn upgrade_outbound(mut self, resource: NegotiatedSubstream, info: Self::Info) -> Self::Future {
441        let IndexedProtoName(index, info) = info;
442        let (key, upgrade) = self.upgrades.remove(index);
443        upgrade.upgrade_outbound(resource, info)
444            .map(move |out| {
445                match out {
446                    Ok(o) => Ok((key, o)),
447                    Err(e) => Err((key, e))
448                }
449            })
450            .boxed()
451    }
452}
453
454/// Check that no two protocol names are equal.
455fn uniq_proto_names<I, T>(iter: I) -> Result<(), DuplicateProtonameError>
456where
457    I: Iterator<Item = T>,
458    T: UpgradeInfoSend
459{
460    let mut set = HashSet::new();
461    for infos in iter {
462        for i in infos.protocol_info() {
463            let v = Vec::from(i.protocol_name());
464            if set.contains(&v) {
465                return Err(DuplicateProtonameError(v))
466            } else {
467                set.insert(v);
468            }
469        }
470    }
471    Ok(())
472}
473
474/// It is an error if two handlers share the same protocol name.
475#[derive(Debug, Clone)]
476pub struct DuplicateProtonameError(Vec<u8>);
477
478impl DuplicateProtonameError {
479    /// The protocol name bytes that occured in more than one handler.
480    pub fn protocol_name(&self) -> &[u8] {
481        &self.0
482    }
483}
484
485impl fmt::Display for DuplicateProtonameError {
486    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
487        if let Ok(s) = std::str::from_utf8(&self.0) {
488            write!(f, "duplicate protocol name: {}", s)
489        } else {
490            write!(f, "duplicate protocol name: {:?}", self.0)
491        }
492    }
493}
494
495impl error::Error for DuplicateProtonameError {}