use std::{
collections::VecDeque,
convert::Infallible,
error::Error,
fmt, io,
task::{Context, Poll},
time::Duration,
};
use ant_libp2p_core::upgrade::ReadyUpgrade;
use ant_libp2p_swarm::{
handler::{ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound},
ConnectionHandler, ConnectionHandlerEvent, Stream, StreamProtocol, StreamUpgradeError,
SubstreamProtocol,
};
use futures::{
future::{BoxFuture, Either},
prelude::*,
};
use futures_timer::Delay;
use crate::{protocol, PROTOCOL_NAME};
#[derive(Debug, Clone)]
pub struct Config {
timeout: Duration,
interval: Duration,
}
impl Config {
pub fn new() -> Self {
Self {
timeout: Duration::from_secs(20),
interval: Duration::from_secs(15),
}
}
pub fn with_timeout(mut self, d: Duration) -> Self {
self.timeout = d;
self
}
pub fn with_interval(mut self, d: Duration) -> Self {
self.interval = d;
self
}
}
impl Default for Config {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub enum Failure {
Timeout,
Unsupported,
Other {
error: Box<dyn std::error::Error + Send + Sync + 'static>,
},
}
impl Failure {
fn other(e: impl std::error::Error + Send + Sync + 'static) -> Self {
Self::Other { error: Box::new(e) }
}
}
impl fmt::Display for Failure {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Failure::Timeout => f.write_str("Ping timeout"),
Failure::Other { error } => write!(f, "Ping error: {error}"),
Failure::Unsupported => write!(f, "Ping protocol not supported"),
}
}
}
impl Error for Failure {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Failure::Timeout => None,
Failure::Other { error } => Some(&**error),
Failure::Unsupported => None,
}
}
}
pub struct Handler {
config: Config,
interval: Delay,
pending_errors: VecDeque<Failure>,
failures: u32,
outbound: Option<OutboundState>,
inbound: Option<PongFuture>,
state: State,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum State {
Inactive {
reported: bool,
},
Active,
}
impl Handler {
pub fn new(config: Config) -> Self {
Handler {
config,
interval: Delay::new(Duration::new(0, 0)),
pending_errors: VecDeque::with_capacity(2),
failures: 0,
outbound: None,
inbound: None,
state: State::Active,
}
}
fn on_dial_upgrade_error(
&mut self,
DialUpgradeError { error, .. }: DialUpgradeError<
<Self as ConnectionHandler>::OutboundOpenInfo,
<Self as ConnectionHandler>::OutboundProtocol,
>,
) {
self.outbound = None;
self.interval.reset(Duration::new(0, 0));
let error = match error {
StreamUpgradeError::NegotiationFailed => {
debug_assert_eq!(self.state, State::Active);
self.state = State::Inactive { reported: false };
return;
}
StreamUpgradeError::Timeout => Failure::Other {
error: Box::new(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"ping protocol negotiation timed out",
)),
},
#[allow(unreachable_patterns)]
StreamUpgradeError::Apply(e) => ant_libp2p_core::util::unreachable(e),
StreamUpgradeError::Io(e) => Failure::Other { error: Box::new(e) },
};
self.pending_errors.push_front(error);
}
}
impl ConnectionHandler for Handler {
type FromBehaviour = Infallible;
type ToBehaviour = Result<Duration, Failure>;
type InboundProtocol = ReadyUpgrade<StreamProtocol>;
type OutboundProtocol = ReadyUpgrade<StreamProtocol>;
type OutboundOpenInfo = ();
type InboundOpenInfo = ();
fn listen_protocol(&self) -> SubstreamProtocol<ReadyUpgrade<StreamProtocol>, ()> {
SubstreamProtocol::new(ReadyUpgrade::new(PROTOCOL_NAME), ())
}
fn on_behaviour_event(&mut self, _: Infallible) {}
#[tracing::instrument(level = "trace", name = "ConnectionHandler::poll", skip(self, cx))]
fn poll(
&mut self,
cx: &mut Context<'_>,
) -> Poll<ConnectionHandlerEvent<ReadyUpgrade<StreamProtocol>, (), Result<Duration, Failure>>>
{
match self.state {
State::Inactive { reported: true } => {
return Poll::Pending; }
State::Inactive { reported: false } => {
self.state = State::Inactive { reported: true };
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Err(
Failure::Unsupported,
)));
}
State::Active => {}
}
if let Some(fut) = self.inbound.as_mut() {
match fut.poll_unpin(cx) {
Poll::Pending => {}
Poll::Ready(Err(e)) => {
tracing::debug!("Inbound ping error: {:?}", e);
self.inbound = None;
}
Poll::Ready(Ok(stream)) => {
tracing::trace!("answered inbound ping from peer");
self.inbound = Some(protocol::recv_ping(stream).boxed());
}
}
}
loop {
if let Some(error) = self.pending_errors.pop_back() {
tracing::debug!("Ping failure: {:?}", error);
self.failures += 1;
if self.failures > 1 {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Err(error)));
}
}
match self.outbound.take() {
Some(OutboundState::Ping(mut ping)) => match ping.poll_unpin(cx) {
Poll::Pending => {
self.outbound = Some(OutboundState::Ping(ping));
break;
}
Poll::Ready(Ok((stream, rtt))) => {
tracing::debug!(?rtt, "ping succeeded");
self.failures = 0;
self.interval.reset(self.config.interval);
self.outbound = Some(OutboundState::Idle(stream));
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Ok(rtt)));
}
Poll::Ready(Err(e)) => {
self.interval.reset(self.config.interval);
self.pending_errors.push_front(e);
}
},
Some(OutboundState::Idle(stream)) => match self.interval.poll_unpin(cx) {
Poll::Pending => {
self.outbound = Some(OutboundState::Idle(stream));
break;
}
Poll::Ready(()) => {
self.outbound = Some(OutboundState::Ping(
send_ping(stream, self.config.timeout).boxed(),
));
}
},
Some(OutboundState::OpenStream) => {
self.outbound = Some(OutboundState::OpenStream);
break;
}
None => match self.interval.poll_unpin(cx) {
Poll::Pending => break,
Poll::Ready(()) => {
self.outbound = Some(OutboundState::OpenStream);
let protocol = SubstreamProtocol::new(ReadyUpgrade::new(PROTOCOL_NAME), ());
return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol,
});
}
},
}
}
Poll::Pending
}
fn on_connection_event(
&mut self,
event: ConnectionEvent<
Self::InboundProtocol,
Self::OutboundProtocol,
Self::InboundOpenInfo,
Self::OutboundOpenInfo,
>,
) {
match event {
ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
protocol: mut stream,
..
}) => {
stream.ignore_for_keep_alive();
self.inbound = Some(protocol::recv_ping(stream).boxed());
}
ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
protocol: mut stream,
..
}) => {
stream.ignore_for_keep_alive();
self.outbound = Some(OutboundState::Ping(
send_ping(stream, self.config.timeout).boxed(),
));
}
ConnectionEvent::DialUpgradeError(dial_upgrade_error) => {
self.on_dial_upgrade_error(dial_upgrade_error)
}
_ => {}
}
}
}
type PingFuture = BoxFuture<'static, Result<(Stream, Duration), Failure>>;
type PongFuture = BoxFuture<'static, Result<Stream, io::Error>>;
enum OutboundState {
OpenStream,
Idle(Stream),
Ping(PingFuture),
}
async fn send_ping(stream: Stream, timeout: Duration) -> Result<(Stream, Duration), Failure> {
let ping = protocol::send_ping(stream);
futures::pin_mut!(ping);
match future::select(ping, Delay::new(timeout)).await {
Either::Left((Ok((stream, rtt)), _)) => Ok((stream, rtt)),
Either::Left((Err(e), _)) => Err(Failure::other(e)),
Either::Right(((), _)) => Err(Failure::Timeout),
}
}