use crate::{
checks::internet::ipv4::{
IPV4_MIN_HEADER_LEN, validate_ipv4_header_available, validate_ipv4_header_length,
validate_ipv4_min_length, validate_ipv4_total_length, validate_ipv4_version,
},
errors::internet::ipv4::Ipv4Error,
};
use std::convert::TryFrom;
use std::net::Ipv4Addr;
#[cfg_attr(doc, aquamarine::aquamarine)]
#[derive(Debug, PartialEq)]
pub struct Ipv4Packet<'a> {
pub version_ihl: u8,
pub dscp_ecn: u8,
pub total_length: u16,
pub identification: u16,
pub flags_fragment: u16,
pub ttl: u8,
pub protocol: u8,
pub header_checksum: u16,
pub source_addr: Ipv4Addr,
pub dest_addr: Ipv4Addr,
pub options: &'a [u8],
pub payload: &'a [u8],
}
impl<'a> Ipv4Packet<'a> {
pub fn version(&self) -> u8 {
self.version_ihl >> 4
}
pub fn ihl(&self) -> u8 {
self.version_ihl & 0x0F
}
pub fn header_length(&self) -> usize {
(self.ihl() as usize) * 4
}
pub fn dscp(&self) -> u8 {
self.dscp_ecn >> 2
}
pub fn ecn(&self) -> u8 {
self.dscp_ecn & 0x03
}
pub fn flags(&self) -> u8 {
(self.flags_fragment >> 13) as u8
}
pub fn fragment_offset(&self) -> u16 {
self.flags_fragment & 0x1FFF
}
}
impl<'a> TryFrom<&'a [u8]> for Ipv4Packet<'a> {
type Error = Ipv4Error;
fn try_from(data: &'a [u8]) -> Result<Self, Self::Error> {
validate_ipv4_min_length(data)?;
let version_ihl = data[0];
let version = version_ihl >> 4;
let ihl = version_ihl & 0x0F;
validate_ipv4_version(version)?;
let header_len = (ihl as usize) * 4;
validate_ipv4_header_length(header_len)?;
validate_ipv4_header_available(data.len(), header_len)?;
let dscp_ecn = data[1];
let total_length = u16::from_be_bytes([data[2], data[3]]);
validate_ipv4_total_length(total_length, header_len, data.len())?;
let identification = u16::from_be_bytes([data[4], data[5]]);
let flags_fragment = u16::from_be_bytes([data[6], data[7]]);
let ttl = data[8];
let protocol = data[9];
let header_checksum = u16::from_be_bytes([data[10], data[11]]);
let source_addr = Ipv4Addr::new(data[12], data[13], data[14], data[15]);
let dest_addr = Ipv4Addr::new(data[16], data[17], data[18], data[19]);
let options = if header_len > IPV4_MIN_HEADER_LEN {
&data[IPV4_MIN_HEADER_LEN..header_len]
} else {
&[]
};
let payload =
if (total_length as usize) > header_len && (total_length as usize) <= data.len() {
&data[header_len..(total_length as usize)]
} else if (total_length as usize) > data.len() {
&data[header_len..]
} else {
&[]
};
Ok(Ipv4Packet {
version_ihl,
dscp_ecn,
total_length,
identification,
flags_fragment,
ttl,
protocol,
header_checksum,
source_addr,
dest_addr,
options,
payload,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
#[test]
fn test_ipv4_packet_parsing() {
let data = [
0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, 0xc0, 0xa8,
0x01, 0x01, 0xc0, 0xa8, 0x01, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x50, 0x02, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
];
let packet = Ipv4Packet::try_from(&data[..]).unwrap();
assert_eq!(packet.version(), 4);
assert_eq!(packet.ihl(), 5);
assert_eq!(packet.header_length(), 20);
assert_eq!(packet.total_length, 60);
assert_eq!(packet.protocol, 6); assert_eq!(packet.source_addr, Ipv4Addr::new(192, 168, 1, 1));
assert_eq!(packet.dest_addr, Ipv4Addr::new(192, 168, 1, 2));
assert!(packet.options.is_empty());
}
#[test]
fn test_invalid_version() {
let data = [
0x65, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, 0xc0, 0xa8,
0x01, 0x01, 0xc0, 0xa8, 0x01, 0x02,
];
let result = Ipv4Packet::try_from(&data[..]);
assert!(matches!(result, Err(Ipv4Error::InvalidVersion(6))));
}
#[test]
fn test_invalid_header_length() {
let data = [
0x41, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, 0xc0, 0xa8,
0x01, 0x01, 0xc0, 0xa8, 0x01, 0x02,
];
let result = Ipv4Packet::try_from(&data[..]);
assert!(matches!(result, Err(Ipv4Error::InvalidHeaderLength(4))));
}
}