use bytes::{Buf, BufMut, BytesMut};
use serde::{de::DeserializeOwned, Serialize};
use thiserror::Error;
use tokio_util::codec::{Decoder, Encoder};
const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
#[derive(Debug, Error)]
pub enum CodecError {
#[error("Frame too large: {0} bytes (max {MAX_FRAME_SIZE})")]
FrameTooLarge(usize),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
}
pub struct TunnelCodec<T> {
_phantom: std::marker::PhantomData<T>,
}
impl<T> TunnelCodec<T> {
pub fn new() -> Self {
Self {
_phantom: std::marker::PhantomData,
}
}
}
impl<T> Default for TunnelCodec<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: DeserializeOwned> Decoder for TunnelCodec<T> {
type Item = T;
type Error = CodecError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.len() < 4 {
return Ok(None);
}
let length = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
if length > MAX_FRAME_SIZE {
return Err(CodecError::FrameTooLarge(length));
}
let total_len = 4 + length;
if src.len() < total_len {
src.reserve(total_len - src.len());
return Ok(None);
}
src.advance(4);
let payload = src.split_to(length);
let message = serde_json::from_slice(&payload)?;
Ok(Some(message))
}
}
impl<T: Serialize> Encoder<T> for TunnelCodec<T> {
type Error = CodecError;
fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
let json = serde_json::to_vec(&item)?;
if json.len() > MAX_FRAME_SIZE {
return Err(CodecError::FrameTooLarge(json.len()));
}
dst.reserve(4 + json.len());
dst.put_u32(json.len() as u32);
dst.put_slice(&json);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::messages::{ClientMessage, ServerMessage, TunnelType};
#[test]
fn test_roundtrip_client_message() {
let mut codec = TunnelCodec::<ClientMessage>::new();
let msg = ClientMessage::RequestTunnel {
subdomain: Some("test".to_string()),
tunnel_type: TunnelType::Http,
local_port: 8080,
};
let mut buf = BytesMut::new();
codec.encode(msg.clone(), &mut buf).unwrap();
let decoded = codec.decode(&mut buf).unwrap().unwrap();
match decoded {
ClientMessage::RequestTunnel {
subdomain,
tunnel_type,
local_port,
} => {
assert_eq!(subdomain, Some("test".to_string()));
assert_eq!(tunnel_type, TunnelType::Http);
assert_eq!(local_port, 8080);
}
_ => panic!("Wrong variant"),
}
}
#[test]
fn test_roundtrip_server_message() {
let mut codec = TunnelCodec::<ServerMessage>::new();
let msg = ServerMessage::HttpRequest {
stream_id: 42,
method: "GET".to_string(),
uri: "/api/test".to_string(),
headers: vec![("Host".to_string(), "example.com".to_string())],
body: vec![],
};
let mut buf = BytesMut::new();
codec.encode(msg, &mut buf).unwrap();
let decoded = codec.decode(&mut buf).unwrap().unwrap();
match decoded {
ServerMessage::HttpRequest {
stream_id,
method,
uri,
..
} => {
assert_eq!(stream_id, 42);
assert_eq!(method, "GET");
assert_eq!(uri, "/api/test");
}
_ => panic!("Wrong variant"),
}
}
#[test]
fn test_partial_frame() {
let mut codec = TunnelCodec::<ClientMessage>::new();
let msg = ClientMessage::Ping { timestamp: 12345 };
let mut buf = BytesMut::new();
codec.encode(msg, &mut buf).unwrap();
let full_len = buf.len();
let mut partial = buf.split_to(full_len / 2);
assert!(codec.decode(&mut partial).unwrap().is_none());
partial.unsplit(buf);
let decoded = codec.decode(&mut partial).unwrap().unwrap();
match decoded {
ClientMessage::Ping { timestamp } => assert_eq!(timestamp, 12345),
_ => panic!("Wrong variant"),
}
}
}