use core::net::{Ipv4Addr, Ipv6Addr};
pub fn ones_complement_sum(data: &[u8]) -> u32 {
let mut chunks = data.chunks_exact(2);
let mut sum = chunks.by_ref().fold(0u32, |sum, chunk| {
sum + u16::from_be_bytes([chunk[0], chunk[1]]) as u32
});
if let [last] = chunks.remainder() {
sum += (*last as u32) << 8;
}
sum
}
pub fn fold_sum(mut sum: u32) -> u16 {
while (sum >> 16) != 0 {
sum = (sum & 0xffff) + (sum >> 16);
}
sum as u16
}
pub fn finalize_checksum(sum: u32) -> u16 {
!fold_sum(sum)
}
pub fn internet_checksum(data: &[u8]) -> u16 {
finalize_checksum(ones_complement_sum(data))
}
pub fn internet_checksum_chunks<'a>(chunks: impl IntoIterator<Item = &'a [u8]>) -> u16 {
let mut sum = 0u32;
let mut high_byte = None;
for chunk in chunks {
for &byte in chunk {
if let Some(high) = high_byte.take() {
sum += u16::from_be_bytes([high, byte]) as u32;
} else {
high_byte = Some(byte);
}
}
}
if let Some(high) = high_byte {
sum += (high as u32) << 8;
}
finalize_checksum(sum)
}
pub fn verify_internet_checksum(data: &[u8]) -> bool {
internet_checksum(data) == 0
}
pub fn ipv4_header_checksum(header: &[u8]) -> u16 {
internet_checksum(header)
}
pub fn ipv4_pseudo_header_checksum(
source: Ipv4Addr,
destination: Ipv4Addr,
protocol: u8,
transport: &[u8],
) -> u16 {
let source = source.octets();
let destination = destination.octets();
let length = transport.len() as u16;
let mut sum = 0u32;
sum += u16::from_be_bytes([source[0], source[1]]) as u32;
sum += u16::from_be_bytes([source[2], source[3]]) as u32;
sum += u16::from_be_bytes([destination[0], destination[1]]) as u32;
sum += u16::from_be_bytes([destination[2], destination[3]]) as u32;
sum += protocol as u32;
sum += length as u32;
sum += ones_complement_sum(transport);
finalize_checksum(sum)
}
pub fn ipv6_pseudo_header_checksum(
source: Ipv6Addr,
destination: Ipv6Addr,
next_header: u8,
transport: &[u8],
) -> u16 {
let source = source.octets();
let destination = destination.octets();
let length = transport.len() as u32;
let mut sum = 0u32;
sum += u16::from_be_bytes([source[0], source[1]]) as u32;
sum += u16::from_be_bytes([source[2], source[3]]) as u32;
sum += u16::from_be_bytes([source[4], source[5]]) as u32;
sum += u16::from_be_bytes([source[6], source[7]]) as u32;
sum += u16::from_be_bytes([source[8], source[9]]) as u32;
sum += u16::from_be_bytes([source[10], source[11]]) as u32;
sum += u16::from_be_bytes([source[12], source[13]]) as u32;
sum += u16::from_be_bytes([source[14], source[15]]) as u32;
sum += u16::from_be_bytes([destination[0], destination[1]]) as u32;
sum += u16::from_be_bytes([destination[2], destination[3]]) as u32;
sum += u16::from_be_bytes([destination[4], destination[5]]) as u32;
sum += u16::from_be_bytes([destination[6], destination[7]]) as u32;
sum += u16::from_be_bytes([destination[8], destination[9]]) as u32;
sum += u16::from_be_bytes([destination[10], destination[11]]) as u32;
sum += u16::from_be_bytes([destination[12], destination[13]]) as u32;
sum += u16::from_be_bytes([destination[14], destination[15]]) as u32;
sum += (length >> 16) + (length & 0xffff);
sum += next_header as u32;
sum += ones_complement_sum(transport);
finalize_checksum(sum)
}
pub fn fletcher16_checkbytes(data: &[u8], checksum_offset: usize) -> [u8; 2] {
let mut c0: i32 = 0;
let mut c1: i32 = 0;
for &byte in data {
c0 = (c0 + byte as i32) % 255;
c1 = (c1 + c0) % 255;
}
let length = data.len() as i32;
let offset = checksum_offset as i32 + 1;
let mut x = ((length - offset) * c0 - c1) % 255;
if x <= 0 {
x += 255;
}
let mut y = 510 - c0 - x;
if y > 255 {
y -= 255;
}
[x as u8, y as u8]
}
pub fn fletcher16_valid(data: &[u8]) -> bool {
let mut c0: i32 = 0;
let mut c1: i32 = 0;
for &byte in data {
c0 = (c0 + byte as i32) % 255;
c1 = (c1 + c0) % 255;
}
c0 == 0 && c1 == 0
}
pub(crate) fn crc32c(data: &[u8]) -> u32 {
const POLY_REFLECTED: u32 = 0x82f6_3b78;
let mut crc = !0u32;
for &byte in data {
crc ^= byte as u32;
for _ in 0..8 {
let mask = (crc & 1).wrapping_neg();
crc = (crc >> 1) ^ (POLY_REFLECTED & mask);
}
}
!crc
}
#[cfg(test)]
mod tests {
use super::{
crc32c, finalize_checksum, fletcher16_checkbytes, fletcher16_valid, fold_sum,
internet_checksum, internet_checksum_chunks, ipv4_header_checksum,
ipv4_pseudo_header_checksum, ipv6_pseudo_header_checksum, ones_complement_sum,
verify_internet_checksum,
};
use core::net::{Ipv4Addr, Ipv6Addr};
#[test]
fn checksum_handles_even_length_input() {
assert_eq!(internet_checksum(&[0x00, 0x01, 0xf2, 0x03]), 0x0dfb);
}
#[test]
fn checksum_handles_odd_length_input() {
assert_eq!(internet_checksum(&[0x00, 0x01, 0xf2]), 0x0dfe);
}
#[test]
fn checksum_folds_carries() {
assert_eq!(fold_sum(0x0001_0001), 0x0002);
assert_eq!(finalize_checksum(0xffff), 0x0000);
}
#[test]
fn checksum_chunks_match_contiguous_bytes() {
let contiguous = [0x45, 0x00, 0x00, 0x54, 0xab, 0xcd, 0x00, 0x00];
let chunked = internet_checksum_chunks([&contiguous[..4], &contiguous[4..]]);
assert_eq!(chunked, internet_checksum(&contiguous));
}
#[test]
fn checksum_chunks_allow_odd_boundaries() {
let contiguous = [0x01, 0x02, 0x03, 0x04, 0x05];
let chunked =
internet_checksum_chunks([&contiguous[..1], &contiguous[1..3], &contiguous[3..]]);
assert_eq!(chunked, internet_checksum(&contiguous));
}
#[test]
fn checksum_verifies_ipv4_header_fixture() {
let header = [
0x45, 0x00, 0x00, 0x54, 0xa6, 0xf2, 0x40, 0x00, 0x40, 0x01, 0x0e, 0xc2, 0xc0, 0xa8,
0x01, 0x65, 0xac, 0xd9, 0x16, 0x0e,
];
assert!(verify_internet_checksum(&header));
let mut zeroed = header;
zeroed[10] = 0;
zeroed[11] = 0;
assert_eq!(ipv4_header_checksum(&zeroed), 0x0ec2);
}
#[test]
fn checksum_sums_words_without_complementing() {
assert_eq!(ones_complement_sum(&[0x12, 0x34, 0x56]), 0x1234 + 0x5600);
}
#[test]
fn checksum_builds_ipv4_pseudo_header() {
let udp = [0x12, 0x34, 0x00, 0x35, 0x00, 0x08, 0x00, 0x00];
let checksum = ipv4_pseudo_header_checksum(
Ipv4Addr::new(192, 0, 2, 1),
Ipv4Addr::new(198, 51, 100, 2),
17,
&udp,
);
assert_eq!(checksum, 0x013e);
}
#[test]
fn checksum_builds_ipv6_pseudo_header() {
let udp = [0x12, 0x34, 0x00, 0x35, 0x00, 0x08, 0x00, 0x00];
let checksum = ipv6_pseudo_header_checksum(
Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1),
Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 2),
17,
&udp,
);
assert_eq!(checksum, 0x9200);
}
#[test]
fn fletcher16_round_trips_a_zeroed_field() {
let mut buffer = [
0x00, 0x02, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c,
0x0d, 0x0e, 0x00, 0x00, 0x10, 0x11,
];
let offset = 16;
let check = fletcher16_checkbytes(&buffer, offset);
buffer[offset] = check[0];
buffer[offset + 1] = check[1];
assert!(fletcher16_valid(&buffer));
}
#[test]
fn fletcher16_detects_a_corrupted_octet() {
let mut buffer = [
0xde, 0xad, 0xbe, 0xef, 0x00, 0x01, 0x02, 0x03, 0x00, 0x00, 0x55, 0x66,
];
let offset = 8;
let check = fletcher16_checkbytes(&buffer, offset);
buffer[offset] = check[0];
buffer[offset + 1] = check[1];
assert!(fletcher16_valid(&buffer));
buffer[2] ^= 0x01;
assert!(!fletcher16_valid(&buffer));
}
#[test]
fn fletcher16_pins_a_deterministic_pair() {
let mut buffer = [
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x00, 0x00, 0x09, 0x0a,
];
let offset = 8;
let check = fletcher16_checkbytes(&buffer, offset);
assert_eq!(check, [0x80, 0x48]);
buffer[offset] = check[0];
buffer[offset + 1] = check[1];
assert!(fletcher16_valid(&buffer));
}
#[test]
fn crc32c_matches_standard_vectors() {
assert_eq!(crc32c(b""), 0x0000_0000);
assert_eq!(crc32c(b"123456789"), 0xe306_9283);
assert_eq!(crc32c(b"abc"), 0x364b_3fb7);
assert_eq!(crc32c(b"hello world"), 0xc994_65aa);
}
}