use crate::*;
use async_trait::async_trait;
use socket2::{Domain, SockAddr, Socket, Type};
use std::io;
use std::net::SocketAddr;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::Instant;
pub struct PingClientTcp {
config: PingClientConfig,
}
impl PingClientTcp {
pub fn new(config: &PingClientConfig) -> PingClientTcp {
return PingClientTcp { config: config.clone() };
}
#[tracing::instrument(name = "Running TCP ping in ping client", level = "debug", skip(self))]
async fn ping_target(&self, source: &SocketAddr, target: &SocketAddr) -> PingClientResult<PingClientPingResultDetails> {
let socket = self.prepare_socket_for_ping(source).map_err(|e| PingClientError::PreparationFailed(Box::new(e)))?;
let start_time = Instant::now();
let connect_result = socket.connect_timeout(&SockAddr::from(target.clone()), self.config.wait_timeout);
let rtt = Instant::now().duration_since(start_time);
match connect_result {
Err(e) if e.kind() == io::ErrorKind::TimedOut => return Ok(PingClientPingResultDetails::new(None, rtt, true, None)),
Err(e) => return Err(PingClientError::PingFailed(Box::new(e))),
Ok(()) => (),
}
let local_addr = socket.local_addr();
let mut warning: Option<PingClientWarning> = None;
if self.config.check_disconnect {
warning = match self.shutdown_connection(socket, &target).await {
Err(e) => Some(PingClientWarning::DisconnectFailed(Box::new(e))),
Ok(_) => None,
}
} else {
drop(socket);
}
return match local_addr {
Ok(addr) => Ok(PingClientPingResultDetails::new(Some(addr.as_socket().unwrap()), rtt, false, warning)),
Err(_) => Ok(PingClientPingResultDetails::new(None, rtt, false, warning)),
};
}
#[tracing::instrument(name = "Creating socket for ping", level = "debug", skip(self))]
fn prepare_socket_for_ping(&self, source: &SocketAddr) -> io::Result<Socket> {
let socket_domain = if source.is_ipv4() { Domain::IPV4 } else { Domain::IPV6 };
let socket = Socket::new(socket_domain, Type::STREAM, None)?;
socket.set_read_timeout(Some(self.config.wait_timeout))?;
if !self.config.check_disconnect {
socket.set_linger(Some(Duration::from_secs(0)))?;
}
if let Some(ttl) = self.config.time_to_live {
socket.set_ttl(ttl)?;
}
socket.bind(&SockAddr::from(source.clone()))?;
return Ok(socket);
}
#[tracing::instrument(name = "Shutdown connection after ping", level = "debug", skip(self))]
async fn shutdown_connection(&self, socket: Socket, target: &SocketAddr) -> io::Result<()> {
if !self.config.wait_before_disconnect.is_zero() {
tracing::debug!("Waiting {:?} before disconnect; target={}", self.config.wait_before_disconnect, target);
tokio::time::sleep(self.config.wait_before_disconnect).await;
}
let mut connection = TcpStream::from_std(socket.into())?;
let mut read_buffer = vec![0 as u8; 128];
tracing::debug!("Checking if connection is already closed; target={}", target);
loop {
match connection.try_read(&mut read_buffer) {
Ok(0) => {
return Err(io::Error::new(io::ErrorKind::ConnectionAborted, "Connection is already half shutdown by remote side."));
}
Ok(_) => (),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => break,
Err(e) => {
return Err(e.into());
}
}
}
tracing::debug!("Shutdown connection write; target={}", target);
connection.shutdown().await?;
tracing::debug!("Wait until shutdown completes; timeout={:?}, target={}", self.config.disconnect_timeout, target);
if self.config.disconnect_timeout.is_zero() {
self.wait_for_connection_shutdown(&mut connection, &mut read_buffer).await?;
} else {
let disconnect_deadline = Instant::now() + self.config.disconnect_timeout;
tokio::select! {
wait_result = self.wait_for_connection_shutdown(&mut connection, &mut read_buffer) => {
if wait_result.is_err() {
return wait_result;
}
}
_ = tokio::time::sleep_until(disconnect_deadline.clone()) => {
return Err(io::Error::new(io::ErrorKind::TimedOut, "Disconnect timed out."));
}
}
}
return Ok(());
}
async fn wait_for_connection_shutdown(&self, connection: &mut TcpStream, read_buffer: &mut Vec<u8>) -> io::Result<()> {
while connection.read(&mut read_buffer[..]).await? > 0 {
continue;
}
return Ok(());
}
}
#[async_trait]
impl PingClient for PingClientTcp {
fn protocol(&self) -> &'static str {
"TCP"
}
async fn prepare_ping(&mut self, _: &SocketAddr) -> Result<(), PingClientError> {
Ok(())
}
async fn ping(&self, source: &SocketAddr, target: &SocketAddr) -> PingClientResult<PingClientPingResultDetails> {
return self.ping_target(source, target).await;
}
}