use socket2::Type as SockType;
use std::{io, net::SocketAddr, time::Duration};
use crate::SocketFamily;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IcmpKind {
V4,
V6,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IcmpSocketType {
Dgram,
Raw,
}
impl IcmpSocketType {
pub fn is_dgram(&self) -> bool {
matches!(self, IcmpSocketType::Dgram)
}
pub fn is_raw(&self) -> bool {
matches!(self, IcmpSocketType::Raw)
}
pub(crate) fn try_from_sock_type(sock_type: SockType) -> io::Result<Self> {
match sock_type {
SockType::DGRAM => Ok(IcmpSocketType::Dgram),
SockType::RAW => Ok(IcmpSocketType::Raw),
_ => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid ICMP socket type",
)),
}
}
pub(crate) fn to_sock_type(&self) -> SockType {
match self {
IcmpSocketType::Dgram => SockType::DGRAM,
IcmpSocketType::Raw => SockType::RAW,
}
}
}
#[derive(Debug, Clone)]
pub struct IcmpConfig {
pub socket_family: SocketFamily,
pub bind: Option<SocketAddr>,
pub ttl: Option<u32>,
pub hoplimit: Option<u32>,
pub read_timeout: Option<Duration>,
pub write_timeout: Option<Duration>,
pub interface: Option<String>,
pub sock_type_hint: IcmpSocketType,
pub fib: Option<u32>,
}
impl IcmpConfig {
pub fn new(kind: IcmpKind) -> Self {
Self {
socket_family: match kind {
IcmpKind::V4 => SocketFamily::IPV4,
IcmpKind::V6 => SocketFamily::IPV6,
},
bind: None,
ttl: None,
hoplimit: None,
read_timeout: None,
write_timeout: None,
interface: None,
sock_type_hint: IcmpSocketType::Dgram,
fib: None,
}
}
pub fn from_family(socket_family: SocketFamily) -> Self {
Self {
socket_family,
..Self::new(match socket_family {
SocketFamily::IPV4 => IcmpKind::V4,
SocketFamily::IPV6 => IcmpKind::V6,
})
}
}
pub fn with_bind(mut self, addr: SocketAddr) -> Self {
self.bind = Some(addr);
self
}
pub fn with_ttl(mut self, ttl: u32) -> Self {
self.ttl = Some(ttl);
self
}
pub fn with_hoplimit(mut self, hops: u32) -> Self {
self.hoplimit = Some(hops);
self
}
pub fn with_hop_limit(self, hops: u32) -> Self {
self.with_hoplimit(hops)
}
pub fn with_read_timeout(mut self, timeout: Duration) -> Self {
self.read_timeout = Some(timeout);
self
}
pub fn with_write_timeout(mut self, timeout: Duration) -> Self {
self.write_timeout = Some(timeout);
self
}
pub fn with_interface(mut self, iface: impl Into<String>) -> Self {
self.interface = Some(iface.into());
self
}
pub fn with_sock_type(mut self, ty: IcmpSocketType) -> Self {
self.sock_type_hint = ty;
self
}
pub fn with_fib(mut self, fib: u32) -> Self {
self.fib = Some(fib);
self
}
pub fn validate(&self) -> io::Result<()> {
if let Some(addr) = self.bind {
let addr_family = crate::SocketFamily::from_socket_addr(&addr);
if addr_family != self.socket_family {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"bind address family does not match socket_family",
));
}
}
if self.socket_family.is_v4() && self.hoplimit.is_some() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"hoplimit is only supported for IPv6 ICMP sockets",
));
}
if self.socket_family.is_v6() && self.ttl.is_some() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"ttl is only supported for IPv4 ICMP sockets",
));
}
if matches!(self.read_timeout, Some(timeout) if timeout.is_zero()) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"read_timeout must be greater than zero",
));
}
if matches!(self.write_timeout, Some(timeout) if timeout.is_zero()) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"write_timeout must be greater than zero",
));
}
if matches!(self.interface.as_deref(), Some("")) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"interface must not be empty",
));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn icmp_config_builders() {
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let cfg = IcmpConfig::new(IcmpKind::V4)
.with_bind(addr)
.with_ttl(4)
.with_interface("eth0")
.with_sock_type(IcmpSocketType::Raw);
assert_eq!(cfg.socket_family, SocketFamily::IPV4);
assert_eq!(cfg.bind, Some(addr));
assert_eq!(cfg.ttl, Some(4));
assert_eq!(cfg.interface.as_deref(), Some("eth0"));
assert_eq!(cfg.sock_type_hint, IcmpSocketType::Raw);
}
#[test]
fn from_family_sets_expected_kind() {
let v4 = IcmpConfig::from_family(SocketFamily::IPV4);
let v6 = IcmpConfig::from_family(SocketFamily::IPV6);
assert_eq!(v4.socket_family, SocketFamily::IPV4);
assert_eq!(v6.socket_family, SocketFamily::IPV6);
}
#[test]
fn icmp_config_validate_rejects_family_mismatch() {
let cfg = IcmpConfig::new(IcmpKind::V4).with_bind("[::1]:0".parse().unwrap());
assert!(cfg.validate().is_err());
}
}