use bytes::BytesMut;
use tds_protocol::rpc::{RpcParam, TypeInfo as RpcTypeInfo};
#[cfg(feature = "decimal")]
use tds_protocol::tvp::encode_tvp_decimal;
use tds_protocol::tvp::{
TvpColumnDef as TvpWireColumnDef, TvpEncoder, TvpWireType, encode_tvp_bit, encode_tvp_float,
encode_tvp_int, encode_tvp_null, encode_tvp_nvarchar, encode_tvp_varbinary,
};
use crate::error::{Error, Result};
use crate::state::ConnectionState;
use super::Client;
impl<S: ConnectionState> Client<S> {
pub(crate) fn convert_single_param(
name: &str,
value: &(dyn crate::ToSql + Sync),
) -> Result<RpcParam> {
use bytes::{BufMut, BytesMut};
use mssql_types::SqlValue;
let sql_value = value.to_sql()?;
Ok(match sql_value {
SqlValue::Null => RpcParam::null(name, RpcTypeInfo::nvarchar(1)),
SqlValue::Bool(v) => {
let mut buf = BytesMut::with_capacity(1);
buf.put_u8(if v { 1 } else { 0 });
RpcParam::new(name, RpcTypeInfo::bit(), buf.freeze())
}
SqlValue::TinyInt(v) => {
let mut buf = BytesMut::with_capacity(1);
buf.put_u8(v);
RpcParam::new(name, RpcTypeInfo::tinyint(), buf.freeze())
}
SqlValue::SmallInt(v) => {
let mut buf = BytesMut::with_capacity(2);
buf.put_i16_le(v);
RpcParam::new(name, RpcTypeInfo::smallint(), buf.freeze())
}
SqlValue::Int(v) => RpcParam::int(name, v),
SqlValue::BigInt(v) => RpcParam::bigint(name, v),
SqlValue::Float(v) => {
let mut buf = BytesMut::with_capacity(4);
buf.put_f32_le(v);
RpcParam::new(name, RpcTypeInfo::real(), buf.freeze())
}
SqlValue::Double(v) => {
let mut buf = BytesMut::with_capacity(8);
buf.put_f64_le(v);
RpcParam::new(name, RpcTypeInfo::float(), buf.freeze())
}
SqlValue::String(ref s) => RpcParam::nvarchar(name, s),
SqlValue::Binary(ref b) => {
RpcParam::new(name, RpcTypeInfo::varbinary(b.len() as u16), b.clone())
}
SqlValue::Xml(ref s) => RpcParam::nvarchar(name, s),
#[cfg(feature = "uuid")]
SqlValue::Uuid(u) => {
let bytes = u.as_bytes();
let mut buf = BytesMut::with_capacity(16);
buf.put_u32_le(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]));
buf.put_u16_le(u16::from_be_bytes([bytes[4], bytes[5]]));
buf.put_u16_le(u16::from_be_bytes([bytes[6], bytes[7]]));
buf.put_slice(&bytes[8..16]);
RpcParam::new(name, RpcTypeInfo::uniqueidentifier(), buf.freeze())
}
#[cfg(feature = "decimal")]
SqlValue::Decimal(d) => RpcParam::nvarchar(name, &d.to_string()),
#[cfg(feature = "chrono")]
SqlValue::Date(_)
| SqlValue::Time(_)
| SqlValue::DateTime(_)
| SqlValue::DateTimeOffset(_) => {
let s = match &sql_value {
SqlValue::Date(d) => d.to_string(),
SqlValue::Time(t) => t.to_string(),
SqlValue::DateTime(dt) => dt.to_string(),
SqlValue::DateTimeOffset(dto) => dto.to_rfc3339(),
_ => unreachable!(),
};
RpcParam::nvarchar(name, &s)
}
#[cfg(feature = "json")]
SqlValue::Json(ref j) => RpcParam::nvarchar(name, &j.to_string()),
SqlValue::Tvp(ref tvp_data) => Self::encode_tvp_param(name, tvp_data)?,
_ => {
return Err(Error::Type(mssql_types::TypeError::UnsupportedConversion {
from: sql_value.type_name().to_string(),
to: "RPC parameter",
}));
}
})
}
pub(crate) fn convert_params(params: &[&(dyn crate::ToSql + Sync)]) -> Result<Vec<RpcParam>> {
params
.iter()
.enumerate()
.map(|(i, p)| {
let name = format!("@p{}", i + 1);
Self::convert_single_param(&name, *p)
})
.collect()
}
fn encode_tvp_param(name: &str, tvp_data: &mssql_types::TvpData) -> Result<RpcParam> {
let wire_columns: Vec<TvpWireColumnDef> = tvp_data
.columns
.iter()
.map(|col| {
let wire_type = Self::convert_tvp_column_type(&col.column_type);
if col.nullable {
TvpWireColumnDef::nullable(wire_type)
} else {
TvpWireColumnDef::new(wire_type)
}
})
.collect();
let encoder = TvpEncoder::new(&tvp_data.schema, &tvp_data.type_name, &wire_columns);
let mut buf = BytesMut::with_capacity(256);
encoder.encode_metadata(&mut buf);
for row in &tvp_data.rows {
encoder.encode_row(&mut buf, |row_buf| {
for (col_idx, value) in row.iter().enumerate() {
let wire_type = &wire_columns[col_idx].wire_type;
Self::encode_tvp_value(value, wire_type, row_buf);
}
});
}
encoder.encode_end(&mut buf);
let full_type_name = if tvp_data.schema.is_empty() {
tvp_data.type_name.clone()
} else {
format!("{}.{}", tvp_data.schema, tvp_data.type_name)
};
let type_info = RpcTypeInfo::tvp(&full_type_name);
Ok(RpcParam {
name: name.to_string(),
flags: tds_protocol::rpc::ParamFlags::default(),
type_info,
value: Some(buf.freeze()),
})
}
fn convert_tvp_column_type(col_type: &mssql_types::TvpColumnType) -> TvpWireType {
#[allow(unreachable_patterns)]
match col_type {
mssql_types::TvpColumnType::Bit => TvpWireType::Bit,
mssql_types::TvpColumnType::TinyInt => TvpWireType::Int { size: 1 },
mssql_types::TvpColumnType::SmallInt => TvpWireType::Int { size: 2 },
mssql_types::TvpColumnType::Int => TvpWireType::Int { size: 4 },
mssql_types::TvpColumnType::BigInt => TvpWireType::Int { size: 8 },
mssql_types::TvpColumnType::Real => TvpWireType::Float { size: 4 },
mssql_types::TvpColumnType::Float => TvpWireType::Float { size: 8 },
mssql_types::TvpColumnType::Decimal { precision, scale } => TvpWireType::Decimal {
precision: *precision,
scale: *scale,
},
mssql_types::TvpColumnType::NVarChar { max_length } => TvpWireType::NVarChar {
max_length: *max_length,
},
mssql_types::TvpColumnType::VarChar { max_length } => TvpWireType::VarChar {
max_length: *max_length,
},
mssql_types::TvpColumnType::VarBinary { max_length } => TvpWireType::VarBinary {
max_length: *max_length,
},
mssql_types::TvpColumnType::UniqueIdentifier => TvpWireType::Guid,
mssql_types::TvpColumnType::Date => TvpWireType::Date,
mssql_types::TvpColumnType::Time { scale } => TvpWireType::Time { scale: *scale },
mssql_types::TvpColumnType::DateTime2 { scale } => {
TvpWireType::DateTime2 { scale: *scale }
}
mssql_types::TvpColumnType::DateTimeOffset { scale } => {
TvpWireType::DateTimeOffset { scale: *scale }
}
mssql_types::TvpColumnType::Xml => TvpWireType::Xml,
_ => unreachable!("unknown TvpColumnType variant"),
}
}
fn encode_tvp_value(
value: &mssql_types::SqlValue,
wire_type: &TvpWireType,
buf: &mut BytesMut,
) {
use mssql_types::SqlValue;
match value {
SqlValue::Null => {
encode_tvp_null(wire_type, buf);
}
SqlValue::Bool(v) => {
encode_tvp_bit(*v, buf);
}
SqlValue::TinyInt(v) => {
encode_tvp_int(*v as i64, 1, buf);
}
SqlValue::SmallInt(v) => {
encode_tvp_int(*v as i64, 2, buf);
}
SqlValue::Int(v) => {
encode_tvp_int(*v as i64, 4, buf);
}
SqlValue::BigInt(v) => {
encode_tvp_int(*v, 8, buf);
}
SqlValue::Float(v) => {
encode_tvp_float(*v as f64, 4, buf);
}
SqlValue::Double(v) => {
encode_tvp_float(*v, 8, buf);
}
SqlValue::String(s) => {
let max_len = match wire_type {
TvpWireType::NVarChar { max_length } => *max_length,
_ => 4000,
};
encode_tvp_nvarchar(s, max_len, buf);
}
SqlValue::Binary(b) => {
let max_len = match wire_type {
TvpWireType::VarBinary { max_length } => *max_length,
_ => 8000,
};
encode_tvp_varbinary(b, max_len, buf);
}
#[cfg(feature = "decimal")]
SqlValue::Decimal(d) => {
let sign = if d.is_sign_negative() { 0u8 } else { 1u8 };
let mantissa = d.mantissa().unsigned_abs();
encode_tvp_decimal(sign, mantissa, buf);
}
#[cfg(feature = "uuid")]
SqlValue::Uuid(u) => {
let bytes = u.as_bytes();
tds_protocol::tvp::encode_tvp_guid(bytes, buf);
}
#[cfg(feature = "chrono")]
SqlValue::Date(d) => {
let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
let days = d.signed_duration_since(base).num_days() as u32;
tds_protocol::tvp::encode_tvp_date(days, buf);
}
#[cfg(feature = "chrono")]
SqlValue::Time(t) => {
use chrono::Timelike;
let nanos =
t.num_seconds_from_midnight() as u64 * 1_000_000_000 + t.nanosecond() as u64;
let intervals = nanos / 100;
let scale = match wire_type {
TvpWireType::Time { scale } => *scale,
_ => 7,
};
tds_protocol::tvp::encode_tvp_time(intervals, scale, buf);
}
#[cfg(feature = "chrono")]
SqlValue::DateTime(dt) => {
use chrono::Timelike;
let nanos = dt.time().num_seconds_from_midnight() as u64 * 1_000_000_000
+ dt.time().nanosecond() as u64;
let intervals = nanos / 100;
let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
let days = dt.date().signed_duration_since(base).num_days() as u32;
let scale = match wire_type {
TvpWireType::DateTime2 { scale } => *scale,
_ => 7,
};
tds_protocol::tvp::encode_tvp_datetime2(intervals, days, scale, buf);
}
#[cfg(feature = "chrono")]
SqlValue::DateTimeOffset(dto) => {
use chrono::{Offset, Timelike};
let nanos = dto.time().num_seconds_from_midnight() as u64 * 1_000_000_000
+ dto.time().nanosecond() as u64;
let intervals = nanos / 100;
let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
let days = dto.date_naive().signed_duration_since(base).num_days() as u32;
let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
let scale = match wire_type {
TvpWireType::DateTimeOffset { scale } => *scale,
_ => 7,
};
tds_protocol::tvp::encode_tvp_datetimeoffset(
intervals,
days,
offset_minutes,
scale,
buf,
);
}
#[cfg(feature = "json")]
SqlValue::Json(j) => {
encode_tvp_nvarchar(&j.to_string(), 0xFFFF, buf);
}
SqlValue::Xml(s) => {
encode_tvp_nvarchar(s, 0xFFFF, buf);
}
SqlValue::Tvp(_) => {
encode_tvp_null(wire_type, buf);
}
_ => {
encode_tvp_null(wire_type, buf);
}
}
}
}