#![cfg_attr(feature = "benchmark", feature(test))]
#[cfg(all(test, feature = "benchmark"))]
extern crate test;
#[inline]
pub fn checksum(bytes: &[u8]) -> [u8; 2] {
let mut c = Checksum::new();
c.add_bytes(bytes);
c.checksum()
}
#[cfg(target_arch = "x86_64")]
type Accumulator = u128;
#[cfg(not(target_arch = "x86_64"))]
type Accumulator = u64;
const SMALL_BUF_THRESHOLD: usize = 64;
macro_rules! loop_unroll {
(@inner $arr: ident, 16, $body:ident) => {
while $arr.len() >= 16 {
$body!(16, u128);
}
unroll_tail!($arr, 16, $body);
};
(@inner $arr: ident, 8, $body:ident) => {
while $arr.len() >= 8 {
$body!(8, u64);
}
unroll_tail!($arr, 8, $body);
};
($arr: ident, $body: ident) => {
#[cfg(target_arch = "x86_64")]
loop_unroll!(@inner $arr, 16, $body);
#[cfg(not(target_arch = "x86_64"))]
loop_unroll!(@inner $arr, 8, $body);
};
}
macro_rules! unroll_tail {
($arr: ident, $n: literal, $read: ident, $body: ident) => {
if $arr.len() & $n != 0 {
$body!($n, $read);
}
};
($arr: ident, 4, $body: ident) => {
unroll_tail!($arr, 2, u16, $body);
};
($arr: ident, 8, $body: ident) => {
unroll_tail!($arr, 4, u32, $body);
unroll_tail!($arr, 4, $body);
};
($arr: ident, 16, $body: ident) => {
unroll_tail!($arr, 8, u64, $body);
unroll_tail!($arr, 8, $body);
};
}
#[inline]
pub fn update(checksum: [u8; 2], old: &[u8], new: &[u8]) -> [u8; 2] {
assert_eq!(old.len(), new.len());
let mut sum = !u16::from_ne_bytes(checksum) as Accumulator;
let mut c1 = Checksum::new();
let mut c2 = Checksum::new();
c1.add_bytes(old);
c2.add_bytes(new);
sum = adc_accumulator(sum, c1.checksum_inner() as Accumulator);
sum = adc_accumulator(sum, !c2.checksum_inner() as Accumulator);
(!normalize(sum)).to_ne_bytes()
}
#[derive(Default)]
pub struct Checksum {
sum: Accumulator,
trailing_byte: Option<u8>,
}
impl Checksum {
#[inline]
pub const fn new() -> Self {
Checksum { sum: 0, trailing_byte: None }
}
#[inline]
pub fn add_bytes(&mut self, mut bytes: &[u8]) {
use std::convert::TryInto;
if bytes.len() < SMALL_BUF_THRESHOLD {
self.add_bytes_small(bytes);
return;
}
let mut sum = self.sum;
let mut carry = false;
macro_rules! update_sum_carry {
($step: literal, $ty: ident, $bytes: expr) => {
let (s, c) = sum
.overflowing_add($ty::from_ne_bytes($bytes.try_into().unwrap()) as Accumulator);
sum = s + (carry as Accumulator);
carry = c;
bytes = &bytes[$step..];
};
($step: literal, $ty: ident) => {
update_sum_carry!($step, $ty, bytes[..$step]);
};
}
if let Some(byte) = self.trailing_byte {
update_sum_carry!(1, u16, [byte, bytes[0]]);
self.trailing_byte = None;
}
loop_unroll!(bytes, update_sum_carry);
if bytes.len() == 1 {
self.trailing_byte = Some(bytes[0]);
}
self.sum = sum + (carry as Accumulator);
}
#[inline(always)]
fn add_bytes_small(&mut self, mut bytes: &[u8]) {
if bytes.is_empty() {
return;
}
let mut sum = self.sum;
fn update_sum(acc: Accumulator, rhs: u16) -> Accumulator {
if let Some(updated) = acc.checked_add(rhs as Accumulator) {
updated
} else {
(normalize(acc) + rhs) as Accumulator
}
}
if let Some(byte) = self.trailing_byte {
sum = update_sum(sum, u16::from_ne_bytes([byte, bytes[0]]));
bytes = &bytes[1..];
self.trailing_byte = None;
}
bytes.chunks(2).for_each(|chunk| match chunk {
[byte] => self.trailing_byte = Some(*byte),
[first, second] => {
sum = update_sum(sum, u16::from_ne_bytes([*first, *second]));
}
bytes => unreachable!("{:?}", bytes),
});
self.sum = sum;
}
fn checksum_inner(&self) -> u16 {
let mut sum = self.sum;
if let Some(byte) = self.trailing_byte {
sum = adc_accumulator(sum, u16::from_ne_bytes([byte, 0]) as Accumulator);
}
!normalize(sum)
}
#[inline]
pub fn checksum(&self) -> [u8; 2] {
self.checksum_inner().to_ne_bytes()
}
}
macro_rules! impl_adc {
($name: ident, $t: ty) => {
fn $name(a: $t, b: $t) -> $t {
let (s, c) = a.overflowing_add(b);
s + (c as $t)
}
};
}
impl_adc!(adc_u16, u16);
impl_adc!(adc_u32, u32);
#[cfg(target_arch = "x86_64")]
impl_adc!(adc_u64, u64);
impl_adc!(adc_accumulator, Accumulator);
fn normalize(a: Accumulator) -> u16 {
#[cfg(target_arch = "x86_64")]
return normalize_64(adc_u64(a as u64, (a >> 64) as u64));
#[cfg(not(target_arch = "x86_64"))]
return normalize_64(a);
}
fn normalize_64(a: u64) -> u16 {
let t = adc_u32(a as u32, (a >> 32) as u32);
adc_u16(t as u16, (t >> 16) as u16)
}
#[cfg(all(test, feature = "benchmark"))]
mod benchmarks {
extern crate test;
use super::*;
#[bench]
fn bench_checksum_31(b: &mut test::Bencher) {
b.iter(|| {
let buf = test::black_box([0xFF; 31]);
let mut c = Checksum::new();
c.add_bytes(&buf);
test::black_box(c.checksum());
});
}
#[bench]
fn bench_checksum_32(b: &mut test::Bencher) {
b.iter(|| {
let buf = test::black_box([0xFF; 32]);
let mut c = Checksum::new();
c.add_bytes(&buf);
test::black_box(c.checksum());
});
}
#[bench]
fn bench_checksum_64(b: &mut test::Bencher) {
b.iter(|| {
let buf = test::black_box([0xFF; 64]);
let mut c = Checksum::new();
c.add_bytes(&buf);
test::black_box(c.checksum());
});
}
#[bench]
fn bench_checksum_128(b: &mut test::Bencher) {
b.iter(|| {
let buf = test::black_box([0xFF; 128]);
let mut c = Checksum::new();
c.add_bytes(&buf);
test::black_box(c.checksum());
});
}
#[bench]
fn bench_checksum_256(b: &mut test::Bencher) {
b.iter(|| {
let buf = test::black_box([0xFF; 256]);
let mut c = Checksum::new();
c.add_bytes(&buf);
test::black_box(c.checksum());
});
}
#[bench]
fn bench_checksum_1024(b: &mut test::Bencher) {
b.iter(|| {
let buf = test::black_box([0xFF; 1024]);
let mut c = Checksum::new();
c.add_bytes(&buf);
test::black_box(c.checksum());
});
}
#[bench]
fn bench_checksum_1023(b: &mut test::Bencher) {
b.iter(|| {
let buf = test::black_box([0xFF; 1023]);
let mut c = Checksum::new();
c.add_bytes(&buf);
test::black_box(c.checksum());
});
}
#[bench]
fn bench_checksum_20(b: &mut test::Bencher) {
b.iter(|| {
let buf = test::black_box([0xFF; 20]);
let mut c = Checksum::new();
c.add_bytes(&buf);
test::black_box(c.checksum());
});
}
#[bench]
fn bench_checksum_small_20(b: &mut test::Bencher) {
b.iter(|| {
let buf = test::black_box([0xFF; 20]);
let mut c = Checksum::new();
c.add_bytes_small(&buf);
test::black_box(c.checksum());
});
}
#[bench]
fn bench_checksum_small_31(b: &mut test::Bencher) {
b.iter(|| {
let buf = test::black_box([0xFF; 31]);
let mut c = Checksum::new();
c.add_bytes_small(&buf);
test::black_box(c.checksum());
});
}
#[bench]
fn bench_update_2(b: &mut test::Bencher) {
b.iter(|| {
let old = test::black_box([0x42; 2]);
let new = test::black_box([0xa0; 2]);
test::black_box(update([42; 2], &old[..], &new[..]));
});
}
#[bench]
fn bench_update_4(b: &mut test::Bencher) {
b.iter(|| {
let old = test::black_box([0x42; 4]);
let new = test::black_box([0xa0; 4]);
test::black_box(update([42; 2], &old[..], &new[..]));
});
}
#[bench]
fn bench_update_8(b: &mut test::Bencher) {
b.iter(|| {
let old = test::black_box([0x42; 8]);
let new = test::black_box([0xa0; 8]);
test::black_box(update([42; 2], &old[..], &new[..]));
});
}
}
#[cfg(test)]
mod tests {
use rand::{Rng, SeedableRng};
use rand_xorshift::XorShiftRng;
use super::*;
fn new_rng(mut seed: u128) -> XorShiftRng {
if seed == 0 {
seed = 1;
}
XorShiftRng::from_seed(seed.to_ne_bytes())
}
#[test]
fn test_checksum() {
for buf in IPV4_HEADERS {
let mut c = Checksum::new();
c.add_bytes(&buf);
assert_eq!(c.checksum(), [0u8; 2]);
let mut c = Checksum::new();
for byte in *buf {
c.add_bytes(&[*byte]);
}
assert_eq!(c.checksum(), [0u8; 2]);
let mut c = Checksum::new();
c.add_bytes(&[0xFF, 0xFF]);
for _ in 0..((2 * (1 << 16)) - 1) {
c.add_bytes(&[0xFF, 0xFF]);
}
assert_eq!(c.checksum(), [0u8; 2]);
}
}
#[test]
fn test_update() {
for b in IPV4_HEADERS {
let mut buf = Vec::new();
buf.extend_from_slice(b);
let mut c = Checksum::new();
c.add_bytes(&buf);
assert_eq!(c.checksum(), [0u8; 2]);
let old = [buf[16], buf[17], buf[18], buf[19]];
(&mut buf[16..20]).copy_from_slice(&[127, 0, 0, 1]);
let updated = update(c.checksum(), &old, &[127, 0, 0, 1]);
let from_scratch = {
let mut c = Checksum::new();
c.add_bytes(&buf);
c.checksum()
};
assert_eq!(updated, from_scratch);
}
}
#[test]
fn test_smoke_update() {
let mut rng = new_rng(70_812_476_915_813);
for _ in 0..2048 {
const BUF_LEN: usize = 31;
let buf: [u8; BUF_LEN] = rng.gen();
let mut c = Checksum::new();
c.add_bytes(&buf);
let (begin, end) = loop {
let begin = rng.gen::<usize>() % BUF_LEN;
let end = begin + (rng.gen::<usize>() % (BUF_LEN + 1 - begin));
if begin % 2 == 0 && (end % 2 == 0 || end == BUF_LEN) {
break (begin, end);
}
};
let mut new_buf = buf;
for i in begin..end {
new_buf[i] = rng.gen();
}
let updated = update(c.checksum(), &buf[begin..end], &new_buf[begin..end]);
let from_scratch = {
let mut c = Checksum::new();
c.add_bytes(&new_buf);
c.checksum()
};
assert_eq!(updated, from_scratch);
}
}
#[test]
fn test_add_bytes_small_prop_test() {
let mut rng = new_rng(123478012483);
let mut c1 = Checksum::new();
let mut c2 = Checksum::new();
for len in 64..1_025 {
for _ in 0..4 {
let mut buf = vec![];
for _ in 0..len {
buf.push(rng.gen());
}
c1.add_bytes(&buf[..]);
c2.add_bytes_small(&buf[..]);
assert_eq!(c1.checksum(), c2.checksum());
let n1 = c1.checksum_inner();
let n2 = c2.checksum_inner();
assert_eq!(n1, n2);
let mut t1 = Checksum::new();
let mut t2 = Checksum::new();
let mut t3 = Checksum::new();
t3.add_bytes(&buf[..]);
if buf.len() % 2 == 1 {
buf.push(0);
}
assert_eq!(buf.len() % 2, 0);
buf.extend_from_slice(&t3.checksum());
t1.add_bytes(&buf[..]);
t2.add_bytes_small(&buf[..]);
assert_eq!(t1.checksum(), [0, 0]);
assert_eq!(t2.checksum(), [0, 0]);
}
}
}
const IPV4_HEADERS: &[&[u8]] = &[
&[
0x45, 0x00, 0x00, 0x34, 0x00, 0x00, 0x40, 0x00, 0x40, 0x06, 0xae, 0xea, 0xc0, 0xa8,
0x01, 0x0f, 0xc0, 0xb8, 0x09, 0x6a,
],
&[
0x45, 0x20, 0x00, 0x74, 0x5b, 0x6e, 0x40, 0x00, 0x37, 0x06, 0x5c, 0x1c, 0xc0, 0xb8,
0x09, 0x6a, 0xc0, 0xa8, 0x01, 0x0f,
],
&[
0x45, 0x20, 0x02, 0x8f, 0x00, 0x00, 0x40, 0x00, 0x3b, 0x11, 0xc9, 0x3f, 0xac, 0xd9,
0x05, 0x6e, 0xc0, 0xa8, 0x01, 0x0f,
],
];
}