use crate::error::{Error, Result};
use crate::header::Header;
use crate::keyword::HeaderValue;
pub fn checksum(data: &[u8]) -> u32 {
let mut hi: u32 = 0;
let mut lo: u32 = 0;
for chunk in data.chunks(4) {
let mut word = [0u8; 4];
word[..chunk.len()].copy_from_slice(chunk);
hi += ((word[0] as u32) << 8) + word[1] as u32;
lo += ((word[2] as u32) << 8) + word[3] as u32;
}
loop {
let hicarry = hi >> 16;
let locarry = lo >> 16;
if hicarry == 0 && locarry == 0 {
break;
}
hi = (hi & 0xFFFF) + locarry;
lo = (lo & 0xFFFF) + hicarry;
}
(hi << 16) | lo
}
pub fn checksum_accumulate(existing: u32, new_data: &[u8]) -> u32 {
let new_sum = checksum(new_data);
ones_complement_add(existing, new_sum)
}
fn ones_complement_add(a: u32, b: u32) -> u32 {
let mut hi = (a >> 16) + (b >> 16);
let mut lo = (a & 0xFFFF) + (b & 0xFFFF);
loop {
let hicarry = hi >> 16;
let locarry = lo >> 16;
if hicarry == 0 && locarry == 0 {
break;
}
hi = (hi & 0xFFFF) + locarry;
lo = (lo & 0xFFFF) + hicarry;
}
(hi << 16) | lo
}
pub fn encode_checksum(sum: u32, complement: bool) -> String {
let exclude: [u8; 13] = [
0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x40, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f, 0x60, ];
let masks: [u32; 4] = [0xff000000, 0x00ff0000, 0x0000ff00, 0x000000ff];
let offset: u32 = 0x30;
let value = if complement { !sum } else { sum };
let mut asc = [0u8; 16];
for i in 0..4 {
let byte = (value & masks[i]) >> (24 - 8 * i);
let quotient = byte / 4 + offset;
let remainder = byte % 4;
let mut ch = [quotient as u8; 4];
ch[0] = (quotient + remainder) as u8;
let mut check = true;
while check {
check = false;
for &ex in &exclude {
for j in (0..4).step_by(2) {
if ch[j] == ex || ch[j + 1] == ex {
ch[j] = ch[j].wrapping_add(1);
ch[j + 1] = ch[j + 1].wrapping_sub(1);
check = true;
}
}
}
}
for j in 0..4 {
asc[4 * j + i] = ch[j];
}
}
let mut result = [0u8; 16];
for i in 0..16 {
result[i] = asc[(i + 15) % 16];
}
String::from_utf8(result.to_vec()).expect("checksum encoding produced non-UTF8")
}
pub fn decode_checksum(ascii: &str, complement: bool) -> u32 {
let bytes = ascii.as_bytes();
assert!(bytes.len() >= 16, "checksum string must be 16 characters");
let mut cbuf = [0u8; 16];
for i in 0..16 {
cbuf[i] = bytes[(i + 1) % 16].wrapping_sub(0x30);
}
let mut hi: u32 = 0;
let mut lo: u32 = 0;
for i in (0..16).step_by(4) {
hi += ((cbuf[i] as u32) << 8) + cbuf[i + 1] as u32;
lo += ((cbuf[i + 2] as u32) << 8) + cbuf[i + 3] as u32;
}
let mut hicarry = hi >> 16;
let mut locarry = lo >> 16;
while hicarry != 0 || locarry != 0 {
hi = (hi & 0xFFFF) + locarry;
lo = (lo & 0xFFFF) + hicarry;
hicarry = hi >> 16;
locarry = lo >> 16;
}
let sum = (hi << 16) | lo;
if complement {
!sum
} else {
sum
}
}
pub fn datasum(data: &[u8]) -> u32 {
checksum(data)
}
pub fn verify_hdu(header_bytes: &[u8], data_bytes: &[u8]) -> bool {
let sum = checksum_accumulate(checksum(header_bytes), data_bytes);
sum == 0xFFFF_FFFF || sum == 0
}
pub fn stamp_hdu(header: &mut Header, data_bytes: &[u8]) -> Result<Vec<u8>> {
let dsum = datasum(data_bytes);
header.set(
"DATASUM",
HeaderValue::String(dsum.to_string()),
Some("data unit checksum"),
);
header.set(
"CHECKSUM",
HeaderValue::String("0000000000000000".into()),
Some("HDU checksum"),
);
let mut header_bytes = Vec::new();
header.write_to(&mut header_bytes)?;
let hsum = checksum(&header_bytes);
let total = ones_complement_add(hsum, dsum);
let encoded = encode_checksum(total, true);
header.set(
"CHECKSUM",
HeaderValue::String(encoded),
Some("HDU checksum"),
);
header_bytes.clear();
header.write_to(&mut header_bytes)?;
Ok(header_bytes)
}
pub fn verify_from_header(header: &Header, data_bytes: &[u8]) -> Result<()> {
if let Some(stored_datasum_str) = header.get_string("DATASUM") {
if let Ok(stored) = stored_datasum_str.parse::<u64>() {
let stored = stored as u32;
let computed = datasum(data_bytes);
if stored != 0 && computed != stored {
return Err(Error::ChecksumMismatch {
expected: stored,
actual: computed,
});
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::io_utils;
#[test]
fn checksum_empty() {
assert_eq!(checksum(&[]), 0);
}
#[test]
fn checksum_zeros() {
let data = vec![0u8; 2880];
assert_eq!(checksum(&data), 0);
}
#[test]
fn checksum_deterministic() {
let data: Vec<u8> = (0..2880).map(|i| (i % 256) as u8).collect();
let c1 = checksum(&data);
let c2 = checksum(&data);
assert_eq!(c1, c2);
assert_ne!(c1, 0);
}
#[test]
fn encode_produces_valid_ascii() {
for val in [0u32, 1, 0xDEADBEEF, 0xFFFFFFFF, 0x12345678] {
let s = encode_checksum(val, false);
assert_eq!(s.len(), 16);
for b in s.bytes() {
assert!(
b.is_ascii_alphanumeric(),
"non-alphanumeric char {b:#x} in encoded checksum"
);
}
}
}
#[test]
fn encode_complement_produces_valid_ascii() {
for val in [0u32, 1, 0xDEADBEEF, 0xFFFFFFFF] {
let s = encode_checksum(val, true);
assert_eq!(s.len(), 16);
for b in s.bytes() {
assert!(b.is_ascii_alphanumeric());
}
}
}
#[test]
fn encode_decode_round_trip() {
for val in [0u32, 1, 42, 0xDEADBEEF, 0xFFFFFFFF, 0x12345678, 2503531142] {
let encoded = encode_checksum(val, false);
let decoded = decode_checksum(&encoded, false);
assert_eq!(decoded, val, "round-trip failed for {val:#x}");
}
}
#[test]
fn encode_decode_complement_round_trip() {
for val in [0u32, 1, 0xDEADBEEF, 0xFFFFFFFF] {
let encoded = encode_checksum(val, true);
let decoded = decode_checksum(&encoded, true);
assert_eq!(decoded, val, "complement round-trip failed for {val:#x}");
}
}
#[test]
fn stamp_and_verify() {
let mut header = Header::new();
header.set("SIMPLE", HeaderValue::Logical(true), Some("standard"));
header.set("BITPIX", HeaderValue::Integer(8), None);
header.set("NAXIS", HeaderValue::Integer(0), None);
let data_bytes = vec![0u8; 0]; let padded_data = io_utils::pad_to_block(&data_bytes);
let header_bytes = stamp_hdu(&mut header, &padded_data).unwrap();
assert!(verify_hdu(&header_bytes, &padded_data));
}
#[test]
fn stamp_and_verify_with_data() {
let mut header = Header::new();
header.set("SIMPLE", HeaderValue::Logical(true), Some("standard"));
header.set("BITPIX", HeaderValue::Integer(8), None);
header.set("NAXIS", HeaderValue::Integer(1), None);
header.set("NAXIS1", HeaderValue::Integer(100), None);
let data_bytes: Vec<u8> = (0..100).map(|i| (i * 3) as u8).collect();
let padded_data = io_utils::pad_to_block(&data_bytes);
let header_bytes = stamp_hdu(&mut header, &padded_data).unwrap();
assert!(verify_hdu(&header_bytes, &padded_data));
let dsum_str = header.get_string("DATASUM").unwrap();
let dsum: u32 = dsum_str.parse().unwrap();
assert_eq!(dsum, datasum(&padded_data));
}
#[test]
fn ones_complement_add_identity() {
assert_eq!(ones_complement_add(0, 0), 0);
assert_eq!(ones_complement_add(42, 0), 42);
assert_eq!(ones_complement_add(0, 42), 42);
}
#[test]
fn ones_complement_add_complement() {
let x = 0x12345678u32;
let result = ones_complement_add(x, !x);
assert_eq!(result, 0xFFFFFFFF);
}
#[test]
fn verify_corruption_detected() {
let mut header = Header::new();
header.set("SIMPLE", HeaderValue::Logical(true), Some("standard"));
header.set("BITPIX", HeaderValue::Integer(8), None);
header.set("NAXIS", HeaderValue::Integer(1), None);
header.set("NAXIS1", HeaderValue::Integer(100), None);
let data_bytes: Vec<u8> = (0..100).collect();
let padded_data = io_utils::pad_to_block(&data_bytes);
let header_bytes = stamp_hdu(&mut header, &padded_data).unwrap();
let mut corrupted = padded_data.clone();
corrupted[0] ^= 0xFF;
assert!(!verify_hdu(&header_bytes, &corrupted));
}
}