use bytes::{BufMut, Bytes, BytesMut};
use crate::codec::Error;
pub mod tcp {
use super::*;
pub const LENGTH_PREFIX_SIZE: usize = 2;
pub fn try_encode_length_prefix(message: &Bytes) -> Result<Bytes, Error> {
let msg_len = message.len();
if msg_len > usize::from(u16::MAX) {
return Err(Error::MessageTooLong(msg_len));
}
let mut buf = BytesMut::with_capacity(LENGTH_PREFIX_SIZE + msg_len);
buf.put_u16(msg_len as u16);
buf.put_slice(message);
Ok(buf.freeze())
}
#[must_use]
pub fn encode_length_prefix(message: &Bytes) -> Bytes {
try_encode_length_prefix(message).expect("DNS message too large for TCP framing")
}
pub fn decode_length_prefix(frame_start: &[u8]) -> Result<u16, Error> {
if frame_start.len() < LENGTH_PREFIX_SIZE {
return Err(Error::TruncatedLengthPrefix(frame_start.len()));
}
Ok(u16::from_be_bytes([frame_start[0], frame_start[1]]))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tcp_encode_prepends_big_endian_length() {
let msg = Bytes::from_static(b"hello");
let framed = tcp::encode_length_prefix(&msg);
assert_eq!(framed[0], 0x00);
assert_eq!(framed[1], 0x05);
assert_eq!(&framed[2..], b"hello");
}
#[test]
fn tcp_encode_length_larger_than_255() {
let msg = Bytes::from(vec![0u8; 300]);
let framed = tcp::encode_length_prefix(&msg);
assert_eq!(framed[0], 0x01); assert_eq!(framed[1], 0x2C); assert_eq!(framed.len(), 302);
}
#[test]
fn tcp_encode_empty_message() {
let msg = Bytes::new();
let framed = tcp::encode_length_prefix(&msg);
assert_eq!(&framed[..], &[0x00, 0x00]);
}
#[test]
fn tcp_try_encode_rejects_oversized_message() {
let msg = Bytes::from(vec![0u8; usize::from(u16::MAX) + 1]);
let err = tcp::try_encode_length_prefix(&msg).unwrap_err();
assert!(matches!(err, Error::MessageTooLong(65536)));
}
#[test]
fn tcp_decode_reads_length() {
let prefix = [0x00, 0x05u8];
assert_eq!(tcp::decode_length_prefix(&prefix).unwrap(), 5u16);
}
#[test]
fn tcp_decode_reads_larger_length() {
let prefix = [0x01, 0x2Cu8]; assert_eq!(tcp::decode_length_prefix(&prefix).unwrap(), 300u16);
}
#[test]
fn tcp_decode_ignores_trailing_bytes() {
let frame = [0x00, 0x03, 0xFF, 0xFF, 0xFF];
assert_eq!(tcp::decode_length_prefix(&frame).unwrap(), 3u16);
}
#[test]
fn tcp_decode_truncated_0_bytes_returns_error() {
let err = tcp::decode_length_prefix(&[]).unwrap_err();
assert!(
matches!(err, Error::TruncatedLengthPrefix(0)),
"unexpected error: {err}"
);
}
#[test]
fn tcp_decode_truncated_1_byte_returns_error() {
let err = tcp::decode_length_prefix(&[0x01]).unwrap_err();
assert!(
matches!(err, Error::TruncatedLengthPrefix(1)),
"unexpected error: {err}"
);
}
#[test]
fn tcp_round_trip() {
let msg = Bytes::from_static(b"round-trip-test");
let framed = tcp::encode_length_prefix(&msg);
let declared_len = tcp::decode_length_prefix(&framed).unwrap();
assert_eq!(declared_len as usize, msg.len());
let body = &framed[tcp::LENGTH_PREFIX_SIZE..];
assert_eq!(body, &msg[..]);
}
#[test]
fn tcp_round_trip_via_reader() {
use crate::codec::reader::Reader;
let msg = Bytes::from_static(b"dns-test-message");
let framed = tcp::encode_length_prefix(&msg);
let mut r = Reader::new(framed);
let len = r.read_u16().unwrap() as usize;
let body = r.read_slice(len).unwrap();
assert_eq!(&body[..], &msg[..]);
}
}