use serde::Serialize;
use serde::de::DeserializeOwned;
use std::fmt;
use std::io::{self, BufRead, Read};
use std::sync::OnceLock;
pub const MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024;
const MAX_RMPV_DEPTH: usize = 128;
static WIRE_CODEC: OnceLock<Codec> = OnceLock::new();
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Codec {
Json,
MsgPack,
}
impl fmt::Display for Codec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Codec::Json => f.write_str("json"),
Codec::MsgPack => f.write_str("msgpack"),
}
}
}
impl Codec {
pub fn encode<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, String> {
match self {
Codec::Json => {
let mut bytes =
serde_json::to_vec(value).map_err(|e| format!("json encode: {e}"))?;
bytes.push(b'\n');
Ok(bytes)
}
Codec::MsgPack => {
let payload =
rmp_serde::to_vec_named(value).map_err(|e| format!("msgpack encode: {e}"))?;
let len = u32::try_from(payload.len()).map_err(|_| {
format!(
"payload exceeds 4 GiB frame limit ({} bytes)",
payload.len()
)
})?;
let mut bytes = Vec::with_capacity(4 + payload.len());
bytes.extend_from_slice(&len.to_be_bytes());
bytes.extend_from_slice(&payload);
Ok(bytes)
}
}
}
pub fn encode_binary_message(
&self,
mut map: serde_json::Map<String, serde_json::Value>,
binary_field: Option<(&str, &[u8])>,
) -> Result<Vec<u8>, String> {
match self {
Codec::Json => {
if let Some((key, bytes)) = binary_field
&& !bytes.is_empty()
{
use base64::Engine;
let b64 = base64::engine::general_purpose::STANDARD.encode(bytes);
map.insert(key.to_string(), serde_json::Value::String(b64));
}
let val = serde_json::Value::Object(map);
let mut bytes =
serde_json::to_vec(&val).map_err(|e| format!("json encode: {e}"))?;
bytes.push(b'\n');
Ok(bytes)
}
Codec::MsgPack => {
use rmpv::Value as V;
let mut entries: Vec<(V, V)> = map
.into_iter()
.map(|(k, v)| (V::String(k.into()), json_to_rmpv(v)))
.collect();
if let Some((key, bytes)) = binary_field
&& !bytes.is_empty()
{
entries.push((V::String(key.into()), V::Binary(bytes.to_vec())));
}
let msg = V::Map(entries);
let mut payload = Vec::new();
rmpv::encode::write_value(&mut payload, &msg)
.map_err(|e| format!("msgpack encode: {e}"))?;
let len = u32::try_from(payload.len()).map_err(|_| {
format!(
"payload exceeds 4 GiB frame limit ({} bytes)",
payload.len()
)
})?;
let mut bytes = Vec::with_capacity(4 + payload.len());
bytes.extend_from_slice(&len.to_be_bytes());
bytes.extend_from_slice(&payload);
Ok(bytes)
}
}
}
pub fn decode<T: DeserializeOwned>(&self, bytes: &[u8]) -> Result<T, String> {
match self {
Codec::Json => serde_json::from_slice(bytes).map_err(|e| format!("json decode: {e}")),
Codec::MsgPack => {
check_msgpack_depth(bytes, MAX_RMPV_DEPTH)
.map_err(|e| format!("msgpack depth check: {e}"))?;
let rmpv_val: rmpv::Value = rmpv::decode::read_value(&mut &bytes[..])
.map_err(|e| format!("msgpack decode (rmpv): {e}"))?;
let json_val = rmpv_to_json(rmpv_val);
serde_json::from_value(json_val.clone()).map_err(|e| {
let msg = format!("msgpack decode (tag dispatch): {e}");
#[cfg(debug_assertions)]
let msg = {
let dump = json_val.to_string();
let truncated = if dump.len() > 512 {
format!("{}...", &dump[..512])
} else {
dump
};
format!("{msg} | json: {truncated}")
};
msg
})
}
}
}
pub fn read_message<R: BufRead>(&self, reader: &mut R) -> io::Result<Option<Vec<u8>>> {
match self {
Codec::Json => loop {
let mut line = String::new();
let limit = (MAX_MESSAGE_SIZE + 1) as u64;
let n = (&mut *reader).take(limit).read_line(&mut line)?;
if n == 0 {
return Ok(None);
}
if line.len() > MAX_MESSAGE_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"JSON message exceeds {} byte limit ({} bytes)",
MAX_MESSAGE_SIZE,
line.len()
),
));
}
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
return Ok(Some(trimmed.as_bytes().to_vec()));
},
Codec::MsgPack => {
let mut len_buf = [0u8; 4];
match reader.read_exact(&mut len_buf) {
Ok(()) => {}
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(e),
}
let len = u32::from_be_bytes(len_buf) as usize;
if len == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"empty frame received",
));
}
if len > MAX_MESSAGE_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"msgpack frame exceeds {} byte limit ({} bytes)",
MAX_MESSAGE_SIZE, len
),
));
}
let mut payload = vec![0u8; len];
reader.read_exact(&mut payload)?;
Ok(Some(payload))
}
}
}
pub fn detect_from_first_byte(byte: u8) -> Codec {
if byte == b'{' {
Codec::Json
} else {
Codec::MsgPack
}
}
pub fn set_global(codec: Codec) {
WIRE_CODEC
.set(codec)
.expect("WIRE_CODEC already initialized");
}
pub fn get_global() -> &'static Codec {
WIRE_CODEC.get().unwrap_or(&Codec::MsgPack)
}
}
fn check_msgpack_depth(bytes: &[u8], max_depth: usize) -> Result<(), String> {
let len = bytes.len();
let mut pos: usize = 0;
let mut depth: usize = 0;
let mut remaining: Vec<usize> = Vec::new();
while pos < len {
let b = bytes[pos];
pos += 1;
let (skip, children) = match b {
0x00..=0x7f => (0, 0),
0x80..=0x8f => (0, ((b & 0x0f) as usize) * 2),
0x90..=0x9f => (0, (b & 0x0f) as usize),
0xa0..=0xbf => ((b & 0x1f) as usize, 0),
0xc0..=0xc3 => (0, 0),
0xc4 => {
if pos >= len {
break;
}
(1 + bytes[pos] as usize, 0)
}
0xc5 => {
if pos + 1 >= len {
break;
}
let n = u16::from_be_bytes([bytes[pos], bytes[pos + 1]]) as usize;
(2 + n, 0)
}
0xc6 => {
if pos + 3 >= len {
break;
}
let n = u32::from_be_bytes([
bytes[pos],
bytes[pos + 1],
bytes[pos + 2],
bytes[pos + 3],
]) as usize;
(4 + n, 0)
}
0xc7 => {
if pos >= len {
break;
}
(2 + bytes[pos] as usize, 0)
}
0xc8 => {
if pos + 1 >= len {
break;
}
let n = u16::from_be_bytes([bytes[pos], bytes[pos + 1]]) as usize;
(3 + n, 0)
}
0xc9 => {
if pos + 3 >= len {
break;
}
let n = u32::from_be_bytes([
bytes[pos],
bytes[pos + 1],
bytes[pos + 2],
bytes[pos + 3],
]) as usize;
(5 + n, 0)
}
0xca => (4, 0),
0xcb => (8, 0),
0xcc | 0xd0 => (1, 0),
0xcd | 0xd1 => (2, 0),
0xce | 0xd2 => (4, 0),
0xcf | 0xd3 => (8, 0),
0xd4 => (2, 0),
0xd5 => (3, 0),
0xd6 => (5, 0),
0xd7 => (9, 0),
0xd8 => (17, 0),
0xd9 => {
if pos >= len {
break;
}
(1 + bytes[pos] as usize, 0)
}
0xda => {
if pos + 1 >= len {
break;
}
let n = u16::from_be_bytes([bytes[pos], bytes[pos + 1]]) as usize;
(2 + n, 0)
}
0xdb => {
if pos + 3 >= len {
break;
}
let n = u32::from_be_bytes([
bytes[pos],
bytes[pos + 1],
bytes[pos + 2],
bytes[pos + 3],
]) as usize;
(4 + n, 0)
}
0xdc => {
if pos + 1 >= len {
break;
}
let n = u16::from_be_bytes([bytes[pos], bytes[pos + 1]]) as usize;
pos += 2;
(0, n)
}
0xdd => {
if pos + 3 >= len {
break;
}
let n = u32::from_be_bytes([
bytes[pos],
bytes[pos + 1],
bytes[pos + 2],
bytes[pos + 3],
]) as usize;
pos += 4;
(0, n)
}
0xde => {
if pos + 1 >= len {
break;
}
let n = u16::from_be_bytes([bytes[pos], bytes[pos + 1]]) as usize;
pos += 2;
(0, n * 2)
}
0xdf => {
if pos + 3 >= len {
break;
}
let n = u32::from_be_bytes([
bytes[pos],
bytes[pos + 1],
bytes[pos + 2],
bytes[pos + 3],
]) as usize;
pos += 4;
(0, n * 2)
}
0xe0..=0xff => (0, 0),
};
pos += skip;
if children > 0 {
let remaining_bytes = len.saturating_sub(pos);
if children > remaining_bytes {
return Err(format!(
"msgpack container declares {children} elements but only {remaining_bytes} bytes remain"
));
}
depth += 1;
if depth > max_depth {
return Err(format!("msgpack nesting depth exceeds limit ({max_depth})"));
}
remaining.push(children);
} else {
while let Some(count) = remaining.last_mut() {
*count -= 1;
if *count == 0 {
remaining.pop();
depth -= 1;
} else {
break;
}
}
}
}
Ok(())
}
fn rmpv_to_json(val: rmpv::Value) -> serde_json::Value {
rmpv_to_json_inner(val, 0)
}
fn rmpv_to_json_inner(val: rmpv::Value, depth: usize) -> serde_json::Value {
if depth > MAX_RMPV_DEPTH {
log::error!("rmpv_to_json: recursion depth exceeded {MAX_RMPV_DEPTH}, replaced with null");
return serde_json::Value::Null;
}
match val {
rmpv::Value::Nil => serde_json::Value::Null,
rmpv::Value::Boolean(b) => serde_json::Value::Bool(b),
rmpv::Value::Integer(n) => {
if let Some(i) = n.as_i64() {
serde_json::Value::Number(i.into())
} else if let Some(u) = n.as_u64() {
serde_json::Value::Number(u.into())
} else {
serde_json::Value::Null
}
}
rmpv::Value::F32(f) => serde_json::Number::from_f64(f as f64)
.map(serde_json::Value::Number)
.unwrap_or_else(|| {
log::warn!("rmpv_to_json: non-finite f32 ({f}) replaced with 0.0");
serde_json::Value::Number(serde_json::Number::from_f64(0.0).unwrap())
}),
rmpv::Value::F64(f) => serde_json::Number::from_f64(f)
.map(serde_json::Value::Number)
.unwrap_or_else(|| {
log::warn!("rmpv_to_json: non-finite f64 ({f}) replaced with 0.0");
serde_json::Value::Number(serde_json::Number::from_f64(0.0).unwrap())
}),
rmpv::Value::String(s) => {
serde_json::Value::String(String::from_utf8_lossy(s.as_bytes()).into_owned())
}
rmpv::Value::Binary(bytes) => {
serde_json::Value::Array(
bytes
.into_iter()
.map(|b| serde_json::Value::Number(b.into()))
.collect(),
)
}
rmpv::Value::Array(arr) => serde_json::Value::Array(
arr.into_iter()
.map(|v| rmpv_to_json_inner(v, depth + 1))
.collect(),
),
rmpv::Value::Map(entries) => {
let mut map = serde_json::Map::new();
for (k, v) in entries {
let key = match k {
rmpv::Value::String(s) => s.into_str().unwrap_or_default().to_string(),
rmpv::Value::Integer(n) => n.to_string(),
other => format!("{other}"),
};
map.insert(key, rmpv_to_json_inner(v, depth + 1));
}
serde_json::Value::Object(map)
}
rmpv::Value::Ext(type_id, _bytes) => {
log::warn!(
"rmpv_to_json: msgpack ext type {type_id} not supported, replaced with null"
);
serde_json::Value::Null
}
}
}
fn json_to_rmpv(val: serde_json::Value) -> rmpv::Value {
match val {
serde_json::Value::Null => rmpv::Value::Nil,
serde_json::Value::Bool(b) => rmpv::Value::Boolean(b),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
rmpv::Value::Integer(i.into())
} else if let Some(u) = n.as_u64() {
rmpv::Value::Integer(u.into())
} else if let Some(f) = n.as_f64() {
rmpv::Value::F64(f)
} else {
rmpv::Value::Nil
}
}
serde_json::Value::String(s) => rmpv::Value::String(s.into()),
serde_json::Value::Array(arr) => {
rmpv::Value::Array(arr.into_iter().map(json_to_rmpv).collect())
}
serde_json::Value::Object(map) => rmpv::Value::Map(
map.into_iter()
.map(|(k, v)| (rmpv::Value::String(k.into()), json_to_rmpv(v)))
.collect(),
),
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
use serde_json::json;
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct Simple {
name: String,
count: u32,
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
enum Tagged {
Alpha { value: String },
Beta { x: f64, y: f64 },
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct WithFlatten {
op: String,
#[serde(flatten)]
rest: serde_json::Value,
}
#[test]
fn json_roundtrip_simple() {
let original = Simple {
name: "test".into(),
count: 42,
};
let bytes = Codec::Json.encode(&original).unwrap();
assert!(bytes.ends_with(b"\n"));
let decoded: Simple = Codec::Json.decode(&bytes[..bytes.len() - 1]).unwrap();
assert_eq!(decoded, original);
}
#[test]
fn json_roundtrip_tagged_enum() {
let original = Tagged::Beta { x: 1.5, y: 2.5 };
let bytes = Codec::Json.encode(&original).unwrap();
let decoded: Tagged = Codec::Json.decode(&bytes[..bytes.len() - 1]).unwrap();
assert_eq!(decoded, original);
}
#[test]
fn msgpack_roundtrip_simple() {
let original = Simple {
name: "test".into(),
count: 42,
};
let bytes = Codec::MsgPack.encode(&original).unwrap();
let len = u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
assert_eq!(len, bytes.len() - 4);
let decoded: Simple = Codec::MsgPack.decode(&bytes[4..]).unwrap();
assert_eq!(decoded, original);
}
#[test]
fn msgpack_roundtrip_tagged_enum() {
let original = Tagged::Alpha {
value: "hello".into(),
};
let bytes = Codec::MsgPack.encode(&original).unwrap();
let payload = &bytes[4..];
let decoded: Tagged = Codec::MsgPack.decode(payload).unwrap();
assert_eq!(decoded, original);
}
#[test]
fn msgpack_roundtrip_tagged_enum_beta() {
let original = Tagged::Beta {
x: std::f64::consts::PI,
y: -1.0,
};
let bytes = Codec::MsgPack.encode(&original).unwrap();
let payload = &bytes[4..];
let decoded: Tagged = Codec::MsgPack.decode(payload).unwrap();
assert_eq!(decoded, original);
}
#[test]
fn msgpack_flatten_deserialize() {
let input = json!({"op": "props", "path": [0, 1], "props": {"label": "hi"}});
let bytes = rmp_serde::to_vec_named(&input).unwrap();
let decoded: WithFlatten = rmp_serde::from_slice(&bytes).unwrap();
assert_eq!(decoded.op, "props");
assert_eq!(decoded.rest["path"], json!([0, 1]));
assert_eq!(decoded.rest["props"]["label"], "hi");
}
#[test]
fn json_read_message_skips_blank_lines() {
let data = b"\n\n{\"name\":\"a\",\"count\":1}\n\n{\"name\":\"b\",\"count\":2}\n\n";
let mut reader = io::BufReader::new(&data[..]);
let msg1 = Codec::Json.read_message(&mut reader).unwrap().unwrap();
let s1: Simple = Codec::Json.decode(&msg1).unwrap();
assert_eq!(s1.name, "a");
let msg2 = Codec::Json.read_message(&mut reader).unwrap().unwrap();
let s2: Simple = Codec::Json.decode(&msg2).unwrap();
assert_eq!(s2.name, "b");
assert!(Codec::Json.read_message(&mut reader).unwrap().is_none());
}
#[test]
fn json_read_message() {
let data = b"{\"name\":\"a\",\"count\":1}\n{\"name\":\"b\",\"count\":2}\n";
let mut reader = io::BufReader::new(&data[..]);
let msg1 = Codec::Json.read_message(&mut reader).unwrap().unwrap();
let s1: Simple = Codec::Json.decode(&msg1).unwrap();
assert_eq!(s1.name, "a");
let msg2 = Codec::Json.read_message(&mut reader).unwrap().unwrap();
let s2: Simple = Codec::Json.decode(&msg2).unwrap();
assert_eq!(s2.name, "b");
assert!(Codec::Json.read_message(&mut reader).unwrap().is_none());
}
#[test]
fn msgpack_read_message() {
let s1 = Simple {
name: "x".into(),
count: 10,
};
let s2 = Simple {
name: "y".into(),
count: 20,
};
let p1 = rmp_serde::to_vec_named(&s1).unwrap();
let p2 = rmp_serde::to_vec_named(&s2).unwrap();
let mut data = Vec::new();
data.extend_from_slice(&(p1.len() as u32).to_be_bytes());
data.extend_from_slice(&p1);
data.extend_from_slice(&(p2.len() as u32).to_be_bytes());
data.extend_from_slice(&p2);
let mut reader = io::BufReader::new(&data[..]);
let msg1 = Codec::MsgPack.read_message(&mut reader).unwrap().unwrap();
let d1: Simple = Codec::MsgPack.decode(&msg1).unwrap();
assert_eq!(d1, s1);
let msg2 = Codec::MsgPack.read_message(&mut reader).unwrap().unwrap();
let d2: Simple = Codec::MsgPack.decode(&msg2).unwrap();
assert_eq!(d2, s2);
assert!(Codec::MsgPack.read_message(&mut reader).unwrap().is_none());
}
#[test]
fn json_read_message_rejects_oversized_line() {
let small_limit = 100;
let long_line: Vec<u8> = vec![b'x'; small_limit + 10];
let mut reader = io::BufReader::new(&long_line[..]);
let mut line = String::new();
let limit = (small_limit + 1) as u64;
let _n = (&mut reader).take(limit).read_line(&mut line).unwrap();
assert!(line.len() <= small_limit + 1);
}
#[test]
fn msgpack_read_message_rejects_oversized_frame() {
let len = (MAX_MESSAGE_SIZE + 1) as u32;
let mut data = Vec::new();
data.extend_from_slice(&len.to_be_bytes());
data.extend_from_slice(&[0u8; 64]);
let mut reader = io::BufReader::new(&data[..]);
let result = Codec::MsgPack.read_message(&mut reader);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("byte limit"));
}
#[test]
fn msgpack_read_message_rejects_zero_length_frame() {
let mut data = Vec::new();
data.extend_from_slice(&0u32.to_be_bytes());
let mut reader = io::BufReader::new(&data[..]);
let result = Codec::MsgPack.read_message(&mut reader);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("empty frame"));
}
#[test]
fn msgpack_external_tagged_enum_alpha() {
let external = json!({"type": "alpha", "value": "hello"});
let bytes = rmp_serde::to_vec_named(&external).unwrap();
let decoded: Tagged = Codec::MsgPack.decode(&bytes).unwrap();
assert_eq!(
decoded,
Tagged::Alpha {
value: "hello".into()
}
);
}
#[test]
fn msgpack_external_tagged_enum_beta() {
let external = json!({"type": "beta", "x": 1.5, "y": -2.0});
let bytes = rmp_serde::to_vec_named(&external).unwrap();
let decoded: Tagged = Codec::MsgPack.decode(&bytes).unwrap();
assert_eq!(decoded, Tagged::Beta { x: 1.5, y: -2.0 });
}
#[test]
fn msgpack_external_incoming_settings() {
use crate::protocol::IncomingMessage;
let external = json!({"type": "settings", "settings": {"antialiasing": false}});
let bytes = rmp_serde::to_vec_named(&external).unwrap();
let decoded: IncomingMessage = Codec::MsgPack.decode(&bytes).unwrap();
assert!(matches!(decoded, IncomingMessage::Settings { .. }));
}
#[test]
fn msgpack_external_incoming_snapshot() {
use crate::protocol::IncomingMessage;
let external = json!({"type": "snapshot", "tree": {"id": "root", "type": "column", "props": {}, "children": []}});
let bytes = rmp_serde::to_vec_named(&external).unwrap();
let decoded: IncomingMessage = Codec::MsgPack.decode(&bytes).unwrap();
assert!(matches!(decoded, IncomingMessage::Snapshot { .. }));
}
#[test]
fn msgpack_image_op_with_native_binary() {
use rmpv::Value as RmpvValue;
let pixel_bytes: Vec<u8> = vec![255, 0, 0, 255, 0, 255, 0, 255]; let msg = RmpvValue::Map(vec![
(
RmpvValue::String("type".into()),
RmpvValue::String("image_op".into()),
),
(
RmpvValue::String("op".into()),
RmpvValue::String("create_image".into()),
),
(
RmpvValue::String("handle".into()),
RmpvValue::String("test_img".into()),
),
(
RmpvValue::String("pixels".into()),
RmpvValue::Binary(pixel_bytes.clone()),
),
(
RmpvValue::String("width".into()),
RmpvValue::Integer(1.into()),
),
(
RmpvValue::String("height".into()),
RmpvValue::Integer(2.into()),
),
]);
let mut buf = Vec::new();
rmpv::encode::write_value(&mut buf, &msg).unwrap();
let decoded: crate::protocol::IncomingMessage = Codec::MsgPack.decode(&buf).unwrap();
match decoded {
crate::protocol::IncomingMessage::ImageOp {
op,
handle,
pixels,
width,
height,
data,
} => {
assert_eq!(op, "create_image");
assert_eq!(handle, "test_img");
assert_eq!(pixels, Some(pixel_bytes));
assert_eq!(width, Some(1));
assert_eq!(height, Some(2));
assert!(data.is_none());
}
other => panic!("expected ImageOp, got {other:?}"),
}
}
#[test]
fn msgpack_image_op_with_base64_string() {
use crate::protocol::IncomingMessage;
use base64::Engine as _;
let pixel_bytes: Vec<u8> = vec![255, 0, 0, 255];
let b64 = base64::engine::general_purpose::STANDARD.encode(&pixel_bytes);
let json_msg = json!({
"type": "image_op",
"op": "create_image",
"handle": "test_img",
"pixels": b64,
"width": 1,
"height": 1
});
let json_str = serde_json::to_string(&json_msg).unwrap();
let decoded: IncomingMessage = Codec::Json.decode(json_str.as_bytes()).unwrap();
match decoded {
IncomingMessage::ImageOp { pixels, .. } => {
assert_eq!(pixels, Some(pixel_bytes));
}
other => panic!("expected ImageOp, got {other:?}"),
}
}
#[test]
fn rmpv_to_json_preserves_binary_as_array() {
let binary = rmpv::Value::Binary(vec![1, 2, 3]);
let result = rmpv_to_json(binary);
assert_eq!(result, json!([1, 2, 3]));
}
#[test]
fn rmpv_to_json_handles_nested_map() {
let val = rmpv::Value::Map(vec![
(
rmpv::Value::String("key".into()),
rmpv::Value::String("val".into()),
),
(
rmpv::Value::String("num".into()),
rmpv::Value::Integer(42.into()),
),
]);
let result = rmpv_to_json(val);
assert_eq!(result, json!({"key": "val", "num": 42}));
}
#[test]
fn detect_json_from_brace() {
assert_eq!(Codec::detect_from_first_byte(b'{'), Codec::Json);
}
#[test]
fn detect_msgpack_from_zero() {
assert_eq!(Codec::detect_from_first_byte(0x00), Codec::MsgPack);
}
#[test]
fn detect_msgpack_from_fixmap() {
assert_eq!(Codec::detect_from_first_byte(0x85), Codec::MsgPack);
}
#[test]
fn display_format() {
assert_eq!(Codec::Json.to_string(), "json");
assert_eq!(Codec::MsgPack.to_string(), "msgpack");
}
#[test]
fn rmpv_to_json_deeply_nested_maps() {
let val = rmpv::Value::Map(vec![(
rmpv::Value::String("outer".into()),
rmpv::Value::Map(vec![(
rmpv::Value::String("inner".into()),
rmpv::Value::Map(vec![(
rmpv::Value::String("deep".into()),
rmpv::Value::Integer(42.into()),
)]),
)]),
)]);
let result = rmpv_to_json(val);
assert_eq!(result, json!({"outer": {"inner": {"deep": 42}}}));
}
#[test]
fn rmpv_to_json_binary_in_nested_map() {
let val = rmpv::Value::Map(vec![
(
rmpv::Value::String("name".into()),
rmpv::Value::String("img".into()),
),
(
rmpv::Value::String("pixels".into()),
rmpv::Value::Binary(vec![255, 128, 0, 255]),
),
]);
let result = rmpv_to_json(val);
assert_eq!(result["name"], json!("img"));
assert_eq!(result["pixels"], json!([255, 128, 0, 255]));
}
#[test]
fn msgpack_roundtrip_with_binary_field() {
use rmpv::Value as RmpvValue;
let raw_bytes: Vec<u8> = vec![0xDE, 0xAD, 0xBE, 0xEF];
let msg = RmpvValue::Map(vec![
(
RmpvValue::String("type".into()),
RmpvValue::String("alpha".into()),
),
(
RmpvValue::String("value".into()),
RmpvValue::String("hello".into()),
),
(
RmpvValue::String("payload".into()),
RmpvValue::Binary(raw_bytes.clone()),
),
]);
let mut buf = Vec::new();
rmpv::encode::write_value(&mut buf, &msg).unwrap();
let rmpv_val: rmpv::Value = rmpv::decode::read_value(&mut &buf[..]).unwrap();
let json_val = rmpv_to_json(rmpv_val);
assert_eq!(json_val["type"], "alpha");
assert_eq!(json_val["value"], "hello");
let payload = json_val["payload"].as_array().unwrap();
let bytes: Vec<u8> = payload.iter().map(|v| v.as_u64().unwrap() as u8).collect();
assert_eq!(bytes, raw_bytes);
}
#[test]
fn rmpv_to_json_handles_nil_and_bool() {
assert_eq!(rmpv_to_json(rmpv::Value::Nil), json!(null));
assert_eq!(rmpv_to_json(rmpv::Value::Boolean(true)), json!(true));
assert_eq!(rmpv_to_json(rmpv::Value::Boolean(false)), json!(false));
}
#[test]
fn msgpack_depth_check_accepts_flat_map() {
let val = json!({"a": 1, "b": "hello", "c": true});
let bytes = rmp_serde::to_vec_named(&val).unwrap();
assert!(check_msgpack_depth(&bytes, 128).is_ok());
}
#[test]
fn msgpack_depth_check_accepts_nested_within_limit() {
let val = json!({"outer": {"middle": {"inner": 42}}});
let bytes = rmp_serde::to_vec_named(&val).unwrap();
assert!(check_msgpack_depth(&bytes, 3).is_ok());
}
#[test]
fn msgpack_depth_check_rejects_beyond_limit() {
let val = json!({"a": {"b": {"c": 1}}});
let bytes = rmp_serde::to_vec_named(&val).unwrap();
assert!(check_msgpack_depth(&bytes, 2).is_err());
}
#[test]
fn msgpack_depth_check_accepts_flat_array() {
let val = json!([1, 2, 3, 4, 5]);
let bytes = rmp_serde::to_vec_named(&val).unwrap();
assert!(check_msgpack_depth(&bytes, 1).is_ok());
}
#[test]
fn msgpack_depth_check_nested_arrays() {
let val = json!([[[42]]]);
let bytes = rmp_serde::to_vec_named(&val).unwrap();
assert!(check_msgpack_depth(&bytes, 3).is_ok());
assert!(check_msgpack_depth(&bytes, 2).is_err());
}
#[test]
fn msgpack_depth_check_mixed_containers() {
let val = json!({"list": [{"nested": true}]});
let bytes = rmp_serde::to_vec_named(&val).unwrap();
assert!(check_msgpack_depth(&bytes, 3).is_ok());
assert!(check_msgpack_depth(&bytes, 2).is_err());
}
#[test]
fn msgpack_depth_check_empty_containers() {
let val = json!({"empty_map": {}, "empty_arr": []});
let bytes = rmp_serde::to_vec_named(&val).unwrap();
assert!(check_msgpack_depth(&bytes, 2).is_ok());
}
#[test]
fn msgpack_depth_check_sibling_arrays_dont_add_depth() {
let val = json!([[1, 2], [3, 4]]);
let bytes = rmp_serde::to_vec_named(&val).unwrap();
assert!(check_msgpack_depth(&bytes, 2).is_ok());
}
#[test]
fn msgpack_depth_check_binary_data() {
use rmpv::Value as V;
let val = V::Map(vec![(
V::String("data".into()),
V::Binary(vec![0xDE, 0xAD]),
)]);
let mut bytes = Vec::new();
rmpv::encode::write_value(&mut bytes, &val).unwrap();
assert!(check_msgpack_depth(&bytes, 1).is_ok());
}
#[test]
fn msgpack_depth_check_deeply_nested_rejects() {
use rmpv::Value as V;
let depth = 200;
let mut val = V::Integer(1.into());
for _ in 0..depth {
val = V::Map(vec![(V::String("a".into()), val)]);
}
let mut bytes = Vec::new();
rmpv::encode::write_value(&mut bytes, &val).unwrap();
assert!(check_msgpack_depth(&bytes, 128).is_err());
assert!(check_msgpack_depth(&bytes, 200).is_ok());
}
#[test]
fn msgpack_decode_rejects_deeply_nested() {
use rmpv::Value as V;
let mut val = V::Integer(1.into());
for _ in 0..200 {
val = V::Map(vec![(V::String("a".into()), val)]);
}
let mut bytes = Vec::new();
rmpv::encode::write_value(&mut bytes, &val).unwrap();
let result: Result<serde_json::Value, _> = Codec::MsgPack.decode(&bytes);
assert!(result.is_err());
assert!(result.unwrap_err().contains("depth"));
}
#[test]
fn msgpack_depth_check_truncated_payload_does_not_panic() {
let val = json!({"a": {"b": [1, 2, 3]}});
let bytes = rmp_serde::to_vec_named(&val).unwrap();
for cut in [1, 3, 5, bytes.len() / 2] {
let _ = check_msgpack_depth(&bytes[..cut], 128);
}
assert!(check_msgpack_depth(&[0x81], 128).is_err()); assert!(check_msgpack_depth(&[0x91], 128).is_err()); assert!(check_msgpack_depth(&[0xdc], 128).is_ok()); assert!(check_msgpack_depth(&[0xde, 0x00], 128).is_ok()); }
#[test]
fn msgpack_depth_check_empty_input() {
assert!(check_msgpack_depth(&[], 128).is_ok());
}
#[test]
fn msgpack_depth_check_scalars_only() {
let val = json!(42);
let bytes = rmp_serde::to_vec_named(&val).unwrap();
assert!(check_msgpack_depth(&bytes, 0).is_ok());
}
#[test]
fn msgpack_depth_check_rejects_forged_element_count() {
let mut bytes = vec![0xdf]; bytes.extend_from_slice(&0xFFFF_FFFFu32.to_be_bytes()); bytes.extend_from_slice(&[0xa1, b'k', 0x01]);
let result = check_msgpack_depth(&bytes, 128);
assert!(result.is_err());
assert!(result.unwrap_err().contains("elements"));
}
#[test]
fn msgpack_decode_rejects_forged_element_count() {
let mut bytes = vec![0xdd]; bytes.extend_from_slice(&0x7FFF_FFFFu32.to_be_bytes()); bytes.push(0x01);
let result: Result<serde_json::Value, _> = Codec::MsgPack.decode(&bytes);
assert!(result.is_err());
assert!(result.unwrap_err().contains("elements"));
}
#[test]
fn json_to_rmpv_scalars() {
assert_eq!(json_to_rmpv(json!(null)), rmpv::Value::Nil);
assert_eq!(json_to_rmpv(json!(true)), rmpv::Value::Boolean(true));
assert_eq!(json_to_rmpv(json!(42)), rmpv::Value::Integer(42.into()));
assert_eq!(json_to_rmpv(json!(2.5)), rmpv::Value::F64(2.5));
assert_eq!(
json_to_rmpv(json!("hello")),
rmpv::Value::String("hello".into())
);
}
#[test]
fn json_to_rmpv_nested() {
let val = json!({"key": [1, "two", null]});
let rmpv = json_to_rmpv(val);
match rmpv {
rmpv::Value::Map(entries) => {
assert_eq!(entries.len(), 1);
let (k, v) = &entries[0];
assert_eq!(k, &rmpv::Value::String("key".into()));
match v {
rmpv::Value::Array(arr) => {
assert_eq!(arr.len(), 3);
assert_eq!(arr[0], rmpv::Value::Integer(1.into()));
assert_eq!(arr[2], rmpv::Value::Nil);
}
other => panic!("expected array, got {other:?}"),
}
}
other => panic!("expected map, got {other:?}"),
}
}
#[test]
fn encode_binary_message_json_without_binary() {
let mut map = serde_json::Map::new();
map.insert("type".to_string(), json!("test"));
map.insert("id".to_string(), json!("t1"));
let bytes = Codec::Json.encode_binary_message(map, None).unwrap();
let s = std::str::from_utf8(&bytes).unwrap();
assert!(s.ends_with('\n'));
let parsed: serde_json::Value = serde_json::from_str(s.trim()).unwrap();
assert_eq!(parsed["type"], "test");
assert_eq!(parsed["id"], "t1");
assert!(parsed.get("rgba").is_none());
}
#[test]
fn encode_binary_message_json_with_binary() {
use base64::Engine as _;
let mut map = serde_json::Map::new();
map.insert("type".to_string(), json!("screenshot"));
let pixel_data = vec![255u8, 0, 128, 64];
let bytes = Codec::Json
.encode_binary_message(map, Some(("rgba", &pixel_data)))
.unwrap();
let parsed: serde_json::Value = serde_json::from_slice(&bytes[..bytes.len() - 1]).unwrap();
let b64 = parsed["rgba"].as_str().unwrap();
let decoded = base64::engine::general_purpose::STANDARD
.decode(b64)
.unwrap();
assert_eq!(decoded, pixel_data);
}
#[test]
fn encode_binary_message_msgpack_with_binary() {
let mut map = serde_json::Map::new();
map.insert("type".to_string(), json!("screenshot"));
map.insert("id".to_string(), json!("s1"));
let pixel_data = vec![0xDE, 0xAD, 0xBE, 0xEF];
let bytes = Codec::MsgPack
.encode_binary_message(map, Some(("rgba", &pixel_data)))
.unwrap();
let payload = &bytes[4..];
let rmpv_val: rmpv::Value = rmpv::decode::read_value(&mut &payload[..]).unwrap();
match rmpv_val {
rmpv::Value::Map(entries) => {
let rgba_entry = entries
.iter()
.find(|(k, _)| k == &rmpv::Value::String("rgba".into()));
match rgba_entry {
Some((_, rmpv::Value::Binary(data))) => {
assert_eq!(data, &pixel_data);
}
other => panic!("expected Binary rgba field, got {other:?}"),
}
}
other => panic!("expected Map, got {other:?}"),
}
}
#[test]
fn encode_binary_message_msgpack_roundtrip_non_binary_fields() {
let mut map = serde_json::Map::new();
map.insert("type".to_string(), json!("test"));
map.insert("count".to_string(), json!(42));
map.insert("nested".to_string(), json!({"a": [1, 2]}));
let bytes = Codec::MsgPack.encode_binary_message(map, None).unwrap();
let decoded: serde_json::Value = Codec::MsgPack.decode(&bytes[4..]).unwrap();
assert_eq!(decoded["type"], "test");
assert_eq!(decoded["count"], 42);
assert_eq!(decoded["nested"]["a"][0], 1);
}
mod proptest_codec {
use super::*;
use proptest::prelude::*;
fn arb_json_value() -> impl Strategy<Value = serde_json::Value> {
let leaf = prop_oneof![
Just(serde_json::Value::Null),
any::<bool>().prop_map(serde_json::Value::Bool),
any::<i64>().prop_map(|n| serde_json::Value::Number(n.into())),
"[a-zA-Z0-9_ ]{0,20}".prop_map(serde_json::Value::String),
];
leaf.prop_recursive(
3, 32, 8, |inner| {
prop_oneof![
prop::collection::vec(inner.clone(), 0..5)
.prop_map(serde_json::Value::Array),
prop::collection::vec(("[a-z_]{1,8}", inner), 0..5).prop_map(|pairs| {
serde_json::Value::Object(pairs.into_iter().collect())
}),
]
},
)
}
proptest! {
#[test]
fn json_encode_decode_roundtrip(val in arb_json_value()) {
let bytes = Codec::Json.encode(&val).unwrap();
let decoded: serde_json::Value =
Codec::Json.decode(&bytes[..bytes.len() - 1]).unwrap();
prop_assert_eq!(decoded, val);
}
}
}
}