use crate::error::{NetError, NetResult};
use crate::IPV4_HEADER_MIN_SIZE;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
#[repr(transparent)]
pub struct Ipv4Addr(pub [u8; 4]);
impl Ipv4Addr {
pub const BROADCAST: Self = Self([255, 255, 255, 255]);
pub const UNSPECIFIED: Self = Self([0, 0, 0, 0]);
pub const LOCALHOST: Self = Self([127, 0, 0, 1]);
#[inline]
#[must_use]
pub const fn new(a: u8, b: u8, c: u8, d: u8) -> Self {
Self([a, b, c, d])
}
#[inline]
#[must_use]
pub const fn from_bytes(bytes: [u8; 4]) -> Self {
Self(bytes)
}
#[inline]
#[must_use]
pub const fn as_bytes(&self) -> &[u8; 4] {
&self.0
}
#[inline]
#[must_use]
pub const fn to_u32(&self) -> u32 {
u32::from_be_bytes(self.0)
}
#[inline]
#[must_use]
pub const fn from_u32(addr: u32) -> Self {
Self(addr.to_be_bytes())
}
#[inline]
#[must_use]
pub const fn is_broadcast(&self) -> bool {
self.0[0] == 255 && self.0[1] == 255 && self.0[2] == 255 && self.0[3] == 255
}
#[inline]
#[must_use]
pub const fn is_unspecified(&self) -> bool {
self.0[0] == 0 && self.0[1] == 0 && self.0[2] == 0 && self.0[3] == 0
}
#[inline]
#[must_use]
pub const fn is_loopback(&self) -> bool {
self.0[0] == 127
}
#[inline]
#[must_use]
pub const fn is_multicast(&self) -> bool {
self.0[0] >= 224 && self.0[0] <= 239
}
#[inline]
#[must_use]
pub const fn is_link_local(&self) -> bool {
self.0[0] == 169 && self.0[1] == 254
}
#[inline]
#[must_use]
pub const fn is_private(&self) -> bool {
self.0[0] == 10
|| (self.0[0] == 172 && self.0[1] >= 16 && self.0[1] <= 31)
|| (self.0[0] == 192 && self.0[1] == 168)
}
#[inline]
pub fn parse(bytes: &[u8]) -> NetResult<Self> {
if bytes.len() < 4 {
return Err(NetError::PacketTooShort);
}
let mut addr = [0u8; 4];
addr.copy_from_slice(&bytes[..4]);
Ok(Self(addr))
}
#[inline]
#[must_use]
pub const fn is_same_subnet(&self, other: &Self, mask: &Self) -> bool {
(self.0[0] & mask.0[0]) == (other.0[0] & mask.0[0])
&& (self.0[1] & mask.0[1]) == (other.0[1] & mask.0[1])
&& (self.0[2] & mask.0[2]) == (other.0[2] & mask.0[2])
&& (self.0[3] & mask.0[3]) == (other.0[3] & mask.0[3])
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u8)]
pub enum Protocol {
Icmp = 1,
Tcp = 6,
Udp = 17,
Unknown(u8) = 255,
}
impl Protocol {
#[inline]
#[must_use]
pub const fn from_u8(value: u8) -> Self {
match value {
1 => Self::Icmp,
6 => Self::Tcp,
17 => Self::Udp,
other => Self::Unknown(other),
}
}
#[inline]
#[must_use]
pub const fn to_u8(self) -> u8 {
match self {
Self::Icmp => 1,
Self::Tcp => 6,
Self::Udp => 17,
Self::Unknown(v) => v,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct Ipv4Flags {
pub dont_fragment: bool,
pub more_fragments: bool,
}
impl Ipv4Flags {
#[inline]
#[must_use]
pub const fn from_raw(raw: u16) -> Self {
Self {
dont_fragment: (raw & 0x4000) != 0,
more_fragments: (raw & 0x2000) != 0,
}
}
#[inline]
#[must_use]
pub const fn to_raw(self) -> u16 {
let mut flags = 0u16;
if self.dont_fragment {
flags |= 0x4000;
}
if self.more_fragments {
flags |= 0x2000;
}
flags
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Ipv4Header {
pub version: u8,
pub ihl: u8,
pub dscp: u8,
pub ecn: u8,
pub total_length: u16,
pub identification: u16,
pub flags: Ipv4Flags,
pub fragment_offset: u16,
pub ttl: u8,
pub protocol: Protocol,
pub checksum: u16,
pub src_addr: Ipv4Addr,
pub dst_addr: Ipv4Addr,
}
impl Ipv4Header {
pub const DEFAULT_TTL: u8 = 64;
#[inline]
pub fn parse(bytes: &[u8]) -> NetResult<(Self, &[u8])> {
if bytes.len() < IPV4_HEADER_MIN_SIZE {
return Err(NetError::PacketTooShort);
}
let version_ihl = bytes[0];
let version = version_ihl >> 4;
let ihl = version_ihl & 0x0F;
if version != 4 {
return Err(NetError::InvalidIpVersion);
}
if ihl < 5 {
return Err(NetError::InvalidIpHeaderLength);
}
let header_len = (ihl as usize) * 4;
if bytes.len() < header_len {
return Err(NetError::PacketTooShort);
}
let total_length = u16::from_be_bytes([bytes[2], bytes[3]]);
if (total_length as usize) < header_len {
return Err(NetError::InvalidIpHeaderLength);
}
let tos = bytes[1];
let dscp = tos >> 2;
let ecn = tos & 0x03;
let identification = u16::from_be_bytes([bytes[4], bytes[5]]);
let flags_frag = u16::from_be_bytes([bytes[6], bytes[7]]);
let flags = Ipv4Flags::from_raw(flags_frag);
let fragment_offset = flags_frag & 0x1FFF;
let ttl = bytes[8];
let protocol = Protocol::from_u8(bytes[9]);
let checksum = u16::from_be_bytes([bytes[10], bytes[11]]);
let src_addr = Ipv4Addr::parse(&bytes[12..16])?;
let dst_addr = Ipv4Addr::parse(&bytes[16..20])?;
let header = Self {
version,
ihl,
dscp,
ecn,
total_length,
identification,
flags,
fragment_offset,
ttl,
protocol,
checksum,
src_addr,
dst_addr,
};
let payload_start = header_len;
let payload_end = (total_length as usize).min(bytes.len());
let payload = &bytes[payload_start..payload_end];
Ok((header, payload))
}
#[inline]
pub fn serialize(&self, buf: &mut [u8]) -> NetResult<usize> {
if buf.len() < IPV4_HEADER_MIN_SIZE {
return Err(NetError::BufferTooSmall);
}
buf[0] = (self.version << 4) | self.ihl;
buf[1] = (self.dscp << 2) | self.ecn;
buf[2..4].copy_from_slice(&self.total_length.to_be_bytes());
buf[4..6].copy_from_slice(&self.identification.to_be_bytes());
let flags_frag = self.flags.to_raw() | self.fragment_offset;
buf[6..8].copy_from_slice(&flags_frag.to_be_bytes());
buf[8] = self.ttl;
buf[9] = self.protocol.to_u8();
buf[10] = 0;
buf[11] = 0;
buf[12..16].copy_from_slice(&self.src_addr.0);
buf[16..20].copy_from_slice(&self.dst_addr.0);
let checksum = Self::compute_checksum(&buf[..IPV4_HEADER_MIN_SIZE]);
buf[10..12].copy_from_slice(&checksum.to_be_bytes());
Ok(IPV4_HEADER_MIN_SIZE)
}
#[inline]
#[must_use]
pub fn compute_checksum(header: &[u8]) -> u16 {
let mut sum: u32 = 0;
for chunk in header.chunks(2) {
let word = if chunk.len() == 2 {
u16::from_be_bytes([chunk[0], chunk[1]])
} else {
u16::from_be_bytes([chunk[0], 0])
};
sum += u32::from(word);
}
while sum > 0xFFFF {
sum = (sum & 0xFFFF) + (sum >> 16);
}
!sum as u16
}
#[inline]
#[must_use]
pub fn verify_checksum(header: &[u8]) -> bool {
Self::compute_checksum(header) == 0
}
#[inline]
#[must_use]
pub const fn header_len(&self) -> usize {
(self.ihl as usize) * 4
}
#[inline]
#[must_use]
pub const fn payload_len(&self) -> usize {
(self.total_length as usize).saturating_sub(self.header_len())
}
#[inline]
#[must_use]
pub const fn new(
src_addr: Ipv4Addr,
dst_addr: Ipv4Addr,
protocol: Protocol,
payload_len: u16,
) -> Self {
Self {
version: 4,
ihl: 5,
dscp: 0,
ecn: 0,
total_length: IPV4_HEADER_MIN_SIZE as u16 + payload_len,
identification: 0,
flags: Ipv4Flags {
dont_fragment: true,
more_fragments: false,
},
fragment_offset: 0,
ttl: Self::DEFAULT_TTL,
protocol,
checksum: 0, src_addr,
dst_addr,
}
}
#[inline]
pub fn decrement_ttl(&mut self) -> NetResult<()> {
if self.ttl == 0 {
return Err(NetError::TtlExpired);
}
self.ttl -= 1;
if self.ttl == 0 {
return Err(NetError::TtlExpired);
}
Ok(())
}
}
pub struct Ipv4HeaderBuilder {
header: Ipv4Header,
}
impl Ipv4HeaderBuilder {
#[inline]
#[must_use]
pub const fn new() -> Self {
Self {
header: Ipv4Header {
version: 4,
ihl: 5,
dscp: 0,
ecn: 0,
total_length: IPV4_HEADER_MIN_SIZE as u16,
identification: 0,
flags: Ipv4Flags {
dont_fragment: true,
more_fragments: false,
},
fragment_offset: 0,
ttl: Ipv4Header::DEFAULT_TTL,
protocol: Protocol::Udp,
checksum: 0,
src_addr: Ipv4Addr::UNSPECIFIED,
dst_addr: Ipv4Addr::UNSPECIFIED,
},
}
}
#[inline]
#[must_use]
pub const fn src(mut self, addr: Ipv4Addr) -> Self {
self.header.src_addr = addr;
self
}
#[inline]
#[must_use]
pub const fn dst(mut self, addr: Ipv4Addr) -> Self {
self.header.dst_addr = addr;
self
}
#[inline]
#[must_use]
pub const fn protocol(mut self, proto: Protocol) -> Self {
self.header.protocol = proto;
self
}
#[inline]
#[must_use]
pub const fn ttl(mut self, ttl: u8) -> Self {
self.header.ttl = ttl;
self
}
#[inline]
#[must_use]
pub const fn identification(mut self, id: u16) -> Self {
self.header.identification = id;
self
}
#[inline]
#[must_use]
pub const fn dont_fragment(mut self, df: bool) -> Self {
self.header.flags.dont_fragment = df;
self
}
#[inline]
#[must_use]
pub const fn dscp(mut self, dscp: u8) -> Self {
self.header.dscp = dscp & 0x3F;
self
}
#[inline]
#[must_use]
pub const fn build(mut self, payload_len: u16) -> Ipv4Header {
self.header.total_length = IPV4_HEADER_MIN_SIZE as u16 + payload_len;
self.header
}
}
impl Default for Ipv4HeaderBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ipv4_addr_new() {
let addr = Ipv4Addr::new(192, 168, 1, 1);
assert_eq!(addr.0, [192, 168, 1, 1]);
}
#[test]
fn test_ipv4_addr_special() {
assert!(Ipv4Addr::BROADCAST.is_broadcast());
assert!(Ipv4Addr::UNSPECIFIED.is_unspecified());
assert!(Ipv4Addr::LOCALHOST.is_loopback());
}
#[test]
fn test_ipv4_addr_classification() {
assert!(Ipv4Addr::new(10, 0, 0, 1).is_private());
assert!(Ipv4Addr::new(172, 16, 0, 1).is_private());
assert!(Ipv4Addr::new(192, 168, 1, 1).is_private());
assert!(!Ipv4Addr::new(8, 8, 8, 8).is_private());
assert!(Ipv4Addr::new(224, 0, 0, 1).is_multicast());
assert!(Ipv4Addr::new(239, 255, 255, 255).is_multicast());
assert!(!Ipv4Addr::new(192, 168, 1, 1).is_multicast());
assert!(Ipv4Addr::new(169, 254, 1, 1).is_link_local());
assert!(!Ipv4Addr::new(192, 168, 1, 1).is_link_local());
}
#[test]
fn test_ipv4_addr_same_subnet() {
let addr1 = Ipv4Addr::new(192, 168, 1, 10);
let addr2 = Ipv4Addr::new(192, 168, 1, 20);
let addr3 = Ipv4Addr::new(192, 168, 2, 10);
let mask = Ipv4Addr::new(255, 255, 255, 0);
assert!(addr1.is_same_subnet(&addr2, &mask));
assert!(!addr1.is_same_subnet(&addr3, &mask));
}
#[test]
fn test_ipv4_addr_u32_conversion() {
let addr = Ipv4Addr::new(192, 168, 1, 1);
let as_u32 = addr.to_u32();
let back = Ipv4Addr::from_u32(as_u32);
assert_eq!(addr, back);
}
#[test]
fn test_protocol_conversion() {
assert_eq!(Protocol::from_u8(1), Protocol::Icmp);
assert_eq!(Protocol::from_u8(6), Protocol::Tcp);
assert_eq!(Protocol::from_u8(17), Protocol::Udp);
assert!(matches!(Protocol::from_u8(99), Protocol::Unknown(99)));
assert_eq!(Protocol::Icmp.to_u8(), 1);
assert_eq!(Protocol::Udp.to_u8(), 17);
}
#[test]
fn test_ipv4_header_parse() {
#[rustfmt::skip]
let packet = [
0x45, 0x00, 0x00, 0x1C, 0x00, 0x01, 0x40, 0x00, 0x40, 0x11, 0x00, 0x00, 0xC0, 0xA8, 0x01, 0x01, 0xC0, 0xA8, 0x01, 0x02, 0xDE, 0xAD, 0xBE, 0xEF, ];
let (header, payload) = Ipv4Header::parse(&packet).unwrap();
assert_eq!(header.version, 4);
assert_eq!(header.ihl, 5);
assert_eq!(header.total_length, 28);
assert_eq!(header.ttl, 64);
assert_eq!(header.protocol, Protocol::Udp);
assert!(header.flags.dont_fragment);
assert!(!header.flags.more_fragments);
assert_eq!(header.src_addr, Ipv4Addr::new(192, 168, 1, 1));
assert_eq!(header.dst_addr, Ipv4Addr::new(192, 168, 1, 2));
assert_eq!(payload, &[0xDE, 0xAD, 0xBE, 0xEF]);
}
#[test]
fn test_ipv4_header_serialize() {
let header = Ipv4Header::new(
Ipv4Addr::new(192, 168, 1, 1),
Ipv4Addr::new(192, 168, 1, 2),
Protocol::Udp,
8, );
let mut buf = [0u8; 64];
let len = header.serialize(&mut buf).unwrap();
assert_eq!(len, 20);
assert_eq!(buf[0], 0x45);
assert_eq!(&buf[12..16], &[192, 168, 1, 1]);
assert_eq!(&buf[16..20], &[192, 168, 1, 2]);
assert!(Ipv4Header::verify_checksum(&buf[..20]));
}
#[test]
fn test_ipv4_header_roundtrip() {
let original = Ipv4HeaderBuilder::new()
.src(Ipv4Addr::new(10, 0, 0, 1))
.dst(Ipv4Addr::new(10, 0, 0, 2))
.protocol(Protocol::Icmp)
.ttl(128)
.identification(0x1234)
.build(100);
let mut buf = [0u8; 256];
original.serialize(&mut buf).unwrap();
let (parsed, _) = Ipv4Header::parse(&buf).unwrap();
assert_eq!(original.src_addr, parsed.src_addr);
assert_eq!(original.dst_addr, parsed.dst_addr);
assert_eq!(original.protocol, parsed.protocol);
assert_eq!(original.ttl, parsed.ttl);
assert_eq!(original.identification, parsed.identification);
}
#[test]
fn test_ipv4_header_checksum() {
#[rustfmt::skip]
let header = [
0x45, 0x00, 0x00, 0x73,
0x00, 0x00, 0x40, 0x00,
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x00, 0x01,
0xc0, 0xa8, 0x00, 0xc7,
];
let checksum = Ipv4Header::compute_checksum(&header);
assert_ne!(checksum, 0); }
#[test]
fn test_ipv4_header_too_short() {
let short = [0x45, 0x00, 0x00]; assert_eq!(Ipv4Header::parse(&short), Err(NetError::PacketTooShort));
}
#[test]
fn test_ipv4_header_wrong_version() {
let mut packet = [0u8; 20];
packet[0] = 0x65; assert_eq!(Ipv4Header::parse(&packet), Err(NetError::InvalidIpVersion));
}
#[test]
fn test_ipv4_header_invalid_ihl() {
let mut packet = [0u8; 20];
packet[0] = 0x43; assert_eq!(
Ipv4Header::parse(&packet),
Err(NetError::InvalidIpHeaderLength)
);
}
#[test]
fn test_ipv4_flags() {
let df_only = Ipv4Flags::from_raw(0x4000);
assert!(df_only.dont_fragment);
assert!(!df_only.more_fragments);
let mf_only = Ipv4Flags::from_raw(0x2000);
assert!(!mf_only.dont_fragment);
assert!(mf_only.more_fragments);
let both = Ipv4Flags::from_raw(0x6000);
assert!(both.dont_fragment);
assert!(both.more_fragments);
assert_eq!(df_only.to_raw(), 0x4000);
assert_eq!(mf_only.to_raw(), 0x2000);
}
#[test]
fn test_decrement_ttl() {
let mut header = Ipv4Header::new(
Ipv4Addr::LOCALHOST,
Ipv4Addr::LOCALHOST,
Protocol::Icmp,
0,
);
header.ttl = 2;
assert!(header.decrement_ttl().is_ok());
assert_eq!(header.ttl, 1);
assert!(header.decrement_ttl().is_err());
header.ttl = 0;
assert_eq!(header.decrement_ttl(), Err(NetError::TtlExpired));
}
}