use crate::{McpError, Result};
use serde::{Serialize, de::DeserializeOwned};
pub use turbomcp_wire::{
AnyCodec, Codec, CodecError, CodecResult, JsonCodec, StreamingJsonDecoder,
};
#[cfg(feature = "wire-simd")]
pub use turbomcp_wire::SimdJsonCodec;
#[cfg(feature = "wire-msgpack")]
pub use turbomcp_wire::MsgPackCodec;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[non_exhaustive]
pub enum CodecType {
#[default]
Json,
SimdJson,
MessagePack,
}
impl CodecType {
#[must_use]
pub fn is_available(&self) -> bool {
match self {
CodecType::Json => true,
#[cfg(feature = "wire-simd")]
CodecType::SimdJson => true,
#[cfg(not(feature = "wire-simd"))]
CodecType::SimdJson => false,
#[cfg(feature = "wire-msgpack")]
CodecType::MessagePack => true,
#[cfg(not(feature = "wire-msgpack"))]
CodecType::MessagePack => false,
}
}
#[must_use]
pub const fn content_type(&self) -> &'static str {
match self {
CodecType::Json | CodecType::SimdJson => "application/json",
CodecType::MessagePack => "application/msgpack",
}
}
}
#[derive(Debug, Clone)]
pub struct ProtocolCodec {
inner: AnyCodec,
codec_type: CodecType,
}
impl Default for ProtocolCodec {
fn default() -> Self {
Self::new()
}
}
impl ProtocolCodec {
#[must_use]
pub fn new() -> Self {
Self {
inner: AnyCodec::Json(JsonCodec::new()),
codec_type: CodecType::Json,
}
}
#[must_use]
pub fn with_type(codec_type: CodecType) -> Self {
let (inner, actual_type) = match codec_type {
CodecType::Json => (AnyCodec::Json(JsonCodec::new()), CodecType::Json),
#[cfg(feature = "wire-simd")]
CodecType::SimdJson => (
AnyCodec::SimdJson(SimdJsonCodec::new()),
CodecType::SimdJson,
),
#[cfg(not(feature = "wire-simd"))]
CodecType::SimdJson => {
tracing::warn!("SIMD JSON codec not available, falling back to standard JSON");
(AnyCodec::Json(JsonCodec::new()), CodecType::Json)
}
#[cfg(feature = "wire-msgpack")]
CodecType::MessagePack => (
AnyCodec::MsgPack(MsgPackCodec::new()),
CodecType::MessagePack,
),
#[cfg(not(feature = "wire-msgpack"))]
CodecType::MessagePack => {
tracing::warn!("MessagePack codec not available, falling back to standard JSON");
(AnyCodec::Json(JsonCodec::new()), CodecType::Json)
}
};
Self {
inner,
codec_type: actual_type,
}
}
#[must_use]
pub fn json_pretty() -> Self {
Self {
inner: AnyCodec::Json(JsonCodec::pretty()),
codec_type: CodecType::Json,
}
}
#[must_use]
pub fn codec_type(&self) -> CodecType {
self.codec_type
}
#[must_use]
pub fn content_type(&self) -> &'static str {
self.inner.content_type()
}
#[must_use]
pub fn name(&self) -> &'static str {
self.inner.name()
}
pub fn encode<T: Serialize>(&self, value: &T) -> Result<Vec<u8>> {
self.inner
.encode(value)
.map_err(|e| McpError::parse_error(e.message))
}
pub fn decode<T: DeserializeOwned>(&self, bytes: &[u8]) -> Result<T> {
self.inner
.decode(bytes)
.map_err(|e| McpError::parse_error(e.message))
}
pub fn encode_string<T: Serialize>(&self, value: &T) -> Result<String> {
if matches!(self.codec_type, CodecType::MessagePack) {
return Err(McpError::invalid_request(
"Cannot encode MessagePack to string",
));
}
let bytes = self.encode(value)?;
String::from_utf8(bytes).map_err(|e| McpError::parse_error(format!("Invalid UTF-8: {e}")))
}
}
#[derive(Debug)]
pub struct StreamingEncoder {
codec: ProtocolCodec,
}
impl StreamingEncoder {
#[must_use]
pub fn new(codec: ProtocolCodec) -> Self {
Self { codec }
}
pub fn encode<T: Serialize>(&self, value: &T) -> Result<Vec<u8>> {
let mut bytes = self.codec.encode(value)?;
bytes.push(b'\n');
Ok(bytes)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::jsonrpc::JsonRpcRequest;
use crate::types::RequestId;
fn test_request(method: &str) -> JsonRpcRequest {
JsonRpcRequest::without_params(method.to_string(), RequestId::Number(1))
}
#[test]
fn test_protocol_codec_json() {
let codec = ProtocolCodec::new();
assert_eq!(codec.codec_type(), CodecType::Json);
assert_eq!(codec.content_type(), "application/json");
let request = test_request("test/ping");
let bytes = codec.encode(&request).unwrap();
let decoded: JsonRpcRequest = codec.decode(&bytes).unwrap();
assert_eq!(decoded.method, "test/ping");
}
#[test]
fn test_protocol_codec_pretty() {
let codec = ProtocolCodec::json_pretty();
let request = test_request("test");
let output = codec.encode_string(&request).unwrap();
assert!(output.contains('\n'));
}
#[test]
fn test_codec_type_availability() {
assert!(CodecType::Json.is_available());
#[cfg(feature = "wire-simd")]
assert!(CodecType::SimdJson.is_available());
#[cfg(not(feature = "wire-simd"))]
assert!(!CodecType::SimdJson.is_available());
}
#[test]
fn test_streaming_encoder() {
let codec = ProtocolCodec::new();
let encoder = StreamingEncoder::new(codec);
let request = test_request("test");
let bytes = encoder.encode(&request).unwrap();
assert!(bytes.ends_with(b"\n"));
}
#[test]
fn test_streaming_decoder_integration() {
let mut decoder = StreamingJsonDecoder::new();
let request = test_request("ping");
let codec = ProtocolCodec::new();
let mut bytes = codec.encode(&request).unwrap();
bytes.push(b'\n');
decoder.feed(&bytes);
let decoded: JsonRpcRequest = decoder.try_decode().unwrap().unwrap();
assert_eq!(decoded.method, "ping");
}
#[cfg(feature = "wire-simd")]
#[test]
fn test_simd_codec() {
let codec = ProtocolCodec::with_type(CodecType::SimdJson);
assert_eq!(codec.codec_type(), CodecType::SimdJson);
let request = test_request("simd/test");
let bytes = codec.encode(&request).unwrap();
let decoded: JsonRpcRequest = codec.decode(&bytes).unwrap();
assert_eq!(decoded.method, "simd/test");
}
#[cfg(feature = "wire-msgpack")]
#[test]
fn test_msgpack_codec() {
let codec = ProtocolCodec::with_type(CodecType::MessagePack);
assert_eq!(codec.codec_type(), CodecType::MessagePack);
assert_eq!(codec.content_type(), "application/msgpack");
let request = test_request("msgpack/test");
let bytes = codec.encode(&request).unwrap();
let decoded: JsonRpcRequest = codec.decode(&bytes).unwrap();
assert_eq!(decoded.method, "msgpack/test");
assert!(codec.encode_string(&request).is_err());
}
}