use crate::message::IcmpV6MsgType;
#[cfg(doc)]
use crate::ping::PingMultiplexer;
use crate::{
message::{
decode::DecodedIcmpMsg,
echo::{parse_echo_reply, EchoId, EchoSeq},
IcmpV4MsgType,
},
platform,
socket::{SocketConfig, SocketPair},
Icmpv4, Icmpv6,
};
use hashbrown::hash_map::Entry;
use log::{debug, warn};
use std::{fmt, hash, io, net, sync, time};
use tokio::sync::{mpsc as tmpsc, mpsc::error::TrySendError, oneshot};
pub(crate) struct MultiplexTask {
v4_buf: Vec<u8>,
v6_buf: Vec<u8>,
sockets: sync::Arc<SocketPair>,
send_session_states:
sync::Arc<sync::RwLock<hashbrown::HashMap<SessionHandle, SendSessionState>>>,
recv_session_states: hashbrown::HashMap<RecvHashKey, RecvSessionState>,
commands: tmpsc::Receiver<MultiplexerCommand>,
shutdown: bool,
next_handle_id: u64,
}
impl MultiplexTask {
#[allow(clippy::type_complexity)]
pub(crate) fn new(
icmpv4_config: SocketConfig<Icmpv4>,
icmpv6_config: SocketConfig<Icmpv6>,
) -> io::Result<(
Self,
u16,
u16,
sync::Arc<SocketPair>,
tmpsc::Sender<MultiplexerCommand>,
sync::Arc<sync::RwLock<hashbrown::HashMap<SessionHandle, SendSessionState>>>,
)> {
let (tx, rx) = tmpsc::channel(16);
let sockets = sync::Arc::new(SocketPair::new(icmpv4_config, icmpv6_config)?);
let v4_port = sockets.ipv4.local_port();
let v6_port = sockets.ipv6.local_port();
let send_session_states = sync::Arc::new(sync::RwLock::new(hashbrown::HashMap::new()));
Ok((
Self {
v4_buf: Vec::new(),
v6_buf: Vec::new(),
sockets: sockets.clone(),
next_handle_id: 0,
send_session_states: send_session_states.clone(),
recv_session_states: hashbrown::HashMap::new(),
commands: rx,
shutdown: false,
},
v4_port,
v6_port,
sockets,
tx,
send_session_states,
))
}
pub(crate) async fn run(&mut self) {
loop {
if self.shutdown {
break;
}
if let Err(e) = self.recv_or_cmd().await {
warn!("Recv task error: {e}")
}
}
}
async fn recv_or_cmd(&mut self) -> Result<(), RecvError> {
let send_states = &mut self.send_session_states;
let recv_states = &mut self.recv_session_states;
tokio::select! {
v4_res = self.sockets.ipv4.recv(&mut self.v4_buf) => {
let (msg, _range) = v4_res?;
handle_recv(msg, IcmpV4MsgType::EchoReply as u8, send_states, recv_states)?;
}
v6_res = self.sockets.ipv6.recv(&mut self.v6_buf) => {
let (msg, _range) = v6_res?;
handle_recv(msg, IcmpV6MsgType::EchoReply as u8, send_states, recv_states)?;
}
cmd_opt = self.commands.recv() => {
match cmd_opt {
None => {
self.handle_command(MultiplexerCommand::Shutdown(oneshot::channel().0)).await?
}
Some(cmd) => self.handle_command(cmd).await?
}
}
}
Ok(())
}
async fn handle_command(&mut self, cmd: MultiplexerCommand) -> Result<(), RecvError> {
match cmd {
MultiplexerCommand::Shutdown(reply) => {
self.shutdown = true;
self.send_session_states.write().unwrap().clear();
self.recv_session_states.clear();
self.commands.close();
reply_if_possible(reply, ())
}
MultiplexerCommand::AddSession {
ip,
id,
data,
reply,
} => reply_if_possible(reply, self.add_session(ip, id, data, 16)),
MultiplexerCommand::CloseSession {
session_handle,
reply,
} => {
handle_close_session(
session_handle,
&mut self.send_session_states,
&mut self.recv_session_states,
);
reply_if_possible(reply, ())
}
}
Ok(())
}
fn add_session(
&mut self,
ip: net::IpAddr,
id: EchoId,
data: Vec<u8>,
channel_buf_size: usize,
) -> Result<(SessionHandle, tmpsc::Receiver<ReplyTimestamp>), AddSessionError> {
let buf_len = 4 + 4 + data.len();
match ip {
net::IpAddr::V4(_) => {
let prefix_len = if platform::ipv4_recv_prefix_ipv4_header() {
60
} else {
0
};
let buf_len = prefix_len + buf_len;
if self.v4_buf.len() < buf_len {
self.v4_buf.resize(buf_len, 0);
}
}
net::IpAddr::V6(_) => {
if self.v6_buf.len() < buf_len {
self.v6_buf.resize(buf_len, 0);
}
}
}
let echo_data = sync::Arc::new(SessionEchoData { id, data });
let key = RecvHashKey {
echo_data: echo_data.clone(),
};
let (tx, rx) = tmpsc::channel(channel_buf_size);
let recv_state = match self.recv_session_states.entry(key) {
Entry::Occupied(_) => {
return Err(AddSessionError::Duplicate);
}
Entry::Vacant(v) => {
v.insert(RecvSessionState {
tx,
session_handle: SessionHandle { id: u64::MAX },
})
}
};
let send_state = SendSessionState {
ip,
echo_data: echo_data.clone(),
};
loop {
let handle = SessionHandle {
id: self.next_handle_id,
};
self.next_handle_id = self.next_handle_id.wrapping_add(1);
match self.send_session_states.write().unwrap().entry(handle) {
Entry::Occupied(_) => {
continue;
}
Entry::Vacant(v) => {
v.insert(send_state);
recv_state.session_handle = handle;
debug!(
"Added session: handle = {handle:?}, id = {id:?}, data = {}",
hex::encode(&echo_data.data)
);
return Ok((handle, rx));
}
}
}
}
}
fn handle_recv(
msg: &[u8],
echo_reply_type: u8,
send_states: &mut sync::Arc<sync::RwLock<hashbrown::HashMap<SessionHandle, SendSessionState>>>,
recv_states: &mut hashbrown::HashMap<RecvHashKey, RecvSessionState>,
) -> Result<(), RecvError> {
let decoded = if let Ok(decoded) = DecodedIcmpMsg::decode(msg) {
decoded
} else {
debug!("ICMP message parse failed");
return Ok(());
};
if decoded.msg_type() != echo_reply_type || decoded.msg_code() != 0 {
debug!(
"Skipping irrelevant ICMP message type {} code {}",
decoded.msg_type(),
decoded.msg_code()
);
return Ok(());
}
let (seq, key) = if let Some((id, seq, data)) = parse_echo_reply(decoded.body()) {
(seq, RefHashKey { id, data })
} else {
debug!("Couldn't parse body as Echo Reply");
return Ok(());
};
if let Some(recv_state) = recv_states.get(&key) {
debug!("Reply for {:?}: seq {:?}", recv_state.session_handle, seq,);
if let Err(e) = recv_state.tx.try_send(ReplyTimestamp {
seq,
received_at: time::Instant::now(),
}) {
match e {
TrySendError::Full(_) => {
warn!("Session channel overflow");
}
TrySendError::Closed(_) => {
debug!("Session channel closed; closing session");
handle_close_session(recv_state.session_handle, send_states, recv_states)
}
}
}
} else {
debug!("Couldn't find session for {key:?}");
}
Ok(())
}
fn handle_close_session(
session_handle: SessionHandle,
send_session_states: &mut sync::Arc<
sync::RwLock<hashbrown::HashMap<SessionHandle, SendSessionState>>,
>,
recv_session_states: &mut hashbrown::HashMap<RecvHashKey, RecvSessionState>,
) {
if let Some(send_state) = send_session_states.write().unwrap().remove(&session_handle) {
recv_session_states.remove(&RecvHashKey {
echo_data: send_state.echo_data,
});
}
}
fn reply_if_possible<T>(reply: oneshot::Sender<T>, val: T) {
if reply.send(val).is_err() {
debug!("Could not reply - channel closed");
}
}
pub(crate) enum MultiplexerCommand {
Shutdown(oneshot::Sender<()>),
AddSession {
ip: net::IpAddr,
id: EchoId,
data: Vec<u8>,
reply: oneshot::Sender<
Result<(SessionHandle, tmpsc::Receiver<ReplyTimestamp>), AddSessionError>,
>,
},
CloseSession {
session_handle: SessionHandle,
reply: oneshot::Sender<()>,
},
}
#[derive(Clone, Copy, Hash, Debug, PartialEq, Eq)]
pub struct SessionHandle {
id: u64,
}
#[derive(Debug, PartialEq, Eq)]
pub struct ReplyTimestamp {
pub seq: EchoSeq,
pub received_at: time::Instant,
}
#[derive(Debug, thiserror::Error)]
pub enum LifecycleError {
#[error("Multiplexer has shut down")]
Shutdown,
}
#[derive(Debug, thiserror::Error)]
pub enum AddSessionError {
#[error("Duplicate session metadata")]
Duplicate,
#[error("Lifecycle error: {0}")]
Lifecycle(#[from] LifecycleError),
}
#[derive(Debug, thiserror::Error)]
enum RecvError {
#[error("IO error: {0}")]
Io(#[from] io::Error),
}
#[derive(Debug, thiserror::Error)]
pub enum SendPingError {
#[error("Invalid session handle")]
InvalidSessionHandle,
#[error("IO error: {0}")]
Io(#[from] io::Error),
#[error("Task error: {0}")]
Lifecycle(#[from] LifecycleError),
}
#[derive(Debug)]
pub(crate) struct SendSessionState {
pub(crate) ip: net::IpAddr,
pub(crate) echo_data: sync::Arc<SessionEchoData>,
}
#[derive(Debug)]
struct RecvSessionState {
session_handle: SessionHandle,
tx: tokio::sync::mpsc::Sender<ReplyTimestamp>,
}
#[derive(Debug, PartialEq, Eq)]
struct RecvHashKey {
echo_data: sync::Arc<SessionEchoData>,
}
#[derive(PartialEq, Eq)]
pub(crate) struct SessionEchoData {
pub(crate) id: EchoId,
pub(crate) data: Vec<u8>,
}
impl fmt::Debug for SessionEchoData {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SessionEchoData")
.field("id", &self.id)
.field("data", &hex::encode(&self.data))
.finish()
}
}
impl hash::Hash for RecvHashKey {
fn hash<H: hash::Hasher>(&self, state: &mut H) {
self.echo_data.id.hash(state);
self.echo_data.data.hash(state);
}
}
#[derive(PartialEq, Eq)]
struct RefHashKey<'a> {
id: EchoId,
data: &'a [u8],
}
#[allow(clippy::needless_lifetimes)] impl<'a> hash::Hash for RefHashKey<'a> {
fn hash<H: hash::Hasher>(&self, state: &mut H) {
self.id.hash(state);
self.data.hash(state);
}
}
#[allow(clippy::needless_lifetimes)] impl<'a> hashbrown::Equivalent<RecvHashKey> for RefHashKey<'a> {
fn equivalent(&self, key: &RecvHashKey) -> bool {
self.id == key.echo_data.id && self.data == key.echo_data.data
}
}
impl fmt::Debug for RefHashKey<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RefHashKey")
.field("id", &self.id)
.field("data", &hex::encode(self.data))
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use hashbrown::Equivalent;
#[test]
fn hash_key_hash_equivalent_to_ref_hash_key() {
let key = RecvHashKey {
echo_data: SessionEchoData {
id: EchoId::from_be(1234),
data: vec![5, 6, 7, 8],
}
.into(),
};
let mut ref_key = RefHashKey {
id: key.echo_data.id,
data: &key.echo_data.data,
};
assert!(ref_key.equivalent(&key));
ref_key.id = [42_u8; 2].into();
assert!(!ref_key.equivalent(&key));
}
}