use std::io;
use log::debug;
use crate::store::checksum_input::ChecksumIndexInput;
use crate::store::{DataInput, DataOutput, IndexInput, IndexOutput};
pub const CODEC_MAGIC: i32 = 0x3fd76c17_u32 as i32;
pub const FOOTER_MAGIC: i32 = !CODEC_MAGIC;
pub const FOOTER_LENGTH: usize = 16;
pub const ID_LENGTH: usize = 16;
pub fn write_header(out: &mut dyn DataOutput, codec: &str, version: i32) -> io::Result<usize> {
validate_codec_name(codec)?;
out.write_be_int(CODEC_MAGIC)?;
out.write_string(codec)?;
out.write_be_int(version)?;
Ok(header_length(codec))
}
pub fn write_index_header(
out: &mut dyn DataOutput,
codec: &str,
version: i32,
id: &[u8; ID_LENGTH],
suffix: &str,
) -> io::Result<usize> {
write_header(out, codec, version)?;
out.write_bytes(id)?;
let suffix_bytes = suffix.as_bytes();
if suffix_bytes.len() > 255 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("suffix too long: {}", suffix_bytes.len()),
));
}
out.write_byte(suffix_bytes.len() as u8)?;
out.write_bytes(suffix_bytes)?;
debug!(
"write_index_header: codec={codec:?}, version={version}, suffix={suffix:?}, id={id:02x?}"
);
Ok(index_header_length(codec, suffix))
}
pub fn write_footer(out: &mut dyn IndexOutput) -> io::Result<()> {
out.write_be_int(FOOTER_MAGIC)?;
out.write_be_int(0)?; let checksum = out.checksum();
debug!("write_footer: checksum=0x{checksum:08x}");
out.write_be_long(checksum as i64)?;
Ok(())
}
pub fn header_length(codec: &str) -> usize {
4 + vint_size(codec.len() as u32) + codec.len() + 4
}
pub fn index_header_length(codec: &str, suffix: &str) -> usize {
header_length(codec) + ID_LENGTH + 1 + suffix.len()
}
fn vint_size(mut val: u32) -> usize {
let mut size = 1;
while val > 0x7F {
val >>= 7;
size += 1;
}
size
}
pub fn check_header(
input: &mut dyn DataInput,
codec: &str,
min_version: i32,
max_version: i32,
) -> io::Result<i32> {
let actual_magic = input.read_be_int()?;
if actual_magic != CODEC_MAGIC {
return Err(io::Error::other(format!(
"codec header mismatch: expected 0x{CODEC_MAGIC:08X}, got 0x{actual_magic:08X}"
)));
}
let actual_codec = input.read_string()?;
if actual_codec != codec {
return Err(io::Error::other(format!(
"codec mismatch: expected {codec:?}, got {actual_codec:?}"
)));
}
let version = input.read_be_int()?;
if version < min_version || version > max_version {
return Err(io::Error::other(format!(
"version {version} out of range [{min_version}, {max_version}] for codec {codec:?}"
)));
}
Ok(version)
}
pub fn check_index_header(
input: &mut dyn DataInput,
codec: &str,
min_version: i32,
max_version: i32,
expected_id: &[u8; ID_LENGTH],
expected_suffix: &str,
) -> io::Result<i32> {
let version = check_header(input, codec, min_version, max_version)?;
let mut actual_id = [0u8; ID_LENGTH];
input.read_bytes(&mut actual_id)?;
if actual_id != *expected_id {
return Err(io::Error::other(format!(
"segment ID mismatch: expected {expected_id:02x?}, got {actual_id:02x?}"
)));
}
let suffix_len = input.read_byte()? as usize;
let mut suffix_bytes = vec![0u8; suffix_len];
input.read_bytes(&mut suffix_bytes)?;
let actual_suffix =
String::from_utf8(suffix_bytes).map_err(|e| io::Error::other(e.to_string()))?;
if actual_suffix != expected_suffix {
return Err(io::Error::other(format!(
"suffix mismatch: expected {expected_suffix:?}, got {actual_suffix:?}"
)));
}
Ok(version)
}
pub fn check_footer(input: &mut ChecksumIndexInput) -> io::Result<()> {
let remaining = input.length() - input.file_pointer();
if remaining != FOOTER_LENGTH as u64 {
return Err(io::Error::other(format!(
"expected {FOOTER_LENGTH} footer bytes remaining, got {remaining}"
)));
}
let magic = input.read_be_int()?;
if magic != FOOTER_MAGIC {
return Err(io::Error::other(format!(
"footer magic mismatch: expected 0x{:08X}, got 0x{magic:08X}",
FOOTER_MAGIC as u32
)));
}
let algorithm_id = input.read_be_int()?;
if algorithm_id != 0 {
return Err(io::Error::other(format!(
"unsupported checksum algorithm: {algorithm_id}"
)));
}
let checksum_before_crc = input.checksum();
let stored_checksum = input.read_be_long()? as u64;
if stored_checksum != checksum_before_crc {
return Err(io::Error::other(format!(
"checksum mismatch: stored 0x{stored_checksum:08X}, computed 0x{checksum_before_crc:08X}"
)));
}
Ok(())
}
pub fn retrieve_checksum(input: &mut dyn IndexInput) -> io::Result<i64> {
let len = input.length();
if len < FOOTER_LENGTH as u64 {
return Err(io::Error::other(format!(
"misplaced codec footer (file truncated?): length={len} but footerLength=={FOOTER_LENGTH}"
)));
}
input.seek(len - FOOTER_LENGTH as u64)?;
validate_footer(input)?;
read_crc(input)
}
pub fn retrieve_checksum_with_length(
input: &mut dyn IndexInput,
expected_length: i64,
) -> io::Result<i64> {
if expected_length < FOOTER_LENGTH as i64 {
return Err(io::Error::other(
"expectedLength cannot be less than the footer length",
));
}
let actual = input.length() as i64;
if actual < expected_length {
return Err(io::Error::other(format!(
"truncated file: length={actual} but expectedLength=={expected_length}"
)));
} else if actual > expected_length {
return Err(io::Error::other(format!(
"file too long: length={actual} but expectedLength=={expected_length}"
)));
}
retrieve_checksum(input)
}
fn validate_footer(input: &mut dyn IndexInput) -> io::Result<()> {
let magic = input.read_be_int()?;
if magic != FOOTER_MAGIC {
return Err(io::Error::other(format!(
"codec footer mismatch (file truncated?): actual footer=0x{:08X} vs expected footer=0x{:08X}",
magic as u32, FOOTER_MAGIC as u32
)));
}
let algorithm_id = input.read_be_int()?;
if algorithm_id != 0 {
return Err(io::Error::other(format!(
"codec footer mismatch: unknown algorithmID: {algorithm_id}"
)));
}
Ok(())
}
fn read_crc(input: &mut dyn IndexInput) -> io::Result<i64> {
let checksum = input.read_be_long()?;
if (checksum as u64) & 0xFFFF_FFFF_0000_0000 != 0 {
return Err(io::Error::other(format!(
"illegal CRC-32 checksum: {checksum}"
)));
}
Ok(checksum)
}
#[cfg(test)]
pub fn checksum_entire_file(input: Box<dyn IndexInput>) -> io::Result<u64> {
let len = input.length();
let mut checksum_input = ChecksumIndexInput::new(input);
checksum_input.seek(0)?;
checksum_input.skip_bytes(len)?;
Ok(checksum_input.checksum())
}
fn validate_codec_name(codec: &str) -> io::Result<()> {
if codec.len() >= 128 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("codec name too long: {}", codec.len()),
));
}
if !codec.bytes().all(|b| b.is_ascii_graphic() || b == b' ') {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"codec name must be simple ASCII",
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::store::byte_slice_input::ByteSliceIndexInput;
use crate::store::memory::MemoryIndexOutput;
#[test]
fn test_header_length() {
assert_eq!(header_length("FooBar"), 9 + 6);
}
#[test]
fn test_write_header() {
let mut out = MemoryIndexOutput::new("test".to_string());
let len = write_header(&mut out, "FooBar", 5).unwrap();
let bytes = out.bytes();
assert_eq!(len, 15);
assert_len_eq_x!(&bytes, 15);
assert_eq!(bytes[0], 0x3f);
assert_eq!(bytes[1], 0xd7);
assert_eq!(bytes[2], 0x6c);
assert_eq!(bytes[3], 0x17);
assert_eq!(bytes[4], 6);
assert_eq!(&bytes[5..11], b"FooBar");
assert_eq!(bytes[11], 0);
assert_eq!(bytes[12], 0);
assert_eq!(bytes[13], 0);
assert_eq!(bytes[14], 5);
}
#[test]
fn test_write_index_header() {
let mut out = MemoryIndexOutput::new("test".to_string());
let id = [1u8; 16];
let len = write_index_header(&mut out, "FooBar", 5, &id, "xyz").unwrap();
let bytes = out.bytes();
assert_eq!(len, 35);
assert_len_eq_x!(&bytes, 35);
assert_eq!(&bytes[15..31], &[1u8; 16]);
assert_eq!(bytes[31], 3);
assert_eq!(&bytes[32..35], b"xyz");
}
#[test]
fn test_write_footer() {
let mut out = MemoryIndexOutput::new("test".to_string());
out.write_bytes(b"hello").unwrap();
write_footer(&mut out).unwrap();
let bytes = out.bytes();
assert_len_eq_x!(&bytes, 21);
let footer_start = 5;
assert_eq!(bytes[footer_start], 0xc0);
assert_eq!(bytes[footer_start + 1], 0x28);
assert_eq!(bytes[footer_start + 2], 0x93);
assert_eq!(bytes[footer_start + 3], 0xe8);
assert_eq!(&bytes[footer_start + 4..footer_start + 8], &[0, 0, 0, 0]);
assert_eq!(&bytes[footer_start + 8..footer_start + 12], &[0, 0, 0, 0]);
}
#[test]
fn test_footer_magic_is_not_of_codec_magic() {
assert_eq!(FOOTER_MAGIC, !CODEC_MAGIC);
assert_eq!(CODEC_MAGIC, 0x3fd76c17_u32 as i32);
}
#[test]
fn test_validate_codec_name_empty() {
let mut out = MemoryIndexOutput::new("test".to_string());
assert_ok!(write_header(&mut out, "", 0));
}
#[test]
fn test_validate_codec_name_too_long() {
let long_name: String = "a".repeat(128);
let mut out = MemoryIndexOutput::new("test".to_string());
assert_err!(write_header(&mut out, &long_name, 0));
}
#[test]
fn test_validate_codec_name_non_ascii() {
let mut out = MemoryIndexOutput::new("test".to_string());
assert_err!(write_header(&mut out, "bad\x01name", 0));
}
#[test]
fn test_index_header_suffix_too_long() {
let mut out = MemoryIndexOutput::new("test".to_string());
let id = [0u8; 16];
let long_suffix: String = "x".repeat(256);
assert_err!(write_index_header(&mut out, "Test", 1, &id, &long_suffix));
}
#[test]
fn test_vint_size_multi_byte() {
let name = "a".repeat(127);
assert_eq!(header_length(&name), 4 + 1 + 127 + 4);
}
#[test]
fn test_check_header_roundtrip() {
let mut out = MemoryIndexOutput::new("test".to_string());
write_header(&mut out, "FooBar", 5).unwrap();
let bytes = out.bytes().to_vec();
let mut input = ByteSliceIndexInput::new("test".into(), bytes);
let version = check_header(&mut input, "FooBar", 1, 10).unwrap();
assert_eq!(version, 5);
}
#[test]
fn test_check_header_wrong_magic() {
let bytes = vec![0x00, 0x00, 0x00, 0x00]; let mut input = ByteSliceIndexInput::new("test".into(), bytes);
assert_err!(check_header(&mut input, "Test", 1, 1));
}
#[test]
fn test_check_header_wrong_codec() {
let mut out = MemoryIndexOutput::new("test".to_string());
write_header(&mut out, "FooBar", 5).unwrap();
let bytes = out.bytes().to_vec();
let mut input = ByteSliceIndexInput::new("test".into(), bytes);
assert_err!(check_header(&mut input, "WrongName", 1, 10));
}
#[test]
fn test_check_header_version_too_low() {
let mut out = MemoryIndexOutput::new("test".to_string());
write_header(&mut out, "Test", 3).unwrap();
let bytes = out.bytes().to_vec();
let mut input = ByteSliceIndexInput::new("test".into(), bytes);
assert_err!(check_header(&mut input, "Test", 5, 10));
}
#[test]
fn test_check_header_version_too_high() {
let mut out = MemoryIndexOutput::new("test".to_string());
write_header(&mut out, "Test", 15).unwrap();
let bytes = out.bytes().to_vec();
let mut input = ByteSliceIndexInput::new("test".into(), bytes);
assert_err!(check_header(&mut input, "Test", 1, 10));
}
#[test]
fn test_check_index_header_roundtrip() {
let mut out = MemoryIndexOutput::new("test".to_string());
let id = [0xABu8; ID_LENGTH];
write_index_header(&mut out, "FooBar", 5, &id, "xyz").unwrap();
let bytes = out.bytes().to_vec();
let mut input = ByteSliceIndexInput::new("test".into(), bytes);
let version = check_index_header(&mut input, "FooBar", 1, 10, &id, "xyz").unwrap();
assert_eq!(version, 5);
}
#[test]
fn test_check_index_header_wrong_id() {
let mut out = MemoryIndexOutput::new("test".to_string());
let id = [0xABu8; ID_LENGTH];
write_index_header(&mut out, "Test", 1, &id, "s").unwrap();
let bytes = out.bytes().to_vec();
let wrong_id = [0xCDu8; ID_LENGTH];
let mut input = ByteSliceIndexInput::new("test".into(), bytes);
assert_err!(check_index_header(&mut input, "Test", 1, 1, &wrong_id, "s"));
}
#[test]
fn test_check_index_header_wrong_suffix() {
let mut out = MemoryIndexOutput::new("test".to_string());
let id = [1u8; ID_LENGTH];
write_index_header(&mut out, "Test", 1, &id, "abc").unwrap();
let bytes = out.bytes().to_vec();
let mut input = ByteSliceIndexInput::new("test".into(), bytes);
assert_err!(check_index_header(&mut input, "Test", 1, 1, &id, "xyz"));
}
#[test]
fn test_check_footer_roundtrip() {
let mut out = MemoryIndexOutput::new("test".to_string());
write_header(&mut out, "Test", 1).unwrap();
out.write_bytes(b"payload data").unwrap();
write_footer(&mut out).unwrap();
let bytes = out.bytes().to_vec();
let inner = ByteSliceIndexInput::new("test".into(), bytes.clone());
let mut input = ChecksumIndexInput::new(Box::new(inner));
let footer_pos = bytes.len() as u64 - FOOTER_LENGTH as u64;
input.seek(footer_pos).unwrap();
check_footer(&mut input).unwrap();
}
#[test]
fn test_check_footer_corrupted_crc() {
let mut out = MemoryIndexOutput::new("test".to_string());
write_header(&mut out, "Test", 1).unwrap();
out.write_bytes(b"payload data").unwrap();
write_footer(&mut out).unwrap();
let mut bytes = out.bytes().to_vec();
bytes[header_length("Test")] ^= 0xFF;
let inner = ByteSliceIndexInput::new("test".into(), bytes.clone());
let mut input = ChecksumIndexInput::new(Box::new(inner));
let footer_pos = bytes.len() as u64 - FOOTER_LENGTH as u64;
input.seek(footer_pos).unwrap();
assert_err!(check_footer(&mut input));
}
#[test]
fn test_checksum_entire_file() {
let mut out = MemoryIndexOutput::new("test".to_string());
out.write_bytes(b"hello world").unwrap();
let expected = out.checksum();
let bytes = out.bytes().to_vec();
let input = ByteSliceIndexInput::new("test".into(), bytes);
let actual = checksum_entire_file(Box::new(input)).unwrap();
assert_eq!(actual, expected);
}
#[test]
fn test_footer_covers_preceding_bytes() {
let mut out = MemoryIndexOutput::new("test".to_string());
out.write_string("test data").unwrap();
let checksum_before_crc = {
let mut out2 = MemoryIndexOutput::new("test2".to_string());
out2.write_string("test data").unwrap();
out2.write_be_int(FOOTER_MAGIC).unwrap();
out2.write_be_int(0).unwrap();
out2.checksum()
};
write_footer(&mut out).unwrap();
let bytes = out.bytes();
let footer_crc_offset = bytes.len() - 8;
let written_crc = u64::from_be_bytes(
bytes[footer_crc_offset..footer_crc_offset + 8]
.try_into()
.unwrap(),
);
assert_eq!(written_crc, checksum_before_crc);
}
fn make_valid_file() -> Vec<u8> {
let mut out = MemoryIndexOutput::new("test".to_string());
write_index_header(&mut out, "TestCodec", 1, &[0u8; 16], "").unwrap();
out.write_bytes(b"payload").unwrap();
write_footer(&mut out).unwrap();
out.into_inner().data
}
#[test]
fn test_retrieve_checksum_valid() {
let data = make_valid_file();
let mut input = ByteSliceIndexInput::new("test".into(), data);
let crc = retrieve_checksum(&mut input).unwrap();
assert_ge!(crc, 0);
}
#[test]
fn test_retrieve_checksum_truncated() {
let mut input = ByteSliceIndexInput::new("test".into(), vec![0u8; 4]);
assert!(retrieve_checksum(&mut input).is_err());
}
#[test]
fn test_retrieve_checksum_with_length_valid() {
let data = make_valid_file();
let expected_length = data.len() as i64;
let mut input = ByteSliceIndexInput::new("test".into(), data);
let crc = retrieve_checksum_with_length(&mut input, expected_length).unwrap();
assert_ge!(crc, 0);
}
#[test]
fn test_retrieve_checksum_with_length_too_short() {
let data = make_valid_file();
let mut input = ByteSliceIndexInput::new("test".into(), data.clone());
let result = retrieve_checksum_with_length(&mut input, data.len() as i64 + 10);
assert!(result.is_err());
}
#[test]
fn test_retrieve_checksum_with_length_too_long() {
let data = make_valid_file();
let mut input = ByteSliceIndexInput::new("test".into(), data);
let result = retrieve_checksum_with_length(&mut input, FOOTER_LENGTH as i64 + 1);
assert!(result.is_err());
}
}