use std::marker::PhantomData;
use bytes::{Bytes, BytesMut};
use crate::IpVersion;
const ICMP_HEADER_LEN: usize = 8;
const ICMPV4_TYPE_ECHO_REQUEST: u8 = 8;
const ICMPV4_TYPE_ECHO_REPLY: u8 = 0;
const ICMPV6_TYPE_ECHO_REQUEST: u8 = 128;
const ICMPV6_TYPE_ECHO_REPLY: u8 = 129;
pub struct EchoRequestPacket<V: IpVersion> {
buf: Bytes,
_version: PhantomData<V>,
}
pub struct EchoReplyPacket<V: IpVersion> {
source: V,
identifier: u16,
sequence_number: u16,
payload: Bytes,
}
impl<V: IpVersion> EchoRequestPacket<V> {
pub fn new(identifier: u16, sequence_number: u16, payload: &[u8]) -> Self {
let mut buf = BytesMut::zeroed(ICMP_HEADER_LEN + payload.len());
let echo_type = if V::IS_V4 {
ICMPV4_TYPE_ECHO_REQUEST
} else {
ICMPV6_TYPE_ECHO_REQUEST
};
buf[0] = echo_type;
buf[4..6].copy_from_slice(&identifier.to_be_bytes());
buf[6..8].copy_from_slice(&sequence_number.to_be_bytes());
buf[ICMP_HEADER_LEN..].copy_from_slice(payload);
let checksum = internet_checksum(&buf);
buf[2..4].copy_from_slice(&checksum.to_be_bytes());
Self {
buf: buf.freeze(),
_version: PhantomData,
}
}
pub(crate) fn as_bytes(&self) -> &[u8] {
&self.buf
}
}
impl<V: IpVersion> EchoReplyPacket<V> {
pub(crate) fn from_reply(source: V, buf: Bytes) -> Option<Self> {
if buf.len() < ICMP_HEADER_LEN {
return None;
}
let expected_type = if V::IS_V4 {
ICMPV4_TYPE_ECHO_REPLY
} else {
ICMPV6_TYPE_ECHO_REPLY
};
if buf[0] != expected_type {
return None;
}
let identifier = u16::from_be_bytes([buf[4], buf[5]]);
let sequence_number = u16::from_be_bytes([buf[6], buf[7]]);
let payload = buf.slice(ICMP_HEADER_LEN..);
Some(Self {
source,
identifier,
sequence_number,
payload,
})
}
pub fn source(&self) -> V {
self.source
}
pub fn identifier(&self) -> u16 {
self.identifier
}
pub fn sequence_number(&self) -> u16 {
self.sequence_number
}
pub fn payload(&self) -> &[u8] {
&self.payload
}
}
fn internet_checksum(data: &[u8]) -> u16 {
let mut sum = 0u32;
let mut i = 0;
while i + 1 < data.len() {
sum += u16::from_be_bytes([data[i], data[i + 1]]) as u32;
i += 2;
}
if i < data.len() {
sum += (data[i] as u32) << 8;
}
while sum >> 16 != 0 {
sum = (sum & 0xFFFF) + (sum >> 16);
}
!sum as u16
}
#[cfg(test)]
mod tests {
use std::net::Ipv4Addr;
use bytes::Bytes;
use super::EchoReplyPacket;
#[test]
fn from_reply_rejects_truncated_packet() {
let buf = Bytes::from_static(&[0x00, 0x00]);
let result = EchoReplyPacket::<Ipv4Addr>::from_reply(Ipv4Addr::LOCALHOST, buf);
assert!(result.is_none(), "Should reject truncated packet");
}
#[test]
fn from_reply_rejects_wrong_icmp_type() {
let buf = Bytes::from_static(&[0x08, 0x00, 0x00, 0x00, 0x12, 0x34, 0x00, 0x01]);
let result = EchoReplyPacket::<Ipv4Addr>::from_reply(Ipv4Addr::LOCALHOST, buf);
assert!(result.is_none(), "Should reject Echo Request (type 8)");
}
#[test]
fn from_reply_rejects_destination_unreachable() {
let buf = Bytes::from_static(&[0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
let result = EchoReplyPacket::<Ipv4Addr>::from_reply(Ipv4Addr::LOCALHOST, buf);
assert!(result.is_none(), "Should reject Destination Unreachable");
}
#[test]
fn from_reply_accepts_valid_echo_reply() {
let buf = Bytes::from_static(&[
0x00, 0x00, 0x00, 0x00, 0x12, 0x34, 0x00, 0x01, b't', b'e', b's', b't',
]);
let packet = EchoReplyPacket::<Ipv4Addr>::from_reply(Ipv4Addr::LOCALHOST, buf).unwrap();
assert_eq!(packet.identifier(), 0x1234);
assert_eq!(packet.sequence_number(), 0x0001);
assert_eq!(packet.payload(), b"test");
}
}