1use bitflags::bitflags;
4use bytes::{Buf, BufMut, Bytes, BytesMut};
5
6use crate::error::ProtocolError;
7
8pub const PACKET_HEADER_SIZE: usize = 8;
10
11pub const MAX_PACKET_SIZE: usize = 65535;
13
14pub const DEFAULT_PACKET_SIZE: usize = 4096;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19#[repr(u8)]
20#[non_exhaustive]
21pub enum PacketType {
22 SqlBatch = 0x01,
24 PreTds7Login = 0x02,
26 Rpc = 0x03,
28 TabularResult = 0x04,
30 Attention = 0x06,
32 BulkLoad = 0x07,
34 FedAuthToken = 0x08,
36 TransactionManager = 0x0E,
38 Tds7Login = 0x10,
40 Sspi = 0x11,
42 PreLogin = 0x12,
44}
45
46impl PacketType {
47 pub fn from_u8(value: u8) -> Result<Self, ProtocolError> {
49 match value {
50 0x01 => Ok(Self::SqlBatch),
51 0x02 => Ok(Self::PreTds7Login),
52 0x03 => Ok(Self::Rpc),
53 0x04 => Ok(Self::TabularResult),
54 0x06 => Ok(Self::Attention),
55 0x07 => Ok(Self::BulkLoad),
56 0x08 => Ok(Self::FedAuthToken),
57 0x0E => Ok(Self::TransactionManager),
58 0x10 => Ok(Self::Tds7Login),
59 0x11 => Ok(Self::Sspi),
60 0x12 => Ok(Self::PreLogin),
61 _ => Err(ProtocolError::InvalidPacketType(value)),
62 }
63 }
64}
65
66bitflags! {
67 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
69 pub struct PacketStatus: u8 {
70 const NORMAL = 0x00;
72 const END_OF_MESSAGE = 0x01;
74 const IGNORE_EVENT = 0x02;
76 const RESET_CONNECTION = 0x08;
78 const RESET_CONNECTION_KEEP_TRANSACTION = 0x10;
80 }
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub struct PacketHeader {
89 pub packet_type: PacketType,
91 pub status: PacketStatus,
93 pub length: u16,
95 pub spid: u16,
97 pub packet_id: u8,
99 pub window: u8,
101}
102
103impl PacketHeader {
104 #[must_use]
106 pub const fn new(packet_type: PacketType, status: PacketStatus, length: u16) -> Self {
107 Self {
108 packet_type,
109 status,
110 length,
111 spid: 0,
112 packet_id: 0,
113 window: 0,
114 }
115 }
116
117 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
119 if src.remaining() < PACKET_HEADER_SIZE {
120 return Err(ProtocolError::IncompletePacket {
121 expected: PACKET_HEADER_SIZE,
122 actual: src.remaining(),
123 });
124 }
125
126 let packet_type = PacketType::from_u8(src.get_u8())?;
127 let status_byte = src.get_u8();
128 let status = PacketStatus::from_bits(status_byte)
129 .ok_or(ProtocolError::InvalidPacketStatus(status_byte))?;
130 let length = src.get_u16();
131 let spid = src.get_u16();
132 let packet_id = src.get_u8();
133 let window = src.get_u8();
134
135 Ok(Self {
136 packet_type,
137 status,
138 length,
139 spid,
140 packet_id,
141 window,
142 })
143 }
144
145 pub fn encode(&self, dst: &mut impl BufMut) {
147 dst.put_u8(self.packet_type as u8);
148 dst.put_u8(self.status.bits());
149 dst.put_u16(self.length);
150 dst.put_u16(self.spid);
151 dst.put_u8(self.packet_id);
152 dst.put_u8(self.window);
153 }
154
155 #[must_use]
157 pub fn encode_to_bytes(&self) -> Bytes {
158 let mut buf = BytesMut::with_capacity(PACKET_HEADER_SIZE);
159 self.encode(&mut buf);
160 buf.freeze()
161 }
162
163 #[must_use]
165 pub const fn payload_length(&self) -> usize {
166 self.length.saturating_sub(PACKET_HEADER_SIZE as u16) as usize
167 }
168
169 #[must_use]
171 pub const fn is_end_of_message(&self) -> bool {
172 self.status.contains(PacketStatus::END_OF_MESSAGE)
173 }
174
175 #[must_use]
177 pub const fn with_packet_id(mut self, id: u8) -> Self {
178 self.packet_id = id;
179 self
180 }
181
182 #[must_use]
184 pub const fn with_spid(mut self, spid: u16) -> Self {
185 self.spid = spid;
186 self
187 }
188}
189
190impl Default for PacketHeader {
191 fn default() -> Self {
192 Self {
193 packet_type: PacketType::SqlBatch,
194 status: PacketStatus::END_OF_MESSAGE,
195 length: PACKET_HEADER_SIZE as u16,
196 spid: 0,
197 packet_id: 1,
198 window: 0,
199 }
200 }
201}
202
203#[cfg(test)]
204#[allow(clippy::unwrap_used)]
205mod tests {
206 use super::*;
207
208 #[test]
209 fn test_header_roundtrip() {
210 let header = PacketHeader {
211 packet_type: PacketType::SqlBatch,
212 status: PacketStatus::END_OF_MESSAGE,
213 length: 100,
214 spid: 54,
215 packet_id: 1,
216 window: 0,
217 };
218
219 let bytes = header.encode_to_bytes();
220 assert_eq!(bytes.len(), PACKET_HEADER_SIZE);
221
222 let mut cursor = bytes.as_ref();
223 let decoded = PacketHeader::decode(&mut cursor).unwrap();
224 assert_eq!(header, decoded);
225 }
226
227 #[test]
228 fn test_payload_length() {
229 let header = PacketHeader::new(PacketType::SqlBatch, PacketStatus::END_OF_MESSAGE, 100);
230 assert_eq!(header.payload_length(), 92);
231 }
232
233 #[test]
234 fn test_packet_type_from_u8() {
235 assert_eq!(PacketType::from_u8(0x01).unwrap(), PacketType::SqlBatch);
236 assert_eq!(PacketType::from_u8(0x12).unwrap(), PacketType::PreLogin);
237 assert!(PacketType::from_u8(0xFF).is_err());
238 }
239}