use bytes::{Buf, BufMut, BytesMut};
use capnweb_core::protocol::Message;
use serde_json;
use std::io;
use tokio_util::codec::{Decoder, Encoder};
pub struct CapnWebCodec {
max_frame_size: usize,
}
impl CapnWebCodec {
pub fn new() -> Self {
Self {
max_frame_size: 10 * 1024 * 1024, }
}
pub fn with_max_frame_size(max_frame_size: usize) -> Self {
Self { max_frame_size }
}
}
impl Default for CapnWebCodec {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub enum FrameFormat {
LengthPrefixed,
NewlineDelimited,
}
impl Decoder for CapnWebCodec {
type Item = Message;
type Error = CodecError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.len() < 4 {
return Ok(None);
}
let mut length_bytes = [0u8; 4];
length_bytes.copy_from_slice(&src[..4]);
let frame_len = u32::from_be_bytes(length_bytes) as usize;
if frame_len > self.max_frame_size {
return Err(CodecError::FrameTooLarge(frame_len));
}
if src.len() < 4 + frame_len {
src.reserve(4 + frame_len - src.len());
return Ok(None);
}
src.advance(4); let frame_data = src.split_to(frame_len);
let json_value: serde_json::Value = serde_json::from_slice(&frame_data)
.map_err(|e| CodecError::JsonError(e.to_string()))?;
let message =
Message::from_json(&json_value).map_err(|e| CodecError::MessageError(e.to_string()))?;
Ok(Some(message))
}
}
impl Encoder<Message> for CapnWebCodec {
type Error = CodecError;
fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
let json_value = item.to_json();
let json_bytes =
serde_json::to_vec(&json_value).map_err(|e| CodecError::JsonError(e.to_string()))?;
if json_bytes.len() > self.max_frame_size {
return Err(CodecError::FrameTooLarge(json_bytes.len()));
}
let length = json_bytes.len() as u32;
dst.reserve(4 + json_bytes.len());
dst.put_u32(length);
dst.put_slice(&json_bytes);
Ok(())
}
}
pub struct NewlineDelimitedCodec {
max_line_length: usize,
}
impl NewlineDelimitedCodec {
pub fn new() -> Self {
Self {
max_line_length: 1024 * 1024, }
}
}
impl Default for NewlineDelimitedCodec {
fn default() -> Self {
Self::new()
}
}
impl Decoder for NewlineDelimitedCodec {
type Item = Message;
type Error = CodecError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let newline_pos = src.iter().position(|&b| b == b'\n');
if let Some(pos) = newline_pos {
if pos > self.max_line_length {
return Err(CodecError::LineTooLong(pos));
}
let line = src.split_to(pos);
src.advance(1);
let json_value: serde_json::Value =
serde_json::from_slice(&line).map_err(|e| CodecError::JsonError(e.to_string()))?;
let message = Message::from_json(&json_value)
.map_err(|e| CodecError::MessageError(e.to_string()))?;
Ok(Some(message))
} else {
if src.len() > self.max_line_length {
return Err(CodecError::LineTooLong(src.len()));
}
Ok(None)
}
}
}
impl Encoder<Message> for NewlineDelimitedCodec {
type Error = CodecError;
fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
let json_value = item.to_json();
let json_bytes =
serde_json::to_vec(&json_value).map_err(|e| CodecError::JsonError(e.to_string()))?;
if json_bytes.len() > self.max_line_length {
return Err(CodecError::LineTooLong(json_bytes.len()));
}
dst.reserve(json_bytes.len() + 1);
dst.put_slice(&json_bytes);
dst.put_u8(b'\n');
Ok(())
}
}
#[derive(Debug, thiserror::Error)]
pub enum CodecError {
#[error("Frame too large: {0} bytes")]
FrameTooLarge(usize),
#[error("Line too long: {0} bytes")]
LineTooLong(usize),
#[error("JSON error: {0}")]
JsonError(String),
#[error("Message parse error: {0}")]
MessageError(String),
#[error("IO error: {0}")]
IoError(#[from] io::Error),
}
#[cfg(test)]
mod tests {
use super::*;
use capnweb_core::{
protocol::{ExportId, ImportId},
Expression,
};
#[test]
fn test_length_prefixed_encode_decode() {
let mut codec = CapnWebCodec::new();
let mut buffer = BytesMut::new();
let msg = Message::Push(Expression::String("test".to_string()));
codec.encode(msg.clone(), &mut buffer).unwrap();
assert!(buffer.len() > 4);
let decoded = codec.decode(&mut buffer).unwrap().unwrap();
match decoded {
Message::Push(expr) => {
assert_eq!(expr, Expression::String("test".to_string()));
}
_ => panic!("Wrong message type"),
}
assert_eq!(buffer.len(), 0);
}
#[test]
fn test_newline_delimited_encode_decode() {
let mut codec = NewlineDelimitedCodec::new();
let mut buffer = BytesMut::new();
let msg = Message::Pull(ImportId(42));
codec.encode(msg, &mut buffer).unwrap();
assert_eq!(buffer[buffer.len() - 1], b'\n');
let decoded = codec.decode(&mut buffer).unwrap().unwrap();
match decoded {
Message::Pull(id) => {
assert_eq!(id, ImportId(42));
}
_ => panic!("Wrong message type"),
}
}
#[test]
fn test_partial_frame() {
let mut codec = CapnWebCodec::new();
let mut buffer = BytesMut::new();
buffer.put_u8(0);
buffer.put_u8(0);
assert!(codec.decode(&mut buffer).unwrap().is_none());
buffer.put_u8(0);
buffer.put_u8(10);
assert!(codec.decode(&mut buffer).unwrap().is_none());
}
#[test]
fn test_frame_too_large() {
let mut codec = CapnWebCodec::with_max_frame_size(100);
let mut buffer = BytesMut::new();
let large_string = "x".repeat(200);
let msg = Message::Push(Expression::String(large_string));
assert!(codec.encode(msg, &mut buffer).is_err());
}
#[test]
fn test_multiple_messages() {
let mut codec = NewlineDelimitedCodec::new();
let mut buffer = BytesMut::new();
let msg1 = Message::Push(Expression::String("first".to_string()));
let msg2 = Message::Pull(ImportId(1));
let msg3 = Message::Resolve(
ExportId(-1),
Expression::Number(serde_json::Number::from(42)),
);
codec.encode(msg1, &mut buffer).unwrap();
codec.encode(msg2, &mut buffer).unwrap();
codec.encode(msg3, &mut buffer).unwrap();
let decoded1 = codec.decode(&mut buffer).unwrap().unwrap();
match decoded1 {
Message::Push(expr) => {
assert_eq!(expr, Expression::String("first".to_string()));
}
_ => panic!("Wrong message type"),
}
let decoded2 = codec.decode(&mut buffer).unwrap().unwrap();
match decoded2 {
Message::Pull(id) => {
assert_eq!(id, ImportId(1));
}
_ => panic!("Wrong message type"),
}
let decoded3 = codec.decode(&mut buffer).unwrap().unwrap();
match decoded3 {
Message::Resolve(id, expr) => {
assert_eq!(id, ExportId(-1));
match expr {
Expression::Number(n) => assert_eq!(n.as_i64(), Some(42)),
_ => panic!("Wrong expression type"),
}
}
_ => panic!("Wrong message type"),
}
}
}