use std::fmt::{self, Write};
use sqlx_core::arguments::Arguments;
use sqlx_core::encode::{Encode, IsNull};
use sqlx_core::error::BoxDynError;
use sqlx_core::types::Type;
use crate::{Mssql, MssqlType, MssqlTypeInfo};
const DATA_TYPE_INTN: u8 = 0x26;
const DATA_TYPE_BITN: u8 = 0x68;
const DATA_TYPE_FLOATN: u8 = 0x6d;
const DATA_TYPE_BIGVARBINARY: u8 = 0xa5;
const DATA_TYPE_BIGVARCHAR: u8 = 0xa7;
const DATA_TYPE_NVARCHAR: u8 = 0xe7;
const DEFAULT_COLLATION: [u8; 5] = [0x81, 0x04, 0xd0, 0x00, 0x34];
const PLP_NULL: u64 = 0xffff_ffff_ffff_ffff;
const PLP_CHUNK_SIZE: usize = 8192;
const STATUS_BY_REF_VALUE: u8 = 0x01;
#[derive(Debug, Default, Clone)]
pub struct MssqlArguments {
len: usize,
data: Vec<u8>,
declarations: String,
}
impl MssqlArguments {
pub const fn is_empty(&self) -> bool {
self.len == 0
}
pub(crate) fn data(&self) -> &[u8] {
&self.data
}
pub(crate) fn declarations(&self) -> &str {
&self.declarations
}
fn add_parameter(
&mut self,
type_info: MssqlTypeInfo,
encoded: Vec<u8>,
is_null: bool,
) -> Result<(), BoxDynError> {
self.len += 1;
let name = format!("@p{}", self.len);
if !self.declarations.is_empty() {
self.declarations.push(',');
}
write!(
self.declarations,
"{name} {}",
declaration(&type_info, encoded.len(), is_null)?
)?;
write_parameter(&mut self.data, &name, &type_info, &encoded, is_null)?;
Ok(())
}
}
impl Arguments for MssqlArguments {
type Database = Mssql;
fn reserve(&mut self, _additional: usize, _size: usize) {}
fn add<'t, T>(&mut self, value: T) -> Result<(), BoxDynError>
where
T: Encode<'t, Self::Database> + Type<Self::Database>,
{
let type_info = value.produces().unwrap_or_else(T::type_info);
let mut encoded = Vec::with_capacity(value.size_hint());
let is_null = matches!(value.encode(&mut encoded)?, IsNull::Yes);
self.add_parameter(type_info, encoded, is_null)?;
Ok(())
}
fn len(&self) -> usize {
self.len
}
fn format_placeholder<W: Write>(&self, writer: &mut W) -> fmt::Result {
write!(writer, "@p{}", self.len)
}
}
pub(crate) fn write_parameter(
out: &mut Vec<u8>,
name: &str,
type_info: &MssqlTypeInfo,
encoded: &[u8],
is_null: bool,
) -> Result<(), BoxDynError> {
write_parameter_with_status(out, name, 0, type_info, encoded, is_null)
}
pub(crate) fn write_output_i32_parameter(
out: &mut Vec<u8>,
name: &str,
value: i32,
) -> Result<(), BoxDynError> {
write_parameter_with_status(
out,
name,
STATUS_BY_REF_VALUE,
&MssqlTypeInfo::INT,
&value.to_le_bytes(),
false,
)
}
fn write_parameter_with_status(
out: &mut Vec<u8>,
name: &str,
status: u8,
type_info: &MssqlTypeInfo,
encoded: &[u8],
is_null: bool,
) -> Result<(), BoxDynError> {
write_b_varchar(out, name)?;
out.push(status);
write_type_info(out, type_info, encoded.len(), is_null)?;
write_param_len_data(out, type_info, encoded, is_null)?;
Ok(())
}
pub(crate) fn write_nvarchar_parameter(
out: &mut Vec<u8>,
name: &str,
value: &str,
) -> Result<(), BoxDynError> {
let mut encoded = Vec::with_capacity(value.len() * 2);
write_utf16(&mut encoded, value);
write_parameter(out, name, &MssqlTypeInfo::NVARCHAR, &encoded, false)
}
pub(crate) fn write_null_nvarchar_parameter(
out: &mut Vec<u8>,
name: &str,
) -> Result<(), BoxDynError> {
write_parameter(out, name, &MssqlTypeInfo::NVARCHAR, &[], true)
}
pub(crate) fn type_declaration(type_info: &MssqlTypeInfo) -> Result<&'static str, BoxDynError> {
Ok(match type_info.kind() {
MssqlType::Bit => "bit",
MssqlType::TinyInt => "tinyint",
MssqlType::SmallInt => "smallint",
MssqlType::Int => "int",
MssqlType::BigInt => "bigint",
MssqlType::Real => "real",
MssqlType::Float => "float",
MssqlType::NVarChar => "nvarchar(max)",
MssqlType::VarChar => "varchar(max)",
MssqlType::VarBinary => "varbinary(max)",
other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
})
}
fn write_type_info(
out: &mut Vec<u8>,
type_info: &MssqlTypeInfo,
encoded_len: usize,
is_null: bool,
) -> Result<(), BoxDynError> {
match type_info.kind() {
MssqlType::Bit => {
out.push(DATA_TYPE_BITN);
out.push(1);
}
MssqlType::TinyInt => {
out.push(DATA_TYPE_INTN);
out.push(1);
}
MssqlType::SmallInt => {
out.push(DATA_TYPE_INTN);
out.push(2);
}
MssqlType::Int => {
out.push(DATA_TYPE_INTN);
out.push(4);
}
MssqlType::BigInt => {
out.push(DATA_TYPE_INTN);
out.push(8);
}
MssqlType::Real => {
out.push(DATA_TYPE_FLOATN);
out.push(4);
}
MssqlType::Float => {
out.push(DATA_TYPE_FLOATN);
out.push(8);
}
MssqlType::NVarChar => {
out.push(DATA_TYPE_NVARCHAR);
out.extend_from_slice(
&nvarchar_type_size(type_info, encoded_len, is_null)?.to_le_bytes(),
);
out.extend_from_slice(&DEFAULT_COLLATION);
}
MssqlType::VarChar => {
out.push(DATA_TYPE_BIGVARCHAR);
out.extend_from_slice(
&bounded_short_len(type_info, encoded_len, is_null)?.to_le_bytes(),
);
out.extend_from_slice(&DEFAULT_COLLATION);
}
MssqlType::VarBinary => {
out.push(DATA_TYPE_BIGVARBINARY);
out.extend_from_slice(
&bounded_short_len(type_info, encoded_len, is_null)?.to_le_bytes(),
);
}
other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
}
Ok(())
}
fn write_param_len_data(
out: &mut Vec<u8>,
type_info: &MssqlTypeInfo,
encoded: &[u8],
is_null: bool,
) -> Result<(), BoxDynError> {
match type_info.kind() {
MssqlType::Bit
| MssqlType::TinyInt
| MssqlType::SmallInt
| MssqlType::Int
| MssqlType::BigInt
| MssqlType::Real
| MssqlType::Float => {
out.push(if is_null {
0
} else {
u8::try_from(encoded.len())?
});
}
MssqlType::NVarChar | MssqlType::VarChar | MssqlType::VarBinary => {
if type_info.size() == Some(u16::MAX) {
write_plp_value(out, encoded, is_null)?;
} else {
let len = if is_null {
u16::MAX
} else {
u16::try_from(encoded.len())?
};
out.extend_from_slice(&len.to_le_bytes());
}
}
other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
}
if !is_null && type_info.size() != Some(u16::MAX) {
out.extend_from_slice(encoded);
}
Ok(())
}
fn declaration(
type_info: &MssqlTypeInfo,
encoded_len: usize,
is_null: bool,
) -> Result<String, BoxDynError> {
Ok(match type_info.kind() {
MssqlType::Bit => "bit".to_owned(),
MssqlType::TinyInt => "tinyint".to_owned(),
MssqlType::SmallInt => "smallint".to_owned(),
MssqlType::Int => "int".to_owned(),
MssqlType::BigInt => "bigint".to_owned(),
MssqlType::Real => "real".to_owned(),
MssqlType::Float => "float".to_owned(),
MssqlType::NVarChar => nvarchar_declaration(type_info, encoded_len, is_null)?,
MssqlType::VarChar => varchar_declaration(type_info, encoded_len, is_null)?,
MssqlType::VarBinary => varbinary_declaration(type_info, encoded_len, is_null)?,
other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
})
}
fn nvarchar_declaration(
type_info: &MssqlTypeInfo,
encoded_len: usize,
is_null: bool,
) -> Result<String, BoxDynError> {
let size = nvarchar_type_size(type_info, encoded_len, is_null)?;
if size == u16::MAX {
Ok("nvarchar(max)".to_owned())
} else {
Ok(format!("nvarchar({})", size / 2))
}
}
fn varchar_declaration(
type_info: &MssqlTypeInfo,
encoded_len: usize,
is_null: bool,
) -> Result<String, BoxDynError> {
let size = bounded_short_len(type_info, encoded_len, is_null)?;
if size == u16::MAX {
Ok("varchar(max)".to_owned())
} else {
Ok(format!("varchar({size})"))
}
}
fn varbinary_declaration(
type_info: &MssqlTypeInfo,
encoded_len: usize,
is_null: bool,
) -> Result<String, BoxDynError> {
let size = bounded_short_len(type_info, encoded_len, is_null)?;
if size == u16::MAX {
Ok("varbinary(max)".to_owned())
} else {
Ok(format!("varbinary({size})"))
}
}
fn nvarchar_type_size(
type_info: &MssqlTypeInfo,
encoded_len: usize,
is_null: bool,
) -> Result<u16, BoxDynError> {
if let Some(size) = type_info.size() {
return Ok(size);
}
let len = if is_null {
2
} else {
std::cmp::max(2, encoded_len)
};
Ok(u16::try_from(len)?)
}
fn bounded_short_len(
type_info: &MssqlTypeInfo,
encoded_len: usize,
is_null: bool,
) -> Result<u16, BoxDynError> {
if let Some(size) = type_info.size() {
return Ok(size);
}
let len = if is_null {
1
} else {
std::cmp::max(1, encoded_len)
};
Ok(u16::try_from(len)?)
}
fn write_plp_value(out: &mut Vec<u8>, encoded: &[u8], is_null: bool) -> Result<(), BoxDynError> {
if is_null {
out.extend_from_slice(&PLP_NULL.to_le_bytes());
return Ok(());
}
out.extend_from_slice(&u64::try_from(encoded.len())?.to_le_bytes());
for chunk in encoded.chunks(PLP_CHUNK_SIZE) {
out.extend_from_slice(&u32::try_from(chunk.len())?.to_le_bytes());
out.extend_from_slice(chunk);
}
out.extend_from_slice(&0_u32.to_le_bytes());
Ok(())
}
fn write_b_varchar(out: &mut Vec<u8>, value: &str) -> Result<(), BoxDynError> {
let char_len = value.encode_utf16().count();
out.push(u8::try_from(char_len)?);
write_utf16(out, value);
Ok(())
}
fn write_utf16(out: &mut Vec<u8>, value: &str) {
for unit in value.encode_utf16() {
out.extend_from_slice(&unit.to_le_bytes());
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn formats_sql_server_style_placeholders() {
let args = MssqlArguments {
len: 3,
data: Vec::new(),
declarations: String::new(),
};
let mut out = String::new();
args.format_placeholder(&mut out).unwrap();
assert_eq!("@p3", out);
}
#[test]
fn records_declarations_and_rpc_argument_data() {
let mut args = MssqlArguments::default();
args.add(7_i32).unwrap();
args.add("hi").unwrap();
assert_eq!("@p1 int,@p2 nvarchar(2)", args.declarations());
assert!(args
.data()
.windows(2)
.any(|bytes| bytes == [DATA_TYPE_INTN, 4]));
assert!(args
.data()
.windows(8)
.any(|bytes| bytes == [DATA_TYPE_NVARCHAR, 4, 0, 0x81, 0x04, 0xd0, 0x00, 0x34]));
}
#[test]
fn declares_lossless_integer_parameter_types() {
let mut args = MssqlArguments::default();
args.add(-5_i8).unwrap();
args.add(255_u8).unwrap();
args.add(65_535_u16).unwrap();
args.add(u32::MAX).unwrap();
assert_eq!(
"@p1 smallint,@p2 tinyint,@p3 int,@p4 bigint",
args.declarations()
);
assert!(args
.data()
.windows(2)
.any(|bytes| bytes == [DATA_TYPE_INTN, 1]));
assert!(args
.data()
.windows(2)
.any(|bytes| bytes == [DATA_TYPE_INTN, 8]));
}
#[test]
fn declares_large_text_and_binary_parameters_as_max() {
let mut args = MssqlArguments::default();
let text = "x".repeat(4001);
let bytes = vec![0x5a; 8001];
args.add(text.as_str()).unwrap();
args.add(bytes.as_slice()).unwrap();
assert_eq!("@p1 nvarchar(max),@p2 varbinary(max)", args.declarations());
assert!(args
.data()
.windows(3)
.any(|bytes| bytes == [DATA_TYPE_NVARCHAR, 0xff, 0xff]));
assert!(args
.data()
.windows(3)
.any(|bytes| bytes == [DATA_TYPE_BIGVARBINARY, 0xff, 0xff]));
assert!(args
.data()
.windows(8)
.any(|bytes| bytes == 8002_u64.to_le_bytes()));
assert!(args
.data()
.windows(8)
.any(|bytes| bytes == 8001_u64.to_le_bytes()));
}
}