use alloc::{string::ToString, vec::Vec};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use crate::VckResult;
pub trait HandoverPayload: Serialize + DeserializeOwned {
const VAR_NAME: &'static str;
const VAR_GUID: [u8; 16];
}
pub fn encode_payload<P: HandoverPayload>(payload: &P) -> VckResult<Vec<u8>> {
messagepack_serde::to_vec(payload)
.map_err(|err| crate::VckError::MsgpackEncode(err.to_string()))
}
pub fn decode_payload<P: HandoverPayload>(bytes: &[u8]) -> VckResult<P> {
messagepack_serde::from_slice(bytes)
.map_err(|err| crate::VckError::MsgpackDecode(err.to_string()))
}
pub const HANDOVER_LOCATOR_MAGIC: u32 = u32::from_le_bytes(*b"VCKL");
pub const HANDOVER_LOCATOR_VERSION: u16 = 1;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct HandoverLocator {
pub magic: u32,
pub version: u16,
pub address: u64,
pub length: u64,
}
impl HandoverLocator {
pub fn new(address: u64, length: u64) -> Self {
Self {
magic: HANDOVER_LOCATOR_MAGIC,
version: HANDOVER_LOCATOR_VERSION,
address,
length,
}
}
pub fn validate(&self) -> VckResult<()> {
if self.magic != HANDOVER_LOCATOR_MAGIC {
return Err(crate::VckError::InvalidData("handover locator: bad magic"));
}
if self.version != HANDOVER_LOCATOR_VERSION {
return Err(crate::VckError::InvalidData(
"handover locator: unsupported version",
));
}
if self.address == 0 || self.length == 0 {
return Err(crate::VckError::InvalidData(
"handover locator: empty address/length",
));
}
Ok(())
}
}
pub fn encode_locator(locator: &HandoverLocator) -> VckResult<Vec<u8>> {
messagepack_serde::to_vec(locator)
.map_err(|err| crate::VckError::MsgpackEncode(err.to_string()))
}
pub fn decode_locator(bytes: &[u8]) -> VckResult<HandoverLocator> {
let locator: HandoverLocator = messagepack_serde::from_slice(bytes)
.map_err(|err| crate::VckError::MsgpackDecode(err.to_string()))?;
locator.validate()?;
Ok(locator)
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct TestPayload {
partition_guid: [u8; 16],
vmk: Vec<u8>,
}
impl HandoverPayload for TestPayload {
const VAR_NAME: &'static str = "TestHandover";
const VAR_GUID: [u8; 16] = [0u8; 16];
}
#[test]
fn encode_decode_round_trip() {
let original = TestPayload {
partition_guid: [
0x5a, 0x95, 0x77, 0x0f, 0x3e, 0xf6, 0x11, 0xf1, 0x8b, 0x5c, 0xb4, 0x2e, 0x99, 0x11,
0x84, 0x0a,
],
vmk: (0u8..32).collect(),
};
let bytes = encode_payload(&original).expect("encode");
let decoded: TestPayload = decode_payload(&bytes).expect("decode");
assert_eq!(decoded, original);
}
#[test]
fn decode_rejects_garbage() {
assert!(decode_payload::<TestPayload>(&[0xff, 0x00, 0x12, 0x34]).is_err());
}
#[test]
fn locator_round_trip() {
let locator = HandoverLocator::new(0x1_2345_6000, 4096);
let bytes = encode_locator(&locator).expect("encode locator");
let decoded = decode_locator(&bytes).expect("decode locator");
assert_eq!(decoded, locator);
}
#[test]
fn locator_rejects_bad_magic() {
let mut locator = HandoverLocator::new(0x1000, 16);
locator.magic ^= 0xFFFF_FFFF;
let bytes = encode_locator(&locator).expect("encode locator");
assert!(decode_locator(&bytes).is_err());
}
#[test]
fn locator_rejects_empty() {
let locator = HandoverLocator::new(0, 0);
let bytes = encode_locator(&locator).expect("encode locator");
assert!(decode_locator(&bytes).is_err());
}
}