use bytes::{Buf, BufMut, Bytes};
use crate::codec::{read_b_varchar, read_us_varchar};
use crate::error::ProtocolError;
use crate::prelude::*;
use crate::types::TypeId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u8)]
#[non_exhaustive]
pub enum TokenType {
ColMetaData = 0x81,
Error = 0xAA,
Info = 0xAB,
LoginAck = 0xAD,
Row = 0xD1,
NbcRow = 0xD2,
EnvChange = 0xE3,
Sspi = 0xED,
Done = 0xFD,
DoneInProc = 0xFF,
DoneProc = 0xFE,
ReturnStatus = 0x79,
ReturnValue = 0xAC,
Order = 0xA9,
FeatureExtAck = 0xAE,
SessionState = 0xE4,
FedAuthInfo = 0xEE,
ColInfo = 0xA5,
TabName = 0xA4,
Offset = 0x78,
}
impl TokenType {
pub fn from_u8(value: u8) -> Option<Self> {
match value {
0x81 => Some(Self::ColMetaData),
0xAA => Some(Self::Error),
0xAB => Some(Self::Info),
0xAD => Some(Self::LoginAck),
0xD1 => Some(Self::Row),
0xD2 => Some(Self::NbcRow),
0xE3 => Some(Self::EnvChange),
0xED => Some(Self::Sspi),
0xFD => Some(Self::Done),
0xFF => Some(Self::DoneInProc),
0xFE => Some(Self::DoneProc),
0x79 => Some(Self::ReturnStatus),
0xAC => Some(Self::ReturnValue),
0xA9 => Some(Self::Order),
0xAE => Some(Self::FeatureExtAck),
0xE4 => Some(Self::SessionState),
0xEE => Some(Self::FedAuthInfo),
0xA5 => Some(Self::ColInfo),
0xA4 => Some(Self::TabName),
0x78 => Some(Self::Offset),
_ => None,
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum Token {
ColMetaData(ColMetaData),
Row(RawRow),
NbcRow(NbcRow),
Done(Done),
DoneProc(DoneProc),
DoneInProc(DoneInProc),
ReturnStatus(i32),
ReturnValue(ReturnValue),
Error(ServerError),
Info(ServerInfo),
LoginAck(LoginAck),
EnvChange(EnvChange),
Order(Order),
FeatureExtAck(FeatureExtAck),
Sspi(SspiToken),
SessionState(SessionState),
FedAuthInfo(FedAuthInfo),
}
#[derive(Debug, Clone, Default)]
pub struct ColMetaData {
pub columns: Vec<ColumnData>,
pub cek_table: Option<crate::crypto::CekTable>,
}
#[derive(Debug, Clone)]
pub struct ColumnData {
pub name: String,
pub type_id: TypeId,
pub col_type: u8,
pub flags: u16,
pub user_type: u32,
pub type_info: TypeInfo,
pub crypto_metadata: Option<crate::crypto::CryptoMetadata>,
}
#[derive(Debug, Clone, Default)]
pub struct TypeInfo {
pub max_length: Option<u32>,
pub precision: Option<u8>,
pub scale: Option<u8>,
pub collation: Option<Collation>,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Collation {
pub lcid: u32,
pub sort_id: u8,
}
impl Collation {
pub fn from_bytes(bytes: &[u8; 5]) -> Self {
Self {
lcid: u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]),
sort_id: bytes[4],
}
}
pub fn to_bytes(&self) -> [u8; 5] {
let b = self.lcid.to_le_bytes();
[b[0], b[1], b[2], b[3], self.sort_id]
}
#[cfg(feature = "encoding")]
pub fn encoding(&self) -> Option<&'static encoding_rs::Encoding> {
crate::collation::encoding_for_lcid(self.lcid)
}
#[cfg(feature = "encoding")]
pub fn is_utf8(&self) -> bool {
crate::collation::is_utf8_collation(self.lcid)
}
#[cfg(feature = "encoding")]
pub fn code_page(&self) -> Option<u16> {
crate::collation::code_page_for_lcid(self.lcid)
}
#[cfg(feature = "encoding")]
pub fn encoding_name(&self) -> &'static str {
crate::collation::encoding_name_for_lcid(self.lcid)
}
}
#[derive(Debug, Clone)]
pub struct RawRow {
pub data: bytes::Bytes,
}
#[derive(Debug, Clone)]
pub struct NbcRow {
pub null_bitmap: Vec<u8>,
pub data: bytes::Bytes,
}
#[derive(Debug, Clone, Copy)]
pub struct Done {
pub status: DoneStatus,
pub cur_cmd: u16,
pub row_count: u64,
}
#[derive(Debug, Clone, Copy, Default)]
#[non_exhaustive]
pub struct DoneStatus {
pub more: bool,
pub error: bool,
pub in_xact: bool,
pub count: bool,
pub attn: bool,
pub srverror: bool,
}
#[derive(Debug, Clone, Copy)]
pub struct DoneInProc {
pub status: DoneStatus,
pub cur_cmd: u16,
pub row_count: u64,
}
#[derive(Debug, Clone, Copy)]
pub struct DoneProc {
pub status: DoneStatus,
pub cur_cmd: u16,
pub row_count: u64,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ReturnValue {
pub param_ordinal: u16,
pub param_name: String,
pub status: u8,
pub user_type: u32,
pub flags: u16,
pub col_type: u8,
pub type_info: TypeInfo,
pub value: bytes::Bytes,
}
#[derive(Debug, Clone)]
pub struct ServerError {
pub number: i32,
pub state: u8,
pub class: u8,
pub message: String,
pub server: String,
pub procedure: String,
pub line: i32,
}
#[derive(Debug, Clone)]
pub struct ServerInfo {
pub number: i32,
pub state: u8,
pub class: u8,
pub message: String,
pub server: String,
pub procedure: String,
pub line: i32,
}
#[derive(Debug, Clone)]
pub struct LoginAck {
pub interface: u8,
pub tds_version: u32,
pub prog_name: String,
pub prog_version: u32,
}
#[derive(Debug, Clone)]
pub struct EnvChange {
pub env_type: EnvChangeType,
pub new_value: EnvChangeValue,
pub old_value: EnvChangeValue,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
#[non_exhaustive]
pub enum EnvChangeType {
Database = 1,
Language = 2,
CharacterSet = 3,
PacketSize = 4,
UnicodeSortingLocalId = 5,
UnicodeComparisonFlags = 6,
SqlCollation = 7,
BeginTransaction = 8,
CommitTransaction = 9,
RollbackTransaction = 10,
EnlistDtcTransaction = 11,
DefectTransaction = 12,
RealTimeLogShipping = 13,
PromoteTransaction = 15,
TransactionManagerAddress = 16,
TransactionEnded = 17,
ResetConnectionCompletionAck = 18,
UserInstanceStarted = 19,
Routing = 20,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum EnvChangeValue {
String(String),
Binary(bytes::Bytes),
Routing {
host: String,
port: u16,
},
}
#[derive(Debug, Clone)]
pub struct Order {
pub columns: Vec<u16>,
}
#[derive(Debug, Clone)]
pub struct FeatureExtAck {
pub features: Vec<FeatureAck>,
}
#[derive(Debug, Clone)]
pub struct FeatureAck {
pub feature_id: u8,
pub data: bytes::Bytes,
}
#[derive(Debug, Clone)]
pub struct SspiToken {
pub data: bytes::Bytes,
}
#[derive(Debug, Clone)]
pub struct SessionState {
pub data: bytes::Bytes,
}
#[derive(Debug, Clone)]
pub struct FedAuthInfo {
pub sts_url: String,
pub spn: String,
}
pub(crate) fn decode_collation(src: &mut impl Buf) -> Result<Collation, ProtocolError> {
if src.remaining() < 5 {
return Err(ProtocolError::UnexpectedEof);
}
let lcid = src.get_u32_le();
let sort_id = src.get_u8();
Ok(Collation { lcid, sort_id })
}
pub(crate) fn decode_type_info(
src: &mut impl Buf,
type_id: TypeId,
col_type: u8,
) -> Result<TypeInfo, ProtocolError> {
match type_id {
TypeId::Null => Ok(TypeInfo::default()),
TypeId::Int1 | TypeId::Bit => Ok(TypeInfo::default()),
TypeId::Int2 => Ok(TypeInfo::default()),
TypeId::Int4 => Ok(TypeInfo::default()),
TypeId::Int8 => Ok(TypeInfo::default()),
TypeId::Float4 => Ok(TypeInfo::default()),
TypeId::Float8 => Ok(TypeInfo::default()),
TypeId::Money => Ok(TypeInfo::default()),
TypeId::Money4 => Ok(TypeInfo::default()),
TypeId::DateTime => Ok(TypeInfo::default()),
TypeId::DateTime4 => Ok(TypeInfo::default()),
TypeId::IntN | TypeId::BitN | TypeId::FloatN | TypeId::MoneyN | TypeId::DateTimeN => {
if src.remaining() < 1 {
return Err(ProtocolError::UnexpectedEof);
}
let max_length = src.get_u8() as u32;
Ok(TypeInfo {
max_length: Some(max_length),
..Default::default()
})
}
TypeId::Guid => {
if src.remaining() < 1 {
return Err(ProtocolError::UnexpectedEof);
}
let max_length = src.get_u8() as u32;
Ok(TypeInfo {
max_length: Some(max_length),
..Default::default()
})
}
TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
if src.remaining() < 3 {
return Err(ProtocolError::UnexpectedEof);
}
let max_length = src.get_u8() as u32;
let precision = src.get_u8();
let scale = src.get_u8();
Ok(TypeInfo {
max_length: Some(max_length),
precision: Some(precision),
scale: Some(scale),
..Default::default()
})
}
TypeId::Char | TypeId::VarChar | TypeId::Binary | TypeId::VarBinary => {
if src.remaining() < 1 {
return Err(ProtocolError::UnexpectedEof);
}
let max_length = src.get_u8() as u32;
Ok(TypeInfo {
max_length: Some(max_length),
..Default::default()
})
}
TypeId::BigVarChar | TypeId::BigChar => {
if src.remaining() < 7 {
return Err(ProtocolError::UnexpectedEof);
}
let max_length = src.get_u16_le() as u32;
let collation = decode_collation(src)?;
Ok(TypeInfo {
max_length: Some(max_length),
collation: Some(collation),
..Default::default()
})
}
TypeId::BigVarBinary | TypeId::BigBinary => {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let max_length = src.get_u16_le() as u32;
Ok(TypeInfo {
max_length: Some(max_length),
..Default::default()
})
}
TypeId::NChar | TypeId::NVarChar => {
if src.remaining() < 7 {
return Err(ProtocolError::UnexpectedEof);
}
let max_length = src.get_u16_le() as u32;
let collation = decode_collation(src)?;
Ok(TypeInfo {
max_length: Some(max_length),
collation: Some(collation),
..Default::default()
})
}
TypeId::Date => Ok(TypeInfo::default()),
TypeId::Time | TypeId::DateTime2 | TypeId::DateTimeOffset => {
if src.remaining() < 1 {
return Err(ProtocolError::UnexpectedEof);
}
let scale = src.get_u8();
Ok(TypeInfo {
scale: Some(scale),
..Default::default()
})
}
TypeId::Text | TypeId::NText | TypeId::Image => {
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let max_length = src.get_u32_le();
let collation = if type_id == TypeId::Text || type_id == TypeId::NText {
if src.remaining() < 5 {
return Err(ProtocolError::UnexpectedEof);
}
Some(decode_collation(src)?)
} else {
None
};
if src.remaining() < 1 {
return Err(ProtocolError::UnexpectedEof);
}
let num_parts = src.get_u8();
for _ in 0..num_parts {
let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
}
Ok(TypeInfo {
max_length: Some(max_length),
collation,
..Default::default()
})
}
TypeId::Xml => {
if src.remaining() < 1 {
return Err(ProtocolError::UnexpectedEof);
}
let schema_present = src.get_u8();
if schema_present != 0 {
let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; }
Ok(TypeInfo::default())
}
TypeId::Udt => {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let max_length = src.get_u16_le() as u32;
let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
Ok(TypeInfo {
max_length: Some(max_length),
..Default::default()
})
}
TypeId::Tvp => {
Err(ProtocolError::InvalidTokenType(col_type))
}
TypeId::Variant => {
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let max_length = src.get_u32_le();
Ok(TypeInfo {
max_length: Some(max_length),
..Default::default()
})
}
}
}
impl ColMetaData {
pub const NO_METADATA: u16 = 0xFFFF;
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let column_count = src.get_u16_le();
if column_count == Self::NO_METADATA {
return Ok(Self {
columns: Vec::new(),
cek_table: None,
});
}
let mut columns = Vec::with_capacity(column_count as usize);
for _ in 0..column_count {
let column = Self::decode_column(src)?;
columns.push(column);
}
Ok(Self {
columns,
cek_table: None,
})
}
fn decode_column(src: &mut impl Buf) -> Result<ColumnData, ProtocolError> {
if src.remaining() < 7 {
return Err(ProtocolError::UnexpectedEof);
}
let user_type = src.get_u32_le();
let flags = src.get_u16_le();
let col_type = src.get_u8();
let type_id = TypeId::from_u8(col_type).unwrap_or(TypeId::Null);
let type_info = decode_type_info(src, type_id, col_type)?;
let name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
Ok(ColumnData {
name,
type_id,
col_type,
flags,
user_type,
type_info,
crypto_metadata: None,
})
}
pub fn decode_encrypted(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let column_count = src.get_u16_le();
if column_count == Self::NO_METADATA {
return Ok(Self {
columns: Vec::new(),
cek_table: None,
});
}
let cek_table = crate::crypto::CekTable::decode(src)?;
let mut columns = Vec::with_capacity(column_count as usize);
for _ in 0..column_count {
let column = Self::decode_column_encrypted(src)?;
columns.push(column);
}
Ok(Self {
columns,
cek_table: Some(cek_table),
})
}
fn decode_column_encrypted(src: &mut impl Buf) -> Result<ColumnData, ProtocolError> {
if src.remaining() < 7 {
return Err(ProtocolError::UnexpectedEof);
}
let user_type = src.get_u32_le();
let flags = src.get_u16_le();
let col_type = src.get_u8();
let type_id = TypeId::from_u8(col_type).unwrap_or(TypeId::Null);
let type_info = decode_type_info(src, type_id, col_type)?;
let crypto_metadata = if crate::crypto::is_column_encrypted(flags) {
Some(crate::crypto::CryptoMetadata::decode(src)?)
} else {
None
};
let name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
Ok(ColumnData {
name,
type_id,
col_type,
flags,
user_type,
type_info,
crypto_metadata,
})
}
#[must_use]
pub fn column_count(&self) -> usize {
self.columns.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.columns.is_empty()
}
}
impl ColumnData {
#[must_use]
pub fn is_nullable(&self) -> bool {
(self.flags & 0x0001) != 0
}
#[must_use]
pub fn fixed_size(&self) -> Option<usize> {
match self.type_id {
TypeId::Null => Some(0),
TypeId::Int1 | TypeId::Bit => Some(1),
TypeId::Int2 => Some(2),
TypeId::Int4 => Some(4),
TypeId::Int8 => Some(8),
TypeId::Float4 => Some(4),
TypeId::Float8 => Some(8),
TypeId::Money => Some(8),
TypeId::Money4 => Some(4),
TypeId::DateTime => Some(8),
TypeId::DateTime4 => Some(4),
TypeId::Date => Some(3),
_ => None,
}
}
}
impl RawRow {
pub fn decode(src: &mut impl Buf, metadata: &ColMetaData) -> Result<Self, ProtocolError> {
let mut data = bytes::BytesMut::new();
for col in &metadata.columns {
Self::decode_column_value(src, col, &mut data)?;
}
Ok(Self {
data: data.freeze(),
})
}
fn decode_column_value(
src: &mut impl Buf,
col: &ColumnData,
dst: &mut bytes::BytesMut,
) -> Result<(), ProtocolError> {
match col.type_id {
TypeId::Null => {
}
TypeId::Int1 | TypeId::Bit => {
if src.remaining() < 1 {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&[src.get_u8()]);
}
TypeId::Int2 => {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&src.get_u16_le().to_le_bytes());
}
TypeId::Int4 => {
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
}
TypeId::Int8 => {
if src.remaining() < 8 {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&src.get_u64_le().to_le_bytes());
}
TypeId::Float4 => {
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
}
TypeId::Float8 => {
if src.remaining() < 8 {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&src.get_u64_le().to_le_bytes());
}
TypeId::Money => {
if src.remaining() < 8 {
return Err(ProtocolError::UnexpectedEof);
}
let hi = src.get_u32_le();
let lo = src.get_u32_le();
dst.extend_from_slice(&hi.to_le_bytes());
dst.extend_from_slice(&lo.to_le_bytes());
}
TypeId::Money4 => {
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
}
TypeId::DateTime => {
if src.remaining() < 8 {
return Err(ProtocolError::UnexpectedEof);
}
let days = src.get_u32_le();
let time = src.get_u32_le();
dst.extend_from_slice(&days.to_le_bytes());
dst.extend_from_slice(&time.to_le_bytes());
}
TypeId::DateTime4 => {
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
}
TypeId::Date => {
Self::decode_bytelen_type(src, dst)?;
}
TypeId::IntN | TypeId::BitN | TypeId::FloatN | TypeId::MoneyN | TypeId::DateTimeN => {
Self::decode_bytelen_type(src, dst)?;
}
TypeId::Guid => {
Self::decode_bytelen_type(src, dst)?;
}
TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
Self::decode_bytelen_type(src, dst)?;
}
TypeId::Char | TypeId::VarChar | TypeId::Binary | TypeId::VarBinary => {
Self::decode_bytelen_type(src, dst)?;
}
TypeId::BigVarChar | TypeId::BigVarBinary => {
if col.type_info.max_length == Some(0xFFFF) {
Self::decode_plp_type(src, dst)?;
} else {
Self::decode_ushortlen_type(src, dst)?;
}
}
TypeId::BigChar | TypeId::BigBinary => {
Self::decode_ushortlen_type(src, dst)?;
}
TypeId::NVarChar => {
if col.type_info.max_length == Some(0xFFFF) {
Self::decode_plp_type(src, dst)?;
} else {
Self::decode_ushortlen_type(src, dst)?;
}
}
TypeId::NChar => {
Self::decode_ushortlen_type(src, dst)?;
}
TypeId::Time | TypeId::DateTime2 | TypeId::DateTimeOffset => {
Self::decode_bytelen_type(src, dst)?;
}
TypeId::Text | TypeId::NText | TypeId::Image => {
Self::decode_textptr_type(src, dst)?;
}
TypeId::Xml => {
Self::decode_plp_type(src, dst)?;
}
TypeId::Variant => {
Self::decode_intlen_type(src, dst)?;
}
TypeId::Udt => {
Self::decode_plp_type(src, dst)?;
}
TypeId::Tvp => {
return Err(ProtocolError::InvalidTokenType(col.col_type));
}
}
Ok(())
}
fn decode_bytelen_type(
src: &mut impl Buf,
dst: &mut bytes::BytesMut,
) -> Result<(), ProtocolError> {
if src.remaining() < 1 {
return Err(ProtocolError::UnexpectedEof);
}
let len = src.get_u8() as usize;
if len == 0xFF {
dst.extend_from_slice(&[0xFF]);
} else if len == 0 {
dst.extend_from_slice(&[0x00]);
} else {
if src.remaining() < len {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&[len as u8]);
for _ in 0..len {
dst.extend_from_slice(&[src.get_u8()]);
}
}
Ok(())
}
fn decode_ushortlen_type(
src: &mut impl Buf,
dst: &mut bytes::BytesMut,
) -> Result<(), ProtocolError> {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let len = src.get_u16_le() as usize;
if len == 0xFFFF {
dst.extend_from_slice(&0xFFFFu16.to_le_bytes());
} else if len == 0 {
dst.extend_from_slice(&0u16.to_le_bytes());
} else {
if src.remaining() < len {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&(len as u16).to_le_bytes());
for _ in 0..len {
dst.extend_from_slice(&[src.get_u8()]);
}
}
Ok(())
}
fn decode_intlen_type(
src: &mut impl Buf,
dst: &mut bytes::BytesMut,
) -> Result<(), ProtocolError> {
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let len = src.get_u32_le() as usize;
if len == 0xFFFFFFFF {
dst.extend_from_slice(&0xFFFFFFFFu32.to_le_bytes());
} else if len == 0 {
dst.extend_from_slice(&0u32.to_le_bytes());
} else {
if src.remaining() < len {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&(len as u32).to_le_bytes());
for _ in 0..len {
dst.extend_from_slice(&[src.get_u8()]);
}
}
Ok(())
}
fn decode_textptr_type(
src: &mut impl Buf,
dst: &mut bytes::BytesMut,
) -> Result<(), ProtocolError> {
if src.remaining() < 1 {
return Err(ProtocolError::UnexpectedEof);
}
let textptr_len = src.get_u8() as usize;
if textptr_len == 0 {
dst.extend_from_slice(&0xFFFFFFFFFFFFFFFFu64.to_le_bytes());
return Ok(());
}
if src.remaining() < textptr_len {
return Err(ProtocolError::UnexpectedEof);
}
src.advance(textptr_len);
if src.remaining() < 8 {
return Err(ProtocolError::UnexpectedEof);
}
src.advance(8);
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let data_len = src.get_u32_le() as usize;
if src.remaining() < data_len {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&(data_len as u64).to_le_bytes());
dst.extend_from_slice(&(data_len as u32).to_le_bytes());
for _ in 0..data_len {
dst.extend_from_slice(&[src.get_u8()]);
}
dst.extend_from_slice(&0u32.to_le_bytes());
Ok(())
}
fn decode_plp_type(src: &mut impl Buf, dst: &mut bytes::BytesMut) -> Result<(), ProtocolError> {
if src.remaining() < 8 {
return Err(ProtocolError::UnexpectedEof);
}
let total_len = src.get_u64_le();
dst.extend_from_slice(&total_len.to_le_bytes());
if total_len == 0xFFFFFFFFFFFFFFFF {
return Ok(());
}
loop {
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let chunk_len = src.get_u32_le() as usize;
dst.extend_from_slice(&(chunk_len as u32).to_le_bytes());
if chunk_len == 0 {
break;
}
if src.remaining() < chunk_len {
return Err(ProtocolError::UnexpectedEof);
}
for _ in 0..chunk_len {
dst.extend_from_slice(&[src.get_u8()]);
}
}
Ok(())
}
}
impl NbcRow {
pub fn decode(src: &mut impl Buf, metadata: &ColMetaData) -> Result<Self, ProtocolError> {
let col_count = metadata.columns.len();
let bitmap_len = col_count.div_ceil(8);
if src.remaining() < bitmap_len {
return Err(ProtocolError::UnexpectedEof);
}
let mut null_bitmap = vec![0u8; bitmap_len];
for byte in &mut null_bitmap {
*byte = src.get_u8();
}
let mut data = bytes::BytesMut::new();
for (i, col) in metadata.columns.iter().enumerate() {
let byte_idx = i / 8;
let bit_idx = i % 8;
let is_null = (null_bitmap[byte_idx] & (1 << bit_idx)) != 0;
if !is_null {
RawRow::decode_column_value(src, col, &mut data)?;
}
}
Ok(Self {
null_bitmap,
data: data.freeze(),
})
}
#[must_use]
pub fn is_null(&self, column_index: usize) -> bool {
let byte_idx = column_index / 8;
let bit_idx = column_index % 8;
if byte_idx < self.null_bitmap.len() {
(self.null_bitmap[byte_idx] & (1 << bit_idx)) != 0
} else {
true }
}
}
impl ReturnValue {
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let param_ordinal = src.get_u16_le();
let param_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
if src.remaining() < 1 {
return Err(ProtocolError::UnexpectedEof);
}
let status = src.get_u8();
if src.remaining() < 7 {
return Err(ProtocolError::UnexpectedEof);
}
let user_type = src.get_u32_le();
let flags = src.get_u16_le();
let col_type = src.get_u8();
let type_id = TypeId::from_u8(col_type).unwrap_or(TypeId::Null);
let type_info = decode_type_info(src, type_id, col_type)?;
let mut value_buf = bytes::BytesMut::new();
let temp_col = ColumnData {
name: String::new(),
type_id,
col_type,
flags,
user_type,
type_info: type_info.clone(),
crypto_metadata: None,
};
RawRow::decode_column_value(src, &temp_col, &mut value_buf)?;
Ok(Self {
param_ordinal,
param_name,
status,
user_type,
flags,
col_type,
type_info,
value: value_buf.freeze(),
})
}
}
impl SessionState {
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let length = src.get_u32_le() as usize;
if src.remaining() < length {
return Err(ProtocolError::IncompletePacket {
expected: length,
actual: src.remaining(),
});
}
let data = src.copy_to_bytes(length);
Ok(Self { data })
}
}
mod done_status_bits {
pub const DONE_MORE: u16 = 0x0001;
pub const DONE_ERROR: u16 = 0x0002;
pub const DONE_INXACT: u16 = 0x0004;
pub const DONE_COUNT: u16 = 0x0010;
pub const DONE_ATTN: u16 = 0x0020;
pub const DONE_SRVERROR: u16 = 0x0100;
}
impl DoneStatus {
#[must_use]
pub fn from_bits(bits: u16) -> Self {
use done_status_bits::*;
Self {
more: (bits & DONE_MORE) != 0,
error: (bits & DONE_ERROR) != 0,
in_xact: (bits & DONE_INXACT) != 0,
count: (bits & DONE_COUNT) != 0,
attn: (bits & DONE_ATTN) != 0,
srverror: (bits & DONE_SRVERROR) != 0,
}
}
#[must_use]
pub fn to_bits(&self) -> u16 {
use done_status_bits::*;
let mut bits = 0u16;
if self.more {
bits |= DONE_MORE;
}
if self.error {
bits |= DONE_ERROR;
}
if self.in_xact {
bits |= DONE_INXACT;
}
if self.count {
bits |= DONE_COUNT;
}
if self.attn {
bits |= DONE_ATTN;
}
if self.srverror {
bits |= DONE_SRVERROR;
}
bits
}
}
impl Done {
pub const SIZE: usize = 12;
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < Self::SIZE {
return Err(ProtocolError::IncompletePacket {
expected: Self::SIZE,
actual: src.remaining(),
});
}
let status = DoneStatus::from_bits(src.get_u16_le());
let cur_cmd = src.get_u16_le();
let row_count = src.get_u64_le();
Ok(Self {
status,
cur_cmd,
row_count,
})
}
pub fn encode(&self, dst: &mut impl BufMut) {
dst.put_u8(TokenType::Done as u8);
dst.put_u16_le(self.status.to_bits());
dst.put_u16_le(self.cur_cmd);
dst.put_u64_le(self.row_count);
}
#[must_use]
pub const fn has_more(&self) -> bool {
self.status.more
}
#[must_use]
pub const fn has_error(&self) -> bool {
self.status.error
}
#[must_use]
pub const fn has_count(&self) -> bool {
self.status.count
}
}
impl DoneProc {
pub const SIZE: usize = 12;
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < Self::SIZE {
return Err(ProtocolError::IncompletePacket {
expected: Self::SIZE,
actual: src.remaining(),
});
}
let status = DoneStatus::from_bits(src.get_u16_le());
let cur_cmd = src.get_u16_le();
let row_count = src.get_u64_le();
Ok(Self {
status,
cur_cmd,
row_count,
})
}
pub fn encode(&self, dst: &mut impl BufMut) {
dst.put_u8(TokenType::DoneProc as u8);
dst.put_u16_le(self.status.to_bits());
dst.put_u16_le(self.cur_cmd);
dst.put_u64_le(self.row_count);
}
}
impl DoneInProc {
pub const SIZE: usize = 12;
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < Self::SIZE {
return Err(ProtocolError::IncompletePacket {
expected: Self::SIZE,
actual: src.remaining(),
});
}
let status = DoneStatus::from_bits(src.get_u16_le());
let cur_cmd = src.get_u16_le();
let row_count = src.get_u64_le();
Ok(Self {
status,
cur_cmd,
row_count,
})
}
pub fn encode(&self, dst: &mut impl BufMut) {
dst.put_u8(TokenType::DoneInProc as u8);
dst.put_u16_le(self.status.to_bits());
dst.put_u16_le(self.cur_cmd);
dst.put_u64_le(self.row_count);
}
}
impl ServerError {
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let _length = src.get_u16_le();
if src.remaining() < 6 {
return Err(ProtocolError::UnexpectedEof);
}
let number = src.get_i32_le();
let state = src.get_u8();
let class = src.get_u8();
let message = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
let server = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
let procedure = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let line = src.get_i32_le();
Ok(Self {
number,
state,
class,
message,
server,
procedure,
line,
})
}
#[must_use]
pub const fn is_fatal(&self) -> bool {
self.class >= 20
}
#[must_use]
pub const fn is_batch_abort(&self) -> bool {
self.class >= 16
}
}
impl ServerInfo {
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let _length = src.get_u16_le();
if src.remaining() < 6 {
return Err(ProtocolError::UnexpectedEof);
}
let number = src.get_i32_le();
let state = src.get_u8();
let class = src.get_u8();
let message = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
let server = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
let procedure = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let line = src.get_i32_le();
Ok(Self {
number,
state,
class,
message,
server,
procedure,
line,
})
}
}
impl LoginAck {
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let _length = src.get_u16_le();
if src.remaining() < 5 {
return Err(ProtocolError::UnexpectedEof);
}
let interface = src.get_u8();
let tds_version = src.get_u32_le();
let prog_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let prog_version = src.get_u32_le();
Ok(Self {
interface,
tds_version,
prog_name,
prog_version,
})
}
#[must_use]
pub fn tds_version(&self) -> crate::version::TdsVersion {
crate::version::TdsVersion::new(self.tds_version)
}
}
impl EnvChangeType {
pub fn from_u8(value: u8) -> Option<Self> {
match value {
1 => Some(Self::Database),
2 => Some(Self::Language),
3 => Some(Self::CharacterSet),
4 => Some(Self::PacketSize),
5 => Some(Self::UnicodeSortingLocalId),
6 => Some(Self::UnicodeComparisonFlags),
7 => Some(Self::SqlCollation),
8 => Some(Self::BeginTransaction),
9 => Some(Self::CommitTransaction),
10 => Some(Self::RollbackTransaction),
11 => Some(Self::EnlistDtcTransaction),
12 => Some(Self::DefectTransaction),
13 => Some(Self::RealTimeLogShipping),
15 => Some(Self::PromoteTransaction),
16 => Some(Self::TransactionManagerAddress),
17 => Some(Self::TransactionEnded),
18 => Some(Self::ResetConnectionCompletionAck),
19 => Some(Self::UserInstanceStarted),
20 => Some(Self::Routing),
_ => None,
}
}
}
impl EnvChange {
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 3 {
return Err(ProtocolError::UnexpectedEof);
}
let length = src.get_u16_le() as usize;
if src.remaining() < length {
return Err(ProtocolError::IncompletePacket {
expected: length,
actual: src.remaining(),
});
}
let env_type_byte = src.get_u8();
let env_type = EnvChangeType::from_u8(env_type_byte)
.ok_or(ProtocolError::InvalidTokenType(env_type_byte))?;
let (new_value, old_value) = match env_type {
EnvChangeType::Routing => {
let new_value = Self::decode_routing_value(src)?;
let old_value = EnvChangeValue::Binary(Bytes::new());
(new_value, old_value)
}
EnvChangeType::BeginTransaction
| EnvChangeType::CommitTransaction
| EnvChangeType::RollbackTransaction
| EnvChangeType::EnlistDtcTransaction
| EnvChangeType::SqlCollation => {
let new_len = src.get_u8() as usize;
let new_value = if new_len > 0 && src.remaining() >= new_len {
EnvChangeValue::Binary(src.copy_to_bytes(new_len))
} else {
EnvChangeValue::Binary(Bytes::new())
};
let old_len = src.get_u8() as usize;
let old_value = if old_len > 0 && src.remaining() >= old_len {
EnvChangeValue::Binary(src.copy_to_bytes(old_len))
} else {
EnvChangeValue::Binary(Bytes::new())
};
(new_value, old_value)
}
_ => {
let new_value = read_b_varchar(src)
.map(EnvChangeValue::String)
.unwrap_or(EnvChangeValue::String(String::new()));
let old_value = read_b_varchar(src)
.map(EnvChangeValue::String)
.unwrap_or(EnvChangeValue::String(String::new()));
(new_value, old_value)
}
};
Ok(Self {
env_type,
new_value,
old_value,
})
}
fn decode_routing_value(src: &mut impl Buf) -> Result<EnvChangeValue, ProtocolError> {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let _routing_len = src.get_u16_le();
if src.remaining() < 5 {
return Err(ProtocolError::UnexpectedEof);
}
let _protocol = src.get_u8();
let port = src.get_u16_le();
let server_len = src.get_u16_le() as usize;
if src.remaining() < server_len * 2 {
return Err(ProtocolError::UnexpectedEof);
}
let mut chars = Vec::with_capacity(server_len);
for _ in 0..server_len {
chars.push(src.get_u16_le());
}
let host = String::from_utf16(&chars).map_err(|_| {
ProtocolError::StringEncoding(
#[cfg(feature = "std")]
"invalid UTF-16 in routing hostname".to_string(),
#[cfg(not(feature = "std"))]
"invalid UTF-16 in routing hostname",
)
})?;
Ok(EnvChangeValue::Routing { host, port })
}
#[must_use]
pub fn is_routing(&self) -> bool {
self.env_type == EnvChangeType::Routing
}
#[must_use]
pub fn routing_info(&self) -> Option<(&str, u16)> {
if let EnvChangeValue::Routing { host, port } = &self.new_value {
Some((host, *port))
} else {
None
}
}
#[must_use]
pub fn new_database(&self) -> Option<&str> {
if self.env_type == EnvChangeType::Database {
if let EnvChangeValue::String(s) = &self.new_value {
return Some(s);
}
}
None
}
}
impl Order {
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let length = src.get_u16_le() as usize;
let column_count = length / 2;
if src.remaining() < length {
return Err(ProtocolError::IncompletePacket {
expected: length,
actual: src.remaining(),
});
}
let mut columns = Vec::with_capacity(column_count);
for _ in 0..column_count {
columns.push(src.get_u16_le());
}
Ok(Self { columns })
}
}
impl FeatureExtAck {
pub const TERMINATOR: u8 = 0xFF;
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
let mut features = Vec::new();
loop {
if !src.has_remaining() {
return Err(ProtocolError::UnexpectedEof);
}
let feature_id = src.get_u8();
if feature_id == Self::TERMINATOR {
break;
}
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let data_len = src.get_u32_le() as usize;
if src.remaining() < data_len {
return Err(ProtocolError::IncompletePacket {
expected: data_len,
actual: src.remaining(),
});
}
let data = src.copy_to_bytes(data_len);
features.push(FeatureAck { feature_id, data });
}
Ok(Self { features })
}
}
impl SspiToken {
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let length = src.get_u16_le() as usize;
if src.remaining() < length {
return Err(ProtocolError::IncompletePacket {
expected: length,
actual: src.remaining(),
});
}
let data = src.copy_to_bytes(length);
Ok(Self { data })
}
}
impl FedAuthInfo {
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let _length = src.get_u32_le();
if src.remaining() < 5 {
return Err(ProtocolError::UnexpectedEof);
}
let _count = src.get_u8();
let mut sts_url = String::new();
let mut spn = String::new();
while src.has_remaining() {
if src.remaining() < 9 {
break;
}
let info_id = src.get_u8();
let info_len = src.get_u32_le() as usize;
let _info_offset = src.get_u32_le();
if src.remaining() < info_len {
break;
}
let char_count = info_len / 2;
let mut chars = Vec::with_capacity(char_count);
for _ in 0..char_count {
chars.push(src.get_u16_le());
}
if let Ok(value) = String::from_utf16(&chars) {
match info_id {
0x01 => spn = value,
0x02 => sts_url = value,
_ => {}
}
}
}
Ok(Self { sts_url, spn })
}
}
pub struct TokenParser {
data: Bytes,
position: usize,
encryption_enabled: bool,
}
impl TokenParser {
#[must_use]
pub fn new(data: Bytes) -> Self {
Self {
data,
position: 0,
encryption_enabled: false,
}
}
#[must_use]
pub fn with_encryption(mut self, enabled: bool) -> Self {
self.encryption_enabled = enabled;
self
}
#[must_use]
pub fn remaining(&self) -> usize {
self.data.len().saturating_sub(self.position)
}
#[must_use]
pub fn has_remaining(&self) -> bool {
self.position < self.data.len()
}
#[must_use]
pub fn peek_token_type(&self) -> Option<TokenType> {
if self.position < self.data.len() {
TokenType::from_u8(self.data[self.position])
} else {
None
}
}
pub fn next_token(&mut self) -> Result<Option<Token>, ProtocolError> {
self.next_token_with_metadata(None)
}
pub fn next_token_with_metadata(
&mut self,
metadata: Option<&ColMetaData>,
) -> Result<Option<Token>, ProtocolError> {
if !self.has_remaining() {
return Ok(None);
}
let mut buf = &self.data[self.position..];
let start_pos = self.position;
let token_type_byte = buf.get_u8();
let token_type = TokenType::from_u8(token_type_byte);
let token = match token_type {
Some(TokenType::Done) => {
let done = Done::decode(&mut buf)?;
Token::Done(done)
}
Some(TokenType::DoneProc) => {
let done = DoneProc::decode(&mut buf)?;
Token::DoneProc(done)
}
Some(TokenType::DoneInProc) => {
let done = DoneInProc::decode(&mut buf)?;
Token::DoneInProc(done)
}
Some(TokenType::Error) => {
let error = ServerError::decode(&mut buf)?;
Token::Error(error)
}
Some(TokenType::Info) => {
let info = ServerInfo::decode(&mut buf)?;
Token::Info(info)
}
Some(TokenType::LoginAck) => {
let login_ack = LoginAck::decode(&mut buf)?;
Token::LoginAck(login_ack)
}
Some(TokenType::EnvChange) => {
let env_change = EnvChange::decode(&mut buf)?;
Token::EnvChange(env_change)
}
Some(TokenType::Order) => {
let order = Order::decode(&mut buf)?;
Token::Order(order)
}
Some(TokenType::FeatureExtAck) => {
let ack = FeatureExtAck::decode(&mut buf)?;
Token::FeatureExtAck(ack)
}
Some(TokenType::Sspi) => {
let sspi = SspiToken::decode(&mut buf)?;
Token::Sspi(sspi)
}
Some(TokenType::FedAuthInfo) => {
let info = FedAuthInfo::decode(&mut buf)?;
Token::FedAuthInfo(info)
}
Some(TokenType::ReturnStatus) => {
if buf.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let status = buf.get_i32_le();
Token::ReturnStatus(status)
}
Some(TokenType::ColMetaData) => {
let col_meta = if self.encryption_enabled {
ColMetaData::decode_encrypted(&mut buf)?
} else {
ColMetaData::decode(&mut buf)?
};
Token::ColMetaData(col_meta)
}
Some(TokenType::Row) => {
let meta = metadata.ok_or_else(|| {
ProtocolError::StringEncoding(
#[cfg(feature = "std")]
"Row token requires column metadata".to_string(),
#[cfg(not(feature = "std"))]
"Row token requires column metadata",
)
})?;
let row = RawRow::decode(&mut buf, meta)?;
Token::Row(row)
}
Some(TokenType::NbcRow) => {
let meta = metadata.ok_or_else(|| {
ProtocolError::StringEncoding(
#[cfg(feature = "std")]
"NbcRow token requires column metadata".to_string(),
#[cfg(not(feature = "std"))]
"NbcRow token requires column metadata",
)
})?;
let row = NbcRow::decode(&mut buf, meta)?;
Token::NbcRow(row)
}
Some(TokenType::ReturnValue) => {
let ret_val = ReturnValue::decode(&mut buf)?;
Token::ReturnValue(ret_val)
}
Some(TokenType::SessionState) => {
let session = SessionState::decode(&mut buf)?;
Token::SessionState(session)
}
Some(TokenType::ColInfo) | Some(TokenType::TabName) | Some(TokenType::Offset) => {
if buf.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let length = buf.get_u16_le() as usize;
if buf.remaining() < length {
return Err(ProtocolError::IncompletePacket {
expected: length,
actual: buf.remaining(),
});
}
buf.advance(length);
self.position = start_pos + (self.data.len() - start_pos - buf.remaining());
return self.next_token_with_metadata(metadata);
}
None => {
return Err(ProtocolError::InvalidTokenType(token_type_byte));
}
};
let consumed = self.data.len() - start_pos - buf.remaining();
self.position = start_pos + consumed;
Ok(Some(token))
}
pub fn skip_token(&mut self) -> Result<(), ProtocolError> {
if !self.has_remaining() {
return Ok(());
}
let token_type_byte = self.data[self.position];
let token_type = TokenType::from_u8(token_type_byte);
let skip_amount = match token_type {
Some(TokenType::Done) | Some(TokenType::DoneProc) | Some(TokenType::DoneInProc) => {
1 + Done::SIZE }
Some(TokenType::ReturnStatus) => {
1 + 4 }
Some(TokenType::Error)
| Some(TokenType::Info)
| Some(TokenType::LoginAck)
| Some(TokenType::EnvChange)
| Some(TokenType::Order)
| Some(TokenType::Sspi)
| Some(TokenType::ColInfo)
| Some(TokenType::TabName)
| Some(TokenType::Offset)
| Some(TokenType::ReturnValue) => {
if self.remaining() < 3 {
return Err(ProtocolError::UnexpectedEof);
}
let length = u16::from_le_bytes([
self.data[self.position + 1],
self.data[self.position + 2],
]) as usize;
1 + 2 + length }
Some(TokenType::SessionState) | Some(TokenType::FedAuthInfo) => {
if self.remaining() < 5 {
return Err(ProtocolError::UnexpectedEof);
}
let length = u32::from_le_bytes([
self.data[self.position + 1],
self.data[self.position + 2],
self.data[self.position + 3],
self.data[self.position + 4],
]) as usize;
1 + 4 + length
}
Some(TokenType::FeatureExtAck) => {
let mut buf = &self.data[self.position + 1..];
let _ = FeatureExtAck::decode(&mut buf)?;
self.data.len() - self.position - buf.remaining()
}
Some(TokenType::ColMetaData) | Some(TokenType::Row) | Some(TokenType::NbcRow) => {
return Err(ProtocolError::InvalidTokenType(token_type_byte));
}
None => {
return Err(ProtocolError::InvalidTokenType(token_type_byte));
}
};
if self.remaining() < skip_amount {
return Err(ProtocolError::UnexpectedEof);
}
self.position += skip_amount;
Ok(())
}
#[must_use]
pub fn position(&self) -> usize {
self.position
}
pub fn reset(&mut self) {
self.position = 0;
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use bytes::BytesMut;
#[test]
fn test_done_roundtrip() {
let done = Done {
status: DoneStatus {
more: false,
error: false,
in_xact: false,
count: true,
attn: false,
srverror: false,
},
cur_cmd: 193, row_count: 42,
};
let mut buf = BytesMut::new();
done.encode(&mut buf);
let mut cursor = &buf[1..];
let decoded = Done::decode(&mut cursor).unwrap();
assert_eq!(decoded.status.count, done.status.count);
assert_eq!(decoded.cur_cmd, done.cur_cmd);
assert_eq!(decoded.row_count, done.row_count);
}
#[test]
fn test_done_status_bits() {
let status = DoneStatus {
more: true,
error: true,
in_xact: true,
count: true,
attn: false,
srverror: false,
};
let bits = status.to_bits();
let restored = DoneStatus::from_bits(bits);
assert_eq!(status.more, restored.more);
assert_eq!(status.error, restored.error);
assert_eq!(status.in_xact, restored.in_xact);
assert_eq!(status.count, restored.count);
}
#[test]
fn test_token_parser_done() {
let data = Bytes::from_static(&[
0xFD, 0x10, 0x00, 0xC1, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
let mut parser = TokenParser::new(data);
let token = parser.next_token().unwrap().unwrap();
match token {
Token::Done(done) => {
assert!(done.status.count);
assert!(!done.status.more);
assert_eq!(done.cur_cmd, 193);
assert_eq!(done.row_count, 5);
}
_ => panic!("Expected Done token"),
}
assert!(parser.next_token().unwrap().is_none());
}
#[test]
fn test_env_change_type_from_u8() {
assert_eq!(EnvChangeType::from_u8(1), Some(EnvChangeType::Database));
assert_eq!(EnvChangeType::from_u8(20), Some(EnvChangeType::Routing));
assert_eq!(EnvChangeType::from_u8(100), None);
}
#[test]
fn test_colmetadata_no_columns() {
let data = Bytes::from_static(&[0xFF, 0xFF]);
let mut cursor: &[u8] = &data;
let meta = ColMetaData::decode(&mut cursor).unwrap();
assert!(meta.is_empty());
assert_eq!(meta.column_count(), 0);
}
#[test]
fn test_colmetadata_single_int_column() {
let mut data = BytesMut::new();
data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x38]); data.extend_from_slice(&[0x02]); data.extend_from_slice(&[b'i', 0x00, b'd', 0x00]);
let mut cursor: &[u8] = &data;
let meta = ColMetaData::decode(&mut cursor).unwrap();
assert_eq!(meta.column_count(), 1);
assert_eq!(meta.columns[0].name, "id");
assert_eq!(meta.columns[0].type_id, TypeId::Int4);
assert!(meta.columns[0].is_nullable());
}
#[test]
fn test_colmetadata_nvarchar_column() {
let mut data = BytesMut::new();
data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0xE7]); data.extend_from_slice(&[0x64, 0x00]); data.extend_from_slice(&[0x09, 0x04, 0xD0, 0x00, 0x34]); data.extend_from_slice(&[0x04]); data.extend_from_slice(&[b'n', 0x00, b'a', 0x00, b'm', 0x00, b'e', 0x00]);
let mut cursor: &[u8] = &data;
let meta = ColMetaData::decode(&mut cursor).unwrap();
assert_eq!(meta.column_count(), 1);
assert_eq!(meta.columns[0].name, "name");
assert_eq!(meta.columns[0].type_id, TypeId::NVarChar);
assert_eq!(meta.columns[0].type_info.max_length, Some(100));
assert!(meta.columns[0].type_info.collation.is_some());
}
#[test]
fn test_raw_row_decode_int() {
let metadata = ColMetaData {
cek_table: None,
columns: vec![ColumnData {
name: "id".to_string(),
type_id: TypeId::Int4,
col_type: 0x38,
flags: 0,
user_type: 0,
type_info: TypeInfo::default(),
crypto_metadata: None,
}],
};
let data = Bytes::from_static(&[0x2A, 0x00, 0x00, 0x00]); let mut cursor: &[u8] = &data;
let row = RawRow::decode(&mut cursor, &metadata).unwrap();
assert_eq!(row.data.len(), 4);
assert_eq!(&row.data[..], &[0x2A, 0x00, 0x00, 0x00]);
}
#[test]
fn test_raw_row_decode_nullable_int() {
let metadata = ColMetaData {
cek_table: None,
columns: vec![ColumnData {
name: "id".to_string(),
type_id: TypeId::IntN,
col_type: 0x26,
flags: 0x01, user_type: 0,
type_info: TypeInfo {
max_length: Some(4),
..Default::default()
},
crypto_metadata: None,
}],
};
let data = Bytes::from_static(&[0x04, 0x2A, 0x00, 0x00, 0x00]); let mut cursor: &[u8] = &data;
let row = RawRow::decode(&mut cursor, &metadata).unwrap();
assert_eq!(row.data.len(), 5);
assert_eq!(row.data[0], 4); assert_eq!(&row.data[1..], &[0x2A, 0x00, 0x00, 0x00]);
}
#[test]
fn test_raw_row_decode_null_value() {
let metadata = ColMetaData {
cek_table: None,
columns: vec![ColumnData {
name: "id".to_string(),
type_id: TypeId::IntN,
col_type: 0x26,
flags: 0x01, user_type: 0,
type_info: TypeInfo {
max_length: Some(4),
..Default::default()
},
crypto_metadata: None,
}],
};
let data = Bytes::from_static(&[0xFF]);
let mut cursor: &[u8] = &data;
let row = RawRow::decode(&mut cursor, &metadata).unwrap();
assert_eq!(row.data.len(), 1);
assert_eq!(row.data[0], 0xFF); }
#[test]
fn test_nbcrow_null_bitmap() {
let row = NbcRow {
null_bitmap: vec![0b00000101], data: Bytes::new(),
};
assert!(row.is_null(0));
assert!(!row.is_null(1));
assert!(row.is_null(2));
assert!(!row.is_null(3));
}
#[test]
fn test_token_parser_colmetadata() {
let mut data = BytesMut::new();
data.extend_from_slice(&[0x81]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x38]); data.extend_from_slice(&[0x02]); data.extend_from_slice(&[b'i', 0x00, b'd', 0x00]);
let mut parser = TokenParser::new(data.freeze());
let token = parser.next_token().unwrap().unwrap();
match token {
Token::ColMetaData(meta) => {
assert_eq!(meta.column_count(), 1);
assert_eq!(meta.columns[0].name, "id");
assert_eq!(meta.columns[0].type_id, TypeId::Int4);
}
_ => panic!("Expected ColMetaData token"),
}
}
#[test]
fn test_token_parser_row_with_metadata() {
let metadata = ColMetaData {
cek_table: None,
columns: vec![ColumnData {
name: "id".to_string(),
type_id: TypeId::Int4,
col_type: 0x38,
flags: 0,
user_type: 0,
type_info: TypeInfo::default(),
crypto_metadata: None,
}],
};
let mut data = BytesMut::new();
data.extend_from_slice(&[0xD1]); data.extend_from_slice(&[0x2A, 0x00, 0x00, 0x00]);
let mut parser = TokenParser::new(data.freeze());
let token = parser
.next_token_with_metadata(Some(&metadata))
.unwrap()
.unwrap();
match token {
Token::Row(row) => {
assert_eq!(row.data.len(), 4);
}
_ => panic!("Expected Row token"),
}
}
#[test]
fn test_token_parser_row_without_metadata_fails() {
let mut data = BytesMut::new();
data.extend_from_slice(&[0xD1]); data.extend_from_slice(&[0x2A, 0x00, 0x00, 0x00]);
let mut parser = TokenParser::new(data.freeze());
let result = parser.next_token();
assert!(result.is_err());
}
#[test]
fn test_token_parser_peek() {
let data = Bytes::from_static(&[
0xFD, 0x10, 0x00, 0xC1, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
let parser = TokenParser::new(data);
assert_eq!(parser.peek_token_type(), Some(TokenType::Done));
}
#[test]
fn test_column_data_fixed_size() {
let col = ColumnData {
name: String::new(),
type_id: TypeId::Int4,
col_type: 0x38,
flags: 0,
user_type: 0,
type_info: TypeInfo::default(),
crypto_metadata: None,
};
assert_eq!(col.fixed_size(), Some(4));
let col2 = ColumnData {
name: String::new(),
type_id: TypeId::NVarChar,
col_type: 0xE7,
flags: 0,
user_type: 0,
type_info: TypeInfo::default(),
crypto_metadata: None,
};
assert_eq!(col2.fixed_size(), None);
}
#[test]
fn test_decode_nvarchar_then_intn_roundtrip() {
let mut wire_data = BytesMut::new();
let word = "World";
let utf16: Vec<u16> = word.encode_utf16().collect();
wire_data.put_u16_le((utf16.len() * 2) as u16); for code_unit in &utf16 {
wire_data.put_u16_le(*code_unit);
}
wire_data.put_u8(4); wire_data.put_i32_le(42);
let metadata = ColMetaData {
cek_table: None,
columns: vec![
ColumnData {
name: "greeting".to_string(),
type_id: TypeId::NVarChar,
col_type: 0xE7,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length: Some(10), precision: None,
scale: None,
collation: None,
},
crypto_metadata: None,
},
ColumnData {
name: "number".to_string(),
type_id: TypeId::IntN,
col_type: 0x26,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length: Some(4),
precision: None,
scale: None,
collation: None,
},
crypto_metadata: None,
},
],
};
let mut wire_cursor = wire_data.freeze();
let raw_row = RawRow::decode(&mut wire_cursor, &metadata).unwrap();
assert_eq!(
wire_cursor.remaining(),
0,
"wire data should be fully consumed"
);
let mut stored_cursor: &[u8] = &raw_row.data;
assert!(
stored_cursor.remaining() >= 2,
"need at least 2 bytes for length"
);
let len0 = stored_cursor.get_u16_le() as usize;
assert_eq!(len0, 10, "NVarChar length should be 10 bytes");
assert!(
stored_cursor.remaining() >= len0,
"need {len0} bytes for data"
);
let mut utf16_read = Vec::new();
for _ in 0..(len0 / 2) {
utf16_read.push(stored_cursor.get_u16_le());
}
let string0 = String::from_utf16(&utf16_read).unwrap();
assert_eq!(string0, "World", "column 0 should be 'World'");
assert!(
stored_cursor.remaining() >= 1,
"need at least 1 byte for length"
);
let len1 = stored_cursor.get_u8();
assert_eq!(len1, 4, "IntN length should be 4");
assert!(stored_cursor.remaining() >= 4, "need 4 bytes for INT data");
let int1 = stored_cursor.get_i32_le();
assert_eq!(int1, 42, "column 1 should be 42");
assert_eq!(
stored_cursor.remaining(),
0,
"stored data should be fully consumed"
);
}
#[test]
fn test_decode_nvarchar_max_then_intn_roundtrip() {
let mut wire_data = BytesMut::new();
let word = "Hello";
let utf16: Vec<u16> = word.encode_utf16().collect();
let byte_len = (utf16.len() * 2) as u64;
wire_data.put_u64_le(byte_len); wire_data.put_u32_le(byte_len as u32); for code_unit in &utf16 {
wire_data.put_u16_le(*code_unit);
}
wire_data.put_u32_le(0);
wire_data.put_u8(4);
wire_data.put_i32_le(99);
let metadata = ColMetaData {
cek_table: None,
columns: vec![
ColumnData {
name: "text".to_string(),
type_id: TypeId::NVarChar,
col_type: 0xE7,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length: Some(0xFFFF), precision: None,
scale: None,
collation: None,
},
crypto_metadata: None,
},
ColumnData {
name: "num".to_string(),
type_id: TypeId::IntN,
col_type: 0x26,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length: Some(4),
precision: None,
scale: None,
collation: None,
},
crypto_metadata: None,
},
],
};
let mut wire_cursor = wire_data.freeze();
let raw_row = RawRow::decode(&mut wire_cursor, &metadata).unwrap();
assert_eq!(
wire_cursor.remaining(),
0,
"wire data should be fully consumed"
);
let mut stored_cursor: &[u8] = &raw_row.data;
let total_len = stored_cursor.get_u64_le();
assert_eq!(total_len, 10, "PLP total length should be 10");
let chunk_len = stored_cursor.get_u32_le();
assert_eq!(chunk_len, 10, "PLP chunk length should be 10");
let mut utf16_read = Vec::new();
for _ in 0..(chunk_len / 2) {
utf16_read.push(stored_cursor.get_u16_le());
}
let string0 = String::from_utf16(&utf16_read).unwrap();
assert_eq!(string0, "Hello", "column 0 should be 'Hello'");
let terminator = stored_cursor.get_u32_le();
assert_eq!(terminator, 0, "PLP should end with 0");
let len1 = stored_cursor.get_u8();
assert_eq!(len1, 4);
let int1 = stored_cursor.get_i32_le();
assert_eq!(int1, 99, "column 1 should be 99");
assert_eq!(
stored_cursor.remaining(),
0,
"stored data should be fully consumed"
);
}
#[test]
fn test_return_status_via_parser() {
let data = Bytes::from_static(&[
0x79, 0x00, 0x00, 0x00, 0x00, ]);
let mut parser = TokenParser::new(data);
let token = parser.next_token().unwrap().unwrap();
match token {
Token::ReturnStatus(status) => {
assert_eq!(status, 0);
}
_ => panic!("Expected ReturnStatus token, got {token:?}"),
}
assert!(parser.next_token().unwrap().is_none());
}
#[test]
fn test_return_status_nonzero() {
let mut buf = BytesMut::new();
buf.put_u8(0x79); buf.put_i32_le(-6);
let mut parser = TokenParser::new(buf.freeze());
let token = parser.next_token().unwrap().unwrap();
match token {
Token::ReturnStatus(status) => {
assert_eq!(status, -6);
}
_ => panic!("Expected ReturnStatus token"),
}
}
#[test]
fn test_done_proc_roundtrip() {
let done = DoneProc {
status: DoneStatus {
more: false,
error: false,
in_xact: false,
count: true,
attn: false,
srverror: false,
},
cur_cmd: 0x00C6, row_count: 100,
};
let mut buf = BytesMut::new();
done.encode(&mut buf);
assert_eq!(buf[0], 0xFE);
let mut cursor = &buf[1..];
let decoded = DoneProc::decode(&mut cursor).unwrap();
assert!(decoded.status.count);
assert!(!decoded.status.more);
assert!(!decoded.status.error);
assert_eq!(decoded.cur_cmd, 0x00C6);
assert_eq!(decoded.row_count, 100);
}
#[test]
fn test_done_proc_via_parser() {
let data = Bytes::from_static(&[
0xFE, 0x00, 0x00, 0xC6, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
let mut parser = TokenParser::new(data);
let token = parser.next_token().unwrap().unwrap();
match token {
Token::DoneProc(done) => {
assert!(!done.status.count);
assert!(!done.status.more);
assert_eq!(done.cur_cmd, 198);
assert_eq!(done.row_count, 0);
}
_ => panic!("Expected DoneProc token"),
}
}
#[test]
fn test_done_proc_with_error_flag() {
let mut buf = BytesMut::new();
buf.put_u8(0xFE); buf.put_u16_le(0x0002); buf.put_u16_le(0x00C6); buf.put_u64_le(0);
let mut parser = TokenParser::new(buf.freeze());
let token = parser.next_token().unwrap().unwrap();
match token {
Token::DoneProc(done) => {
assert!(done.status.error);
assert!(!done.status.count);
assert!(!done.status.more);
}
_ => panic!("Expected DoneProc token"),
}
}
#[test]
fn test_done_in_proc_roundtrip() {
let done = DoneInProc {
status: DoneStatus {
more: true,
error: false,
in_xact: false,
count: true,
attn: false,
srverror: false,
},
cur_cmd: 193, row_count: 7,
};
let mut buf = BytesMut::new();
done.encode(&mut buf);
assert_eq!(buf[0], 0xFF);
let mut cursor = &buf[1..];
let decoded = DoneInProc::decode(&mut cursor).unwrap();
assert!(decoded.status.more);
assert!(decoded.status.count);
assert!(!decoded.status.error);
assert_eq!(decoded.cur_cmd, 193);
assert_eq!(decoded.row_count, 7);
}
#[test]
fn test_done_in_proc_via_parser() {
let data = Bytes::from_static(&[
0xFF, 0x11, 0x00, 0xC1, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
let mut parser = TokenParser::new(data);
let token = parser.next_token().unwrap().unwrap();
match token {
Token::DoneInProc(done) => {
assert!(done.status.more);
assert!(done.status.count);
assert_eq!(done.cur_cmd, 193);
assert_eq!(done.row_count, 3);
}
_ => panic!("Expected DoneInProc token"),
}
}
#[test]
fn test_server_error_decode() {
let mut buf = BytesMut::new();
let msg_utf16: Vec<u16> = "Invalid column name 'foo'.".encode_utf16().collect();
let srv_utf16: Vec<u16> = "SQLDB01".encode_utf16().collect();
let proc_utf16: Vec<u16> = "".encode_utf16().collect();
let length: u16 = (4
+ 1
+ 1
+ 2
+ (msg_utf16.len() * 2)
+ 1
+ (srv_utf16.len() * 2)
+ 1
+ (proc_utf16.len() * 2)
+ 4) as u16;
buf.put_u16_le(length);
buf.put_i32_le(207); buf.put_u8(1); buf.put_u8(16);
buf.put_u16_le(msg_utf16.len() as u16);
for &c in &msg_utf16 {
buf.put_u16_le(c);
}
buf.put_u8(srv_utf16.len() as u8);
for &c in &srv_utf16 {
buf.put_u16_le(c);
}
buf.put_u8(proc_utf16.len() as u8);
buf.put_i32_le(42);
let mut cursor = buf.freeze();
let error = ServerError::decode(&mut cursor).unwrap();
assert_eq!(error.number, 207);
assert_eq!(error.state, 1);
assert_eq!(error.class, 16);
assert_eq!(error.message, "Invalid column name 'foo'.");
assert_eq!(error.server, "SQLDB01");
assert_eq!(error.procedure, "");
assert_eq!(error.line, 42);
}
#[test]
fn test_server_error_severity_helpers() {
let fatal = ServerError {
number: 4014,
state: 1,
class: 20,
message: "Fatal error".to_string(),
server: String::new(),
procedure: String::new(),
line: 0,
};
assert!(fatal.is_fatal());
assert!(fatal.is_batch_abort());
let batch_abort = ServerError {
number: 547,
state: 0,
class: 16,
message: "Constraint violation".to_string(),
server: String::new(),
procedure: String::new(),
line: 1,
};
assert!(!batch_abort.is_fatal());
assert!(batch_abort.is_batch_abort());
let informational = ServerError {
number: 5701,
state: 2,
class: 10,
message: "Changed db context".to_string(),
server: String::new(),
procedure: String::new(),
line: 0,
};
assert!(!informational.is_fatal());
assert!(!informational.is_batch_abort());
}
#[test]
fn test_server_error_via_parser() {
let mut buf = BytesMut::new();
buf.put_u8(0xAA);
let msg_utf16: Vec<u16> = "Syntax error".encode_utf16().collect();
let srv_utf16: Vec<u16> = "SRV".encode_utf16().collect();
let proc_utf16: Vec<u16> = "sp_test".encode_utf16().collect();
let length: u16 = (4
+ 1
+ 1
+ 2
+ (msg_utf16.len() * 2)
+ 1
+ (srv_utf16.len() * 2)
+ 1
+ (proc_utf16.len() * 2)
+ 4) as u16;
buf.put_u16_le(length);
buf.put_i32_le(102); buf.put_u8(1);
buf.put_u8(15);
buf.put_u16_le(msg_utf16.len() as u16);
for &c in &msg_utf16 {
buf.put_u16_le(c);
}
buf.put_u8(srv_utf16.len() as u8);
for &c in &srv_utf16 {
buf.put_u16_le(c);
}
buf.put_u8(proc_utf16.len() as u8);
for &c in &proc_utf16 {
buf.put_u16_le(c);
}
buf.put_i32_le(5);
let mut parser = TokenParser::new(buf.freeze());
let token = parser.next_token().unwrap().unwrap();
match token {
Token::Error(err) => {
assert_eq!(err.number, 102);
assert_eq!(err.class, 15);
assert_eq!(err.message, "Syntax error");
assert_eq!(err.server, "SRV");
assert_eq!(err.procedure, "sp_test");
assert_eq!(err.line, 5);
}
_ => panic!("Expected Error token"),
}
}
fn build_return_value_intn(
ordinal: u16,
name: &str,
status: u8,
value: Option<i32>,
) -> BytesMut {
let mut inner = BytesMut::new();
inner.put_u16_le(ordinal);
let name_utf16: Vec<u16> = name.encode_utf16().collect();
inner.put_u8(name_utf16.len() as u8);
for &c in &name_utf16 {
inner.put_u16_le(c);
}
inner.put_u8(status);
inner.put_u32_le(0);
inner.put_u16_le(0x0001);
inner.put_u8(0x26);
inner.put_u8(4);
match value {
Some(v) => {
inner.put_u8(4); inner.put_i32_le(v);
}
None => {
inner.put_u8(0); }
}
inner
}
#[test]
fn test_return_value_int_output() {
let buf = build_return_value_intn(1, "@result", 0x01, Some(42));
let mut cursor = buf.freeze();
let rv = ReturnValue::decode(&mut cursor).unwrap();
assert_eq!(rv.param_ordinal, 1);
assert_eq!(rv.param_name, "@result");
assert_eq!(rv.status, 0x01); assert_eq!(rv.col_type, 0x26); assert_eq!(rv.type_info.max_length, Some(4));
assert_eq!(rv.value.len(), 5);
assert_eq!(rv.value[0], 4);
assert_eq!(
i32::from_le_bytes([rv.value[1], rv.value[2], rv.value[3], rv.value[4]]),
42
);
}
#[test]
fn test_return_value_null_output() {
let buf = build_return_value_intn(2, "@count", 0x01, None);
let mut cursor = buf.freeze();
let rv = ReturnValue::decode(&mut cursor).unwrap();
assert_eq!(rv.param_ordinal, 2);
assert_eq!(rv.param_name, "@count");
assert_eq!(rv.status, 0x01);
assert_eq!(rv.col_type, 0x26);
assert_eq!(rv.value.len(), 1);
assert_eq!(rv.value[0], 0);
}
#[test]
fn test_return_value_udf_status() {
let buf = build_return_value_intn(0, "@RETURN_VALUE", 0x02, Some(-1));
let mut cursor = buf.freeze();
let rv = ReturnValue::decode(&mut cursor).unwrap();
assert_eq!(rv.param_ordinal, 0);
assert_eq!(rv.param_name, "@RETURN_VALUE");
assert_eq!(rv.status, 0x02); assert_eq!(rv.value[0], 4);
assert_eq!(
i32::from_le_bytes([rv.value[1], rv.value[2], rv.value[3], rv.value[4]]),
-1
);
}
#[test]
fn test_return_value_nvarchar_output() {
let mut inner = BytesMut::new();
inner.put_u16_le(1);
let name_utf16: Vec<u16> = "@name".encode_utf16().collect();
inner.put_u8(name_utf16.len() as u8);
for &c in &name_utf16 {
inner.put_u16_le(c);
}
inner.put_u8(0x01);
inner.put_u32_le(0);
inner.put_u16_le(0x0001);
inner.put_u8(0xE7);
inner.put_u16_le(200); inner.put_u32_le(0x0904D000); inner.put_u8(0x34);
let val_utf16: Vec<u16> = "Hello".encode_utf16().collect();
let byte_len = (val_utf16.len() * 2) as u16;
inner.put_u16_le(byte_len);
for &c in &val_utf16 {
inner.put_u16_le(c);
}
let mut cursor = inner.freeze();
let rv = ReturnValue::decode(&mut cursor).unwrap();
assert_eq!(rv.param_ordinal, 1);
assert_eq!(rv.param_name, "@name");
assert_eq!(rv.status, 0x01);
assert_eq!(rv.col_type, 0xE7); assert_eq!(rv.type_info.max_length, Some(200));
assert!(rv.type_info.collation.is_some());
assert_eq!(rv.value.len(), 12); let val_len = u16::from_le_bytes([rv.value[0], rv.value[1]]);
assert_eq!(val_len, 10);
}
#[test]
fn test_return_value_via_parser() {
let mut data = BytesMut::new();
data.put_u8(0xAC); data.extend_from_slice(&build_return_value_intn(0, "@out", 0x01, Some(99)));
let mut parser = TokenParser::new(data.freeze());
let token = parser.next_token().unwrap().unwrap();
match token {
Token::ReturnValue(rv) => {
assert_eq!(rv.param_name, "@out");
assert_eq!(rv.param_ordinal, 0);
assert_eq!(rv.status, 0x01);
assert_eq!(rv.col_type, 0x26);
}
_ => panic!("Expected ReturnValue token"),
}
}
#[test]
fn test_multi_token_stored_proc_response() {
let mut data = BytesMut::new();
data.put_u8(0xFF); data.put_u16_le(0x0010); data.put_u16_le(0x00C1); data.put_u64_le(3);
data.put_u8(0x79); data.put_i32_le(0);
data.put_u8(0xFE); data.put_u16_le(0x0000); data.put_u16_le(0x00C6); data.put_u64_le(0);
let mut parser = TokenParser::new(data.freeze());
let t1 = parser.next_token().unwrap().unwrap();
match t1 {
Token::DoneInProc(done) => {
assert!(done.status.count);
assert_eq!(done.row_count, 3);
assert_eq!(done.cur_cmd, 193);
}
_ => panic!("Expected DoneInProc, got {t1:?}"),
}
let t2 = parser.next_token().unwrap().unwrap();
match t2 {
Token::ReturnStatus(status) => {
assert_eq!(status, 0);
}
_ => panic!("Expected ReturnStatus, got {t2:?}"),
}
let t3 = parser.next_token().unwrap().unwrap();
match t3 {
Token::DoneProc(done) => {
assert!(!done.status.count);
assert!(!done.status.more);
assert_eq!(done.cur_cmd, 198);
}
_ => panic!("Expected DoneProc, got {t3:?}"),
}
assert!(parser.next_token().unwrap().is_none());
}
#[test]
fn test_multi_token_error_in_stream() {
let mut data = BytesMut::new();
data.put_u8(0xAA);
let msg_utf16: Vec<u16> = "Deadlock".encode_utf16().collect();
let srv_utf16: Vec<u16> = "DB1".encode_utf16().collect();
let length: u16 = (4 + 1 + 1
+ 2 + (msg_utf16.len() * 2)
+ 1 + (srv_utf16.len() * 2)
+ 1 + 4) as u16;
data.put_u16_le(length);
data.put_i32_le(1205); data.put_u8(51); data.put_u8(13);
data.put_u16_le(msg_utf16.len() as u16);
for &c in &msg_utf16 {
data.put_u16_le(c);
}
data.put_u8(srv_utf16.len() as u8);
for &c in &srv_utf16 {
data.put_u16_le(c);
}
data.put_u8(0); data.put_i32_le(0);
data.put_u8(0xFD);
data.put_u16_le(0x0002); data.put_u16_le(0x00C1); data.put_u64_le(0);
let mut parser = TokenParser::new(data.freeze());
let t1 = parser.next_token().unwrap().unwrap();
match t1 {
Token::Error(err) => {
assert_eq!(err.number, 1205);
assert_eq!(err.class, 13);
assert_eq!(err.message, "Deadlock");
assert_eq!(err.server, "DB1");
}
_ => panic!("Expected Error token, got {t1:?}"),
}
let t2 = parser.next_token().unwrap().unwrap();
match t2 {
Token::Done(done) => {
assert!(done.status.error);
assert!(!done.status.count);
}
_ => panic!("Expected Done token, got {t2:?}"),
}
assert!(parser.next_token().unwrap().is_none());
}
#[test]
fn test_multi_token_proc_with_return_value() {
let mut data = BytesMut::new();
data.put_u8(0xAC);
data.extend_from_slice(&build_return_value_intn(1, "@result", 0x01, Some(42)));
data.put_u8(0x79);
data.put_i32_le(0);
data.put_u8(0xFE);
data.put_u16_le(0x0000);
data.put_u16_le(0x00C6);
data.put_u64_le(0);
let mut parser = TokenParser::new(data.freeze());
let t1 = parser.next_token().unwrap().unwrap();
match t1 {
Token::ReturnValue(rv) => {
assert_eq!(rv.param_name, "@result");
assert_eq!(rv.param_ordinal, 1);
}
_ => panic!("Expected ReturnValue, got {t1:?}"),
}
let t2 = parser.next_token().unwrap().unwrap();
assert!(matches!(t2, Token::ReturnStatus(0)));
let t3 = parser.next_token().unwrap().unwrap();
assert!(matches!(t3, Token::DoneProc(_)));
assert!(parser.next_token().unwrap().is_none());
}
#[test]
fn test_return_status_truncated() {
let data = Bytes::from_static(&[0x79, 0x01, 0x02, 0x03]);
let mut parser = TokenParser::new(data);
assert!(parser.next_token().is_err());
}
#[test]
fn test_done_proc_truncated() {
let data = Bytes::from_static(&[0xFE, 0x00, 0x00, 0xC1, 0x00, 0x01, 0x00, 0x00, 0x00]);
let mut parser = TokenParser::new(data);
assert!(parser.next_token().is_err());
}
#[test]
fn test_server_error_truncated() {
let data = Bytes::from_static(&[0xAA, 0x20, 0x00]);
let mut parser = TokenParser::new(data);
assert!(parser.next_token().is_err());
}
}