use bytes::{Buf, BufMut, Bytes};
use crate::codec::{read_b_varchar, read_us_varchar};
use crate::error::ProtocolError;
use crate::prelude::*;
use crate::types::TypeId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u8)]
#[non_exhaustive]
pub enum TokenType {
ColMetaData = 0x81,
Error = 0xAA,
Info = 0xAB,
LoginAck = 0xAD,
Row = 0xD1,
NbcRow = 0xD2,
EnvChange = 0xE3,
Sspi = 0xED,
Done = 0xFD,
DoneInProc = 0xFF,
DoneProc = 0xFE,
ReturnStatus = 0x79,
ReturnValue = 0xAC,
Order = 0xA9,
FeatureExtAck = 0xAE,
SessionState = 0xE4,
FedAuthInfo = 0xEE,
ColInfo = 0xA5,
TabName = 0xA4,
Offset = 0x78,
}
impl TokenType {
pub fn from_u8(value: u8) -> Option<Self> {
match value {
0x81 => Some(Self::ColMetaData),
0xAA => Some(Self::Error),
0xAB => Some(Self::Info),
0xAD => Some(Self::LoginAck),
0xD1 => Some(Self::Row),
0xD2 => Some(Self::NbcRow),
0xE3 => Some(Self::EnvChange),
0xED => Some(Self::Sspi),
0xFD => Some(Self::Done),
0xFF => Some(Self::DoneInProc),
0xFE => Some(Self::DoneProc),
0x79 => Some(Self::ReturnStatus),
0xAC => Some(Self::ReturnValue),
0xA9 => Some(Self::Order),
0xAE => Some(Self::FeatureExtAck),
0xE4 => Some(Self::SessionState),
0xEE => Some(Self::FedAuthInfo),
0xA5 => Some(Self::ColInfo),
0xA4 => Some(Self::TabName),
0x78 => Some(Self::Offset),
_ => None,
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum Token {
ColMetaData(ColMetaData),
Row(RawRow),
NbcRow(NbcRow),
Done(Done),
DoneProc(DoneProc),
DoneInProc(DoneInProc),
ReturnStatus(i32),
ReturnValue(ReturnValue),
Error(ServerError),
Info(ServerInfo),
LoginAck(LoginAck),
EnvChange(EnvChange),
Order(Order),
FeatureExtAck(FeatureExtAck),
Sspi(SspiToken),
SessionState(SessionState),
FedAuthInfo(FedAuthInfo),
}
#[derive(Debug, Clone, Default)]
pub struct ColMetaData {
pub columns: Vec<ColumnData>,
}
#[derive(Debug, Clone)]
pub struct ColumnData {
pub name: String,
pub type_id: TypeId,
pub col_type: u8,
pub flags: u16,
pub user_type: u32,
pub type_info: TypeInfo,
}
#[derive(Debug, Clone, Default)]
pub struct TypeInfo {
pub max_length: Option<u32>,
pub precision: Option<u8>,
pub scale: Option<u8>,
pub collation: Option<Collation>,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Collation {
pub lcid: u32,
pub sort_id: u8,
}
impl Collation {
#[cfg(feature = "encoding")]
pub fn encoding(&self) -> Option<&'static encoding_rs::Encoding> {
crate::collation::encoding_for_lcid(self.lcid)
}
#[cfg(feature = "encoding")]
pub fn is_utf8(&self) -> bool {
crate::collation::is_utf8_collation(self.lcid)
}
#[cfg(feature = "encoding")]
pub fn code_page(&self) -> Option<u16> {
crate::collation::code_page_for_lcid(self.lcid)
}
#[cfg(feature = "encoding")]
pub fn encoding_name(&self) -> &'static str {
crate::collation::encoding_name_for_lcid(self.lcid)
}
}
#[derive(Debug, Clone)]
pub struct RawRow {
pub data: bytes::Bytes,
}
#[derive(Debug, Clone)]
pub struct NbcRow {
pub null_bitmap: Vec<u8>,
pub data: bytes::Bytes,
}
#[derive(Debug, Clone, Copy)]
pub struct Done {
pub status: DoneStatus,
pub cur_cmd: u16,
pub row_count: u64,
}
#[derive(Debug, Clone, Copy, Default)]
#[non_exhaustive]
pub struct DoneStatus {
pub more: bool,
pub error: bool,
pub in_xact: bool,
pub count: bool,
pub attn: bool,
pub srverror: bool,
}
#[derive(Debug, Clone, Copy)]
pub struct DoneInProc {
pub status: DoneStatus,
pub cur_cmd: u16,
pub row_count: u64,
}
#[derive(Debug, Clone, Copy)]
pub struct DoneProc {
pub status: DoneStatus,
pub cur_cmd: u16,
pub row_count: u64,
}
#[derive(Debug, Clone)]
pub struct ReturnValue {
pub param_ordinal: u16,
pub param_name: String,
pub status: u8,
pub user_type: u32,
pub flags: u16,
pub type_info: TypeInfo,
pub value: bytes::Bytes,
}
#[derive(Debug, Clone)]
pub struct ServerError {
pub number: i32,
pub state: u8,
pub class: u8,
pub message: String,
pub server: String,
pub procedure: String,
pub line: i32,
}
#[derive(Debug, Clone)]
pub struct ServerInfo {
pub number: i32,
pub state: u8,
pub class: u8,
pub message: String,
pub server: String,
pub procedure: String,
pub line: i32,
}
#[derive(Debug, Clone)]
pub struct LoginAck {
pub interface: u8,
pub tds_version: u32,
pub prog_name: String,
pub prog_version: u32,
}
#[derive(Debug, Clone)]
pub struct EnvChange {
pub env_type: EnvChangeType,
pub new_value: EnvChangeValue,
pub old_value: EnvChangeValue,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
#[non_exhaustive]
pub enum EnvChangeType {
Database = 1,
Language = 2,
CharacterSet = 3,
PacketSize = 4,
UnicodeSortingLocalId = 5,
UnicodeComparisonFlags = 6,
SqlCollation = 7,
BeginTransaction = 8,
CommitTransaction = 9,
RollbackTransaction = 10,
EnlistDtcTransaction = 11,
DefectTransaction = 12,
RealTimeLogShipping = 13,
PromoteTransaction = 15,
TransactionManagerAddress = 16,
TransactionEnded = 17,
ResetConnectionCompletionAck = 18,
UserInstanceStarted = 19,
Routing = 20,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum EnvChangeValue {
String(String),
Binary(bytes::Bytes),
Routing {
host: String,
port: u16,
},
}
#[derive(Debug, Clone)]
pub struct Order {
pub columns: Vec<u16>,
}
#[derive(Debug, Clone)]
pub struct FeatureExtAck {
pub features: Vec<FeatureAck>,
}
#[derive(Debug, Clone)]
pub struct FeatureAck {
pub feature_id: u8,
pub data: bytes::Bytes,
}
#[derive(Debug, Clone)]
pub struct SspiToken {
pub data: bytes::Bytes,
}
#[derive(Debug, Clone)]
pub struct SessionState {
pub data: bytes::Bytes,
}
#[derive(Debug, Clone)]
pub struct FedAuthInfo {
pub sts_url: String,
pub spn: String,
}
impl ColMetaData {
pub const NO_METADATA: u16 = 0xFFFF;
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let column_count = src.get_u16_le();
if column_count == Self::NO_METADATA {
return Ok(Self {
columns: Vec::new(),
});
}
let mut columns = Vec::with_capacity(column_count as usize);
for _ in 0..column_count {
let column = Self::decode_column(src)?;
columns.push(column);
}
Ok(Self { columns })
}
fn decode_column(src: &mut impl Buf) -> Result<ColumnData, ProtocolError> {
if src.remaining() < 7 {
return Err(ProtocolError::UnexpectedEof);
}
let user_type = src.get_u32_le();
let flags = src.get_u16_le();
let col_type = src.get_u8();
let type_id = TypeId::from_u8(col_type).unwrap_or(TypeId::Null);
let type_info = Self::decode_type_info(src, type_id, col_type)?;
let name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
Ok(ColumnData {
name,
type_id,
col_type,
flags,
user_type,
type_info,
})
}
fn decode_type_info(
src: &mut impl Buf,
type_id: TypeId,
col_type: u8,
) -> Result<TypeInfo, ProtocolError> {
match type_id {
TypeId::Null => Ok(TypeInfo::default()),
TypeId::Int1 | TypeId::Bit => Ok(TypeInfo::default()),
TypeId::Int2 => Ok(TypeInfo::default()),
TypeId::Int4 => Ok(TypeInfo::default()),
TypeId::Int8 => Ok(TypeInfo::default()),
TypeId::Float4 => Ok(TypeInfo::default()),
TypeId::Float8 => Ok(TypeInfo::default()),
TypeId::Money => Ok(TypeInfo::default()),
TypeId::Money4 => Ok(TypeInfo::default()),
TypeId::DateTime => Ok(TypeInfo::default()),
TypeId::DateTime4 => Ok(TypeInfo::default()),
TypeId::IntN | TypeId::BitN | TypeId::FloatN | TypeId::MoneyN | TypeId::DateTimeN => {
if src.remaining() < 1 {
return Err(ProtocolError::UnexpectedEof);
}
let max_length = src.get_u8() as u32;
Ok(TypeInfo {
max_length: Some(max_length),
..Default::default()
})
}
TypeId::Guid => {
if src.remaining() < 1 {
return Err(ProtocolError::UnexpectedEof);
}
let max_length = src.get_u8() as u32;
Ok(TypeInfo {
max_length: Some(max_length),
..Default::default()
})
}
TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
if src.remaining() < 3 {
return Err(ProtocolError::UnexpectedEof);
}
let max_length = src.get_u8() as u32;
let precision = src.get_u8();
let scale = src.get_u8();
Ok(TypeInfo {
max_length: Some(max_length),
precision: Some(precision),
scale: Some(scale),
..Default::default()
})
}
TypeId::Char | TypeId::VarChar | TypeId::Binary | TypeId::VarBinary => {
if src.remaining() < 1 {
return Err(ProtocolError::UnexpectedEof);
}
let max_length = src.get_u8() as u32;
Ok(TypeInfo {
max_length: Some(max_length),
..Default::default()
})
}
TypeId::BigVarChar | TypeId::BigChar => {
if src.remaining() < 7 {
return Err(ProtocolError::UnexpectedEof);
}
let max_length = src.get_u16_le() as u32;
let collation = Self::decode_collation(src)?;
Ok(TypeInfo {
max_length: Some(max_length),
collation: Some(collation),
..Default::default()
})
}
TypeId::BigVarBinary | TypeId::BigBinary => {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let max_length = src.get_u16_le() as u32;
Ok(TypeInfo {
max_length: Some(max_length),
..Default::default()
})
}
TypeId::NChar | TypeId::NVarChar => {
if src.remaining() < 7 {
return Err(ProtocolError::UnexpectedEof);
}
let max_length = src.get_u16_le() as u32;
let collation = Self::decode_collation(src)?;
Ok(TypeInfo {
max_length: Some(max_length),
collation: Some(collation),
..Default::default()
})
}
TypeId::Date => Ok(TypeInfo::default()),
TypeId::Time | TypeId::DateTime2 | TypeId::DateTimeOffset => {
if src.remaining() < 1 {
return Err(ProtocolError::UnexpectedEof);
}
let scale = src.get_u8();
Ok(TypeInfo {
scale: Some(scale),
..Default::default()
})
}
TypeId::Text | TypeId::NText | TypeId::Image => {
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let max_length = src.get_u32_le();
let collation = if type_id == TypeId::Text || type_id == TypeId::NText {
if src.remaining() < 5 {
return Err(ProtocolError::UnexpectedEof);
}
Some(Self::decode_collation(src)?)
} else {
None
};
if src.remaining() < 1 {
return Err(ProtocolError::UnexpectedEof);
}
let num_parts = src.get_u8();
for _ in 0..num_parts {
let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
}
Ok(TypeInfo {
max_length: Some(max_length),
collation,
..Default::default()
})
}
TypeId::Xml => {
if src.remaining() < 1 {
return Err(ProtocolError::UnexpectedEof);
}
let schema_present = src.get_u8();
if schema_present != 0 {
let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; }
Ok(TypeInfo::default())
}
TypeId::Udt => {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let max_length = src.get_u16_le() as u32;
let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
Ok(TypeInfo {
max_length: Some(max_length),
..Default::default()
})
}
TypeId::Tvp => {
Err(ProtocolError::InvalidTokenType(col_type))
}
TypeId::Variant => {
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let max_length = src.get_u32_le();
Ok(TypeInfo {
max_length: Some(max_length),
..Default::default()
})
}
}
}
fn decode_collation(src: &mut impl Buf) -> Result<Collation, ProtocolError> {
if src.remaining() < 5 {
return Err(ProtocolError::UnexpectedEof);
}
let lcid = src.get_u32_le();
let sort_id = src.get_u8();
Ok(Collation { lcid, sort_id })
}
#[must_use]
pub fn column_count(&self) -> usize {
self.columns.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.columns.is_empty()
}
}
impl ColumnData {
#[must_use]
pub fn is_nullable(&self) -> bool {
(self.flags & 0x0001) != 0
}
#[must_use]
pub fn fixed_size(&self) -> Option<usize> {
match self.type_id {
TypeId::Null => Some(0),
TypeId::Int1 | TypeId::Bit => Some(1),
TypeId::Int2 => Some(2),
TypeId::Int4 => Some(4),
TypeId::Int8 => Some(8),
TypeId::Float4 => Some(4),
TypeId::Float8 => Some(8),
TypeId::Money => Some(8),
TypeId::Money4 => Some(4),
TypeId::DateTime => Some(8),
TypeId::DateTime4 => Some(4),
TypeId::Date => Some(3),
_ => None,
}
}
}
impl RawRow {
pub fn decode(src: &mut impl Buf, metadata: &ColMetaData) -> Result<Self, ProtocolError> {
let mut data = bytes::BytesMut::new();
for col in &metadata.columns {
Self::decode_column_value(src, col, &mut data)?;
}
Ok(Self {
data: data.freeze(),
})
}
fn decode_column_value(
src: &mut impl Buf,
col: &ColumnData,
dst: &mut bytes::BytesMut,
) -> Result<(), ProtocolError> {
match col.type_id {
TypeId::Null => {
}
TypeId::Int1 | TypeId::Bit => {
if src.remaining() < 1 {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&[src.get_u8()]);
}
TypeId::Int2 => {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&src.get_u16_le().to_le_bytes());
}
TypeId::Int4 => {
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
}
TypeId::Int8 => {
if src.remaining() < 8 {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&src.get_u64_le().to_le_bytes());
}
TypeId::Float4 => {
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
}
TypeId::Float8 => {
if src.remaining() < 8 {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&src.get_u64_le().to_le_bytes());
}
TypeId::Money => {
if src.remaining() < 8 {
return Err(ProtocolError::UnexpectedEof);
}
let hi = src.get_u32_le();
let lo = src.get_u32_le();
dst.extend_from_slice(&hi.to_le_bytes());
dst.extend_from_slice(&lo.to_le_bytes());
}
TypeId::Money4 => {
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
}
TypeId::DateTime => {
if src.remaining() < 8 {
return Err(ProtocolError::UnexpectedEof);
}
let days = src.get_u32_le();
let time = src.get_u32_le();
dst.extend_from_slice(&days.to_le_bytes());
dst.extend_from_slice(&time.to_le_bytes());
}
TypeId::DateTime4 => {
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
}
TypeId::Date => {
Self::decode_bytelen_type(src, dst)?;
}
TypeId::IntN | TypeId::BitN | TypeId::FloatN | TypeId::MoneyN | TypeId::DateTimeN => {
Self::decode_bytelen_type(src, dst)?;
}
TypeId::Guid => {
Self::decode_bytelen_type(src, dst)?;
}
TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
Self::decode_bytelen_type(src, dst)?;
}
TypeId::Char | TypeId::VarChar | TypeId::Binary | TypeId::VarBinary => {
Self::decode_bytelen_type(src, dst)?;
}
TypeId::BigVarChar | TypeId::BigVarBinary => {
if col.type_info.max_length == Some(0xFFFF) {
Self::decode_plp_type(src, dst)?;
} else {
Self::decode_ushortlen_type(src, dst)?;
}
}
TypeId::BigChar | TypeId::BigBinary => {
Self::decode_ushortlen_type(src, dst)?;
}
TypeId::NVarChar => {
if col.type_info.max_length == Some(0xFFFF) {
Self::decode_plp_type(src, dst)?;
} else {
Self::decode_ushortlen_type(src, dst)?;
}
}
TypeId::NChar => {
Self::decode_ushortlen_type(src, dst)?;
}
TypeId::Time | TypeId::DateTime2 | TypeId::DateTimeOffset => {
Self::decode_bytelen_type(src, dst)?;
}
TypeId::Text | TypeId::NText | TypeId::Image => {
Self::decode_textptr_type(src, dst)?;
}
TypeId::Xml => {
Self::decode_plp_type(src, dst)?;
}
TypeId::Variant => {
Self::decode_intlen_type(src, dst)?;
}
TypeId::Udt => {
Self::decode_plp_type(src, dst)?;
}
TypeId::Tvp => {
return Err(ProtocolError::InvalidTokenType(col.col_type));
}
}
Ok(())
}
fn decode_bytelen_type(
src: &mut impl Buf,
dst: &mut bytes::BytesMut,
) -> Result<(), ProtocolError> {
if src.remaining() < 1 {
return Err(ProtocolError::UnexpectedEof);
}
let len = src.get_u8() as usize;
if len == 0xFF {
dst.extend_from_slice(&[0xFF]);
} else if len == 0 {
dst.extend_from_slice(&[0x00]);
} else {
if src.remaining() < len {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&[len as u8]);
for _ in 0..len {
dst.extend_from_slice(&[src.get_u8()]);
}
}
Ok(())
}
fn decode_ushortlen_type(
src: &mut impl Buf,
dst: &mut bytes::BytesMut,
) -> Result<(), ProtocolError> {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let len = src.get_u16_le() as usize;
if len == 0xFFFF {
dst.extend_from_slice(&0xFFFFu16.to_le_bytes());
} else if len == 0 {
dst.extend_from_slice(&0u16.to_le_bytes());
} else {
if src.remaining() < len {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&(len as u16).to_le_bytes());
for _ in 0..len {
dst.extend_from_slice(&[src.get_u8()]);
}
}
Ok(())
}
fn decode_intlen_type(
src: &mut impl Buf,
dst: &mut bytes::BytesMut,
) -> Result<(), ProtocolError> {
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let len = src.get_u32_le() as usize;
if len == 0xFFFFFFFF {
dst.extend_from_slice(&0xFFFFFFFFu32.to_le_bytes());
} else if len == 0 {
dst.extend_from_slice(&0u32.to_le_bytes());
} else {
if src.remaining() < len {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&(len as u32).to_le_bytes());
for _ in 0..len {
dst.extend_from_slice(&[src.get_u8()]);
}
}
Ok(())
}
fn decode_textptr_type(
src: &mut impl Buf,
dst: &mut bytes::BytesMut,
) -> Result<(), ProtocolError> {
if src.remaining() < 1 {
return Err(ProtocolError::UnexpectedEof);
}
let textptr_len = src.get_u8() as usize;
if textptr_len == 0 {
dst.extend_from_slice(&0xFFFFFFFFFFFFFFFFu64.to_le_bytes());
return Ok(());
}
if src.remaining() < textptr_len {
return Err(ProtocolError::UnexpectedEof);
}
src.advance(textptr_len);
if src.remaining() < 8 {
return Err(ProtocolError::UnexpectedEof);
}
src.advance(8);
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let data_len = src.get_u32_le() as usize;
if src.remaining() < data_len {
return Err(ProtocolError::UnexpectedEof);
}
dst.extend_from_slice(&(data_len as u64).to_le_bytes());
dst.extend_from_slice(&(data_len as u32).to_le_bytes());
for _ in 0..data_len {
dst.extend_from_slice(&[src.get_u8()]);
}
dst.extend_from_slice(&0u32.to_le_bytes());
Ok(())
}
fn decode_plp_type(src: &mut impl Buf, dst: &mut bytes::BytesMut) -> Result<(), ProtocolError> {
if src.remaining() < 8 {
return Err(ProtocolError::UnexpectedEof);
}
let total_len = src.get_u64_le();
dst.extend_from_slice(&total_len.to_le_bytes());
if total_len == 0xFFFFFFFFFFFFFFFF {
return Ok(());
}
loop {
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let chunk_len = src.get_u32_le() as usize;
dst.extend_from_slice(&(chunk_len as u32).to_le_bytes());
if chunk_len == 0 {
break;
}
if src.remaining() < chunk_len {
return Err(ProtocolError::UnexpectedEof);
}
for _ in 0..chunk_len {
dst.extend_from_slice(&[src.get_u8()]);
}
}
Ok(())
}
}
impl NbcRow {
pub fn decode(src: &mut impl Buf, metadata: &ColMetaData) -> Result<Self, ProtocolError> {
let col_count = metadata.columns.len();
let bitmap_len = col_count.div_ceil(8);
if src.remaining() < bitmap_len {
return Err(ProtocolError::UnexpectedEof);
}
let mut null_bitmap = vec![0u8; bitmap_len];
for byte in &mut null_bitmap {
*byte = src.get_u8();
}
let mut data = bytes::BytesMut::new();
for (i, col) in metadata.columns.iter().enumerate() {
let byte_idx = i / 8;
let bit_idx = i % 8;
let is_null = (null_bitmap[byte_idx] & (1 << bit_idx)) != 0;
if !is_null {
RawRow::decode_column_value(src, col, &mut data)?;
}
}
Ok(Self {
null_bitmap,
data: data.freeze(),
})
}
#[must_use]
pub fn is_null(&self, column_index: usize) -> bool {
let byte_idx = column_index / 8;
let bit_idx = column_index % 8;
if byte_idx < self.null_bitmap.len() {
(self.null_bitmap[byte_idx] & (1 << bit_idx)) != 0
} else {
true }
}
}
impl ReturnValue {
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let _length = src.get_u16_le();
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let param_ordinal = src.get_u16_le();
let param_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
if src.remaining() < 1 {
return Err(ProtocolError::UnexpectedEof);
}
let status = src.get_u8();
if src.remaining() < 7 {
return Err(ProtocolError::UnexpectedEof);
}
let user_type = src.get_u32_le();
let flags = src.get_u16_le();
let col_type = src.get_u8();
let type_id = TypeId::from_u8(col_type).unwrap_or(TypeId::Null);
let type_info = ColMetaData::decode_type_info(src, type_id, col_type)?;
let mut value_buf = bytes::BytesMut::new();
let temp_col = ColumnData {
name: String::new(),
type_id,
col_type,
flags,
user_type,
type_info: type_info.clone(),
};
RawRow::decode_column_value(src, &temp_col, &mut value_buf)?;
Ok(Self {
param_ordinal,
param_name,
status,
user_type,
flags,
type_info,
value: value_buf.freeze(),
})
}
}
impl SessionState {
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let length = src.get_u32_le() as usize;
if src.remaining() < length {
return Err(ProtocolError::IncompletePacket {
expected: length,
actual: src.remaining(),
});
}
let data = src.copy_to_bytes(length);
Ok(Self { data })
}
}
mod done_status_bits {
pub const DONE_MORE: u16 = 0x0001;
pub const DONE_ERROR: u16 = 0x0002;
pub const DONE_INXACT: u16 = 0x0004;
pub const DONE_COUNT: u16 = 0x0010;
pub const DONE_ATTN: u16 = 0x0020;
pub const DONE_SRVERROR: u16 = 0x0100;
}
impl DoneStatus {
#[must_use]
pub fn from_bits(bits: u16) -> Self {
use done_status_bits::*;
Self {
more: (bits & DONE_MORE) != 0,
error: (bits & DONE_ERROR) != 0,
in_xact: (bits & DONE_INXACT) != 0,
count: (bits & DONE_COUNT) != 0,
attn: (bits & DONE_ATTN) != 0,
srverror: (bits & DONE_SRVERROR) != 0,
}
}
#[must_use]
pub fn to_bits(&self) -> u16 {
use done_status_bits::*;
let mut bits = 0u16;
if self.more {
bits |= DONE_MORE;
}
if self.error {
bits |= DONE_ERROR;
}
if self.in_xact {
bits |= DONE_INXACT;
}
if self.count {
bits |= DONE_COUNT;
}
if self.attn {
bits |= DONE_ATTN;
}
if self.srverror {
bits |= DONE_SRVERROR;
}
bits
}
}
impl Done {
pub const SIZE: usize = 12;
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < Self::SIZE {
return Err(ProtocolError::IncompletePacket {
expected: Self::SIZE,
actual: src.remaining(),
});
}
let status = DoneStatus::from_bits(src.get_u16_le());
let cur_cmd = src.get_u16_le();
let row_count = src.get_u64_le();
Ok(Self {
status,
cur_cmd,
row_count,
})
}
pub fn encode(&self, dst: &mut impl BufMut) {
dst.put_u8(TokenType::Done as u8);
dst.put_u16_le(self.status.to_bits());
dst.put_u16_le(self.cur_cmd);
dst.put_u64_le(self.row_count);
}
#[must_use]
pub const fn has_more(&self) -> bool {
self.status.more
}
#[must_use]
pub const fn has_error(&self) -> bool {
self.status.error
}
#[must_use]
pub const fn has_count(&self) -> bool {
self.status.count
}
}
impl DoneProc {
pub const SIZE: usize = 12;
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < Self::SIZE {
return Err(ProtocolError::IncompletePacket {
expected: Self::SIZE,
actual: src.remaining(),
});
}
let status = DoneStatus::from_bits(src.get_u16_le());
let cur_cmd = src.get_u16_le();
let row_count = src.get_u64_le();
Ok(Self {
status,
cur_cmd,
row_count,
})
}
pub fn encode(&self, dst: &mut impl BufMut) {
dst.put_u8(TokenType::DoneProc as u8);
dst.put_u16_le(self.status.to_bits());
dst.put_u16_le(self.cur_cmd);
dst.put_u64_le(self.row_count);
}
}
impl DoneInProc {
pub const SIZE: usize = 12;
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < Self::SIZE {
return Err(ProtocolError::IncompletePacket {
expected: Self::SIZE,
actual: src.remaining(),
});
}
let status = DoneStatus::from_bits(src.get_u16_le());
let cur_cmd = src.get_u16_le();
let row_count = src.get_u64_le();
Ok(Self {
status,
cur_cmd,
row_count,
})
}
pub fn encode(&self, dst: &mut impl BufMut) {
dst.put_u8(TokenType::DoneInProc as u8);
dst.put_u16_le(self.status.to_bits());
dst.put_u16_le(self.cur_cmd);
dst.put_u64_le(self.row_count);
}
}
impl ServerError {
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let _length = src.get_u16_le();
if src.remaining() < 6 {
return Err(ProtocolError::UnexpectedEof);
}
let number = src.get_i32_le();
let state = src.get_u8();
let class = src.get_u8();
let message = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
let server = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
let procedure = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let line = src.get_i32_le();
Ok(Self {
number,
state,
class,
message,
server,
procedure,
line,
})
}
#[must_use]
pub const fn is_fatal(&self) -> bool {
self.class >= 20
}
#[must_use]
pub const fn is_batch_abort(&self) -> bool {
self.class >= 16
}
}
impl ServerInfo {
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let _length = src.get_u16_le();
if src.remaining() < 6 {
return Err(ProtocolError::UnexpectedEof);
}
let number = src.get_i32_le();
let state = src.get_u8();
let class = src.get_u8();
let message = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
let server = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
let procedure = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let line = src.get_i32_le();
Ok(Self {
number,
state,
class,
message,
server,
procedure,
line,
})
}
}
impl LoginAck {
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let _length = src.get_u16_le();
if src.remaining() < 5 {
return Err(ProtocolError::UnexpectedEof);
}
let interface = src.get_u8();
let tds_version = src.get_u32_le();
let prog_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let prog_version = src.get_u32_le();
Ok(Self {
interface,
tds_version,
prog_name,
prog_version,
})
}
#[must_use]
pub fn tds_version(&self) -> crate::version::TdsVersion {
crate::version::TdsVersion::new(self.tds_version)
}
}
impl EnvChangeType {
pub fn from_u8(value: u8) -> Option<Self> {
match value {
1 => Some(Self::Database),
2 => Some(Self::Language),
3 => Some(Self::CharacterSet),
4 => Some(Self::PacketSize),
5 => Some(Self::UnicodeSortingLocalId),
6 => Some(Self::UnicodeComparisonFlags),
7 => Some(Self::SqlCollation),
8 => Some(Self::BeginTransaction),
9 => Some(Self::CommitTransaction),
10 => Some(Self::RollbackTransaction),
11 => Some(Self::EnlistDtcTransaction),
12 => Some(Self::DefectTransaction),
13 => Some(Self::RealTimeLogShipping),
15 => Some(Self::PromoteTransaction),
16 => Some(Self::TransactionManagerAddress),
17 => Some(Self::TransactionEnded),
18 => Some(Self::ResetConnectionCompletionAck),
19 => Some(Self::UserInstanceStarted),
20 => Some(Self::Routing),
_ => None,
}
}
}
impl EnvChange {
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 3 {
return Err(ProtocolError::UnexpectedEof);
}
let length = src.get_u16_le() as usize;
if src.remaining() < length {
return Err(ProtocolError::IncompletePacket {
expected: length,
actual: src.remaining(),
});
}
let env_type_byte = src.get_u8();
let env_type = EnvChangeType::from_u8(env_type_byte)
.ok_or(ProtocolError::InvalidTokenType(env_type_byte))?;
let (new_value, old_value) = match env_type {
EnvChangeType::Routing => {
let new_value = Self::decode_routing_value(src)?;
let old_value = EnvChangeValue::Binary(Bytes::new());
(new_value, old_value)
}
EnvChangeType::BeginTransaction
| EnvChangeType::CommitTransaction
| EnvChangeType::RollbackTransaction
| EnvChangeType::EnlistDtcTransaction
| EnvChangeType::SqlCollation => {
let new_len = src.get_u8() as usize;
let new_value = if new_len > 0 && src.remaining() >= new_len {
EnvChangeValue::Binary(src.copy_to_bytes(new_len))
} else {
EnvChangeValue::Binary(Bytes::new())
};
let old_len = src.get_u8() as usize;
let old_value = if old_len > 0 && src.remaining() >= old_len {
EnvChangeValue::Binary(src.copy_to_bytes(old_len))
} else {
EnvChangeValue::Binary(Bytes::new())
};
(new_value, old_value)
}
_ => {
let new_value = read_b_varchar(src)
.map(EnvChangeValue::String)
.unwrap_or(EnvChangeValue::String(String::new()));
let old_value = read_b_varchar(src)
.map(EnvChangeValue::String)
.unwrap_or(EnvChangeValue::String(String::new()));
(new_value, old_value)
}
};
Ok(Self {
env_type,
new_value,
old_value,
})
}
fn decode_routing_value(src: &mut impl Buf) -> Result<EnvChangeValue, ProtocolError> {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let _routing_len = src.get_u16_le();
if src.remaining() < 5 {
return Err(ProtocolError::UnexpectedEof);
}
let _protocol = src.get_u8();
let port = src.get_u16_le();
let server_len = src.get_u16_le() as usize;
if src.remaining() < server_len * 2 {
return Err(ProtocolError::UnexpectedEof);
}
let mut chars = Vec::with_capacity(server_len);
for _ in 0..server_len {
chars.push(src.get_u16_le());
}
let host = String::from_utf16(&chars).map_err(|_| {
ProtocolError::StringEncoding(
#[cfg(feature = "std")]
"invalid UTF-16 in routing hostname".to_string(),
#[cfg(not(feature = "std"))]
"invalid UTF-16 in routing hostname",
)
})?;
Ok(EnvChangeValue::Routing { host, port })
}
#[must_use]
pub fn is_routing(&self) -> bool {
self.env_type == EnvChangeType::Routing
}
#[must_use]
pub fn routing_info(&self) -> Option<(&str, u16)> {
if let EnvChangeValue::Routing { host, port } = &self.new_value {
Some((host, *port))
} else {
None
}
}
#[must_use]
pub fn new_database(&self) -> Option<&str> {
if self.env_type == EnvChangeType::Database {
if let EnvChangeValue::String(s) = &self.new_value {
return Some(s);
}
}
None
}
}
impl Order {
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let length = src.get_u16_le() as usize;
let column_count = length / 2;
if src.remaining() < length {
return Err(ProtocolError::IncompletePacket {
expected: length,
actual: src.remaining(),
});
}
let mut columns = Vec::with_capacity(column_count);
for _ in 0..column_count {
columns.push(src.get_u16_le());
}
Ok(Self { columns })
}
}
impl FeatureExtAck {
pub const TERMINATOR: u8 = 0xFF;
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
let mut features = Vec::new();
loop {
if !src.has_remaining() {
return Err(ProtocolError::UnexpectedEof);
}
let feature_id = src.get_u8();
if feature_id == Self::TERMINATOR {
break;
}
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let data_len = src.get_u32_le() as usize;
if src.remaining() < data_len {
return Err(ProtocolError::IncompletePacket {
expected: data_len,
actual: src.remaining(),
});
}
let data = src.copy_to_bytes(data_len);
features.push(FeatureAck { feature_id, data });
}
Ok(Self { features })
}
}
impl SspiToken {
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let length = src.get_u16_le() as usize;
if src.remaining() < length {
return Err(ProtocolError::IncompletePacket {
expected: length,
actual: src.remaining(),
});
}
let data = src.copy_to_bytes(length);
Ok(Self { data })
}
}
impl FedAuthInfo {
pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
if src.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let _length = src.get_u32_le();
if src.remaining() < 5 {
return Err(ProtocolError::UnexpectedEof);
}
let _count = src.get_u8();
let mut sts_url = String::new();
let mut spn = String::new();
while src.has_remaining() {
if src.remaining() < 9 {
break;
}
let info_id = src.get_u8();
let info_len = src.get_u32_le() as usize;
let _info_offset = src.get_u32_le();
if src.remaining() < info_len {
break;
}
let char_count = info_len / 2;
let mut chars = Vec::with_capacity(char_count);
for _ in 0..char_count {
chars.push(src.get_u16_le());
}
if let Ok(value) = String::from_utf16(&chars) {
match info_id {
0x01 => spn = value,
0x02 => sts_url = value,
_ => {}
}
}
}
Ok(Self { sts_url, spn })
}
}
pub struct TokenParser {
data: Bytes,
position: usize,
}
impl TokenParser {
#[must_use]
pub fn new(data: Bytes) -> Self {
Self { data, position: 0 }
}
#[must_use]
pub fn remaining(&self) -> usize {
self.data.len().saturating_sub(self.position)
}
#[must_use]
pub fn has_remaining(&self) -> bool {
self.position < self.data.len()
}
#[must_use]
pub fn peek_token_type(&self) -> Option<TokenType> {
if self.position < self.data.len() {
TokenType::from_u8(self.data[self.position])
} else {
None
}
}
pub fn next_token(&mut self) -> Result<Option<Token>, ProtocolError> {
self.next_token_with_metadata(None)
}
pub fn next_token_with_metadata(
&mut self,
metadata: Option<&ColMetaData>,
) -> Result<Option<Token>, ProtocolError> {
if !self.has_remaining() {
return Ok(None);
}
let mut buf = &self.data[self.position..];
let start_pos = self.position;
let token_type_byte = buf.get_u8();
let token_type = TokenType::from_u8(token_type_byte);
let token = match token_type {
Some(TokenType::Done) => {
let done = Done::decode(&mut buf)?;
Token::Done(done)
}
Some(TokenType::DoneProc) => {
let done = DoneProc::decode(&mut buf)?;
Token::DoneProc(done)
}
Some(TokenType::DoneInProc) => {
let done = DoneInProc::decode(&mut buf)?;
Token::DoneInProc(done)
}
Some(TokenType::Error) => {
let error = ServerError::decode(&mut buf)?;
Token::Error(error)
}
Some(TokenType::Info) => {
let info = ServerInfo::decode(&mut buf)?;
Token::Info(info)
}
Some(TokenType::LoginAck) => {
let login_ack = LoginAck::decode(&mut buf)?;
Token::LoginAck(login_ack)
}
Some(TokenType::EnvChange) => {
let env_change = EnvChange::decode(&mut buf)?;
Token::EnvChange(env_change)
}
Some(TokenType::Order) => {
let order = Order::decode(&mut buf)?;
Token::Order(order)
}
Some(TokenType::FeatureExtAck) => {
let ack = FeatureExtAck::decode(&mut buf)?;
Token::FeatureExtAck(ack)
}
Some(TokenType::Sspi) => {
let sspi = SspiToken::decode(&mut buf)?;
Token::Sspi(sspi)
}
Some(TokenType::FedAuthInfo) => {
let info = FedAuthInfo::decode(&mut buf)?;
Token::FedAuthInfo(info)
}
Some(TokenType::ReturnStatus) => {
if buf.remaining() < 4 {
return Err(ProtocolError::UnexpectedEof);
}
let status = buf.get_i32_le();
Token::ReturnStatus(status)
}
Some(TokenType::ColMetaData) => {
let col_meta = ColMetaData::decode(&mut buf)?;
Token::ColMetaData(col_meta)
}
Some(TokenType::Row) => {
let meta = metadata.ok_or_else(|| {
ProtocolError::StringEncoding(
#[cfg(feature = "std")]
"Row token requires column metadata".to_string(),
#[cfg(not(feature = "std"))]
"Row token requires column metadata",
)
})?;
let row = RawRow::decode(&mut buf, meta)?;
Token::Row(row)
}
Some(TokenType::NbcRow) => {
let meta = metadata.ok_or_else(|| {
ProtocolError::StringEncoding(
#[cfg(feature = "std")]
"NbcRow token requires column metadata".to_string(),
#[cfg(not(feature = "std"))]
"NbcRow token requires column metadata",
)
})?;
let row = NbcRow::decode(&mut buf, meta)?;
Token::NbcRow(row)
}
Some(TokenType::ReturnValue) => {
let ret_val = ReturnValue::decode(&mut buf)?;
Token::ReturnValue(ret_val)
}
Some(TokenType::SessionState) => {
let session = SessionState::decode(&mut buf)?;
Token::SessionState(session)
}
Some(TokenType::ColInfo) | Some(TokenType::TabName) | Some(TokenType::Offset) => {
if buf.remaining() < 2 {
return Err(ProtocolError::UnexpectedEof);
}
let length = buf.get_u16_le() as usize;
if buf.remaining() < length {
return Err(ProtocolError::IncompletePacket {
expected: length,
actual: buf.remaining(),
});
}
buf.advance(length);
self.position = start_pos + (self.data.len() - start_pos - buf.remaining());
return self.next_token_with_metadata(metadata);
}
None => {
return Err(ProtocolError::InvalidTokenType(token_type_byte));
}
};
let consumed = self.data.len() - start_pos - buf.remaining();
self.position = start_pos + consumed;
Ok(Some(token))
}
pub fn skip_token(&mut self) -> Result<(), ProtocolError> {
if !self.has_remaining() {
return Ok(());
}
let token_type_byte = self.data[self.position];
let token_type = TokenType::from_u8(token_type_byte);
let skip_amount = match token_type {
Some(TokenType::Done) | Some(TokenType::DoneProc) | Some(TokenType::DoneInProc) => {
1 + Done::SIZE }
Some(TokenType::ReturnStatus) => {
1 + 4 }
Some(TokenType::Error)
| Some(TokenType::Info)
| Some(TokenType::LoginAck)
| Some(TokenType::EnvChange)
| Some(TokenType::Order)
| Some(TokenType::Sspi)
| Some(TokenType::ColInfo)
| Some(TokenType::TabName)
| Some(TokenType::Offset)
| Some(TokenType::ReturnValue) => {
if self.remaining() < 3 {
return Err(ProtocolError::UnexpectedEof);
}
let length = u16::from_le_bytes([
self.data[self.position + 1],
self.data[self.position + 2],
]) as usize;
1 + 2 + length }
Some(TokenType::SessionState) | Some(TokenType::FedAuthInfo) => {
if self.remaining() < 5 {
return Err(ProtocolError::UnexpectedEof);
}
let length = u32::from_le_bytes([
self.data[self.position + 1],
self.data[self.position + 2],
self.data[self.position + 3],
self.data[self.position + 4],
]) as usize;
1 + 4 + length
}
Some(TokenType::FeatureExtAck) => {
let mut buf = &self.data[self.position + 1..];
let _ = FeatureExtAck::decode(&mut buf)?;
self.data.len() - self.position - buf.remaining()
}
Some(TokenType::ColMetaData) | Some(TokenType::Row) | Some(TokenType::NbcRow) => {
return Err(ProtocolError::InvalidTokenType(token_type_byte));
}
None => {
return Err(ProtocolError::InvalidTokenType(token_type_byte));
}
};
if self.remaining() < skip_amount {
return Err(ProtocolError::UnexpectedEof);
}
self.position += skip_amount;
Ok(())
}
#[must_use]
pub fn position(&self) -> usize {
self.position
}
pub fn reset(&mut self) {
self.position = 0;
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use bytes::BytesMut;
#[test]
fn test_done_roundtrip() {
let done = Done {
status: DoneStatus {
more: false,
error: false,
in_xact: false,
count: true,
attn: false,
srverror: false,
},
cur_cmd: 193, row_count: 42,
};
let mut buf = BytesMut::new();
done.encode(&mut buf);
let mut cursor = &buf[1..];
let decoded = Done::decode(&mut cursor).unwrap();
assert_eq!(decoded.status.count, done.status.count);
assert_eq!(decoded.cur_cmd, done.cur_cmd);
assert_eq!(decoded.row_count, done.row_count);
}
#[test]
fn test_done_status_bits() {
let status = DoneStatus {
more: true,
error: true,
in_xact: true,
count: true,
attn: false,
srverror: false,
};
let bits = status.to_bits();
let restored = DoneStatus::from_bits(bits);
assert_eq!(status.more, restored.more);
assert_eq!(status.error, restored.error);
assert_eq!(status.in_xact, restored.in_xact);
assert_eq!(status.count, restored.count);
}
#[test]
fn test_token_parser_done() {
let data = Bytes::from_static(&[
0xFD, 0x10, 0x00, 0xC1, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
let mut parser = TokenParser::new(data);
let token = parser.next_token().unwrap().unwrap();
match token {
Token::Done(done) => {
assert!(done.status.count);
assert!(!done.status.more);
assert_eq!(done.cur_cmd, 193);
assert_eq!(done.row_count, 5);
}
_ => panic!("Expected Done token"),
}
assert!(parser.next_token().unwrap().is_none());
}
#[test]
fn test_env_change_type_from_u8() {
assert_eq!(EnvChangeType::from_u8(1), Some(EnvChangeType::Database));
assert_eq!(EnvChangeType::from_u8(20), Some(EnvChangeType::Routing));
assert_eq!(EnvChangeType::from_u8(100), None);
}
#[test]
fn test_colmetadata_no_columns() {
let data = Bytes::from_static(&[0xFF, 0xFF]);
let mut cursor: &[u8] = &data;
let meta = ColMetaData::decode(&mut cursor).unwrap();
assert!(meta.is_empty());
assert_eq!(meta.column_count(), 0);
}
#[test]
fn test_colmetadata_single_int_column() {
let mut data = BytesMut::new();
data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x38]); data.extend_from_slice(&[0x02]); data.extend_from_slice(&[b'i', 0x00, b'd', 0x00]);
let mut cursor: &[u8] = &data;
let meta = ColMetaData::decode(&mut cursor).unwrap();
assert_eq!(meta.column_count(), 1);
assert_eq!(meta.columns[0].name, "id");
assert_eq!(meta.columns[0].type_id, TypeId::Int4);
assert!(meta.columns[0].is_nullable());
}
#[test]
fn test_colmetadata_nvarchar_column() {
let mut data = BytesMut::new();
data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0xE7]); data.extend_from_slice(&[0x64, 0x00]); data.extend_from_slice(&[0x09, 0x04, 0xD0, 0x00, 0x34]); data.extend_from_slice(&[0x04]); data.extend_from_slice(&[b'n', 0x00, b'a', 0x00, b'm', 0x00, b'e', 0x00]);
let mut cursor: &[u8] = &data;
let meta = ColMetaData::decode(&mut cursor).unwrap();
assert_eq!(meta.column_count(), 1);
assert_eq!(meta.columns[0].name, "name");
assert_eq!(meta.columns[0].type_id, TypeId::NVarChar);
assert_eq!(meta.columns[0].type_info.max_length, Some(100));
assert!(meta.columns[0].type_info.collation.is_some());
}
#[test]
fn test_raw_row_decode_int() {
let metadata = ColMetaData {
columns: vec![ColumnData {
name: "id".to_string(),
type_id: TypeId::Int4,
col_type: 0x38,
flags: 0,
user_type: 0,
type_info: TypeInfo::default(),
}],
};
let data = Bytes::from_static(&[0x2A, 0x00, 0x00, 0x00]); let mut cursor: &[u8] = &data;
let row = RawRow::decode(&mut cursor, &metadata).unwrap();
assert_eq!(row.data.len(), 4);
assert_eq!(&row.data[..], &[0x2A, 0x00, 0x00, 0x00]);
}
#[test]
fn test_raw_row_decode_nullable_int() {
let metadata = ColMetaData {
columns: vec![ColumnData {
name: "id".to_string(),
type_id: TypeId::IntN,
col_type: 0x26,
flags: 0x01, user_type: 0,
type_info: TypeInfo {
max_length: Some(4),
..Default::default()
},
}],
};
let data = Bytes::from_static(&[0x04, 0x2A, 0x00, 0x00, 0x00]); let mut cursor: &[u8] = &data;
let row = RawRow::decode(&mut cursor, &metadata).unwrap();
assert_eq!(row.data.len(), 5);
assert_eq!(row.data[0], 4); assert_eq!(&row.data[1..], &[0x2A, 0x00, 0x00, 0x00]);
}
#[test]
fn test_raw_row_decode_null_value() {
let metadata = ColMetaData {
columns: vec![ColumnData {
name: "id".to_string(),
type_id: TypeId::IntN,
col_type: 0x26,
flags: 0x01, user_type: 0,
type_info: TypeInfo {
max_length: Some(4),
..Default::default()
},
}],
};
let data = Bytes::from_static(&[0xFF]);
let mut cursor: &[u8] = &data;
let row = RawRow::decode(&mut cursor, &metadata).unwrap();
assert_eq!(row.data.len(), 1);
assert_eq!(row.data[0], 0xFF); }
#[test]
fn test_nbcrow_null_bitmap() {
let row = NbcRow {
null_bitmap: vec![0b00000101], data: Bytes::new(),
};
assert!(row.is_null(0));
assert!(!row.is_null(1));
assert!(row.is_null(2));
assert!(!row.is_null(3));
}
#[test]
fn test_token_parser_colmetadata() {
let mut data = BytesMut::new();
data.extend_from_slice(&[0x81]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x38]); data.extend_from_slice(&[0x02]); data.extend_from_slice(&[b'i', 0x00, b'd', 0x00]);
let mut parser = TokenParser::new(data.freeze());
let token = parser.next_token().unwrap().unwrap();
match token {
Token::ColMetaData(meta) => {
assert_eq!(meta.column_count(), 1);
assert_eq!(meta.columns[0].name, "id");
assert_eq!(meta.columns[0].type_id, TypeId::Int4);
}
_ => panic!("Expected ColMetaData token"),
}
}
#[test]
fn test_token_parser_row_with_metadata() {
let metadata = ColMetaData {
columns: vec![ColumnData {
name: "id".to_string(),
type_id: TypeId::Int4,
col_type: 0x38,
flags: 0,
user_type: 0,
type_info: TypeInfo::default(),
}],
};
let mut data = BytesMut::new();
data.extend_from_slice(&[0xD1]); data.extend_from_slice(&[0x2A, 0x00, 0x00, 0x00]);
let mut parser = TokenParser::new(data.freeze());
let token = parser
.next_token_with_metadata(Some(&metadata))
.unwrap()
.unwrap();
match token {
Token::Row(row) => {
assert_eq!(row.data.len(), 4);
}
_ => panic!("Expected Row token"),
}
}
#[test]
fn test_token_parser_row_without_metadata_fails() {
let mut data = BytesMut::new();
data.extend_from_slice(&[0xD1]); data.extend_from_slice(&[0x2A, 0x00, 0x00, 0x00]);
let mut parser = TokenParser::new(data.freeze());
let result = parser.next_token();
assert!(result.is_err());
}
#[test]
fn test_token_parser_peek() {
let data = Bytes::from_static(&[
0xFD, 0x10, 0x00, 0xC1, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
let parser = TokenParser::new(data);
assert_eq!(parser.peek_token_type(), Some(TokenType::Done));
}
#[test]
fn test_column_data_fixed_size() {
let col = ColumnData {
name: String::new(),
type_id: TypeId::Int4,
col_type: 0x38,
flags: 0,
user_type: 0,
type_info: TypeInfo::default(),
};
assert_eq!(col.fixed_size(), Some(4));
let col2 = ColumnData {
name: String::new(),
type_id: TypeId::NVarChar,
col_type: 0xE7,
flags: 0,
user_type: 0,
type_info: TypeInfo::default(),
};
assert_eq!(col2.fixed_size(), None);
}
#[test]
fn test_decode_nvarchar_then_intn_roundtrip() {
let mut wire_data = BytesMut::new();
let word = "World";
let utf16: Vec<u16> = word.encode_utf16().collect();
wire_data.put_u16_le((utf16.len() * 2) as u16); for code_unit in &utf16 {
wire_data.put_u16_le(*code_unit);
}
wire_data.put_u8(4); wire_data.put_i32_le(42);
let metadata = ColMetaData {
columns: vec![
ColumnData {
name: "greeting".to_string(),
type_id: TypeId::NVarChar,
col_type: 0xE7,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length: Some(10), precision: None,
scale: None,
collation: None,
},
},
ColumnData {
name: "number".to_string(),
type_id: TypeId::IntN,
col_type: 0x26,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length: Some(4),
precision: None,
scale: None,
collation: None,
},
},
],
};
let mut wire_cursor = wire_data.freeze();
let raw_row = RawRow::decode(&mut wire_cursor, &metadata).unwrap();
assert_eq!(
wire_cursor.remaining(),
0,
"wire data should be fully consumed"
);
let mut stored_cursor: &[u8] = &raw_row.data;
assert!(
stored_cursor.remaining() >= 2,
"need at least 2 bytes for length"
);
let len0 = stored_cursor.get_u16_le() as usize;
assert_eq!(len0, 10, "NVarChar length should be 10 bytes");
assert!(
stored_cursor.remaining() >= len0,
"need {len0} bytes for data"
);
let mut utf16_read = Vec::new();
for _ in 0..(len0 / 2) {
utf16_read.push(stored_cursor.get_u16_le());
}
let string0 = String::from_utf16(&utf16_read).unwrap();
assert_eq!(string0, "World", "column 0 should be 'World'");
assert!(
stored_cursor.remaining() >= 1,
"need at least 1 byte for length"
);
let len1 = stored_cursor.get_u8();
assert_eq!(len1, 4, "IntN length should be 4");
assert!(stored_cursor.remaining() >= 4, "need 4 bytes for INT data");
let int1 = stored_cursor.get_i32_le();
assert_eq!(int1, 42, "column 1 should be 42");
assert_eq!(
stored_cursor.remaining(),
0,
"stored data should be fully consumed"
);
}
#[test]
fn test_decode_nvarchar_max_then_intn_roundtrip() {
let mut wire_data = BytesMut::new();
let word = "Hello";
let utf16: Vec<u16> = word.encode_utf16().collect();
let byte_len = (utf16.len() * 2) as u64;
wire_data.put_u64_le(byte_len); wire_data.put_u32_le(byte_len as u32); for code_unit in &utf16 {
wire_data.put_u16_le(*code_unit);
}
wire_data.put_u32_le(0);
wire_data.put_u8(4);
wire_data.put_i32_le(99);
let metadata = ColMetaData {
columns: vec![
ColumnData {
name: "text".to_string(),
type_id: TypeId::NVarChar,
col_type: 0xE7,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length: Some(0xFFFF), precision: None,
scale: None,
collation: None,
},
},
ColumnData {
name: "num".to_string(),
type_id: TypeId::IntN,
col_type: 0x26,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length: Some(4),
precision: None,
scale: None,
collation: None,
},
},
],
};
let mut wire_cursor = wire_data.freeze();
let raw_row = RawRow::decode(&mut wire_cursor, &metadata).unwrap();
assert_eq!(
wire_cursor.remaining(),
0,
"wire data should be fully consumed"
);
let mut stored_cursor: &[u8] = &raw_row.data;
let total_len = stored_cursor.get_u64_le();
assert_eq!(total_len, 10, "PLP total length should be 10");
let chunk_len = stored_cursor.get_u32_le();
assert_eq!(chunk_len, 10, "PLP chunk length should be 10");
let mut utf16_read = Vec::new();
for _ in 0..(chunk_len / 2) {
utf16_read.push(stored_cursor.get_u16_le());
}
let string0 = String::from_utf16(&utf16_read).unwrap();
assert_eq!(string0, "Hello", "column 0 should be 'Hello'");
let terminator = stored_cursor.get_u32_le();
assert_eq!(terminator, 0, "PLP should end with 0");
let len1 = stored_cursor.get_u8();
assert_eq!(len1, 4);
let int1 = stored_cursor.get_i32_le();
assert_eq!(int1, 99, "column 1 should be 99");
assert_eq!(
stored_cursor.remaining(),
0,
"stored data should be fully consumed"
);
}
}