ip4sum 0.1.0

Highly optimized IPv4 checksum calculation, no-std compatible
Documentation
// SPDX-License-Identifier: MIT | Copyright (c) 2026 Khashayar Fereidani

//! Test suite for ip4sum.

use crate::{Checksum, checksum};

// -- Helper: reference checksum computed byte-by-byte in network order --

fn reference_checksum(data: &[u8]) -> u16 {
    let mut sum: u32 = 0;
    let mut i = 0;
    while i + 1 < data.len() {
        let word = ((data[i] as u32) << 8) | (data[i + 1] as u32);
        sum = sum.wrapping_add(word);
        i += 2;
    }
    if i < data.len() {
        sum = sum.wrapping_add((data[i] as u32) << 8);
    }
    while (sum >> 16) != 0 {
        sum = (sum >> 16) + (sum & 0xffff);
    }
    !(sum as u16)
}

// =========================================================================
// Basic correctness
// =========================================================================

#[test]
fn test_empty() {
    assert_eq!(checksum(&[]), reference_checksum(&[]));
}

#[test]
fn test_single_byte() {
    assert_eq!(checksum(&[0x45]), reference_checksum(&[0x45]));
}

#[test]
fn test_two_bytes() {
    assert_eq!(checksum(&[0x45, 0x00]), reference_checksum(&[0x45, 0x00]));
}

#[test]
fn test_three_bytes() {
    assert_eq!(
        checksum(&[0x45, 0x00, 0xab]),
        reference_checksum(&[0x45, 0x00, 0xab])
    );
}

#[test]
fn test_four_bytes() {
    assert_eq!(
        checksum(&[0x45, 0x00, 0x00, 0x30]),
        reference_checksum(&[0x45, 0x00, 0x00, 0x30])
    );
}

#[test]
fn test_five_bytes() {
    assert_eq!(
        checksum(&[0x45, 0x00, 0x00, 0x30, 0xff]),
        reference_checksum(&[0x45, 0x00, 0x00, 0x30, 0xff])
    );
}

// =========================================================================
// Known-good IPv4 headers
// =========================================================================

#[test]
fn test_ipv4_header_1() {
    // Version/IHL=0x45, TOS=0, total_len=48, id=0, flags=0x4000, ttl=64,
    // proto=ICMP, checksum=0, src=10.0.0.1, dst=10.0.0.2
    let data: [u8; 20] = [
        0x45, 0x00, 0x00, 0x30, 0x00, 0x00, 0x40, 0x00, 0x40, 0x01, 0x00, 0x00, 0x0a, 0x00, 0x00,
        0x01, 0x0a, 0x00, 0x00, 0x02,
    ];
    let csum = checksum(&data);
    let mut verified = data;
    verified[10] = csum.to_be_bytes()[0];
    verified[11] = csum.to_be_bytes()[1];
    assert_eq!(checksum(&verified), 0);
}

#[test]
fn test_ipv4_header_2() {
    let data: [u8; 20] = [
        0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, 0xac, 0x10, 0x0a,
        0x63, 0xac, 0x10, 0x0a, 0x0c,
    ];
    let csum = checksum(&data);
    let mut verified = data;
    verified[10] = csum.to_be_bytes()[0];
    verified[11] = csum.to_be_bytes()[1];
    assert_eq!(checksum(&verified), 0);
}

// =========================================================================
// Checksum struct matches one-shot
// =========================================================================

#[test]
fn test_incremental_two_parts() {
    let data: [u8; 20] = [
        0x45, 0x00, 0x00, 0x30, 0x00, 0x00, 0x40, 0x00, 0x40, 0x01, 0x00, 0x00, 0x0a, 0x00, 0x00,
        0x01, 0x0a, 0x00, 0x00, 0x02,
    ];

    let csum_oneshot = checksum(&data);

    let mut hasher = Checksum::new();
    hasher.update(&data[..10]);
    hasher.update(&data[10..]);
    let csum_inc = hasher.finalize();

    assert_eq!(csum_oneshot, csum_inc);
}

#[test]
fn test_incremental_three_parts() {
    let data: [u8; 20] = [
        0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, 0xac, 0x10, 0x0a,
        0x63, 0xac, 0x10, 0x0a, 0x0c,
    ];

    let csum_oneshot = checksum(&data);

    let mut hasher = Checksum::new();
    hasher.update(&data[..4]);
    hasher.update(&data[4..16]);
    hasher.update(&data[16..]);
    let csum_inc = hasher.finalize();

    assert_eq!(csum_oneshot, csum_inc);
}

#[test]
fn test_incremental_with_zeroed_checksum_field() {
    let data: [u8; 20] = [
        0x45, 0x00, 0x00, 0x30, 0x00, 0x00, 0x40, 0x00, 0x40, 0x01, 0x00, 0x00, 0x0a, 0x00, 0x00,
        0x01, 0x0a, 0x00, 0x00, 0x02,
    ];

    let csum_oneshot = checksum(&data);

    let mut hasher = Checksum::new();
    hasher.update(&data[..10]);
    hasher.update(&[0, 0]);
    hasher.update(&data[12..]);
    let csum_inc = hasher.finalize();

    assert_eq!(csum_oneshot, csum_inc);
}

