use bytes::{BufMut, BytesMut};
use tds_protocol::packet::{MAX_PACKET_SIZE, PACKET_HEADER_SIZE, PacketHeader};
use tokio_util::codec::{Decoder, Encoder};
use crate::error::CodecError;
#[derive(Debug, Clone)]
pub struct Packet {
pub header: PacketHeader,
pub payload: BytesMut,
}
impl Packet {
#[must_use]
pub fn new(header: PacketHeader, payload: BytesMut) -> Self {
Self { header, payload }
}
#[must_use]
pub fn total_size(&self) -> usize {
PACKET_HEADER_SIZE + self.payload.len()
}
#[must_use]
pub fn is_end_of_message(&self) -> bool {
self.header.is_end_of_message()
}
}
pub struct TdsCodec {
max_packet_size: usize,
packet_id: u8,
}
impl TdsCodec {
#[must_use]
pub fn new() -> Self {
Self {
max_packet_size: MAX_PACKET_SIZE,
packet_id: 1,
}
}
#[must_use]
pub fn with_max_packet_size(mut self, size: usize) -> Self {
self.max_packet_size = size.min(MAX_PACKET_SIZE);
self
}
fn next_packet_id(&mut self) -> u8 {
let id = self.packet_id;
self.packet_id = self.packet_id.wrapping_add(1);
if self.packet_id == 0 {
self.packet_id = 1;
}
id
}
pub fn reset_packet_id(&mut self) {
self.packet_id = 1;
}
}
impl Default for TdsCodec {
fn default() -> Self {
Self::new()
}
}
impl Decoder for TdsCodec {
type Item = Packet;
type Error = CodecError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.len() < PACKET_HEADER_SIZE {
return Ok(None);
}
let length = u16::from_be_bytes([src[2], src[3]]) as usize;
if length < PACKET_HEADER_SIZE {
return Err(CodecError::InvalidHeader);
}
if length > self.max_packet_size {
return Err(CodecError::PacketTooLarge {
size: length,
max: self.max_packet_size,
});
}
if src.len() < length {
src.reserve(length - src.len());
return Ok(None);
}
let packet_bytes = src.split_to(length);
let mut cursor = packet_bytes.as_ref();
let header = PacketHeader::decode(&mut cursor)?;
let payload = BytesMut::from(&packet_bytes[PACKET_HEADER_SIZE..]);
tracing::trace!(
packet_type = ?header.packet_type,
length = length,
is_eom = header.is_end_of_message(),
"decoded TDS packet"
);
Ok(Some(Packet::new(header, payload)))
}
}
impl Encoder<Packet> for TdsCodec {
type Error = CodecError;
fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> {
let total_length = PACKET_HEADER_SIZE + item.payload.len();
if total_length > self.max_packet_size {
return Err(CodecError::PacketTooLarge {
size: total_length,
max: self.max_packet_size,
});
}
dst.reserve(total_length);
let mut header = item.header;
header.length = total_length as u16;
header.packet_id = self.next_packet_id();
header.encode(dst);
dst.put_slice(&item.payload);
tracing::trace!(
packet_type = ?header.packet_type,
length = total_length,
packet_id = header.packet_id,
"encoded TDS packet"
);
Ok(())
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use tds_protocol::packet::{PacketStatus, PacketType};
#[test]
fn test_decode_packet() {
let mut codec = TdsCodec::new();
let mut data = BytesMut::new();
data.put_u8(PacketType::SqlBatch as u8); data.put_u8(PacketStatus::END_OF_MESSAGE.bits()); data.put_u16(12); data.put_u16(0); data.put_u8(1); data.put_u8(0); data.put_slice(b"test");
let packet = codec.decode(&mut data).unwrap().unwrap();
assert_eq!(packet.header.packet_type, PacketType::SqlBatch);
assert!(packet.header.is_end_of_message());
assert_eq!(&packet.payload[..], b"test");
}
#[test]
fn test_encode_packet() {
let mut codec = TdsCodec::new();
let header = PacketHeader::new(PacketType::SqlBatch, PacketStatus::END_OF_MESSAGE, 0);
let payload = BytesMut::from(&b"test"[..]);
let packet = Packet::new(header, payload);
let mut dst = BytesMut::new();
codec.encode(packet, &mut dst).unwrap();
assert_eq!(dst.len(), 12); assert_eq!(dst[0], PacketType::SqlBatch as u8);
}
#[test]
fn test_incomplete_packet() {
let mut codec = TdsCodec::new();
let mut data = BytesMut::new();
data.put_u8(PacketType::SqlBatch as u8);
data.put_u8(PacketStatus::END_OF_MESSAGE.bits());
data.put_u16(12); data.put_u16(0);
data.put_u8(1);
data.put_u8(0);
let result = codec.decode(&mut data).unwrap();
assert!(result.is_none()); }
}