use crate::{Checksum, checksum};
fn reference_checksum(data: &[u8]) -> u16 {
let mut sum: u32 = 0;
let mut i = 0;
while i + 1 < data.len() {
let word = ((data[i] as u32) << 8) | (data[i + 1] as u32);
sum = sum.wrapping_add(word);
i += 2;
}
if i < data.len() {
sum = sum.wrapping_add((data[i] as u32) << 8);
}
while (sum >> 16) != 0 {
sum = (sum >> 16) + (sum & 0xffff);
}
!(sum as u16)
}
#[test]
fn test_empty() {
assert_eq!(checksum(&[]), reference_checksum(&[]));
}
#[test]
fn test_single_byte() {
assert_eq!(checksum(&[0x45]), reference_checksum(&[0x45]));
}
#[test]
fn test_two_bytes() {
assert_eq!(checksum(&[0x45, 0x00]), reference_checksum(&[0x45, 0x00]));
}
#[test]
fn test_three_bytes() {
assert_eq!(
checksum(&[0x45, 0x00, 0xab]),
reference_checksum(&[0x45, 0x00, 0xab])
);
}
#[test]
fn test_four_bytes() {
assert_eq!(
checksum(&[0x45, 0x00, 0x00, 0x30]),
reference_checksum(&[0x45, 0x00, 0x00, 0x30])
);
}
#[test]
fn test_five_bytes() {
assert_eq!(
checksum(&[0x45, 0x00, 0x00, 0x30, 0xff]),
reference_checksum(&[0x45, 0x00, 0x00, 0x30, 0xff])
);
}
#[test]
fn test_ipv4_header_1() {
let data: [u8; 20] = [
0x45, 0x00, 0x00, 0x30, 0x00, 0x00, 0x40, 0x00, 0x40, 0x01, 0x00, 0x00, 0x0a, 0x00, 0x00,
0x01, 0x0a, 0x00, 0x00, 0x02,
];
let csum = checksum(&data);
let mut verified = data;
verified[10] = csum.to_be_bytes()[0];
verified[11] = csum.to_be_bytes()[1];
assert_eq!(checksum(&verified), 0);
}
#[test]
fn test_ipv4_header_2() {
let data: [u8; 20] = [
0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, 0xac, 0x10, 0x0a,
0x63, 0xac, 0x10, 0x0a, 0x0c,
];
let csum = checksum(&data);
let mut verified = data;
verified[10] = csum.to_be_bytes()[0];
verified[11] = csum.to_be_bytes()[1];
assert_eq!(checksum(&verified), 0);
}
#[test]
fn test_incremental_two_parts() {
let data: [u8; 20] = [
0x45, 0x00, 0x00, 0x30, 0x00, 0x00, 0x40, 0x00, 0x40, 0x01, 0x00, 0x00, 0x0a, 0x00, 0x00,
0x01, 0x0a, 0x00, 0x00, 0x02,
];
let csum_oneshot = checksum(&data);
let mut hasher = Checksum::new();
hasher.update(&data[..10]);
hasher.update(&data[10..]);
let csum_inc = hasher.finalize();
assert_eq!(csum_oneshot, csum_inc);
}
#[test]
fn test_incremental_three_parts() {
let data: [u8; 20] = [
0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, 0xac, 0x10, 0x0a,
0x63, 0xac, 0x10, 0x0a, 0x0c,
];
let csum_oneshot = checksum(&data);
let mut hasher = Checksum::new();
hasher.update(&data[..4]);
hasher.update(&data[4..16]);
hasher.update(&data[16..]);
let csum_inc = hasher.finalize();
assert_eq!(csum_oneshot, csum_inc);
}
#[test]
fn test_incremental_with_zeroed_checksum_field() {
let data: [u8; 20] = [
0x45, 0x00, 0x00, 0x30, 0x00, 0x00, 0x40, 0x00, 0x40, 0x01, 0x00, 0x00, 0x0a, 0x00, 0x00,
0x01, 0x0a, 0x00, 0x00, 0x02,
];
let csum_oneshot = checksum(&data);
let mut hasher = Checksum::new();
hasher.update(&data[..10]);
hasher.update(&[0, 0]);
hasher.update(&data[12..]);
let csum_inc = hasher.finalize();
assert_eq!(csum_oneshot, csum_inc);
}
#[test]
fn test_reset() {
let data: [u8; 20] = [
0x45, 0x00, 0x00, 0x30, 0x00, 0x00, 0x40, 0x00, 0x40, 0x01, 0x00, 0x00, 0x0a, 0x00, 0x00,
0x01, 0x0a, 0x00, 0x00, 0x02,
];
let expected = checksum(&data);
let mut hasher = Checksum::new();
hasher.update(&[0xFF, 0xFF, 0xFF, 0xFF]);
hasher.reset();
hasher.update(&data);
assert_eq!(hasher.finalize(), expected);
}
#[test]
fn test_default_trait() {
let mut a = Checksum::default();
let mut b = Checksum::new();
a.update(&[0x01, 0x02]);
b.update(&[0x01, 0x02]);
assert_eq!(a.finalize(), b.finalize());
}
#[test]
fn test_clone() {
let mut original = Checksum::new();
original.update(&[0x45, 0x00, 0x00, 0x30]);
let cloned = original.clone();
assert_eq!(original.finalize(), cloned.finalize());
}
#[test]
fn test_sizes_0_to_256() {
let mut buf = [0u8; 256];
for (i, b) in buf.iter_mut().enumerate() {
*b = (i as u8).wrapping_mul(31);
}
for len in 0..=256 {
assert_eq!(
checksum(&buf[..len]),
reference_checksum(&buf[..len]),
"mismatch at len={len}"
);
}
}
#[test]
fn test_size_63_64_65() {
let buf = [0xABu8; 256];
for &len in &[63, 64, 65] {
assert_eq!(
checksum(&buf[..len]),
reference_checksum(&buf[..len]),
"mismatch at len={len}"
);
}
}
#[test]
fn test_size_127_128_129() {
let buf = [0xCDu8; 256];
for &len in &[127, 128, 129] {
assert_eq!(
checksum(&buf[..len]),
reference_checksum(&buf[..len]),
"mismatch at len={len}"
);
}
}
#[test]
fn test_tcp_checksum_incremental() {
let pseudo: [u8; 12] = [192, 168, 1, 1, 192, 168, 1, 2, 0, 0, 0, 6];
let tcp_length: [u8; 2] = [0x00, 0x14];
let tcp_hdr: [u8; 20] = [
0x00, 0x50, 0x00, 0x50, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x50, 0x02, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00,
];
let mut combined = [0u8; 34];
combined[..12].copy_from_slice(&pseudo);
combined[12..14].copy_from_slice(&tcp_length);
combined[14..].copy_from_slice(&tcp_hdr);
let csum_oneshot = checksum(&combined);
let mut hasher = Checksum::new();
hasher.update(&pseudo);
hasher.update(&tcp_length);
hasher.update(&tcp_hdr);
let csum_inc = hasher.finalize();
assert_eq!(csum_oneshot, csum_inc);
}
#[test]
fn test_fold_zero() {
let hasher = Checksum::new();
assert_eq!(hasher.finalize(), 0xFFFF);
}
#[test]
fn test_fold_known_value() {
assert_eq!(checksum(&[0xFF, 0xFF]), 0);
}
#[test]
fn test_verification_property() {
let mut buf = [0u8; 60];
for (i, b) in buf.iter_mut().enumerate() {
*b = (i as u8).wrapping_mul(7).wrapping_add(0x5A);
}
buf[10] = 0;
buf[11] = 0;
let csum = checksum(&buf);
buf[10] = csum.to_be_bytes()[0];
buf[11] = csum.to_be_bytes()[1];
assert_eq!(checksum(&buf), 0);
}
#[test]
fn test_1500_bytes() {
let mut buf = [0u8; 1500];
for (i, b) in buf.iter_mut().enumerate() {
*b = (i as u8).wrapping_add(0x37);
}
assert_eq!(checksum(&buf), reference_checksum(&buf));
}
#[test]
fn test_odd_length_large() {
let mut buf = [0u8; 1501];
for (i, b) in buf.iter_mut().enumerate() {
*b = (i as u8).wrapping_mul(13);
}
assert_eq!(checksum(&buf), reference_checksum(&buf));
}