use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned, big_endian};
use crate::packet::{IpNextProtocol, Udp};
pub(crate) const fn size_must_be<T>(size: usize) -> usize {
if size_of::<T>() == size {
size
} else {
panic!("Size of T is wrong!")
}
}
#[repr(C)]
#[derive(Debug, FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)]
pub struct PseudoHeaderV4 {
pub source: big_endian::U32,
pub destination: big_endian::U32,
_zero: u8,
pub protocol: IpNextProtocol,
pub length: big_endian::U16,
}
impl PseudoHeaderV4 {
pub fn from_udp(source: big_endian::U32, destination: big_endian::U32, udp: &Udp) -> Self {
Self {
source,
destination,
_zero: 0,
protocol: IpNextProtocol::Udp,
length: udp.as_bytes().len().try_into().unwrap(),
}
}
}
#[repr(C)]
#[derive(Debug, FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)]
pub struct PseudoHeaderV6 {
pub source: big_endian::U128,
pub destination: big_endian::U128,
_zero: u8,
pub protocol: IpNextProtocol,
pub length: big_endian::U16,
}
impl PseudoHeaderV6 {
pub fn from_udp(source: big_endian::U128, destination: big_endian::U128, udp: &Udp) -> Self {
Self {
source,
destination,
_zero: 0,
protocol: IpNextProtocol::Udp,
length: udp.as_bytes().len().try_into().unwrap(),
}
}
}
pub fn checksum(payload: &[&[u8]]) -> u16 {
let mut sum = 0;
for p in payload {
sum += checksum_payload(p);
}
finalize_csum(sum)
}
pub fn checksum_udp<H: IntoBytes + Immutable>(header: H, payload: &[u8]) -> u16 {
let csum = checksum(&[header.as_bytes(), payload]);
if csum == 0 {
return !0;
}
csum
}
fn checksum_payload(bytes: &[u8]) -> u32 {
let (words, rest) = <[big_endian::U16]>::ref_from_prefix(bytes).unwrap();
let mut sum: u32 = words.iter().map(|w| u32::from(w.get())).sum();
if let [b] = rest {
sum += u32::from(u16::from_be_bytes([*b, 0]));
}
sum
}
fn finalize_csum(mut sum: u32) -> u16 {
while sum >> 16 != 0 {
sum = (sum >> 16) + (sum & 0xffff);
}
!(sum as u16)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::packet::Ipv4Header;
use std::net::Ipv4Addr;
#[test]
fn test_ipv4_header_checksum() {
let src_ip = Ipv4Addr::new(10, 0, 0, 1);
let dst_ip = Ipv4Addr::new(192, 168, 1, 1);
let mut header = Ipv4Header::new_for_length(src_ip, dst_ip, IpNextProtocol::Udp, 23);
header.header_checksum = checksum(&[header.as_bytes()]).into();
assert_eq!(header.header_checksum.get(), 0xAF18);
assert_eq!(checksum(&[header.as_bytes()]), 0);
}
}