use prost::Message;
use crate::error::{MotorcortexError, Result};
use crate::msg::{Hash, get_hash, get_hash_size};
pub(crate) fn encode_with_hash<M: Message + Hash>(msg: &M) -> Result<Vec<u8>> {
let mut buffer = Vec::with_capacity(get_hash_size() + msg.encoded_len());
buffer.extend(get_hash::<M>().to_le_bytes());
msg.encode(&mut buffer)
.map_err(|e| MotorcortexError::Encode(e.to_string()))?;
Ok(buffer)
}
pub(crate) fn decode_message<T: Message + Default + Hash>(bytes: &[u8]) -> Result<T> {
let hash_size = get_hash_size();
if bytes.len() < hash_size {
return Err(MotorcortexError::Decode(
"Invalid message length, hash missing".into(),
));
}
let provided = u32::from_le_bytes(
bytes[..hash_size]
.try_into()
.map_err(|_| MotorcortexError::Decode("Failed to extract hash".into()))?,
);
if provided != get_hash::<T>() {
return Err(MotorcortexError::Decode("Invalid message hash".into()));
}
T::decode(&bytes[hash_size..]).map_err(MotorcortexError::from)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::msg::{LoginMsg, StatusCode, StatusMsg};
#[test]
fn round_trips_a_status_msg() {
let msg = StatusMsg {
header: None,
status: StatusCode::Ok as i32,
};
let wire = encode_with_hash(&msg).unwrap();
let decoded: StatusMsg = decode_message(&wire).unwrap();
assert_eq!(msg, decoded);
}
#[test]
fn round_trips_a_login_msg_with_payload() {
let msg = LoginMsg {
header: None,
login: "operator".into(),
password: "secret".into(),
};
let wire = encode_with_hash(&msg).unwrap();
assert!(wire.len() > get_hash_size(), "hash + body");
let decoded: LoginMsg = decode_message(&wire).unwrap();
assert_eq!(msg, decoded);
}
#[test]
fn decode_rejects_input_shorter_than_hash() {
let err = decode_message::<StatusMsg>(&[0, 1]).expect_err("< 4 bytes must fail");
assert!(matches!(err, MotorcortexError::Decode(_)));
}
#[test]
fn decode_rejects_wrong_hash() {
let err = decode_message::<StatusMsg>(&[0xFF, 0xFF, 0xFF, 0xFF])
.expect_err("bad hash must fail");
assert!(matches!(err, MotorcortexError::Decode(_)));
}
#[test]
fn decode_rejects_malformed_body() {
let mut buf = get_hash::<StatusMsg>().to_le_bytes().to_vec();
buf.push(0xFF); let err = decode_message::<StatusMsg>(&buf).expect_err("bad body must fail");
assert!(matches!(err, MotorcortexError::Decode(_)));
}
}