use serde::{Deserialize, Serialize};
pub use tokio_util::codec::LengthDelimitedCodec;
use crate::protocol::{ClientConnect, ClientHello, ServerConnectResponse};
pub const MAX_FRAME_LENGTH: usize = 8 * 1024 * 1024;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ClientMessage {
Hello(ClientHello),
Connect(ClientConnect),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ServerMessage {
ConnectResponse(ServerConnectResponse),
}
pub fn length_codec() -> LengthDelimitedCodec {
LengthDelimitedCodec::builder()
.length_field_offset(0)
.length_field_length(4)
.length_adjustment(0)
.num_skip(4)
.max_frame_length(MAX_FRAME_LENGTH)
.new_codec()
}
pub const LENGTH_PREFIX_SIZE: usize = 4;
#[cfg(test)]
mod tests {
use bytes::{Bytes, BytesMut};
use tokio_util::codec::{Decoder, Encoder};
use super::*;
use crate::protocol::{
ClientHello, ConnectErrorKind, ServerConnectResponse, PROTOCOL_VERSION, decode, encode,
};
#[test]
fn test_length_codec_encode_decode_roundtrip() {
let mut codec = length_codec();
let mut buf = BytesMut::new();
Encoder::<Bytes>::encode(&mut codec, Bytes::from_static(b"hello world"), &mut buf).unwrap();
let decoded = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded.as_ref(), b"hello world");
}
#[test]
fn test_length_codec_encode_decode_empty_frame() {
let mut codec = length_codec();
let mut buf = BytesMut::new();
Encoder::<Bytes>::encode(&mut codec, Bytes::new(), &mut buf).unwrap();
let decoded = codec.decode(&mut buf).unwrap().unwrap();
assert!(decoded.is_empty());
}
#[test]
fn test_length_codec_partial_data_returns_none() {
let mut codec = length_codec();
let mut buf = BytesMut::from(&[0u8, 0][..]);
let result = codec.decode(&mut buf).unwrap();
assert!(result.is_none());
}
#[test]
fn test_length_codec_multi_frame_sequence() {
let mut codec = length_codec();
let mut buf = BytesMut::new();
Encoder::<Bytes>::encode(&mut codec, Bytes::from_static(b"first"), &mut buf).unwrap();
Encoder::<Bytes>::encode(&mut codec, Bytes::from_static(b"second"), &mut buf).unwrap();
let first = codec.decode(&mut buf).unwrap().unwrap();
let second = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(first.as_ref(), b"first");
assert_eq!(second.as_ref(), b"second");
}
#[test]
fn test_length_codec_frame_at_exact_max_size() {
let mut codec = length_codec();
let mut buf = BytesMut::new();
let payload = Bytes::from(vec![0u8; MAX_FRAME_LENGTH]);
Encoder::<Bytes>::encode(&mut codec, payload.clone(), &mut buf).unwrap();
let decoded = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded.len(), MAX_FRAME_LENGTH);
}
#[test]
fn test_length_codec_frame_exceeds_max_size_rejected() {
let mut codec = length_codec();
let oversized_len = (MAX_FRAME_LENGTH + 1) as u32;
let mut buf = BytesMut::new();
buf.extend_from_slice(&oversized_len.to_be_bytes());
buf.extend_from_slice(&[0u8; 16]); let result = codec.decode(&mut buf);
assert!(result.is_err(), "expected error for oversized frame");
}
#[test]
fn test_length_codec_length_prefix_is_4_bytes() {
let mut codec = length_codec();
let mut buf = BytesMut::new();
let payload = Bytes::from(vec![99u8; 16]);
Encoder::<Bytes>::encode(&mut codec, payload, &mut buf).unwrap();
assert_eq!(buf.len(), LENGTH_PREFIX_SIZE + 16);
}
#[test]
fn test_client_message_hello_roundtrip() {
let msg = ClientMessage::Hello(ClientHello {
version: PROTOCOL_VERSION,
secret: [1u8; 32],
options: Bytes::new(),
});
let bytes = encode(&msg).unwrap();
let decoded: ClientMessage = decode(&bytes).unwrap();
assert_eq!(decoded, msg);
}
#[test]
fn test_server_message_connect_response_roundtrip() {
let ok = ServerMessage::ConnectResponse(ServerConnectResponse::Ok);
let bytes = encode(&ok).unwrap();
let decoded: ServerMessage = decode(&bytes).unwrap();
assert_eq!(decoded, ok);
let err_msg = ServerMessage::ConnectResponse(ServerConnectResponse::Err {
kind: ConnectErrorKind::TimedOut,
message: "timed out".to_string(),
});
let bytes = encode(&err_msg).unwrap();
let decoded: ServerMessage = decode(&bytes).unwrap();
assert_eq!(decoded, err_msg);
}
}