use bytes::{BufMut, BytesMut};
use once_cell::sync::Lazy;
use regex::Regex;
use std::sync::Arc;
use mssql_types::{SqlValue, ToSql, TypeError};
use tds_protocol::packet::{PacketHeader, PacketStatus, PacketType};
use tds_protocol::token::{DoneStatus, TokenType};
use crate::error::Error;
#[derive(Debug, Clone)]
pub struct BulkOptions {
pub batch_size: usize,
pub check_constraints: bool,
pub fire_triggers: bool,
pub keep_nulls: bool,
pub table_lock: bool,
pub order_hint: Option<Vec<String>>,
pub max_errors: u32,
}
impl Default for BulkOptions {
fn default() -> Self {
Self {
batch_size: 0,
check_constraints: true,
fire_triggers: false,
keep_nulls: true,
table_lock: false,
order_hint: None,
max_errors: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct BulkColumn {
pub name: String,
pub sql_type: String,
pub nullable: bool,
pub ordinal: usize,
type_id: u8,
max_length: Option<u32>,
precision: Option<u8>,
scale: Option<u8>,
}
impl BulkColumn {
pub fn new<S: Into<String>>(name: S, sql_type: S, ordinal: usize) -> Self {
let sql_type_str: String = sql_type.into();
let (type_id, max_length, precision, scale) = parse_sql_type(&sql_type_str);
Self {
name: name.into(),
sql_type: sql_type_str,
nullable: true,
ordinal,
type_id,
max_length,
precision,
scale,
}
}
#[must_use]
pub fn with_nullable(mut self, nullable: bool) -> Self {
self.nullable = nullable;
self
}
}
fn parse_sql_type(sql_type: &str) -> (u8, Option<u32>, Option<u8>, Option<u8>) {
let upper = sql_type.to_uppercase();
let (base, params) = if let Some(paren_pos) = upper.find('(') {
let base = &upper[..paren_pos];
let params_str = upper[paren_pos + 1..].trim_end_matches(')');
(base, Some(params_str))
} else {
(upper.as_str(), None)
};
match base {
"BIT" => (0x32, None, None, None),
"TINYINT" => (0x30, None, None, None),
"SMALLINT" => (0x34, None, None, None),
"INT" => (0x38, None, None, None),
"BIGINT" => (0x7F, None, None, None),
"REAL" => (0x3B, None, None, None),
"FLOAT" => (0x3E, None, None, None),
"DATE" => (0x28, None, None, None),
"TIME" => {
let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
(0x29, None, None, Some(scale))
}
"DATETIME" => (0x3D, None, None, None),
"DATETIME2" => {
let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
(0x2A, None, None, Some(scale))
}
"DATETIMEOFFSET" => {
let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
(0x2B, None, None, Some(scale))
}
"SMALLDATETIME" => (0x3F, None, None, None),
"UNIQUEIDENTIFIER" => (0x24, Some(16), None, None),
"VARCHAR" | "CHAR" => {
let len = params
.and_then(|p| {
if p == "MAX" {
Some(0xFFFF_u32)
} else {
p.parse().ok()
}
})
.unwrap_or(8000);
(0xA7, Some(len), None, None)
}
"NVARCHAR" | "NCHAR" => {
let is_max = params.map(|p| p == "MAX").unwrap_or(false);
if is_max {
(0xE7, Some(0xFFFF), None, None)
} else {
let len = params.and_then(|p| p.parse().ok()).unwrap_or(4000);
(0xE7, Some(len * 2), None, None)
}
}
"VARBINARY" | "BINARY" => {
let len = params
.and_then(|p| {
if p == "MAX" {
Some(0xFFFF_u32)
} else {
p.parse().ok()
}
})
.unwrap_or(8000);
(0xA5, Some(len), None, None)
}
"DECIMAL" | "NUMERIC" => {
let (precision, scale) = if let Some(p) = params {
let parts: Vec<&str> = p.split(',').map(|s| s.trim()).collect();
(
parts.first().and_then(|s| s.parse().ok()).unwrap_or(18),
parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0),
)
} else {
(18, 0)
};
(0x6C, None, Some(precision), Some(scale))
}
"MONEY" => (0x3C, Some(8), None, None),
"SMALLMONEY" => (0x7A, Some(4), None, None),
"XML" => (0xF1, Some(0xFFFF), None, None),
"TEXT" => (0x23, Some(0x7FFF_FFFF), None, None),
"NTEXT" => (0x63, Some(0x7FFF_FFFF), None, None),
"IMAGE" => (0x22, Some(0x7FFF_FFFF), None, None),
_ => (0xE7, Some(8000), None, None), }
}
#[derive(Debug, Clone)]
pub struct BulkInsertResult {
pub rows_affected: u64,
pub batches_committed: u32,
pub has_errors: bool,
}
#[derive(Debug)]
pub struct BulkInsertBuilder {
table_name: String,
columns: Vec<BulkColumn>,
options: BulkOptions,
}
impl BulkInsertBuilder {
pub fn new<S: Into<String>>(table_name: S) -> Self {
Self {
table_name: table_name.into(),
columns: Vec::new(),
options: BulkOptions::default(),
}
}
#[must_use]
pub fn with_columns(mut self, column_names: &[&str]) -> Self {
self.columns = column_names
.iter()
.enumerate()
.map(|(i, name)| BulkColumn::new(*name, "NVARCHAR(MAX)", i))
.collect();
self
}
#[must_use]
pub fn with_typed_columns(mut self, columns: Vec<BulkColumn>) -> Self {
self.columns = columns;
self
}
#[must_use]
pub fn with_options(mut self, options: BulkOptions) -> Self {
self.options = options;
self
}
#[must_use]
pub fn batch_size(mut self, size: usize) -> Self {
self.options.batch_size = size;
self
}
#[must_use]
pub fn table_lock(mut self, enabled: bool) -> Self {
self.options.table_lock = enabled;
self
}
#[must_use]
pub fn fire_triggers(mut self, enabled: bool) -> Self {
self.options.fire_triggers = enabled;
self
}
pub fn table_name(&self) -> &str {
&self.table_name
}
pub fn columns(&self) -> &[BulkColumn] {
&self.columns
}
pub fn options(&self) -> &BulkOptions {
&self.options
}
pub fn build_insert_bulk_statement(&self) -> Result<String, Error> {
validate_qualified_identifier(&self.table_name)?;
for col in &self.columns {
validate_identifier(&col.name)?;
}
let mut sql = format!("INSERT BULK {}", self.table_name);
if !self.columns.is_empty() {
sql.push_str(" (");
let cols: Vec<String> = self
.columns
.iter()
.map(|c| {
validate_sql_type(&c.sql_type)?;
Ok(format!("{} {}", c.name, c.sql_type))
})
.collect::<Result<Vec<_>, Error>>()?;
sql.push_str(&cols.join(", "));
sql.push(')');
}
let mut hints: Vec<String> = Vec::new();
if self.options.check_constraints {
hints.push("CHECK_CONSTRAINTS".to_string());
}
if self.options.fire_triggers {
hints.push("FIRE_TRIGGERS".to_string());
}
if self.options.keep_nulls {
hints.push("KEEP_NULLS".to_string());
}
if self.options.table_lock {
hints.push("TABLOCK".to_string());
}
if self.options.batch_size > 0 {
hints.push(format!("ROWS_PER_BATCH = {}", self.options.batch_size));
}
if let Some(ref order) = self.options.order_hint {
for col_name in order {
validate_identifier(col_name)?;
}
hints.push(format!("ORDER({})", order.join(", ")));
}
if !hints.is_empty() {
sql.push_str(" WITH (");
sql.push_str(&hints.join(", "));
sql.push(')');
}
Ok(sql)
}
}
fn validate_sql_type(type_str: &str) -> Result<(), Error> {
#[allow(clippy::expect_used)] static SQL_TYPE_RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r"^[a-zA-Z][a-zA-Z0-9_ ()\.,]{0,127}$").expect("valid regex"));
if type_str.is_empty() {
return Err(Error::Config("SQL type cannot be empty".into()));
}
if !SQL_TYPE_RE.is_match(type_str) {
return Err(Error::Config(format!(
"invalid SQL type '{type_str}': contains disallowed characters"
)));
}
Ok(())
}
fn validate_identifier(name: &str) -> Result<(), Error> {
#[allow(clippy::expect_used)] static IDENTIFIER_RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_@#$]{0,127}$").expect("valid regex"));
if name.is_empty() {
return Err(Error::InvalidIdentifier(
"identifier cannot be empty".into(),
));
}
if !IDENTIFIER_RE.is_match(name) {
return Err(Error::InvalidIdentifier(format!(
"invalid identifier '{name}': must start with letter/underscore, \
contain only alphanumerics/_/@/#/$, and be 1-128 characters"
)));
}
Ok(())
}
fn validate_qualified_identifier(name: &str) -> Result<(), Error> {
if name.is_empty() {
return Err(Error::InvalidIdentifier(
"identifier cannot be empty".into(),
));
}
let parts: Vec<&str> = name.split('.').collect();
if parts.len() > 4 {
return Err(Error::InvalidIdentifier(format!(
"invalid qualified identifier '{name}': too many parts (max 4: server.catalog.schema.object)"
)));
}
for part in &parts {
validate_identifier(part)?;
}
Ok(())
}
pub struct BulkInsert {
columns: Arc<[BulkColumn]>,
buffer: BytesMut,
rows_in_batch: usize,
total_rows: u64,
batch_size: usize,
batches_committed: u32,
packet_id: u8,
}
impl BulkInsert {
pub fn new(columns: Vec<BulkColumn>, batch_size: usize) -> Self {
let mut bulk = Self {
columns: columns.into(),
buffer: BytesMut::with_capacity(64 * 1024), rows_in_batch: 0,
total_rows: 0,
batch_size,
batches_committed: 0,
packet_id: 1,
};
bulk.write_colmetadata();
bulk
}
fn write_colmetadata(&mut self) {
let buf = &mut self.buffer;
buf.put_u8(TokenType::ColMetaData as u8);
buf.put_u16_le(self.columns.len() as u16);
for col in self.columns.iter() {
buf.put_u32_le(0);
let flags: u16 = if col.nullable { 0x0001 } else { 0x0000 };
buf.put_u16_le(flags);
buf.put_u8(col.type_id);
match col.type_id {
0x32 | 0x30 | 0x34 | 0x38 | 0x7F | 0x3B | 0x3E | 0x3D | 0x3F | 0x28 => {}
0xE7 | 0xA7 | 0xA5 | 0xAD => {
let max_len = col.max_length.unwrap_or(8000);
if max_len == 0xFFFF {
buf.put_u16_le(0xFFFF);
} else {
buf.put_u16_le(max_len as u16);
}
if col.type_id == 0xE7 || col.type_id == 0xA7 {
buf.put_u32_le(0x0409_0904); buf.put_u8(52); }
}
0x6C | 0x6A => {
let precision = col.precision.unwrap_or(18);
let len = decimal_byte_length(precision);
buf.put_u8(len);
buf.put_u8(precision);
buf.put_u8(col.scale.unwrap_or(0));
}
0x29..=0x2B => {
buf.put_u8(col.scale.unwrap_or(7));
}
0x24 => {
buf.put_u8(16);
}
_ => {
if let Some(len) = col.max_length {
if len <= 0xFFFF {
buf.put_u16_le(len as u16);
}
}
}
}
let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
buf.put_u8(name_utf16.len() as u8);
for code_unit in name_utf16 {
buf.put_u16_le(code_unit);
}
}
}
pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
if values.len() != self.columns.len() {
return Err(Error::Config(format!(
"expected {} values, got {}",
self.columns.len(),
values.len()
)));
}
let sql_values: Result<Vec<SqlValue>, TypeError> =
values.iter().map(|v| v.to_sql()).collect();
let sql_values = sql_values.map_err(Error::from)?;
self.write_row(&sql_values)?;
self.rows_in_batch += 1;
self.total_rows += 1;
Ok(())
}
pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
if values.len() != self.columns.len() {
return Err(Error::Config(format!(
"expected {} values, got {}",
self.columns.len(),
values.len()
)));
}
self.write_row(values)?;
self.rows_in_batch += 1;
self.total_rows += 1;
Ok(())
}
fn write_row(&mut self, values: &[SqlValue]) -> Result<(), Error> {
self.buffer.put_u8(TokenType::Row as u8);
let columns: Vec<_> = self.columns.iter().cloned().collect();
for (i, (col, value)) in columns.iter().zip(values.iter()).enumerate() {
self.encode_column_value(col, value)
.map_err(|e| Error::Config(format!("failed to encode column {i}: {e}")))?;
}
Ok(())
}
fn encode_column_value(&mut self, col: &BulkColumn, value: &SqlValue) -> Result<(), TypeError> {
let buf = &mut self.buffer;
let is_plp_type =
col.max_length == Some(0xFFFF) && matches!(col.type_id, 0xE7 | 0xA7 | 0xA5 | 0xAD);
match value {
SqlValue::Null => {
match col.type_id {
0xE7 | 0xA7 | 0xA5 | 0xAD => {
if is_plp_type {
buf.put_u64_le(0xFFFF_FFFF_FFFF_FFFF);
} else {
buf.put_u16_le(0xFFFF);
}
}
0x26 | 0x6C | 0x6A | 0x24 | 0x29 | 0x2A | 0x2B => {
buf.put_u8(0);
}
_ => {
if col.nullable {
buf.put_u8(0);
} else {
return Err(TypeError::UnexpectedNull);
}
}
}
}
SqlValue::Bool(v) => {
buf.put_u8(1); buf.put_u8(if *v { 1 } else { 0 });
}
SqlValue::TinyInt(v) => {
buf.put_u8(1); buf.put_u8(*v);
}
SqlValue::SmallInt(v) => {
buf.put_u8(2); buf.put_i16_le(*v);
}
SqlValue::Int(v) => {
buf.put_u8(4); buf.put_i32_le(*v);
}
SqlValue::BigInt(v) => {
buf.put_u8(8); buf.put_i64_le(*v);
}
SqlValue::Float(v) => {
buf.put_u8(4); buf.put_f32_le(*v);
}
SqlValue::Double(v) => {
buf.put_u8(8); buf.put_f64_le(*v);
}
SqlValue::String(s) => {
let utf16: Vec<u16> = s.encode_utf16().collect();
let byte_len = utf16.len() * 2;
if is_plp_type {
encode_plp_string(&utf16, buf);
} else if byte_len > 0xFFFF {
return Err(TypeError::BufferTooSmall {
needed: byte_len,
available: 0xFFFF,
});
} else {
buf.put_u16_le(byte_len as u16);
for code_unit in utf16 {
buf.put_u16_le(code_unit);
}
}
}
SqlValue::Binary(b) => {
if is_plp_type {
encode_plp_binary(b, buf);
} else if b.len() > 0xFFFF {
return Err(TypeError::BufferTooSmall {
needed: b.len(),
available: 0xFFFF,
});
} else {
buf.put_u16_le(b.len() as u16);
buf.put_slice(b);
}
}
#[cfg(feature = "decimal")]
SqlValue::Decimal(d) => {
let precision = col.precision.unwrap_or(18);
let len = decimal_byte_length(precision);
buf.put_u8(len);
buf.put_u8(if d.is_sign_negative() { 0 } else { 1 });
let mantissa = d.mantissa().unsigned_abs();
let mantissa_bytes = mantissa.to_le_bytes();
buf.put_slice(&mantissa_bytes[..((len - 1) as usize)]);
}
#[cfg(feature = "uuid")]
SqlValue::Uuid(u) => {
buf.put_u8(16); mssql_types::encode::encode_uuid(*u, buf);
}
#[cfg(feature = "chrono")]
SqlValue::Date(d) => {
buf.put_u8(3); mssql_types::encode::encode_date(*d, buf);
}
#[cfg(feature = "chrono")]
SqlValue::Time(t) => {
let scale = col.scale.unwrap_or(7);
let len = time_byte_length(scale);
buf.put_u8(len);
encode_time_with_scale(*t, scale, buf);
}
#[cfg(feature = "chrono")]
SqlValue::DateTime(dt) => {
let scale = col.scale.unwrap_or(7);
let time_len = time_byte_length(scale);
let total_len = time_len + 3;
buf.put_u8(total_len);
encode_time_with_scale(dt.time(), scale, buf);
mssql_types::encode::encode_date(dt.date(), buf);
}
#[cfg(feature = "chrono")]
SqlValue::DateTimeOffset(dto) => {
let scale = col.scale.unwrap_or(7);
let time_len = time_byte_length(scale);
let total_len = time_len + 3 + 2;
buf.put_u8(total_len);
encode_time_with_scale(dto.time(), scale, buf);
mssql_types::encode::encode_date(dto.date_naive(), buf);
use chrono::Offset;
let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
buf.put_i16_le(offset_minutes);
}
#[cfg(feature = "json")]
SqlValue::Json(j) => {
let s = j.to_string();
encode_nvarchar_value(&s, buf)?;
}
SqlValue::Xml(x) => {
encode_nvarchar_value(x, buf)?;
}
SqlValue::Tvp(_) => {
return Err(TypeError::UnsupportedConversion {
from: "TVP".to_string(),
to: "bulk copy value",
});
}
_ => {
return Err(TypeError::UnsupportedConversion {
from: value.type_name().to_string(),
to: "bulk copy value",
});
}
}
Ok(())
}
}
fn encode_nvarchar_value(s: &str, buf: &mut BytesMut) -> Result<(), TypeError> {
let utf16: Vec<u16> = s.encode_utf16().collect();
let byte_len = utf16.len() * 2;
if byte_len > 0xFFFF {
return Err(TypeError::BufferTooSmall {
needed: byte_len,
available: 0xFFFF,
});
}
buf.put_u16_le(byte_len as u16);
for code_unit in utf16 {
buf.put_u16_le(code_unit);
}
Ok(())
}
fn encode_plp_string(utf16: &[u16], buf: &mut BytesMut) {
let byte_len = utf16.len() * 2;
buf.put_u64_le(byte_len as u64);
if byte_len > 0 {
buf.put_u32_le(byte_len as u32);
for code_unit in utf16 {
buf.put_u16_le(*code_unit);
}
}
buf.put_u32_le(0);
}
fn encode_plp_binary(data: &[u8], buf: &mut BytesMut) {
buf.put_u64_le(data.len() as u64);
if !data.is_empty() {
buf.put_u32_le(data.len() as u32);
buf.put_slice(data);
}
buf.put_u32_le(0);
}
#[cfg(feature = "chrono")]
fn encode_time_with_scale(time: chrono::NaiveTime, scale: u8, buf: &mut BytesMut) {
use chrono::Timelike;
let nanos = time.num_seconds_from_midnight() as u64 * 1_000_000_000 + time.nanosecond() as u64;
let intervals = nanos / time_scale_divisor(scale);
let len = time_byte_length(scale);
for i in 0..len {
buf.put_u8(((intervals >> (i * 8)) & 0xFF) as u8);
}
}
impl BulkInsert {
fn write_done(&mut self) {
let buf = &mut self.buffer;
buf.put_u8(TokenType::Done as u8);
let status = DoneStatus::from_bits(0x0010); buf.put_u16_le(status.to_bits());
buf.put_u16_le(0);
buf.put_u64_le(self.total_rows);
}
pub fn take_packets(&mut self) -> Vec<BytesMut> {
const MAX_PACKET_SIZE: usize = 4096;
const HEADER_SIZE: usize = 8;
const MAX_PAYLOAD: usize = MAX_PACKET_SIZE - HEADER_SIZE;
let data = self.buffer.split();
let mut packets = Vec::new();
let mut offset = 0;
while offset < data.len() {
let remaining = data.len() - offset;
let payload_size = remaining.min(MAX_PAYLOAD);
let is_last = offset + payload_size >= data.len();
let mut packet = BytesMut::with_capacity(MAX_PACKET_SIZE);
let header = PacketHeader {
packet_type: PacketType::BulkLoad,
status: if is_last {
PacketStatus::END_OF_MESSAGE
} else {
PacketStatus::NORMAL
},
length: (HEADER_SIZE + payload_size) as u16,
spid: 0,
packet_id: self.packet_id,
window: 0,
};
header.encode(&mut packet);
packet.put_slice(&data[offset..offset + payload_size]);
packets.push(packet);
offset += payload_size;
self.packet_id = self.packet_id.wrapping_add(1);
}
packets
}
pub fn total_rows(&self) -> u64 {
self.total_rows
}
pub fn rows_in_batch(&self) -> usize {
self.rows_in_batch
}
pub fn should_flush(&self) -> bool {
self.batch_size > 0 && self.rows_in_batch >= self.batch_size
}
pub fn finish_packets(&mut self) -> Vec<BytesMut> {
self.write_done();
self.take_packets()
}
pub fn result(&self) -> BulkInsertResult {
BulkInsertResult {
rows_affected: self.total_rows,
batches_committed: self.batches_committed,
has_errors: false,
}
}
}
fn decimal_byte_length(precision: u8) -> u8 {
match precision {
1..=9 => 5,
10..=19 => 9,
20..=28 => 13,
29..=38 => 17,
_ => 17, }
}
#[cfg(feature = "chrono")]
fn time_byte_length(scale: u8) -> u8 {
match scale {
0..=2 => 3,
3..=4 => 4,
5..=7 => 5,
_ => 5,
}
}
#[cfg(feature = "chrono")]
fn time_scale_divisor(scale: u8) -> u64 {
match scale {
0 => 1_000_000_000,
1 => 100_000_000,
2 => 10_000_000,
3 => 1_000_000,
4 => 100_000,
5 => 10_000,
6 => 1_000,
7 => 100,
_ => 100,
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_bulk_options_default() {
let opts = BulkOptions::default();
assert_eq!(opts.batch_size, 0);
assert!(opts.check_constraints);
assert!(!opts.fire_triggers);
assert!(opts.keep_nulls);
assert!(!opts.table_lock);
}
#[test]
fn test_bulk_column_creation() {
let col = BulkColumn::new("id", "INT", 0);
assert_eq!(col.name, "id");
assert_eq!(col.type_id, 0x38);
assert!(col.nullable);
}
#[test]
fn test_parse_sql_type() {
let (type_id, len, _prec, _scale) = parse_sql_type("INT");
assert_eq!(type_id, 0x38);
assert!(len.is_none());
let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
assert_eq!(type_id, 0xE7);
assert_eq!(len, Some(200));
let (type_id, _, prec, scale) = parse_sql_type("DECIMAL(10,2)");
assert_eq!(type_id, 0x6C);
assert_eq!(prec, Some(10));
assert_eq!(scale, Some(2));
}
#[test]
fn test_insert_bulk_statement() {
let builder = BulkInsertBuilder::new("dbo.Users")
.with_typed_columns(vec![
BulkColumn::new("id", "INT", 0),
BulkColumn::new("name", "NVARCHAR(100)", 1),
])
.table_lock(true);
let sql = builder.build_insert_bulk_statement().unwrap();
assert!(sql.contains("INSERT BULK dbo.Users"));
assert!(sql.contains("TABLOCK"));
}
#[test]
fn test_bulk_insert_rejects_injection() {
let builder = BulkInsertBuilder::new("table;DROP TABLE users")
.with_typed_columns(vec![BulkColumn::new("id", "INT", 0)]);
assert!(builder.build_insert_bulk_statement().is_err());
}
#[test]
fn test_bulk_insert_validates_column_names() {
let builder = BulkInsertBuilder::new("Users").with_typed_columns(vec![BulkColumn::new(
"col;DROP TABLE x",
"INT",
0,
)]);
assert!(builder.build_insert_bulk_statement().is_err());
}
#[test]
fn test_bulk_insert_accepts_qualified_names() {
let builder = BulkInsertBuilder::new("catalog.dbo.Users")
.with_typed_columns(vec![BulkColumn::new("id", "INT", 0)]);
assert!(builder.build_insert_bulk_statement().is_ok());
}
#[test]
fn test_bulk_insert_creation() {
let columns = vec![
BulkColumn::new("id", "INT", 0),
BulkColumn::new("name", "NVARCHAR(100)", 1),
];
let bulk = BulkInsert::new(columns, 1000);
assert_eq!(bulk.total_rows(), 0);
assert_eq!(bulk.rows_in_batch(), 0);
assert!(!bulk.should_flush());
}
#[test]
fn test_decimal_byte_length() {
assert_eq!(decimal_byte_length(5), 5);
assert_eq!(decimal_byte_length(15), 9);
assert_eq!(decimal_byte_length(25), 13);
assert_eq!(decimal_byte_length(35), 17);
}
#[test]
#[cfg(feature = "chrono")]
fn test_time_byte_length() {
assert_eq!(time_byte_length(0), 3);
assert_eq!(time_byte_length(3), 4);
assert_eq!(time_byte_length(7), 5);
}
#[test]
fn test_plp_string_encoding() {
let mut buf = BytesMut::new();
let text = "Hello";
let utf16: Vec<u16> = text.encode_utf16().collect();
encode_plp_string(&utf16, &mut buf);
assert_eq!(buf.len(), 8 + 4 + 10 + 4);
assert_eq!(&buf[0..8], &10u64.to_le_bytes());
assert_eq!(&buf[8..12], &10u32.to_le_bytes());
assert_eq!(&buf[22..26], &0u32.to_le_bytes());
}
#[test]
fn test_plp_binary_encoding() {
let mut buf = BytesMut::new();
let data = b"test binary data";
encode_plp_binary(data, &mut buf);
assert_eq!(buf.len(), 8 + 4 + 16 + 4);
assert_eq!(&buf[0..8], &16u64.to_le_bytes());
assert_eq!(&buf[8..12], &16u32.to_le_bytes());
assert_eq!(&buf[12..28], data);
assert_eq!(&buf[28..32], &0u32.to_le_bytes());
}
#[test]
fn test_plp_empty_string() {
let mut buf = BytesMut::new();
let utf16: Vec<u16> = "".encode_utf16().collect();
encode_plp_string(&utf16, &mut buf);
assert_eq!(buf.len(), 8 + 4);
assert_eq!(&buf[0..8], &0u64.to_le_bytes());
assert_eq!(&buf[8..12], &0u32.to_le_bytes());
}
#[test]
fn test_plp_empty_binary() {
let mut buf = BytesMut::new();
encode_plp_binary(&[], &mut buf);
assert_eq!(buf.len(), 8 + 4);
assert_eq!(&buf[0..8], &0u64.to_le_bytes());
assert_eq!(&buf[8..12], &0u32.to_le_bytes());
}
#[test]
fn test_parse_sql_type_max() {
let (type_id, len, _, _) = parse_sql_type("NVARCHAR(MAX)");
assert_eq!(type_id, 0xE7);
assert_eq!(len, Some(0xFFFF));
let (type_id, len, _, _) = parse_sql_type("VARBINARY(MAX)");
assert_eq!(type_id, 0xA5);
assert_eq!(len, Some(0xFFFF));
let (type_id, len, _, _) = parse_sql_type("VARCHAR(MAX)");
assert_eq!(type_id, 0xA7);
assert_eq!(len, Some(0xFFFF));
let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
assert_eq!(type_id, 0xE7);
assert_eq!(len, Some(200)); }
}