mssql_codec/
packet_codec.rs1use bytes::{BufMut, BytesMut};
4use tds_protocol::packet::{MAX_PACKET_SIZE, PACKET_HEADER_SIZE, PacketHeader};
5use tokio_util::codec::{Decoder, Encoder};
6
7use crate::error::CodecError;
8
9#[derive(Debug, Clone)]
11pub struct Packet {
12 pub header: PacketHeader,
14 pub payload: BytesMut,
16}
17
18impl Packet {
19 #[must_use]
21 pub fn new(header: PacketHeader, payload: BytesMut) -> Self {
22 Self { header, payload }
23 }
24
25 #[must_use]
27 pub fn total_size(&self) -> usize {
28 PACKET_HEADER_SIZE + self.payload.len()
29 }
30
31 #[must_use]
33 pub fn is_end_of_message(&self) -> bool {
34 self.header.is_end_of_message()
35 }
36}
37
38pub struct TdsCodec {
43 max_packet_size: usize,
45 packet_id: u8,
47}
48
49impl TdsCodec {
50 #[must_use]
52 pub fn new() -> Self {
53 Self {
54 max_packet_size: MAX_PACKET_SIZE,
55 packet_id: 1,
56 }
57 }
58
59 #[must_use]
61 pub fn with_max_packet_size(mut self, size: usize) -> Self {
62 self.max_packet_size = size.min(MAX_PACKET_SIZE);
63 self
64 }
65
66 fn next_packet_id(&mut self) -> u8 {
68 let id = self.packet_id;
69 self.packet_id = self.packet_id.wrapping_add(1);
70 if self.packet_id == 0 {
71 self.packet_id = 1;
72 }
73 id
74 }
75
76 pub fn reset_packet_id(&mut self) {
78 self.packet_id = 1;
79 }
80}
81
82impl Default for TdsCodec {
83 fn default() -> Self {
84 Self::new()
85 }
86}
87
88impl Decoder for TdsCodec {
89 type Item = Packet;
90 type Error = CodecError;
91
92 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
93 if src.len() < PACKET_HEADER_SIZE {
95 return Ok(None);
96 }
97
98 let length = u16::from_be_bytes([src[2], src[3]]) as usize;
100
101 if length < PACKET_HEADER_SIZE {
103 return Err(CodecError::InvalidHeader);
104 }
105 if length > self.max_packet_size {
106 return Err(CodecError::PacketTooLarge {
107 size: length,
108 max: self.max_packet_size,
109 });
110 }
111
112 if src.len() < length {
114 src.reserve(length - src.len());
116 return Ok(None);
117 }
118
119 let packet_bytes = src.split_to(length);
121 let mut cursor = packet_bytes.as_ref();
122
123 let header = PacketHeader::decode(&mut cursor)?;
125
126 let payload = BytesMut::from(&packet_bytes[PACKET_HEADER_SIZE..]);
128
129 tracing::trace!(
130 packet_type = ?header.packet_type,
131 length = length,
132 is_eom = header.is_end_of_message(),
133 "decoded TDS packet"
134 );
135
136 Ok(Some(Packet::new(header, payload)))
137 }
138}
139
140impl Encoder<Packet> for TdsCodec {
141 type Error = CodecError;
142
143 fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> {
144 let total_length = PACKET_HEADER_SIZE + item.payload.len();
145
146 if total_length > self.max_packet_size {
147 return Err(CodecError::PacketTooLarge {
148 size: total_length,
149 max: self.max_packet_size,
150 });
151 }
152
153 dst.reserve(total_length);
155
156 let mut header = item.header;
158 header.length = total_length as u16;
159 header.packet_id = self.next_packet_id();
160
161 header.encode(dst);
163
164 dst.put_slice(&item.payload);
166
167 tracing::trace!(
168 packet_type = ?header.packet_type,
169 length = total_length,
170 packet_id = header.packet_id,
171 "encoded TDS packet"
172 );
173
174 Ok(())
175 }
176}
177
178#[cfg(test)]
179#[allow(clippy::unwrap_used)]
180mod tests {
181 use super::*;
182 use tds_protocol::packet::{PacketStatus, PacketType};
183
184 #[test]
185 fn test_decode_packet() {
186 let mut codec = TdsCodec::new();
187
188 let mut data = BytesMut::new();
190 data.put_u8(PacketType::SqlBatch as u8); data.put_u8(PacketStatus::END_OF_MESSAGE.bits()); data.put_u16(12); data.put_u16(0); data.put_u8(1); data.put_u8(0); data.put_slice(b"test"); let packet = codec.decode(&mut data).unwrap().unwrap();
199 assert_eq!(packet.header.packet_type, PacketType::SqlBatch);
200 assert!(packet.header.is_end_of_message());
201 assert_eq!(&packet.payload[..], b"test");
202 }
203
204 #[test]
205 fn test_encode_packet() {
206 let mut codec = TdsCodec::new();
207
208 let header = PacketHeader::new(PacketType::SqlBatch, PacketStatus::END_OF_MESSAGE, 0);
209 let payload = BytesMut::from(&b"test"[..]);
210 let packet = Packet::new(header, payload);
211
212 let mut dst = BytesMut::new();
213 codec.encode(packet, &mut dst).unwrap();
214
215 assert_eq!(dst.len(), 12); assert_eq!(dst[0], PacketType::SqlBatch as u8);
217 }
218
219 #[test]
220 fn test_incomplete_packet() {
221 let mut codec = TdsCodec::new();
222
223 let mut data = BytesMut::new();
225 data.put_u8(PacketType::SqlBatch as u8);
226 data.put_u8(PacketStatus::END_OF_MESSAGE.bits());
227 data.put_u16(12); data.put_u16(0);
229 data.put_u8(1);
230 data.put_u8(0);
231 let result = codec.decode(&mut data).unwrap();
234 assert!(result.is_none()); }
236}