use serde::{Deserialize, Serialize};
pub const AXON_MAGIC: [u8; 4] = [0x41, 0x58, 0x4F, 0x4E]; pub const AXON_VERSION: u8 = 1;
pub const HEADER_SIZE: usize = 16;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[repr(u8)]
pub enum WireMessageType {
Delta = 0,
Signal = 1,
Presence = 2,
Intent = 3,
}
impl WireMessageType {
pub fn from_u8(v: u8) -> Option<Self> {
match v {
0 => Some(Self::Delta),
1 => Some(Self::Signal),
2 => Some(Self::Presence),
3 => Some(Self::Intent),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct AxonHeader {
pub version: u8,
pub msg_type: WireMessageType,
pub flags: u16,
pub payload_len: u32,
pub checksum: u32,
}
#[derive(Debug)]
pub enum WireError {
InvalidMagic,
InvalidMessageType(u8),
ChecksumMismatch { expected: u32, actual: u32 },
BufferTooShort { needed: usize, got: usize },
}
impl std::fmt::Display for WireError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidMagic => write!(f, "Invalid AXON magic number"),
Self::InvalidMessageType(t) => write!(f, "Invalid message type: {}", t),
Self::ChecksumMismatch { expected, actual } => write!(
f,
"Checksum mismatch: expected {:#x}, got {:#x}",
expected, actual
),
Self::BufferTooShort { needed, got } => {
write!(f, "Buffer too short: needed {} bytes, got {}", needed, got)
}
}
}
}
impl std::error::Error for WireError {}
pub fn serialize(msg_type: WireMessageType, payload: &[u8]) -> Vec<u8> {
let checksum = crc32fast::hash(payload);
let mut buf = Vec::with_capacity(HEADER_SIZE + payload.len());
buf.extend_from_slice(&AXON_MAGIC);
buf.push(AXON_VERSION);
buf.push(msg_type as u8);
buf.extend_from_slice(&0u16.to_le_bytes()); buf.extend_from_slice(&(payload.len() as u32).to_le_bytes());
buf.extend_from_slice(&checksum.to_le_bytes());
buf.extend_from_slice(payload);
buf
}
pub fn deserialize(buf: &[u8]) -> Result<(AxonHeader, &[u8]), WireError> {
if buf.len() < HEADER_SIZE {
return Err(WireError::BufferTooShort {
needed: HEADER_SIZE,
got: buf.len(),
});
}
if buf[0..4] != AXON_MAGIC {
return Err(WireError::InvalidMagic);
}
let version = buf[4];
let msg_type = WireMessageType::from_u8(buf[5]).ok_or(WireError::InvalidMessageType(buf[5]))?;
let flags = u16::from_le_bytes([buf[6], buf[7]]);
let payload_len = u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]) as usize;
let checksum = u32::from_le_bytes([buf[12], buf[13], buf[14], buf[15]]);
let total = HEADER_SIZE + payload_len;
if buf.len() < total {
return Err(WireError::BufferTooShort {
needed: total,
got: buf.len(),
});
}
let payload = &buf[HEADER_SIZE..total];
let actual_checksum = crc32fast::hash(payload);
if actual_checksum != checksum {
return Err(WireError::ChecksumMismatch {
expected: checksum,
actual: actual_checksum,
});
}
Ok((
AxonHeader {
version,
msg_type,
flags,
payload_len: payload_len as u32,
checksum,
},
payload,
))
}
#[derive(Debug, Clone, Copy)]
pub struct DeltaEntry {
pub offset: u32,
pub value: f32,
}
pub fn serialize_deltas(deltas: &[DeltaEntry]) -> Vec<u8> {
let payload_len = deltas.len() * 8;
let mut payload = Vec::with_capacity(payload_len);
for d in deltas {
payload.extend_from_slice(&d.offset.to_le_bytes());
payload.extend_from_slice(&d.value.to_le_bytes());
}
serialize(WireMessageType::Delta, &payload)
}
pub fn deserialize_deltas(buf: &[u8]) -> Result<Vec<DeltaEntry>, WireError> {
let (header, payload) = deserialize(buf)?;
debug_assert_eq!(header.msg_type, WireMessageType::Delta);
let mut deltas = Vec::with_capacity(payload.len() / 8);
for chunk in payload.chunks_exact(8) {
let offset = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
let value = f32::from_le_bytes([chunk[4], chunk[5], chunk[6], chunk[7]]);
deltas.push(DeltaEntry { offset, value });
}
Ok(deltas)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_roundtrip() {
let payload = b"hello axon";
let buf = serialize(WireMessageType::Signal, payload);
let (header, decoded) = deserialize(&buf).unwrap();
assert_eq!(header.version, 1);
assert_eq!(header.msg_type, WireMessageType::Signal);
assert_eq!(decoded, payload);
}
#[test]
fn test_header_size() {
let buf = serialize(WireMessageType::Signal, b"test");
assert_eq!(buf.len(), 16 + 4);
}
#[test]
fn test_magic_validation() {
let mut buf = serialize(WireMessageType::Signal, b"test");
buf[0] = 0xFF; assert!(matches!(deserialize(&buf), Err(WireError::InvalidMagic)));
}
#[test]
fn test_checksum_validation() {
let mut buf = serialize(WireMessageType::Signal, b"test");
if let Some(last) = buf.last_mut() {
*last ^= 0xFF;
}
assert!(matches!(
deserialize(&buf),
Err(WireError::ChecksumMismatch { .. })
));
}
#[test]
fn test_invalid_message_type() {
let mut buf = serialize(WireMessageType::Signal, b"test");
buf[5] = 99; assert!(matches!(
deserialize(&buf),
Err(WireError::InvalidMessageType(99))
));
}
#[test]
fn test_buffer_too_short() {
assert!(matches!(
deserialize(&[0; 4]),
Err(WireError::BufferTooShort { .. })
));
}
#[test]
fn test_delta_roundtrip() {
let deltas = vec![
DeltaEntry {
offset: 0,
value: 1.5,
},
DeltaEntry {
offset: 42,
value: -3.125,
},
];
let buf = serialize_deltas(&deltas);
let decoded = deserialize_deltas(&buf).unwrap();
assert_eq!(decoded.len(), 2);
assert_eq!(decoded[0].offset, 0);
assert!((decoded[0].value - 1.5).abs() < 1e-6);
assert_eq!(decoded[1].offset, 42);
assert!((decoded[1].value - (-3.125_f32)).abs() < 1e-5);
}
#[test]
fn test_empty_payload() {
let buf = serialize(WireMessageType::Presence, &[]);
let (header, payload) = deserialize(&buf).unwrap();
assert_eq!(header.payload_len, 0);
assert_eq!(payload.len(), 0);
}
#[test]
fn test_all_message_types() {
for (byte, expected) in [
(0u8, WireMessageType::Delta),
(1, WireMessageType::Signal),
(2, WireMessageType::Presence),
(3, WireMessageType::Intent),
] {
assert_eq!(WireMessageType::from_u8(byte), Some(expected));
}
assert!(WireMessageType::from_u8(4).is_none());
assert!(WireMessageType::from_u8(255).is_none());
}
#[test]
fn test_checksum_is_real_crc32() {
let payload = b"verify crc";
let buf = serialize(WireMessageType::Intent, payload);
let stored_checksum = u32::from_le_bytes([buf[12], buf[13], buf[14], buf[15]]);
let expected = crc32fast::hash(payload);
assert_eq!(stored_checksum, expected);
assert_ne!(stored_checksum, 0); }
#[test]
fn test_serialize_sets_correct_payload_len() {
let payload = b"length check payload";
let buf = serialize(WireMessageType::Delta, payload);
let stored_len = u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]);
assert_eq!(stored_len as usize, payload.len());
}
#[test]
fn test_flags_are_zero() {
let buf = serialize(WireMessageType::Signal, b"flags");
let flags = u16::from_le_bytes([buf[6], buf[7]]);
assert_eq!(flags, 0);
}
}