#include <stdio.h>
#include <string.h>
#include <assert.h>
#include "checksum.h"
static uint16_t reference_checksum(const unsigned char *data, size_t len) {
uint32_t sum = 0;
size_t i = 0;
while (i + 1 < len) {
sum += ((uint32_t)data[i] << 8) | (uint32_t)data[i + 1];
i += 2;
}
if (i < len) {
sum += (uint32_t)data[i] << 8;
}
while (sum >> 16) {
sum = (sum >> 16) + (sum & 0xffffu);
}
return ~(uint16_t)sum;
}
static void fill_deterministic(unsigned char *buf, size_t len,
unsigned char seed) {
for (size_t i = 0; i < len; i++) {
buf[i] = (unsigned char)((i * 31u + seed) & 0xFFu);
}
}
static void test_empty(void) {
assert(ip4sum_checksum_oneshot((const unsigned char *)"", 0)
== reference_checksum((const unsigned char *)"", 0));
}
static void test_single_byte(void) {
const unsigned char data[] = {0x45};
assert(ip4sum_checksum_oneshot(data, sizeof(data))
== reference_checksum(data, sizeof(data)));
}
static void test_two_bytes(void) {
const unsigned char data[] = {0x45, 0x00};
assert(ip4sum_checksum_oneshot(data, sizeof(data))
== reference_checksum(data, sizeof(data)));
}
static void test_three_bytes(void) {
const unsigned char data[] = {0x45, 0x00, 0xAB};
assert(ip4sum_checksum_oneshot(data, sizeof(data))
== reference_checksum(data, sizeof(data)));
}
static void test_four_bytes(void) {
const unsigned char data[] = {0x45, 0x00, 0x00, 0x30};
assert(ip4sum_checksum_oneshot(data, sizeof(data))
== reference_checksum(data, sizeof(data)));
}
static void test_five_bytes(void) {
const unsigned char data[] = {0x45, 0x00, 0x00, 0x30, 0xFF};
assert(ip4sum_checksum_oneshot(data, sizeof(data))
== reference_checksum(data, sizeof(data)));
}
static void test_ipv4_header_1(void) {
const unsigned char data[20] = {
0x45, 0x00, 0x00, 0x30, 0x00, 0x00, 0x40, 0x00,
0x40, 0x01, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x01,
0x0A, 0x00, 0x00, 0x02
};
unsigned char verified[20];
memcpy(verified, data, 20);
uint16_t csum = ip4sum_checksum_oneshot(data, 20);
verified[10] = (unsigned char)(csum >> 8);
verified[11] = (unsigned char)(csum & 0xFF);
assert(ip4sum_checksum_oneshot(verified, 20) == 0);
}
static void test_ipv4_header_2(void) {
const unsigned char data[20] = {
0x45, 0x00, 0x00, 0x3C, 0x1C, 0x46, 0x40, 0x00,
0x40, 0x06, 0x00, 0x00, 0xAC, 0x10, 0x0A, 0x63,
0xAC, 0x10, 0x0A, 0x0C
};
unsigned char verified[20];
memcpy(verified, data, 20);
uint16_t csum = ip4sum_checksum_oneshot(data, 20);
verified[10] = (unsigned char)(csum >> 8);
verified[11] = (unsigned char)(csum & 0xFF);
assert(ip4sum_checksum_oneshot(verified, 20) == 0);
}
static void test_incremental_two_parts(void) {
const unsigned char data[20] = {
0x45, 0x00, 0x00, 0x30, 0x00, 0x00, 0x40, 0x00,
0x40, 0x01, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x01,
0x0A, 0x00, 0x00, 0x02
};
uint16_t csum_oneshot = ip4sum_checksum_oneshot(data, 20);
ip4sum_checksum c = ip4sum_checksum_new();
ip4sum_checksum_update(&c, data, 10);
ip4sum_checksum_update(&c, data + 10, 10);
uint16_t csum_inc = ip4sum_checksum_finalize(c);
assert(csum_oneshot == csum_inc);
}
static void test_incremental_three_parts(void) {
const unsigned char data[20] = {
0x45, 0x00, 0x00, 0x3C, 0x1C, 0x46, 0x40, 0x00,
0x40, 0x06, 0x00, 0x00, 0xAC, 0x10, 0x0A, 0x63,
0xAC, 0x10, 0x0A, 0x0C
};
uint16_t csum_oneshot = ip4sum_checksum_oneshot(data, 20);
ip4sum_checksum c = ip4sum_checksum_new();
ip4sum_checksum_update(&c, data, 4);
ip4sum_checksum_update(&c, data + 4, 12);
ip4sum_checksum_update(&c, data + 16, 4);
uint16_t csum_inc = ip4sum_checksum_finalize(c);
assert(csum_oneshot == csum_inc);
}
static void test_incremental_with_zeroed_checksum(void) {
const unsigned char data[20] = {
0x45, 0x00, 0x00, 0x30, 0x00, 0x00, 0x40, 0x00,
0x40, 0x01, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x01,
0x0A, 0x00, 0x00, 0x02
};
uint16_t csum_oneshot = ip4sum_checksum_oneshot(data, 20);
const unsigned char zero[2] = {0, 0};
ip4sum_checksum c = ip4sum_checksum_new();
ip4sum_checksum_update(&c, data, 10);
ip4sum_checksum_update(&c, zero, 2);
ip4sum_checksum_update(&c, data + 12, 8);
uint16_t csum_inc = ip4sum_checksum_finalize(c);
assert(csum_oneshot == csum_inc);
}
static void test_reset(void) {
const unsigned char data[20] = {
0x45, 0x00, 0x00, 0x30, 0x00, 0x00, 0x40, 0x00,
0x40, 0x01, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x01,
0x0A, 0x00, 0x00, 0x02
};
uint16_t expected = ip4sum_checksum_oneshot(data, 20);
ip4sum_checksum c = ip4sum_checksum_new();
const unsigned char junk[] = {0xFF, 0xFF, 0xFF, 0xFF};
ip4sum_checksum_update(&c, junk, sizeof(junk));
ip4sum_checksum_reset(&c);
ip4sum_checksum_update(&c, data, 20);
assert(ip4sum_checksum_finalize(c) == expected);
}
static void test_sizes_0_to_256(void) {
unsigned char buf[256];
fill_deterministic(buf, sizeof(buf), 0);
for (size_t len = 0; len <= 256; len++) {
uint16_t got = ip4sum_checksum_oneshot(buf, len);
uint16_t want = reference_checksum(buf, len);
assert(got == want);
}
}
static void test_size_63_64_65(void) {
unsigned char buf[256];
memset(buf, 0xAB, sizeof(buf));
const size_t lens[] = {63, 64, 65};
for (size_t i = 0; i < 3; i++) {
size_t len = lens[i];
assert(ip4sum_checksum_oneshot(buf, len)
== reference_checksum(buf, len));
}
}
static void test_size_127_128_129(void) {
unsigned char buf[256];
memset(buf, 0xCD, sizeof(buf));
const size_t lens[] = {127, 128, 129};
for (size_t i = 0; i < 3; i++) {
size_t len = lens[i];
assert(ip4sum_checksum_oneshot(buf, len)
== reference_checksum(buf, len));
}
}
static void test_fold_zero(void) {
ip4sum_checksum c = ip4sum_checksum_new();
assert(ip4sum_checksum_finalize(c) == 0xFFFF);
}
static void test_fold_known_value(void) {
const unsigned char data[] = {0xFF, 0xFF};
assert(ip4sum_checksum_oneshot(data, sizeof(data)) == 0);
}
static void test_verification_property(void) {
unsigned char buf[60];
for (size_t i = 0; i < sizeof(buf); i++) {
buf[i] = (unsigned char)((i * 7u + 0x5Au) & 0xFFu);
}
buf[10] = 0;
buf[11] = 0;
uint16_t csum = ip4sum_checksum_oneshot(buf, 60);
buf[10] = (unsigned char)(csum >> 8);
buf[11] = (unsigned char)(csum & 0xFF);
assert(ip4sum_checksum_oneshot(buf, 60) == 0);
}
static void test_1500_bytes(void) {
unsigned char buf[1500];
for (size_t i = 0; i < sizeof(buf); i++) {
buf[i] = (unsigned char)((i + 0x37u) & 0xFFu);
}
assert(ip4sum_checksum_oneshot(buf, 1500)
== reference_checksum(buf, 1500));
}
static void test_odd_length_large(void) {
unsigned char buf[1501];
for (size_t i = 0; i < sizeof(buf); i++) {
buf[i] = (unsigned char)((i * 13u) & 0xFFu);
}
assert(ip4sum_checksum_oneshot(buf, 1501)
== reference_checksum(buf, 1501));
}
static void test_tcp_checksum_incremental(void) {
const unsigned char pseudo[12] = {
192, 168, 1, 1, 192, 168, 1, 2, 0, 0, 0, 6
};
const unsigned char tcp_length[2] = {0x00, 0x14};
const unsigned char tcp_hdr[20] = {
0x00, 0x50, 0x00, 0x50, 0x00, 0x00, 0x00, 0x01,
0x00, 0x00, 0x00, 0x00, 0x50, 0x02, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00
};
unsigned char combined[34];
memcpy(combined, pseudo, 12);
memcpy(combined + 12, tcp_length, 2);
memcpy(combined + 14, tcp_hdr, 20);
uint16_t csum_oneshot = ip4sum_checksum_oneshot(combined, 34);
ip4sum_checksum c = ip4sum_checksum_new();
ip4sum_checksum_update(&c, pseudo, 12);
ip4sum_checksum_update(&c, tcp_length, 2);
ip4sum_checksum_update(&c, tcp_hdr, 20);
uint16_t csum_inc = ip4sum_checksum_finalize(c);
assert(csum_oneshot == csum_inc);
}
int main(void) {
int pass = 0;
test_empty(); pass++;
test_single_byte(); pass++;
test_two_bytes(); pass++;
test_three_bytes(); pass++;
test_four_bytes(); pass++;
test_five_bytes(); pass++;
test_ipv4_header_1(); pass++;
test_ipv4_header_2(); pass++;
test_incremental_two_parts(); pass++;
test_incremental_three_parts(); pass++;
test_incremental_with_zeroed_checksum(); pass++;
test_reset(); pass++;
test_sizes_0_to_256(); pass++;
test_size_63_64_65(); pass++;
test_size_127_128_129(); pass++;
test_fold_zero(); pass++;
test_fold_known_value(); pass++;
test_verification_property(); pass++;
test_1500_bytes(); pass++;
test_odd_length_large(); pass++;
test_tcp_checksum_incremental(); pass++;
printf("All %d tests passed.\n", pass);
return 0;
}