use thiserror::Error;
pub const PACKET_TOKEN: u16 = 0xB5AD;
pub const HEADER_LEN: usize = 8;
pub const MAX_PAYLOAD: usize = u16::MAX as usize;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Packet<'a> {
pub sync: u8,
pub protocol: u8,
pub packet_type: u8,
pub payload: &'a [u8],
}
#[derive(Debug, Error, PartialEq, Eq)]
pub enum DecodeError {
#[error("not enough bytes (need {need}, have {have})")]
Incomplete {
need: usize,
have: usize,
},
#[error("start token mismatch: expected {expected:#06x}, got {got:#06x}")]
BadToken {
expected: u16,
got: u16,
},
#[error("header checksum mismatch: stored {stored:#06x}, computed {computed:#06x}")]
HeaderChecksum {
stored: u16,
computed: u16,
},
#[error("payload checksum mismatch: stored {stored:#06x}, computed {computed:#06x}")]
PayloadChecksum {
stored: u16,
computed: u16,
},
}
#[derive(Debug, Error, PartialEq, Eq)]
pub enum EncodeError {
#[error("payload length {len} exceeds maximum {MAX_PAYLOAD}")]
PayloadTooLarge {
len: usize,
},
#[error("protocol id {0} out of range (0..=15)")]
BadProtocol(u8),
#[error("packet type {0} out of range (0..=15)")]
BadPacketType(u8),
}
impl<'a> Packet<'a> {
pub fn new(
sync: u8,
protocol: u8,
packet_type: u8,
payload: &'a [u8],
) -> Result<Self, EncodeError> {
if protocol > 0xF {
return Err(EncodeError::BadProtocol(protocol));
}
if packet_type > 0xF {
return Err(EncodeError::BadPacketType(packet_type));
}
if payload.len() > MAX_PAYLOAD {
return Err(EncodeError::PayloadTooLarge { len: payload.len() });
}
Ok(Self {
sync,
protocol,
packet_type,
payload,
})
}
pub fn wire_len(&self) -> usize {
if self.payload.is_empty() {
HEADER_LEN
} else {
HEADER_LEN + self.payload.len() + 2
}
}
}
pub fn fletcher16(buf: &[u8]) -> u16 {
let mut cs: u16 = 0;
for &b in buf {
let cs_low = ((cs & 0xFF) + b as u16) % 255;
cs = ((((cs >> 8) + cs_low) % 255) << 8) | cs_low;
}
cs
}
pub fn encode(packet: &Packet<'_>, out: &mut Vec<u8>) -> Result<usize, EncodeError> {
if packet.protocol > 0xF {
return Err(EncodeError::BadProtocol(packet.protocol));
}
if packet.packet_type > 0xF {
return Err(EncodeError::BadPacketType(packet.packet_type));
}
if packet.payload.len() > MAX_PAYLOAD {
return Err(EncodeError::PayloadTooLarge {
len: packet.payload.len(),
});
}
let start = out.len();
out.extend_from_slice(&PACKET_TOKEN.to_le_bytes());
let header_payload_start = out.len();
out.push(packet.sync);
out.push(((packet.protocol & 0xF) << 4) | (packet.packet_type & 0xF));
out.extend_from_slice(&(packet.payload.len() as u16).to_le_bytes());
let header_cs = fletcher16(&out[header_payload_start..]);
out.extend_from_slice(&header_cs.to_le_bytes());
if !packet.payload.is_empty() {
out.extend_from_slice(packet.payload);
let body_end = out.len();
let payload_cs = fletcher16(&out[header_payload_start..body_end]);
out.extend_from_slice(&payload_cs.to_le_bytes());
}
Ok(out.len() - start)
}
pub fn decode(buf: &[u8]) -> Result<(Packet<'_>, usize), DecodeError> {
if buf.len() < HEADER_LEN {
return Err(DecodeError::Incomplete {
need: HEADER_LEN,
have: buf.len(),
});
}
let token = u16::from_le_bytes([buf[0], buf[1]]);
if token != PACKET_TOKEN {
return Err(DecodeError::BadToken {
expected: PACKET_TOKEN,
got: token,
});
}
let sync = buf[2];
let proto_type = buf[3];
let protocol = proto_type >> 4;
let packet_type = proto_type & 0x0F;
let payload_len = u16::from_le_bytes([buf[4], buf[5]]) as usize;
let stored_header_cs = u16::from_le_bytes([buf[6], buf[7]]);
let computed_header_cs = fletcher16(&buf[2..6]);
if stored_header_cs != computed_header_cs {
return Err(DecodeError::HeaderChecksum {
stored: stored_header_cs,
computed: computed_header_cs,
});
}
if payload_len == 0 {
return Ok((
Packet {
sync,
protocol,
packet_type,
payload: &[],
},
HEADER_LEN,
));
}
let total_len = HEADER_LEN + payload_len + 2;
if buf.len() < total_len {
return Err(DecodeError::Incomplete {
need: total_len,
have: buf.len(),
});
}
let payload = &buf[HEADER_LEN..HEADER_LEN + payload_len];
let cs_off = HEADER_LEN + payload_len;
let stored_payload_cs = u16::from_le_bytes([buf[cs_off], buf[cs_off + 1]]);
let computed_payload_cs = fletcher16(&buf[2..cs_off]);
if stored_payload_cs != computed_payload_cs {
return Err(DecodeError::PayloadChecksum {
stored: stored_payload_cs,
computed: computed_payload_cs,
});
}
Ok((
Packet {
sync,
protocol,
packet_type,
payload,
},
total_len,
))
}