use rkyv::rancor::Error as RkyvError;
use rkyv::util::AlignedVec;
use crate::error::{PartialError, PartialFormatErrorKind};
use crate::traits::PartialExpectation;
pub const MAGIC: [u8; 4] = *b"VRPS";
pub const FORMAT_VERSION: u8 = 2;
const MIN_PARTIAL_BYTES: usize = MAGIC.len() + 1 + 4 + 4;
const RKYV_ALIGN: usize = 16;
#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize, Debug, Clone)]
pub struct WireEnvelopeHeader {
pub paradigm_kind: u8,
pub discriminator: u32,
pub parity_mode: u8,
pub rank_id: Option<u32>,
pub dataset_hash: [u8; 32],
pub params_hash: [u8; 32],
pub shape_fingerprint: [u32; 4],
}
pub fn encode(header: &WireEnvelopeHeader, body_archive: &[u8]) -> Result<Vec<u8>, PartialError> {
let header_archive = rkyv::to_bytes::<RkyvError>(header).map_err(|e| PartialError::Format {
kind: PartialFormatErrorKind::RkyvDecode {
detail: format!("rkyv::to_bytes(header) failed: {e}"),
},
})?;
let header_len = u32::try_from(header_archive.len()).map_err(|_| PartialError::Format {
kind: PartialFormatErrorKind::RkyvDecode {
detail: format!("header archive too large: {} bytes", header_archive.len()),
},
})?;
let mut out = Vec::with_capacity(MIN_PARTIAL_BYTES + header_archive.len() + body_archive.len());
out.extend_from_slice(&MAGIC);
out.push(FORMAT_VERSION);
out.extend_from_slice(&header_len.to_le_bytes());
out.extend_from_slice(&header_archive);
out.extend_from_slice(body_archive);
let crc = crc32fast::hash(&out);
out.extend_from_slice(&crc.to_le_bytes());
Ok(out)
}
pub struct ValidatedView<'a> {
pub header: &'a ArchivedWireEnvelopeHeader,
pub body_archive: &'a [u8],
}
pub fn with_validated_envelope<R>(
bytes: &[u8],
expected: &PartialExpectation,
body_callback: impl FnOnce(ValidatedView<'_>) -> Result<R, PartialError>,
) -> Result<R, PartialError> {
let (header_bytes, body_bytes) = validate_framing(bytes)?;
let mut aligned: AlignedVec<RKYV_ALIGN> = AlignedVec::with_capacity(header_bytes.len());
aligned.extend_from_slice(header_bytes);
let archived =
rkyv::access::<ArchivedWireEnvelopeHeader, RkyvError>(&aligned).map_err(|e| {
PartialError::Format {
kind: PartialFormatErrorKind::RkyvDecode {
detail: format!("rkyv::access(header) failed: {e}"),
},
}
})?;
validate_header_fields(archived, expected)?;
body_callback(ValidatedView {
header: archived,
body_archive: body_bytes,
})
}
fn validate_framing(bytes: &[u8]) -> Result<(&[u8], &[u8]), PartialError> {
if bytes.len() < MIN_PARTIAL_BYTES {
return Err(PartialError::Format {
kind: PartialFormatErrorKind::TooShort {
observed: bytes.len(),
minimum: MIN_PARTIAL_BYTES,
},
});
}
let magic: [u8; 4] = bytes[..4].try_into().map_err(|_| PartialError::Format {
kind: PartialFormatErrorKind::TooShort {
observed: bytes.len(),
minimum: MIN_PARTIAL_BYTES,
},
})?;
if magic != MAGIC {
return Err(PartialError::Format {
kind: PartialFormatErrorKind::WrongMagic { found: magic },
});
}
let version = bytes[4];
if version != FORMAT_VERSION {
return Err(PartialError::Format {
kind: PartialFormatErrorKind::WrongVersion {
expected: FORMAT_VERSION,
found: version,
},
});
}
let crc_split = bytes.len() - 4;
let stored_crc =
u32::from_le_bytes(
bytes[crc_split..]
.try_into()
.map_err(|_| PartialError::Format {
kind: PartialFormatErrorKind::Crc,
})?,
);
let actual_crc = crc32fast::hash(&bytes[..crc_split]);
if stored_crc != actual_crc {
return Err(PartialError::Format {
kind: PartialFormatErrorKind::Crc,
});
}
let header_len =
u32::from_le_bytes(bytes[5..9].try_into().map_err(|_| PartialError::Format {
kind: PartialFormatErrorKind::TooShort {
observed: bytes.len(),
minimum: MIN_PARTIAL_BYTES,
},
})?) as usize;
let header_end = 9usize.saturating_add(header_len);
if header_end > crc_split {
return Err(PartialError::Format {
kind: PartialFormatErrorKind::TooShort {
observed: bytes.len(),
minimum: header_end + 4,
},
});
}
Ok((&bytes[9..header_end], &bytes[header_end..crc_split]))
}
fn validate_header_fields(
archived: &ArchivedWireEnvelopeHeader,
expected: &PartialExpectation,
) -> Result<(), PartialError> {
let paradigm = archived.paradigm_kind;
if paradigm != expected.paradigm.as_u8() {
return Err(PartialError::Format {
kind: PartialFormatErrorKind::ParadigmMismatch {
expected: expected.paradigm.as_u8(),
found: paradigm,
},
});
}
let discriminator = archived.discriminator.to_native();
if discriminator != expected.discriminator {
return Err(PartialError::Format {
kind: PartialFormatErrorKind::KernelMismatch {
expected: expected.discriminator,
found: discriminator,
},
});
}
let parity_mode = archived.parity_mode;
if parity_mode != expected.parity_mode {
return Err(PartialError::Format {
kind: PartialFormatErrorKind::ParityMismatch {
expected: expected.parity_mode,
found: parity_mode,
},
});
}
let fingerprint = [
archived.shape_fingerprint[0].to_native(),
archived.shape_fingerprint[1].to_native(),
archived.shape_fingerprint[2].to_native(),
archived.shape_fingerprint[3].to_native(),
];
if fingerprint != expected.shape_fingerprint {
return Err(PartialError::Format {
kind: PartialFormatErrorKind::GridMismatch {
detail: format!(
"expected {:?}, got {:?}",
expected.shape_fingerprint, fingerprint
),
},
});
}
let dataset_hash: [u8; 32] = archived.dataset_hash;
if dataset_hash != expected.dataset_hash {
return Err(PartialError::DatasetMismatch {
expected: expected.dataset_hash,
actual: dataset_hash,
});
}
let params_hash: [u8; 32] = archived.params_hash;
if params_hash != expected.params_hash {
return Err(PartialError::ParamsMismatch {
expected: expected.params_hash,
actual: params_hash,
});
}
Ok(())
}
pub fn rank_id_from_archive(header: &ArchivedWireEnvelopeHeader) -> Option<u32> {
header.rank_id.as_ref().map(|v| v.to_native())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::ParadigmKind;
fn fake_expectation() -> PartialExpectation {
PartialExpectation {
paradigm: ParadigmKind::Instance,
discriminator: 0,
parity_mode: 1,
dataset_hash: [0xAB; 32],
params_hash: [0xCD; 32],
shape_fingerprint: [80, 4, 5000, 0],
strict_mode: false,
}
}
fn fake_header() -> WireEnvelopeHeader {
WireEnvelopeHeader {
paradigm_kind: ParadigmKind::Instance.as_u8(),
discriminator: 0,
parity_mode: 1,
rank_id: None,
dataset_hash: [0xAB; 32],
params_hash: [0xCD; 32],
shape_fingerprint: [80, 4, 5000, 0],
}
}
#[test]
fn round_trip_empty_body() {
let bytes = encode(&fake_header(), &[]).unwrap();
let exp = fake_expectation();
with_validated_envelope(&bytes, &exp, |view| {
assert!(view.body_archive.is_empty());
assert_eq!(view.header.paradigm_kind, ParadigmKind::Instance.as_u8());
Ok(())
})
.unwrap();
}
#[test]
fn rejects_too_short() {
let err = validate_framing(b"VRP").unwrap_err();
assert!(matches!(
err,
PartialError::Format {
kind: PartialFormatErrorKind::TooShort { .. }
}
));
}
#[test]
fn rejects_wrong_magic() {
let mut bytes = vec![0u8; MIN_PARTIAL_BYTES + 8];
bytes[..4].copy_from_slice(b"FAKE");
bytes[4] = FORMAT_VERSION;
let err = validate_framing(&bytes).unwrap_err();
assert!(matches!(
err,
PartialError::Format {
kind: PartialFormatErrorKind::WrongMagic { .. }
}
));
}
#[test]
fn rejects_wrong_version() {
let mut bytes = vec![0u8; MIN_PARTIAL_BYTES + 8];
bytes[..4].copy_from_slice(&MAGIC);
bytes[4] = 99;
let err = validate_framing(&bytes).unwrap_err();
assert!(matches!(
err,
PartialError::Format {
kind: PartialFormatErrorKind::WrongVersion { .. }
}
));
}
#[test]
fn rejects_bad_crc() {
let mut bytes = encode(&fake_header(), &[]).unwrap();
let n = bytes.len();
bytes[n - 1] ^= 0xFF;
let exp = fake_expectation();
let err = with_validated_envelope(&bytes, &exp, |_| Ok(())).unwrap_err();
assert!(matches!(
err,
PartialError::Format {
kind: PartialFormatErrorKind::Crc
}
));
}
#[test]
fn rejects_paradigm_mismatch() {
let bytes = encode(&fake_header(), &[]).unwrap();
let mut exp = fake_expectation();
exp.paradigm = ParadigmKind::Semantic;
let err = with_validated_envelope(&bytes, &exp, |_| Ok(())).unwrap_err();
match err {
PartialError::Format {
kind: PartialFormatErrorKind::ParadigmMismatch { expected, found },
} => {
assert_eq!(expected, ParadigmKind::Semantic.as_u8());
assert_eq!(found, ParadigmKind::Instance.as_u8());
}
other => panic!("unexpected error: {other:?}"),
}
}
#[test]
fn rejects_discriminator_mismatch() {
let bytes = encode(&fake_header(), &[]).unwrap();
let mut exp = fake_expectation();
exp.discriminator = 1;
let err = with_validated_envelope(&bytes, &exp, |_| Ok(())).unwrap_err();
assert!(matches!(
err,
PartialError::Format {
kind: PartialFormatErrorKind::KernelMismatch { .. }
}
));
}
#[test]
fn rejects_dataset_hash_mismatch() {
let bytes = encode(&fake_header(), &[]).unwrap();
let mut exp = fake_expectation();
exp.dataset_hash = [0; 32];
let err = with_validated_envelope(&bytes, &exp, |_| Ok(())).unwrap_err();
assert!(matches!(err, PartialError::DatasetMismatch { .. }));
}
#[test]
fn round_trip_with_body() {
let body = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b";
let bytes = encode(&fake_header(), body).unwrap();
let exp = fake_expectation();
with_validated_envelope(&bytes, &exp, |view| {
assert_eq!(view.body_archive, body);
Ok(())
})
.unwrap();
}
}