1use crate::upgrade::SendWrapper;
22use crate::protocols_handler::{
23 KeepAlive,
24 ProtocolsHandler,
25 IntoProtocolsHandler,
26 ProtocolsHandlerEvent,
27 ProtocolsHandlerUpgrErr
28};
29
30use futures::prelude::*;
31use futures::stream::FuturesUnordered;
32use mwc_libp2p_core::{
33 Multiaddr,
34 Connected,
35 connection::{
36 ConnectionHandler,
37 ConnectionHandlerEvent,
38 IntoConnectionHandler,
39 Substream,
40 SubstreamEndpoint,
41 },
42 muxing::StreamMuxerBox,
43 upgrade::{self, InboundUpgradeApply, OutboundUpgradeApply, UpgradeError}
44};
45use std::{error, fmt, pin::Pin, task::Context, task::Poll, time::Duration};
46use wasm_timer::{Delay, Instant};
47
48pub struct NodeHandlerWrapperBuilder<TIntoProtoHandler> {
50 handler: TIntoProtoHandler,
52 substream_upgrade_protocol_override: Option<upgrade::Version>,
54}
55
56impl<TIntoProtoHandler> NodeHandlerWrapperBuilder<TIntoProtoHandler>
57where
58 TIntoProtoHandler: IntoProtocolsHandler
59{
60 pub(crate) fn new(handler: TIntoProtoHandler) -> Self {
62 NodeHandlerWrapperBuilder {
63 handler,
64 substream_upgrade_protocol_override: None,
65 }
66 }
67
68 pub(crate) fn with_substream_upgrade_protocol_override(
69 mut self,
70 version: Option<upgrade::Version>
71 ) -> Self {
72 self.substream_upgrade_protocol_override = version;
73 self
74 }
75}
76
77impl<TIntoProtoHandler, TProtoHandler> IntoConnectionHandler
78 for NodeHandlerWrapperBuilder<TIntoProtoHandler>
79where
80 TIntoProtoHandler: IntoProtocolsHandler<Handler = TProtoHandler>,
81 TProtoHandler: ProtocolsHandler,
82{
83 type Handler = NodeHandlerWrapper<TIntoProtoHandler::Handler>;
84
85 fn into_handler(self, connected: &Connected) -> Self::Handler {
86 NodeHandlerWrapper {
87 handler: self.handler.into_handler(&connected.peer_id, &connected.endpoint),
88 negotiating_in: Default::default(),
89 negotiating_out: Default::default(),
90 queued_dial_upgrades: Vec::new(),
91 unique_dial_upgrade_id: 0,
92 shutdown: Shutdown::None,
93 substream_upgrade_protocol_override: self.substream_upgrade_protocol_override,
94 }
95 }
96}
97
98pub struct NodeHandlerWrapper<TProtoHandler>
102where
103 TProtoHandler: ProtocolsHandler,
104{
105 handler: TProtoHandler,
107 negotiating_in: FuturesUnordered<SubstreamUpgrade<
109 TProtoHandler::InboundOpenInfo,
110 InboundUpgradeApply<Substream<StreamMuxerBox>, SendWrapper<TProtoHandler::InboundProtocol>>,
111 >>,
112 negotiating_out: FuturesUnordered<SubstreamUpgrade<
114 TProtoHandler::OutboundOpenInfo,
115 OutboundUpgradeApply<Substream<StreamMuxerBox>, SendWrapper<TProtoHandler::OutboundProtocol>>,
116 >>,
117 queued_dial_upgrades: Vec<(u64, (upgrade::Version, SendWrapper<TProtoHandler::OutboundProtocol>))>,
120 unique_dial_upgrade_id: u64,
122 shutdown: Shutdown,
124 substream_upgrade_protocol_override: Option<upgrade::Version>,
126}
127
128struct SubstreamUpgrade<UserData, Upgrade> {
129 user_data: Option<UserData>,
130 timeout: Delay,
131 upgrade: Upgrade,
132}
133
134impl<UserData, Upgrade> Unpin for SubstreamUpgrade<UserData, Upgrade> {}
135
136impl<UserData, Upgrade, UpgradeOutput, TUpgradeError> Future for SubstreamUpgrade<UserData, Upgrade>
137where
138 Upgrade: Future<Output = Result<UpgradeOutput, UpgradeError<TUpgradeError>>> + Unpin,
139{
140 type Output = (UserData, Result<UpgradeOutput, ProtocolsHandlerUpgrErr<TUpgradeError>>);
141
142 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
143 match self.timeout.poll_unpin(cx) {
144 Poll::Ready(Ok(_)) => return Poll::Ready((
145 self.user_data.take().expect("Future not to be polled again once ready."),
146 Err(ProtocolsHandlerUpgrErr::Timeout)),
147 ),
148 Poll::Ready(Err(_)) => return Poll::Ready((
149 self.user_data.take().expect("Future not to be polled again once ready."),
150 Err(ProtocolsHandlerUpgrErr::Timer)),
151 ),
152 Poll::Pending => {},
153 }
154
155 match self.upgrade.poll_unpin(cx) {
156 Poll::Ready(Ok(upgrade)) => Poll::Ready((
157 self.user_data.take().expect("Future not to be polled again once ready."),
158 Ok(upgrade),
159 )),
160 Poll::Ready(Err(err)) => Poll::Ready((
161 self.user_data.take().expect("Future not to be polled again once ready."),
162 Err(ProtocolsHandlerUpgrErr::Upgrade(err)),
163 )),
164 Poll::Pending => Poll::Pending,
165 }
166 }
167}
168
169
170enum Shutdown {
180 None,
182 Asap,
184 Later(Delay, Instant)
186}
187
188#[derive(Debug)]
190pub enum NodeHandlerWrapperError<TErr> {
191 Handler(TErr),
193 KeepAliveTimeout,
195}
196
197impl<TErr> From<TErr> for NodeHandlerWrapperError<TErr> {
198 fn from(err: TErr) -> NodeHandlerWrapperError<TErr> {
199 NodeHandlerWrapperError::Handler(err)
200 }
201}
202
203impl<TErr> fmt::Display for NodeHandlerWrapperError<TErr>
204where
205 TErr: fmt::Display
206{
207 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208 match self {
209 NodeHandlerWrapperError::Handler(err) => write!(f, "{}", err),
210 NodeHandlerWrapperError::KeepAliveTimeout =>
211 write!(f, "Connection closed due to expired keep-alive timeout."),
212 }
213 }
214}
215
216impl<TErr> error::Error for NodeHandlerWrapperError<TErr>
217where
218 TErr: error::Error + 'static
219{
220 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
221 match self {
222 NodeHandlerWrapperError::Handler(err) => Some(err),
223 NodeHandlerWrapperError::KeepAliveTimeout => None,
224 }
225 }
226}
227
228impl<TProtoHandler> ConnectionHandler for NodeHandlerWrapper<TProtoHandler>
229where
230 TProtoHandler: ProtocolsHandler,
231{
232 type InEvent = TProtoHandler::InEvent;
233 type OutEvent = TProtoHandler::OutEvent;
234 type Error = NodeHandlerWrapperError<TProtoHandler::Error>;
235 type Substream = Substream<StreamMuxerBox>;
236 type OutboundOpenInfo = (u64, TProtoHandler::OutboundOpenInfo, Duration);
239
240 fn inject_substream(
241 &mut self,
242 substream: Self::Substream,
243 endpoint: SubstreamEndpoint<Self::OutboundOpenInfo>,
244 ) {
245 match endpoint {
246 SubstreamEndpoint::Listener => {
247 let protocol = self.handler.listen_protocol();
248 let timeout = *protocol.timeout();
249 let (_, upgrade, user_data) = protocol.into_upgrade();
250 let upgrade = upgrade::apply_inbound(substream, SendWrapper(upgrade));
251 let timeout = Delay::new(timeout);
252 self.negotiating_in.push(SubstreamUpgrade {
253 user_data: Some(user_data),
254 timeout,
255 upgrade,
256 });
257 }
258 SubstreamEndpoint::Dialer((upgrade_id, user_data, timeout)) => {
259 let pos = match self
260 .queued_dial_upgrades
261 .iter()
262 .position(|(id, _)| id == &upgrade_id)
263 {
264 Some(p) => p,
265 None => {
266 debug_assert!(false, "Received an upgrade with an invalid upgrade ID");
267 return;
268 }
269 };
270
271 let (_, (mut version, upgrade)) = self.queued_dial_upgrades.remove(pos);
272 if let Some(v) = self.substream_upgrade_protocol_override {
273 if v != version {
274 log::debug!("Substream upgrade protocol override: {:?} -> {:?}", version, v);
275 version = v;
276 }
277 }
278 let upgrade = upgrade::apply_outbound(substream, upgrade, version);
279 let timeout = Delay::new(timeout);
280 self.negotiating_out.push(SubstreamUpgrade {
281 user_data: Some(user_data),
282 timeout,
283 upgrade,
284 });
285 }
286 }
287 }
288
289 fn inject_event(&mut self, event: Self::InEvent) {
290 self.handler.inject_event(event);
291 }
292
293 fn inject_address_change(&mut self, new_address: &Multiaddr) {
294 self.handler.inject_address_change(new_address);
295 }
296
297 fn poll(&mut self, cx: &mut Context<'_>) -> Poll<
298 Result<ConnectionHandlerEvent<Self::OutboundOpenInfo, Self::OutEvent>, Self::Error>
299 > {
300 while let Poll::Ready(Some((user_data, res))) = self.negotiating_in.poll_next_unpin(cx) {
301 match res {
302 Ok(upgrade) => self.handler.inject_fully_negotiated_inbound(upgrade, user_data),
303 Err(err) => self.handler.inject_listen_upgrade_error(user_data, err),
304 }
305 }
306
307 while let Poll::Ready(Some((user_data, res))) = self.negotiating_out.poll_next_unpin(cx) {
308 match res {
309 Ok(upgrade) => self.handler.inject_fully_negotiated_outbound(upgrade, user_data),
310 Err(err) => self.handler.inject_dial_upgrade_error(user_data, err),
311 }
312 }
313
314 let poll_result = self.handler.poll(cx);
317
318 match (&mut self.shutdown, self.handler.connection_keep_alive()) {
321 (Shutdown::Later(timer, deadline), KeepAlive::Until(t)) =>
322 if *deadline != t {
323 *deadline = t;
324 timer.reset_at(t)
325 },
326 (_, KeepAlive::Until(t)) => self.shutdown = Shutdown::Later(Delay::new_at(t), t),
327 (_, KeepAlive::No) => self.shutdown = Shutdown::Asap,
328 (_, KeepAlive::Yes) => self.shutdown = Shutdown::None
329 };
330
331 match poll_result {
332 Poll::Ready(ProtocolsHandlerEvent::Custom(event)) => {
333 return Poll::Ready(Ok(ConnectionHandlerEvent::Custom(event)));
334 }
335 Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol }) => {
336 let id = self.unique_dial_upgrade_id;
337 let timeout = *protocol.timeout();
338 self.unique_dial_upgrade_id += 1;
339 let (version, upgrade, info) = protocol.into_upgrade();
340 self.queued_dial_upgrades.push((id, (version, SendWrapper(upgrade))));
341 return Poll::Ready(Ok(
342 ConnectionHandlerEvent::OutboundSubstreamRequest((id, info, timeout)),
343 ));
344 }
345 Poll::Ready(ProtocolsHandlerEvent::Close(err)) => return Poll::Ready(Err(err.into())),
346 Poll::Pending => (),
347 };
348
349 if self.negotiating_in.is_empty() && self.negotiating_out.is_empty() {
352 match self.shutdown {
353 Shutdown::None => {},
354 Shutdown::Asap => return Poll::Ready(Err(NodeHandlerWrapperError::KeepAliveTimeout)),
355 Shutdown::Later(ref mut delay, _) => match Future::poll(Pin::new(delay), cx) {
356 Poll::Ready(_) => return Poll::Ready(Err(NodeHandlerWrapperError::KeepAliveTimeout)),
357 Poll::Pending => {}
358 }
359 }
360 }
361
362 Poll::Pending
363 }
364}