Skip to main content

mssql_codec/
packet_codec.rs

1//! TDS packet codec implementation.
2
3use bytes::{BufMut, BytesMut};
4use tds_protocol::packet::{MAX_PACKET_SIZE, PACKET_HEADER_SIZE, PacketHeader};
5use tokio_util::codec::{Decoder, Encoder};
6
7use crate::error::CodecError;
8
9/// A TDS packet with header and payload.
10#[derive(Debug, Clone)]
11pub struct Packet {
12    /// Packet header.
13    pub header: PacketHeader,
14    /// Packet payload (excluding header).
15    pub payload: BytesMut,
16}
17
18impl Packet {
19    /// Create a new packet with the given header and payload.
20    #[must_use]
21    pub fn new(header: PacketHeader, payload: BytesMut) -> Self {
22        Self { header, payload }
23    }
24
25    /// Get the total packet size including header.
26    #[must_use]
27    pub fn total_size(&self) -> usize {
28        PACKET_HEADER_SIZE + self.payload.len()
29    }
30
31    /// Check if this is the last packet in a message.
32    #[must_use]
33    pub fn is_end_of_message(&self) -> bool {
34        self.header.is_end_of_message()
35    }
36}
37
38/// TDS packet codec for tokio-util framing.
39///
40/// This codec handles the low-level encoding and decoding of TDS packets
41/// over a byte stream.
42pub struct TdsCodec {
43    /// Maximum packet size to accept.
44    max_packet_size: usize,
45    /// Current packet sequence number for encoding.
46    packet_id: u8,
47}
48
49impl TdsCodec {
50    /// Create a new TDS codec with default settings.
51    #[must_use]
52    pub fn new() -> Self {
53        Self {
54            max_packet_size: MAX_PACKET_SIZE,
55            packet_id: 1,
56        }
57    }
58
59    /// Create a new TDS codec with a custom maximum packet size.
60    #[must_use]
61    pub fn with_max_packet_size(mut self, size: usize) -> Self {
62        self.max_packet_size = size.min(MAX_PACKET_SIZE);
63        self
64    }
65
66    /// Get the next packet ID and increment the counter.
67    fn next_packet_id(&mut self) -> u8 {
68        let id = self.packet_id;
69        self.packet_id = self.packet_id.wrapping_add(1);
70        if self.packet_id == 0 {
71            self.packet_id = 1;
72        }
73        id
74    }
75
76    /// Reset the packet ID counter.
77    pub fn reset_packet_id(&mut self) {
78        self.packet_id = 1;
79    }
80}
81
82impl Default for TdsCodec {
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88impl Decoder for TdsCodec {
89    type Item = Packet;
90    type Error = CodecError;
91
92    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
93        // Need at least a header to proceed
94        if src.len() < PACKET_HEADER_SIZE {
95            return Ok(None);
96        }
97
98        // Peek at the header to get the length
99        let length = u16::from_be_bytes([src[2], src[3]]) as usize;
100
101        // Validate packet length
102        if length < PACKET_HEADER_SIZE {
103            return Err(CodecError::InvalidHeader);
104        }
105        if length > self.max_packet_size {
106            return Err(CodecError::PacketTooLarge {
107                size: length,
108                max: self.max_packet_size,
109            });
110        }
111
112        // Check if we have the complete packet
113        if src.len() < length {
114            // Reserve space for the full packet
115            src.reserve(length - src.len());
116            return Ok(None);
117        }
118
119        // Extract the packet bytes
120        let packet_bytes = src.split_to(length);
121        let mut cursor = packet_bytes.as_ref();
122
123        // Parse the header
124        let header = PacketHeader::decode(&mut cursor)?;
125
126        // Extract payload
127        let payload = BytesMut::from(&packet_bytes[PACKET_HEADER_SIZE..]);
128
129        tracing::trace!(
130            packet_type = ?header.packet_type,
131            length = length,
132            is_eom = header.is_end_of_message(),
133            "decoded TDS packet"
134        );
135
136        Ok(Some(Packet::new(header, payload)))
137    }
138}
139
140impl Encoder<Packet> for TdsCodec {
141    type Error = CodecError;
142
143    fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> {
144        let total_length = PACKET_HEADER_SIZE + item.payload.len();
145
146        if total_length > self.max_packet_size {
147            return Err(CodecError::PacketTooLarge {
148                size: total_length,
149                max: self.max_packet_size,
150            });
151        }
152
153        // Reserve space
154        dst.reserve(total_length);
155
156        // Create header with correct length and packet ID
157        let mut header = item.header;
158        header.length = total_length as u16;
159        header.packet_id = self.next_packet_id();
160
161        // Encode header
162        header.encode(dst);
163
164        // Encode payload
165        dst.put_slice(&item.payload);
166
167        tracing::trace!(
168            packet_type = ?header.packet_type,
169            length = total_length,
170            packet_id = header.packet_id,
171            "encoded TDS packet"
172        );
173
174        Ok(())
175    }
176}
177
178#[cfg(test)]
179#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
180mod tests {
181    use super::*;
182    use tds_protocol::packet::{PacketStatus, PacketType};
183
184    #[test]
185    fn test_decode_packet() {
186        let mut codec = TdsCodec::new();
187
188        // Create a minimal packet: header (8 bytes) + 4 bytes payload
189        let mut data = BytesMut::new();
190        data.put_u8(PacketType::SqlBatch as u8); // type
191        data.put_u8(PacketStatus::END_OF_MESSAGE.bits()); // status
192        data.put_u16(12); // length (8 header + 4 payload)
193        data.put_u16(0); // spid
194        data.put_u8(1); // packet_id
195        data.put_u8(0); // window
196        data.put_slice(b"test"); // payload
197
198        let packet = codec.decode(&mut data).unwrap().unwrap();
199        assert_eq!(packet.header.packet_type, PacketType::SqlBatch);
200        assert!(packet.header.is_end_of_message());
201        assert_eq!(&packet.payload[..], b"test");
202    }
203
204    #[test]
205    fn test_encode_packet() {
206        let mut codec = TdsCodec::new();
207
208        let header = PacketHeader::new(PacketType::SqlBatch, PacketStatus::END_OF_MESSAGE, 0);
209        let payload = BytesMut::from(&b"test"[..]);
210        let packet = Packet::new(header, payload);
211
212        let mut dst = BytesMut::new();
213        codec.encode(packet, &mut dst).unwrap();
214
215        assert_eq!(dst.len(), 12); // 8 header + 4 payload
216        assert_eq!(dst[0], PacketType::SqlBatch as u8);
217    }
218
219    #[test]
220    fn test_incomplete_packet() {
221        let mut codec = TdsCodec::new();
222
223        // Only header, no payload
224        let mut data = BytesMut::new();
225        data.put_u8(PacketType::SqlBatch as u8);
226        data.put_u8(PacketStatus::END_OF_MESSAGE.bits());
227        data.put_u16(12); // Claims to be 12 bytes
228        data.put_u16(0);
229        data.put_u8(1);
230        data.put_u8(0);
231        // Missing 4 bytes of payload
232
233        let result = codec.decode(&mut data).unwrap();
234        assert!(result.is_none()); // Should return None for incomplete
235    }
236
237    /// Build an 8-byte TDS header with an arbitrary length field (which may be
238    /// intentionally invalid for testing).
239    fn header_with_length(len: u16) -> BytesMut {
240        let mut data = BytesMut::new();
241        data.put_u8(PacketType::SqlBatch as u8);
242        data.put_u8(PacketStatus::END_OF_MESSAGE.bits());
243        data.put_u16(len);
244        data.put_u16(0); // spid
245        data.put_u8(1); // packet_id
246        data.put_u8(0); // window
247        data
248    }
249
250    /// Issue #165: a length field smaller than the 8-byte header is malformed
251    /// and must be rejected, not silently accepted or panicked on.
252    #[test]
253    fn test_decode_rejects_length_below_header_size() {
254        let mut codec = TdsCodec::new();
255        let mut data = header_with_length(4); // < PACKET_HEADER_SIZE (8)
256        assert!(matches!(
257            codec.decode(&mut data),
258            Err(CodecError::InvalidHeader)
259        ));
260    }
261
262    /// Issue #165: a declared length above the negotiated maximum must be
263    /// rejected before any allocation/read of the claimed size.
264    #[test]
265    fn test_decode_rejects_packet_too_large() {
266        let mut codec = TdsCodec::new().with_max_packet_size(16);
267        // Header claims 20 bytes; only the 8-byte header is present, but the
268        // size check must fire before the completeness check.
269        let mut data = header_with_length(20);
270        match codec.decode(&mut data) {
271            Err(CodecError::PacketTooLarge { size, max }) => {
272                assert_eq!(size, 20);
273                assert_eq!(max, 16);
274            }
275            other => panic!("expected PacketTooLarge, got {other:?}"),
276        }
277    }
278
279    /// Issue #165: encoding a packet whose total length exceeds the maximum
280    /// must error rather than emit a truncated/overflowing length field.
281    #[test]
282    fn test_encode_rejects_packet_too_large() {
283        let mut codec = TdsCodec::new().with_max_packet_size(16);
284        let header = PacketHeader::new(PacketType::SqlBatch, PacketStatus::END_OF_MESSAGE, 0);
285        // 8-byte header + 16-byte payload = 24 > 16.
286        let payload = BytesMut::from(&[0u8; 16][..]);
287        let mut dst = BytesMut::new();
288        match codec.encode(Packet::new(header, payload), &mut dst) {
289            Err(CodecError::PacketTooLarge { size, max }) => {
290                assert_eq!(size, 24);
291                assert_eq!(max, 16);
292            }
293            other => panic!("expected PacketTooLarge, got {other:?}"),
294        }
295    }
296
297    /// Issue #165: the packet-id counter wraps 255 → 1, skipping 0 (TDS
298    /// packet IDs are 1-based; a 0 would be misread by the server).
299    #[test]
300    fn test_packet_id_wraps_past_zero() {
301        let mut codec = TdsCodec::new();
302        let mut saw_zero = false;
303        let mut saw_wrap_to_one = false;
304        let mut prev = codec.next_packet_id(); // first id (1)
305        for _ in 0..600 {
306            let id = codec.next_packet_id();
307            if id == 0 {
308                saw_zero = true;
309            }
310            if prev == 255 {
311                assert_eq!(id, 1, "after 255 the id must skip 0 and become 1");
312                saw_wrap_to_one = true;
313            }
314            prev = id;
315        }
316        assert!(!saw_zero, "packet id 0 must never be emitted");
317        assert!(saw_wrap_to_one, "the test must exercise the 255→1 wrap");
318    }
319
320    /// Issue #165: two complete packets concatenated in one buffer must both
321    /// decode, with the buffer fully consumed.
322    #[test]
323    fn test_decode_two_packets_from_one_buffer() {
324        let mut codec = TdsCodec::new();
325        let mut data = BytesMut::new();
326        for tag in [b"aaaa", b"bbbb"] {
327            data.put_u8(PacketType::SqlBatch as u8);
328            data.put_u8(PacketStatus::END_OF_MESSAGE.bits());
329            data.put_u16(12);
330            data.put_u16(0);
331            data.put_u8(1);
332            data.put_u8(0);
333            data.put_slice(tag);
334        }
335
336        let p1 = codec.decode(&mut data).unwrap().expect("first packet");
337        assert_eq!(&p1.payload[..], b"aaaa");
338        let p2 = codec.decode(&mut data).unwrap().expect("second packet");
339        assert_eq!(&p2.payload[..], b"bbbb");
340        assert!(data.is_empty(), "buffer must be fully consumed");
341        assert!(codec.decode(&mut data).unwrap().is_none());
342    }
343
344    /// Issue #165: a packet arriving in two reads (partial header, then the
345    /// rest) must decode once the full packet is present.
346    #[test]
347    fn test_decode_incremental_feed() {
348        let mut codec = TdsCodec::new();
349        let mut full = header_with_length(12);
350        full.put_slice(b"test");
351
352        // Feed only the first 5 bytes (partial header).
353        let mut data = BytesMut::new();
354        data.put_slice(&full[..5]);
355        assert!(codec.decode(&mut data).unwrap().is_none());
356
357        // Feed the remaining bytes; now it decodes.
358        data.put_slice(&full[5..]);
359        let p = codec
360            .decode(&mut data)
361            .unwrap()
362            .expect("packet after full feed");
363        assert_eq!(&p.payload[..], b"test");
364        assert!(data.is_empty());
365    }
366}