1use std::{
22 collections::VecDeque,
23 convert::Infallible,
24 error::Error,
25 fmt, io,
26 task::{Context, Poll},
27 time::Duration,
28};
29
30use futures::{
31 future::{BoxFuture, Either},
32 prelude::*,
33};
34use futures_timer::Delay;
35use libp2p_core::upgrade::ReadyUpgrade;
36use libp2p_swarm::{
37 handler::{ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound},
38 ConnectionHandler, ConnectionHandlerEvent, Stream, StreamProtocol, StreamUpgradeError,
39 SubstreamProtocol,
40};
41
42use crate::{protocol, PROTOCOL_NAME};
43
44#[derive(Debug, Clone)]
46pub struct Config {
47 timeout: Duration,
49 interval: Duration,
51}
52
53impl Config {
54 pub fn new() -> Self {
64 Self {
65 timeout: Duration::from_secs(20),
66 interval: Duration::from_secs(15),
67 }
68 }
69
70 pub fn with_timeout(mut self, d: Duration) -> Self {
72 self.timeout = d;
73 self
74 }
75
76 pub fn with_interval(mut self, d: Duration) -> Self {
78 self.interval = d;
79 self
80 }
81}
82
83impl Default for Config {
84 fn default() -> Self {
85 Self::new()
86 }
87}
88
89#[derive(Debug)]
91pub enum Failure {
92 Timeout,
95 Unsupported,
97 Other {
99 error: Box<dyn std::error::Error + Send + Sync + 'static>,
100 },
101}
102
103impl Failure {
104 fn other(e: impl std::error::Error + Send + Sync + 'static) -> Self {
105 Self::Other { error: Box::new(e) }
106 }
107}
108
109impl fmt::Display for Failure {
110 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111 match self {
112 Failure::Timeout => f.write_str("Ping timeout"),
113 Failure::Other { error } => write!(f, "Ping error: {error}"),
114 Failure::Unsupported => write!(f, "Ping protocol not supported"),
115 }
116 }
117}
118
119impl Error for Failure {
120 fn source(&self) -> Option<&(dyn Error + 'static)> {
121 match self {
122 Failure::Timeout => None,
123 Failure::Other { error } => Some(&**error),
124 Failure::Unsupported => None,
125 }
126 }
127}
128
129pub struct Handler {
132 config: Config,
134 interval: Delay,
136 pending_errors: VecDeque<Failure>,
138 failures: u32,
142 outbound: Option<OutboundState>,
144 inbound: Option<PongFuture>,
148 state: State,
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
153enum State {
154 Inactive {
156 reported: bool,
160 },
161 Active,
163}
164
165impl Handler {
166 pub fn new(config: Config) -> Self {
168 Handler {
169 config,
170 interval: Delay::new(Duration::new(0, 0)),
171 pending_errors: VecDeque::with_capacity(2),
172 failures: 0,
173 outbound: None,
174 inbound: None,
175 state: State::Active,
176 }
177 }
178
179 fn on_dial_upgrade_error(
180 &mut self,
181 DialUpgradeError { error, .. }: DialUpgradeError<
182 (),
183 <Self as ConnectionHandler>::OutboundProtocol,
184 >,
185 ) {
186 self.outbound = None; self.interval.reset(Duration::new(0, 0));
199
200 let error = match error {
201 StreamUpgradeError::NegotiationFailed => {
202 debug_assert_eq!(self.state, State::Active);
203
204 self.state = State::Inactive { reported: false };
205 return;
206 }
207 StreamUpgradeError::Timeout => Failure::Other {
209 error: Box::new(std::io::Error::new(
210 std::io::ErrorKind::TimedOut,
211 "ping protocol negotiation timed out",
212 )),
213 },
214 #[allow(unreachable_patterns)]
216 StreamUpgradeError::Apply(e) => libp2p_core::util::unreachable(e),
217 StreamUpgradeError::Io(e) => Failure::Other { error: Box::new(e) },
218 };
219
220 self.pending_errors.push_front(error);
221 }
222}
223
224impl ConnectionHandler for Handler {
225 type FromBehaviour = Infallible;
226 type ToBehaviour = Result<Duration, Failure>;
227 type InboundProtocol = ReadyUpgrade<StreamProtocol>;
228 type OutboundProtocol = ReadyUpgrade<StreamProtocol>;
229 type OutboundOpenInfo = ();
230 type InboundOpenInfo = ();
231
232 fn listen_protocol(&self) -> SubstreamProtocol<ReadyUpgrade<StreamProtocol>> {
233 SubstreamProtocol::new(ReadyUpgrade::new(PROTOCOL_NAME), ())
234 }
235
236 fn on_behaviour_event(&mut self, _: Infallible) {}
237
238 #[tracing::instrument(level = "trace", name = "ConnectionHandler::poll", skip(self, cx))]
239 fn poll(
240 &mut self,
241 cx: &mut Context<'_>,
242 ) -> Poll<ConnectionHandlerEvent<ReadyUpgrade<StreamProtocol>, (), Result<Duration, Failure>>>
243 {
244 match self.state {
245 State::Inactive { reported: true } => {
246 return Poll::Pending; }
248 State::Inactive { reported: false } => {
249 self.state = State::Inactive { reported: true };
250 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Err(
251 Failure::Unsupported,
252 )));
253 }
254 State::Active => {}
255 }
256
257 if let Some(fut) = self.inbound.as_mut() {
259 match fut.poll_unpin(cx) {
260 Poll::Pending => {}
261 Poll::Ready(Err(e)) => {
262 tracing::debug!("Inbound ping error: {:?}", e);
263 self.inbound = None;
264 }
265 Poll::Ready(Ok(stream)) => {
266 tracing::trace!("answered inbound ping from peer");
267
268 self.inbound = Some(protocol::recv_ping(stream).boxed());
270 }
271 }
272 }
273
274 loop {
275 if let Some(error) = self.pending_errors.pop_back() {
277 tracing::debug!("Ping failure: {:?}", error);
278
279 self.failures += 1;
280
281 if self.failures > 1 {
287 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Err(error)));
288 }
289 }
290
291 match self.outbound.take() {
293 Some(OutboundState::Ping(mut ping)) => match ping.poll_unpin(cx) {
294 Poll::Pending => {
295 self.outbound = Some(OutboundState::Ping(ping));
296 break;
297 }
298 Poll::Ready(Ok((stream, rtt))) => {
299 tracing::debug!(?rtt, "ping succeeded");
300 self.failures = 0;
301 self.interval.reset(self.config.interval);
302 self.outbound = Some(OutboundState::Idle(stream));
303 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Ok(rtt)));
304 }
305 Poll::Ready(Err(e)) => {
306 self.interval.reset(self.config.interval);
307 self.pending_errors.push_front(e);
308 }
309 },
310 Some(OutboundState::Idle(stream)) => match self.interval.poll_unpin(cx) {
311 Poll::Pending => {
312 self.outbound = Some(OutboundState::Idle(stream));
313 break;
314 }
315 Poll::Ready(()) => {
316 self.outbound = Some(OutboundState::Ping(
317 send_ping(stream, self.config.timeout).boxed(),
318 ));
319 }
320 },
321 Some(OutboundState::OpenStream) => {
322 self.outbound = Some(OutboundState::OpenStream);
323 break;
324 }
325 None => match self.interval.poll_unpin(cx) {
326 Poll::Pending => break,
327 Poll::Ready(()) => {
328 self.outbound = Some(OutboundState::OpenStream);
329 let protocol = SubstreamProtocol::new(ReadyUpgrade::new(PROTOCOL_NAME), ());
330 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
331 protocol,
332 });
333 }
334 },
335 }
336 }
337
338 Poll::Pending
339 }
340
341 fn on_connection_event(
342 &mut self,
343 event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
344 ) {
345 match event {
346 ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
347 protocol: mut stream,
348 ..
349 }) => {
350 stream.ignore_for_keep_alive();
351 self.inbound = Some(protocol::recv_ping(stream).boxed());
352 }
353 ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
354 protocol: mut stream,
355 ..
356 }) => {
357 stream.ignore_for_keep_alive();
358 self.outbound = Some(OutboundState::Ping(
359 send_ping(stream, self.config.timeout).boxed(),
360 ));
361 }
362 ConnectionEvent::DialUpgradeError(dial_upgrade_error) => {
363 self.on_dial_upgrade_error(dial_upgrade_error)
364 }
365 _ => {}
366 }
367 }
368}
369
370type PingFuture = BoxFuture<'static, Result<(Stream, Duration), Failure>>;
371type PongFuture = BoxFuture<'static, Result<Stream, io::Error>>;
372
373enum OutboundState {
375 OpenStream,
377 Idle(Stream),
379 Ping(PingFuture),
381}
382
383async fn send_ping(stream: Stream, timeout: Duration) -> Result<(Stream, Duration), Failure> {
385 let ping = protocol::send_ping(stream);
386 futures::pin_mut!(ping);
387
388 match future::select(ping, Delay::new(timeout)).await {
389 Either::Left((Ok((stream, rtt)), _)) => Ok((stream, rtt)),
390 Either::Left((Err(e), _)) => Err(Failure::other(e)),
391 Either::Right(((), _)) => Err(Failure::Timeout),
392 }
393}