use super::types::PayloadFormat;
use bytes::Bytes;
use sonic_rs::JsonValueTrait as _;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum CodecError {
#[error("json parse error: {0}")]
Json(#[from] sonic_rs::Error),
#[error("msgpack parse error: {0}")]
MsgPack(#[from] rmpv::decode::Error),
#[error("msgpack encode error: {0}")]
Encode(#[from] rmpv::encode::Error),
#[error("msgpack trailing bytes: {0} byte(s) remain after value")]
TrailingBytes(usize),
}
#[derive(Debug, Clone)]
pub enum ParsedPayload {
Json(sonic_rs::Value),
MsgPack(rmpv::Value),
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum FieldRef<'a> {
Str(&'a str),
Int(i64),
Float(f64),
Bool(bool),
Null,
Other,
}
pub fn parse(payload: &Bytes, format: PayloadFormat) -> Result<ParsedPayload, CodecError> {
let effective = match format {
PayloadFormat::Auto => PayloadFormat::detect(payload),
other => other,
};
match effective {
PayloadFormat::Json | PayloadFormat::Auto => {
let value: sonic_rs::Value = sonic_rs::from_slice(payload)?;
Ok(ParsedPayload::Json(value))
}
PayloadFormat::MsgPack => {
let mut cursor: &[u8] = payload.as_ref();
let value = rmpv::decode::read_value(&mut cursor)?;
let remaining = cursor.len();
if remaining > 0 {
return Err(CodecError::TrailingBytes(remaining));
}
Ok(ParsedPayload::MsgPack(value))
}
}
}
pub fn to_json_bytes(value: &sonic_rs::Value) -> Result<Bytes, CodecError> {
let buf = sonic_rs::to_vec(value)?;
Ok(Bytes::from(buf))
}
pub fn to_msgpack_bytes(value: &rmpv::Value) -> Result<Bytes, CodecError> {
let mut buf: Vec<u8> = Vec::new();
rmpv::encode::write_value(&mut buf, value)?;
Ok(Bytes::from(buf))
}
impl ParsedPayload {
#[must_use]
pub fn is_json(&self) -> bool {
matches!(self, Self::Json(_))
}
#[must_use]
pub fn is_msgpack(&self) -> bool {
matches!(self, Self::MsgPack(_))
}
#[must_use]
pub fn field_str(&self, name: &str) -> Option<&str> {
match self {
Self::Json(v) => v.get(name).and_then(|f| f.as_str()),
Self::MsgPack(v) => msgpack_field(v, name).and_then(rmpv::Value::as_str),
}
}
#[must_use]
pub fn field(&self, name: &str) -> Option<FieldRef<'_>> {
match self {
Self::Json(v) => v.get(name).map(json_field_ref),
Self::MsgPack(v) => msgpack_field(v, name).map(msgpack_field_ref),
}
}
pub fn to_bytes(&self) -> Result<Bytes, CodecError> {
match self {
Self::Json(v) => to_json_bytes(v),
Self::MsgPack(v) => to_msgpack_bytes(v),
}
}
}
fn json_field_ref(v: &sonic_rs::Value) -> FieldRef<'_> {
if let Some(s) = v.as_str() {
FieldRef::Str(s)
} else if v.is_null() {
FieldRef::Null
} else if let Some(b) = v.as_bool() {
FieldRef::Bool(b)
} else if let Some(i) = v.as_i64() {
FieldRef::Int(i)
} else if let Some(f) = v.as_f64() {
FieldRef::Float(f)
} else {
FieldRef::Other
}
}
fn msgpack_field<'a>(v: &'a rmpv::Value, name: &str) -> Option<&'a rmpv::Value> {
match v {
rmpv::Value::Map(pairs) => pairs
.iter()
.find(|(k, _)| k.as_str() == Some(name))
.map(|(_, val)| val),
_ => None,
}
}
fn msgpack_field_ref(v: &rmpv::Value) -> FieldRef<'_> {
match v {
rmpv::Value::String(s) => s.as_str().map_or(FieldRef::Other, FieldRef::Str),
rmpv::Value::Nil => FieldRef::Null,
rmpv::Value::Boolean(b) => FieldRef::Bool(*b),
rmpv::Value::Integer(_) => v
.as_i64()
.map(FieldRef::Int)
.or_else(|| v.as_f64().map(FieldRef::Float))
.unwrap_or(FieldRef::Other),
rmpv::Value::F32(f) => FieldRef::Float(f64::from(*f)),
rmpv::Value::F64(f) => FieldRef::Float(*f),
_ => FieldRef::Other,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn fixstr(s: &str) -> Vec<u8> {
let bytes = s.as_bytes();
assert!(bytes.len() < 32, "fixstr helper only handles len < 32");
let len = u8::try_from(bytes.len()).expect("len < 32 fits u8");
let mut out = vec![0xa0 | len];
out.extend_from_slice(bytes);
out
}
fn fixmap_header(n: u8) -> u8 {
assert!(n < 16, "fixmap helper only handles < 16 entries");
0x80 | n
}
fn sample_msgpack() -> Bytes {
let mut buf = vec![fixmap_header(5)];
buf.extend(fixstr("_table"));
buf.extend(fixstr("events"));
buf.extend(fixstr("org_id"));
buf.push(42);
buf.extend(fixstr("live"));
buf.push(0xc3);
buf.extend(fixstr("ratio"));
buf.push(0xcb);
buf.extend_from_slice(&1.5f64.to_be_bytes());
buf.extend(fixstr("missing"));
buf.push(0xc0);
Bytes::from(buf)
}
fn sample_json() -> Bytes {
Bytes::from_static(
br#"{"_table":"events","org_id":42,"live":true,"ratio":1.5,"missing":null}"#,
)
}
#[test]
fn parse_json_object() {
let parsed = parse(&sample_json(), PayloadFormat::Json).unwrap();
assert!(parsed.is_json());
assert!(!parsed.is_msgpack());
assert_eq!(parsed.field_str("_table"), Some("events"));
}
#[test]
fn parse_json_array_is_ok() {
let parsed = parse(&Bytes::from_static(b"[1,2,3]"), PayloadFormat::Json).unwrap();
assert!(parsed.is_json());
assert_eq!(parsed.field_str("anything"), None);
}
#[test]
fn parse_msgpack_map() {
let parsed = parse(&sample_msgpack(), PayloadFormat::MsgPack).unwrap();
assert!(parsed.is_msgpack());
assert!(!parsed.is_json());
assert_eq!(parsed.field_str("_table"), Some("events"));
}
#[test]
fn parse_minimal_fixmap() {
let mut buf = vec![fixmap_header(1)];
buf.extend(fixstr("k"));
buf.extend(fixstr("v"));
let parsed = parse(&Bytes::from(buf), PayloadFormat::MsgPack).unwrap();
assert_eq!(parsed.field_str("k"), Some("v"));
}
#[test]
fn parse_auto_dispatches_to_json() {
let parsed = parse(&sample_json(), PayloadFormat::Auto).unwrap();
assert!(parsed.is_json(), "object byte '{{' must detect as JSON");
assert_eq!(parsed.field_str("_table"), Some("events"));
}
#[test]
fn parse_auto_dispatches_to_msgpack() {
let parsed = parse(&sample_msgpack(), PayloadFormat::Auto).unwrap();
assert!(
parsed.is_msgpack(),
"fixmap byte 0x85 must detect as MsgPack"
);
assert_eq!(parsed.field_str("_table"), Some("events"));
}
#[test]
fn field_str_identical_across_formats() {
let j = parse(&sample_json(), PayloadFormat::Json).unwrap();
let m = parse(&sample_msgpack(), PayloadFormat::MsgPack).unwrap();
assert_eq!(j.field_str("_table"), m.field_str("_table"));
assert_eq!(j.field_str("_table"), Some("events"));
}
#[test]
fn field_str_returns_none_for_non_string() {
let j = parse(&sample_json(), PayloadFormat::Json).unwrap();
let m = parse(&sample_msgpack(), PayloadFormat::MsgPack).unwrap();
assert_eq!(j.field_str("org_id"), None);
assert_eq!(m.field_str("org_id"), None);
}
#[test]
fn field_str_returns_none_for_missing_key() {
let j = parse(&sample_json(), PayloadFormat::Json).unwrap();
let m = parse(&sample_msgpack(), PayloadFormat::MsgPack).unwrap();
assert_eq!(j.field_str("nope"), None);
assert_eq!(m.field_str("nope"), None);
}
#[test]
fn field_str_value_is_present_via_field_too() {
let j = parse(&sample_json(), PayloadFormat::Json).unwrap();
assert_eq!(j.field("_table"), Some(FieldRef::Str("events")));
}
#[test]
fn field_int_identical_across_formats() {
let j = parse(&sample_json(), PayloadFormat::Json).unwrap();
let m = parse(&sample_msgpack(), PayloadFormat::MsgPack).unwrap();
assert_eq!(j.field("org_id"), Some(FieldRef::Int(42)));
assert_eq!(m.field("org_id"), Some(FieldRef::Int(42)));
}
#[test]
fn field_bool_identical_across_formats() {
let j = parse(&sample_json(), PayloadFormat::Json).unwrap();
let m = parse(&sample_msgpack(), PayloadFormat::MsgPack).unwrap();
assert_eq!(j.field("live"), Some(FieldRef::Bool(true)));
assert_eq!(m.field("live"), Some(FieldRef::Bool(true)));
}
#[test]
fn field_float_identical_across_formats() {
let j = parse(&sample_json(), PayloadFormat::Json).unwrap();
let m = parse(&sample_msgpack(), PayloadFormat::MsgPack).unwrap();
assert_eq!(j.field("ratio"), Some(FieldRef::Float(1.5)));
assert_eq!(m.field("ratio"), Some(FieldRef::Float(1.5)));
}
#[test]
fn field_null_identical_across_formats() {
let j = parse(&sample_json(), PayloadFormat::Json).unwrap();
let m = parse(&sample_msgpack(), PayloadFormat::MsgPack).unwrap();
assert_eq!(j.field("missing"), Some(FieldRef::Null));
assert_eq!(m.field("missing"), Some(FieldRef::Null));
}
#[test]
fn field_missing_key_is_none_for_both() {
let j = parse(&sample_json(), PayloadFormat::Json).unwrap();
let m = parse(&sample_msgpack(), PayloadFormat::MsgPack).unwrap();
assert_eq!(j.field("nope"), None);
assert_eq!(m.field("nope"), None);
}
#[test]
fn field_nested_object_is_other() {
let j = parse(
&Bytes::from_static(br#"{"k":{"nested":1}}"#),
PayloadFormat::Json,
)
.unwrap();
assert_eq!(j.field("k"), Some(FieldRef::Other));
assert_eq!(j.field_str("k"), None);
let mut buf = vec![fixmap_header(1)];
buf.extend(fixstr("k"));
buf.push(0x91); buf.push(0x01); let m = parse(&Bytes::from(buf), PayloadFormat::MsgPack).unwrap();
assert_eq!(m.field("k"), Some(FieldRef::Other));
}
#[test]
fn field_on_non_object_top_level_is_none() {
let j = parse(&Bytes::from_static(b"[1,2,3]"), PayloadFormat::Json).unwrap();
assert_eq!(j.field("0"), None);
let m = parse(&Bytes::from(vec![0x92, 0x01, 0x02]), PayloadFormat::MsgPack).unwrap();
assert_eq!(m.field("0"), None);
}
#[test]
fn malformed_json_errors() {
let err = parse(&Bytes::from_static(b"{not valid json"), PayloadFormat::Json).unwrap_err();
assert!(matches!(err, CodecError::Json(_)), "got {err:?}");
assert!(!err.to_string().is_empty());
}
#[test]
fn empty_blob_auto_errors_as_json() {
let err = parse(&Bytes::new(), PayloadFormat::Auto).unwrap_err();
assert!(matches!(err, CodecError::Json(_)), "got {err:?}");
}
#[test]
fn malformed_msgpack_errors() {
let err = parse(&Bytes::from_static(&[0x81]), PayloadFormat::MsgPack).unwrap_err();
assert!(matches!(err, CodecError::MsgPack(_)), "got {err:?}");
assert!(!err.to_string().is_empty());
}
#[test]
fn msgpack_truncated_float_errors() {
let mut buf = vec![fixmap_header(1)];
buf.extend(fixstr("ratio"));
buf.push(0xcb);
buf.extend_from_slice(&[0x00, 0x01, 0x02]); let err = parse(&Bytes::from(buf), PayloadFormat::MsgPack).unwrap_err();
assert!(matches!(err, CodecError::MsgPack(_)), "got {err:?}");
}
fn assert_sample_fields_eq(a: &ParsedPayload, b: &ParsedPayload) {
assert_eq!(a.field("_table"), b.field("_table"));
assert_eq!(a.field("org_id"), b.field("org_id"));
assert_eq!(a.field("live"), b.field("live"));
assert_eq!(a.field("ratio"), b.field("ratio"));
assert_eq!(a.field("missing"), b.field("missing"));
}
#[test]
fn json_to_bytes_round_trips() {
let original = parse(&sample_json(), PayloadFormat::Json).unwrap();
assert!(original.is_json());
let bytes = original.to_bytes().unwrap();
assert!(!bytes.is_empty());
let reparsed = parse(&bytes, PayloadFormat::Json).unwrap();
assert!(reparsed.is_json(), "JSON must round-trip as JSON");
assert_sample_fields_eq(&original, &reparsed);
}
#[test]
fn msgpack_to_bytes_round_trips_via_native_bytes() {
let original = parse(&sample_msgpack(), PayloadFormat::MsgPack).unwrap();
assert!(original.is_msgpack());
let bytes = original.to_bytes().unwrap();
assert!(!bytes.is_empty());
assert_eq!(bytes[0], fixmap_header(5), "expected fixmap(5) wire marker");
let reparsed = parse(&bytes, PayloadFormat::MsgPack).unwrap();
assert!(reparsed.is_msgpack(), "MsgPack must round-trip as MsgPack");
assert_sample_fields_eq(&original, &reparsed);
}
#[test]
fn to_json_bytes_reparses_to_same_value() {
let ParsedPayload::Json(value) = parse(&sample_json(), PayloadFormat::Json).unwrap() else {
panic!("expected JSON");
};
let bytes = to_json_bytes(&value).unwrap();
let reparsed = parse(&bytes, PayloadFormat::Json).unwrap();
assert_eq!(reparsed.field("_table"), Some(FieldRef::Str("events")));
assert_eq!(reparsed.field("org_id"), Some(FieldRef::Int(42)));
assert_eq!(reparsed.field("ratio"), Some(FieldRef::Float(1.5)));
}
#[test]
fn to_msgpack_bytes_reparses_to_same_value() {
let ParsedPayload::MsgPack(value) =
parse(&sample_msgpack(), PayloadFormat::MsgPack).unwrap()
else {
panic!("expected MsgPack");
};
let bytes = to_msgpack_bytes(&value).unwrap();
let reparsed = parse(&bytes, PayloadFormat::MsgPack).unwrap();
assert_eq!(reparsed.field("_table"), Some(FieldRef::Str("events")));
assert_eq!(reparsed.field("org_id"), Some(FieldRef::Int(42)));
assert_eq!(reparsed.field("live"), Some(FieldRef::Bool(true)));
assert_eq!(reparsed.field("ratio"), Some(FieldRef::Float(1.5)));
assert_eq!(reparsed.field("missing"), Some(FieldRef::Null));
}
#[test]
fn to_bytes_preserves_a_mutated_json_field() {
let ParsedPayload::Json(mut value) = parse(&sample_json(), PayloadFormat::Json).unwrap()
else {
panic!("expected JSON");
};
value.insert("_table", sonic_rs::Value::from("audit"));
let bytes = to_json_bytes(&value).unwrap();
let reparsed = parse(&bytes, PayloadFormat::Json).unwrap();
assert_eq!(reparsed.field_str("_table"), Some("audit"));
assert_eq!(reparsed.field("org_id"), Some(FieldRef::Int(42)));
}
#[test]
fn json_to_bytes_handles_top_level_array() {
let parsed = parse(&Bytes::from_static(b"[1,2,3]"), PayloadFormat::Json).unwrap();
let bytes = parsed.to_bytes().unwrap();
let reparsed = parse(&bytes, PayloadFormat::Json).unwrap();
assert!(reparsed.is_json());
assert_eq!(reparsed.field_str("anything"), None);
}
#[test]
fn msgpack_to_bytes_handles_top_level_scalar() {
let parsed = parse(&Bytes::from(vec![42u8]), PayloadFormat::MsgPack).unwrap();
let bytes = parsed.to_bytes().unwrap();
assert_eq!(
bytes.as_ref(),
&[42u8],
"fixint must re-emit byte-identical"
);
let reparsed = parse(&bytes, PayloadFormat::MsgPack).unwrap();
assert!(reparsed.is_msgpack());
}
#[test]
fn double_round_trip_is_stable() {
let first = parse(&sample_msgpack(), PayloadFormat::MsgPack).unwrap();
let b1 = first.to_bytes().unwrap();
let second = parse(&b1, PayloadFormat::MsgPack).unwrap();
let b2 = second.to_bytes().unwrap();
assert_eq!(b1, b2, "re-serialising a re-parsed value must be stable");
}
#[test]
fn msgpack_rejects_trailing_bytes() {
let mut buf = vec![fixmap_header(1)];
buf.extend(fixstr("k"));
buf.extend(fixstr("v"));
buf.push(0xc0); let err = parse(&Bytes::from(buf), PayloadFormat::MsgPack)
.expect_err("trailing byte must be rejected");
match err {
CodecError::TrailingBytes(n) => assert_eq!(n, 1, "expected 1 trailing byte"),
other => panic!("expected TrailingBytes(1), got {other:?}"),
}
}
#[test]
fn msgpack_rejects_concatenated_values() {
let buf = vec![0x01u8, 0x02u8]; let err = parse(&Bytes::from(buf), PayloadFormat::MsgPack)
.expect_err("concatenated values must be rejected");
match err {
CodecError::TrailingBytes(n) => assert_eq!(n, 1, "expected 1 trailing byte"),
other => panic!("expected TrailingBytes(1), got {other:?}"),
}
}
#[test]
fn msgpack_clean_single_value_still_parses_ok() {
let mut buf = vec![fixmap_header(1)];
buf.extend(fixstr("k"));
buf.extend(fixstr("v"));
let parsed = parse(&Bytes::from(buf), PayloadFormat::MsgPack).unwrap();
assert_eq!(parsed.field_str("k"), Some("v"));
}
#[test]
fn json_rejects_trailing_garbage() {
let mut buf = br#"{"_table":"events"}"#.to_vec();
buf.extend_from_slice(b"garbage");
let err = parse(&Bytes::from(buf), PayloadFormat::Json)
.expect_err("trailing non-whitespace garbage must be rejected");
assert!(
matches!(err, CodecError::Json(_)),
"expected CodecError::Json for trailing garbage, got {err:?}"
);
}
#[test]
fn json_accepts_trailing_whitespace() {
let mut buf = br#"{"_table":"events"}"#.to_vec();
buf.extend_from_slice(b" \t\r\n");
let parsed = parse(&Bytes::from(buf), PayloadFormat::Json)
.expect("trailing whitespace must be accepted");
assert_eq!(parsed.field_str("_table"), Some("events"));
}
#[test]
fn json_parsed_as_msgpack_errors() {
let err = parse(&sample_json(), PayloadFormat::MsgPack)
.expect_err("JSON fed to MsgPack path must error after trailing-bytes hardening");
assert!(
matches!(err, CodecError::TrailingBytes(_)),
"expected TrailingBytes, got {err:?}"
);
}
}