#![doc = include_str!("../README.md")]
use std::io;
use serde::{Serialize, de::DeserializeOwned};
#[derive(Debug, thiserror::Error)]
pub enum CodecError {
#[error("payload empty")]
Empty,
#[error("version mismatch: expected {expected}, got {actual}")]
Version { expected: u8, actual: u8 },
#[error("encode failed: {0}")]
Encode(#[source] postcard::Error),
#[error("decode failed: {0}")]
Decode(#[source] postcard::Error),
#[error("trailing bytes: {extra} unconsumed after a valid body")]
TrailingBytes { extra: usize },
}
pub fn encode<T: Serialize>(version: u8, value: &T) -> Result<Vec<u8>, CodecError> {
let body = postcard::to_stdvec(value).map_err(CodecError::Encode)?;
let mut out = Vec::with_capacity(1 + body.len());
out.push(version);
out.extend_from_slice(&body);
Ok(out)
}
pub fn decode<T: DeserializeOwned>(expected_version: u8, bytes: &[u8]) -> Result<T, CodecError> {
let (first, rest) = bytes.split_first().ok_or(CodecError::Empty)?;
if *first != expected_version {
return Err(CodecError::Version {
expected: expected_version,
actual: *first,
});
}
let (value, remainder) = postcard::take_from_bytes(rest).map_err(CodecError::Decode)?;
if !remainder.is_empty() {
return Err(CodecError::TrailingBytes {
extra: remainder.len(),
});
}
Ok(value)
}
pub fn codec_io_error(context: &str, err: CodecError) -> io::Error {
match err {
version_mismatch @ CodecError::Version { .. } => io::Error::new(
io::ErrorKind::InvalidData,
format!("{context}: {version_mismatch}"),
),
other => io::Error::other(format!("{context}: {other}")),
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct Sample {
idx: u64,
name: String,
}
#[test]
fn encode_decode_roundtrip() {
let original = Sample {
idx: 42,
name: "tsoracle".into(),
};
let bytes = encode(1, &original).expect("encode");
assert_eq!(bytes[0], 1);
let decoded: Sample = decode(1, &bytes).expect("decode");
assert_eq!(original, decoded);
}
#[test]
fn decode_rejects_wrong_version() {
let bytes = encode(
2,
&Sample {
idx: 1,
name: "x".into(),
},
)
.expect("encode");
let err = decode::<Sample>(1, &bytes).expect_err("must reject");
assert!(matches!(
err,
CodecError::Version {
expected: 1,
actual: 2
}
));
}
#[test]
fn decode_rejects_empty() {
let err = decode::<Sample>(1, &[]).expect_err("must reject");
assert!(matches!(err, CodecError::Empty));
}
#[test]
fn decode_rejects_truncated_input() {
let original = Sample {
idx: u64::MAX,
name: "hello-world-storage-roundtrip".into(),
};
let bytes = encode(1, &original).expect("encode");
assert!(bytes.len() >= 16, "payload should be non-trivial");
let truncated = &bytes[..bytes.len() / 2];
assert!(matches!(
decode::<Sample>(1, truncated),
Err(CodecError::Decode(_))
));
}
#[test]
fn decode_rejects_trailing_bytes() {
let original = Sample {
idx: 7,
name: "trailing".into(),
};
let mut bytes = encode(1, &original).expect("encode");
bytes.extend_from_slice(&[0xAB, 0xCD, 0xEF]);
assert!(matches!(
decode::<Sample>(1, &bytes),
Err(CodecError::TrailingBytes { extra: 3 })
));
}
#[test]
fn codec_io_error_maps_version_mismatch_to_invalid_data() {
let v2_bytes = encode(
2,
&Sample {
idx: 1,
name: "x".into(),
},
)
.unwrap();
let err = decode::<Sample>(1, &v2_bytes).expect_err("must reject");
assert!(matches!(err, CodecError::Version { .. }));
let io_err = codec_io_error("vote decode", err);
assert_eq!(io_err.kind(), io::ErrorKind::InvalidData);
assert!(
io_err.to_string().starts_with("vote decode: "),
"context must prefix the message, got {io_err}"
);
}
#[test]
fn codec_io_error_maps_other_variants_to_other_kind() {
let io_err = codec_io_error("vote decode", CodecError::Empty);
assert_ne!(io_err.kind(), io::ErrorKind::InvalidData);
assert!(io_err.to_string().starts_with("vote decode: "));
}
use proptest::prelude::*;
proptest! {
#[test]
fn encode_decode_roundtrip_any(
version in any::<u8>(),
idx in any::<u64>(),
name in any::<String>(),
) {
let s = Sample { idx, name };
let bytes = encode(version, &s).unwrap();
prop_assert_eq!(bytes[0], version);
let back: Sample = decode(version, &bytes).unwrap();
prop_assert_eq!(s, back);
}
#[test]
fn decode_rejects_any_version_mismatch(
encoded in any::<u8>(),
expected in any::<u8>(),
idx in any::<u64>(),
name in any::<String>(),
) {
prop_assume!(encoded != expected);
let bytes = encode(encoded, &Sample { idx, name }).unwrap();
match decode::<Sample>(expected, &bytes) {
Err(CodecError::Version { expected: e, actual: a }) => {
prop_assert_eq!(e, expected);
prop_assert_eq!(a, encoded);
}
other => prop_assert!(false, "expected Version mismatch; got {other:?}"),
}
}
}
}