use bytes::BytesMut;
use tds_protocol::rpc::{RpcParam, TypeInfo as RpcTypeInfo};
#[cfg(feature = "decimal")]
use tds_protocol::tvp::encode_tvp_decimal;
use tds_protocol::tvp::{
TvpColumnDef as TvpWireColumnDef, TvpEncoder, TvpWireType, encode_tvp_bit, encode_tvp_float,
encode_tvp_int, encode_tvp_null, encode_tvp_nvarchar, encode_tvp_varbinary, encode_tvp_varchar,
};
#[cfg(feature = "chrono")]
use tds_protocol::tvp::{encode_tvp_datetime, encode_tvp_smalldatetime};
#[cfg(feature = "decimal")]
use tds_protocol::tvp::{encode_tvp_money, encode_tvp_smallmoney};
use crate::error::{Error, Result};
use crate::state::ConnectionState;
use super::Client;
impl<S: ConnectionState> Client<S> {
pub(crate) fn sql_value_to_rpc_param(
name: &str,
sql_value: &mssql_types::SqlValue,
send_unicode: bool,
collation: Option<&tds_protocol::token::Collation>,
) -> Result<RpcParam> {
use bytes::{BufMut, BytesMut};
use mssql_types::SqlValue;
Ok(match sql_value {
SqlValue::Null => RpcParam::null(name, RpcTypeInfo::nvarchar(1)),
SqlValue::Bool(v) => {
let mut buf = BytesMut::with_capacity(1);
buf.put_u8(if *v { 1 } else { 0 });
RpcParam::new(name, RpcTypeInfo::bit(), buf.freeze())
}
SqlValue::TinyInt(v) => {
let mut buf = BytesMut::with_capacity(1);
buf.put_u8(*v);
RpcParam::new(name, RpcTypeInfo::tinyint(), buf.freeze())
}
SqlValue::SmallInt(v) => {
let mut buf = BytesMut::with_capacity(2);
buf.put_i16_le(*v);
RpcParam::new(name, RpcTypeInfo::smallint(), buf.freeze())
}
SqlValue::Int(v) => RpcParam::int(name, *v),
SqlValue::BigInt(v) => RpcParam::bigint(name, *v),
SqlValue::Float(v) => {
let mut buf = BytesMut::with_capacity(4);
buf.put_f32_le(*v);
RpcParam::new(name, RpcTypeInfo::real(), buf.freeze())
}
SqlValue::Double(v) => {
let mut buf = BytesMut::with_capacity(8);
buf.put_f64_le(*v);
RpcParam::new(name, RpcTypeInfo::float(), buf.freeze())
}
SqlValue::String(s) => {
if send_unicode {
RpcParam::nvarchar(name, s)
} else if let Some(c) = collation {
RpcParam::varchar_with_collation(name, s, c)
} else {
RpcParam::varchar(name, s)
}
}
SqlValue::Binary(b) => {
let type_info = if b.len() > 8000 {
RpcTypeInfo::varbinary_max()
} else {
RpcTypeInfo::varbinary(b.len().max(1) as u16)
};
RpcParam::new(name, type_info, b.clone())
}
SqlValue::Xml(s) => RpcParam::nvarchar(name, s),
#[cfg(feature = "uuid")]
SqlValue::Uuid(u) => {
let bytes = u.as_bytes();
let mut buf = BytesMut::with_capacity(16);
buf.put_u32_le(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]));
buf.put_u16_le(u16::from_be_bytes([bytes[4], bytes[5]]));
buf.put_u16_le(u16::from_be_bytes([bytes[6], bytes[7]]));
buf.put_slice(&bytes[8..16]);
RpcParam::new(name, RpcTypeInfo::uniqueidentifier(), buf.freeze())
}
#[cfg(feature = "decimal")]
SqlValue::Decimal(d) => {
let mut buf = BytesMut::with_capacity(17);
mssql_types::encode::encode_decimal(*d, &mut buf);
let scale = d.scale() as u8;
RpcParam::new(name, RpcTypeInfo::decimal(38, scale), buf.freeze())
}
#[cfg(feature = "decimal")]
SqlValue::Money(d) => {
let mut buf = BytesMut::with_capacity(8);
mssql_types::encode::encode_money(*d, &mut buf)?;
RpcParam::new(name, RpcTypeInfo::money(), buf.freeze())
}
#[cfg(feature = "decimal")]
SqlValue::SmallMoney(d) => {
let mut buf = BytesMut::with_capacity(4);
mssql_types::encode::encode_smallmoney(*d, &mut buf)?;
RpcParam::new(name, RpcTypeInfo::smallmoney(), buf.freeze())
}
#[cfg(feature = "chrono")]
SqlValue::Date(d) => {
let mut buf = BytesMut::with_capacity(3);
mssql_types::encode::encode_date(*d, &mut buf);
RpcParam::new(name, RpcTypeInfo::date(), buf.freeze())
}
#[cfg(feature = "chrono")]
SqlValue::Time(t) => {
let mut buf = BytesMut::with_capacity(5);
mssql_types::encode::encode_time(*t, &mut buf);
RpcParam::new(name, RpcTypeInfo::time(7), buf.freeze())
}
#[cfg(feature = "chrono")]
SqlValue::DateTime(dt) => {
let mut buf = BytesMut::with_capacity(8);
mssql_types::encode::encode_datetime2(*dt, &mut buf);
RpcParam::new(name, RpcTypeInfo::datetime2(7), buf.freeze())
}
#[cfg(feature = "chrono")]
SqlValue::SmallDateTime(dt) => {
let mut buf = BytesMut::with_capacity(4);
mssql_types::encode::encode_smalldatetime(*dt, &mut buf)?;
RpcParam::new(name, RpcTypeInfo::smalldatetime(), buf.freeze())
}
#[cfg(feature = "chrono")]
SqlValue::DateTimeOffset(dto) => {
let mut buf = BytesMut::with_capacity(10);
mssql_types::encode::encode_datetimeoffset(*dto, &mut buf);
RpcParam::new(name, RpcTypeInfo::datetimeoffset(7), buf.freeze())
}
#[cfg(feature = "json")]
SqlValue::Json(j) => RpcParam::nvarchar(name, &j.to_string()),
SqlValue::Tvp(tvp_data) => Self::encode_tvp_param(name, tvp_data)?,
_ => {
return Err(Error::Type(mssql_types::TypeError::UnsupportedConversion {
from: sql_value.type_name().to_string(),
to: "RPC parameter",
}));
}
})
}
pub(crate) fn convert_single_param(
name: &str,
value: &(dyn crate::ToSql + Sync),
send_unicode: bool,
collation: Option<&tds_protocol::token::Collation>,
) -> Result<RpcParam> {
let sql_value = value.to_sql()?;
Self::sql_value_to_rpc_param(name, &sql_value, send_unicode, collation)
}
pub(crate) fn convert_params(
params: &[&(dyn crate::ToSql + Sync)],
send_unicode: bool,
collation: Option<&tds_protocol::token::Collation>,
) -> Result<Vec<RpcParam>> {
params
.iter()
.enumerate()
.map(|(i, p)| {
let name = format!("@p{}", i + 1);
Self::convert_single_param(&name, *p, send_unicode, collation)
})
.collect()
}
pub(crate) fn convert_params_positional(
params: &[&(dyn crate::ToSql + Sync)],
send_unicode: bool,
collation: Option<&tds_protocol::token::Collation>,
) -> Result<Vec<RpcParam>> {
params
.iter()
.map(|p| Self::convert_single_param("", *p, send_unicode, collation))
.collect()
}
pub(crate) fn convert_named_params(
params: &[crate::to_params::NamedParam],
send_unicode: bool,
collation: Option<&tds_protocol::token::Collation>,
) -> Result<Vec<RpcParam>> {
params
.iter()
.map(|p| {
let name = if p.name.starts_with('@') {
p.name.clone()
} else {
format!("@{}", p.name)
};
Self::sql_value_to_rpc_param(&name, &p.value, send_unicode, collation)
})
.collect()
}
fn encode_tvp_param(name: &str, tvp_data: &mssql_types::TvpData) -> Result<RpcParam> {
let wire_columns: Vec<TvpWireColumnDef> = tvp_data
.columns
.iter()
.map(|col| {
let wire_type = Self::convert_tvp_column_type(&col.column_type)?;
Ok(if col.nullable {
TvpWireColumnDef::nullable(wire_type)
} else {
TvpWireColumnDef::new(wire_type)
})
})
.collect::<Result<Vec<_>>>()?;
let encoder = TvpEncoder::new(&tvp_data.schema, &tvp_data.type_name, &wire_columns);
let mut buf = BytesMut::with_capacity(256);
encoder.encode_metadata(&mut buf);
for row in &tvp_data.rows {
encoder.encode_row(&mut buf, |row_buf| {
for (col_idx, value) in row.iter().enumerate() {
let wire_type = &wire_columns[col_idx].wire_type;
Self::encode_tvp_value(value, wire_type, row_buf);
}
});
}
encoder.encode_end(&mut buf);
let full_type_name = if tvp_data.schema.is_empty() {
tvp_data.type_name.clone()
} else {
format!("{}.{}", tvp_data.schema, tvp_data.type_name)
};
let type_info = RpcTypeInfo::tvp(&full_type_name);
Ok(RpcParam {
name: name.to_string(),
flags: tds_protocol::rpc::ParamFlags::default(),
type_info,
value: Some(buf.freeze()),
})
}
fn convert_tvp_column_type(col_type: &mssql_types::TvpColumnType) -> Result<TvpWireType> {
#[allow(unreachable_patterns)]
Ok(match col_type {
mssql_types::TvpColumnType::Bit => TvpWireType::Bit,
mssql_types::TvpColumnType::TinyInt => TvpWireType::Int { size: 1 },
mssql_types::TvpColumnType::SmallInt => TvpWireType::Int { size: 2 },
mssql_types::TvpColumnType::Int => TvpWireType::Int { size: 4 },
mssql_types::TvpColumnType::BigInt => TvpWireType::Int { size: 8 },
mssql_types::TvpColumnType::Real => TvpWireType::Float { size: 4 },
mssql_types::TvpColumnType::Float => TvpWireType::Float { size: 8 },
mssql_types::TvpColumnType::Decimal { precision, scale } => TvpWireType::Decimal {
precision: *precision,
scale: *scale,
},
mssql_types::TvpColumnType::NVarChar { max_length } => TvpWireType::NVarChar {
max_length: *max_length,
},
mssql_types::TvpColumnType::VarChar { max_length } => TvpWireType::VarChar {
max_length: *max_length,
},
mssql_types::TvpColumnType::VarBinary { max_length } => TvpWireType::VarBinary {
max_length: *max_length,
},
mssql_types::TvpColumnType::UniqueIdentifier => TvpWireType::Guid,
mssql_types::TvpColumnType::Date => TvpWireType::Date,
mssql_types::TvpColumnType::Time { scale } => TvpWireType::Time { scale: *scale },
mssql_types::TvpColumnType::DateTime2 { scale } => {
TvpWireType::DateTime2 { scale: *scale }
}
mssql_types::TvpColumnType::DateTimeOffset { scale } => {
TvpWireType::DateTimeOffset { scale: *scale }
}
mssql_types::TvpColumnType::Money => TvpWireType::Money,
mssql_types::TvpColumnType::SmallMoney => TvpWireType::SmallMoney,
mssql_types::TvpColumnType::DateTime => TvpWireType::DateTime,
mssql_types::TvpColumnType::SmallDateTime => TvpWireType::SmallDateTime,
mssql_types::TvpColumnType::Xml => TvpWireType::Xml,
_ => {
return Err(Error::Type(mssql_types::TypeError::UnsupportedConversion {
from: format!("{col_type:?}"),
to: "TVP wire type",
}));
}
})
}
fn encode_tvp_value(
value: &mssql_types::SqlValue,
wire_type: &TvpWireType,
buf: &mut BytesMut,
) {
use mssql_types::SqlValue;
match value {
SqlValue::Null => {
encode_tvp_null(wire_type, buf);
}
SqlValue::Bool(v) => {
encode_tvp_bit(*v, buf);
}
SqlValue::TinyInt(v) => {
encode_tvp_int(*v as i64, 1, buf);
}
SqlValue::SmallInt(v) => {
encode_tvp_int(*v as i64, 2, buf);
}
SqlValue::Int(v) => {
encode_tvp_int(*v as i64, 4, buf);
}
SqlValue::BigInt(v) => {
encode_tvp_int(*v, 8, buf);
}
SqlValue::Float(v) => {
encode_tvp_float(*v as f64, 4, buf);
}
SqlValue::Double(v) => {
encode_tvp_float(*v, 8, buf);
}
SqlValue::String(s) => match wire_type {
TvpWireType::NVarChar { max_length } => {
encode_tvp_nvarchar(s, *max_length, buf);
}
TvpWireType::VarChar { max_length } => {
encode_tvp_varchar(s, *max_length, buf);
}
_ => {
encode_tvp_nvarchar(s, 4000, buf);
}
},
SqlValue::Binary(b) => {
let max_len = match wire_type {
TvpWireType::VarBinary { max_length } => *max_length,
_ => 8000,
};
encode_tvp_varbinary(b, max_len, buf);
}
#[cfg(feature = "decimal")]
SqlValue::Decimal(d) | SqlValue::Money(d) | SqlValue::SmallMoney(d) => {
match wire_type {
TvpWireType::Money => {
let scaled = match mssql_types::encode::decimal_to_money_cents_i64(*d) {
Ok(v) => v,
Err(_) => {
encode_tvp_null(wire_type, buf);
return;
}
};
encode_tvp_money(scaled, buf);
}
TvpWireType::SmallMoney => {
let scaled = match mssql_types::encode::decimal_to_smallmoney_cents_i32(*d)
{
Ok(v) => v,
Err(_) => {
encode_tvp_null(wire_type, buf);
return;
}
};
encode_tvp_smallmoney(scaled, buf);
}
_ => {
let sign = if d.is_sign_negative() { 0u8 } else { 1u8 };
let mantissa = d.mantissa().unsigned_abs();
encode_tvp_decimal(sign, mantissa, buf);
}
}
}
#[cfg(feature = "uuid")]
SqlValue::Uuid(u) => {
let bytes = u.as_bytes();
tds_protocol::tvp::encode_tvp_guid(bytes, buf);
}
#[cfg(feature = "chrono")]
SqlValue::Date(d) => {
let base =
chrono::NaiveDate::from_ymd_opt(1, 1, 1).expect("epoch 0001-01-01 is valid");
let days = d.signed_duration_since(base).num_days() as u32;
tds_protocol::tvp::encode_tvp_date(days, buf);
}
#[cfg(feature = "chrono")]
SqlValue::Time(t) => {
use chrono::Timelike;
let nanos =
t.num_seconds_from_midnight() as u64 * 1_000_000_000 + t.nanosecond() as u64;
let intervals = nanos / 100;
let scale = match wire_type {
TvpWireType::Time { scale } => *scale,
_ => 7,
};
tds_protocol::tvp::encode_tvp_time(intervals, scale, buf);
}
#[cfg(feature = "chrono")]
SqlValue::DateTime(dt) | SqlValue::SmallDateTime(dt) => {
match wire_type {
TvpWireType::DateTime => {
let (days, ticks) = mssql_types::encode::datetime_to_legacy_days_ticks(*dt);
encode_tvp_datetime(days, ticks, buf);
}
TvpWireType::SmallDateTime => {
match mssql_types::encode::datetime_to_smalldatetime_days_minutes(*dt) {
Ok((days, minutes)) => {
encode_tvp_smalldatetime(days, minutes, buf);
}
Err(_) => {
encode_tvp_null(wire_type, buf);
}
}
}
_ => {
use chrono::Timelike;
let nanos = dt.time().num_seconds_from_midnight() as u64 * 1_000_000_000
+ dt.time().nanosecond() as u64;
let intervals = nanos / 100;
let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1)
.expect("epoch 0001-01-01 is valid");
let days = dt.date().signed_duration_since(base).num_days() as u32;
let scale = match wire_type {
TvpWireType::DateTime2 { scale } => *scale,
_ => 7,
};
tds_protocol::tvp::encode_tvp_datetime2(intervals, days, scale, buf);
}
}
}
#[cfg(feature = "chrono")]
SqlValue::DateTimeOffset(dto) => {
use chrono::{Offset, Timelike};
let nanos = dto.time().num_seconds_from_midnight() as u64 * 1_000_000_000
+ dto.time().nanosecond() as u64;
let intervals = nanos / 100;
let base =
chrono::NaiveDate::from_ymd_opt(1, 1, 1).expect("epoch 0001-01-01 is valid");
let days = dto.date_naive().signed_duration_since(base).num_days() as u32;
let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
let scale = match wire_type {
TvpWireType::DateTimeOffset { scale } => *scale,
_ => 7,
};
tds_protocol::tvp::encode_tvp_datetimeoffset(
intervals,
days,
offset_minutes,
scale,
buf,
);
}
#[cfg(feature = "json")]
SqlValue::Json(j) => {
encode_tvp_nvarchar(&j.to_string(), 0xFFFF, buf);
}
SqlValue::Xml(s) => {
encode_tvp_nvarchar(s, 0xFFFF, buf);
}
SqlValue::Tvp(_) => {
encode_tvp_null(wire_type, buf);
}
_ => {
encode_tvp_null(wire_type, buf);
}
}
}
}