1use crate::NegotiatedSubstream;
25use crate::protocols_handler::{
26 KeepAlive,
27 IntoProtocolsHandler,
28 ProtocolsHandler,
29 ProtocolsHandlerEvent,
30 ProtocolsHandlerUpgrErr,
31 SubstreamProtocol
32};
33use crate::upgrade::{
34 InboundUpgradeSend,
35 OutboundUpgradeSend,
36 UpgradeInfoSend
37};
38use futures::{future::BoxFuture, prelude::*};
39use mwc_libp2p_core::{ConnectedPoint, Multiaddr, PeerId};
40use mwc_libp2p_core::upgrade::{self, ProtocolName, UpgradeError, NegotiationError, ProtocolError};
41use rand::Rng;
42use std::{
43 cmp,
44 collections::{HashMap, HashSet},
45 error,
46 fmt,
47 hash::Hash,
48 iter::{self, FromIterator},
49 task::{Context, Poll},
50 time::Duration
51};
52
53#[derive(Clone)]
55pub struct MultiHandler<K, H> {
56 handlers: HashMap<K, H>
57}
58
59impl<K, H> fmt::Debug for MultiHandler<K, H>
60where
61 K: fmt::Debug + Eq + Hash,
62 H: fmt::Debug
63{
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 f.debug_struct("MultiHandler")
66 .field("handlers", &self.handlers)
67 .finish()
68 }
69}
70
71impl<K, H> MultiHandler<K, H>
72where
73 K: Hash + Eq,
74 H: ProtocolsHandler
75{
76 pub fn try_from_iter<I>(iter: I) -> Result<Self, DuplicateProtonameError>
83 where
84 I: IntoIterator<Item = (K, H)>
85 {
86 let m = MultiHandler { handlers: HashMap::from_iter(iter) };
87 uniq_proto_names(m.handlers.values().map(|h| h.listen_protocol().into_upgrade().1))?;
88 Ok(m)
89 }
90}
91
92impl<K, H> ProtocolsHandler for MultiHandler<K, H>
93where
94 K: Clone + Hash + Eq + Send + 'static,
95 H: ProtocolsHandler,
96 H::InboundProtocol: InboundUpgradeSend,
97 H::OutboundProtocol: OutboundUpgradeSend
98{
99 type InEvent = (K, <H as ProtocolsHandler>::InEvent);
100 type OutEvent = (K, <H as ProtocolsHandler>::OutEvent);
101 type Error = <H as ProtocolsHandler>::Error;
102 type InboundProtocol = Upgrade<K, <H as ProtocolsHandler>::InboundProtocol>;
103 type OutboundProtocol = <H as ProtocolsHandler>::OutboundProtocol;
104 type InboundOpenInfo = Info<K, <H as ProtocolsHandler>::InboundOpenInfo>;
105 type OutboundOpenInfo = (K, <H as ProtocolsHandler>::OutboundOpenInfo);
106
107 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
108 let (upgrade, info, timeout, version) = self.handlers.iter()
109 .map(|(key, handler)| {
110 let proto = handler.listen_protocol();
111 let timeout = *proto.timeout();
112 let (version, upgrade, info) = proto.into_upgrade();
113 (key.clone(), (version, upgrade, info, timeout))
114 })
115 .fold((Upgrade::new(), Info::new(), Duration::from_secs(0), None),
116 |(mut upg, mut inf, mut timeout, mut version), (k, (v, u, i, t))| {
117 upg.upgrades.push((k.clone(), u));
118 inf.infos.push((k, i));
119 timeout = cmp::max(timeout, t);
120 version = version.map_or(Some(v), |vv|
121 if v != vv {
122 log::warn!("Differing upgrade versions. Defaulting to V1.");
126 Some(upgrade::Version::V1)
127 } else {
128 Some(v)
129 });
130 (upg, inf, timeout, version)
131 }
132 );
133 SubstreamProtocol::new(upgrade, info)
134 .with_timeout(timeout)
135 .with_upgrade_protocol(version.unwrap_or(upgrade::Version::V1))
136 }
137
138 fn inject_fully_negotiated_outbound (
139 &mut self,
140 protocol: <Self::OutboundProtocol as OutboundUpgradeSend>::Output,
141 (key, arg): Self::OutboundOpenInfo
142 ) {
143 if let Some(h) = self.handlers.get_mut(&key) {
144 h.inject_fully_negotiated_outbound(protocol, arg)
145 } else {
146 log::error!("inject_fully_negotiated_outbound: no handler for key")
147 }
148 }
149
150 fn inject_fully_negotiated_inbound (
151 &mut self,
152 (key, arg): <Self::InboundProtocol as InboundUpgradeSend>::Output,
153 mut info: Self::InboundOpenInfo
154 ) {
155 if let Some(h) = self.handlers.get_mut(&key) {
156 if let Some(i) = info.take(&key) {
157 h.inject_fully_negotiated_inbound(arg, i)
158 }
159 } else {
160 log::error!("inject_fully_negotiated_inbound: no handler for key")
161 }
162 }
163
164 fn inject_event(&mut self, (key, event): Self::InEvent) {
165 if let Some(h) = self.handlers.get_mut(&key) {
166 h.inject_event(event)
167 } else {
168 log::error!("inject_event: no handler for key")
169 }
170 }
171
172 fn inject_address_change(&mut self, addr: &Multiaddr) {
173 for h in self.handlers.values_mut() {
174 h.inject_address_change(addr)
175 }
176 }
177
178 fn inject_dial_upgrade_error (
179 &mut self,
180 (key, arg): Self::OutboundOpenInfo,
181 error: ProtocolsHandlerUpgrErr<<Self::OutboundProtocol as OutboundUpgradeSend>::Error>
182 ) {
183 if let Some(h) = self.handlers.get_mut(&key) {
184 h.inject_dial_upgrade_error(arg, error)
185 } else {
186 log::error!("inject_dial_upgrade_error: no handler for protocol")
187 }
188 }
189
190 fn inject_listen_upgrade_error(
191 &mut self,
192 mut info: Self::InboundOpenInfo,
193 error: ProtocolsHandlerUpgrErr<<Self::InboundProtocol as InboundUpgradeSend>::Error>
194 ) {
195 match error {
196 ProtocolsHandlerUpgrErr::Timer =>
197 for (k, h) in &mut self.handlers {
198 if let Some(i) = info.take(k) {
199 h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Timer)
200 }
201 }
202 ProtocolsHandlerUpgrErr::Timeout =>
203 for (k, h) in &mut self.handlers {
204 if let Some(i) = info.take(k) {
205 h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Timeout)
206 }
207 }
208 ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) =>
209 for (k, h) in &mut self.handlers {
210 if let Some(i) = info.take(k) {
211 h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)))
212 }
213 }
214 ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::ProtocolError(e))) =>
215 match e {
216 ProtocolError::IoError(e) =>
217 for (k, h) in &mut self.handlers {
218 if let Some(i) = info.take(k) {
219 let e = NegotiationError::ProtocolError(ProtocolError::IoError(e.kind().into()));
220 h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)))
221 }
222 }
223 ProtocolError::InvalidMessage =>
224 for (k, h) in &mut self.handlers {
225 if let Some(i) = info.take(k) {
226 let e = NegotiationError::ProtocolError(ProtocolError::InvalidMessage);
227 h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)))
228 }
229 }
230 ProtocolError::InvalidProtocol =>
231 for (k, h) in &mut self.handlers {
232 if let Some(i) = info.take(k) {
233 let e = NegotiationError::ProtocolError(ProtocolError::InvalidProtocol);
234 h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)))
235 }
236 }
237 ProtocolError::TooManyProtocols =>
238 for (k, h) in &mut self.handlers {
239 if let Some(i) = info.take(k) {
240 let e = NegotiationError::ProtocolError(ProtocolError::TooManyProtocols);
241 h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)))
242 }
243 }
244 }
245 ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply((k, e))) =>
246 if let Some(h) = self.handlers.get_mut(&k) {
247 if let Some(i) = info.take(&k) {
248 h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e)))
249 }
250 }
251 }
252 }
253
254 fn connection_keep_alive(&self) -> KeepAlive {
255 self.handlers.values()
256 .map(|h| h.connection_keep_alive())
257 .max()
258 .unwrap_or(KeepAlive::No)
259 }
260
261 fn poll(&mut self, cx: &mut Context<'_>)
262 -> Poll<ProtocolsHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::OutEvent, Self::Error>>
263 {
264 if self.handlers.is_empty() {
267 return Poll::Pending;
268 }
269
270 let pos = rand::thread_rng().gen_range(0, self.handlers.len());
272
273 for (k, h) in self.handlers.iter_mut().skip(pos) {
274 if let Poll::Ready(e) = h.poll(cx) {
275 let e = e.map_outbound_open_info(|i| (k.clone(), i)).map_custom(|p| (k.clone(), p));
276 return Poll::Ready(e)
277 }
278 }
279
280 for (k, h) in self.handlers.iter_mut().take(pos) {
281 if let Poll::Ready(e) = h.poll(cx) {
282 let e = e.map_outbound_open_info(|i| (k.clone(), i)).map_custom(|p| (k.clone(), p));
283 return Poll::Ready(e)
284 }
285 }
286
287 Poll::Pending
288 }
289}
290
291#[derive(Clone)]
293pub struct IntoMultiHandler<K, H> {
294 handlers: HashMap<K, H>
295}
296
297impl<K, H> fmt::Debug for IntoMultiHandler<K, H>
298where
299 K: fmt::Debug + Eq + Hash,
300 H: fmt::Debug
301{
302 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303 f.debug_struct("IntoMultiHandler")
304 .field("handlers", &self.handlers)
305 .finish()
306 }
307}
308
309
310impl<K, H> IntoMultiHandler<K, H>
311where
312 K: Hash + Eq,
313 H: IntoProtocolsHandler
314{
315 pub fn try_from_iter<I>(iter: I) -> Result<Self, DuplicateProtonameError>
322 where
323 I: IntoIterator<Item = (K, H)>
324 {
325 let m = IntoMultiHandler { handlers: HashMap::from_iter(iter) };
326 uniq_proto_names(m.handlers.values().map(|h| h.inbound_protocol()))?;
327 Ok(m)
328 }
329}
330
331impl<K, H> IntoProtocolsHandler for IntoMultiHandler<K, H>
332where
333 K: Clone + Eq + Hash + Send + 'static,
334 H: IntoProtocolsHandler
335{
336 type Handler = MultiHandler<K, H::Handler>;
337
338 fn into_handler(self, p: &PeerId, c: &ConnectedPoint) -> Self::Handler {
339 MultiHandler {
340 handlers: self.handlers.into_iter()
341 .map(|(k, h)| (k, h.into_handler(p, c)))
342 .collect()
343 }
344 }
345
346 fn inbound_protocol(&self) -> <Self::Handler as ProtocolsHandler>::InboundProtocol {
347 Upgrade {
348 upgrades: self.handlers.iter()
349 .map(|(k, h)| (k.clone(), h.inbound_protocol()))
350 .collect()
351 }
352 }
353}
354
355#[derive(Debug, Clone)]
357pub struct IndexedProtoName<H>(usize, H);
358
359impl<H: ProtocolName> ProtocolName for IndexedProtoName<H> {
360 fn protocol_name(&self) -> &[u8] {
361 self.1.protocol_name()
362 }
363}
364
365#[derive(Clone)]
367pub struct Info<K, I> {
368 infos: Vec<(K, I)>
369}
370
371impl<K: Eq, I> Info<K, I> {
372 fn new() -> Self {
373 Info { infos: Vec::new() }
374 }
375
376 pub fn take(&mut self, k: &K) -> Option<I> {
377 if let Some(p) = self.infos.iter().position(|(key, _)| key == k) {
378 return Some(self.infos.remove(p).1)
379 }
380 None
381 }
382}
383
384#[derive(Clone)]
386pub struct Upgrade<K, H> {
387 upgrades: Vec<(K, H)>
388}
389
390impl<K, H> Upgrade<K, H> {
391 fn new() -> Self {
392 Upgrade { upgrades: Vec::new() }
393 }
394}
395
396impl<K, H> fmt::Debug for Upgrade<K, H>
397where
398 K: fmt::Debug + Eq + Hash,
399 H: fmt::Debug
400{
401 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
402 f.debug_struct("Upgrade")
403 .field("upgrades", &self.upgrades)
404 .finish()
405 }
406}
407
408impl<K, H> UpgradeInfoSend for Upgrade<K, H>
409where
410 H: UpgradeInfoSend,
411 K: Send + 'static
412{
413 type Info = IndexedProtoName<H::Info>;
414 type InfoIter = std::vec::IntoIter<Self::Info>;
415
416 fn protocol_info(&self) -> Self::InfoIter {
417 self.upgrades.iter().enumerate()
418 .map(|(i, (_, h))| iter::repeat(i).zip(h.protocol_info()))
419 .flatten()
420 .map(|(i, h)| IndexedProtoName(i, h))
421 .collect::<Vec<_>>()
422 .into_iter()
423 }
424}
425
426impl<K, H> InboundUpgradeSend for Upgrade<K, H>
427where
428 H: InboundUpgradeSend,
429 K: Send + 'static
430{
431 type Output = (K, <H as InboundUpgradeSend>::Output);
432 type Error = (K, <H as InboundUpgradeSend>::Error);
433 type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
434
435 fn upgrade_inbound(mut self, resource: NegotiatedSubstream, info: Self::Info) -> Self::Future {
436 let IndexedProtoName(index, info) = info;
437 let (key, upgrade) = self.upgrades.remove(index);
438 upgrade.upgrade_inbound(resource, info)
439 .map(move |out| {
440 match out {
441 Ok(o) => Ok((key, o)),
442 Err(e) => Err((key, e))
443 }
444 })
445 .boxed()
446 }
447}
448
449impl<K, H> OutboundUpgradeSend for Upgrade<K, H>
450where
451 H: OutboundUpgradeSend,
452 K: Send + 'static
453{
454 type Output = (K, <H as OutboundUpgradeSend>::Output);
455 type Error = (K, <H as OutboundUpgradeSend>::Error);
456 type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
457
458 fn upgrade_outbound(mut self, resource: NegotiatedSubstream, info: Self::Info) -> Self::Future {
459 let IndexedProtoName(index, info) = info;
460 let (key, upgrade) = self.upgrades.remove(index);
461 upgrade.upgrade_outbound(resource, info)
462 .map(move |out| {
463 match out {
464 Ok(o) => Ok((key, o)),
465 Err(e) => Err((key, e))
466 }
467 })
468 .boxed()
469 }
470}
471
472fn uniq_proto_names<I, T>(iter: I) -> Result<(), DuplicateProtonameError>
474where
475 I: Iterator<Item = T>,
476 T: UpgradeInfoSend
477{
478 let mut set = HashSet::new();
479 for infos in iter {
480 for i in infos.protocol_info() {
481 let v = Vec::from(i.protocol_name());
482 if set.contains(&v) {
483 return Err(DuplicateProtonameError(v))
484 } else {
485 set.insert(v);
486 }
487 }
488 }
489 Ok(())
490}
491
492#[derive(Debug, Clone)]
494pub struct DuplicateProtonameError(Vec<u8>);
495
496impl DuplicateProtonameError {
497 pub fn protocol_name(&self) -> &[u8] {
499 &self.0
500 }
501}
502
503impl fmt::Display for DuplicateProtonameError {
504 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
505 if let Ok(s) = std::str::from_utf8(&self.0) {
506 write!(f, "duplicate protocol name: {}", s)
507 } else {
508 write!(f, "duplicate protocol name: {:?}", self.0)
509 }
510 }
511}
512
513impl error::Error for DuplicateProtonameError {}