use flate2::Compression;
use flate2::write::GzEncoder;
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use bytes::{Bytes, BytesMut};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[cfg(feature = "messagepack")]
use msgpacker::Packable;
use crate::types::{ContentType, ProtocolVersion, Timestamp};
use crate::{McpError as Error, Result};
#[cfg(feature = "messagepack")]
#[derive(Debug, Clone)]
pub enum JsonValue {
Null,
Bool(bool),
Number(f64),
String(String),
Array(Vec<JsonValue>),
Object(std::collections::HashMap<String, JsonValue>),
}
#[cfg(feature = "messagepack")]
impl JsonValue {
pub fn from_serde_json(value: &serde_json::Value) -> Self {
match value {
serde_json::Value::Null => JsonValue::Null,
serde_json::Value::Bool(b) => JsonValue::Bool(*b),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
JsonValue::Number(i as f64)
} else if let Some(u) = n.as_u64() {
JsonValue::Number(u as f64)
} else if let Some(f) = n.as_f64() {
JsonValue::Number(f)
} else {
JsonValue::Null
}
}
serde_json::Value::String(s) => JsonValue::String(s.clone()),
serde_json::Value::Array(arr) => {
JsonValue::Array(arr.iter().map(Self::from_serde_json).collect())
}
serde_json::Value::Object(obj) => {
let mut map = std::collections::HashMap::new();
for (k, v) in obj {
map.insert(k.clone(), Self::from_serde_json(v));
}
JsonValue::Object(map)
}
}
}
}
#[cfg(feature = "messagepack")]
impl msgpacker::Packable for JsonValue {
fn pack<T>(&self, buf: &mut T) -> usize
where
T: Extend<u8>,
{
match self {
JsonValue::Null => {
buf.extend([0xc0]);
1
}
JsonValue::Bool(b) => b.pack(buf),
JsonValue::Number(n) => n.pack(buf),
JsonValue::String(s) => s.pack(buf),
JsonValue::Array(arr) => {
let len = arr.len();
let mut bytes_written = 0;
if len <= 15 {
buf.extend([0x90 + len as u8]);
bytes_written += 1;
} else if len <= u16::MAX as usize {
buf.extend([0xdc]);
buf.extend((len as u16).to_be_bytes());
bytes_written += 3;
} else {
buf.extend([0xdd]);
buf.extend((len as u32).to_be_bytes());
bytes_written += 5;
}
for item in arr {
bytes_written += item.pack(buf);
}
bytes_written
}
JsonValue::Object(obj) => {
let len = obj.len();
let mut bytes_written = 0;
if len <= 15 {
buf.extend([0x80 + len as u8]);
bytes_written += 1;
} else if len <= u16::MAX as usize {
buf.extend([0xde]);
buf.extend((len as u16).to_be_bytes());
bytes_written += 3;
} else {
buf.extend([0xdf]);
buf.extend((len as u32).to_be_bytes());
bytes_written += 5;
}
for (k, v) in obj {
bytes_written += k.pack(buf);
bytes_written += v.pack(buf);
}
bytes_written
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageId {
String(String),
Number(i64),
Uuid(Uuid),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageMetadata {
pub created_at: Timestamp,
pub protocol_version: ProtocolVersion,
pub encoding: Option<String>,
pub content_type: ContentType,
pub size: usize,
pub correlation_id: Option<String>,
pub headers: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct Message {
pub id: MessageId,
pub metadata: MessageMetadata,
pub payload: MessagePayload,
}
#[derive(Debug, Clone)]
pub enum MessagePayload {
Json(JsonPayload),
Binary(BinaryPayload),
Text(String),
Empty,
}
#[derive(Debug, Clone)]
pub struct JsonPayload {
pub raw: Bytes,
pub parsed: Option<Arc<serde_json::Value>>,
pub is_valid: bool,
}
#[derive(Debug, Clone)]
pub struct BinaryPayload {
pub data: Bytes,
pub format: BinaryFormat,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum BinaryFormat {
MessagePack,
ProtoBuf,
Cbor,
Custom,
}
#[derive(Debug)]
pub struct MessageSerializer {
default_format: SerializationFormat,
enable_compression: bool,
compression_threshold: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SerializationFormat {
Json,
#[cfg(feature = "simd")]
SimdJson,
MessagePack,
Cbor,
}
impl Message {
pub fn json(id: MessageId, value: impl Serialize) -> Result<Self> {
let json_bytes = Self::serialize_json(&value)?;
let payload = MessagePayload::Json(JsonPayload {
raw: json_bytes.freeze(),
parsed: Some(Arc::new(serde_json::to_value(value)?)),
is_valid: true,
});
Ok(Self {
id,
metadata: MessageMetadata::new(ContentType::Json, payload.size()),
payload,
})
}
pub fn binary(id: MessageId, data: Bytes, format: BinaryFormat) -> Self {
let size = data.len();
let payload = MessagePayload::Binary(BinaryPayload { data, format });
Self {
id,
metadata: MessageMetadata::new(ContentType::Binary, size),
payload,
}
}
#[must_use]
pub fn text(id: MessageId, text: String) -> Self {
let size = text.len();
let payload = MessagePayload::Text(text);
Self {
id,
metadata: MessageMetadata::new(ContentType::Text, size),
payload,
}
}
#[must_use]
pub fn empty(id: MessageId) -> Self {
Self {
id,
metadata: MessageMetadata::new(ContentType::Json, 0),
payload: MessagePayload::Empty,
}
}
pub const fn size(&self) -> usize {
self.metadata.size
}
pub const fn is_empty(&self) -> bool {
matches!(self.payload, MessagePayload::Empty)
}
pub fn serialize(&self, format: SerializationFormat) -> Result<Bytes> {
match format {
SerializationFormat::Json => self.serialize_json_format(),
#[cfg(feature = "simd")]
SerializationFormat::SimdJson => self.serialize_simd_json(),
SerializationFormat::MessagePack => self.serialize_messagepack(),
SerializationFormat::Cbor => self.serialize_cbor(),
}
}
pub fn deserialize(bytes: Bytes) -> Result<Self> {
let format = Self::detect_format(&bytes);
Self::deserialize_with_format(bytes, format)
}
pub fn deserialize_with_format(bytes: Bytes, format: SerializationFormat) -> Result<Self> {
match format {
SerializationFormat::Json => Ok(Self::deserialize_json(bytes)),
#[cfg(feature = "simd")]
SerializationFormat::SimdJson => Ok(Self::deserialize_simd_json(bytes)),
SerializationFormat::MessagePack => Ok(Self::deserialize_messagepack(bytes)),
SerializationFormat::Cbor => Self::deserialize_cbor(bytes),
}
}
pub fn parse_json<T>(&self) -> Result<T>
where
T: for<'de> Deserialize<'de>,
{
match &self.payload {
MessagePayload::Json(json_payload) => json_payload.parsed.as_ref().map_or_else(
|| {
#[cfg(feature = "simd")]
{
let mut json_bytes = json_payload.raw.to_vec();
simd_json::from_slice(&mut json_bytes).map_err(|e| {
Error::serialization(format!("SIMD JSON parsing failed: {e}"))
})
}
#[cfg(not(feature = "simd"))]
{
serde_json::from_slice(&json_payload.raw).map_err(|e| {
Error::serialization(format!("JSON parsing failed: {}", e))
})
}
},
|parsed| {
serde_json::from_value((**parsed).clone())
.map_err(|e| Error::serialization(format!("JSON parsing failed: {e}")))
},
),
_ => Err(Error::invalid_params("Message payload is not JSON")),
}
}
fn serialize_json(value: &impl Serialize) -> Result<BytesMut> {
#[cfg(feature = "simd")]
{
sonic_rs::to_vec(value)
.map(|v| BytesMut::from(v.as_slice()))
.map_err(|e| Error::serialization(format!("SIMD JSON serialization failed: {e}")))
}
#[cfg(not(feature = "simd"))]
{
serde_json::to_vec(value)
.map(|v| BytesMut::from(v.as_slice()))
.map_err(|e| Error::serialization(format!("JSON serialization failed: {}", e)))
}
}
fn serialize_json_format(&self) -> Result<Bytes> {
match &self.payload {
MessagePayload::Json(json_payload) => Ok(json_payload.raw.clone()),
MessagePayload::Text(text) => Ok(Bytes::from(text.clone())),
MessagePayload::Empty => Ok(Bytes::from_static(b"{}")),
MessagePayload::Binary(_) => Err(Error::invalid_params(
"Cannot serialize non-JSON payload as JSON",
)),
}
}
#[cfg(feature = "simd")]
fn serialize_simd_json(&self) -> Result<Bytes> {
match &self.payload {
MessagePayload::Json(json_payload) => {
if json_payload.is_valid {
Ok(json_payload.raw.clone())
} else {
Err(Error::serialization("Invalid JSON payload"))
}
}
_ => Err(Error::invalid_params(
"Cannot serialize non-JSON payload with SIMD JSON",
)),
}
}
fn serialize_messagepack(&self) -> Result<Bytes> {
#[cfg(feature = "messagepack")]
{
match &self.payload {
MessagePayload::Binary(binary) if binary.format == BinaryFormat::MessagePack => {
Ok(binary.data.clone())
}
MessagePayload::Json(json_payload) => json_payload.parsed.as_ref().map_or_else(
|| {
Err(Error::serialization(
"Cannot serialize unparsed JSON to MessagePack",
))
},
|parsed| {
let packable_value = JsonValue::from_serde_json(parsed.as_ref());
let mut buffer = Vec::new();
packable_value.pack(&mut buffer);
Ok(Bytes::from(buffer))
},
),
_ => Err(Error::invalid_params(
"Cannot serialize payload as MessagePack",
)),
}
}
#[cfg(not(feature = "messagepack"))]
{
let _ = self; Err(Error::invalid_params(
"MessagePack serialization not available",
))
}
}
fn serialize_cbor(&self) -> Result<Bytes> {
match &self.payload {
MessagePayload::Binary(binary) if binary.format == BinaryFormat::Cbor => {
Ok(binary.data.clone())
}
MessagePayload::Json(json_payload) => {
if let Some(parsed) = &json_payload.parsed {
{
let mut buffer = Vec::new();
ciborium::into_writer(parsed.as_ref(), &mut buffer)
.map(|_| Bytes::from(buffer))
.map_err(|e| {
Error::serialization(format!("CBOR serialization failed: {e}"))
})
}
} else {
#[cfg(feature = "simd")]
{
let mut json_bytes = json_payload.raw.to_vec();
let value: serde_json::Value = simd_json::from_slice(&mut json_bytes)
.map_err(|e| {
Error::serialization(format!(
"SIMD JSON parsing failed before CBOR: {e}"
))
})?;
{
let mut buffer = Vec::new();
ciborium::into_writer(&value, &mut buffer)
.map(|_| Bytes::from(buffer))
.map_err(|e| {
Error::serialization(format!("CBOR serialization failed: {e}"))
})
}
}
#[cfg(not(feature = "simd"))]
{
let value: serde_json::Value = serde_json::from_slice(&json_payload.raw)
.map_err(|e| {
Error::serialization(format!(
"JSON parsing failed before CBOR: {}",
e
))
})?;
let mut buf = Vec::new();
ciborium::ser::into_writer(&value, &mut buf).map_err(|e| {
Error::serialization(format!("CBOR serialization failed: {}", e))
})?;
Ok(Bytes::from(buf))
}
}
}
_ => Err(Error::invalid_params("Cannot serialize payload as CBOR")),
}
}
fn deserialize_json(bytes: Bytes) -> Self {
let is_valid = serde_json::from_slice::<serde_json::Value>(&bytes).is_ok();
let payload = MessagePayload::Json(JsonPayload {
raw: bytes,
parsed: None, is_valid,
});
Self {
id: MessageId::Uuid(Uuid::new_v4()),
metadata: MessageMetadata::new(ContentType::Json, payload.size()),
payload,
}
}
#[cfg(feature = "simd")]
fn deserialize_simd_json(bytes: Bytes) -> Self {
let mut json_bytes = bytes.to_vec();
let is_valid = simd_json::from_slice::<serde_json::Value>(&mut json_bytes).is_ok();
let payload = MessagePayload::Json(JsonPayload {
raw: bytes,
parsed: None,
is_valid,
});
Self {
id: MessageId::Uuid(Uuid::new_v4()),
metadata: MessageMetadata::new(ContentType::Json, payload.size()),
payload,
}
}
fn deserialize_messagepack(bytes: Bytes) -> Self {
let payload = MessagePayload::Binary(BinaryPayload {
data: bytes,
format: BinaryFormat::MessagePack,
});
Self {
id: MessageId::Uuid(Uuid::new_v4()),
metadata: MessageMetadata::new(ContentType::Binary, payload.size()),
payload,
}
}
fn deserialize_cbor(bytes: Bytes) -> Result<Self> {
if let Ok(value) = ciborium::from_reader::<serde_json::Value, _>(&bytes[..]) {
let raw = serde_json::to_vec(&value)
.map(Bytes::from)
.map_err(|e| Error::serialization(format!("JSON re-encode failed: {e}")))?;
let payload = MessagePayload::Json(JsonPayload {
raw,
parsed: Some(Arc::new(value)),
is_valid: true,
});
return Ok(Self {
id: MessageId::Uuid(Uuid::new_v4()),
metadata: MessageMetadata::new(ContentType::Json, payload.size()),
payload,
});
}
let payload = MessagePayload::Binary(BinaryPayload {
data: bytes,
format: BinaryFormat::Cbor,
});
Ok(Self {
id: MessageId::Uuid(Uuid::new_v4()),
metadata: MessageMetadata::new(ContentType::Binary, payload.size()),
payload,
})
}
fn detect_format(bytes: &[u8]) -> SerializationFormat {
if bytes.is_empty() {
return SerializationFormat::Json;
}
if matches!(bytes[0], b'{' | b'[') {
#[cfg(feature = "simd")]
{
return SerializationFormat::SimdJson;
}
#[cfg(not(feature = "simd"))]
{
return SerializationFormat::Json;
}
}
if bytes.len() >= 2 && (bytes[0] == 0x82 || bytes[0] == 0x83) {
return SerializationFormat::MessagePack;
}
#[cfg(feature = "simd")]
{
SerializationFormat::SimdJson
}
#[cfg(not(feature = "simd"))]
{
SerializationFormat::Json
}
}
}
impl MessagePayload {
pub const fn size(&self) -> usize {
match self {
Self::Json(json) => json.raw.len(),
Self::Binary(binary) => binary.data.len(),
Self::Text(text) => text.len(),
Self::Empty => 0,
}
}
}
impl MessageMetadata {
#[must_use]
pub fn new(content_type: ContentType, size: usize) -> Self {
Self {
created_at: Timestamp::now(),
protocol_version: ProtocolVersion::LATEST.clone(),
encoding: None,
content_type,
size,
correlation_id: None,
headers: HashMap::new(),
}
}
#[must_use]
pub fn with_header(mut self, key: String, value: String) -> Self {
self.headers.insert(key, value);
self
}
#[must_use]
pub fn with_correlation_id(mut self, correlation_id: String) -> Self {
self.correlation_id = Some(correlation_id);
self
}
#[must_use]
pub fn with_encoding(mut self, encoding: String) -> Self {
self.encoding = Some(encoding);
self
}
}
impl MessageSerializer {
#[must_use]
pub const fn new() -> Self {
Self {
default_format: SerializationFormat::Json,
enable_compression: false,
compression_threshold: 1024, }
}
#[must_use]
pub const fn with_format(mut self, format: SerializationFormat) -> Self {
self.default_format = format;
self
}
#[must_use]
pub const fn with_compression(mut self, enable: bool, threshold: usize) -> Self {
self.enable_compression = enable;
self.compression_threshold = threshold;
self
}
pub fn serialize(&self, message: &mut Message) -> Result<Bytes> {
let serialized = message.serialize(self.default_format)?;
if self.enable_compression && serialized.len() > self.compression_threshold {
message.metadata.encoding = Some("gzip".to_string()); Ok(self.compress(serialized))
} else {
Ok(serialized)
}
}
fn compress(&self, data: Bytes) -> Bytes {
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
if let Err(e) = std::io::Write::write_all(&mut encoder, &data) {
eprintln!("Failed to compress data: {}", e);
return data; }
match encoder.finish() {
Ok(compressed_data) => Bytes::from(compressed_data),
Err(e) => {
eprintln!("Failed to finish compression: {}", e);
data }
}
}
}
impl Default for MessageSerializer {
fn default() -> Self {
Self::new()
}
}
impl fmt::Display for MessageId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::String(s) => write!(f, "{s}"),
Self::Number(n) => write!(f, "{n}"),
Self::Uuid(u) => write!(f, "{u}"),
}
}
}
impl From<String> for MessageId {
fn from(s: String) -> Self {
Self::String(s)
}
}
impl From<&str> for MessageId {
fn from(s: &str) -> Self {
Self::String(s.to_string())
}
}
impl From<i64> for MessageId {
fn from(n: i64) -> Self {
Self::Number(n)
}
}
impl From<Uuid> for MessageId {
fn from(u: Uuid) -> Self {
Self::Uuid(u)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_message_creation() {
let message = Message::json(MessageId::from("test"), json!({"key": "value"})).unwrap();
assert_eq!(message.id.to_string(), "test");
assert!(!message.is_empty());
}
#[test]
fn test_message_serialization() {
let message = Message::json(MessageId::from(1), json!({"test": true})).unwrap();
let serialized = message.serialize(SerializationFormat::Json).unwrap();
assert!(!serialized.is_empty());
}
#[derive(Deserialize, PartialEq, Debug)]
struct TestData {
number: i32,
}
#[test]
fn test_message_parsing() {
let message = Message::json(MessageId::from("test"), json!({"number": 42})).unwrap();
let parsed: TestData = message.parse_json().unwrap();
assert_eq!(parsed.number, 42);
}
#[test]
fn test_format_detection() {
let json_bytes = Bytes::from(r#"{"test": true}"#);
let format = Message::detect_format(&json_bytes);
#[cfg(feature = "simd")]
assert_eq!(format, SerializationFormat::SimdJson);
#[cfg(not(feature = "simd"))]
assert_eq!(format, SerializationFormat::Json);
}
#[test]
fn test_message_metadata() {
let metadata = MessageMetadata::new(ContentType::Json, 100)
.with_header("custom".to_string(), "value".to_string())
.with_correlation_id("corr-123".to_string());
assert_eq!(metadata.size, 100);
assert_eq!(metadata.headers.get("custom"), Some(&"value".to_string()));
assert_eq!(metadata.correlation_id, Some("corr-123".to_string()));
}
#[test]
fn test_message_serializer_compression() {
use flate2::read::GzDecoder;
use std::io::Read;
let serializer = MessageSerializer::new().with_compression(true, 10);
let large_json = json!({
"data": "a".repeat(100), });
let mut message =
Message::json(MessageId::from("compressed_test"), large_json.clone()).unwrap();
let original_size = message.size();
assert!(
original_size > 10,
"Original message size should be greater than compression threshold"
);
let compressed_bytes = serializer.serialize(&mut message).unwrap();
assert_eq!(message.metadata.encoding, Some("gzip".to_string()));
assert!(
compressed_bytes.len() < original_size,
"Compressed size should be smaller than original"
);
let mut decoder = GzDecoder::new(&compressed_bytes[..]);
let mut decompressed_data = Vec::new();
decoder.read_to_end(&mut decompressed_data).unwrap();
let decompressed_message = Message::deserialize(Bytes::from(decompressed_data)).unwrap();
let parsed_json: serde_json::Value = decompressed_message.parse_json().unwrap();
assert_eq!(parsed_json, large_json);
}
}