use crate::packets::ip::ProtocolNumber;
use anyhow::{anyhow, Result};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::slice;
#[derive(Debug)]
pub enum PseudoHeader {
V4 {
src: Ipv4Addr,
dst: Ipv4Addr,
packet_len: u16,
protocol: ProtocolNumber,
},
V6 {
src: Ipv6Addr,
dst: Ipv6Addr,
packet_len: u16,
protocol: ProtocolNumber,
},
}
impl PseudoHeader {
pub fn sum(&self) -> u16 {
let mut sum = match *self {
PseudoHeader::V4 {
src,
dst,
packet_len,
protocol,
} => v4_csum(src, dst, packet_len, protocol),
PseudoHeader::V6 {
src,
dst,
packet_len,
protocol,
} => v6_csum(src, dst, packet_len, protocol),
};
while sum >> 16 != 0 {
sum = (sum >> 16) + (sum & 0xFFFF);
}
sum as u16
}
}
fn v4_csum(src: Ipv4Addr, dst: Ipv4Addr, packet_len: u16, protocol: ProtocolNumber) -> u32 {
let src: u32 = src.into();
let dst: u32 = dst.into();
(src >> 16)
+ (src & 0xFFFF)
+ (dst >> 16)
+ (dst & 0xFFFF)
+ u32::from(protocol.0)
+ u32::from(packet_len)
}
fn v6_csum(src: Ipv6Addr, dst: Ipv6Addr, packet_len: u16, protocol: ProtocolNumber) -> u32 {
src.segments().iter().fold(0, |acc, &x| acc + u32::from(x))
+ dst.segments().iter().fold(0, |acc, &x| acc + u32::from(x))
+ u32::from(packet_len)
+ u32::from(protocol.0)
}
#[allow(clippy::cast_ptr_alignment)]
pub fn compute(pseudo_header_sum: u16, payload: &[u8]) -> u16 {
let len = payload.len();
let mut data = payload;
let mut checksum = u32::from(pseudo_header_sum);
if len % 2 > 0 {
checksum += u32::from(payload[len - 1]) << 8;
data = &payload[..(len - 1)];
}
let data = unsafe { slice::from_raw_parts(data.as_ptr() as *const u16, len / 2) };
checksum = data
.iter()
.fold(checksum, |acc, &x| acc + u32::from(u16::from_be(x)));
while checksum >> 16 != 0 {
checksum = (checksum >> 16) + (checksum & 0xFFFF);
}
!(checksum as u16)
}
pub fn compute_inc(old_checksum: u16, old_value: &[u16], new_value: &[u16]) -> u16 {
let mut checksum = old_value
.iter()
.zip(new_value.iter())
.fold(u32::from(!old_checksum), |acc, (&old, &new)| {
acc + u32::from(!old) + u32::from(new)
});
while checksum >> 16 != 0 {
checksum = (checksum >> 16) + (checksum & 0xFFFF);
}
!(checksum as u16)
}
pub fn compute_with_ipaddr(
old_checksum: u16,
old_value: &IpAddr,
new_value: &IpAddr,
) -> Result<u16> {
match (old_value, new_value) {
(IpAddr::V4(old), IpAddr::V4(new)) => {
let old: u32 = (*old).into();
let old = [(old >> 16) as u16, (old & 0xFFFF) as u16];
let new: u32 = (*new).into();
let new = [(new >> 16) as u16, (new & 0xFFFF) as u16];
Ok(compute_inc(old_checksum, &old, &new))
}
(IpAddr::V6(old), IpAddr::V6(new)) => {
Ok(compute_inc(old_checksum, &old.segments(), &new.segments()))
}
_ => Err(anyhow!("cannot mix IPv4 and IPv6 addresses.")),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn compute_checksum_incrementally() {
assert_eq!(0x0000, compute_inc(0xdd2f, &[0x5555], &[0x3285]));
}
}