use serde::{Deserialize, Serialize};
use crate::{PacketBuilder, PacketError, PacketHeader, Checksumable};
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[repr(u8)]
pub enum IcmpType {
EchoReply = 0,
DestinationUnreachable = 3,
SourceQuench = 4,
Redirect = 5,
EchoRequest = 8,
RouterAdvertisement = 9,
RouterSolicitation = 10,
TimeExceeded = 11,
ParameterProblem = 12,
TimestampRequest = 13,
TimestampReply = 14,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[repr(u8)]
pub enum DestUnreachableCode {
NetworkUnreachable = 0,
HostUnreachable = 1,
ProtocolUnreachable = 2,
PortUnreachable = 3,
FragmentationNeeded = 4,
SourceRouteFailed = 5,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IcmpHeader {
pub message_type: IcmpType,
pub code: u8,
pub checksum: u16,
pub rest_of_header: u32,
}
impl IcmpHeader {
fn new(message_type: IcmpType, code: u8, rest_of_header: u32) -> Self {
Self {
message_type,
code,
checksum: 0,
rest_of_header,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IcmpPacket {
pub header: IcmpHeader,
pub payload: Vec<u8>,
}
#[derive(Debug, Default)]
pub struct IcmpBuilder {
message_type: Option<IcmpType>,
code: u8,
rest_of_header: u32,
payload: Vec<u8>,
}
impl IcmpBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn message_type(mut self, message_type: IcmpType) -> Self {
self.message_type = Some(message_type);
self
}
pub fn code(mut self, code: u8) -> Self {
self.code = code;
self
}
pub fn rest_of_header(mut self, rest_of_header: u32) -> Self {
self.rest_of_header = rest_of_header;
self
}
pub fn payload(mut self, payload: Vec<u8>) -> Self {
self.payload = payload;
self
}
pub fn build(self) -> Result<IcmpPacket, PacketError> {
let message_type = self.message_type.ok_or_else(||
PacketError::InvalidFieldValue("ICMP message type not set".to_string()))?;
let packet = IcmpPacket {
header: IcmpHeader::new(message_type, self.code, self.rest_of_header),
payload: self.payload,
};
Ok(packet)
}
}
impl IcmpPacket {
pub fn builder() -> IcmpBuilder {
IcmpBuilder::new()
}
pub fn echo_request(identifier: u16, sequence: u16, payload: Vec<u8>) -> Result<Self, PacketError> {
let rest_of_header = ((identifier as u32) << 16) | (sequence as u32);
IcmpBuilder::new()
.message_type(IcmpType::EchoRequest)
.code(0)
.rest_of_header(rest_of_header)
.payload(payload)
.build()
}
pub fn echo_reply(identifier: u16, sequence: u16, payload: Vec<u8>) -> Result<Self, PacketError> {
let rest_of_header = ((identifier as u32) << 16) | (sequence as u32);
IcmpBuilder::new()
.message_type(IcmpType::EchoReply)
.code(0)
.rest_of_header(rest_of_header)
.payload(payload)
.build()
}
}
impl PacketHeader for IcmpHeader {
fn header_length(&self) -> usize {
8 }
fn as_bytes(&self) -> Result<Vec<u8>, PacketError> {
let mut bytes = Vec::with_capacity(self.header_length());
bytes.push(self.message_type as u8);
bytes.push(self.code);
bytes.extend_from_slice(&self.checksum.to_be_bytes());
bytes.extend_from_slice(&self.rest_of_header.to_be_bytes());
Ok(bytes)
}
}
impl Checksumable for IcmpHeader {
fn calculate_checksum(&self) -> u16 {
let mut sum = 0u32;
let bytes = self.as_bytes().unwrap();
for i in (0..bytes.len()).step_by(2) {
let word = if i + 1 < bytes.len() {
((bytes[i] as u32) << 8) | (bytes[i + 1] as u32)
} else {
(bytes[i] as u32) << 8
};
sum += word;
}
while (sum >> 16) > 0 {
sum = (sum & 0xFFFF) + (sum >> 16);
}
!sum as u16
}
fn verify_checksum(&self) -> bool {
self.calculate_checksum() == 0
}
}
impl PacketBuilder for IcmpPacket {
fn build(&self) -> Result<Vec<u8>, PacketError> {
let mut packet = self.header.as_bytes()?;
packet.extend_from_slice(&self.payload);
Ok(packet)
}
fn length(&self) -> usize {
self.header.header_length() + self.payload.len()
}
fn validate(&self) -> Result<(), PacketError> {
match self.header.message_type {
IcmpType::EchoRequest | IcmpType::EchoReply => {
if self.payload.len() > 65507 { return Err(PacketError::InvalidLength);
}
}
_ => {
if self.payload.len() > 1500 { return Err(PacketError::InvalidLength);
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_icmp_echo_request() {
let payload = b"Hello, World!".to_vec();
let packet = IcmpPacket::echo_request(1, 1, payload.clone()).unwrap();
assert_eq!(packet.header.message_type as u8, IcmpType::EchoRequest as u8);
assert_eq!(packet.header.code, 0);
assert_eq!(packet.payload, payload);
assert!(packet.validate().is_ok());
}
#[test]
fn test_icmp_echo_reply() {
let payload = b"Hello, World!".to_vec();
let packet = IcmpPacket::echo_reply(1, 1, payload.clone()).unwrap();
assert_eq!(packet.header.message_type as u8, IcmpType::EchoReply as u8);
assert_eq!(packet.header.code, 0);
assert_eq!(packet.payload, payload);
assert!(packet.validate().is_ok());
}
#[test]
fn test_icmp_builder() {
let packet = IcmpPacket::builder()
.message_type(IcmpType::DestinationUnreachable)
.code(DestUnreachableCode::PortUnreachable as u8)
.rest_of_header(0)
.payload(vec![1, 2, 3, 4])
.build()
.unwrap();
assert_eq!(packet.header.message_type as u8, IcmpType::DestinationUnreachable as u8);
assert_eq!(packet.header.code, DestUnreachableCode::PortUnreachable as u8);
assert!(packet.validate().is_ok());
}
#[test]
fn test_invalid_payload_size() {
let large_payload = vec![0; 65508]; let result = IcmpPacket::echo_request(1, 1, large_payload);
assert!(result.is_ok());
let packet = result.unwrap();
let validation_result = packet.validate();
assert!(validation_result.is_err());
match validation_result {
Err(PacketError::InvalidLength) => (),
_ => panic!("Expected InvalidLength error"),
}
}
}