mssql_codec/
packet_codec.rs

1//! TDS packet codec implementation.
2
3use bytes::{BufMut, BytesMut};
4use tds_protocol::packet::{MAX_PACKET_SIZE, PACKET_HEADER_SIZE, PacketHeader};
5use tokio_util::codec::{Decoder, Encoder};
6
7use crate::error::CodecError;
8
9/// A TDS packet with header and payload.
10#[derive(Debug, Clone)]
11pub struct Packet {
12    /// Packet header.
13    pub header: PacketHeader,
14    /// Packet payload (excluding header).
15    pub payload: BytesMut,
16}
17
18impl Packet {
19    /// Create a new packet with the given header and payload.
20    #[must_use]
21    pub fn new(header: PacketHeader, payload: BytesMut) -> Self {
22        Self { header, payload }
23    }
24
25    /// Get the total packet size including header.
26    #[must_use]
27    pub fn total_size(&self) -> usize {
28        PACKET_HEADER_SIZE + self.payload.len()
29    }
30
31    /// Check if this is the last packet in a message.
32    #[must_use]
33    pub fn is_end_of_message(&self) -> bool {
34        self.header.is_end_of_message()
35    }
36}
37
38/// TDS packet codec for tokio-util framing.
39///
40/// This codec handles the low-level encoding and decoding of TDS packets
41/// over a byte stream.
42pub struct TdsCodec {
43    /// Maximum packet size to accept.
44    max_packet_size: usize,
45    /// Current packet sequence number for encoding.
46    packet_id: u8,
47}
48
49impl TdsCodec {
50    /// Create a new TDS codec with default settings.
51    #[must_use]
52    pub fn new() -> Self {
53        Self {
54            max_packet_size: MAX_PACKET_SIZE,
55            packet_id: 1,
56        }
57    }
58
59    /// Create a new TDS codec with a custom maximum packet size.
60    #[must_use]
61    pub fn with_max_packet_size(mut self, size: usize) -> Self {
62        self.max_packet_size = size.min(MAX_PACKET_SIZE);
63        self
64    }
65
66    /// Get the next packet ID and increment the counter.
67    fn next_packet_id(&mut self) -> u8 {
68        let id = self.packet_id;
69        self.packet_id = self.packet_id.wrapping_add(1);
70        if self.packet_id == 0 {
71            self.packet_id = 1;
72        }
73        id
74    }
75
76    /// Reset the packet ID counter.
77    pub fn reset_packet_id(&mut self) {
78        self.packet_id = 1;
79    }
80}
81
82impl Default for TdsCodec {
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88impl Decoder for TdsCodec {
89    type Item = Packet;
90    type Error = CodecError;
91
92    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
93        // Need at least a header to proceed
94        if src.len() < PACKET_HEADER_SIZE {
95            return Ok(None);
96        }
97
98        // Peek at the header to get the length
99        let length = u16::from_be_bytes([src[2], src[3]]) as usize;
100
101        // Validate packet length
102        if length < PACKET_HEADER_SIZE {
103            return Err(CodecError::InvalidHeader);
104        }
105        if length > self.max_packet_size {
106            return Err(CodecError::PacketTooLarge {
107                size: length,
108                max: self.max_packet_size,
109            });
110        }
111
112        // Check if we have the complete packet
113        if src.len() < length {
114            // Reserve space for the full packet
115            src.reserve(length - src.len());
116            return Ok(None);
117        }
118
119        // Extract the packet bytes
120        let packet_bytes = src.split_to(length);
121        let mut cursor = packet_bytes.as_ref();
122
123        // Parse the header
124        let header = PacketHeader::decode(&mut cursor)?;
125
126        // Extract payload
127        let payload = BytesMut::from(&packet_bytes[PACKET_HEADER_SIZE..]);
128
129        tracing::trace!(
130            packet_type = ?header.packet_type,
131            length = length,
132            is_eom = header.is_end_of_message(),
133            "decoded TDS packet"
134        );
135
136        Ok(Some(Packet::new(header, payload)))
137    }
138}
139
140impl Encoder<Packet> for TdsCodec {
141    type Error = CodecError;
142
143    fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> {
144        let total_length = PACKET_HEADER_SIZE + item.payload.len();
145
146        if total_length > self.max_packet_size {
147            return Err(CodecError::PacketTooLarge {
148                size: total_length,
149                max: self.max_packet_size,
150            });
151        }
152
153        // Reserve space
154        dst.reserve(total_length);
155
156        // Create header with correct length and packet ID
157        let mut header = item.header;
158        header.length = total_length as u16;
159        header.packet_id = self.next_packet_id();
160
161        // Encode header
162        header.encode(dst);
163
164        // Encode payload
165        dst.put_slice(&item.payload);
166
167        tracing::trace!(
168            packet_type = ?header.packet_type,
169            length = total_length,
170            packet_id = header.packet_id,
171            "encoded TDS packet"
172        );
173
174        Ok(())
175    }
176}
177
178#[cfg(test)]
179#[allow(clippy::unwrap_used)]
180mod tests {
181    use super::*;
182    use tds_protocol::packet::{PacketStatus, PacketType};
183
184    #[test]
185    fn test_decode_packet() {
186        let mut codec = TdsCodec::new();
187
188        // Create a minimal packet: header (8 bytes) + 4 bytes payload
189        let mut data = BytesMut::new();
190        data.put_u8(PacketType::SqlBatch as u8); // type
191        data.put_u8(PacketStatus::END_OF_MESSAGE.bits()); // status
192        data.put_u16(12); // length (8 header + 4 payload)
193        data.put_u16(0); // spid
194        data.put_u8(1); // packet_id
195        data.put_u8(0); // window
196        data.put_slice(b"test"); // payload
197
198        let packet = codec.decode(&mut data).unwrap().unwrap();
199        assert_eq!(packet.header.packet_type, PacketType::SqlBatch);
200        assert!(packet.header.is_end_of_message());
201        assert_eq!(&packet.payload[..], b"test");
202    }
203
204    #[test]
205    fn test_encode_packet() {
206        let mut codec = TdsCodec::new();
207
208        let header = PacketHeader::new(PacketType::SqlBatch, PacketStatus::END_OF_MESSAGE, 0);
209        let payload = BytesMut::from(&b"test"[..]);
210        let packet = Packet::new(header, payload);
211
212        let mut dst = BytesMut::new();
213        codec.encode(packet, &mut dst).unwrap();
214
215        assert_eq!(dst.len(), 12); // 8 header + 4 payload
216        assert_eq!(dst[0], PacketType::SqlBatch as u8);
217    }
218
219    #[test]
220    fn test_incomplete_packet() {
221        let mut codec = TdsCodec::new();
222
223        // Only header, no payload
224        let mut data = BytesMut::new();
225        data.put_u8(PacketType::SqlBatch as u8);
226        data.put_u8(PacketStatus::END_OF_MESSAGE.bits());
227        data.put_u16(12); // Claims to be 12 bytes
228        data.put_u16(0);
229        data.put_u8(1);
230        data.put_u8(0);
231        // Missing 4 bytes of payload
232
233        let result = codec.decode(&mut data).unwrap();
234        assert!(result.is_none()); // Should return None for incomplete
235    }
236}