use core::{fmt, hash::Hasher, num::Wrapping};
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
mod x86;
#[inline]
pub fn checksum(data: &[u8]) -> u16 {
let mut checksum = Checksum::default();
checksum.write(data);
checksum.finish()
}
const LARGE_WRITE_LEN: usize = 32;
type Accumulator = u64;
type State = Wrapping<Accumulator>;
type LargeWriteFn = for<'a> unsafe fn(&mut State, bytes: &'a [u8]) -> &'a [u8];
#[inline(always)]
fn write_sized_generic<'a, const MAX_LEN: usize, const CHUNK_LEN: usize>(
state: &mut State,
mut bytes: &'a [u8],
on_chunk: impl Fn(&[u8; CHUNK_LEN], &mut Accumulator),
) -> &'a [u8] {
while bytes.len() >= MAX_LEN {
let chunks = unsafe { bytes.get_unchecked(..MAX_LEN) };
bytes = unsafe { bytes.get_unchecked(MAX_LEN..) };
let mut sum = 0;
for chunk in chunks.chunks_exact(CHUNK_LEN) {
let chunk = unsafe {
debug_assert_eq!(chunk.len(), CHUNK_LEN);
&*(chunk.as_ptr() as *const [u8; CHUNK_LEN])
};
on_chunk(chunk, &mut sum);
}
*state += sum;
}
bytes
}
#[inline(always)]
fn write_sized_generic_u16<'a, const LEN: usize>(state: &mut State, bytes: &'a [u8]) -> &'a [u8] {
write_sized_generic::<LEN, 2>(
state,
bytes,
#[inline(always)]
|&bytes, acc| {
*acc += u16::from_ne_bytes(bytes) as Accumulator;
},
)
}
#[inline(always)]
fn write_sized_generic_u32<'a, const LEN: usize>(state: &mut State, bytes: &'a [u8]) -> &'a [u8] {
write_sized_generic::<LEN, 4>(
state,
bytes,
#[inline(always)]
|&bytes, acc| {
*acc += u32::from_ne_bytes(bytes) as Accumulator;
},
)
}
#[inline]
#[cfg(all(feature = "once_cell", not(any(kani, miri))))]
fn probe_write_large() -> LargeWriteFn {
static LARGE_WRITE_FN: once_cell::sync::Lazy<LargeWriteFn> = once_cell::sync::Lazy::new(|| {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if let Some(fun) = x86::probe() {
return fun;
}
}
write_sized_generic_u32::<16>
});
*LARGE_WRITE_FN
}
#[inline]
#[cfg(not(all(feature = "once_cell", not(any(kani, miri)))))]
fn probe_write_large() -> LargeWriteFn {
write_sized_generic_u32::<16>
}
#[derive(Clone, Copy)]
pub struct Checksum {
state: State,
partial_write: bool,
write_large: LargeWriteFn,
}
impl Default for Checksum {
fn default() -> Self {
Self {
state: Default::default(),
partial_write: false,
write_large: probe_write_large(),
}
}
}
impl fmt::Debug for Checksum {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut v = *self;
v.carry();
f.debug_tuple("Checksum").field(&v.finish()).finish()
}
}
impl Checksum {
#[inline]
pub fn generic() -> Self {
Self {
state: Default::default(),
partial_write: false,
write_large: write_sized_generic_u32::<16>,
}
}
#[inline]
fn write_byte(&mut self, byte: u8, shift: bool) {
if shift {
self.state += (byte as Accumulator) << 8;
} else {
self.state += byte as Accumulator;
}
}
#[inline]
fn carry(&mut self) {
#[cfg(kani)]
self.carry_rfc();
#[cfg(not(kani))]
self.carry_optimized();
}
#[inline]
#[allow(dead_code)]
fn carry_rfc(&mut self) {
let mut state = self.state.0;
for _ in 0..core::mem::size_of::<Accumulator>() {
state = (state & 0xffff) + (state >> 16);
}
self.state.0 = state;
}
#[inline]
#[allow(dead_code)]
fn carry_optimized(&mut self) {
let values: [u16; core::mem::size_of::<Accumulator>() / 2] = unsafe {
debug_assert!(core::mem::align_of::<State>() >= core::mem::align_of::<u16>());
core::mem::transmute(self.state.0)
};
let mut sum = 0u16;
for value in values {
let (res, overflowed) = sum.overflowing_add(value);
sum = res;
if overflowed {
sum += 1;
}
}
self.state.0 = sum as _;
}
#[inline]
pub fn write_padded(&mut self, bytes: &[u8]) {
self.write(bytes);
if core::mem::take(&mut self.partial_write) {
self.write_byte(0, cfg!(target_endian = "little"));
}
}
#[inline]
pub fn finish(self) -> u16 {
self.finish_be().to_be()
}
#[inline]
pub fn finish_be(mut self) -> u16 {
self.carry();
let value = self.state.0 as u16;
let value = !value;
if value == 0 {
return 0xffff;
}
value
}
}
impl Hasher for Checksum {
#[inline]
fn write(&mut self, mut bytes: &[u8]) {
if bytes.is_empty() {
return;
}
if core::mem::take(&mut self.partial_write) {
let (chunk, remaining) = bytes.split_at(1);
bytes = remaining;
self.write_byte(chunk[0], cfg!(target_endian = "little"));
}
if bytes.len() >= LARGE_WRITE_LEN {
bytes = unsafe { (self.write_large)(&mut self.state, bytes) };
}
#[cfg(not(kani))]
{
bytes = write_sized_generic_u32::<4>(&mut self.state, bytes);
}
bytes = write_sized_generic_u16::<2>(&mut self.state, bytes);
if let Some(byte) = bytes.first().copied() {
self.partial_write = true;
self.write_byte(byte, cfg!(target_endian = "big"));
}
}
#[inline]
fn finish(&self) -> u64 {
Self::finish(*self) as _
}
}
#[cfg(test)]
mod tests {
use super::*;
use bolero::check;
#[test]
fn rfc_example_test() {
let bytes = [0x00, 0x01, 0xf2, 0x03, 0xf4, 0xf5, 0xf6, 0xf7];
let mut checksum = Checksum::default();
checksum.write(&bytes);
checksum.carry();
assert_eq!((checksum.state.0 as u16).to_le_bytes(), [0xdd, 0xf2]);
assert_eq!((!rfc_c_port(&bytes)).to_be_bytes(), [0xdd, 0xf2]);
}
fn rfc_c_port(data: &[u8]) -> u16 {
let mut addr = data.as_ptr();
let mut count = data.len();
unsafe {
let mut sum = 0u32;
while count > 1 {
let value = u16::from_be_bytes([*addr, *addr.add(1)]);
sum = sum.wrapping_add(value as u32);
addr = addr.add(2);
count -= 2;
}
if count > 0 {
let value = u16::from_be_bytes([*addr, 0]);
sum = sum.wrapping_add(value as u32);
}
while sum >> 16 != 0 {
sum = (sum & 0xffff) + (sum >> 16);
}
!(sum as u16)
}
}
#[cfg(any(kani, miri))]
const LEN: usize = if cfg!(kani) { 16 } else { 32 };
#[test]
#[cfg_attr(kani, kani::proof, kani::unwind(17), kani::solver(cadical))]
fn differential() {
#[cfg(any(kani, miri))]
type Bytes = crate::testing::InlineVec<u8, LEN>;
#[cfg(not(any(kani, miri)))]
type Bytes = Vec<u8>;
check!()
.with_type::<(usize, Bytes)>()
.for_each(|(index, bytes)| {
let index = if bytes.is_empty() {
0
} else {
*index % bytes.len()
};
let (a, b) = bytes.split_at(index);
let mut cs = Checksum::default();
cs.write(a);
cs.write(b);
let mut rfc_value = rfc_c_port(bytes);
if rfc_value == 0 {
rfc_value = 0xffff;
}
assert_eq!(rfc_value.to_be_bytes(), cs.finish().to_be_bytes());
});
}
#[test]
#[cfg_attr(kani, kani::proof, kani::unwind(9), kani::solver(kissat))]
fn u32_u16_differential() {
#[cfg(any(kani, miri))]
type Bytes = crate::testing::InlineVec<u8, 8>;
#[cfg(not(any(kani, miri)))]
type Bytes = Vec<u8>;
check!().with_type::<Bytes>().for_each(|bytes| {
let a = {
let mut cs = Checksum::generic();
let bytes = write_sized_generic_u32::<4>(&mut cs.state, bytes);
write_sized_generic_u16::<2>(&mut cs.state, bytes);
cs.finish()
};
let b = {
let mut cs = Checksum::generic();
write_sized_generic_u16::<2>(&mut cs.state, bytes);
cs.finish()
};
assert_eq!(a, b);
});
}
#[test]
#[cfg_attr(kani, kani::proof, kani::unwind(9), kani::solver(kissat))]
fn carry_differential() {
check!().with_type::<u64>().cloned().for_each(|state| {
let mut opt = Checksum::generic();
opt.state.0 = state;
opt.carry_optimized();
let mut rfc = Checksum::generic();
rfc.state.0 = state;
rfc.carry_rfc();
assert_eq!(opt.state.0, rfc.state.0);
});
}
}