use bitflags::bitflags;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use crate::error::ProtocolError;
pub const PACKET_HEADER_SIZE: usize = 8;
pub const MAX_PACKET_SIZE: usize = 65535;
pub const DEFAULT_PACKET_SIZE: usize = 4096;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u8)]
#[non_exhaustive]
pub enum PacketType {
SqlBatch = 0x01,
PreTds7Login = 0x02,
Rpc = 0x03,
TabularResult = 0x04,
Attention = 0x06,
BulkLoad = 0x07,
FedAuthToken = 0x08,
TransactionManager = 0x0E,
Tds7Login = 0x10,
Sspi = 0x11,
PreLogin = 0x12,
}
impl PacketType {
pub fn from_u8(value: u8) -> Result<Self, ProtocolError> {
match value {
0x01 => Ok(Self::SqlBatch),
0x02 => Ok(Self::PreTds7Login),
0x03 => Ok(Self::Rpc),
0x04 => Ok(Self::TabularResult),
0x06 => Ok(Self::Attention),
0x07 => Ok(Self::BulkLoad),
0x08 => Ok(Self::FedAuthToken),
0x0E => Ok(Self::TransactionManager),
0x10 => Ok(Self::Tds7Login),
0x11 => Ok(Self::Sspi),
0x12 => Ok(Self::PreLogin),
_ => Err(ProtocolError::InvalidPacketType(value)),
}
}
}
bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct PacketStatus: u8 {
const NORMAL = 0x00;
const END_OF_MESSAGE = 0x01;
const IGNORE_EVENT = 0x02;
const RESET_CONNECTION = 0x08;
const RESET_CONNECTION_KEEP_TRANSACTION = 0x10;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PacketHeader {
pub packet_type: PacketType,
pub status: PacketStatus,
pub length: u16,
pub spid: u16,
pub packet_id: u8,
pub window: u8,
}
impl PacketHeader {
#[must_use]
pub const fn new(packet_type: PacketType, status: PacketStatus, length: u16) -> Self {
Self {
packet_type,
status,
length,
spid: 0,
packet_id: 0,
window: 0,
}
}
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < PACKET_HEADER_SIZE {
return Err(ProtocolError::IncompletePacket {
expected: PACKET_HEADER_SIZE,
actual: src.remaining(),
});
}
let packet_type = PacketType::from_u8(src.get_u8())?;
let status_byte = src.get_u8();
let status = PacketStatus::from_bits(status_byte)
.ok_or(ProtocolError::InvalidPacketStatus(status_byte))?;
let length = src.get_u16();
let spid = src.get_u16();
let packet_id = src.get_u8();
let window = src.get_u8();
Ok(Self {
packet_type,
status,
length,
spid,
packet_id,
window,
})
}
pub fn encode(&self, dst: &mut impl BufMut) {
dst.put_u8(self.packet_type as u8);
dst.put_u8(self.status.bits());
dst.put_u16(self.length);
dst.put_u16(self.spid);
dst.put_u8(self.packet_id);
dst.put_u8(self.window);
}
#[must_use]
pub fn encode_to_bytes(&self) -> Bytes {
let mut buf = BytesMut::with_capacity(PACKET_HEADER_SIZE);
self.encode(&mut buf);
buf.freeze()
}
#[must_use]
pub const fn payload_length(&self) -> usize {
self.length.saturating_sub(PACKET_HEADER_SIZE as u16) as usize
}
#[must_use]
pub const fn is_end_of_message(&self) -> bool {
self.status.contains(PacketStatus::END_OF_MESSAGE)
}
#[must_use]
pub const fn with_packet_id(mut self, id: u8) -> Self {
self.packet_id = id;
self
}
#[must_use]
pub const fn with_spid(mut self, spid: u16) -> Self {
self.spid = spid;
self
}
}
impl Default for PacketHeader {
fn default() -> Self {
Self {
packet_type: PacketType::SqlBatch,
status: PacketStatus::END_OF_MESSAGE,
length: PACKET_HEADER_SIZE as u16,
spid: 0,
packet_id: 1,
window: 0,
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_header_roundtrip() {
let header = PacketHeader {
packet_type: PacketType::SqlBatch,
status: PacketStatus::END_OF_MESSAGE,
length: 100,
spid: 54,
packet_id: 1,
window: 0,
};
let bytes = header.encode_to_bytes();
assert_eq!(bytes.len(), PACKET_HEADER_SIZE);
let mut cursor = bytes.as_ref();
let decoded = PacketHeader::decode(&mut cursor).unwrap();
assert_eq!(header, decoded);
}
#[test]
fn test_payload_length() {
let header = PacketHeader::new(PacketType::SqlBatch, PacketStatus::END_OF_MESSAGE, 100);
assert_eq!(header.payload_length(), 92);
}
#[test]
fn test_packet_type_from_u8() {
assert_eq!(PacketType::from_u8(0x01).unwrap(), PacketType::SqlBatch);
assert_eq!(PacketType::from_u8(0x12).unwrap(), PacketType::PreLogin);
assert!(PacketType::from_u8(0xFF).is_err());
}
}