#[test]
fn test_reset() {
    let data: [u8; 20] = [
        0x45, 0x00, 0x00, 0x30, 0x00, 0x00, 0x40, 0x00, 0x40, 0x01, 0x00, 0x00, 0x0a, 0x00, 0x00,
        0x01, 0x0a, 0x00, 0x00, 0x02,
    ];

    let expected = checksum(&data);

    let mut hasher = Checksum::new();
    hasher.update(&[0xFF, 0xFF, 0xFF, 0xFF]);
    hasher.reset();
    hasher.update(&data);
    assert_eq!(hasher.finalize(), expected);
}

#[test]
fn test_default_trait() {
    let mut a = Checksum::default();
    let mut b = Checksum::new();
    a.update(&[0x01, 0x02]);
    b.update(&[0x01, 0x02]);
    assert_eq!(a.finalize(), b.finalize());
}

#[test]
fn test_clone() {
    let mut original = Checksum::new();
    original.update(&[0x45, 0x00, 0x00, 0x30]);
    let cloned = original.clone();
    assert_eq!(original.finalize(), cloned.finalize());
}

// =========================================================================
// Boundary sizes (exercise the tiered loop structure)
// =========================================================================

#[test]
fn test_sizes_0_to_256() {
    let mut buf = [0u8; 256];
    for (i, b) in buf.iter_mut().enumerate() {
        *b = (i as u8).wrapping_mul(31);
    }
    for len in 0..=256 {
        assert_eq!(
            checksum(&buf[..len]),
            reference_checksum(&buf[..len]),
            "mismatch at len={len}"
        );
    }
}

#[test]
fn test_size_63_64_65() {
    let buf = [0xABu8; 256];
    for &len in &[63, 64, 65] {
        assert_eq!(
            checksum(&buf[..len]),
            reference_checksum(&buf[..len]),
            "mismatch at len={len}"
        );
    }
}

#[test]
fn test_size_127_128_129() {
    let buf = [0xCDu8; 256];
    for &len in &[127, 128, 129] {
        assert_eq!(
            checksum(&buf[..len]),
            reference_checksum(&buf[..len]),
            "mismatch at len={len}"
        );
    }
}

// =========================================================================
// TCP-style pseudo-header + header checksum
// =========================================================================

#[test]
fn test_tcp_checksum_incremental() {
    let pseudo: [u8; 12] = [192, 168, 1, 1, 192, 168, 1, 2, 0, 0, 0, 6];
    let tcp_length: [u8; 2] = [0x00, 0x14];
    let tcp_hdr: [u8; 20] = [
        0x00, 0x50, 0x00, 0x50, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x50, 0x02, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00,
    ];

    let mut combined = [0u8; 34];
    combined[..12].copy_from_slice(&pseudo);
    combined[12..14].copy_from_slice(&tcp_length);
    combined[14..].copy_from_slice(&tcp_hdr);
    let csum_oneshot = checksum(&combined);

    let mut hasher = Checksum::new();
    hasher.update(&pseudo);
    hasher.update(&tcp_length);
    hasher.update(&tcp_hdr);
    let csum_inc = hasher.finalize();

    assert_eq!(csum_oneshot, csum_inc);
}

// =========================================================================
// Fold unit tests
// =========================================================================

#[test]
fn test_fold_zero() {
    let hasher = Checksum::new();
    assert_eq!(hasher.finalize(), 0xFFFF);
}

#[test]
fn test_fold_known_value() {
    assert_eq!(checksum(&[0xFF, 0xFF]), 0);
}

// =========================================================================
// Verification property: inserting computed checksum yields zero
// =========================================================================

#[test]
fn test_verification_property() {
    let mut buf = [0u8; 60];
    for (i, b) in buf.iter_mut().enumerate() {
        *b = (i as u8).wrapping_mul(7).wrapping_add(0x5A);
    }
    buf[10] = 0;
    buf[11] = 0;
    let csum = checksum(&buf);
    buf[10] = csum.to_be_bytes()[0];
    buf[11] = csum.to_be_bytes()[1];
    assert_eq!(checksum(&buf), 0);
}

// =========================================================================
// Larger buffers
// =========================================================================

#[test]
fn test_1500_bytes() {
    let mut buf = [0u8; 1500];
    for (i, b) in buf.iter_mut().enumerate() {
        *b = (i as u8).wrapping_add(0x37);
    }
    assert_eq!(checksum(&buf), reference_checksum(&buf));
}

#[test]
fn test_odd_length_large() {
    let mut buf = [0u8; 1501];
    for (i, b) in buf.iter_mut().enumerate() {
        *b = (i as u8).wrapping_mul(13);
    }
    assert_eq!(checksum(&buf), reference_checksum(&buf));
}