#[inline(always)]
pub fn checksum_fold(mut sum: u32) -> u16 {
while sum >> 16 != 0 {
sum = (sum & 0xFFFF) + (sum >> 16);
}
!sum as u16
}
#[inline]
pub fn checksum_add(data: &[u8]) -> u32 {
let mut sum: u32 = 0;
let mut i = 0;
while i + 1 < data.len() {
let word = u16::from_be_bytes([data[i], data[i + 1]]);
sum = sum.wrapping_add(word as u32);
i += 2;
}
if i < data.len() {
sum = sum.wrapping_add((data[i] as u32) << 8);
}
sum
}
#[inline]
pub fn checksum(data: &[u8]) -> u16 {
checksum_fold(checksum_add(data))
}
#[inline(always)]
pub fn incremental_checksum_update(old_checksum: u16, old_value: u16, new_value: u16) -> u16 {
let sum = (!old_checksum as u32)
.wrapping_add(!old_value as u32)
.wrapping_add(new_value as u32);
checksum_fold(sum)
}
#[inline]
pub fn incremental_checksum_update_32(old_checksum: u16, old_value: u32, new_value: u32) -> u16 {
let old_hi = (old_value >> 16) as u16;
let old_lo = old_value as u16;
let new_hi = (new_value >> 16) as u16;
let new_lo = new_value as u16;
let checksum = incremental_checksum_update(old_checksum, old_hi, new_hi);
incremental_checksum_update(checksum, old_lo, new_lo)
}
#[inline]
pub fn update_checksum_for_ip(old_checksum: u16, old_ip: [u8; 4], new_ip: [u8; 4]) -> u16 {
let old_val = u32::from_be_bytes(old_ip);
let new_val = u32::from_be_bytes(new_ip);
incremental_checksum_update_32(old_checksum, old_val, new_val)
}
#[inline]
pub fn update_checksum_for_port(old_checksum: u16, old_port: u16, new_port: u16) -> u16 {
incremental_checksum_update(old_checksum, old_port, new_port)
}
#[inline]
pub fn update_checksum_for_nat(
old_checksum: u16,
old_ip: [u8; 4],
old_port: u16,
new_ip: [u8; 4],
new_port: u16,
) -> u16 {
let checksum = update_checksum_for_ip(old_checksum, old_ip, new_ip);
update_checksum_for_port(checksum, old_port, new_port)
}
#[inline]
pub fn ipv4_header_checksum(header: &[u8]) -> u16 {
debug_assert!(header.len() >= 20, "IPv4 header too short");
checksum(header)
}
#[inline]
pub fn tcp_checksum(src_ip: [u8; 4], dst_ip: [u8; 4], tcp_segment: &[u8]) -> u16 {
let mut sum: u32 = 0;
sum = sum.wrapping_add(u16::from_be_bytes([src_ip[0], src_ip[1]]) as u32);
sum = sum.wrapping_add(u16::from_be_bytes([src_ip[2], src_ip[3]]) as u32);
sum = sum.wrapping_add(u16::from_be_bytes([dst_ip[0], dst_ip[1]]) as u32);
sum = sum.wrapping_add(u16::from_be_bytes([dst_ip[2], dst_ip[3]]) as u32);
sum = sum.wrapping_add(6u32); sum = sum.wrapping_add(tcp_segment.len() as u32);
sum = sum.wrapping_add(checksum_add(tcp_segment));
checksum_fold(sum)
}
#[inline]
pub fn udp_checksum(src_ip: [u8; 4], dst_ip: [u8; 4], udp_datagram: &[u8]) -> u16 {
let mut sum: u32 = 0;
sum = sum.wrapping_add(u16::from_be_bytes([src_ip[0], src_ip[1]]) as u32);
sum = sum.wrapping_add(u16::from_be_bytes([src_ip[2], src_ip[3]]) as u32);
sum = sum.wrapping_add(u16::from_be_bytes([dst_ip[0], dst_ip[1]]) as u32);
sum = sum.wrapping_add(u16::from_be_bytes([dst_ip[2], dst_ip[3]]) as u32);
sum = sum.wrapping_add(17u32); sum = sum.wrapping_add(udp_datagram.len() as u32);
sum = sum.wrapping_add(checksum_add(udp_datagram));
let result = checksum_fold(sum);
if result == 0 { 0xFFFF } else { result }
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn checksum_simd_neon(data: &[u8]) -> u16 {
use std::arch::aarch64::*;
unsafe {
let mut sum = vdupq_n_u32(0);
let chunks = data.chunks_exact(16);
let remainder = chunks.remainder();
for chunk in chunks {
let bytes = vld1q_u8(chunk.as_ptr());
let swapped = vrev16q_u8(bytes);
let words = vreinterpretq_u16_u8(swapped);
sum = vpadalq_u16(sum, words);
}
let sum32 = vaddvq_u32(sum);
let mut scalar_sum = sum32;
let mut i = 0;
while i + 1 < remainder.len() {
let word = u16::from_be_bytes([remainder[i], remainder[i + 1]]);
scalar_sum = scalar_sum.wrapping_add(word as u32);
i += 2;
}
if i < remainder.len() {
scalar_sum = scalar_sum.wrapping_add((remainder[i] as u32) << 8);
}
while scalar_sum > 0xFFFF {
scalar_sum = (scalar_sum & 0xFFFF) + (scalar_sum >> 16);
}
!scalar_sum as u16
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
pub fn checksum_simd(data: &[u8]) -> u16 {
unsafe { checksum_simd_neon(data) }
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "ssse3")]
unsafe fn checksum_simd_ssse3(data: &[u8]) -> u16 {
use std::arch::x86_64::*;
unsafe {
let mut sum_lo = _mm_setzero_si128();
let mut sum_hi = _mm_setzero_si128();
let swap_mask = _mm_setr_epi8(1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14);
let chunks = data.chunks_exact(16);
let remainder = chunks.remainder();
for chunk in chunks {
let bytes = _mm_loadu_si128(chunk.as_ptr().cast());
let swapped = _mm_shuffle_epi8(bytes, swap_mask);
let zero = _mm_setzero_si128();
let words_lo = _mm_unpacklo_epi16(swapped, zero); let words_hi = _mm_unpackhi_epi16(swapped, zero);
sum_lo = _mm_add_epi32(sum_lo, words_lo);
sum_hi = _mm_add_epi32(sum_hi, words_hi);
}
let sum = _mm_add_epi32(sum_lo, sum_hi);
let hadd1 = _mm_hadd_epi32(sum, sum);
let hadd2 = _mm_hadd_epi32(hadd1, hadd1);
let sum32 = _mm_cvtsi128_si32(hadd2) as u32;
let mut scalar_sum = sum32;
let mut i = 0;
while i + 1 < remainder.len() {
let word = u16::from_be_bytes([remainder[i], remainder[i + 1]]);
scalar_sum = scalar_sum.wrapping_add(word as u32);
i += 2;
}
if i < remainder.len() {
scalar_sum = scalar_sum.wrapping_add((remainder[i] as u32) << 8);
}
while scalar_sum > 0xFFFF {
scalar_sum = (scalar_sum & 0xFFFF) + (scalar_sum >> 16);
}
!scalar_sum as u16
}
}
#[cfg(target_arch = "x86_64")]
#[inline]
pub fn checksum_simd(data: &[u8]) -> u16 {
if is_x86_feature_detected!("ssse3") {
unsafe { checksum_simd_ssse3(data) }
} else {
checksum(data)
}
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
pub fn checksum_simd(data: &[u8]) -> u16 {
checksum(data)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_checksum_basic() {
let data = [0x00, 0x01, 0xF2, 0x03, 0xF4, 0xF5, 0xF6, 0xF7];
let sum = checksum(&data);
assert_eq!(sum, 0x220D);
}
#[test]
fn test_checksum_empty() {
assert_eq!(checksum(&[]), 0xFFFF);
}
#[test]
fn test_checksum_odd_length() {
let data = [0x01, 0x02, 0x03];
let sum = checksum(&data);
assert_eq!(sum, 0xFBFD);
}
#[test]
fn test_incremental_update() {
let data = [0x00, 0x01, 0x02, 0x03];
let original_checksum = checksum(&data);
let updated = incremental_checksum_update(original_checksum, 0x0001, 0x0005);
let new_data = [0x00, 0x05, 0x02, 0x03];
let recalculated = checksum(&new_data);
assert_eq!(updated, recalculated);
}
#[test]
fn test_incremental_update_ip() {
let data: [u8; 8] = [
192, 168, 1, 100, 192, 168, 1, 1, ];
let original_checksum = checksum(&data);
let new_ip = [10u8, 0, 0, 100]; let updated = update_checksum_for_ip(original_checksum, [192, 168, 1, 100], new_ip);
let new_data: [u8; 8] = [10, 0, 0, 100, 192, 168, 1, 1];
let recalculated = checksum(&new_data);
assert_eq!(updated, recalculated);
}
#[test]
fn test_ipv4_checksum() {
let header: [u8; 20] = [
0x45, 0x00, 0x00, 0x3c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, 0xac, 0x10, 0x0a, 0x63, 0xac, 0x10, 0x0a, 0x0c, ];
let checksum = ipv4_header_checksum(&header);
assert_ne!(checksum, 0);
}
#[test]
fn test_nat_checksum_update() {
let original_checksum: u16 = 0x1234;
let updated = update_checksum_for_nat(
original_checksum,
[192, 168, 1, 100],
12345,
[10, 0, 0, 1],
54321,
);
assert_ne!(updated, original_checksum);
}
#[test]
fn test_checksum_simd() {
let data: Vec<u8> = (0..100).collect();
let scalar = checksum(&data);
let simd = checksum_simd(&data);
assert_eq!(scalar, simd);
}
}