1use bytes::{BufMut, BytesMut};
53use once_cell::sync::Lazy;
54use regex::Regex;
55use std::sync::Arc;
56
57use mssql_types::{SqlValue, ToSql, TypeError};
58use tds_protocol::packet::{PacketHeader, PacketStatus, PacketType};
59use tds_protocol::token::{DoneStatus, TokenType};
60
61use crate::error::Error;
62
63#[derive(Debug, Clone)]
68pub struct BulkOptions {
69 pub batch_size: usize,
75
76 pub check_constraints: bool,
80
81 pub fire_triggers: bool,
85
86 pub keep_nulls: bool,
90
91 pub table_lock: bool,
97
98 pub order_hint: Option<Vec<String>>,
104
105 pub max_errors: u32,
109}
110
111impl Default for BulkOptions {
112 fn default() -> Self {
113 Self {
114 batch_size: 0,
115 check_constraints: true,
116 fire_triggers: false,
117 keep_nulls: true,
118 table_lock: false,
119 order_hint: None,
120 max_errors: 0,
121 }
122 }
123}
124
125#[derive(Debug, Clone)]
127pub struct BulkColumn {
128 pub name: String,
130 pub sql_type: String,
132 pub nullable: bool,
134 pub ordinal: usize,
136 type_id: u8,
138 max_length: Option<u32>,
140 precision: Option<u8>,
142 scale: Option<u8>,
144}
145
146impl BulkColumn {
147 pub fn new<S: Into<String>>(name: S, sql_type: S, ordinal: usize) -> Self {
149 let sql_type_str: String = sql_type.into();
150 let (type_id, max_length, precision, scale) = parse_sql_type(&sql_type_str);
151
152 Self {
153 name: name.into(),
154 sql_type: sql_type_str,
155 nullable: true,
156 ordinal,
157 type_id,
158 max_length,
159 precision,
160 scale,
161 }
162 }
163
164 #[must_use]
166 pub fn with_nullable(mut self, nullable: bool) -> Self {
167 self.nullable = nullable;
168 self
169 }
170}
171
172fn parse_sql_type(sql_type: &str) -> (u8, Option<u32>, Option<u8>, Option<u8>) {
181 let upper = sql_type.to_uppercase();
182
183 let (base, params) = if let Some(paren_pos) = upper.find('(') {
185 let base = &upper[..paren_pos];
186 let params_str = upper[paren_pos + 1..].trim_end_matches(')');
187 (base, Some(params_str))
188 } else {
189 (upper.as_str(), None)
190 };
191
192 match base {
193 "BIT" => (0x32, None, None, None),
194 "TINYINT" => (0x30, None, None, None),
195 "SMALLINT" => (0x34, None, None, None),
196 "INT" => (0x38, None, None, None),
197 "BIGINT" => (0x7F, None, None, None),
198 "REAL" => (0x3B, None, None, None),
199 "FLOAT" => (0x3E, None, None, None),
200 "DATE" => (0x28, None, None, None),
201 "TIME" => {
202 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
203 (0x29, None, None, Some(scale))
204 }
205 "DATETIME" => (0x3D, None, None, None),
206 "DATETIME2" => {
207 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
208 (0x2A, None, None, Some(scale))
209 }
210 "DATETIMEOFFSET" => {
211 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
212 (0x2B, None, None, Some(scale))
213 }
214 "SMALLDATETIME" => (0x3F, None, None, None),
215 "UNIQUEIDENTIFIER" => (0x24, Some(16), None, None),
216 "VARCHAR" | "CHAR" => {
217 let len = params
218 .and_then(|p| {
219 if p == "MAX" {
220 Some(0xFFFF_u32)
221 } else {
222 p.parse().ok()
223 }
224 })
225 .unwrap_or(8000);
226 (0xA7, Some(len), None, None)
227 }
228 "NVARCHAR" | "NCHAR" => {
229 let is_max = params.map(|p| p == "MAX").unwrap_or(false);
230 if is_max {
231 (0xE7, Some(0xFFFF), None, None)
233 } else {
234 let len = params.and_then(|p| p.parse().ok()).unwrap_or(4000);
236 (0xE7, Some(len * 2), None, None)
237 }
238 }
239 "VARBINARY" | "BINARY" => {
240 let len = params
241 .and_then(|p| {
242 if p == "MAX" {
243 Some(0xFFFF_u32)
244 } else {
245 p.parse().ok()
246 }
247 })
248 .unwrap_or(8000);
249 (0xA5, Some(len), None, None)
250 }
251 "DECIMAL" | "NUMERIC" => {
252 let (precision, scale) = if let Some(p) = params {
253 let parts: Vec<&str> = p.split(',').map(|s| s.trim()).collect();
254 (
255 parts.first().and_then(|s| s.parse().ok()).unwrap_or(18),
256 parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0),
257 )
258 } else {
259 (18, 0)
260 };
261 (0x6C, None, Some(precision), Some(scale))
262 }
263 "MONEY" => (0x3C, Some(8), None, None),
264 "SMALLMONEY" => (0x7A, Some(4), None, None),
265 "XML" => (0xF1, Some(0xFFFF), None, None),
266 "TEXT" => (0x23, Some(0x7FFF_FFFF), None, None),
267 "NTEXT" => (0x63, Some(0x7FFF_FFFF), None, None),
268 "IMAGE" => (0x22, Some(0x7FFF_FFFF), None, None),
269 _ => (0xE7, Some(8000), None, None), }
271}
272
273#[derive(Debug, Clone)]
275pub struct BulkInsertResult {
276 pub rows_affected: u64,
278 pub batches_committed: u32,
280 pub has_errors: bool,
282}
283
284#[derive(Debug)]
286pub struct BulkInsertBuilder {
287 table_name: String,
288 columns: Vec<BulkColumn>,
289 options: BulkOptions,
290}
291
292impl BulkInsertBuilder {
293 pub fn new<S: Into<String>>(table_name: S) -> Self {
295 Self {
296 table_name: table_name.into(),
297 columns: Vec::new(),
298 options: BulkOptions::default(),
299 }
300 }
301
302 #[must_use]
307 pub fn with_columns(mut self, column_names: &[&str]) -> Self {
308 self.columns = column_names
309 .iter()
310 .enumerate()
311 .map(|(i, name)| BulkColumn::new(*name, "NVARCHAR(MAX)", i))
312 .collect();
313 self
314 }
315
316 #[must_use]
318 pub fn with_typed_columns(mut self, columns: Vec<BulkColumn>) -> Self {
319 self.columns = columns;
320 self
321 }
322
323 #[must_use]
325 pub fn with_options(mut self, options: BulkOptions) -> Self {
326 self.options = options;
327 self
328 }
329
330 #[must_use]
332 pub fn batch_size(mut self, size: usize) -> Self {
333 self.options.batch_size = size;
334 self
335 }
336
337 #[must_use]
339 pub fn table_lock(mut self, enabled: bool) -> Self {
340 self.options.table_lock = enabled;
341 self
342 }
343
344 #[must_use]
346 pub fn fire_triggers(mut self, enabled: bool) -> Self {
347 self.options.fire_triggers = enabled;
348 self
349 }
350
351 pub fn table_name(&self) -> &str {
353 &self.table_name
354 }
355
356 pub fn columns(&self) -> &[BulkColumn] {
358 &self.columns
359 }
360
361 pub fn options(&self) -> &BulkOptions {
363 &self.options
364 }
365
366 pub fn build_insert_bulk_statement(&self) -> Result<String, Error> {
373 crate::validation::validate_qualified_identifier(&self.table_name)?;
375
376 for col in &self.columns {
378 crate::validation::validate_identifier(&col.name)?;
379 }
380
381 let mut sql = format!("INSERT BULK {}", self.table_name);
382
383 if !self.columns.is_empty() {
385 sql.push_str(" (");
386 let cols: Vec<String> = self
387 .columns
388 .iter()
389 .map(|c| {
390 validate_sql_type(&c.sql_type)?;
396 Ok(format!("{} {}", c.name, c.sql_type))
397 })
398 .collect::<Result<Vec<_>, Error>>()?;
399 sql.push_str(&cols.join(", "));
400 sql.push(')');
401 }
402
403 let mut hints: Vec<String> = Vec::new();
405
406 if self.options.check_constraints {
407 hints.push("CHECK_CONSTRAINTS".to_string());
408 }
409 if self.options.fire_triggers {
410 hints.push("FIRE_TRIGGERS".to_string());
411 }
412 if self.options.keep_nulls {
413 hints.push("KEEP_NULLS".to_string());
414 }
415 if self.options.table_lock {
416 hints.push("TABLOCK".to_string());
417 }
418 if self.options.batch_size > 0 {
419 hints.push(format!("ROWS_PER_BATCH = {}", self.options.batch_size));
420 }
421
422 if let Some(ref order) = self.options.order_hint {
423 for col_name in order {
425 crate::validation::validate_identifier(col_name)?;
426 }
427 hints.push(format!("ORDER({})", order.join(", ")));
428 }
429
430 if !hints.is_empty() {
431 sql.push_str(" WITH (");
432 sql.push_str(&hints.join(", "));
433 sql.push(')');
434 }
435
436 Ok(sql)
437 }
438}
439
440fn validate_sql_type(type_str: &str) -> Result<(), Error> {
446 #[allow(clippy::expect_used)] static SQL_TYPE_RE: Lazy<Regex> =
448 Lazy::new(|| Regex::new(r"^[a-zA-Z][a-zA-Z0-9_ ()\.,]{0,127}$").expect("valid regex"));
449
450 if type_str.is_empty() {
451 return Err(Error::Config("SQL type cannot be empty".into()));
452 }
453
454 if !SQL_TYPE_RE.is_match(type_str) {
455 return Err(Error::Config(format!(
456 "invalid SQL type '{type_str}': contains disallowed characters"
457 )));
458 }
459
460 Ok(())
461}
462
463pub struct BulkInsert {
468 columns: Arc<[BulkColumn]>,
470 buffer: BytesMut,
472 rows_in_batch: usize,
474 total_rows: u64,
476 batch_size: usize,
478 batches_committed: u32,
480 packet_id: u8,
482}
483
484impl BulkInsert {
485 pub fn new(columns: Vec<BulkColumn>, batch_size: usize) -> Self {
487 let mut bulk = Self {
488 columns: columns.into(),
489 buffer: BytesMut::with_capacity(64 * 1024), rows_in_batch: 0,
491 total_rows: 0,
492 batch_size,
493 batches_committed: 0,
494 packet_id: 1,
495 };
496
497 bulk.write_colmetadata();
499
500 bulk
501 }
502
503 fn write_colmetadata(&mut self) {
505 let buf = &mut self.buffer;
506
507 buf.put_u8(TokenType::ColMetaData as u8);
509
510 buf.put_u16_le(self.columns.len() as u16);
512
513 for col in self.columns.iter() {
514 buf.put_u32_le(0);
516
517 let flags: u16 = if col.nullable { 0x0001 } else { 0x0000 };
519 buf.put_u16_le(flags);
520
521 buf.put_u8(col.type_id);
523
524 match col.type_id {
526 0x32 | 0x30 | 0x34 | 0x38 | 0x7F | 0x3B | 0x3E | 0x3D | 0x3F | 0x28 => {}
528
529 0xE7 | 0xA7 | 0xA5 | 0xAD => {
531 let max_len = col.max_length.unwrap_or(8000);
533 if max_len == 0xFFFF {
534 buf.put_u16_le(0xFFFF);
535 } else {
536 buf.put_u16_le(max_len as u16);
537 }
538
539 if col.type_id == 0xE7 || col.type_id == 0xA7 {
541 buf.put_u32_le(0x0409_0904); buf.put_u8(52); }
545 }
546
547 0x6C | 0x6A => {
549 let precision = col.precision.unwrap_or(18);
551 let len = decimal_byte_length(precision);
552 buf.put_u8(len);
553 buf.put_u8(precision);
554 buf.put_u8(col.scale.unwrap_or(0));
555 }
556
557 0x29..=0x2B => {
559 buf.put_u8(col.scale.unwrap_or(7));
560 }
561
562 0x24 => {
564 buf.put_u8(16);
565 }
566
567 _ => {
569 if let Some(len) = col.max_length {
570 if len <= 0xFFFF {
571 buf.put_u16_le(len as u16);
572 }
573 }
574 }
575 }
576
577 let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
579 buf.put_u8(name_utf16.len() as u8);
580 for code_unit in name_utf16 {
581 buf.put_u16_le(code_unit);
582 }
583 }
584 }
585
586 pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
597 if values.len() != self.columns.len() {
598 return Err(Error::Config(format!(
599 "expected {} values, got {}",
600 self.columns.len(),
601 values.len()
602 )));
603 }
604
605 let sql_values: Result<Vec<SqlValue>, TypeError> =
607 values.iter().map(|v| v.to_sql()).collect();
608 let sql_values = sql_values.map_err(Error::from)?;
609
610 self.write_row(&sql_values)?;
611
612 self.rows_in_batch += 1;
613 self.total_rows += 1;
614
615 Ok(())
616 }
617
618 pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
620 if values.len() != self.columns.len() {
621 return Err(Error::Config(format!(
622 "expected {} values, got {}",
623 self.columns.len(),
624 values.len()
625 )));
626 }
627
628 self.write_row(values)?;
629
630 self.rows_in_batch += 1;
631 self.total_rows += 1;
632
633 Ok(())
634 }
635
636 fn write_row(&mut self, values: &[SqlValue]) -> Result<(), Error> {
638 self.buffer.put_u8(TokenType::Row as u8);
640
641 let columns: Vec<_> = self.columns.iter().cloned().collect();
643
644 for (i, (col, value)) in columns.iter().zip(values.iter()).enumerate() {
646 self.encode_column_value(col, value)
647 .map_err(|e| Error::Config(format!("failed to encode column {i}: {e}")))?;
648 }
649
650 Ok(())
651 }
652
653 fn encode_column_value(&mut self, col: &BulkColumn, value: &SqlValue) -> Result<(), TypeError> {
655 let buf = &mut self.buffer;
656
657 let is_plp_type =
660 col.max_length == Some(0xFFFF) && matches!(col.type_id, 0xE7 | 0xA7 | 0xA5 | 0xAD);
661
662 match value {
663 SqlValue::Null => {
664 match col.type_id {
666 0xE7 | 0xA7 | 0xA5 | 0xAD => {
668 if is_plp_type {
669 buf.put_u64_le(0xFFFF_FFFF_FFFF_FFFF);
671 } else {
672 buf.put_u16_le(0xFFFF);
674 }
675 }
676 0x26 | 0x6C | 0x6A | 0x24 | 0x29 | 0x2A | 0x2B => {
678 buf.put_u8(0);
679 }
680 _ => {
682 if col.nullable {
683 buf.put_u8(0);
684 } else {
685 return Err(TypeError::UnexpectedNull);
686 }
687 }
688 }
689 }
690
691 SqlValue::Bool(v) => {
692 buf.put_u8(1); buf.put_u8(if *v { 1 } else { 0 });
694 }
695
696 SqlValue::TinyInt(v) => {
697 buf.put_u8(1); buf.put_u8(*v);
699 }
700
701 SqlValue::SmallInt(v) => {
702 buf.put_u8(2); buf.put_i16_le(*v);
704 }
705
706 SqlValue::Int(v) => {
707 buf.put_u8(4); buf.put_i32_le(*v);
709 }
710
711 SqlValue::BigInt(v) => {
712 buf.put_u8(8); buf.put_i64_le(*v);
714 }
715
716 SqlValue::Float(v) => {
717 buf.put_u8(4); buf.put_f32_le(*v);
719 }
720
721 SqlValue::Double(v) => {
722 buf.put_u8(8); buf.put_f64_le(*v);
724 }
725
726 SqlValue::String(s) => {
727 let utf16: Vec<u16> = s.encode_utf16().collect();
729 let byte_len = utf16.len() * 2;
730
731 if is_plp_type {
732 encode_plp_string(&utf16, buf);
735 } else if byte_len > 0xFFFF {
736 return Err(TypeError::BufferTooSmall {
738 needed: byte_len,
739 available: 0xFFFF,
740 });
741 } else {
742 buf.put_u16_le(byte_len as u16);
744 for code_unit in utf16 {
745 buf.put_u16_le(code_unit);
746 }
747 }
748 }
749
750 SqlValue::Binary(b) => {
751 if is_plp_type {
752 encode_plp_binary(b, buf);
754 } else if b.len() > 0xFFFF {
755 return Err(TypeError::BufferTooSmall {
757 needed: b.len(),
758 available: 0xFFFF,
759 });
760 } else {
761 buf.put_u16_le(b.len() as u16);
763 buf.put_slice(b);
764 }
765 }
766
767 #[cfg(feature = "decimal")]
769 SqlValue::Decimal(d) => {
770 let precision = col.precision.unwrap_or(18);
771 let len = decimal_byte_length(precision);
772 buf.put_u8(len);
773
774 buf.put_u8(if d.is_sign_negative() { 0 } else { 1 });
776
777 let mantissa = d.mantissa().unsigned_abs();
779 let mantissa_bytes = mantissa.to_le_bytes();
780 buf.put_slice(&mantissa_bytes[..((len - 1) as usize)]);
781 }
782
783 #[cfg(feature = "uuid")]
784 SqlValue::Uuid(u) => {
785 buf.put_u8(16); mssql_types::encode::encode_uuid(*u, buf);
788 }
789
790 #[cfg(feature = "chrono")]
791 SqlValue::Date(d) => {
792 buf.put_u8(3); mssql_types::encode::encode_date(*d, buf);
794 }
795
796 #[cfg(feature = "chrono")]
797 SqlValue::Time(t) => {
798 let scale = col.scale.unwrap_or(7);
799 let len = time_byte_length(scale);
800 buf.put_u8(len);
801 encode_time_with_scale(*t, scale, buf);
803 }
804
805 #[cfg(feature = "chrono")]
806 SqlValue::DateTime(dt) => {
807 let scale = col.scale.unwrap_or(7);
808 let time_len = time_byte_length(scale);
809 let total_len = time_len + 3;
810 buf.put_u8(total_len);
811 encode_time_with_scale(dt.time(), scale, buf);
813 mssql_types::encode::encode_date(dt.date(), buf);
814 }
815
816 #[cfg(feature = "chrono")]
817 SqlValue::DateTimeOffset(dto) => {
818 let scale = col.scale.unwrap_or(7);
819 let time_len = time_byte_length(scale);
820 let total_len = time_len + 3 + 2;
821 buf.put_u8(total_len);
822 encode_time_with_scale(dto.time(), scale, buf);
824 mssql_types::encode::encode_date(dto.date_naive(), buf);
825 use chrono::Offset;
827 let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
828 buf.put_i16_le(offset_minutes);
829 }
830
831 #[cfg(feature = "json")]
832 SqlValue::Json(j) => {
833 let s = j.to_string();
834 encode_nvarchar_value(&s, buf)?;
835 }
836
837 SqlValue::Xml(x) => {
838 encode_nvarchar_value(x, buf)?;
839 }
840
841 SqlValue::Tvp(_) => {
842 return Err(TypeError::UnsupportedConversion {
844 from: "TVP".to_string(),
845 to: "bulk copy value",
846 });
847 }
848 _ => {
850 return Err(TypeError::UnsupportedConversion {
851 from: value.type_name().to_string(),
852 to: "bulk copy value",
853 });
854 }
855 }
856
857 Ok(())
858 }
859}
860
861fn encode_nvarchar_value(s: &str, buf: &mut BytesMut) -> Result<(), TypeError> {
863 let utf16: Vec<u16> = s.encode_utf16().collect();
864 let byte_len = utf16.len() * 2;
865
866 if byte_len > 0xFFFF {
867 return Err(TypeError::BufferTooSmall {
868 needed: byte_len,
869 available: 0xFFFF,
870 });
871 }
872
873 buf.put_u16_le(byte_len as u16);
874 for code_unit in utf16 {
875 buf.put_u16_le(code_unit);
876 }
877 Ok(())
878}
879
880fn encode_plp_string(utf16: &[u16], buf: &mut BytesMut) {
890 let byte_len = utf16.len() * 2;
891
892 buf.put_u64_le(byte_len as u64);
894
895 if byte_len > 0 {
896 buf.put_u32_le(byte_len as u32);
898 for code_unit in utf16 {
899 buf.put_u16_le(*code_unit);
900 }
901 }
902
903 buf.put_u32_le(0);
905}
906
907fn encode_plp_binary(data: &[u8], buf: &mut BytesMut) {
916 buf.put_u64_le(data.len() as u64);
918
919 if !data.is_empty() {
920 buf.put_u32_le(data.len() as u32);
922 buf.put_slice(data);
923 }
924
925 buf.put_u32_le(0);
927}
928
929#[cfg(feature = "chrono")]
931fn encode_time_with_scale(time: chrono::NaiveTime, scale: u8, buf: &mut BytesMut) {
932 use chrono::Timelike;
933
934 let nanos = time.num_seconds_from_midnight() as u64 * 1_000_000_000 + time.nanosecond() as u64;
935 let intervals = nanos / time_scale_divisor(scale);
936 let len = time_byte_length(scale);
937
938 for i in 0..len {
939 buf.put_u8(((intervals >> (i * 8)) & 0xFF) as u8);
940 }
941}
942
943impl BulkInsert {
944 fn write_done(&mut self) {
946 let buf = &mut self.buffer;
947
948 buf.put_u8(TokenType::Done as u8);
949
950 let status = DoneStatus::from_bits(0x0010); buf.put_u16_le(status.to_bits());
953
954 buf.put_u16_le(0);
956
957 buf.put_u64_le(self.total_rows);
959 }
960
961 pub fn take_packets(&mut self) -> Vec<BytesMut> {
965 const MAX_PACKET_SIZE: usize = 4096;
966 const HEADER_SIZE: usize = 8;
967 const MAX_PAYLOAD: usize = MAX_PACKET_SIZE - HEADER_SIZE;
968
969 let data = self.buffer.split();
970 let mut packets = Vec::new();
971 let mut offset = 0;
972
973 while offset < data.len() {
974 let remaining = data.len() - offset;
975 let payload_size = remaining.min(MAX_PAYLOAD);
976 let is_last = offset + payload_size >= data.len();
977
978 let mut packet = BytesMut::with_capacity(MAX_PACKET_SIZE);
979
980 let header = PacketHeader {
982 packet_type: PacketType::BulkLoad,
983 status: if is_last {
984 PacketStatus::END_OF_MESSAGE
985 } else {
986 PacketStatus::NORMAL
987 },
988 length: (HEADER_SIZE + payload_size) as u16,
989 spid: 0,
990 packet_id: self.packet_id,
991 window: 0,
992 };
993
994 header.encode(&mut packet);
995
996 packet.put_slice(&data[offset..offset + payload_size]);
998
999 packets.push(packet);
1000 offset += payload_size;
1001 self.packet_id = self.packet_id.wrapping_add(1);
1002 }
1003
1004 packets
1005 }
1006
1007 pub fn total_rows(&self) -> u64 {
1009 self.total_rows
1010 }
1011
1012 pub fn rows_in_batch(&self) -> usize {
1014 self.rows_in_batch
1015 }
1016
1017 pub fn should_flush(&self) -> bool {
1019 self.batch_size > 0 && self.rows_in_batch >= self.batch_size
1020 }
1021
1022 pub fn finish_packets(&mut self) -> Vec<BytesMut> {
1025 self.write_done();
1026 self.take_packets()
1027 }
1028
1029 pub fn result(&self) -> BulkInsertResult {
1031 BulkInsertResult {
1032 rows_affected: self.total_rows,
1033 batches_committed: self.batches_committed,
1034 has_errors: false,
1035 }
1036 }
1037}
1038
1039fn decimal_byte_length(precision: u8) -> u8 {
1041 match precision {
1042 1..=9 => 5,
1043 10..=19 => 9,
1044 20..=28 => 13,
1045 29..=38 => 17,
1046 _ => 17, }
1048}
1049
1050#[cfg(feature = "chrono")]
1052fn time_byte_length(scale: u8) -> u8 {
1053 match scale {
1054 0..=2 => 3,
1055 3..=4 => 4,
1056 5..=7 => 5,
1057 _ => 5,
1058 }
1059}
1060
1061#[cfg(feature = "chrono")]
1063fn time_scale_divisor(scale: u8) -> u64 {
1064 match scale {
1065 0 => 1_000_000_000,
1066 1 => 100_000_000,
1067 2 => 10_000_000,
1068 3 => 1_000_000,
1069 4 => 100_000,
1070 5 => 10_000,
1071 6 => 1_000,
1072 7 => 100,
1073 _ => 100,
1074 }
1075}
1076
1077#[cfg(test)]
1078#[allow(clippy::unwrap_used)]
1079mod tests {
1080 use super::*;
1081
1082 #[test]
1083 fn test_bulk_options_default() {
1084 let opts = BulkOptions::default();
1085 assert_eq!(opts.batch_size, 0);
1086 assert!(opts.check_constraints);
1087 assert!(!opts.fire_triggers);
1088 assert!(opts.keep_nulls);
1089 assert!(!opts.table_lock);
1090 }
1091
1092 #[test]
1093 fn test_bulk_column_creation() {
1094 let col = BulkColumn::new("id", "INT", 0);
1095 assert_eq!(col.name, "id");
1096 assert_eq!(col.type_id, 0x38);
1097 assert!(col.nullable);
1098 }
1099
1100 #[test]
1101 fn test_parse_sql_type() {
1102 let (type_id, len, _prec, _scale) = parse_sql_type("INT");
1103 assert_eq!(type_id, 0x38);
1104 assert!(len.is_none());
1105
1106 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
1107 assert_eq!(type_id, 0xE7);
1108 assert_eq!(len, Some(200)); let (type_id, _, prec, scale) = parse_sql_type("DECIMAL(10,2)");
1111 assert_eq!(type_id, 0x6C);
1112 assert_eq!(prec, Some(10));
1113 assert_eq!(scale, Some(2));
1114 }
1115
1116 #[test]
1117 fn test_insert_bulk_statement() {
1118 let builder = BulkInsertBuilder::new("dbo.Users")
1119 .with_typed_columns(vec![
1120 BulkColumn::new("id", "INT", 0),
1121 BulkColumn::new("name", "NVARCHAR(100)", 1),
1122 ])
1123 .table_lock(true);
1124
1125 let sql = builder.build_insert_bulk_statement().unwrap();
1126 assert!(sql.contains("INSERT BULK dbo.Users"));
1127 assert!(sql.contains("TABLOCK"));
1128 }
1129
1130 #[test]
1131 fn test_bulk_insert_rejects_injection() {
1132 let builder = BulkInsertBuilder::new("table;DROP TABLE users")
1133 .with_typed_columns(vec![BulkColumn::new("id", "INT", 0)]);
1134
1135 assert!(builder.build_insert_bulk_statement().is_err());
1136 }
1137
1138 #[test]
1139 fn test_bulk_insert_validates_column_names() {
1140 let builder = BulkInsertBuilder::new("Users").with_typed_columns(vec![BulkColumn::new(
1141 "col;DROP TABLE x",
1142 "INT",
1143 0,
1144 )]);
1145
1146 assert!(builder.build_insert_bulk_statement().is_err());
1147 }
1148
1149 #[test]
1150 fn test_bulk_insert_accepts_qualified_names() {
1151 let builder = BulkInsertBuilder::new("catalog.dbo.Users")
1152 .with_typed_columns(vec![BulkColumn::new("id", "INT", 0)]);
1153
1154 assert!(builder.build_insert_bulk_statement().is_ok());
1155 }
1156
1157 #[test]
1158 fn test_bulk_insert_creation() {
1159 let columns = vec![
1160 BulkColumn::new("id", "INT", 0),
1161 BulkColumn::new("name", "NVARCHAR(100)", 1),
1162 ];
1163
1164 let bulk = BulkInsert::new(columns, 1000);
1165 assert_eq!(bulk.total_rows(), 0);
1166 assert_eq!(bulk.rows_in_batch(), 0);
1167 assert!(!bulk.should_flush());
1168 }
1169
1170 #[test]
1171 fn test_decimal_byte_length() {
1172 assert_eq!(decimal_byte_length(5), 5);
1173 assert_eq!(decimal_byte_length(15), 9);
1174 assert_eq!(decimal_byte_length(25), 13);
1175 assert_eq!(decimal_byte_length(35), 17);
1176 }
1177
1178 #[test]
1179 #[cfg(feature = "chrono")]
1180 fn test_time_byte_length() {
1181 assert_eq!(time_byte_length(0), 3);
1182 assert_eq!(time_byte_length(3), 4);
1183 assert_eq!(time_byte_length(7), 5);
1184 }
1185
1186 #[test]
1187 fn test_plp_string_encoding() {
1188 let mut buf = BytesMut::new();
1189 let text = "Hello";
1190 let utf16: Vec<u16> = text.encode_utf16().collect();
1191
1192 encode_plp_string(&utf16, &mut buf);
1193
1194 assert_eq!(buf.len(), 8 + 4 + 10 + 4);
1200
1201 assert_eq!(&buf[0..8], &10u64.to_le_bytes());
1203
1204 assert_eq!(&buf[8..12], &10u32.to_le_bytes());
1206
1207 assert_eq!(&buf[22..26], &0u32.to_le_bytes());
1209 }
1210
1211 #[test]
1212 fn test_plp_binary_encoding() {
1213 let mut buf = BytesMut::new();
1214 let data = b"test binary data";
1215
1216 encode_plp_binary(data, &mut buf);
1217
1218 assert_eq!(buf.len(), 8 + 4 + 16 + 4);
1224
1225 assert_eq!(&buf[0..8], &16u64.to_le_bytes());
1227
1228 assert_eq!(&buf[8..12], &16u32.to_le_bytes());
1230
1231 assert_eq!(&buf[12..28], data);
1233
1234 assert_eq!(&buf[28..32], &0u32.to_le_bytes());
1236 }
1237
1238 #[test]
1239 fn test_plp_empty_string() {
1240 let mut buf = BytesMut::new();
1241 let utf16: Vec<u16> = "".encode_utf16().collect();
1242
1243 encode_plp_string(&utf16, &mut buf);
1244
1245 assert_eq!(buf.len(), 8 + 4);
1247
1248 assert_eq!(&buf[0..8], &0u64.to_le_bytes());
1250
1251 assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1253 }
1254
1255 #[test]
1256 fn test_plp_empty_binary() {
1257 let mut buf = BytesMut::new();
1258
1259 encode_plp_binary(&[], &mut buf);
1260
1261 assert_eq!(buf.len(), 8 + 4);
1263
1264 assert_eq!(&buf[0..8], &0u64.to_le_bytes());
1266
1267 assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1269 }
1270
1271 #[test]
1272 fn test_parse_sql_type_max() {
1273 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(MAX)");
1275 assert_eq!(type_id, 0xE7);
1276 assert_eq!(len, Some(0xFFFF)); let (type_id, len, _, _) = parse_sql_type("VARBINARY(MAX)");
1280 assert_eq!(type_id, 0xA5);
1281 assert_eq!(len, Some(0xFFFF));
1282
1283 let (type_id, len, _, _) = parse_sql_type("VARCHAR(MAX)");
1285 assert_eq!(type_id, 0xA7);
1286 assert_eq!(len, Some(0xFFFF));
1287
1288 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
1290 assert_eq!(type_id, 0xE7);
1291 assert_eq!(len, Some(200)); }
1293}