use crate::error::{Result, TensogramError};
use crate::types::HashDescriptor;
use crate::wire::{FRAME_HEADER_SIZE, FrameType, footer_size_for};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HashAlgorithm {
Xxh3,
}
impl HashAlgorithm {
pub fn as_str(&self) -> &'static str {
match self {
HashAlgorithm::Xxh3 => "xxh3",
}
}
pub fn parse(s: &str) -> Result<Self> {
match s {
"xxh3" => Ok(HashAlgorithm::Xxh3),
_ => Err(TensogramError::Metadata(format!("unknown hash type: {s}"))),
}
}
}
pub fn compute_hash(data: &[u8], algorithm: HashAlgorithm) -> String {
match algorithm {
HashAlgorithm::Xxh3 => format_xxh3_digest(xxhash_rust::xxh3::xxh3_64(data)),
}
}
pub fn hash_frame_body(frame_bytes: &[u8], frame_type: FrameType) -> Result<u64> {
let footer = footer_size_for(frame_type);
let min_size = FRAME_HEADER_SIZE + footer;
if frame_bytes.len() < min_size {
return Err(TensogramError::Framing(format!(
"frame too small to hash: frame_bytes.len() = {} < header({}) + footer({}) = {}; \
for frame_type = {:?}. Likely truncated; re-read from source.",
frame_bytes.len(),
FRAME_HEADER_SIZE,
footer,
min_size,
frame_type,
)));
}
let body = &frame_bytes[FRAME_HEADER_SIZE..frame_bytes.len() - footer];
Ok(xxhash_rust::xxh3::xxh3_64(body))
}
pub fn verify_frame_hash(frame_bytes: &[u8], frame_type: FrameType) -> Result<()> {
use crate::wire::{FRAME_COMMON_FOOTER_SIZE, FRAME_END, read_u64_be};
let frame_len = frame_bytes.len();
if frame_len < FRAME_COMMON_FOOTER_SIZE {
return Err(TensogramError::Framing(format!(
"frame too small to read hash slot: frame_bytes.len() = {frame_len} \
< FRAME_COMMON_FOOTER_SIZE ({FRAME_COMMON_FOOTER_SIZE}); \
truncated or not a v3 frame"
)));
}
let endf_start = frame_len - FRAME_END.len();
if &frame_bytes[endf_start..frame_len] != FRAME_END {
return Err(TensogramError::Framing(
"frame missing ENDF marker while verifying inline hash — \
likely truncated or not a v3 frame; re-read from source"
.to_string(),
));
}
let slot_start = frame_len - FRAME_COMMON_FOOTER_SIZE;
let stored = read_u64_be(frame_bytes, slot_start);
let computed = hash_frame_body(frame_bytes, frame_type)?;
if computed != stored {
return Err(TensogramError::HashMismatch {
expected: format_xxh3_digest(stored),
actual: format_xxh3_digest(computed),
});
}
Ok(())
}
#[inline]
pub(crate) fn format_xxh3_digest(digest: u64) -> String {
format!("{digest:016x}")
}
pub fn verify_hash(data: &[u8], descriptor: &HashDescriptor) -> Result<()> {
let algorithm = match HashAlgorithm::parse(&descriptor.algorithm) {
Ok(algo) => algo,
Err(_) => {
tracing::warn!(
algorithm = %descriptor.algorithm,
"unknown hash algorithm, skipping verification"
);
return Ok(());
}
};
let actual = compute_hash(data, algorithm);
if actual != descriptor.value {
return Err(TensogramError::HashMismatch {
expected: descriptor.value.clone(),
actual,
});
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_xxh3() {
let data = b"hello world";
let hash = compute_hash(data, HashAlgorithm::Xxh3);
assert_eq!(hash.len(), 16); assert_eq!(hash, compute_hash(data, HashAlgorithm::Xxh3));
}
#[test]
fn test_verify_hash() {
let data = b"test data";
let hash = compute_hash(data, HashAlgorithm::Xxh3);
let descriptor = HashDescriptor {
algorithm: "xxh3".to_string(),
value: hash,
};
assert!(verify_hash(data, &descriptor).is_ok());
}
#[test]
fn test_verify_hash_mismatch() {
let data = b"test data";
let descriptor = HashDescriptor {
algorithm: "xxh3".to_string(),
value: "0000000000000000".to_string(),
};
assert!(verify_hash(data, &descriptor).is_err());
}
#[test]
fn test_unknown_hash_type_skips_verification() {
let data = b"test data";
let descriptor = HashDescriptor {
algorithm: "sha256".to_string(),
value: "abc123".to_string(),
};
assert!(verify_hash(data, &descriptor).is_ok());
}
#[test]
fn hash_frame_body_rejects_below_minimum_size() {
let buf = vec![0u8; 30];
let err = hash_frame_body(&buf, FrameType::NTensorFrame).unwrap_err();
let msg = err.to_string();
assert!(msg.contains("frame too small to hash"));
assert!(msg.contains("frame_bytes.len() = 30"));
assert!(msg.contains("NTensorFrame"));
}
#[test]
fn verify_frame_hash_rejects_below_common_footer_size() {
let buf = vec![0u8; 10];
let err = verify_frame_hash(&buf, FrameType::HeaderMetadata).unwrap_err();
assert!(
err.to_string()
.contains("frame too small to read hash slot")
);
}
#[test]
fn verify_frame_hash_rejects_missing_endf() {
let mut buf = vec![0u8; 12];
buf[8..12].copy_from_slice(b"XXXX");
let err = verify_frame_hash(&buf, FrameType::HeaderMetadata).unwrap_err();
assert!(err.to_string().contains("ENDF"));
}
#[test]
fn verify_frame_hash_rejects_zero_slot_when_body_is_nonempty() {
use crate::wire::{FRAME_END, FRAME_MAGIC};
let body = b"hello";
let mut buf = Vec::new();
buf.extend_from_slice(FRAME_MAGIC);
buf.extend_from_slice(&1u16.to_be_bytes()); buf.extend_from_slice(&1u16.to_be_bytes()); buf.extend_from_slice(&0u16.to_be_bytes()); let total_length = 16 + body.len() + 12;
buf.extend_from_slice(&(total_length as u64).to_be_bytes());
buf.extend_from_slice(body);
buf.extend_from_slice(&0u64.to_be_bytes()); buf.extend_from_slice(FRAME_END);
let err = verify_frame_hash(&buf, FrameType::HeaderMetadata).unwrap_err();
assert!(matches!(err, TensogramError::HashMismatch { .. }));
}
#[test]
fn verify_frame_hash_accepts_zero_slot_only_when_body_hashes_to_zero() {
use crate::wire::{FRAME_END, FRAME_HEADER_SIZE, FRAME_MAGIC};
let body: &[u8] = b"";
let mut buf = Vec::new();
buf.extend_from_slice(FRAME_MAGIC);
buf.extend_from_slice(&1u16.to_be_bytes());
buf.extend_from_slice(&1u16.to_be_bytes());
buf.extend_from_slice(&0u16.to_be_bytes());
let total_length = FRAME_HEADER_SIZE + body.len() + 12;
buf.extend_from_slice(&(total_length as u64).to_be_bytes());
buf.extend_from_slice(body);
buf.extend_from_slice(&0u64.to_be_bytes()); buf.extend_from_slice(FRAME_END);
let result = verify_frame_hash(&buf, FrameType::HeaderMetadata);
assert!(matches!(result, Err(TensogramError::HashMismatch { .. })));
}
#[test]
fn verify_frame_hash_reports_mismatch_on_tampered_slot() {
use crate::wire::{FRAME_END, FRAME_MAGIC};
let body = b"hello";
let mut buf = Vec::new();
buf.extend_from_slice(FRAME_MAGIC);
buf.extend_from_slice(&1u16.to_be_bytes());
buf.extend_from_slice(&1u16.to_be_bytes());
buf.extend_from_slice(&0u16.to_be_bytes());
let total_length = 16 + body.len() + 12;
buf.extend_from_slice(&(total_length as u64).to_be_bytes());
buf.extend_from_slice(body);
buf.extend_from_slice(&0xDEADBEEFCAFEBABEu64.to_be_bytes()); buf.extend_from_slice(FRAME_END);
let err = verify_frame_hash(&buf, FrameType::HeaderMetadata).unwrap_err();
assert!(matches!(err, TensogramError::HashMismatch { .. }));
let msg = err.to_string();
assert!(msg.contains("deadbeefcafebabe"));
}
}