use bytes::{BufMut, BytesMut};
use once_cell::sync::Lazy;
use regex::Regex;
use std::sync::Arc;
use mssql_types::{SqlValue, ToSql, TypeError};
use tds_protocol::packet::{PacketHeader, PacketStatus, PacketType};
use tds_protocol::token::{Collation, DoneStatus, TokenType};
use crate::error::Error;
#[derive(Debug, Clone)]
pub struct BulkOptions {
pub batch_size: usize,
pub check_constraints: bool,
pub fire_triggers: bool,
pub keep_nulls: bool,
pub table_lock: bool,
pub order_hint: Option<Vec<String>>,
}
impl Default for BulkOptions {
fn default() -> Self {
Self {
batch_size: 0,
check_constraints: true,
fire_triggers: false,
keep_nulls: true,
table_lock: false,
order_hint: None,
}
}
}
#[derive(Debug, Clone)]
pub struct BulkColumn {
pub name: String,
pub sql_type: String,
pub nullable: bool,
pub ordinal: usize,
type_id: u8,
max_length: Option<u32>,
precision: Option<u8>,
scale: Option<u8>,
collation: Option<Collation>,
}
impl BulkColumn {
pub fn new<S: Into<String>>(name: S, sql_type: S, ordinal: usize) -> Result<Self, TypeError> {
let sql_type_str: String = sql_type.into();
reject_unsupported_bulk_type(&sql_type_str)?;
let (type_id, max_length, precision, scale) = parse_sql_type(&sql_type_str);
Ok(Self {
name: name.into(),
sql_type: sql_type_str,
nullable: true,
ordinal,
type_id,
max_length,
precision,
scale,
collation: None,
})
}
#[must_use]
pub fn with_nullable(mut self, nullable: bool) -> Self {
self.nullable = nullable;
self
}
#[must_use]
pub fn with_collation(mut self, collation: Collation) -> Self {
self.collation = Some(collation);
self
}
}
fn parse_sql_type(sql_type: &str) -> (u8, Option<u32>, Option<u8>, Option<u8>) {
let upper = sql_type.to_uppercase();
let (base, params) = if let Some(paren_pos) = upper.find('(') {
let base = &upper[..paren_pos];
let params_str = upper[paren_pos + 1..].trim_end_matches(')');
(base, Some(params_str))
} else {
(upper.as_str(), None)
};
match base {
"BIT" => (0x68, Some(1), None, None), "TINYINT" => (0x26, Some(1), None, None), "SMALLINT" => (0x26, Some(2), None, None), "INT" => (0x26, Some(4), None, None), "BIGINT" => (0x26, Some(8), None, None), "REAL" => (0x6D, Some(4), None, None), "FLOAT" => (0x6D, Some(8), None, None), "DATE" => (0x28, None, None, None),
"TIME" => {
let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
(0x29, None, None, Some(scale))
}
"DATETIME" => (0x6F, Some(8), None, None), "DATETIME2" => {
let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
(0x2A, None, None, Some(scale))
}
"DATETIMEOFFSET" => {
let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
(0x2B, None, None, Some(scale))
}
"SMALLDATETIME" => (0x6F, Some(4), None, None), "UNIQUEIDENTIFIER" => (0x24, Some(16), None, None),
"VARCHAR" | "CHAR" => {
let len = params
.and_then(|p| {
if p == "MAX" {
Some(0xFFFF_u32)
} else {
p.parse().ok()
}
})
.unwrap_or(8000);
(0xA7, Some(len), None, None)
}
"NVARCHAR" | "NCHAR" => {
let is_max = params.map(|p| p == "MAX").unwrap_or(false);
if is_max {
(0xE7, Some(0xFFFF), None, None)
} else {
let len = params.and_then(|p| p.parse().ok()).unwrap_or(4000);
(0xE7, Some(len * 2), None, None)
}
}
"VARBINARY" | "BINARY" => {
let len = params
.and_then(|p| {
if p == "MAX" {
Some(0xFFFF_u32)
} else {
p.parse().ok()
}
})
.unwrap_or(8000);
(0xA5, Some(len), None, None)
}
"DECIMAL" | "NUMERIC" => {
let (precision, scale) = if let Some(p) = params {
let parts: Vec<&str> = p.split(',').map(|s| s.trim()).collect();
(
parts.first().and_then(|s| s.parse().ok()).unwrap_or(18),
parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0),
)
} else {
(18, 0)
};
(0x6C, None, Some(precision), Some(scale))
}
"MONEY" => (0x6E, Some(8), None, None), "SMALLMONEY" => (0x6E, Some(4), None, None), "XML" => (0xF1, Some(0xFFFF), None, None),
_ => (0xE7, Some(8000), None, None), }
}
fn reject_unsupported_bulk_type(sql_type: &str) -> Result<(), TypeError> {
let base = sql_type
.split('(')
.next()
.unwrap_or("")
.trim()
.to_uppercase();
match base.as_str() {
"TEXT" | "NTEXT" => Err(TypeError::UnsupportedType {
sql_type: base,
reason: "TEXT/NTEXT are not supported. Use VARCHAR(MAX) / \
NVARCHAR(MAX) instead (Microsoft deprecated TEXT/NTEXT in \
SQL Server 2005)."
.to_string(),
}),
"IMAGE" => Err(TypeError::UnsupportedType {
sql_type: base,
reason: "IMAGE is not supported. Use VARBINARY(MAX) instead \
(Microsoft deprecated IMAGE in SQL Server 2005)."
.to_string(),
}),
_ => Ok(()),
}
}
#[derive(Debug, Clone)]
pub struct BulkInsertResult {
pub rows_affected: u64,
pub batches_committed: u32,
pub has_errors: bool,
}
#[derive(Debug)]
pub struct BulkInsertBuilder {
table_name: String,
columns: Vec<BulkColumn>,
options: BulkOptions,
}
impl BulkInsertBuilder {
pub fn new<S: Into<String>>(table_name: S) -> Self {
Self {
table_name: table_name.into(),
columns: Vec::new(),
options: BulkOptions::default(),
}
}
#[must_use]
#[allow(clippy::expect_used)] pub fn with_columns(mut self, column_names: &[&str]) -> Self {
self.columns = column_names
.iter()
.enumerate()
.map(|(i, name)| {
BulkColumn::new(*name, "NVARCHAR(MAX)", i)
.expect("NVARCHAR(MAX) is always a supported type")
})
.collect();
self
}
#[must_use]
pub fn with_typed_columns(mut self, columns: Vec<BulkColumn>) -> Self {
self.columns = columns;
self
}
#[must_use]
pub fn with_options(mut self, options: BulkOptions) -> Self {
self.options = options;
self
}
#[must_use]
pub fn batch_size(mut self, size: usize) -> Self {
self.options.batch_size = size;
self
}
#[must_use]
pub fn table_lock(mut self, enabled: bool) -> Self {
self.options.table_lock = enabled;
self
}
#[must_use]
pub fn fire_triggers(mut self, enabled: bool) -> Self {
self.options.fire_triggers = enabled;
self
}
pub fn table_name(&self) -> &str {
&self.table_name
}
pub fn columns(&self) -> &[BulkColumn] {
&self.columns
}
pub fn options(&self) -> &BulkOptions {
&self.options
}
pub fn build_insert_bulk_statement(&self) -> Result<String, Error> {
crate::validation::validate_qualified_identifier(&self.table_name)?;
for col in &self.columns {
crate::validation::validate_identifier(&col.name)?;
}
let mut sql = format!("INSERT BULK {}", self.table_name);
if !self.columns.is_empty() {
sql.push_str(" (");
let cols: Vec<String> = self
.columns
.iter()
.map(|c| {
validate_sql_type(&c.sql_type)?;
Ok(format!("{} {}", c.name, c.sql_type))
})
.collect::<Result<Vec<_>, Error>>()?;
sql.push_str(&cols.join(", "));
sql.push(')');
}
let mut hints: Vec<String> = Vec::new();
if self.options.check_constraints {
hints.push("CHECK_CONSTRAINTS".to_string());
}
if self.options.fire_triggers {
hints.push("FIRE_TRIGGERS".to_string());
}
if self.options.keep_nulls {
hints.push("KEEP_NULLS".to_string());
}
if self.options.table_lock {
hints.push("TABLOCK".to_string());
}
if self.options.batch_size > 0 {
hints.push(format!("ROWS_PER_BATCH = {}", self.options.batch_size));
}
if let Some(ref order) = self.options.order_hint {
for col_name in order {
crate::validation::validate_identifier(col_name)?;
}
hints.push(format!("ORDER({})", order.join(", ")));
}
if !hints.is_empty() {
sql.push_str(" WITH (");
sql.push_str(&hints.join(", "));
sql.push(')');
}
Ok(sql)
}
}
fn validate_sql_type(type_str: &str) -> Result<(), Error> {
#[allow(clippy::expect_used)] static SQL_TYPE_RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r"^[a-zA-Z][a-zA-Z0-9_ ()\.,]{0,127}$").expect("valid regex"));
if type_str.is_empty() {
return Err(Error::Config("SQL type cannot be empty".into()));
}
if !SQL_TYPE_RE.is_match(type_str) {
return Err(Error::Config(format!(
"invalid SQL type '{type_str}': contains disallowed characters"
)));
}
Ok(())
}
pub struct BulkInsert {
columns: Arc<[BulkColumn]>,
fixed_len: Arc<[bool]>,
buffer: BytesMut,
rows_in_batch: usize,
total_rows: u64,
batch_size: usize,
batches_committed: u32,
packet_id: u8,
}
impl BulkInsert {
pub fn new(columns: Vec<BulkColumn>, batch_size: usize) -> Self {
Self::new_with_server_metadata(columns, batch_size, None, None)
}
pub fn new_with_server_metadata(
mut columns: Vec<BulkColumn>,
batch_size: usize,
raw_colmetadata: Option<bytes::Bytes>,
server_columns: Option<&[tds_protocol::token::ColumnData]>,
) -> Self {
let fixed_len: Vec<bool> = if let Some(srv_cols) = server_columns {
for (col, srv) in columns.iter_mut().zip(srv_cols.iter()) {
if col.collation.is_none() {
col.collation = srv.type_info.collation;
}
}
srv_cols
.iter()
.map(|c| c.type_id.is_fixed_length())
.collect()
} else {
columns
.iter()
.map(|c| !c.nullable && nullable_to_fixed_type(c.type_id, c.max_length).is_some())
.collect()
};
let mut bulk = Self {
columns: columns.into(),
fixed_len: fixed_len.into(),
buffer: BytesMut::with_capacity(64 * 1024),
rows_in_batch: 0,
total_rows: 0,
batch_size,
batches_committed: 0,
packet_id: 1,
};
if let Some(raw) = raw_colmetadata {
bulk.buffer.extend_from_slice(&raw);
} else {
bulk.write_colmetadata();
}
bulk
}
fn write_colmetadata(&mut self) {
let buf = &mut self.buffer;
buf.put_u8(TokenType::ColMetaData as u8);
buf.put_u16_le(self.columns.len() as u16);
for col in self.columns.iter() {
buf.put_u32_le(0);
let effective_type_id = if !col.nullable {
nullable_to_fixed_type(col.type_id, col.max_length).unwrap_or(col.type_id)
} else {
col.type_id
};
let is_fixed_variant = effective_type_id != col.type_id;
let mut flags: u16 = 0x0008; if col.nullable {
flags |= 0x0001; }
buf.put_u16_le(flags);
buf.put_u8(effective_type_id);
if is_fixed_variant {
let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
buf.put_u8(name_utf16.len() as u8);
for code_unit in name_utf16 {
buf.put_u16_le(code_unit);
}
continue;
}
match col.type_id {
0x26 | 0x68 | 0x6D | 0x6E | 0x6F => {
buf.put_u8(col.max_length.unwrap_or(4) as u8);
}
0x28 => {}
0xE7 | 0xA7 | 0xA5 | 0xAD => {
let max_len = col.max_length.unwrap_or(8000);
if max_len == 0xFFFF {
buf.put_u16_le(0xFFFF);
} else {
buf.put_u16_le(max_len as u16);
}
if col.type_id == 0xE7 || col.type_id == 0xA7 {
if let Some(coll) = col.collation.as_ref() {
buf.put_slice(&coll.to_bytes());
} else {
buf.put_slice(&[0x09, 0x04, 0xD0, 0x00, 0x34]);
}
}
}
0x6C | 0x6A => {
let precision = col.precision.unwrap_or(18);
let len = decimal_byte_length(precision);
buf.put_u8(len);
buf.put_u8(precision);
buf.put_u8(col.scale.unwrap_or(0));
}
0x29..=0x2B => {
buf.put_u8(col.scale.unwrap_or(7));
}
0x24 => {
buf.put_u8(16);
}
_ => {
if let Some(len) = col.max_length {
if len <= 0xFFFF {
buf.put_u16_le(len as u16);
}
}
}
}
let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
buf.put_u8(name_utf16.len() as u8);
for code_unit in name_utf16 {
buf.put_u16_le(code_unit);
}
}
}
pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
if values.len() != self.columns.len() {
return Err(Error::Config(format!(
"expected {} values, got {}",
self.columns.len(),
values.len()
)));
}
let sql_values: Result<Vec<SqlValue>, TypeError> =
values.iter().map(|v| v.to_sql()).collect();
let sql_values = sql_values.map_err(Error::from)?;
self.write_row(&sql_values)?;
self.rows_in_batch += 1;
self.total_rows += 1;
Ok(())
}
pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
if values.len() != self.columns.len() {
return Err(Error::Config(format!(
"expected {} values, got {}",
self.columns.len(),
values.len()
)));
}
self.write_row(values)?;
self.rows_in_batch += 1;
self.total_rows += 1;
Ok(())
}
fn write_row(&mut self, values: &[SqlValue]) -> Result<(), Error> {
self.buffer.put_u8(TokenType::Row as u8);
let columns: Vec<_> = self.columns.iter().cloned().collect();
let fixed_len = self.fixed_len.clone();
for (i, (col, value)) in columns.iter().zip(values.iter()).enumerate() {
let is_fixed = *fixed_len.get(i).unwrap_or(&false);
self.encode_column_value(col, value, is_fixed)
.map_err(|e| Error::Config(format!("failed to encode column {i}: {e}")))?;
}
Ok(())
}
fn encode_column_value(
&mut self,
col: &BulkColumn,
value: &SqlValue,
is_fixed: bool,
) -> Result<(), TypeError> {
let buf = &mut self.buffer;
let is_plp_type =
col.max_length == Some(0xFFFF) && matches!(col.type_id, 0xE7 | 0xA7 | 0xA5 | 0xAD);
match value {
SqlValue::Null => {
match col.type_id {
0xE7 | 0xA7 | 0xA5 | 0xAD => {
if is_plp_type {
buf.put_u64_le(0xFFFF_FFFF_FFFF_FFFF);
} else {
buf.put_u16_le(0xFFFF);
}
}
0x26 | 0x68 | 0x6D | 0x6E | 0x6F | 0x6C | 0x6A | 0x24 | 0x28 | 0x29 | 0x2A
| 0x2B => {
buf.put_u8(0);
}
_ => {
if col.nullable {
buf.put_u8(0);
} else {
return Err(TypeError::UnexpectedNull);
}
}
}
}
SqlValue::Bool(v) => {
if !is_fixed {
buf.put_u8(1);
}
buf.put_u8(if *v { 1 } else { 0 });
}
SqlValue::TinyInt(v) => {
if !is_fixed {
buf.put_u8(1);
}
buf.put_u8(*v);
}
SqlValue::SmallInt(v) => {
if !is_fixed {
buf.put_u8(2);
}
buf.put_i16_le(*v);
}
SqlValue::Int(v) => {
if !is_fixed {
buf.put_u8(4);
}
buf.put_i32_le(*v);
}
SqlValue::BigInt(v) => {
if !is_fixed {
buf.put_u8(8);
}
buf.put_i64_le(*v);
}
SqlValue::Float(v) => {
if !is_fixed {
buf.put_u8(4);
}
buf.put_f32_le(*v);
}
SqlValue::Double(v) => {
if !is_fixed {
buf.put_u8(8);
}
buf.put_f64_le(*v);
}
SqlValue::String(s) => {
let is_varchar = matches!(col.type_id, 0xA7 | 0x2F | 0xAF);
if is_varchar {
let encoded = encode_varchar_for_collation(s, col.collation.as_ref());
let byte_len = encoded.len();
if is_plp_type {
encode_plp_binary(&encoded, buf);
} else if byte_len > 0xFFFF {
return Err(TypeError::BufferTooSmall {
needed: byte_len,
available: 0xFFFF,
});
} else {
buf.put_u16_le(byte_len as u16);
buf.put_slice(&encoded);
}
} else {
let utf16: Vec<u16> = s.encode_utf16().collect();
let byte_len = utf16.len() * 2;
if is_plp_type {
encode_plp_string(&utf16, buf);
} else if byte_len > 0xFFFF {
return Err(TypeError::BufferTooSmall {
needed: byte_len,
available: 0xFFFF,
});
} else {
buf.put_u16_le(byte_len as u16);
for code_unit in utf16 {
buf.put_u16_le(code_unit);
}
}
}
}
SqlValue::Binary(b) => {
if is_plp_type {
encode_plp_binary(b, buf);
} else if b.len() > 0xFFFF {
return Err(TypeError::BufferTooSmall {
needed: b.len(),
available: 0xFFFF,
});
} else {
buf.put_u16_le(b.len() as u16);
buf.put_slice(b);
}
}
#[cfg(feature = "decimal")]
SqlValue::Decimal(d) => {
if col.type_id == 0x6E {
encode_money_value(*d, col, buf, is_fixed)?;
} else {
let precision = col.precision.unwrap_or(18);
let len = decimal_byte_length(precision);
buf.put_u8(len);
buf.put_u8(if d.is_sign_negative() { 0 } else { 1 });
let mantissa = d.mantissa().unsigned_abs();
let mantissa_bytes = mantissa.to_le_bytes();
buf.put_slice(&mantissa_bytes[..((len - 1) as usize)]);
}
}
#[cfg(feature = "uuid")]
SqlValue::Uuid(u) => {
buf.put_u8(16); mssql_types::encode::encode_uuid(*u, buf);
}
#[cfg(feature = "chrono")]
SqlValue::Date(d) => {
buf.put_u8(3); mssql_types::encode::encode_date(*d, buf);
}
#[cfg(feature = "chrono")]
SqlValue::Time(t) => {
let scale = col.scale.unwrap_or(7);
let len = time_byte_length(scale);
buf.put_u8(len);
encode_time_with_scale(*t, scale, buf);
}
#[cfg(feature = "chrono")]
SqlValue::DateTime(dt) => {
if col.type_id == 0x6F {
let total_len = col.max_length.unwrap_or(8) as u8;
if !is_fixed {
buf.put_u8(total_len);
}
match total_len {
8 => mssql_types::encode::encode_datetime_legacy(*dt, buf),
4 => mssql_types::encode::encode_smalldatetime(*dt, buf)?,
_ => {
return Err(TypeError::InvalidDateTime(format!(
"DATETIMEN max_length must be 4 or 8, got {total_len}"
)));
}
}
} else {
let scale = col.scale.unwrap_or(7);
let time_len = time_byte_length(scale);
let total_len = time_len + 3;
buf.put_u8(total_len);
encode_time_with_scale(dt.time(), scale, buf);
mssql_types::encode::encode_date(dt.date(), buf);
}
}
#[cfg(feature = "chrono")]
SqlValue::SmallDateTime(dt) => {
if !is_fixed {
buf.put_u8(4);
}
mssql_types::encode::encode_smalldatetime(*dt, buf)?;
}
#[cfg(feature = "decimal")]
SqlValue::Money(d) => {
if !is_fixed {
buf.put_u8(8);
}
mssql_types::encode::encode_money(*d, buf)?;
}
#[cfg(feature = "decimal")]
SqlValue::SmallMoney(d) => {
if !is_fixed {
buf.put_u8(4);
}
mssql_types::encode::encode_smallmoney(*d, buf)?;
}
#[cfg(feature = "chrono")]
SqlValue::DateTimeOffset(dto) => {
let scale = col.scale.unwrap_or(7);
let time_len = time_byte_length(scale);
let total_len = time_len + 3 + 2;
buf.put_u8(total_len);
encode_time_with_scale(dto.time(), scale, buf);
mssql_types::encode::encode_date(dto.date_naive(), buf);
use chrono::Offset;
let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
buf.put_i16_le(offset_minutes);
}
#[cfg(feature = "json")]
SqlValue::Json(j) => {
let s = j.to_string();
encode_nvarchar_value(&s, buf)?;
}
SqlValue::Xml(x) => {
encode_nvarchar_value(x, buf)?;
}
SqlValue::Tvp(_) => {
return Err(TypeError::UnsupportedConversion {
from: "TVP".to_string(),
to: "bulk copy value",
});
}
_ => {
return Err(TypeError::UnsupportedConversion {
from: value.type_name().to_string(),
to: "bulk copy value",
});
}
}
Ok(())
}
}
#[cfg(feature = "decimal")]
fn encode_money_value(
value: rust_decimal::Decimal,
col: &BulkColumn,
buf: &mut BytesMut,
is_fixed: bool,
) -> Result<(), TypeError> {
let money_bytes: u8 = col.max_length.unwrap_or(8) as u8;
if !is_fixed {
buf.put_u8(money_bytes);
}
match money_bytes {
4 => mssql_types::encode::encode_smallmoney(value, buf),
8 => mssql_types::encode::encode_money(value, buf),
_ => Err(TypeError::InvalidDecimal(format!(
"MONEY column has invalid max_length: {money_bytes}"
))),
}
}
fn encode_nvarchar_value(s: &str, buf: &mut BytesMut) -> Result<(), TypeError> {
let utf16: Vec<u16> = s.encode_utf16().collect();
let byte_len = utf16.len() * 2;
if byte_len > 0xFFFF {
return Err(TypeError::BufferTooSmall {
needed: byte_len,
available: 0xFFFF,
});
}
buf.put_u16_le(byte_len as u16);
for code_unit in utf16 {
buf.put_u16_le(code_unit);
}
Ok(())
}
const PLP_UNKNOWN_LEN: u64 = 0xFFFFFFFFFFFFFFFE;
fn encode_plp_string(utf16: &[u16], buf: &mut BytesMut) {
let byte_len = utf16.len() * 2;
buf.put_u64_le(PLP_UNKNOWN_LEN);
if byte_len > 0 {
buf.put_u32_le(byte_len as u32);
for code_unit in utf16 {
buf.put_u16_le(*code_unit);
}
}
buf.put_u32_le(0);
}
fn encode_plp_binary(data: &[u8], buf: &mut BytesMut) {
buf.put_u64_le(PLP_UNKNOWN_LEN);
if !data.is_empty() {
buf.put_u32_le(data.len() as u32);
buf.put_slice(data);
}
buf.put_u32_le(0);
}
fn encode_varchar_for_collation(value: &str, collation: Option<&Collation>) -> Vec<u8> {
tds_protocol::collation::encode_str_for_collation(value, collation)
}
#[cfg(feature = "chrono")]
fn encode_time_with_scale(time: chrono::NaiveTime, scale: u8, buf: &mut BytesMut) {
use chrono::Timelike;
let nanos = time.num_seconds_from_midnight() as u64 * 1_000_000_000 + time.nanosecond() as u64;
let intervals = nanos / time_scale_divisor(scale);
let len = time_byte_length(scale);
for i in 0..len {
buf.put_u8(((intervals >> (i * 8)) & 0xFF) as u8);
}
}
impl BulkInsert {
fn write_done(&mut self) {
let buf = &mut self.buffer;
buf.put_u8(TokenType::Done as u8);
let status = DoneStatus::from_bits(0x0010); buf.put_u16_le(status.to_bits());
buf.put_u16_le(0);
buf.put_u64_le(self.total_rows);
}
pub fn take_packets(&mut self) -> Vec<BytesMut> {
const MAX_PACKET_SIZE: usize = 4096;
const HEADER_SIZE: usize = 8;
const MAX_PAYLOAD: usize = MAX_PACKET_SIZE - HEADER_SIZE;
let data = self.buffer.split();
let mut packets = Vec::new();
let mut offset = 0;
while offset < data.len() {
let remaining = data.len() - offset;
let payload_size = remaining.min(MAX_PAYLOAD);
let is_last = offset + payload_size >= data.len();
let mut packet = BytesMut::with_capacity(MAX_PACKET_SIZE);
let header = PacketHeader {
packet_type: PacketType::BulkLoad,
status: if is_last {
PacketStatus::END_OF_MESSAGE
} else {
PacketStatus::NORMAL
},
length: (HEADER_SIZE + payload_size) as u16,
spid: 0,
packet_id: self.packet_id,
window: 0,
};
header.encode(&mut packet);
packet.put_slice(&data[offset..offset + payload_size]);
packets.push(packet);
offset += payload_size;
self.packet_id = self.packet_id.wrapping_add(1);
}
packets
}
pub fn total_rows(&self) -> u64 {
self.total_rows
}
pub fn rows_in_batch(&self) -> usize {
self.rows_in_batch
}
pub fn should_flush(&self) -> bool {
self.batch_size > 0 && self.rows_in_batch >= self.batch_size
}
pub fn finish_packets(&mut self) -> Vec<BytesMut> {
self.write_done();
self.take_packets()
}
pub fn result(&self) -> BulkInsertResult {
BulkInsertResult {
rows_affected: self.total_rows,
batches_committed: self.batches_committed,
has_errors: false,
}
}
}
pub struct BulkWriter<'a, S: crate::state::ConnectionState> {
client: &'a mut crate::client::Client<S>,
bulk: BulkInsert,
}
impl<'a, S: crate::state::ConnectionState> BulkWriter<'a, S> {
pub(crate) fn new(client: &'a mut crate::client::Client<S>, bulk: BulkInsert) -> Self {
Self { client, bulk }
}
pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
self.bulk.send_row(values)
}
pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
self.bulk.send_row_values(values)
}
pub fn total_rows(&self) -> u64 {
self.bulk.total_rows()
}
pub async fn finish(mut self) -> Result<BulkInsertResult, Error> {
let total_rows = self.bulk.total_rows();
tracing::debug!(total_rows = total_rows, "finishing bulk insert");
self.bulk.write_done();
let payload = self.bulk.buffer.split().freeze();
let rows_affected = self.client.send_and_read_bulk_load(payload).await?;
Ok(BulkInsertResult {
rows_affected,
batches_committed: 1,
has_errors: false,
})
}
}
fn nullable_to_fixed_type(type_id: u8, max_length: Option<u32>) -> Option<u8> {
match (type_id, max_length) {
(0x68, _) => Some(0x32), (0x26, Some(1)) => Some(0x30), (0x26, Some(2)) => Some(0x34), (0x26, Some(4)) => Some(0x38), (0x26, Some(8)) => Some(0x7F), (0x6D, Some(4)) => Some(0x3B), (0x6D, Some(8)) => Some(0x3E), (0x6E, Some(4)) => Some(0x7A), (0x6E, Some(8)) => Some(0x3C), (0x6F, Some(4)) => Some(0x3A), (0x6F, Some(8)) => Some(0x3D), _ => None,
}
}
fn decimal_byte_length(precision: u8) -> u8 {
match precision {
1..=9 => 5,
10..=19 => 9,
20..=28 => 13,
29..=38 => 17,
_ => 17, }
}
#[cfg(feature = "chrono")]
fn time_byte_length(scale: u8) -> u8 {
match scale {
0..=2 => 3,
3..=4 => 4,
5..=7 => 5,
_ => 5,
}
}
#[cfg(feature = "chrono")]
fn time_scale_divisor(scale: u8) -> u64 {
match scale {
0 => 1_000_000_000,
1 => 100_000_000,
2 => 10_000_000,
3 => 1_000_000,
4 => 100_000,
5 => 10_000,
6 => 1_000,
7 => 100,
_ => 100,
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_bulk_options_default() {
let opts = BulkOptions::default();
assert_eq!(opts.batch_size, 0);
assert!(opts.check_constraints);
assert!(!opts.fire_triggers);
assert!(opts.keep_nulls);
assert!(!opts.table_lock);
}
#[test]
fn test_bulk_column_creation() {
let col = BulkColumn::new("id", "INT", 0).unwrap();
assert_eq!(col.name, "id");
assert_eq!(col.type_id, 0x26); assert_eq!(col.max_length, Some(4));
assert!(col.nullable);
}
#[test]
fn test_bulk_column_rejects_text() {
let err = BulkColumn::new("body", "TEXT", 0).unwrap_err();
match err {
TypeError::UnsupportedType { sql_type, reason } => {
assert_eq!(sql_type, "TEXT");
assert!(
reason.contains("VARCHAR(MAX)"),
"error should redirect to VARCHAR(MAX), got: {reason}"
);
assert!(
reason.contains("deprecated"),
"error should mention deprecation, got: {reason}"
);
}
other => panic!("expected UnsupportedType, got {other:?}"),
}
}
#[test]
fn test_bulk_column_rejects_ntext() {
let err = BulkColumn::new("body", "NTEXT", 0).unwrap_err();
match err {
TypeError::UnsupportedType { sql_type, reason } => {
assert_eq!(sql_type, "NTEXT");
assert!(
reason.contains("NVARCHAR(MAX)"),
"error should redirect to NVARCHAR(MAX), got: {reason}"
);
assert!(
reason.contains("deprecated"),
"error should mention deprecation, got: {reason}"
);
}
other => panic!("expected UnsupportedType, got {other:?}"),
}
}
#[test]
fn test_bulk_column_rejects_text_case_insensitive() {
assert!(matches!(
BulkColumn::new("body", "text", 0),
Err(TypeError::UnsupportedType { .. })
));
assert!(matches!(
BulkColumn::new("body", "Ntext", 0),
Err(TypeError::UnsupportedType { .. })
));
}
#[test]
fn test_bulk_column_rejects_image() {
let err = BulkColumn::new("blob", "IMAGE", 0).unwrap_err();
match err {
TypeError::UnsupportedType { sql_type, reason } => {
assert_eq!(sql_type, "IMAGE");
assert!(
reason.contains("VARBINARY(MAX)"),
"error should redirect to VARBINARY(MAX), got: {reason}"
);
assert!(
reason.contains("deprecated"),
"error should mention deprecation, got: {reason}"
);
}
other => panic!("expected UnsupportedType, got {other:?}"),
}
}
#[test]
fn test_bulk_column_rejects_image_case_insensitive() {
assert!(matches!(
BulkColumn::new("blob", "image", 0),
Err(TypeError::UnsupportedType { .. })
));
assert!(matches!(
BulkColumn::new("blob", "Image", 0),
Err(TypeError::UnsupportedType { .. })
));
}
#[test]
fn test_parse_sql_type() {
let (type_id, len, _prec, _scale) = parse_sql_type("INT");
assert_eq!(type_id, 0x26);
assert_eq!(len, Some(4));
let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
assert_eq!(type_id, 0xE7);
assert_eq!(len, Some(200));
let (type_id, _, prec, scale) = parse_sql_type("DECIMAL(10,2)");
assert_eq!(type_id, 0x6C);
assert_eq!(prec, Some(10));
assert_eq!(scale, Some(2));
let (type_id, len, _, _) = parse_sql_type("SMALLDATETIME");
assert_eq!(type_id, 0x6F);
assert_eq!(len, Some(4));
let (type_id, len, _, _) = parse_sql_type("DATETIME");
assert_eq!(type_id, 0x6F);
assert_eq!(len, Some(8));
}
#[test]
fn test_insert_bulk_statement() {
let builder = BulkInsertBuilder::new("dbo.Users")
.with_typed_columns(vec![
BulkColumn::new("id", "INT", 0).unwrap(),
BulkColumn::new("name", "NVARCHAR(100)", 1).unwrap(),
])
.table_lock(true);
let sql = builder.build_insert_bulk_statement().unwrap();
assert!(sql.contains("INSERT BULK dbo.Users"));
assert!(sql.contains("TABLOCK"));
}
#[test]
fn test_bulk_insert_rejects_injection() {
let builder = BulkInsertBuilder::new("table;DROP TABLE users")
.with_typed_columns(vec![BulkColumn::new("id", "INT", 0).unwrap()]);
assert!(builder.build_insert_bulk_statement().is_err());
}
#[test]
fn test_bulk_insert_validates_column_names() {
let builder = BulkInsertBuilder::new("Users")
.with_typed_columns(vec![BulkColumn::new("col;DROP TABLE x", "INT", 0).unwrap()]);
assert!(builder.build_insert_bulk_statement().is_err());
}
#[test]
fn test_bulk_insert_accepts_qualified_names() {
let builder = BulkInsertBuilder::new("catalog.dbo.Users")
.with_typed_columns(vec![BulkColumn::new("id", "INT", 0).unwrap()]);
assert!(builder.build_insert_bulk_statement().is_ok());
}
#[test]
fn test_bulk_insert_creation() {
let columns = vec![
BulkColumn::new("id", "INT", 0).unwrap(),
BulkColumn::new("name", "NVARCHAR(100)", 1).unwrap(),
];
let bulk = BulkInsert::new(columns, 1000);
assert_eq!(bulk.total_rows(), 0);
assert_eq!(bulk.rows_in_batch(), 0);
assert!(!bulk.should_flush());
}
#[test]
fn test_decimal_byte_length() {
assert_eq!(decimal_byte_length(5), 5);
assert_eq!(decimal_byte_length(15), 9);
assert_eq!(decimal_byte_length(25), 13);
assert_eq!(decimal_byte_length(35), 17);
}
#[test]
#[cfg(feature = "chrono")]
fn test_time_byte_length() {
assert_eq!(time_byte_length(0), 3);
assert_eq!(time_byte_length(3), 4);
assert_eq!(time_byte_length(7), 5);
}
#[test]
fn test_plp_string_encoding() {
let mut buf = BytesMut::new();
let text = "Hello";
let utf16: Vec<u16> = text.encode_utf16().collect();
encode_plp_string(&utf16, &mut buf);
assert_eq!(buf.len(), 8 + 4 + 10 + 4);
assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
assert_eq!(&buf[8..12], &10u32.to_le_bytes());
assert_eq!(&buf[22..26], &0u32.to_le_bytes());
}
#[test]
fn test_plp_binary_encoding() {
let mut buf = BytesMut::new();
let data = b"test binary data";
encode_plp_binary(data, &mut buf);
assert_eq!(buf.len(), 8 + 4 + 16 + 4);
assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
assert_eq!(&buf[8..12], &16u32.to_le_bytes());
assert_eq!(&buf[12..28], data);
assert_eq!(&buf[28..32], &0u32.to_le_bytes());
}
#[test]
fn test_plp_empty_string() {
let mut buf = BytesMut::new();
let utf16: Vec<u16> = "".encode_utf16().collect();
encode_plp_string(&utf16, &mut buf);
assert_eq!(buf.len(), 8 + 4);
assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
assert_eq!(&buf[8..12], &0u32.to_le_bytes());
}
#[test]
fn test_plp_empty_binary() {
let mut buf = BytesMut::new();
encode_plp_binary(&[], &mut buf);
assert_eq!(buf.len(), 8 + 4);
assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
assert_eq!(&buf[8..12], &0u32.to_le_bytes());
}
#[test]
fn test_write_colmetadata_roundtrip() {
use tds_protocol::token::ColMetaData;
let columns = vec![
BulkColumn::new("id", "INT", 0).unwrap(),
BulkColumn::new("tiny", "TINYINT", 1).unwrap(),
BulkColumn::new("small", "SMALLINT", 2).unwrap(),
BulkColumn::new("big", "BIGINT", 3).unwrap(),
BulkColumn::new("flag", "BIT", 4).unwrap(),
BulkColumn::new("r", "REAL", 5).unwrap(),
BulkColumn::new("f", "FLOAT", 6).unwrap(),
BulkColumn::new("name", "NVARCHAR(100)", 7).unwrap(),
BulkColumn::new("code", "VARCHAR(50)", 8).unwrap(),
BulkColumn::new("data", "VARBINARY(200)", 9).unwrap(),
BulkColumn::new("d", "DATE", 10).unwrap(),
BulkColumn::new("t", "TIME(3)", 11).unwrap(),
BulkColumn::new("dt", "DATETIME", 12).unwrap(),
BulkColumn::new("dt2", "DATETIME2(7)", 13).unwrap(),
BulkColumn::new("dto", "DATETIMEOFFSET(7)", 14).unwrap(),
BulkColumn::new("sdt", "SMALLDATETIME", 15).unwrap(),
BulkColumn::new("uid", "UNIQUEIDENTIFIER", 16).unwrap(),
BulkColumn::new("amt", "DECIMAL(18,2)", 17).unwrap(),
BulkColumn::new("price", "MONEY", 18).unwrap(),
BulkColumn::new("smoney", "SMALLMONEY", 19).unwrap(),
BulkColumn::new("nmax", "NVARCHAR(MAX)", 20).unwrap(),
BulkColumn::new("vmax", "VARCHAR(MAX)", 21).unwrap(),
BulkColumn::new("bmax", "VARBINARY(MAX)", 22).unwrap(),
];
let bulk = BulkInsert::new(columns.clone(), 0);
let buf = &bulk.buffer[1..];
let mut cursor = bytes::Bytes::copy_from_slice(buf);
let meta = ColMetaData::decode(&mut cursor)
.expect("write_colmetadata output should be parseable by TDS decoder");
assert_eq!(meta.columns.len(), columns.len());
for (i, (parsed, original)) in meta.columns.iter().zip(columns.iter()).enumerate() {
assert_eq!(parsed.name, original.name, "column {i} name mismatch");
assert_eq!(
parsed.col_type, original.type_id,
"column {i} ({}) type mismatch",
original.name
);
match original.type_id {
0x26 => {
assert_eq!(
parsed.type_info.max_length, original.max_length,
"column {i} ({}) INTN max_length",
original.name
);
}
0x68 => {
assert_eq!(parsed.type_info.max_length, Some(1));
}
0x6D => {
assert_eq!(
parsed.type_info.max_length, original.max_length,
"column {i} ({}) FLTN max_length",
original.name
);
}
0x6E => {
assert_eq!(
parsed.type_info.max_length, original.max_length,
"column {i} ({}) MONEYN max_length",
original.name
);
}
0x6F => {
assert_eq!(
parsed.type_info.max_length, original.max_length,
"column {i} ({}) DATETIMEN max_length",
original.name
);
}
0x24 => {
assert_eq!(parsed.type_info.max_length, Some(16));
}
0x28 => {}
0x29..=0x2B => {
assert_eq!(
parsed.type_info.scale, original.scale,
"column {i} ({}) scale",
original.name
);
}
0xE7 | 0xA7 => {
assert_eq!(
parsed.type_info.max_length, original.max_length,
"column {i} ({}) string max_length",
original.name
);
assert!(
parsed.type_info.collation.is_some(),
"column {i} ({}) should have collation",
original.name
);
}
0xA5 => {
assert_eq!(
parsed.type_info.max_length, original.max_length,
"column {i} ({}) binary max_length",
original.name
);
assert!(
parsed.type_info.collation.is_none(),
"column {i} ({}) should not have collation",
original.name
);
}
0x6C => {
assert_eq!(
parsed.type_info.precision, original.precision,
"column {i} ({}) precision",
original.name
);
assert_eq!(
parsed.type_info.scale, original.scale,
"column {i} ({}) scale",
original.name
);
}
_ => {}
}
}
}
#[test]
fn test_write_colmetadata_not_null_uses_fixed_types() {
use tds_protocol::token::ColMetaData;
use tds_protocol::types::TypeId;
let columns = vec![
BulkColumn::new("id", "INT", 0)
.unwrap()
.with_nullable(false),
BulkColumn::new("tiny", "TINYINT", 1)
.unwrap()
.with_nullable(false),
BulkColumn::new("small", "SMALLINT", 2)
.unwrap()
.with_nullable(false),
BulkColumn::new("big", "BIGINT", 3)
.unwrap()
.with_nullable(false),
BulkColumn::new("flag", "BIT", 4)
.unwrap()
.with_nullable(false),
BulkColumn::new("r", "REAL", 5)
.unwrap()
.with_nullable(false),
BulkColumn::new("f", "FLOAT", 6)
.unwrap()
.with_nullable(false),
BulkColumn::new("dt", "DATETIME", 7)
.unwrap()
.with_nullable(false),
BulkColumn::new("sdt", "SMALLDATETIME", 8)
.unwrap()
.with_nullable(false),
BulkColumn::new("mny", "MONEY", 9)
.unwrap()
.with_nullable(false),
BulkColumn::new("smny", "SMALLMONEY", 10)
.unwrap()
.with_nullable(false),
];
let bulk = BulkInsert::new(columns.clone(), 0);
for (i, fixed) in bulk.fixed_len.iter().enumerate() {
assert!(
*fixed,
"column {i} ({}) should be fixed_len",
columns[i].name
);
}
let buf = &bulk.buffer[1..]; let mut cursor = bytes::Bytes::copy_from_slice(buf);
let meta = ColMetaData::decode(&mut cursor).expect("parseable");
let expected: &[(&str, TypeId)] = &[
("id", TypeId::Int4),
("tiny", TypeId::Int1),
("small", TypeId::Int2),
("big", TypeId::Int8),
("flag", TypeId::Bit),
("r", TypeId::Float4),
("f", TypeId::Float8),
("dt", TypeId::DateTime),
("sdt", TypeId::DateTime4),
("mny", TypeId::Money),
("smny", TypeId::Money4),
];
for (i, (name, ty)) in expected.iter().enumerate() {
assert_eq!(meta.columns[i].name, *name, "column {i} name");
assert_eq!(meta.columns[i].type_id, *ty, "column {i} ({name}) type");
assert_eq!(
meta.columns[i].flags & 0x0001,
0,
"column {i} ({name}) should not have Nullable flag set"
);
}
}
#[test]
fn test_write_colmetadata_uses_caller_collation() {
use tds_protocol::token::{ColMetaData, Collation};
let chinese = Collation {
lcid: 0x0804,
sort_id: 0x52,
};
let columns = vec![
BulkColumn::new("s", "VARCHAR(50)", 0)
.unwrap()
.with_collation(chinese),
BulkColumn::new("n", "NVARCHAR(50)", 1)
.unwrap()
.with_collation(chinese),
BulkColumn::new("d", "VARCHAR(10)", 2).unwrap(),
];
let bulk = BulkInsert::new(columns, 0);
let buf = &bulk.buffer[1..];
let mut cursor = bytes::Bytes::copy_from_slice(buf);
let meta = ColMetaData::decode(&mut cursor).expect("parseable");
let c0 = meta.columns[0]
.type_info
.collation
.as_ref()
.expect("VARCHAR has collation");
assert_eq!(c0.lcid, chinese.lcid, "VARCHAR caller LCID");
assert_eq!(c0.sort_id, chinese.sort_id, "VARCHAR caller sort_id");
let c1 = meta.columns[1]
.type_info
.collation
.as_ref()
.expect("NVARCHAR has collation");
assert_eq!(c1.lcid, chinese.lcid, "NVARCHAR caller LCID");
assert_eq!(c1.sort_id, chinese.sort_id, "NVARCHAR caller sort_id");
let default = meta.columns[2]
.type_info
.collation
.as_ref()
.expect("VARCHAR has default collation");
assert_eq!(default.to_bytes(), [0x09, 0x04, 0xD0, 0x00, 0x34]);
}
#[test]
fn test_parse_sql_type_max() {
let (type_id, len, _, _) = parse_sql_type("NVARCHAR(MAX)");
assert_eq!(type_id, 0xE7);
assert_eq!(len, Some(0xFFFF));
let (type_id, len, _, _) = parse_sql_type("VARBINARY(MAX)");
assert_eq!(type_id, 0xA5);
assert_eq!(len, Some(0xFFFF));
let (type_id, len, _, _) = parse_sql_type("VARCHAR(MAX)");
assert_eq!(type_id, 0xA7);
assert_eq!(len, Some(0xFFFF));
let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
assert_eq!(type_id, 0xE7);
assert_eq!(len, Some(200)); }
}