use log::debug;
use std::{fmt, io, net, sync, time};
use tokio::sync::{mpsc as tmpsc, oneshot};
mod multiplexer_task;
#[cfg(test)]
mod tests;
use crate::{
message::echo::{EchoId, EchoSeq, IcmpEchoRequest},
ping::multiplexer_task::{MultiplexTask, MultiplexerCommand, SendSessionState},
platform,
socket::{SocketConfig, SocketPair},
Icmpv4, Icmpv6, IpVersion,
};
pub use multiplexer_task::{
AddSessionError, LifecycleError, ReplyTimestamp, SendPingError, SessionHandle,
};
#[derive(Clone)]
pub struct PingMultiplexer {
state: sync::Arc<MultiplexerClientState>,
}
impl PingMultiplexer {
pub fn new(
icmpv4_config: SocketConfig<Icmpv4>,
icmpv6_config: SocketConfig<Icmpv6>,
) -> io::Result<Self> {
let (mut inner, ipv4_local_port, ipv6_local_port, sockets, tx, send_state) =
MultiplexTask::new(icmpv4_config, icmpv6_config)?;
let handle = tokio::spawn(async move {
inner.run().await;
});
Ok(Self {
state: sync::Arc::new(MultiplexerClientState {
commands: tx,
ipv4_local_port,
ipv6_local_port,
sockets,
task_handle: Some(handle).into(),
send_sessions: send_state,
req_pool: opool::Pool::new(4, ReqAllocator),
}),
})
}
pub async fn add_session(
&self,
ip: net::IpAddr,
id: EchoId,
data: Vec<u8>,
) -> Result<(SessionHandle, tmpsc::Receiver<ReplyTimestamp>), AddSessionError> {
let (tx, rx) = oneshot::channel();
self.send_cmd(
MultiplexerCommand::AddSession {
ip,
id,
data,
reply: tx,
},
rx,
)
.await?
}
pub async fn send_ping(
&self,
session_handle: SessionHandle,
seq: EchoSeq,
) -> Result<time::Instant, SendPingError> {
{
let (mut req, ip) = {
if let Some(session_send_state) = self
.state
.send_sessions
.read()
.unwrap()
.get(&session_handle)
{
let mut req = self.state.req_pool.get();
req.set_id(session_send_state.echo_data.id);
req.set_seq(seq);
req.set_data(&session_send_state.echo_data.data);
(req, session_send_state.ip)
} else {
return Err(SendPingError::InvalidSessionHandle);
}
};
self.state.sockets.send_to_either(&mut *req, ip).await?;
debug!("Sent {session_handle:?} seq {seq:?}");
Ok(time::Instant::now())
}
}
pub async fn close_session(&self, session_handle: SessionHandle) -> Result<(), LifecycleError> {
let (tx, rx) = oneshot::channel();
self.send_cmd(
MultiplexerCommand::CloseSession {
session_handle,
reply: tx,
},
rx,
)
.await
}
pub async fn shutdown(&self) {
let (tx, rx) = oneshot::channel();
if let Err(e) = self.send_cmd(MultiplexerCommand::Shutdown(tx), rx).await {
match e {
LifecycleError::Shutdown => {
}
}
}
let handle = match self.state.task_handle.lock().unwrap().take() {
Some(h) => h,
None => return,
};
if let Err(e) = handle.await {
debug!("Inner task exited with error: {}", e);
};
}
pub fn ipv4_local_port(&self) -> u16 {
self.state.ipv4_local_port
}
pub fn ipv6_local_port(&self) -> u16 {
self.state.ipv6_local_port
}
pub fn platform_echo_id(&self, ip_version: IpVersion) -> Option<EchoId> {
if platform::icmp_send_overwrite_echo_id_with_local_port() {
let port = match ip_version {
IpVersion::V4 => self.ipv4_local_port(),
IpVersion::V6 => self.ipv6_local_port(),
};
Some(EchoId::from_be(port))
} else {
None
}
}
async fn send_cmd<T>(
&self,
cmd: MultiplexerCommand,
rx: oneshot::Receiver<T>,
) -> Result<T, LifecycleError> {
self.state
.commands
.send(cmd)
.await
.map_err(|_| LifecycleError::Shutdown)?;
rx.await.map_err(|_| LifecycleError::Shutdown)
}
}
impl fmt::Debug for PingMultiplexer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "PingMultiplexer")
}
}
struct MultiplexerClientState {
commands: tmpsc::Sender<MultiplexerCommand>,
sockets: sync::Arc<SocketPair>,
ipv4_local_port: u16,
ipv6_local_port: u16,
task_handle: sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
send_sessions: sync::Arc<sync::RwLock<hashbrown::HashMap<SessionHandle, SendSessionState>>>,
req_pool: opool::Pool<ReqAllocator, IcmpEchoRequest>,
}
struct ReqAllocator;
impl opool::PoolAllocator<IcmpEchoRequest> for ReqAllocator {
fn allocate(&self) -> IcmpEchoRequest {
IcmpEchoRequest::new()
}
}