#![allow(clippy::unwrap_used, clippy::expect_used, clippy::needless_range_loop)]
use bytes::Buf;
use mssql_types::SqlValue;
#[cfg(feature = "chrono")]
use mssql_types::__private::intervals_to_time;
use mssql_types::__private::time_bytes_for_scale;
use tds_protocol::token::{ColMetaData, Collation, ColumnData, NbcRow, RawRow};
use tds_protocol::types::TypeId;
use crate::error::{Error, Result};
#[cfg(feature = "chrono")]
fn smalldatetime_from_wire(days: i64, minutes: u32) -> Result<chrono::NaiveDateTime> {
let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).expect("epoch 1900-01-01 is valid");
let date = base
.checked_add_signed(chrono::Duration::days(days))
.ok_or_else(|| Error::Protocol(format!("SMALLDATETIME days out of range: {days}")))?;
let secs = u64::from(minutes) * 60;
let time = u32::try_from(secs)
.ok()
.and_then(|s| chrono::NaiveTime::from_num_seconds_from_midnight_opt(s, 0))
.ok_or_else(|| Error::Protocol(format!("SMALLDATETIME minutes out of range: {minutes}")))?;
Ok(date.and_time(time))
}
#[cfg(feature = "chrono")]
fn datetime_from_wire(days: i64, time_300ths: u64) -> Result<chrono::NaiveDateTime> {
let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).expect("epoch 1900-01-01 is valid");
let date = base
.checked_add_signed(chrono::Duration::days(days))
.ok_or_else(|| Error::Protocol(format!("DATETIME days out of range: {days}")))?;
let total_ms = (time_300ths * 1000) / 300;
let nanos = ((total_ms % 1000) * 1_000_000) as u32;
let time = u32::try_from(total_ms / 1000)
.ok()
.and_then(|secs| chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nanos))
.ok_or_else(|| {
Error::Protocol(format!(
"DATETIME time component out of range: {time_300ths}"
))
})?;
Ok(date.and_time(time))
}
pub(crate) fn convert_raw_row(
raw: &RawRow,
meta: &ColMetaData,
columns: &[crate::row::Column],
) -> Result<crate::row::Row> {
let mut values = Vec::with_capacity(meta.columns.len());
let mut buf = raw.data.as_ref();
for col in &meta.columns {
let value = parse_column_value(&mut buf, col)?;
values.push(value);
}
Ok(crate::row::Row::from_values(columns.to_vec(), values))
}
pub(crate) fn convert_nbc_row(
nbc: &NbcRow,
meta: &ColMetaData,
columns: &[crate::row::Column],
) -> Result<crate::row::Row> {
let mut values = Vec::with_capacity(meta.columns.len());
let mut buf = nbc.data.as_ref();
for (i, col) in meta.columns.iter().enumerate() {
if nbc.is_null(i) {
values.push(mssql_types::SqlValue::Null);
} else {
let value = parse_column_value(&mut buf, col)?;
values.push(value);
}
}
Ok(crate::row::Row::from_values(columns.to_vec(), values))
}
#[cfg(feature = "always-encrypted")]
pub(crate) fn convert_raw_row_decrypted(
raw: &RawRow,
meta: &ColMetaData,
columns: &[crate::row::Column],
decryptor: &crate::column_decryptor::ColumnDecryptor,
) -> Result<crate::row::Row> {
let mut values = Vec::with_capacity(meta.columns.len());
let mut buf = raw.data.as_ref();
for (i, col) in meta.columns.iter().enumerate() {
let value = if decryptor.is_encrypted(i) {
decrypt_column(&mut buf, col, decryptor, i)?
} else {
parse_column_value(&mut buf, col)?
};
values.push(value);
}
Ok(crate::row::Row::from_values(columns.to_vec(), values))
}
#[cfg(feature = "always-encrypted")]
pub(crate) fn convert_nbc_row_decrypted(
nbc: &NbcRow,
meta: &ColMetaData,
columns: &[crate::row::Column],
decryptor: &crate::column_decryptor::ColumnDecryptor,
) -> Result<crate::row::Row> {
let mut values = Vec::with_capacity(meta.columns.len());
let mut buf = nbc.data.as_ref();
for (i, col) in meta.columns.iter().enumerate() {
if nbc.is_null(i) {
values.push(SqlValue::Null);
} else {
let value = if decryptor.is_encrypted(i) {
decrypt_column(&mut buf, col, decryptor, i)?
} else {
parse_column_value(&mut buf, col)?
};
values.push(value);
}
}
Ok(crate::row::Row::from_values(columns.to_vec(), values))
}
#[cfg(feature = "always-encrypted")]
fn decrypt_column(
buf: &mut &[u8],
_col: &ColumnData,
decryptor: &crate::column_decryptor::ColumnDecryptor,
ordinal: usize,
) -> Result<SqlValue> {
if buf.remaining() < 2 {
return Err(Error::Protocol(
"unexpected EOF reading encrypted column length".to_string(),
));
}
let length = buf.get_u16_le();
if length == 0xFFFF {
return Ok(SqlValue::Null);
}
let length = length as usize;
if buf.remaining() < length {
return Err(Error::Protocol(format!(
"unexpected EOF reading encrypted column data: need {length} bytes, have {}",
buf.remaining()
)));
}
let ciphertext = &buf[..length];
buf.advance(length);
let (plaintext, base_col) = decryptor.decrypt_column_value(ordinal, ciphertext)?;
denormalize_decrypted(plaintext, base_col)
}
#[cfg(feature = "always-encrypted")]
fn denormalize_decrypted(plaintext: Vec<u8>, base_col: &ColumnData) -> Result<SqlValue> {
match base_col.type_id {
TypeId::Bit
| TypeId::BitN
| TypeId::Int1
| TypeId::Int2
| TypeId::Int4
| TypeId::Int8
| TypeId::IntN => {
let v = i64::from_le_bytes(decrypted_array::<8>(&plaintext, "integer")?);
Ok(match base_col.type_id {
TypeId::Bit | TypeId::BitN => SqlValue::Bool(v != 0),
TypeId::Int1 => SqlValue::TinyInt(v as u8),
TypeId::Int2 => SqlValue::SmallInt(v as i16),
TypeId::Int8 => SqlValue::BigInt(v),
TypeId::IntN => match base_col.type_info.max_length {
Some(1) => SqlValue::TinyInt(v as u8),
Some(2) => SqlValue::SmallInt(v as i16),
Some(8) => SqlValue::BigInt(v),
_ => SqlValue::Int(v as i32),
},
_ => SqlValue::Int(v as i32),
})
}
TypeId::Float4 => Ok(SqlValue::Float(f32::from_le_bytes(decrypted_array::<4>(
&plaintext, "REAL",
)?))),
TypeId::Float8 => Ok(SqlValue::Double(f64::from_le_bytes(decrypted_array::<8>(
&plaintext, "FLOAT",
)?))),
TypeId::FloatN => match base_col.type_info.max_length {
Some(4) => Ok(SqlValue::Float(f32::from_le_bytes(decrypted_array::<4>(
&plaintext, "REAL",
)?))),
_ => Ok(SqlValue::Double(f64::from_le_bytes(decrypted_array::<8>(
&plaintext, "FLOAT",
)?))),
},
TypeId::NVarChar | TypeId::NChar => {
if plaintext.len() % 2 != 0 {
return Err(Error::Encryption(format!(
"decrypted NVARCHAR has an odd byte length ({})",
plaintext.len()
)));
}
let units: Vec<u16> = plaintext
.chunks_exact(2)
.map(|c| u16::from_le_bytes([c[0], c[1]]))
.collect();
let s = String::from_utf16(&units).map_err(|_| {
Error::Encryption("decrypted NVARCHAR is not valid UTF-16".to_string())
})?;
let s = if base_col.type_id == TypeId::NChar {
pad_fixed_char(s, base_col.type_info.max_length.unwrap_or(0) as usize / 2)
} else {
s
};
Ok(SqlValue::String(s))
}
TypeId::BigChar | TypeId::Char | TypeId::BigVarChar | TypeId::VarChar => {
let (s, _, had_errors) = encoding_rs::WINDOWS_1252.decode(&plaintext);
if had_errors {
return Err(Error::Encryption(
"decrypted CHAR is not valid Windows-1252".to_string(),
));
}
let s = s.into_owned();
let s = if matches!(base_col.type_id, TypeId::Char | TypeId::BigChar) {
pad_fixed_char(s, base_col.type_info.max_length.unwrap_or(0) as usize)
} else {
s
};
Ok(SqlValue::String(s))
}
TypeId::BigVarBinary | TypeId::BigBinary | TypeId::VarBinary | TypeId::Binary => {
Ok(SqlValue::Binary(bytes::Bytes::from(plaintext)))
}
#[cfg(feature = "uuid")]
TypeId::Guid => {
let b = decrypted_array::<16>(&plaintext, "uniqueidentifier")?;
Ok(SqlValue::Uuid(uuid::Uuid::from_bytes([
b[3], b[2], b[1], b[0], b[5], b[4], b[7], b[6], b[8], b[9], b[10], b[11], b[12],
b[13], b[14], b[15],
])))
}
#[cfg(feature = "chrono")]
TypeId::Date => {
let b = decrypted_array::<3>(&plaintext, "date")?;
let days = u32::from(b[0]) | (u32::from(b[1]) << 8) | (u32::from(b[2]) << 16);
chrono::NaiveDate::from_num_days_from_ce_opt(days as i32 + 1)
.map(SqlValue::Date)
.ok_or_else(|| {
Error::Encryption(format!("decrypted DATE day count {days} is out of range"))
})
}
#[cfg(feature = "decimal")]
TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
let b = decrypted_array::<17>(&plaintext, "decimal")?;
let mut mag = [0u8; 16];
mag.copy_from_slice(&b[1..17]);
let magnitude = u128::from_le_bytes(mag) as i128;
let signed = if b[0] == 0 { -magnitude } else { magnitude };
let scale = u32::from(base_col.type_info.scale.unwrap_or(0));
rust_decimal::Decimal::try_from_i128_with_scale(signed, scale)
.map(SqlValue::Decimal)
.map_err(|e| Error::Encryption(format!("decrypted DECIMAL out of range: {e}")))
}
#[cfg(feature = "decimal")]
TypeId::Money | TypeId::Money4 | TypeId::MoneyN => {
if plaintext.len() != 8 {
return Err(Error::Encryption(format!(
"decrypted MONEY has {} bytes, expected 8",
plaintext.len()
)));
}
parse_money_value(&mut plaintext.as_slice(), 8)
}
#[cfg(feature = "chrono")]
TypeId::Time => ae_time_from_bytes(&plaintext).map(SqlValue::Time),
#[cfg(feature = "chrono")]
TypeId::DateTime2 => {
if plaintext.len() != 8 {
return Err(Error::Encryption(format!(
"decrypted DATETIME2 has {} bytes, expected 8",
plaintext.len()
)));
}
let time = ae_time_from_bytes(&plaintext[..5])?;
let date = ae_date_from_bytes(&plaintext[5..8])?;
Ok(SqlValue::DateTime(date.and_time(time)))
}
#[cfg(feature = "chrono")]
TypeId::DateTimeOffset => {
use chrono::TimeZone;
if plaintext.len() != 10 {
return Err(Error::Encryption(format!(
"decrypted DATETIMEOFFSET has {} bytes, expected 10",
plaintext.len()
)));
}
let time = ae_time_from_bytes(&plaintext[..5])?;
let date = ae_date_from_bytes(&plaintext[5..8])?;
let offset_min = i16::from_le_bytes([plaintext[8], plaintext[9]]);
let offset =
chrono::FixedOffset::east_opt(i32::from(offset_min) * 60).ok_or_else(|| {
Error::Encryption(format!(
"decrypted DATETIMEOFFSET offset {offset_min} invalid"
))
})?;
Ok(SqlValue::DateTimeOffset(
offset.from_utc_datetime(&date.and_time(time)),
))
}
#[cfg(feature = "chrono")]
TypeId::DateTime | TypeId::DateTime4 | TypeId::DateTimeN => match plaintext.len() {
8 => {
let b = decrypted_array::<8>(&plaintext, "datetime")?;
let days = i64::from(i32::from_le_bytes([b[0], b[1], b[2], b[3]]));
let ticks = u64::from(u32::from_le_bytes([b[4], b[5], b[6], b[7]]));
datetime_from_wire(days, ticks).map(SqlValue::DateTime)
}
4 => {
let b = decrypted_array::<4>(&plaintext, "smalldatetime")?;
let days = i64::from(u16::from_le_bytes([b[0], b[1]]));
let minutes = u32::from(u16::from_le_bytes([b[2], b[3]]));
smalldatetime_from_wire(days, minutes).map(SqlValue::DateTime)
}
n => Err(Error::Encryption(format!(
"decrypted DATETIME has {n} bytes, expected 4 or 8"
))),
},
other => Err(Error::Encryption(format!(
"Always Encrypted read is not yet implemented for base type {other:?}"
))),
}
}
#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
fn ae_time_from_bytes(b: &[u8]) -> Result<chrono::NaiveTime> {
if b.len() != 5 {
return Err(Error::Encryption(format!(
"decrypted TIME has {} bytes, expected 5",
b.len()
)));
}
let mut buf = [0u8; 8];
buf[..5].copy_from_slice(b);
let ticks7 = u64::from_le_bytes(buf);
let nanos = ticks7
.checked_mul(100)
.ok_or_else(|| Error::Encryption("decrypted TIME out of range".to_string()))?;
let secs = u32::try_from(nanos / 1_000_000_000)
.map_err(|_| Error::Encryption("decrypted TIME out of range".to_string()))?;
let nsub = (nanos % 1_000_000_000) as u32;
chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nsub)
.ok_or_else(|| Error::Encryption(format!("decrypted TIME {secs}s out of range")))
}
#[cfg(feature = "always-encrypted")]
fn pad_fixed_char(mut s: String, target_chars: usize) -> String {
let cur = s.chars().count();
if cur < target_chars {
s.extend(std::iter::repeat_n(' ', target_chars - cur));
}
s
}
#[cfg(all(feature = "always-encrypted", feature = "chrono"))]
fn ae_date_from_bytes(b: &[u8]) -> Result<chrono::NaiveDate> {
if b.len() != 3 {
return Err(Error::Encryption(format!(
"decrypted date has {} bytes, expected 3",
b.len()
)));
}
let days = u32::from(b[0]) | (u32::from(b[1]) << 8) | (u32::from(b[2]) << 16);
chrono::NaiveDate::from_num_days_from_ce_opt(days as i32 + 1).ok_or_else(|| {
Error::Encryption(format!("decrypted date day count {days} is out of range"))
})
}
#[cfg(feature = "always-encrypted")]
fn decrypted_array<const N: usize>(plaintext: &[u8], what: &str) -> Result<[u8; N]> {
plaintext.try_into().map_err(|_| {
Error::Encryption(format!(
"decrypted {what} has {} bytes, expected {N}",
plaintext.len()
))
})
}
fn parse_money_value(buf: &mut &[u8], bytes: usize) -> Result<SqlValue> {
if bytes == 0 {
return Ok(SqlValue::Null);
}
let cents = match bytes {
4 => buf.get_i32_le() as i64,
8 => {
let high = buf.get_i32_le();
let low = buf.get_u32_le();
((high as i64) << 32) | (low as i64)
}
_ => return Err(Error::Protocol(format!("invalid money length: {bytes}"))),
};
#[cfg(feature = "decimal")]
{
use rust_decimal::Decimal;
match Decimal::try_from_i128_with_scale(cents as i128, 4) {
Ok(decimal) => Ok(SqlValue::Decimal(decimal)),
Err(_) => Ok(SqlValue::Double((cents as f64) / 10000.0)),
}
}
#[cfg(not(feature = "decimal"))]
{
Ok(SqlValue::Double((cents as f64) / 10000.0))
}
}
pub fn parse_column_value(buf: &mut &[u8], col: &ColumnData) -> Result<SqlValue> {
let value = match col.type_id {
TypeId::Null => SqlValue::Null,
TypeId::Int1 => {
if buf.remaining() < 1 {
return Err(Error::Protocol("unexpected EOF reading TINYINT".into()));
}
SqlValue::TinyInt(buf.get_u8())
}
TypeId::Bit => {
if buf.remaining() < 1 {
return Err(Error::Protocol("unexpected EOF reading BIT".into()));
}
SqlValue::Bool(buf.get_u8() != 0)
}
TypeId::Int2 => {
if buf.remaining() < 2 {
return Err(Error::Protocol("unexpected EOF reading SMALLINT".into()));
}
SqlValue::SmallInt(buf.get_i16_le())
}
TypeId::Int4 => {
if buf.remaining() < 4 {
return Err(Error::Protocol("unexpected EOF reading INT".into()));
}
SqlValue::Int(buf.get_i32_le())
}
TypeId::Float4 => {
if buf.remaining() < 4 {
return Err(Error::Protocol("unexpected EOF reading REAL".into()));
}
SqlValue::Float(buf.get_f32_le())
}
TypeId::Int8 => {
if buf.remaining() < 8 {
return Err(Error::Protocol("unexpected EOF reading BIGINT".into()));
}
SqlValue::BigInt(buf.get_i64_le())
}
TypeId::Float8 => {
if buf.remaining() < 8 {
return Err(Error::Protocol("unexpected EOF reading FLOAT".into()));
}
SqlValue::Double(buf.get_f64_le())
}
TypeId::Money | TypeId::Money4 | TypeId::MoneyN => {
let bytes = match col.type_id {
TypeId::Money => 8,
TypeId::Money4 => 4,
TypeId::MoneyN => {
if buf.remaining() < 1 {
return Err(Error::Protocol(
"unexpected EOF reading MoneyN length".into(),
));
}
buf.get_u8() as usize
}
_ => unreachable!("inner match is bounded by outer Money|Money4|MoneyN arm"),
};
if buf.remaining() < bytes {
return Err(Error::Protocol(format!(
"unexpected EOF reading money data ({bytes} bytes)"
)));
}
parse_money_value(buf, bytes)?
}
TypeId::IntN => {
if buf.remaining() < 1 {
return Err(Error::Protocol("unexpected EOF reading IntN length".into()));
}
let len = buf.get_u8();
if buf.remaining() < len as usize {
return Err(Error::Protocol("unexpected EOF reading IntN data".into()));
}
match len {
0 => SqlValue::Null,
1 => SqlValue::TinyInt(buf.get_u8()),
2 => SqlValue::SmallInt(buf.get_i16_le()),
4 => SqlValue::Int(buf.get_i32_le()),
8 => SqlValue::BigInt(buf.get_i64_le()),
_ => {
return Err(Error::Protocol(format!("invalid IntN length: {len}")));
}
}
}
TypeId::FloatN => {
if buf.remaining() < 1 {
return Err(Error::Protocol(
"unexpected EOF reading FloatN length".into(),
));
}
let len = buf.get_u8();
if buf.remaining() < len as usize {
return Err(Error::Protocol("unexpected EOF reading FloatN data".into()));
}
match len {
0 => SqlValue::Null,
4 => SqlValue::Float(buf.get_f32_le()),
8 => SqlValue::Double(buf.get_f64_le()),
_ => {
return Err(Error::Protocol(format!("invalid FloatN length: {len}")));
}
}
}
TypeId::BitN => {
if buf.remaining() < 1 {
return Err(Error::Protocol("unexpected EOF reading BitN length".into()));
}
let len = buf.get_u8();
if buf.remaining() < len as usize {
return Err(Error::Protocol("unexpected EOF reading BitN data".into()));
}
match len {
0 => SqlValue::Null,
1 => SqlValue::Bool(buf.get_u8() != 0),
_ => {
return Err(Error::Protocol(format!("invalid BitN length: {len}")));
}
}
}
TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
let type_info = mssql_types::TypeInfo::decimal(
col.type_info.precision.unwrap_or(18),
col.type_info.scale.unwrap_or(0),
);
mssql_types::__private::decode_decimal(buf, &type_info)?
}
TypeId::DateTimeN => {
if buf.remaining() < 1 {
return Err(Error::Protocol(
"unexpected EOF reading DateTimeN length".into(),
));
}
let len = buf.get_u8() as usize;
if len == 0 {
SqlValue::Null
} else if buf.remaining() < len {
return Err(Error::Protocol("unexpected EOF reading DateTimeN".into()));
} else {
match len {
4 => {
let days = buf.get_u16_le() as i64;
let minutes = buf.get_u16_le() as u32;
#[cfg(feature = "chrono")]
{
SqlValue::DateTime(smalldatetime_from_wire(days, minutes)?)
}
#[cfg(not(feature = "chrono"))]
{
SqlValue::String(format!("SMALLDATETIME({days},{minutes})"))
}
}
8 => {
let days = buf.get_i32_le() as i64;
let time_300ths = buf.get_u32_le() as u64;
#[cfg(feature = "chrono")]
{
SqlValue::DateTime(datetime_from_wire(days, time_300ths)?)
}
#[cfg(not(feature = "chrono"))]
{
SqlValue::String(format!("DATETIME({days},{time_300ths})"))
}
}
_ => {
return Err(Error::Protocol(format!("invalid DateTimeN length: {len}")));
}
}
}
}
TypeId::DateTime => {
if buf.remaining() < 8 {
return Err(Error::Protocol("unexpected EOF reading DATETIME".into()));
}
let days = buf.get_i32_le() as i64;
let time_300ths = buf.get_u32_le() as u64;
#[cfg(feature = "chrono")]
{
SqlValue::DateTime(datetime_from_wire(days, time_300ths)?)
}
#[cfg(not(feature = "chrono"))]
{
SqlValue::String(format!("DATETIME({days},{time_300ths})"))
}
}
TypeId::DateTime4 => {
if buf.remaining() < 4 {
return Err(Error::Protocol(
"unexpected EOF reading SMALLDATETIME".into(),
));
}
let days = buf.get_u16_le() as i64;
let minutes = buf.get_u16_le() as u32;
#[cfg(feature = "chrono")]
{
SqlValue::DateTime(smalldatetime_from_wire(days, minutes)?)
}
#[cfg(not(feature = "chrono"))]
{
SqlValue::String(format!("SMALLDATETIME({days},{minutes})"))
}
}
TypeId::Date => {
if buf.remaining() < 1 {
return Err(Error::Protocol("unexpected EOF reading DATE length".into()));
}
let len = buf.get_u8() as usize;
if len == 0 {
SqlValue::Null
} else if len != 3 {
return Err(Error::Protocol(format!("invalid DATE length: {len}")));
} else if buf.remaining() < 3 {
return Err(Error::Protocol("unexpected EOF reading DATE".into()));
} else {
let days = buf.get_u8() as u32
| ((buf.get_u8() as u32) << 8)
| ((buf.get_u8() as u32) << 16);
#[cfg(feature = "chrono")]
{
let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1)
.expect("epoch 0001-01-01 is valid");
let date = base + chrono::Duration::days(days as i64);
SqlValue::Date(date)
}
#[cfg(not(feature = "chrono"))]
{
SqlValue::String(format!("DATE({days})"))
}
}
}
TypeId::Time => {
if buf.remaining() < 1 {
return Err(Error::Protocol("unexpected EOF reading TIME length".into()));
}
let len = buf.get_u8() as usize;
if len == 0 {
SqlValue::Null
} else if buf.remaining() < len {
return Err(Error::Protocol("unexpected EOF reading TIME".into()));
} else {
let mut time_bytes = [0u8; 8];
for byte in time_bytes.iter_mut().take(len) {
*byte = buf.get_u8();
}
let intervals = u64::from_le_bytes(time_bytes);
#[cfg(feature = "chrono")]
{
let scale = col.type_info.scale.unwrap_or(7);
let time = intervals_to_time(intervals, scale);
SqlValue::Time(time)
}
#[cfg(not(feature = "chrono"))]
{
SqlValue::String(format!("TIME({intervals})"))
}
}
}
TypeId::DateTime2 => {
if buf.remaining() < 1 {
return Err(Error::Protocol(
"unexpected EOF reading DATETIME2 length".into(),
));
}
let len = buf.get_u8() as usize;
if len == 0 {
SqlValue::Null
} else if buf.remaining() < len {
return Err(Error::Protocol("unexpected EOF reading DATETIME2".into()));
} else {
let scale = col.type_info.scale.unwrap_or(7);
let time_len = time_bytes_for_scale(scale);
if len < time_len + 3 {
return Err(Error::Protocol(format!(
"DATETIME2 length {len} too short for scale {scale}"
)));
}
let mut time_bytes = [0u8; 8];
for byte in time_bytes.iter_mut().take(time_len) {
*byte = buf.get_u8();
}
let intervals = u64::from_le_bytes(time_bytes);
let days = buf.get_u8() as u32
| ((buf.get_u8() as u32) << 8)
| ((buf.get_u8() as u32) << 16);
#[cfg(feature = "chrono")]
{
let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1)
.expect("epoch 0001-01-01 is valid");
let date = base + chrono::Duration::days(days as i64);
let time = intervals_to_time(intervals, scale);
SqlValue::DateTime(date.and_time(time))
}
#[cfg(not(feature = "chrono"))]
{
SqlValue::String(format!("DATETIME2({days},{intervals})"))
}
}
}
TypeId::DateTimeOffset => {
if buf.remaining() < 1 {
return Err(Error::Protocol(
"unexpected EOF reading DATETIMEOFFSET length".into(),
));
}
let len = buf.get_u8() as usize;
if len == 0 {
SqlValue::Null
} else if buf.remaining() < len {
return Err(Error::Protocol(
"unexpected EOF reading DATETIMEOFFSET".into(),
));
} else {
let scale = col.type_info.scale.unwrap_or(7);
let time_len = time_bytes_for_scale(scale);
if len < time_len + 5 {
return Err(Error::Protocol(format!(
"DATETIMEOFFSET length {len} too short for scale {scale}"
)));
}
let mut time_bytes = [0u8; 8];
for byte in time_bytes.iter_mut().take(time_len) {
*byte = buf.get_u8();
}
let intervals = u64::from_le_bytes(time_bytes);
let days = buf.get_u8() as u32
| ((buf.get_u8() as u32) << 8)
| ((buf.get_u8() as u32) << 16);
let offset_minutes = buf.get_i16_le();
#[cfg(feature = "chrono")]
{
use chrono::TimeZone;
let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1)
.expect("epoch 0001-01-01 is valid");
let date = base + chrono::Duration::days(days as i64);
let time = intervals_to_time(intervals, scale);
let offset = chrono::FixedOffset::east_opt((offset_minutes as i32) * 60)
.unwrap_or_else(|| {
chrono::FixedOffset::east_opt(0).expect("UTC offset 0 is valid")
});
let datetime = offset.from_utc_datetime(&date.and_time(time));
SqlValue::DateTimeOffset(datetime)
}
#[cfg(not(feature = "chrono"))]
{
SqlValue::String(format!(
"DATETIMEOFFSET({days},{intervals},{offset_minutes})"
))
}
}
}
TypeId::Text => parse_plp_varchar(buf, col.type_info.collation.as_ref())?,
TypeId::Char | TypeId::VarChar => {
if buf.remaining() < 1 {
return Err(Error::Protocol(
"unexpected EOF reading legacy varchar length".into(),
));
}
let len = buf.get_u8();
if len == 0xFF {
SqlValue::Null
} else if len == 0 {
SqlValue::String(String::new())
} else if buf.remaining() < len as usize {
return Err(Error::Protocol(
"unexpected EOF reading legacy varchar data".into(),
));
} else {
let data = &buf[..len as usize];
let s = decode_varchar_string(data, col.type_info.collation.as_ref());
buf.advance(len as usize);
SqlValue::String(s)
}
}
TypeId::BigVarChar | TypeId::BigChar => {
if col.type_info.max_length == Some(0xFFFF) {
parse_plp_varchar(buf, col.type_info.collation.as_ref())?
} else {
if buf.remaining() < 2 {
return Err(Error::Protocol(
"unexpected EOF reading varchar length".into(),
));
}
let len = buf.get_u16_le();
if len == 0xFFFF {
SqlValue::Null
} else if buf.remaining() < len as usize {
return Err(Error::Protocol(
"unexpected EOF reading varchar data".into(),
));
} else {
let data = &buf[..len as usize];
let s = decode_varchar_string(data, col.type_info.collation.as_ref());
buf.advance(len as usize);
SqlValue::String(s)
}
}
}
TypeId::NText => parse_plp_nvarchar(buf)?,
TypeId::NVarChar | TypeId::NChar => {
if col.type_info.max_length == Some(0xFFFF) {
parse_plp_nvarchar(buf)?
} else {
if buf.remaining() < 2 {
return Err(Error::Protocol(
"unexpected EOF reading nvarchar length".into(),
));
}
let len = buf.get_u16_le();
if len == 0xFFFF {
SqlValue::Null
} else if buf.remaining() < len as usize {
return Err(Error::Protocol(
"unexpected EOF reading nvarchar data".into(),
));
} else {
let data = &buf[..len as usize];
let utf16: Vec<u16> = data
.chunks_exact(2)
.map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
.collect();
let s = String::from_utf16(&utf16)
.map_err(|_| Error::Protocol("invalid UTF-16 in nvarchar".into()))?;
buf.advance(len as usize);
SqlValue::String(s)
}
}
}
TypeId::Image => parse_plp_varbinary(buf)?,
TypeId::Binary | TypeId::VarBinary => {
if buf.remaining() < 1 {
return Err(Error::Protocol(
"unexpected EOF reading legacy varbinary length".into(),
));
}
let len = buf.get_u8();
if len == 0xFF {
SqlValue::Null
} else if len == 0 {
SqlValue::Binary(bytes::Bytes::new())
} else if buf.remaining() < len as usize {
return Err(Error::Protocol(
"unexpected EOF reading legacy varbinary data".into(),
));
} else {
let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
buf.advance(len as usize);
SqlValue::Binary(data)
}
}
TypeId::BigVarBinary | TypeId::BigBinary => {
if col.type_info.max_length == Some(0xFFFF) {
parse_plp_varbinary(buf)?
} else {
if buf.remaining() < 2 {
return Err(Error::Protocol(
"unexpected EOF reading varbinary length".into(),
));
}
let len = buf.get_u16_le();
if len == 0xFFFF {
SqlValue::Null
} else if buf.remaining() < len as usize {
return Err(Error::Protocol(
"unexpected EOF reading varbinary data".into(),
));
} else {
let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
buf.advance(len as usize);
SqlValue::Binary(data)
}
}
}
TypeId::Xml => {
match parse_plp_nvarchar(buf)? {
SqlValue::Null => SqlValue::Null,
SqlValue::String(s) => SqlValue::Xml(s),
_ => {
return Err(Error::Protocol(
"unexpected value type when parsing XML".into(),
));
}
}
}
TypeId::Guid => {
if buf.remaining() < 1 {
return Err(Error::Protocol("unexpected EOF reading GUID length".into()));
}
let len = buf.get_u8();
if len == 0 {
SqlValue::Null
} else if len != 16 {
return Err(Error::Protocol(format!("invalid GUID length: {len}")));
} else if buf.remaining() < 16 {
return Err(Error::Protocol("unexpected EOF reading GUID".into()));
} else {
decode_guid_bytes(buf)
}
}
TypeId::Variant => parse_sql_variant(buf)?,
TypeId::Udt => parse_plp_varbinary(buf)?,
_ => {
if buf.remaining() < 2 {
return Err(Error::Protocol(format!(
"unexpected EOF reading {:?}",
col.type_id
)));
}
let len = buf.get_u16_le();
if len == 0xFFFF {
SqlValue::Null
} else if buf.remaining() < len as usize {
return Err(Error::Protocol(format!(
"unexpected EOF reading {:?} data",
col.type_id
)));
} else {
let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
buf.advance(len as usize);
SqlValue::Binary(data)
}
}
};
Ok(value)
}
pub(crate) fn parse_plp_nvarchar(buf: &mut &[u8]) -> Result<SqlValue> {
if buf.remaining() < 8 {
return Err(Error::Protocol(
"unexpected EOF reading PLP total length".into(),
));
}
let total_len = buf.get_u64_le();
if total_len == 0xFFFFFFFFFFFFFFFF {
return Ok(SqlValue::Null);
}
let mut all_data = Vec::new();
loop {
if buf.remaining() < 4 {
return Err(Error::Protocol(
"unexpected EOF reading PLP chunk length".into(),
));
}
let chunk_len = buf.get_u32_le() as usize;
if chunk_len == 0 {
break; }
if buf.remaining() < chunk_len {
return Err(Error::Protocol(
"unexpected EOF reading PLP chunk data".into(),
));
}
all_data.extend_from_slice(&buf[..chunk_len]);
buf.advance(chunk_len);
}
let utf16: Vec<u16> = all_data
.chunks_exact(2)
.map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
.collect();
let s = String::from_utf16(&utf16)
.map_err(|_| Error::Protocol("invalid UTF-16 in PLP nvarchar".into()))?;
Ok(SqlValue::String(s))
}
#[allow(unused_variables)]
fn decode_varchar_string(data: &[u8], collation: Option<&Collation>) -> String {
#[cfg(feature = "encoding")]
if let Some(coll) = collation {
if let Some(encoding) = coll.encoding() {
let (decoded, _, had_errors) = encoding.decode(data);
if !had_errors {
return decoded.into_owned();
}
}
}
String::from_utf8_lossy(data).into_owned()
}
fn parse_plp_varchar(buf: &mut &[u8], collation: Option<&Collation>) -> Result<SqlValue> {
if buf.remaining() < 8 {
return Err(Error::Protocol(
"unexpected EOF reading PLP total length".into(),
));
}
let total_len = buf.get_u64_le();
if total_len == 0xFFFFFFFFFFFFFFFF {
return Ok(SqlValue::Null);
}
let mut all_data = Vec::new();
loop {
if buf.remaining() < 4 {
return Err(Error::Protocol(
"unexpected EOF reading PLP chunk length".into(),
));
}
let chunk_len = buf.get_u32_le() as usize;
if chunk_len == 0 {
break; }
if buf.remaining() < chunk_len {
return Err(Error::Protocol(
"unexpected EOF reading PLP chunk data".into(),
));
}
all_data.extend_from_slice(&buf[..chunk_len]);
buf.advance(chunk_len);
}
let s = decode_varchar_string(&all_data, collation);
Ok(SqlValue::String(s))
}
pub(crate) fn parse_plp_varbinary(buf: &mut &[u8]) -> Result<SqlValue> {
if buf.remaining() < 8 {
return Err(Error::Protocol(
"unexpected EOF reading PLP total length".into(),
));
}
let total_len = buf.get_u64_le();
if total_len == 0xFFFFFFFFFFFFFFFF {
return Ok(SqlValue::Null);
}
let mut all_data = Vec::new();
loop {
if buf.remaining() < 4 {
return Err(Error::Protocol(
"unexpected EOF reading PLP chunk length".into(),
));
}
let chunk_len = buf.get_u32_le() as usize;
if chunk_len == 0 {
break; }
if buf.remaining() < chunk_len {
return Err(Error::Protocol(
"unexpected EOF reading PLP chunk data".into(),
));
}
all_data.extend_from_slice(&buf[..chunk_len]);
buf.advance(chunk_len);
}
Ok(SqlValue::Binary(bytes::Bytes::from(all_data)))
}
fn parse_sql_variant(buf: &mut &[u8]) -> Result<SqlValue> {
if buf.remaining() < 4 {
return Err(Error::Protocol(
"unexpected EOF reading SQL_VARIANT length".into(),
));
}
let total_len = buf.get_u32_le() as usize;
if total_len == 0 {
return Ok(SqlValue::Null);
}
if buf.remaining() < total_len {
return Err(Error::Protocol(
"unexpected EOF reading SQL_VARIANT data".into(),
));
}
if total_len < 2 {
return Err(Error::Protocol(
"SQL_VARIANT too short for type info".into(),
));
}
let base_type = buf.get_u8();
let prop_count = buf.get_u8() as usize;
if buf.remaining() < prop_count {
return Err(Error::Protocol(
"unexpected EOF reading SQL_VARIANT properties".into(),
));
}
let data_len = total_len.saturating_sub(2).saturating_sub(prop_count);
match base_type {
0x30 => {
buf.advance(prop_count);
if data_len < 1 {
return Ok(SqlValue::Null);
}
let v = buf.get_u8();
Ok(SqlValue::TinyInt(v))
}
0x32 => {
buf.advance(prop_count);
if data_len < 1 {
return Ok(SqlValue::Null);
}
let v = buf.get_u8();
Ok(SqlValue::Bool(v != 0))
}
0x34 => {
buf.advance(prop_count);
if data_len < 2 {
return Ok(SqlValue::Null);
}
let v = buf.get_i16_le();
Ok(SqlValue::SmallInt(v))
}
0x38 => {
buf.advance(prop_count);
if data_len < 4 {
return Ok(SqlValue::Null);
}
let v = buf.get_i32_le();
Ok(SqlValue::Int(v))
}
0x7F => {
buf.advance(prop_count);
if data_len < 8 {
return Ok(SqlValue::Null);
}
let v = buf.get_i64_le();
Ok(SqlValue::BigInt(v))
}
0x6D => {
let float_len = if prop_count >= 1 { buf.get_u8() } else { 8 };
buf.advance(prop_count.saturating_sub(1));
if float_len == 4 && data_len >= 4 {
let v = buf.get_f32_le();
Ok(SqlValue::Float(v))
} else if data_len >= 8 {
let v = buf.get_f64_le();
Ok(SqlValue::Double(v))
} else {
Ok(SqlValue::Null)
}
}
0x6E => {
let money_len = if prop_count >= 1 { buf.get_u8() } else { 8 };
buf.advance(prop_count.saturating_sub(1));
if money_len == 0 || data_len == 0 {
Ok(SqlValue::Null)
} else if (money_len == 4 && data_len >= 4) || (money_len == 8 && data_len >= 8) {
parse_money_value(buf, money_len as usize)
} else {
buf.advance(data_len);
Ok(SqlValue::Null)
}
}
0x6F => {
#[cfg(feature = "chrono")]
let dt_len = if prop_count >= 1 { buf.get_u8() } else { 8 };
#[cfg(not(feature = "chrono"))]
if prop_count >= 1 {
buf.get_u8();
}
buf.advance(prop_count.saturating_sub(1));
#[cfg(feature = "chrono")]
{
if dt_len == 4 && data_len >= 4 {
let days = buf.get_u16_le() as i64;
let mins = buf.get_u16_le() as u32;
Ok(SqlValue::DateTime(smalldatetime_from_wire(days, mins)?))
} else if data_len >= 8 {
let days = buf.get_i32_le() as i64;
let ticks = buf.get_u32_le() as u64;
Ok(SqlValue::DateTime(datetime_from_wire(days, ticks)?))
} else {
Ok(SqlValue::Null)
}
}
#[cfg(not(feature = "chrono"))]
{
buf.advance(data_len);
Ok(SqlValue::Null)
}
}
0x6A | 0x6C => {
let precision = if prop_count >= 1 { buf.get_u8() } else { 18 };
let scale = if prop_count >= 2 { buf.get_u8() } else { 0 };
buf.advance(prop_count.saturating_sub(2));
if data_len > 17 {
buf.advance(data_len);
return Ok(SqlValue::Null);
}
let type_info = mssql_types::TypeInfo::decimal(precision, scale);
let result = {
let len_prefix = [data_len as u8];
let mut framed = (&len_prefix[..]).chain(&buf[..data_len]);
mssql_types::__private::decode_decimal(&mut framed, &type_info)
};
buf.advance(data_len);
result.map_err(Into::into)
}
0x24 => {
buf.advance(prop_count);
if data_len < 16 {
return Ok(SqlValue::Null);
}
Ok(decode_guid_bytes(buf))
}
0x28 => {
buf.advance(prop_count);
#[cfg(feature = "chrono")]
{
if data_len < 3 {
return Ok(SqlValue::Null);
}
let mut date_bytes = [0u8; 4];
date_bytes[0] = buf.get_u8();
date_bytes[1] = buf.get_u8();
date_bytes[2] = buf.get_u8();
let days = u32::from_le_bytes(date_bytes);
let base =
chrono::NaiveDate::from_ymd_opt(1, 1, 1).expect("epoch 0001-01-01 is valid");
let date = base + chrono::Duration::days(days as i64);
Ok(SqlValue::Date(date))
}
#[cfg(not(feature = "chrono"))]
{
buf.advance(data_len);
Ok(SqlValue::Null)
}
}
0x29 => {
#[cfg_attr(not(feature = "chrono"), allow(unused_variables))]
let scale = if prop_count >= 1 { buf.get_u8() } else { 7 };
buf.advance(prop_count.saturating_sub(1));
#[cfg(feature = "chrono")]
{
if data_len == 0 {
return Ok(SqlValue::Null);
}
let time_len = time_bytes_for_scale(scale);
if data_len < time_len {
return Ok(SqlValue::Null);
}
let mut time_bytes = [0u8; 8];
for byte in time_bytes.iter_mut().take(time_len) {
*byte = buf.get_u8();
}
if data_len > time_len {
buf.advance(data_len - time_len);
}
let intervals = u64::from_le_bytes(time_bytes);
Ok(SqlValue::Time(intervals_to_time(intervals, scale)))
}
#[cfg(not(feature = "chrono"))]
{
buf.advance(data_len);
Ok(SqlValue::Null)
}
}
0x2A => {
#[cfg_attr(not(feature = "chrono"), allow(unused_variables))]
let scale = if prop_count >= 1 { buf.get_u8() } else { 7 };
buf.advance(prop_count.saturating_sub(1));
#[cfg(feature = "chrono")]
{
let time_len = time_bytes_for_scale(scale);
if data_len < time_len + 3 {
return Ok(SqlValue::Null);
}
let mut time_bytes = [0u8; 8];
for byte in time_bytes.iter_mut().take(time_len) {
*byte = buf.get_u8();
}
let intervals = u64::from_le_bytes(time_bytes);
let days = buf.get_u8() as u32
| ((buf.get_u8() as u32) << 8)
| ((buf.get_u8() as u32) << 16);
let consumed = time_len + 3;
if data_len > consumed {
buf.advance(data_len - consumed);
}
let base =
chrono::NaiveDate::from_ymd_opt(1, 1, 1).expect("epoch 0001-01-01 is valid");
let date = base + chrono::Duration::days(days as i64);
let time = intervals_to_time(intervals, scale);
Ok(SqlValue::DateTime(date.and_time(time)))
}
#[cfg(not(feature = "chrono"))]
{
buf.advance(data_len);
Ok(SqlValue::Null)
}
}
0x2B => {
#[cfg_attr(not(feature = "chrono"), allow(unused_variables))]
let scale = if prop_count >= 1 { buf.get_u8() } else { 7 };
buf.advance(prop_count.saturating_sub(1));
#[cfg(feature = "chrono")]
{
let time_len = time_bytes_for_scale(scale);
if data_len < time_len + 3 + 2 {
return Ok(SqlValue::Null);
}
let mut time_bytes = [0u8; 8];
for byte in time_bytes.iter_mut().take(time_len) {
*byte = buf.get_u8();
}
let intervals = u64::from_le_bytes(time_bytes);
let days = buf.get_u8() as u32
| ((buf.get_u8() as u32) << 8)
| ((buf.get_u8() as u32) << 16);
let offset_minutes = buf.get_i16_le();
let consumed = time_len + 3 + 2;
if data_len > consumed {
buf.advance(data_len - consumed);
}
use chrono::TimeZone;
let base =
chrono::NaiveDate::from_ymd_opt(1, 1, 1).expect("epoch 0001-01-01 is valid");
let date = base + chrono::Duration::days(days as i64);
let time = intervals_to_time(intervals, scale);
let offset = chrono::FixedOffset::east_opt((offset_minutes as i32) * 60)
.unwrap_or_else(|| {
chrono::FixedOffset::east_opt(0).expect("UTC offset 0 is valid")
});
let datetime = offset.from_utc_datetime(&date.and_time(time));
Ok(SqlValue::DateTimeOffset(datetime))
}
#[cfg(not(feature = "chrono"))]
{
buf.advance(data_len);
Ok(SqlValue::Null)
}
}
0xA7 | 0x2F | 0x27 => {
let collation = if prop_count >= 5 && buf.remaining() >= 5 {
let lcid = buf.get_u32_le();
let sort_id = buf.get_u8();
buf.advance(prop_count.saturating_sub(5)); Some(Collation { lcid, sort_id })
} else {
buf.advance(prop_count);
None
};
if data_len == 0 {
return Ok(SqlValue::String(String::new()));
}
let data = &buf[..data_len];
let s = decode_varchar_string(data, collation.as_ref());
buf.advance(data_len);
Ok(SqlValue::String(s))
}
0xE7 | 0xEF => {
buf.advance(prop_count);
if data_len == 0 {
return Ok(SqlValue::String(String::new()));
}
let utf16: Vec<u16> = buf[..data_len]
.chunks_exact(2)
.map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
.collect();
buf.advance(data_len);
let s = String::from_utf16(&utf16)
.map_err(|_| Error::Protocol("invalid UTF-16 in SQL_VARIANT nvarchar".into()))?;
Ok(SqlValue::String(s))
}
0xA5 | 0x2D | 0x25 => {
buf.advance(prop_count);
let data = bytes::Bytes::copy_from_slice(&buf[..data_len]);
buf.advance(data_len);
Ok(SqlValue::Binary(data))
}
_ => {
buf.advance(prop_count);
let data = bytes::Bytes::copy_from_slice(&buf[..data_len]);
buf.advance(data_len);
Ok(SqlValue::Binary(data))
}
}
}
fn decode_guid_bytes(buf: &mut &[u8]) -> SqlValue {
let mut bytes = [0u8; 16];
bytes[3] = buf.get_u8();
bytes[2] = buf.get_u8();
bytes[1] = buf.get_u8();
bytes[0] = buf.get_u8();
bytes[5] = buf.get_u8();
bytes[4] = buf.get_u8();
bytes[7] = buf.get_u8();
bytes[6] = buf.get_u8();
for byte in &mut bytes[8..16] {
*byte = buf.get_u8();
}
#[cfg(feature = "uuid")]
{
SqlValue::Uuid(uuid::Uuid::from_bytes(bytes))
}
#[cfg(not(feature = "uuid"))]
{
SqlValue::Binary(bytes::Bytes::copy_from_slice(&bytes))
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use tds_protocol::token::TypeInfo;
fn make_plp_data(total_len: u64, chunks: &[&[u8]]) -> Vec<u8> {
let mut data = Vec::new();
data.extend_from_slice(&total_len.to_le_bytes());
for chunk in chunks {
let len = chunk.len() as u32;
data.extend_from_slice(&len.to_le_bytes());
data.extend_from_slice(chunk);
}
data.extend_from_slice(&0u32.to_le_bytes());
data
}
#[test]
fn test_parse_plp_nvarchar_simple() {
let utf16_data = [0x48, 0x00, 0x65, 0x00, 0x6C, 0x00, 0x6C, 0x00, 0x6F, 0x00];
let plp = make_plp_data(10, &[&utf16_data]);
let mut buf: &[u8] = &plp;
let result = parse_plp_nvarchar(&mut buf).unwrap();
match result {
SqlValue::String(s) => assert_eq!(s, "Hello"),
_ => panic!("expected String, got {result:?}"),
}
}
#[test]
fn test_parse_plp_nvarchar_null() {
let plp = 0xFFFFFFFFFFFFFFFFu64.to_le_bytes();
let mut buf: &[u8] = &plp;
let result = parse_plp_nvarchar(&mut buf).unwrap();
assert!(matches!(result, SqlValue::Null));
}
#[test]
fn test_parse_plp_nvarchar_empty() {
let plp = make_plp_data(0, &[]);
let mut buf: &[u8] = &plp;
let result = parse_plp_nvarchar(&mut buf).unwrap();
match result {
SqlValue::String(s) => assert_eq!(s, ""),
_ => panic!("expected empty String"),
}
}
#[test]
fn test_parse_plp_nvarchar_multi_chunk() {
let chunk1 = [0x48, 0x00, 0x65, 0x00, 0x6C, 0x00]; let chunk2 = [0x6C, 0x00, 0x6F, 0x00]; let plp = make_plp_data(10, &[&chunk1, &chunk2]);
let mut buf: &[u8] = &plp;
let result = parse_plp_nvarchar(&mut buf).unwrap();
match result {
SqlValue::String(s) => assert_eq!(s, "Hello"),
_ => panic!("expected String"),
}
}
#[test]
fn test_parse_plp_varchar_simple() {
let data = b"Hello World";
let plp = make_plp_data(11, &[data]);
let mut buf: &[u8] = &plp;
let result = parse_plp_varchar(&mut buf, None).unwrap();
match result {
SqlValue::String(s) => assert_eq!(s, "Hello World"),
_ => panic!("expected String"),
}
}
#[test]
fn test_parse_plp_varchar_null() {
let plp = 0xFFFFFFFFFFFFFFFFu64.to_le_bytes();
let mut buf: &[u8] = &plp;
let result = parse_plp_varchar(&mut buf, None).unwrap();
assert!(matches!(result, SqlValue::Null));
}
#[test]
fn test_parse_plp_varbinary_simple() {
let data = [0x01, 0x02, 0x03, 0x04, 0x05];
let plp = make_plp_data(5, &[&data]);
let mut buf: &[u8] = &plp;
let result = parse_plp_varbinary(&mut buf).unwrap();
match result {
SqlValue::Binary(b) => assert_eq!(&b[..], &[0x01, 0x02, 0x03, 0x04, 0x05]),
_ => panic!("expected Binary"),
}
}
#[test]
fn test_parse_plp_varbinary_null() {
let plp = 0xFFFFFFFFFFFFFFFFu64.to_le_bytes();
let mut buf: &[u8] = &plp;
let result = parse_plp_varbinary(&mut buf).unwrap();
assert!(matches!(result, SqlValue::Null));
}
#[test]
fn test_parse_plp_varbinary_large() {
let chunk1: Vec<u8> = (0..100u8).collect();
let chunk2: Vec<u8> = (100..200u8).collect();
let chunk3: Vec<u8> = (200..255u8).collect();
let total_len = chunk1.len() + chunk2.len() + chunk3.len();
let plp = make_plp_data(total_len as u64, &[&chunk1, &chunk2, &chunk3]);
let mut buf: &[u8] = &plp;
let result = parse_plp_varbinary(&mut buf).unwrap();
match result {
SqlValue::Binary(b) => {
assert_eq!(b.len(), 255);
for (i, &byte) in b.iter().enumerate() {
assert_eq!(byte, i as u8);
}
}
_ => panic!("expected Binary"),
}
}
fn make_nvarchar_int_row(nvarchar_value: &str, int_value: i32) -> Vec<u8> {
let mut data = Vec::new();
let utf16: Vec<u16> = nvarchar_value.encode_utf16().collect();
let byte_len = (utf16.len() * 2) as u16;
data.extend_from_slice(&byte_len.to_le_bytes());
for code_unit in utf16 {
data.extend_from_slice(&code_unit.to_le_bytes());
}
data.push(4); data.extend_from_slice(&int_value.to_le_bytes());
data
}
#[cfg(feature = "chrono")]
fn datetime_col(type_id: TypeId, col_type: u8, max_length: Option<u32>) -> ColumnData {
ColumnData {
name: "c".to_string(),
type_id,
col_type,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length,
precision: None,
scale: None,
collation: None,
},
crypto_metadata: None,
}
}
#[cfg(feature = "chrono")]
#[test]
fn hostile_smalldatetime_minutes_is_error_not_panic() {
let data = [4u8, 0x00, 0x00, 0xFF, 0xFF];
let col = datetime_col(TypeId::DateTimeN, 0x6F, Some(4));
let mut buf: &[u8] = &data;
assert!(parse_column_value(&mut buf, &col).is_err());
let data = [0x00, 0x00, 0xFF, 0xFF];
let col = datetime_col(TypeId::DateTime4, 0x3A, None);
let mut buf: &[u8] = &data;
assert!(parse_column_value(&mut buf, &col).is_err());
}
#[cfg(feature = "chrono")]
#[test]
fn hostile_datetime_days_overflow_is_error_not_panic() {
let mut data = vec![8u8];
data.extend_from_slice(&i32::MAX.to_le_bytes());
data.extend_from_slice(&0u32.to_le_bytes());
let col = datetime_col(TypeId::DateTimeN, 0x6F, Some(8));
let mut buf: &[u8] = &data;
assert!(parse_column_value(&mut buf, &col).is_err());
let mut data = Vec::new();
data.extend_from_slice(&i32::MIN.to_le_bytes());
data.extend_from_slice(&0u32.to_le_bytes());
let col = datetime_col(TypeId::DateTime, 0x3D, None);
let mut buf: &[u8] = &data;
assert!(parse_column_value(&mut buf, &col).is_err());
}
#[cfg(feature = "chrono")]
#[test]
fn hostile_datetime_time_300ths_is_error_not_panic() {
let mut data = vec![8u8];
data.extend_from_slice(&0i32.to_le_bytes());
data.extend_from_slice(&u32::MAX.to_le_bytes());
let col = datetime_col(TypeId::DateTimeN, 0x6F, Some(8));
let mut buf: &[u8] = &data;
assert!(parse_column_value(&mut buf, &col).is_err());
}
#[cfg(feature = "chrono")]
#[test]
fn hostile_truncated_n_types_are_error_not_panic() {
for (type_id, col_type, len) in [
(TypeId::IntN, 0x26u8, 8u8),
(TypeId::FloatN, 0x6D, 4),
(TypeId::BitN, 0x68, 1),
] {
let data = [len];
let col = datetime_col(type_id, col_type, Some(len as u32));
let mut buf: &[u8] = &data;
assert!(
parse_column_value(&mut buf, &col).is_err(),
"{type_id:?} must error on truncated payload"
);
}
}
#[cfg(feature = "chrono")]
#[test]
fn hostile_short_datetime2_len_is_error_not_panic() {
let data = [1u8, 0xAA];
let mut col = datetime_col(TypeId::DateTime2, 0x2A, None);
col.type_info.scale = Some(7);
let mut buf: &[u8] = &data;
assert!(parse_column_value(&mut buf, &col).is_err());
let data = [1u8, 0xAA];
let mut col = datetime_col(TypeId::DateTimeOffset, 0x2B, None);
col.type_info.scale = Some(7);
let mut buf: &[u8] = &data;
assert!(parse_column_value(&mut buf, &col).is_err());
}
#[cfg(feature = "chrono")]
#[test]
fn hostile_variant_datetime_days_overflow_is_error_not_panic() {
let mut data = Vec::new();
data.extend_from_slice(&11u32.to_le_bytes());
data.push(0x6F);
data.push(0x01);
data.push(0x08);
data.extend_from_slice(&i32::MAX.to_le_bytes());
data.extend_from_slice(&0u32.to_le_bytes());
let mut buf: &[u8] = &data;
assert!(parse_sql_variant(&mut buf).is_err());
}
#[cfg(feature = "decimal")]
#[test]
fn hostile_variant_decimal_over_96bit_is_error_not_panic() {
let mantissa = 1u128 << 100;
let mut data = Vec::new();
data.extend_from_slice(&21u32.to_le_bytes());
data.push(0x6A); data.push(0x02); data.push(38); data.push(10); data.push(0x01); data.extend_from_slice(&mantissa.to_le_bytes());
let mut buf: &[u8] = &data;
let err = parse_sql_variant(&mut buf).expect_err("oversized NUMERIC must error");
assert!(
err.to_string().contains("rust_decimal"),
"error should explain the range limitation: {err}"
);
}
#[cfg(feature = "decimal")]
#[test]
fn variant_decimal_decodes_via_shared_decoder() {
let mut data = Vec::new();
data.extend_from_slice(&7u32.to_le_bytes());
data.push(0x6A); data.push(0x02); data.push(5); data.push(2); data.push(0x01); data.extend_from_slice(&12345u16.to_le_bytes()); let mut buf: &[u8] = &data;
let value = parse_sql_variant(&mut buf).expect("valid NUMERIC must decode");
assert_eq!(value, SqlValue::Decimal("123.45".parse().unwrap()));
}
#[test]
fn variant_decimal_oversized_payload_is_null() {
let mut data = Vec::new();
data.extend_from_slice(&22u32.to_le_bytes());
data.push(0x6A); data.push(0x02); data.push(38); data.push(0); data.push(0x01); data.extend_from_slice(&[0u8; 17]); let mut buf: &[u8] = &data;
let value = parse_sql_variant(&mut buf).expect("oversized payload must not error");
assert_eq!(value, SqlValue::Null);
assert!(buf.is_empty(), "the whole payload must be consumed");
}
#[cfg(feature = "chrono")]
#[test]
fn hostile_time_intervals_do_not_panic() {
let t = intervals_to_time(u64::MAX, 0);
let _ = t; }
#[test]
fn test_parse_row_nvarchar_then_int() {
let raw_data = make_nvarchar_int_row("World", 42);
let col0 = ColumnData {
name: "greeting".to_string(),
type_id: TypeId::NVarChar,
col_type: 0xE7,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length: Some(10), precision: None,
scale: None,
collation: None,
},
crypto_metadata: None,
};
let col1 = ColumnData {
name: "number".to_string(),
type_id: TypeId::IntN,
col_type: 0x26,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length: Some(4),
precision: None,
scale: None,
collation: None,
},
crypto_metadata: None,
};
let mut buf: &[u8] = &raw_data;
let value0 = parse_column_value(&mut buf, &col0).unwrap();
match value0 {
SqlValue::String(s) => assert_eq!(s, "World"),
_ => panic!("expected String, got {value0:?}"),
}
let value1 = parse_column_value(&mut buf, &col1).unwrap();
match value1 {
SqlValue::Int(i) => assert_eq!(i, 42),
_ => panic!("expected Int, got {value1:?}"),
}
assert_eq!(buf.len(), 0, "buffer should be fully consumed");
}
#[test]
fn test_parse_row_multiple_types() {
let mut data = Vec::new();
data.extend_from_slice(&0xFFFFu16.to_le_bytes());
data.push(4); data.extend_from_slice(&123i32.to_le_bytes());
let utf16: Vec<u16> = "Test".encode_utf16().collect();
data.extend_from_slice(&((utf16.len() * 2) as u16).to_le_bytes());
for code_unit in utf16 {
data.extend_from_slice(&code_unit.to_le_bytes());
}
data.push(0);
let col0 = ColumnData {
name: "col0".to_string(),
type_id: TypeId::NVarChar,
col_type: 0xE7,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length: Some(100),
precision: None,
scale: None,
collation: None,
},
crypto_metadata: None,
};
let col1 = ColumnData {
name: "col1".to_string(),
type_id: TypeId::IntN,
col_type: 0x26,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length: Some(4),
precision: None,
scale: None,
collation: None,
},
crypto_metadata: None,
};
let col2 = col0.clone();
let col3 = col1.clone();
let mut buf: &[u8] = &data;
let v0 = parse_column_value(&mut buf, &col0).unwrap();
assert!(matches!(v0, SqlValue::Null), "col0 should be Null");
let v1 = parse_column_value(&mut buf, &col1).unwrap();
assert!(matches!(v1, SqlValue::Int(123)), "col1 should be 123");
let v2 = parse_column_value(&mut buf, &col2).unwrap();
match v2 {
SqlValue::String(s) => assert_eq!(s, "Test"),
_ => panic!("col2 should be 'Test'"),
}
let v3 = parse_column_value(&mut buf, &col3).unwrap();
assert!(matches!(v3, SqlValue::Null), "col3 should be Null");
assert_eq!(buf.len(), 0, "buffer should be fully consumed");
}
#[test]
fn test_parse_row_with_unicode() {
let test_str = "Héllo Wörld 日本語";
let mut data = Vec::new();
let utf16: Vec<u16> = test_str.encode_utf16().collect();
data.extend_from_slice(&((utf16.len() * 2) as u16).to_le_bytes());
for code_unit in utf16 {
data.extend_from_slice(&code_unit.to_le_bytes());
}
data.push(8); data.extend_from_slice(&9999999999i64.to_le_bytes());
let col0 = ColumnData {
name: "text".to_string(),
type_id: TypeId::NVarChar,
col_type: 0xE7,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length: Some(100),
precision: None,
scale: None,
collation: None,
},
crypto_metadata: None,
};
let col1 = ColumnData {
name: "num".to_string(),
type_id: TypeId::IntN,
col_type: 0x26,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length: Some(8),
precision: None,
scale: None,
collation: None,
},
crypto_metadata: None,
};
let mut buf: &[u8] = &data;
let v0 = parse_column_value(&mut buf, &col0).unwrap();
match v0 {
SqlValue::String(s) => assert_eq!(s, test_str),
_ => panic!("expected String"),
}
let v1 = parse_column_value(&mut buf, &col1).unwrap();
match v1 {
SqlValue::BigInt(i) => assert_eq!(i, 9999999999),
_ => panic!("expected BigInt"),
}
assert_eq!(buf.len(), 0, "buffer should be fully consumed");
}
}