#![allow(clippy::unwrap_used, clippy::expect_used, clippy::needless_range_loop)]
use bytes::Buf;
use mssql_types::SqlValue;
use tds_protocol::token::{ColMetaData, Collation, ColumnData, NbcRow, RawRow};
use tds_protocol::types::TypeId;
use crate::error::{Error, Result};
#[cfg(feature = "chrono")]
fn smalldatetime_from_wire(days: i64, minutes: u32) -> Result<chrono::NaiveDateTime> {
let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).expect("epoch 1900-01-01 is valid");
let date = base
.checked_add_signed(chrono::Duration::days(days))
.ok_or_else(|| Error::Protocol(format!("SMALLDATETIME days out of range: {days}")))?;
let secs = u64::from(minutes) * 60;
let time = u32::try_from(secs)
.ok()
.and_then(|s| chrono::NaiveTime::from_num_seconds_from_midnight_opt(s, 0))
.ok_or_else(|| Error::Protocol(format!("SMALLDATETIME minutes out of range: {minutes}")))?;
Ok(date.and_time(time))
}
#[cfg(feature = "chrono")]
fn datetime_from_wire(days: i64, time_300ths: u64) -> Result<chrono::NaiveDateTime> {
let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).expect("epoch 1900-01-01 is valid");
let date = base
.checked_add_signed(chrono::Duration::days(days))
.ok_or_else(|| Error::Protocol(format!("DATETIME days out of range: {days}")))?;
let total_ms = (time_300ths * 1000) / 300;
let nanos = ((total_ms % 1000) * 1_000_000) as u32;
let time = u32::try_from(total_ms / 1000)
.ok()
.and_then(|secs| chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nanos))
.ok_or_else(|| {
Error::Protocol(format!(
"DATETIME time component out of range: {time_300ths}"
))
})?;
Ok(date.and_time(time))
}
pub(crate) fn convert_raw_row(
raw: &RawRow,
meta: &ColMetaData,
columns: &[crate::row::Column],
) -> Result<crate::row::Row> {
let mut values = Vec::with_capacity(meta.columns.len());
let mut buf = raw.data.as_ref();
for col in &meta.columns {
let value = parse_column_value(&mut buf, col)?;
values.push(value);
}
Ok(crate::row::Row::from_values(columns.to_vec(), values))
}
pub(crate) fn convert_nbc_row(
nbc: &NbcRow,
meta: &ColMetaData,
columns: &[crate::row::Column],
) -> Result<crate::row::Row> {
let mut values = Vec::with_capacity(meta.columns.len());
let mut buf = nbc.data.as_ref();
for (i, col) in meta.columns.iter().enumerate() {
if nbc.is_null(i) {
values.push(mssql_types::SqlValue::Null);
} else {
let value = parse_column_value(&mut buf, col)?;
values.push(value);
}
}
Ok(crate::row::Row::from_values(columns.to_vec(), values))
}
#[cfg(feature = "always-encrypted")]
pub(crate) fn convert_raw_row_decrypted(
raw: &RawRow,
meta: &ColMetaData,
columns: &[crate::row::Column],
decryptor: &crate::column_decryptor::ColumnDecryptor,
) -> Result<crate::row::Row> {
let mut values = Vec::with_capacity(meta.columns.len());
let mut buf = raw.data.as_ref();
for (i, col) in meta.columns.iter().enumerate() {
let value = if decryptor.is_encrypted(i) {
decrypt_column(&mut buf, col, decryptor, i)?
} else {
parse_column_value(&mut buf, col)?
};
values.push(value);
}
Ok(crate::row::Row::from_values(columns.to_vec(), values))
}
#[cfg(feature = "always-encrypted")]
pub(crate) fn convert_nbc_row_decrypted(
nbc: &NbcRow,
meta: &ColMetaData,
columns: &[crate::row::Column],
decryptor: &crate::column_decryptor::ColumnDecryptor,
) -> Result<crate::row::Row> {
let mut values = Vec::with_capacity(meta.columns.len());
let mut buf = nbc.data.as_ref();
for (i, col) in meta.columns.iter().enumerate() {
if nbc.is_null(i) {
values.push(SqlValue::Null);
} else {
let value = if decryptor.is_encrypted(i) {
decrypt_column(&mut buf, col, decryptor, i)?
} else {
parse_column_value(&mut buf, col)?
};
values.push(value);
}
}
Ok(crate::row::Row::from_values(columns.to_vec(), values))
}
#[cfg(feature = "always-encrypted")]
fn decrypt_column(
buf: &mut &[u8],
_col: &ColumnData,
decryptor: &crate::column_decryptor::ColumnDecryptor,
ordinal: usize,
) -> Result<SqlValue> {
if buf.remaining() < 2 {
return Err(Error::Protocol(
"unexpected EOF reading encrypted column length".to_string(),
));
}
let length = buf.get_u16_le();
if length == 0xFFFF {
return Ok(SqlValue::Null);
}
let length = length as usize;
if buf.remaining() < length {
return Err(Error::Protocol(format!(
"unexpected EOF reading encrypted column data: need {length} bytes, have {}",
buf.remaining()
)));
}
let ciphertext = &buf[..length];
buf.advance(length);
let (plaintext, base_col) = decryptor.decrypt_column_value(ordinal, ciphertext)?;
let mut pt_buf: &[u8] = &plaintext;
parse_column_value(&mut pt_buf, base_col)
}
fn parse_money_value(buf: &mut &[u8], bytes: usize) -> Result<SqlValue> {
if bytes == 0 {
return Ok(SqlValue::Null);
}
let cents = match bytes {
4 => buf.get_i32_le() as i64,
8 => {
let high = buf.get_i32_le();
let low = buf.get_u32_le();
((high as i64) << 32) | (low as i64)
}
_ => return Err(Error::Protocol(format!("invalid money length: {bytes}"))),
};
#[cfg(feature = "decimal")]
{
use rust_decimal::Decimal;
match Decimal::try_from_i128_with_scale(cents as i128, 4) {
Ok(decimal) => Ok(SqlValue::Decimal(decimal)),
Err(_) => Ok(SqlValue::Double((cents as f64) / 10000.0)),
}
}
#[cfg(not(feature = "decimal"))]
{
Ok(SqlValue::Double((cents as f64) / 10000.0))
}
}
pub fn parse_column_value(buf: &mut &[u8], col: &ColumnData) -> Result<SqlValue> {
let value = match col.type_id {
TypeId::Null => SqlValue::Null,
TypeId::Int1 => {
if buf.remaining() < 1 {
return Err(Error::Protocol("unexpected EOF reading TINYINT".into()));
}
SqlValue::TinyInt(buf.get_u8())
}
TypeId::Bit => {
if buf.remaining() < 1 {
return Err(Error::Protocol("unexpected EOF reading BIT".into()));
}
SqlValue::Bool(buf.get_u8() != 0)
}
TypeId::Int2 => {
if buf.remaining() < 2 {
return Err(Error::Protocol("unexpected EOF reading SMALLINT".into()));
}
SqlValue::SmallInt(buf.get_i16_le())
}
TypeId::Int4 => {
if buf.remaining() < 4 {
return Err(Error::Protocol("unexpected EOF reading INT".into()));
}
SqlValue::Int(buf.get_i32_le())
}
TypeId::Float4 => {
if buf.remaining() < 4 {
return Err(Error::Protocol("unexpected EOF reading REAL".into()));
}
SqlValue::Float(buf.get_f32_le())
}
TypeId::Int8 => {
if buf.remaining() < 8 {
return Err(Error::Protocol("unexpected EOF reading BIGINT".into()));
}
SqlValue::BigInt(buf.get_i64_le())
}
TypeId::Float8 => {
if buf.remaining() < 8 {
return Err(Error::Protocol("unexpected EOF reading FLOAT".into()));
}
SqlValue::Double(buf.get_f64_le())
}
TypeId::Money | TypeId::Money4 | TypeId::MoneyN => {
let bytes = match col.type_id {
TypeId::Money => 8,
TypeId::Money4 => 4,
TypeId::MoneyN => {
if buf.remaining() < 1 {
return Err(Error::Protocol(
"unexpected EOF reading MoneyN length".into(),
));
}
buf.get_u8() as usize
}
_ => unreachable!("inner match is bounded by outer Money|Money4|MoneyN arm"),
};
if buf.remaining() < bytes {
return Err(Error::Protocol(format!(
"unexpected EOF reading money data ({bytes} bytes)"
)));
}
parse_money_value(buf, bytes)?
}
TypeId::IntN => {
if buf.remaining() < 1 {
return Err(Error::Protocol("unexpected EOF reading IntN length".into()));
}
let len = buf.get_u8();
if buf.remaining() < len as usize {
return Err(Error::Protocol("unexpected EOF reading IntN data".into()));
}
match len {
0 => SqlValue::Null,
1 => SqlValue::TinyInt(buf.get_u8()),
2 => SqlValue::SmallInt(buf.get_i16_le()),
4 => SqlValue::Int(buf.get_i32_le()),
8 => SqlValue::BigInt(buf.get_i64_le()),
_ => {
return Err(Error::Protocol(format!("invalid IntN length: {len}")));
}
}
}
TypeId::FloatN => {
if buf.remaining() < 1 {
return Err(Error::Protocol(
"unexpected EOF reading FloatN length".into(),
));
}
let len = buf.get_u8();
if buf.remaining() < len as usize {
return Err(Error::Protocol("unexpected EOF reading FloatN data".into()));
}
match len {
0 => SqlValue::Null,
4 => SqlValue::Float(buf.get_f32_le()),
8 => SqlValue::Double(buf.get_f64_le()),
_ => {
return Err(Error::Protocol(format!("invalid FloatN length: {len}")));
}
}
}
TypeId::BitN => {
if buf.remaining() < 1 {
return Err(Error::Protocol("unexpected EOF reading BitN length".into()));
}
let len = buf.get_u8();
if buf.remaining() < len as usize {
return Err(Error::Protocol("unexpected EOF reading BitN data".into()));
}
match len {
0 => SqlValue::Null,
1 => SqlValue::Bool(buf.get_u8() != 0),
_ => {
return Err(Error::Protocol(format!("invalid BitN length: {len}")));
}
}
}
TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
if buf.remaining() < 1 {
return Err(Error::Protocol(
"unexpected EOF reading DECIMAL/NUMERIC length".into(),
));
}
let len = buf.get_u8() as usize;
if len == 0 {
SqlValue::Null
} else {
if buf.remaining() < len {
return Err(Error::Protocol(
"unexpected EOF reading DECIMAL/NUMERIC data".into(),
));
}
let sign = buf.get_u8();
let mantissa_len = len - 1;
let mut mantissa_bytes = [0u8; 16];
for i in 0..mantissa_len.min(16) {
mantissa_bytes[i] = buf.get_u8();
}
for _ in 16..mantissa_len {
buf.get_u8();
}
let mantissa = u128::from_le_bytes(mantissa_bytes);
let scale = col.type_info.scale.unwrap_or(0) as u32;
#[cfg(feature = "decimal")]
{
use rust_decimal::Decimal;
let decimal = i128::try_from(mantissa)
.ok()
.and_then(|m| Decimal::try_from_i128_with_scale(m, scale).ok());
match decimal {
Some(mut decimal) => {
if sign == 0 {
decimal.set_sign_negative(true);
}
SqlValue::Decimal(decimal)
}
None => {
return Err(mssql_types::TypeError::InvalidDecimal(format!(
"NUMERIC value (mantissa {mantissa}, scale {scale}) exceeds \
rust_decimal's 96-bit/scale-28 range; CAST the column to a \
narrower NUMERIC, FLOAT, or VARCHAR in the query"
))
.into());
}
}
}
#[cfg(not(feature = "decimal"))]
{
let divisor = 10f64.powi(scale as i32);
let value = (mantissa as f64) / divisor;
let value = if sign == 0 { -value } else { value };
SqlValue::Double(value)
}
}
}
TypeId::DateTimeN => {
if buf.remaining() < 1 {
return Err(Error::Protocol(
"unexpected EOF reading DateTimeN length".into(),
));
}
let len = buf.get_u8() as usize;
if len == 0 {
SqlValue::Null
} else if buf.remaining() < len {
return Err(Error::Protocol("unexpected EOF reading DateTimeN".into()));
} else {
match len {
4 => {
let days = buf.get_u16_le() as i64;
let minutes = buf.get_u16_le() as u32;
#[cfg(feature = "chrono")]
{
SqlValue::DateTime(smalldatetime_from_wire(days, minutes)?)
}
#[cfg(not(feature = "chrono"))]
{
SqlValue::String(format!("SMALLDATETIME({days},{minutes})"))
}
}
8 => {
let days = buf.get_i32_le() as i64;
let time_300ths = buf.get_u32_le() as u64;
#[cfg(feature = "chrono")]
{
SqlValue::DateTime(datetime_from_wire(days, time_300ths)?)
}
#[cfg(not(feature = "chrono"))]
{
SqlValue::String(format!("DATETIME({days},{time_300ths})"))
}
}
_ => {
return Err(Error::Protocol(format!("invalid DateTimeN length: {len}")));
}
}
}
}
TypeId::DateTime => {
if buf.remaining() < 8 {
return Err(Error::Protocol("unexpected EOF reading DATETIME".into()));
}
let days = buf.get_i32_le() as i64;
let time_300ths = buf.get_u32_le() as u64;
#[cfg(feature = "chrono")]
{
SqlValue::DateTime(datetime_from_wire(days, time_300ths)?)
}
#[cfg(not(feature = "chrono"))]
{
SqlValue::String(format!("DATETIME({days},{time_300ths})"))
}
}
TypeId::DateTime4 => {
if buf.remaining() < 4 {
return Err(Error::Protocol(
"unexpected EOF reading SMALLDATETIME".into(),
));
}
let days = buf.get_u16_le() as i64;
let minutes = buf.get_u16_le() as u32;
#[cfg(feature = "chrono")]
{
SqlValue::DateTime(smalldatetime_from_wire(days, minutes)?)
}
#[cfg(not(feature = "chrono"))]
{
SqlValue::String(format!("SMALLDATETIME({days},{minutes})"))
}
}
TypeId::Date => {
if buf.remaining() < 1 {
return Err(Error::Protocol("unexpected EOF reading DATE length".into()));
}
let len = buf.get_u8() as usize;
if len == 0 {
SqlValue::Null
} else if len != 3 {
return Err(Error::Protocol(format!("invalid DATE length: {len}")));
} else if buf.remaining() < 3 {
return Err(Error::Protocol("unexpected EOF reading DATE".into()));
} else {
let days = buf.get_u8() as u32
| ((buf.get_u8() as u32) << 8)
| ((buf.get_u8() as u32) << 16);
#[cfg(feature = "chrono")]
{
let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1)
.expect("epoch 0001-01-01 is valid");
let date = base + chrono::Duration::days(days as i64);
SqlValue::Date(date)
}
#[cfg(not(feature = "chrono"))]
{
SqlValue::String(format!("DATE({days})"))
}
}
}
TypeId::Time => {
if buf.remaining() < 1 {
return Err(Error::Protocol("unexpected EOF reading TIME length".into()));
}
let len = buf.get_u8() as usize;
if len == 0 {
SqlValue::Null
} else if buf.remaining() < len {
return Err(Error::Protocol("unexpected EOF reading TIME".into()));
} else {
let mut time_bytes = [0u8; 8];
for byte in time_bytes.iter_mut().take(len) {
*byte = buf.get_u8();
}
let intervals = u64::from_le_bytes(time_bytes);
#[cfg(feature = "chrono")]
{
let scale = col.type_info.scale.unwrap_or(7);
let time = intervals_to_time(intervals, scale);
SqlValue::Time(time)
}
#[cfg(not(feature = "chrono"))]
{
SqlValue::String(format!("TIME({intervals})"))
}
}
}
TypeId::DateTime2 => {
if buf.remaining() < 1 {
return Err(Error::Protocol(
"unexpected EOF reading DATETIME2 length".into(),
));
}
let len = buf.get_u8() as usize;
if len == 0 {
SqlValue::Null
} else if buf.remaining() < len {
return Err(Error::Protocol("unexpected EOF reading DATETIME2".into()));
} else {
let scale = col.type_info.scale.unwrap_or(7);
let time_len = time_bytes_for_scale(scale);
if len < time_len + 3 {
return Err(Error::Protocol(format!(
"DATETIME2 length {len} too short for scale {scale}"
)));
}
let mut time_bytes = [0u8; 8];
for byte in time_bytes.iter_mut().take(time_len) {
*byte = buf.get_u8();
}
let intervals = u64::from_le_bytes(time_bytes);
let days = buf.get_u8() as u32
| ((buf.get_u8() as u32) << 8)
| ((buf.get_u8() as u32) << 16);
#[cfg(feature = "chrono")]
{
let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1)
.expect("epoch 0001-01-01 is valid");
let date = base + chrono::Duration::days(days as i64);
let time = intervals_to_time(intervals, scale);
SqlValue::DateTime(date.and_time(time))
}
#[cfg(not(feature = "chrono"))]
{
SqlValue::String(format!("DATETIME2({days},{intervals})"))
}
}
}
TypeId::DateTimeOffset => {
if buf.remaining() < 1 {
return Err(Error::Protocol(
"unexpected EOF reading DATETIMEOFFSET length".into(),
));
}
let len = buf.get_u8() as usize;
if len == 0 {
SqlValue::Null
} else if buf.remaining() < len {
return Err(Error::Protocol(
"unexpected EOF reading DATETIMEOFFSET".into(),
));
} else {
let scale = col.type_info.scale.unwrap_or(7);
let time_len = time_bytes_for_scale(scale);
if len < time_len + 5 {
return Err(Error::Protocol(format!(
"DATETIMEOFFSET length {len} too short for scale {scale}"
)));
}
let mut time_bytes = [0u8; 8];
for byte in time_bytes.iter_mut().take(time_len) {
*byte = buf.get_u8();
}
let intervals = u64::from_le_bytes(time_bytes);
let days = buf.get_u8() as u32
| ((buf.get_u8() as u32) << 8)
| ((buf.get_u8() as u32) << 16);
let offset_minutes = buf.get_i16_le();
#[cfg(feature = "chrono")]
{
use chrono::TimeZone;
let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1)
.expect("epoch 0001-01-01 is valid");
let date = base + chrono::Duration::days(days as i64);
let time = intervals_to_time(intervals, scale);
let offset = chrono::FixedOffset::east_opt((offset_minutes as i32) * 60)
.unwrap_or_else(|| {
chrono::FixedOffset::east_opt(0).expect("UTC offset 0 is valid")
});
let datetime = offset.from_utc_datetime(&date.and_time(time));
SqlValue::DateTimeOffset(datetime)
}
#[cfg(not(feature = "chrono"))]
{
SqlValue::String(format!(
"DATETIMEOFFSET({days},{intervals},{offset_minutes})"
))
}
}
}
TypeId::Text => parse_plp_varchar(buf, col.type_info.collation.as_ref())?,
TypeId::Char | TypeId::VarChar => {
if buf.remaining() < 1 {
return Err(Error::Protocol(
"unexpected EOF reading legacy varchar length".into(),
));
}
let len = buf.get_u8();
if len == 0xFF {
SqlValue::Null
} else if len == 0 {
SqlValue::String(String::new())
} else if buf.remaining() < len as usize {
return Err(Error::Protocol(
"unexpected EOF reading legacy varchar data".into(),
));
} else {
let data = &buf[..len as usize];
let s = decode_varchar_string(data, col.type_info.collation.as_ref());
buf.advance(len as usize);
SqlValue::String(s)
}
}
TypeId::BigVarChar | TypeId::BigChar => {
if col.type_info.max_length == Some(0xFFFF) {
parse_plp_varchar(buf, col.type_info.collation.as_ref())?
} else {
if buf.remaining() < 2 {
return Err(Error::Protocol(
"unexpected EOF reading varchar length".into(),
));
}
let len = buf.get_u16_le();
if len == 0xFFFF {
SqlValue::Null
} else if buf.remaining() < len as usize {
return Err(Error::Protocol(
"unexpected EOF reading varchar data".into(),
));
} else {
let data = &buf[..len as usize];
let s = decode_varchar_string(data, col.type_info.collation.as_ref());
buf.advance(len as usize);
SqlValue::String(s)
}
}
}
TypeId::NText => parse_plp_nvarchar(buf)?,
TypeId::NVarChar | TypeId::NChar => {
if col.type_info.max_length == Some(0xFFFF) {
parse_plp_nvarchar(buf)?
} else {
if buf.remaining() < 2 {
return Err(Error::Protocol(
"unexpected EOF reading nvarchar length".into(),
));
}
let len = buf.get_u16_le();
if len == 0xFFFF {
SqlValue::Null
} else if buf.remaining() < len as usize {
return Err(Error::Protocol(
"unexpected EOF reading nvarchar data".into(),
));
} else {
let data = &buf[..len as usize];
let utf16: Vec<u16> = data
.chunks_exact(2)
.map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
.collect();
let s = String::from_utf16(&utf16)
.map_err(|_| Error::Protocol("invalid UTF-16 in nvarchar".into()))?;
buf.advance(len as usize);
SqlValue::String(s)
}
}
}
TypeId::Image => parse_plp_varbinary(buf)?,
TypeId::Binary | TypeId::VarBinary => {
if buf.remaining() < 1 {
return Err(Error::Protocol(
"unexpected EOF reading legacy varbinary length".into(),
));
}
let len = buf.get_u8();
if len == 0xFF {
SqlValue::Null
} else if len == 0 {
SqlValue::Binary(bytes::Bytes::new())
} else if buf.remaining() < len as usize {
return Err(Error::Protocol(
"unexpected EOF reading legacy varbinary data".into(),
));
} else {
let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
buf.advance(len as usize);
SqlValue::Binary(data)
}
}
TypeId::BigVarBinary | TypeId::BigBinary => {
if col.type_info.max_length == Some(0xFFFF) {
parse_plp_varbinary(buf)?
} else {
if buf.remaining() < 2 {
return Err(Error::Protocol(
"unexpected EOF reading varbinary length".into(),
));
}
let len = buf.get_u16_le();
if len == 0xFFFF {
SqlValue::Null
} else if buf.remaining() < len as usize {
return Err(Error::Protocol(
"unexpected EOF reading varbinary data".into(),
));
} else {
let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
buf.advance(len as usize);
SqlValue::Binary(data)
}
}
}
TypeId::Xml => {
match parse_plp_nvarchar(buf)? {
SqlValue::Null => SqlValue::Null,
SqlValue::String(s) => SqlValue::Xml(s),
_ => {
return Err(Error::Protocol(
"unexpected value type when parsing XML".into(),
));
}
}
}
TypeId::Guid => {
if buf.remaining() < 1 {
return Err(Error::Protocol("unexpected EOF reading GUID length".into()));
}
let len = buf.get_u8();
if len == 0 {
SqlValue::Null
} else if len != 16 {
return Err(Error::Protocol(format!("invalid GUID length: {len}")));
} else if buf.remaining() < 16 {
return Err(Error::Protocol("unexpected EOF reading GUID".into()));
} else {
decode_guid_bytes(buf)
}
}
TypeId::Variant => parse_sql_variant(buf)?,
TypeId::Udt => parse_plp_varbinary(buf)?,
_ => {
if buf.remaining() < 2 {
return Err(Error::Protocol(format!(
"unexpected EOF reading {:?}",
col.type_id
)));
}
let len = buf.get_u16_le();
if len == 0xFFFF {
SqlValue::Null
} else if buf.remaining() < len as usize {
return Err(Error::Protocol(format!(
"unexpected EOF reading {:?} data",
col.type_id
)));
} else {
let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
buf.advance(len as usize);
SqlValue::Binary(data)
}
}
};
Ok(value)
}
pub(crate) fn parse_plp_nvarchar(buf: &mut &[u8]) -> Result<SqlValue> {
if buf.remaining() < 8 {
return Err(Error::Protocol(
"unexpected EOF reading PLP total length".into(),
));
}
let total_len = buf.get_u64_le();
if total_len == 0xFFFFFFFFFFFFFFFF {
return Ok(SqlValue::Null);
}
let mut all_data = Vec::new();
loop {
if buf.remaining() < 4 {
return Err(Error::Protocol(
"unexpected EOF reading PLP chunk length".into(),
));
}
let chunk_len = buf.get_u32_le() as usize;
if chunk_len == 0 {
break; }
if buf.remaining() < chunk_len {
return Err(Error::Protocol(
"unexpected EOF reading PLP chunk data".into(),
));
}
all_data.extend_from_slice(&buf[..chunk_len]);
buf.advance(chunk_len);
}
let utf16: Vec<u16> = all_data
.chunks_exact(2)
.map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
.collect();
let s = String::from_utf16(&utf16)
.map_err(|_| Error::Protocol("invalid UTF-16 in PLP nvarchar".into()))?;
Ok(SqlValue::String(s))
}
#[allow(unused_variables)]
fn decode_varchar_string(data: &[u8], collation: Option<&Collation>) -> String {
#[cfg(feature = "encoding")]
if let Some(coll) = collation {
if let Some(encoding) = coll.encoding() {
let (decoded, _, had_errors) = encoding.decode(data);
if !had_errors {
return decoded.into_owned();
}
}
}
String::from_utf8_lossy(data).into_owned()
}
fn parse_plp_varchar(buf: &mut &[u8], collation: Option<&Collation>) -> Result<SqlValue> {
if buf.remaining() < 8 {
return Err(Error::Protocol(
"unexpected EOF reading PLP total length".into(),
));
}
let total_len = buf.get_u64_le();
if total_len == 0xFFFFFFFFFFFFFFFF {
return Ok(SqlValue::Null);
}
let mut all_data = Vec::new();
loop {
if buf.remaining() < 4 {
return Err(Error::Protocol(
"unexpected EOF reading PLP chunk length".into(),
));
}
let chunk_len = buf.get_u32_le() as usize;
if chunk_len == 0 {
break; }
if buf.remaining() < chunk_len {
return Err(Error::Protocol(
"unexpected EOF reading PLP chunk data".into(),
));
}
all_data.extend_from_slice(&buf[..chunk_len]);
buf.advance(chunk_len);
}
let s = decode_varchar_string(&all_data, collation);
Ok(SqlValue::String(s))
}
pub(crate) fn parse_plp_varbinary(buf: &mut &[u8]) -> Result<SqlValue> {
if buf.remaining() < 8 {
return Err(Error::Protocol(
"unexpected EOF reading PLP total length".into(),
));
}
let total_len = buf.get_u64_le();
if total_len == 0xFFFFFFFFFFFFFFFF {
return Ok(SqlValue::Null);
}
let mut all_data = Vec::new();
loop {
if buf.remaining() < 4 {
return Err(Error::Protocol(
"unexpected EOF reading PLP chunk length".into(),
));
}
let chunk_len = buf.get_u32_le() as usize;
if chunk_len == 0 {
break; }
if buf.remaining() < chunk_len {
return Err(Error::Protocol(
"unexpected EOF reading PLP chunk data".into(),
));
}
all_data.extend_from_slice(&buf[..chunk_len]);
buf.advance(chunk_len);
}
Ok(SqlValue::Binary(bytes::Bytes::from(all_data)))
}
fn parse_sql_variant(buf: &mut &[u8]) -> Result<SqlValue> {
if buf.remaining() < 4 {
return Err(Error::Protocol(
"unexpected EOF reading SQL_VARIANT length".into(),
));
}
let total_len = buf.get_u32_le() as usize;
if total_len == 0 {
return Ok(SqlValue::Null);
}
if buf.remaining() < total_len {
return Err(Error::Protocol(
"unexpected EOF reading SQL_VARIANT data".into(),
));
}
if total_len < 2 {
return Err(Error::Protocol(
"SQL_VARIANT too short for type info".into(),
));
}
let base_type = buf.get_u8();
let prop_count = buf.get_u8() as usize;
if buf.remaining() < prop_count {
return Err(Error::Protocol(
"unexpected EOF reading SQL_VARIANT properties".into(),
));
}
let data_len = total_len.saturating_sub(2).saturating_sub(prop_count);
match base_type {
0x30 => {
buf.advance(prop_count);
if data_len < 1 {
return Ok(SqlValue::Null);
}
let v = buf.get_u8();
Ok(SqlValue::TinyInt(v))
}
0x32 => {
buf.advance(prop_count);
if data_len < 1 {
return Ok(SqlValue::Null);
}
let v = buf.get_u8();
Ok(SqlValue::Bool(v != 0))
}
0x34 => {
buf.advance(prop_count);
if data_len < 2 {
return Ok(SqlValue::Null);
}
let v = buf.get_i16_le();
Ok(SqlValue::SmallInt(v))
}
0x38 => {
buf.advance(prop_count);
if data_len < 4 {
return Ok(SqlValue::Null);
}
let v = buf.get_i32_le();
Ok(SqlValue::Int(v))
}
0x7F => {
buf.advance(prop_count);
if data_len < 8 {
return Ok(SqlValue::Null);
}
let v = buf.get_i64_le();
Ok(SqlValue::BigInt(v))
}
0x6D => {
let float_len = if prop_count >= 1 { buf.get_u8() } else { 8 };
buf.advance(prop_count.saturating_sub(1));
if float_len == 4 && data_len >= 4 {
let v = buf.get_f32_le();
Ok(SqlValue::Float(v))
} else if data_len >= 8 {
let v = buf.get_f64_le();
Ok(SqlValue::Double(v))
} else {
Ok(SqlValue::Null)
}
}
0x6E => {
let money_len = if prop_count >= 1 { buf.get_u8() } else { 8 };
buf.advance(prop_count.saturating_sub(1));
if money_len == 0 || data_len == 0 {
Ok(SqlValue::Null)
} else if (money_len == 4 && data_len >= 4) || (money_len == 8 && data_len >= 8) {
parse_money_value(buf, money_len as usize)
} else {
buf.advance(data_len);
Ok(SqlValue::Null)
}
}
0x6F => {
#[cfg(feature = "chrono")]
let dt_len = if prop_count >= 1 { buf.get_u8() } else { 8 };
#[cfg(not(feature = "chrono"))]
if prop_count >= 1 {
buf.get_u8();
}
buf.advance(prop_count.saturating_sub(1));
#[cfg(feature = "chrono")]
{
if dt_len == 4 && data_len >= 4 {
let days = buf.get_u16_le() as i64;
let mins = buf.get_u16_le() as u32;
Ok(SqlValue::DateTime(smalldatetime_from_wire(days, mins)?))
} else if data_len >= 8 {
let days = buf.get_i32_le() as i64;
let ticks = buf.get_u32_le() as u64;
Ok(SqlValue::DateTime(datetime_from_wire(days, ticks)?))
} else {
Ok(SqlValue::Null)
}
}
#[cfg(not(feature = "chrono"))]
{
buf.advance(data_len);
Ok(SqlValue::Null)
}
}
0x6A | 0x6C => {
let _precision = if prop_count >= 1 { buf.get_u8() } else { 18 };
let scale = if prop_count >= 2 { buf.get_u8() } else { 0 };
buf.advance(prop_count.saturating_sub(2));
if data_len < 1 {
return Ok(SqlValue::Null);
}
let sign = buf.get_u8();
let mantissa_len = data_len - 1;
if mantissa_len > 16 {
buf.advance(mantissa_len);
return Ok(SqlValue::Null);
}
let mut mantissa_bytes = [0u8; 16];
for i in 0..mantissa_len.min(16) {
mantissa_bytes[i] = buf.get_u8();
}
let mantissa = u128::from_le_bytes(mantissa_bytes);
#[cfg(feature = "decimal")]
{
use rust_decimal::Decimal;
let decimal = i128::try_from(mantissa)
.ok()
.and_then(|m| Decimal::try_from_i128_with_scale(m, scale as u32).ok());
match decimal {
Some(mut decimal) => {
if sign == 0 {
decimal.set_sign_negative(true);
}
Ok(SqlValue::Decimal(decimal))
}
None => Err(mssql_types::TypeError::InvalidDecimal(format!(
"NUMERIC value in sql_variant (mantissa {mantissa}, scale {scale}) \
exceeds rust_decimal's 96-bit/scale-28 range; CAST the column to a \
narrower NUMERIC, FLOAT, or VARCHAR in the query"
))
.into()),
}
}
#[cfg(not(feature = "decimal"))]
{
let divisor = 10f64.powi(scale as i32);
let value = (mantissa as f64) / divisor;
let value = if sign == 0 { -value } else { value };
Ok(SqlValue::Double(value))
}
}
0x24 => {
buf.advance(prop_count);
if data_len < 16 {
return Ok(SqlValue::Null);
}
Ok(decode_guid_bytes(buf))
}
0x28 => {
buf.advance(prop_count);
#[cfg(feature = "chrono")]
{
if data_len < 3 {
return Ok(SqlValue::Null);
}
let mut date_bytes = [0u8; 4];
date_bytes[0] = buf.get_u8();
date_bytes[1] = buf.get_u8();
date_bytes[2] = buf.get_u8();
let days = u32::from_le_bytes(date_bytes);
let base =
chrono::NaiveDate::from_ymd_opt(1, 1, 1).expect("epoch 0001-01-01 is valid");
let date = base + chrono::Duration::days(days as i64);
Ok(SqlValue::Date(date))
}
#[cfg(not(feature = "chrono"))]
{
buf.advance(data_len);
Ok(SqlValue::Null)
}
}
0x29 => {
#[cfg_attr(not(feature = "chrono"), allow(unused_variables))]
let scale = if prop_count >= 1 { buf.get_u8() } else { 7 };
buf.advance(prop_count.saturating_sub(1));
#[cfg(feature = "chrono")]
{
if data_len == 0 {
return Ok(SqlValue::Null);
}
let time_len = time_bytes_for_scale(scale);
if data_len < time_len {
return Ok(SqlValue::Null);
}
let mut time_bytes = [0u8; 8];
for byte in time_bytes.iter_mut().take(time_len) {
*byte = buf.get_u8();
}
if data_len > time_len {
buf.advance(data_len - time_len);
}
let intervals = u64::from_le_bytes(time_bytes);
Ok(SqlValue::Time(intervals_to_time(intervals, scale)))
}
#[cfg(not(feature = "chrono"))]
{
buf.advance(data_len);
Ok(SqlValue::Null)
}
}
0x2A => {
#[cfg_attr(not(feature = "chrono"), allow(unused_variables))]
let scale = if prop_count >= 1 { buf.get_u8() } else { 7 };
buf.advance(prop_count.saturating_sub(1));
#[cfg(feature = "chrono")]
{
let time_len = time_bytes_for_scale(scale);
if data_len < time_len + 3 {
return Ok(SqlValue::Null);
}
let mut time_bytes = [0u8; 8];
for byte in time_bytes.iter_mut().take(time_len) {
*byte = buf.get_u8();
}
let intervals = u64::from_le_bytes(time_bytes);
let days = buf.get_u8() as u32
| ((buf.get_u8() as u32) << 8)
| ((buf.get_u8() as u32) << 16);
let consumed = time_len + 3;
if data_len > consumed {
buf.advance(data_len - consumed);
}
let base =
chrono::NaiveDate::from_ymd_opt(1, 1, 1).expect("epoch 0001-01-01 is valid");
let date = base + chrono::Duration::days(days as i64);
let time = intervals_to_time(intervals, scale);
Ok(SqlValue::DateTime(date.and_time(time)))
}
#[cfg(not(feature = "chrono"))]
{
buf.advance(data_len);
Ok(SqlValue::Null)
}
}
0x2B => {
#[cfg_attr(not(feature = "chrono"), allow(unused_variables))]
let scale = if prop_count >= 1 { buf.get_u8() } else { 7 };
buf.advance(prop_count.saturating_sub(1));
#[cfg(feature = "chrono")]
{
let time_len = time_bytes_for_scale(scale);
if data_len < time_len + 3 + 2 {
return Ok(SqlValue::Null);
}
let mut time_bytes = [0u8; 8];
for byte in time_bytes.iter_mut().take(time_len) {
*byte = buf.get_u8();
}
let intervals = u64::from_le_bytes(time_bytes);
let days = buf.get_u8() as u32
| ((buf.get_u8() as u32) << 8)
| ((buf.get_u8() as u32) << 16);
let offset_minutes = buf.get_i16_le();
let consumed = time_len + 3 + 2;
if data_len > consumed {
buf.advance(data_len - consumed);
}
use chrono::TimeZone;
let base =
chrono::NaiveDate::from_ymd_opt(1, 1, 1).expect("epoch 0001-01-01 is valid");
let date = base + chrono::Duration::days(days as i64);
let time = intervals_to_time(intervals, scale);
let offset = chrono::FixedOffset::east_opt((offset_minutes as i32) * 60)
.unwrap_or_else(|| {
chrono::FixedOffset::east_opt(0).expect("UTC offset 0 is valid")
});
let datetime = offset.from_utc_datetime(&date.and_time(time));
Ok(SqlValue::DateTimeOffset(datetime))
}
#[cfg(not(feature = "chrono"))]
{
buf.advance(data_len);
Ok(SqlValue::Null)
}
}
0xA7 | 0x2F | 0x27 => {
let collation = if prop_count >= 5 && buf.remaining() >= 5 {
let lcid = buf.get_u32_le();
let sort_id = buf.get_u8();
buf.advance(prop_count.saturating_sub(5)); Some(Collation { lcid, sort_id })
} else {
buf.advance(prop_count);
None
};
if data_len == 0 {
return Ok(SqlValue::String(String::new()));
}
let data = &buf[..data_len];
let s = decode_varchar_string(data, collation.as_ref());
buf.advance(data_len);
Ok(SqlValue::String(s))
}
0xE7 | 0xEF => {
buf.advance(prop_count);
if data_len == 0 {
return Ok(SqlValue::String(String::new()));
}
let utf16: Vec<u16> = buf[..data_len]
.chunks_exact(2)
.map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
.collect();
buf.advance(data_len);
let s = String::from_utf16(&utf16)
.map_err(|_| Error::Protocol("invalid UTF-16 in SQL_VARIANT nvarchar".into()))?;
Ok(SqlValue::String(s))
}
0xA5 | 0x2D | 0x25 => {
buf.advance(prop_count);
let data = bytes::Bytes::copy_from_slice(&buf[..data_len]);
buf.advance(data_len);
Ok(SqlValue::Binary(data))
}
_ => {
buf.advance(prop_count);
let data = bytes::Bytes::copy_from_slice(&buf[..data_len]);
buf.advance(data_len);
Ok(SqlValue::Binary(data))
}
}
}
fn time_bytes_for_scale(scale: u8) -> usize {
match scale {
0..=2 => 3,
3..=4 => 4,
5..=7 => 5,
_ => 5, }
}
#[cfg(feature = "chrono")]
fn intervals_to_time(intervals: u64, scale: u8) -> chrono::NaiveTime {
let nanos = match scale {
0 => intervals.saturating_mul(1_000_000_000),
1 => intervals.saturating_mul(100_000_000),
2 => intervals.saturating_mul(10_000_000),
3 => intervals.saturating_mul(1_000_000),
4 => intervals.saturating_mul(100_000),
5 => intervals.saturating_mul(10_000),
6 => intervals.saturating_mul(1_000),
7 => intervals.saturating_mul(100),
_ => intervals.saturating_mul(100),
};
let secs = (nanos / 1_000_000_000) as u32;
let nano_part = (nanos % 1_000_000_000) as u32;
chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nano_part)
.unwrap_or_else(|| chrono::NaiveTime::from_hms_opt(0, 0, 0).expect("midnight is valid"))
}
fn decode_guid_bytes(buf: &mut &[u8]) -> SqlValue {
let mut bytes = [0u8; 16];
bytes[3] = buf.get_u8();
bytes[2] = buf.get_u8();
bytes[1] = buf.get_u8();
bytes[0] = buf.get_u8();
bytes[5] = buf.get_u8();
bytes[4] = buf.get_u8();
bytes[7] = buf.get_u8();
bytes[6] = buf.get_u8();
for byte in &mut bytes[8..16] {
*byte = buf.get_u8();
}
#[cfg(feature = "uuid")]
{
SqlValue::Uuid(uuid::Uuid::from_bytes(bytes))
}
#[cfg(not(feature = "uuid"))]
{
SqlValue::Binary(bytes::Bytes::copy_from_slice(&bytes))
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use tds_protocol::token::TypeInfo;
fn make_plp_data(total_len: u64, chunks: &[&[u8]]) -> Vec<u8> {
let mut data = Vec::new();
data.extend_from_slice(&total_len.to_le_bytes());
for chunk in chunks {
let len = chunk.len() as u32;
data.extend_from_slice(&len.to_le_bytes());
data.extend_from_slice(chunk);
}
data.extend_from_slice(&0u32.to_le_bytes());
data
}
#[test]
fn test_parse_plp_nvarchar_simple() {
let utf16_data = [0x48, 0x00, 0x65, 0x00, 0x6C, 0x00, 0x6C, 0x00, 0x6F, 0x00];
let plp = make_plp_data(10, &[&utf16_data]);
let mut buf: &[u8] = &plp;
let result = parse_plp_nvarchar(&mut buf).unwrap();
match result {
SqlValue::String(s) => assert_eq!(s, "Hello"),
_ => panic!("expected String, got {result:?}"),
}
}
#[test]
fn test_parse_plp_nvarchar_null() {
let plp = 0xFFFFFFFFFFFFFFFFu64.to_le_bytes();
let mut buf: &[u8] = &plp;
let result = parse_plp_nvarchar(&mut buf).unwrap();
assert!(matches!(result, SqlValue::Null));
}
#[test]
fn test_parse_plp_nvarchar_empty() {
let plp = make_plp_data(0, &[]);
let mut buf: &[u8] = &plp;
let result = parse_plp_nvarchar(&mut buf).unwrap();
match result {
SqlValue::String(s) => assert_eq!(s, ""),
_ => panic!("expected empty String"),
}
}
#[test]
fn test_parse_plp_nvarchar_multi_chunk() {
let chunk1 = [0x48, 0x00, 0x65, 0x00, 0x6C, 0x00]; let chunk2 = [0x6C, 0x00, 0x6F, 0x00]; let plp = make_plp_data(10, &[&chunk1, &chunk2]);
let mut buf: &[u8] = &plp;
let result = parse_plp_nvarchar(&mut buf).unwrap();
match result {
SqlValue::String(s) => assert_eq!(s, "Hello"),
_ => panic!("expected String"),
}
}
#[test]
fn test_parse_plp_varchar_simple() {
let data = b"Hello World";
let plp = make_plp_data(11, &[data]);
let mut buf: &[u8] = &plp;
let result = parse_plp_varchar(&mut buf, None).unwrap();
match result {
SqlValue::String(s) => assert_eq!(s, "Hello World"),
_ => panic!("expected String"),
}
}
#[test]
fn test_parse_plp_varchar_null() {
let plp = 0xFFFFFFFFFFFFFFFFu64.to_le_bytes();
let mut buf: &[u8] = &plp;
let result = parse_plp_varchar(&mut buf, None).unwrap();
assert!(matches!(result, SqlValue::Null));
}
#[test]
fn test_parse_plp_varbinary_simple() {
let data = [0x01, 0x02, 0x03, 0x04, 0x05];
let plp = make_plp_data(5, &[&data]);
let mut buf: &[u8] = &plp;
let result = parse_plp_varbinary(&mut buf).unwrap();
match result {
SqlValue::Binary(b) => assert_eq!(&b[..], &[0x01, 0x02, 0x03, 0x04, 0x05]),
_ => panic!("expected Binary"),
}
}
#[test]
fn test_parse_plp_varbinary_null() {
let plp = 0xFFFFFFFFFFFFFFFFu64.to_le_bytes();
let mut buf: &[u8] = &plp;
let result = parse_plp_varbinary(&mut buf).unwrap();
assert!(matches!(result, SqlValue::Null));
}
#[test]
fn test_parse_plp_varbinary_large() {
let chunk1: Vec<u8> = (0..100u8).collect();
let chunk2: Vec<u8> = (100..200u8).collect();
let chunk3: Vec<u8> = (200..255u8).collect();
let total_len = chunk1.len() + chunk2.len() + chunk3.len();
let plp = make_plp_data(total_len as u64, &[&chunk1, &chunk2, &chunk3]);
let mut buf: &[u8] = &plp;
let result = parse_plp_varbinary(&mut buf).unwrap();
match result {
SqlValue::Binary(b) => {
assert_eq!(b.len(), 255);
for (i, &byte) in b.iter().enumerate() {
assert_eq!(byte, i as u8);
}
}
_ => panic!("expected Binary"),
}
}
fn make_nvarchar_int_row(nvarchar_value: &str, int_value: i32) -> Vec<u8> {
let mut data = Vec::new();
let utf16: Vec<u16> = nvarchar_value.encode_utf16().collect();
let byte_len = (utf16.len() * 2) as u16;
data.extend_from_slice(&byte_len.to_le_bytes());
for code_unit in utf16 {
data.extend_from_slice(&code_unit.to_le_bytes());
}
data.push(4); data.extend_from_slice(&int_value.to_le_bytes());
data
}
#[cfg(feature = "chrono")]
fn datetime_col(type_id: TypeId, col_type: u8, max_length: Option<u32>) -> ColumnData {
ColumnData {
name: "c".to_string(),
type_id,
col_type,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length,
precision: None,
scale: None,
collation: None,
},
crypto_metadata: None,
}
}
#[cfg(feature = "chrono")]
#[test]
fn hostile_smalldatetime_minutes_is_error_not_panic() {
let data = [4u8, 0x00, 0x00, 0xFF, 0xFF];
let col = datetime_col(TypeId::DateTimeN, 0x6F, Some(4));
let mut buf: &[u8] = &data;
assert!(parse_column_value(&mut buf, &col).is_err());
let data = [0x00, 0x00, 0xFF, 0xFF];
let col = datetime_col(TypeId::DateTime4, 0x3A, None);
let mut buf: &[u8] = &data;
assert!(parse_column_value(&mut buf, &col).is_err());
}
#[cfg(feature = "chrono")]
#[test]
fn hostile_datetime_days_overflow_is_error_not_panic() {
let mut data = vec![8u8];
data.extend_from_slice(&i32::MAX.to_le_bytes());
data.extend_from_slice(&0u32.to_le_bytes());
let col = datetime_col(TypeId::DateTimeN, 0x6F, Some(8));
let mut buf: &[u8] = &data;
assert!(parse_column_value(&mut buf, &col).is_err());
let mut data = Vec::new();
data.extend_from_slice(&i32::MIN.to_le_bytes());
data.extend_from_slice(&0u32.to_le_bytes());
let col = datetime_col(TypeId::DateTime, 0x3D, None);
let mut buf: &[u8] = &data;
assert!(parse_column_value(&mut buf, &col).is_err());
}
#[cfg(feature = "chrono")]
#[test]
fn hostile_datetime_time_300ths_is_error_not_panic() {
let mut data = vec![8u8];
data.extend_from_slice(&0i32.to_le_bytes());
data.extend_from_slice(&u32::MAX.to_le_bytes());
let col = datetime_col(TypeId::DateTimeN, 0x6F, Some(8));
let mut buf: &[u8] = &data;
assert!(parse_column_value(&mut buf, &col).is_err());
}
#[cfg(feature = "chrono")]
#[test]
fn hostile_truncated_n_types_are_error_not_panic() {
for (type_id, col_type, len) in [
(TypeId::IntN, 0x26u8, 8u8),
(TypeId::FloatN, 0x6D, 4),
(TypeId::BitN, 0x68, 1),
] {
let data = [len];
let col = datetime_col(type_id, col_type, Some(len as u32));
let mut buf: &[u8] = &data;
assert!(
parse_column_value(&mut buf, &col).is_err(),
"{type_id:?} must error on truncated payload"
);
}
}
#[cfg(feature = "chrono")]
#[test]
fn hostile_short_datetime2_len_is_error_not_panic() {
let data = [1u8, 0xAA];
let mut col = datetime_col(TypeId::DateTime2, 0x2A, None);
col.type_info.scale = Some(7);
let mut buf: &[u8] = &data;
assert!(parse_column_value(&mut buf, &col).is_err());
let data = [1u8, 0xAA];
let mut col = datetime_col(TypeId::DateTimeOffset, 0x2B, None);
col.type_info.scale = Some(7);
let mut buf: &[u8] = &data;
assert!(parse_column_value(&mut buf, &col).is_err());
}
#[cfg(feature = "chrono")]
#[test]
fn hostile_variant_datetime_days_overflow_is_error_not_panic() {
let mut data = Vec::new();
data.extend_from_slice(&11u32.to_le_bytes());
data.push(0x6F);
data.push(0x01);
data.push(0x08);
data.extend_from_slice(&i32::MAX.to_le_bytes());
data.extend_from_slice(&0u32.to_le_bytes());
let mut buf: &[u8] = &data;
assert!(parse_sql_variant(&mut buf).is_err());
}
#[cfg(feature = "decimal")]
#[test]
fn hostile_variant_decimal_over_96bit_is_error_not_panic() {
let mantissa = 1u128 << 100;
let mut data = Vec::new();
data.extend_from_slice(&21u32.to_le_bytes());
data.push(0x6A); data.push(0x02); data.push(38); data.push(10); data.push(0x01); data.extend_from_slice(&mantissa.to_le_bytes());
let mut buf: &[u8] = &data;
let err = parse_sql_variant(&mut buf).expect_err("oversized NUMERIC must error");
assert!(
err.to_string().contains("rust_decimal"),
"error should explain the range limitation: {err}"
);
}
#[cfg(feature = "chrono")]
#[test]
fn hostile_time_intervals_do_not_panic() {
let t = intervals_to_time(u64::MAX, 0);
let _ = t; }
#[test]
fn test_parse_row_nvarchar_then_int() {
let raw_data = make_nvarchar_int_row("World", 42);
let col0 = ColumnData {
name: "greeting".to_string(),
type_id: TypeId::NVarChar,
col_type: 0xE7,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length: Some(10), precision: None,
scale: None,
collation: None,
},
crypto_metadata: None,
};
let col1 = ColumnData {
name: "number".to_string(),
type_id: TypeId::IntN,
col_type: 0x26,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length: Some(4),
precision: None,
scale: None,
collation: None,
},
crypto_metadata: None,
};
let mut buf: &[u8] = &raw_data;
let value0 = parse_column_value(&mut buf, &col0).unwrap();
match value0 {
SqlValue::String(s) => assert_eq!(s, "World"),
_ => panic!("expected String, got {value0:?}"),
}
let value1 = parse_column_value(&mut buf, &col1).unwrap();
match value1 {
SqlValue::Int(i) => assert_eq!(i, 42),
_ => panic!("expected Int, got {value1:?}"),
}
assert_eq!(buf.len(), 0, "buffer should be fully consumed");
}
#[test]
fn test_parse_row_multiple_types() {
let mut data = Vec::new();
data.extend_from_slice(&0xFFFFu16.to_le_bytes());
data.push(4); data.extend_from_slice(&123i32.to_le_bytes());
let utf16: Vec<u16> = "Test".encode_utf16().collect();
data.extend_from_slice(&((utf16.len() * 2) as u16).to_le_bytes());
for code_unit in utf16 {
data.extend_from_slice(&code_unit.to_le_bytes());
}
data.push(0);
let col0 = ColumnData {
name: "col0".to_string(),
type_id: TypeId::NVarChar,
col_type: 0xE7,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length: Some(100),
precision: None,
scale: None,
collation: None,
},
crypto_metadata: None,
};
let col1 = ColumnData {
name: "col1".to_string(),
type_id: TypeId::IntN,
col_type: 0x26,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length: Some(4),
precision: None,
scale: None,
collation: None,
},
crypto_metadata: None,
};
let col2 = col0.clone();
let col3 = col1.clone();
let mut buf: &[u8] = &data;
let v0 = parse_column_value(&mut buf, &col0).unwrap();
assert!(matches!(v0, SqlValue::Null), "col0 should be Null");
let v1 = parse_column_value(&mut buf, &col1).unwrap();
assert!(matches!(v1, SqlValue::Int(123)), "col1 should be 123");
let v2 = parse_column_value(&mut buf, &col2).unwrap();
match v2 {
SqlValue::String(s) => assert_eq!(s, "Test"),
_ => panic!("col2 should be 'Test'"),
}
let v3 = parse_column_value(&mut buf, &col3).unwrap();
assert!(matches!(v3, SqlValue::Null), "col3 should be Null");
assert_eq!(buf.len(), 0, "buffer should be fully consumed");
}
#[test]
fn test_parse_row_with_unicode() {
let test_str = "Héllo Wörld 日本語";
let mut data = Vec::new();
let utf16: Vec<u16> = test_str.encode_utf16().collect();
data.extend_from_slice(&((utf16.len() * 2) as u16).to_le_bytes());
for code_unit in utf16 {
data.extend_from_slice(&code_unit.to_le_bytes());
}
data.push(8); data.extend_from_slice(&9999999999i64.to_le_bytes());
let col0 = ColumnData {
name: "text".to_string(),
type_id: TypeId::NVarChar,
col_type: 0xE7,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length: Some(100),
precision: None,
scale: None,
collation: None,
},
crypto_metadata: None,
};
let col1 = ColumnData {
name: "num".to_string(),
type_id: TypeId::IntN,
col_type: 0x26,
flags: 0x01,
user_type: 0,
type_info: TypeInfo {
max_length: Some(8),
precision: None,
scale: None,
collation: None,
},
crypto_metadata: None,
};
let mut buf: &[u8] = &data;
let v0 = parse_column_value(&mut buf, &col0).unwrap();
match v0 {
SqlValue::String(s) => assert_eq!(s, test_str),
_ => panic!("expected String"),
}
let v1 = parse_column_value(&mut buf, &col1).unwrap();
match v1 {
SqlValue::BigInt(i) => assert_eq!(i, 9999999999),
_ => panic!("expected BigInt"),
}
assert_eq!(buf.len(), 0, "buffer should be fully consumed");
}
}