use async_trait::async_trait;
use rand::{distributions, prelude::*};
use std::num::NonZeroU32;
use std::time::{Duration, Instant};
use std::{error::Error, io};
use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use libp2prs_core::transport::TransportError;
use libp2prs_core::upgrade::UpgradeInfo;
use libp2prs_runtime::task;
use crate::connection::Connection;
use crate::protocol_handler::{IProtocolHandler, Notifiee, ProtocolHandler};
use crate::substream::Substream;
use libp2prs_core::ProtocolId;
#[derive(Clone, Debug)]
pub struct PingConfig {
timeout: Duration,
interval: Duration,
max_failures: NonZeroU32,
unsolicited: bool,
keep_alive: bool,
}
impl Default for PingConfig {
fn default() -> Self {
Self::new()
}
}
impl PingConfig {
pub fn new() -> Self {
Self {
timeout: Duration::from_secs(20),
interval: Duration::from_secs(15),
max_failures: NonZeroU32::new(1).expect("1 != 0"),
unsolicited: false,
keep_alive: false,
}
}
pub fn timeout(&self) -> Duration {
self.timeout
}
pub fn interval(&self) -> Duration {
self.interval
}
pub fn max_failures(&self) -> u32 {
self.max_failures.into()
}
pub fn unsolicited(&self) -> bool {
self.unsolicited
}
pub fn with_timeout(mut self, d: Duration) -> Self {
self.timeout = d;
self
}
pub fn with_interval(mut self, d: Duration) -> Self {
self.interval = d;
self
}
pub fn with_max_failures(mut self, n: NonZeroU32) -> Self {
self.max_failures = n;
self
}
pub fn with_unsolicited(mut self, b: bool) -> Self {
self.unsolicited = b;
self
}
pub fn with_keep_alive(mut self, b: bool) -> Self {
self.keep_alive = b;
self
}
}
pub async fn ping<T: AsyncRead + AsyncWrite + Send + Unpin + std::fmt::Debug>(
mut stream: T,
timeout: Duration,
) -> Result<Duration, TransportError> {
let ping = async {
let payload: [u8; PING_SIZE] = thread_rng().sample(distributions::Standard);
log::trace!("Preparing ping payload {:?}", payload);
stream.write_all(&payload).await?;
let started = Instant::now();
let mut recv_payload = [0u8; PING_SIZE];
stream.read_exact(&mut recv_payload).await?;
stream.close().await?;
if recv_payload == payload {
log::trace!("ping succeeded for {:?}", stream);
Ok(started.elapsed())
} else {
log::info!("Invalid ping payload received {:?}", payload);
Err(io::Error::new(io::ErrorKind::InvalidData, "Ping payload mismatch"))
}
};
task::timeout(timeout, ping)
.await
.map_or(Err(TransportError::Timeout), |r| r.map_err(|e| e.into()))
}
#[derive(Debug, Clone)]
pub(crate) struct PingHandler {
config: PingConfig,
}
impl PingHandler {
pub(crate) fn new(config: PingConfig) -> Self {
PingHandler { config }
}
}
pub(crate) const PING_PROTOCOL: &[u8] = b"/ipfs/ping/1.0.0";
const PING_SIZE: usize = 32;
impl UpgradeInfo for PingHandler {
type Info = ProtocolId;
fn protocol_info(&self) -> Vec<Self::Info> {
vec![PING_PROTOCOL.into()]
}
}
impl Notifiee for PingHandler {
fn connected(&mut self, connection: &mut Connection) {
let config = self.config.clone();
if config.unsolicited() {
log::trace!("starting Ping service for {:?}", connection);
connection.start_ping(config.timeout(), config.interval(), config.max_failures());
}
}
}
#[async_trait]
impl ProtocolHandler for PingHandler {
async fn handle(&mut self, mut stream: Substream, _info: <Self as UpgradeInfo>::Info) -> Result<(), Box<dyn Error>> {
log::trace!("Ping Protocol handling on {:?}", stream);
let mut payload = [0u8; PING_SIZE];
while stream.read_exact(&mut payload).await.is_ok() {
stream.write_all(&payload).await?;
}
stream.close().await?;
Ok(())
}
fn box_clone(&self) -> IProtocolHandler {
Box::new(self.clone())
}
}
#[cfg(test)]
mod tests {
use super::PingHandler;
use crate::ping::{ping, PingConfig};
use crate::protocol_handler::ProtocolHandler;
use crate::substream::Substream;
use libp2prs_core::transport::ListenerEvent;
use libp2prs_core::upgrade::UpgradeInfo;
use libp2prs_core::{
multiaddr::multiaddr,
transport::{memory::MemoryTransport, Transport},
};
use libp2prs_runtime::task;
use rand::{thread_rng, Rng};
use std::time::Duration;
#[test]
fn ping_pong() {
let mem_addr = multiaddr![Memory(thread_rng().gen::<u64>())];
let listener_addr = mem_addr.clone();
let mut listener = MemoryTransport.listen_on(mem_addr).unwrap();
task::spawn(async move {
let socket = match listener.accept().await.unwrap() {
ListenerEvent::Accepted(socket) => socket,
_ => panic!("unreachable"),
};
let socket = Substream::new_with_default(Box::new(socket));
let mut handler = PingHandler::new(PingConfig::new().with_unsolicited(true));
let _ = handler.handle(socket, handler.protocol_info().first().unwrap().clone()).await;
});
task::block_on(async move {
let socket = MemoryTransport.dial(listener_addr).await.unwrap();
let rtt = ping(socket, Duration::from_secs(3)).await.unwrap();
assert!(rtt > Duration::from_secs(0));
});
}
}