1use bytes::{BufMut, BytesMut};
53use once_cell::sync::Lazy;
54use regex::Regex;
55use std::sync::Arc;
56
57use mssql_types::{SqlValue, ToSql, TypeError};
58use tds_protocol::packet::{PacketHeader, PacketStatus, PacketType};
59use tds_protocol::token::{DoneStatus, TokenType};
60
61use crate::error::Error;
62
63#[derive(Debug, Clone)]
68pub struct BulkOptions {
69 pub batch_size: usize,
75
76 pub check_constraints: bool,
80
81 pub fire_triggers: bool,
85
86 pub keep_nulls: bool,
90
91 pub table_lock: bool,
97
98 pub order_hint: Option<Vec<String>>,
104
105 pub max_errors: u32,
109}
110
111impl Default for BulkOptions {
112 fn default() -> Self {
113 Self {
114 batch_size: 0,
115 check_constraints: true,
116 fire_triggers: false,
117 keep_nulls: true,
118 table_lock: false,
119 order_hint: None,
120 max_errors: 0,
121 }
122 }
123}
124
125#[derive(Debug, Clone)]
127pub struct BulkColumn {
128 pub name: String,
130 pub sql_type: String,
132 pub nullable: bool,
134 pub ordinal: usize,
136 type_id: u8,
138 max_length: Option<u32>,
140 precision: Option<u8>,
142 scale: Option<u8>,
144}
145
146impl BulkColumn {
147 pub fn new<S: Into<String>>(name: S, sql_type: S, ordinal: usize) -> Self {
149 let sql_type_str: String = sql_type.into();
150 let (type_id, max_length, precision, scale) = parse_sql_type(&sql_type_str);
151
152 Self {
153 name: name.into(),
154 sql_type: sql_type_str,
155 nullable: true,
156 ordinal,
157 type_id,
158 max_length,
159 precision,
160 scale,
161 }
162 }
163
164 #[must_use]
166 pub fn with_nullable(mut self, nullable: bool) -> Self {
167 self.nullable = nullable;
168 self
169 }
170}
171
172fn parse_sql_type(sql_type: &str) -> (u8, Option<u32>, Option<u8>, Option<u8>) {
174 let upper = sql_type.to_uppercase();
175
176 let (base, params) = if let Some(paren_pos) = upper.find('(') {
178 let base = &upper[..paren_pos];
179 let params_str = upper[paren_pos + 1..].trim_end_matches(')');
180 (base, Some(params_str))
181 } else {
182 (upper.as_str(), None)
183 };
184
185 match base {
186 "BIT" => (0x32, None, None, None),
187 "TINYINT" => (0x30, None, None, None),
188 "SMALLINT" => (0x34, None, None, None),
189 "INT" => (0x38, None, None, None),
190 "BIGINT" => (0x7F, None, None, None),
191 "REAL" => (0x3B, None, None, None),
192 "FLOAT" => (0x3E, None, None, None),
193 "DATE" => (0x28, None, None, None),
194 "TIME" => {
195 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
196 (0x29, None, None, Some(scale))
197 }
198 "DATETIME" => (0x3D, None, None, None),
199 "DATETIME2" => {
200 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
201 (0x2A, None, None, Some(scale))
202 }
203 "DATETIMEOFFSET" => {
204 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
205 (0x2B, None, None, Some(scale))
206 }
207 "SMALLDATETIME" => (0x3F, None, None, None),
208 "UNIQUEIDENTIFIER" => (0x24, Some(16), None, None),
209 "VARCHAR" | "CHAR" => {
210 let len = params
211 .and_then(|p| {
212 if p == "MAX" {
213 Some(0xFFFF_u32)
214 } else {
215 p.parse().ok()
216 }
217 })
218 .unwrap_or(8000);
219 (0xA7, Some(len), None, None)
220 }
221 "NVARCHAR" | "NCHAR" => {
222 let is_max = params.map(|p| p == "MAX").unwrap_or(false);
223 if is_max {
224 (0xE7, Some(0xFFFF), None, None)
226 } else {
227 let len = params.and_then(|p| p.parse().ok()).unwrap_or(4000);
229 (0xE7, Some(len * 2), None, None)
230 }
231 }
232 "VARBINARY" | "BINARY" => {
233 let len = params
234 .and_then(|p| {
235 if p == "MAX" {
236 Some(0xFFFF_u32)
237 } else {
238 p.parse().ok()
239 }
240 })
241 .unwrap_or(8000);
242 (0xA5, Some(len), None, None)
243 }
244 "DECIMAL" | "NUMERIC" => {
245 let (precision, scale) = if let Some(p) = params {
246 let parts: Vec<&str> = p.split(',').map(|s| s.trim()).collect();
247 (
248 parts.first().and_then(|s| s.parse().ok()).unwrap_or(18),
249 parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0),
250 )
251 } else {
252 (18, 0)
253 };
254 (0x6C, None, Some(precision), Some(scale))
255 }
256 "MONEY" => (0x3C, Some(8), None, None),
257 "SMALLMONEY" => (0x7A, Some(4), None, None),
258 "XML" => (0xF1, Some(0xFFFF), None, None),
259 "TEXT" => (0x23, Some(0x7FFF_FFFF), None, None),
260 "NTEXT" => (0x63, Some(0x7FFF_FFFF), None, None),
261 "IMAGE" => (0x22, Some(0x7FFF_FFFF), None, None),
262 _ => (0xE7, Some(8000), None, None), }
264}
265
266#[derive(Debug, Clone)]
268pub struct BulkInsertResult {
269 pub rows_affected: u64,
271 pub batches_committed: u32,
273 pub has_errors: bool,
275}
276
277#[derive(Debug)]
279pub struct BulkInsertBuilder {
280 table_name: String,
281 columns: Vec<BulkColumn>,
282 options: BulkOptions,
283}
284
285impl BulkInsertBuilder {
286 pub fn new<S: Into<String>>(table_name: S) -> Self {
288 Self {
289 table_name: table_name.into(),
290 columns: Vec::new(),
291 options: BulkOptions::default(),
292 }
293 }
294
295 #[must_use]
300 pub fn with_columns(mut self, column_names: &[&str]) -> Self {
301 self.columns = column_names
302 .iter()
303 .enumerate()
304 .map(|(i, name)| BulkColumn::new(*name, "NVARCHAR(MAX)", i))
305 .collect();
306 self
307 }
308
309 #[must_use]
311 pub fn with_typed_columns(mut self, columns: Vec<BulkColumn>) -> Self {
312 self.columns = columns;
313 self
314 }
315
316 #[must_use]
318 pub fn with_options(mut self, options: BulkOptions) -> Self {
319 self.options = options;
320 self
321 }
322
323 #[must_use]
325 pub fn batch_size(mut self, size: usize) -> Self {
326 self.options.batch_size = size;
327 self
328 }
329
330 #[must_use]
332 pub fn table_lock(mut self, enabled: bool) -> Self {
333 self.options.table_lock = enabled;
334 self
335 }
336
337 #[must_use]
339 pub fn fire_triggers(mut self, enabled: bool) -> Self {
340 self.options.fire_triggers = enabled;
341 self
342 }
343
344 pub fn table_name(&self) -> &str {
346 &self.table_name
347 }
348
349 pub fn columns(&self) -> &[BulkColumn] {
351 &self.columns
352 }
353
354 pub fn options(&self) -> &BulkOptions {
356 &self.options
357 }
358
359 pub fn build_insert_bulk_statement(&self) -> Result<String, Error> {
366 validate_qualified_identifier(&self.table_name)?;
368
369 for col in &self.columns {
371 validate_identifier(&col.name)?;
372 }
373
374 let mut sql = format!("INSERT BULK {}", self.table_name);
375
376 if !self.columns.is_empty() {
378 sql.push_str(" (");
379 let cols: Vec<String> = self
380 .columns
381 .iter()
382 .map(|c| {
383 validate_sql_type(&c.sql_type)?;
389 Ok(format!("{} {}", c.name, c.sql_type))
390 })
391 .collect::<Result<Vec<_>, Error>>()?;
392 sql.push_str(&cols.join(", "));
393 sql.push(')');
394 }
395
396 let mut hints: Vec<String> = Vec::new();
398
399 if self.options.check_constraints {
400 hints.push("CHECK_CONSTRAINTS".to_string());
401 }
402 if self.options.fire_triggers {
403 hints.push("FIRE_TRIGGERS".to_string());
404 }
405 if self.options.keep_nulls {
406 hints.push("KEEP_NULLS".to_string());
407 }
408 if self.options.table_lock {
409 hints.push("TABLOCK".to_string());
410 }
411 if self.options.batch_size > 0 {
412 hints.push(format!("ROWS_PER_BATCH = {}", self.options.batch_size));
413 }
414
415 if let Some(ref order) = self.options.order_hint {
416 for col_name in order {
418 validate_identifier(col_name)?;
419 }
420 hints.push(format!("ORDER({})", order.join(", ")));
421 }
422
423 if !hints.is_empty() {
424 sql.push_str(" WITH (");
425 sql.push_str(&hints.join(", "));
426 sql.push(')');
427 }
428
429 Ok(sql)
430 }
431}
432
433fn validate_sql_type(type_str: &str) -> Result<(), Error> {
439 #[allow(clippy::expect_used)] static SQL_TYPE_RE: Lazy<Regex> =
441 Lazy::new(|| Regex::new(r"^[a-zA-Z][a-zA-Z0-9_ ()\.,]{0,127}$").expect("valid regex"));
442
443 if type_str.is_empty() {
444 return Err(Error::Config("SQL type cannot be empty".into()));
445 }
446
447 if !SQL_TYPE_RE.is_match(type_str) {
448 return Err(Error::Config(format!(
449 "invalid SQL type '{type_str}': contains disallowed characters"
450 )));
451 }
452
453 Ok(())
454}
455
456fn validate_identifier(name: &str) -> Result<(), Error> {
458 #[allow(clippy::expect_used)] static IDENTIFIER_RE: Lazy<Regex> =
460 Lazy::new(|| Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_@#$]{0,127}$").expect("valid regex"));
461
462 if name.is_empty() {
463 return Err(Error::InvalidIdentifier(
464 "identifier cannot be empty".into(),
465 ));
466 }
467
468 if !IDENTIFIER_RE.is_match(name) {
469 return Err(Error::InvalidIdentifier(format!(
470 "invalid identifier '{name}': must start with letter/underscore, \
471 contain only alphanumerics/_/@/#/$, and be 1-128 characters"
472 )));
473 }
474
475 Ok(())
476}
477
478fn validate_qualified_identifier(name: &str) -> Result<(), Error> {
483 if name.is_empty() {
484 return Err(Error::InvalidIdentifier(
485 "identifier cannot be empty".into(),
486 ));
487 }
488
489 let parts: Vec<&str> = name.split('.').collect();
490 if parts.len() > 4 {
491 return Err(Error::InvalidIdentifier(format!(
492 "invalid qualified identifier '{name}': too many parts (max 4: server.catalog.schema.object)"
493 )));
494 }
495
496 for part in &parts {
497 validate_identifier(part)?;
498 }
499
500 Ok(())
501}
502
503pub struct BulkInsert {
508 columns: Arc<[BulkColumn]>,
510 buffer: BytesMut,
512 rows_in_batch: usize,
514 total_rows: u64,
516 batch_size: usize,
518 batches_committed: u32,
520 packet_id: u8,
522}
523
524impl BulkInsert {
525 pub fn new(columns: Vec<BulkColumn>, batch_size: usize) -> Self {
527 let mut bulk = Self {
528 columns: columns.into(),
529 buffer: BytesMut::with_capacity(64 * 1024), rows_in_batch: 0,
531 total_rows: 0,
532 batch_size,
533 batches_committed: 0,
534 packet_id: 1,
535 };
536
537 bulk.write_colmetadata();
539
540 bulk
541 }
542
543 fn write_colmetadata(&mut self) {
545 let buf = &mut self.buffer;
546
547 buf.put_u8(TokenType::ColMetaData as u8);
549
550 buf.put_u16_le(self.columns.len() as u16);
552
553 for col in self.columns.iter() {
554 buf.put_u32_le(0);
556
557 let flags: u16 = if col.nullable { 0x0001 } else { 0x0000 };
559 buf.put_u16_le(flags);
560
561 buf.put_u8(col.type_id);
563
564 match col.type_id {
566 0x32 | 0x30 | 0x34 | 0x38 | 0x7F | 0x3B | 0x3E | 0x3D | 0x3F | 0x28 => {}
568
569 0xE7 | 0xA7 | 0xA5 | 0xAD => {
571 let max_len = col.max_length.unwrap_or(8000);
573 if max_len == 0xFFFF {
574 buf.put_u16_le(0xFFFF);
575 } else {
576 buf.put_u16_le(max_len as u16);
577 }
578
579 if col.type_id == 0xE7 || col.type_id == 0xA7 {
581 buf.put_u32_le(0x0409_0904); buf.put_u8(52); }
585 }
586
587 0x6C | 0x6A => {
589 let precision = col.precision.unwrap_or(18);
591 let len = decimal_byte_length(precision);
592 buf.put_u8(len);
593 buf.put_u8(precision);
594 buf.put_u8(col.scale.unwrap_or(0));
595 }
596
597 0x29..=0x2B => {
599 buf.put_u8(col.scale.unwrap_or(7));
600 }
601
602 0x24 => {
604 buf.put_u8(16);
605 }
606
607 _ => {
609 if let Some(len) = col.max_length {
610 if len <= 0xFFFF {
611 buf.put_u16_le(len as u16);
612 }
613 }
614 }
615 }
616
617 let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
619 buf.put_u8(name_utf16.len() as u8);
620 for code_unit in name_utf16 {
621 buf.put_u16_le(code_unit);
622 }
623 }
624 }
625
626 pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
637 if values.len() != self.columns.len() {
638 return Err(Error::Config(format!(
639 "expected {} values, got {}",
640 self.columns.len(),
641 values.len()
642 )));
643 }
644
645 let sql_values: Result<Vec<SqlValue>, TypeError> =
647 values.iter().map(|v| v.to_sql()).collect();
648 let sql_values = sql_values.map_err(Error::from)?;
649
650 self.write_row(&sql_values)?;
651
652 self.rows_in_batch += 1;
653 self.total_rows += 1;
654
655 Ok(())
656 }
657
658 pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
660 if values.len() != self.columns.len() {
661 return Err(Error::Config(format!(
662 "expected {} values, got {}",
663 self.columns.len(),
664 values.len()
665 )));
666 }
667
668 self.write_row(values)?;
669
670 self.rows_in_batch += 1;
671 self.total_rows += 1;
672
673 Ok(())
674 }
675
676 fn write_row(&mut self, values: &[SqlValue]) -> Result<(), Error> {
678 self.buffer.put_u8(TokenType::Row as u8);
680
681 let columns: Vec<_> = self.columns.iter().cloned().collect();
683
684 for (i, (col, value)) in columns.iter().zip(values.iter()).enumerate() {
686 self.encode_column_value(col, value)
687 .map_err(|e| Error::Config(format!("failed to encode column {i}: {e}")))?;
688 }
689
690 Ok(())
691 }
692
693 fn encode_column_value(&mut self, col: &BulkColumn, value: &SqlValue) -> Result<(), TypeError> {
695 let buf = &mut self.buffer;
696
697 let is_plp_type =
700 col.max_length == Some(0xFFFF) && matches!(col.type_id, 0xE7 | 0xA7 | 0xA5 | 0xAD);
701
702 match value {
703 SqlValue::Null => {
704 match col.type_id {
706 0xE7 | 0xA7 | 0xA5 | 0xAD => {
708 if is_plp_type {
709 buf.put_u64_le(0xFFFF_FFFF_FFFF_FFFF);
711 } else {
712 buf.put_u16_le(0xFFFF);
714 }
715 }
716 0x26 | 0x6C | 0x6A | 0x24 | 0x29 | 0x2A | 0x2B => {
718 buf.put_u8(0);
719 }
720 _ => {
722 if col.nullable {
723 buf.put_u8(0);
724 } else {
725 return Err(TypeError::UnexpectedNull);
726 }
727 }
728 }
729 }
730
731 SqlValue::Bool(v) => {
732 buf.put_u8(1); buf.put_u8(if *v { 1 } else { 0 });
734 }
735
736 SqlValue::TinyInt(v) => {
737 buf.put_u8(1); buf.put_u8(*v);
739 }
740
741 SqlValue::SmallInt(v) => {
742 buf.put_u8(2); buf.put_i16_le(*v);
744 }
745
746 SqlValue::Int(v) => {
747 buf.put_u8(4); buf.put_i32_le(*v);
749 }
750
751 SqlValue::BigInt(v) => {
752 buf.put_u8(8); buf.put_i64_le(*v);
754 }
755
756 SqlValue::Float(v) => {
757 buf.put_u8(4); buf.put_f32_le(*v);
759 }
760
761 SqlValue::Double(v) => {
762 buf.put_u8(8); buf.put_f64_le(*v);
764 }
765
766 SqlValue::String(s) => {
767 let utf16: Vec<u16> = s.encode_utf16().collect();
769 let byte_len = utf16.len() * 2;
770
771 if is_plp_type {
772 encode_plp_string(&utf16, buf);
775 } else if byte_len > 0xFFFF {
776 return Err(TypeError::BufferTooSmall {
778 needed: byte_len,
779 available: 0xFFFF,
780 });
781 } else {
782 buf.put_u16_le(byte_len as u16);
784 for code_unit in utf16 {
785 buf.put_u16_le(code_unit);
786 }
787 }
788 }
789
790 SqlValue::Binary(b) => {
791 if is_plp_type {
792 encode_plp_binary(b, buf);
794 } else if b.len() > 0xFFFF {
795 return Err(TypeError::BufferTooSmall {
797 needed: b.len(),
798 available: 0xFFFF,
799 });
800 } else {
801 buf.put_u16_le(b.len() as u16);
803 buf.put_slice(b);
804 }
805 }
806
807 #[cfg(feature = "decimal")]
809 SqlValue::Decimal(d) => {
810 let precision = col.precision.unwrap_or(18);
811 let len = decimal_byte_length(precision);
812 buf.put_u8(len);
813
814 buf.put_u8(if d.is_sign_negative() { 0 } else { 1 });
816
817 let mantissa = d.mantissa().unsigned_abs();
819 let mantissa_bytes = mantissa.to_le_bytes();
820 buf.put_slice(&mantissa_bytes[..((len - 1) as usize)]);
821 }
822
823 #[cfg(feature = "uuid")]
824 SqlValue::Uuid(u) => {
825 buf.put_u8(16); mssql_types::encode::encode_uuid(*u, buf);
828 }
829
830 #[cfg(feature = "chrono")]
831 SqlValue::Date(d) => {
832 buf.put_u8(3); mssql_types::encode::encode_date(*d, buf);
834 }
835
836 #[cfg(feature = "chrono")]
837 SqlValue::Time(t) => {
838 let scale = col.scale.unwrap_or(7);
839 let len = time_byte_length(scale);
840 buf.put_u8(len);
841 encode_time_with_scale(*t, scale, buf);
843 }
844
845 #[cfg(feature = "chrono")]
846 SqlValue::DateTime(dt) => {
847 let scale = col.scale.unwrap_or(7);
848 let time_len = time_byte_length(scale);
849 let total_len = time_len + 3;
850 buf.put_u8(total_len);
851 encode_time_with_scale(dt.time(), scale, buf);
853 mssql_types::encode::encode_date(dt.date(), buf);
854 }
855
856 #[cfg(feature = "chrono")]
857 SqlValue::DateTimeOffset(dto) => {
858 let scale = col.scale.unwrap_or(7);
859 let time_len = time_byte_length(scale);
860 let total_len = time_len + 3 + 2;
861 buf.put_u8(total_len);
862 encode_time_with_scale(dto.time(), scale, buf);
864 mssql_types::encode::encode_date(dto.date_naive(), buf);
865 use chrono::Offset;
867 let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
868 buf.put_i16_le(offset_minutes);
869 }
870
871 #[cfg(feature = "json")]
872 SqlValue::Json(j) => {
873 let s = j.to_string();
874 encode_nvarchar_value(&s, buf)?;
875 }
876
877 SqlValue::Xml(x) => {
878 encode_nvarchar_value(x, buf)?;
879 }
880
881 SqlValue::Tvp(_) => {
882 return Err(TypeError::UnsupportedConversion {
884 from: "TVP".to_string(),
885 to: "bulk copy value",
886 });
887 }
888 _ => {
890 return Err(TypeError::UnsupportedConversion {
891 from: value.type_name().to_string(),
892 to: "bulk copy value",
893 });
894 }
895 }
896
897 Ok(())
898 }
899}
900
901fn encode_nvarchar_value(s: &str, buf: &mut BytesMut) -> Result<(), TypeError> {
903 let utf16: Vec<u16> = s.encode_utf16().collect();
904 let byte_len = utf16.len() * 2;
905
906 if byte_len > 0xFFFF {
907 return Err(TypeError::BufferTooSmall {
908 needed: byte_len,
909 available: 0xFFFF,
910 });
911 }
912
913 buf.put_u16_le(byte_len as u16);
914 for code_unit in utf16 {
915 buf.put_u16_le(code_unit);
916 }
917 Ok(())
918}
919
920fn encode_plp_string(utf16: &[u16], buf: &mut BytesMut) {
930 let byte_len = utf16.len() * 2;
931
932 buf.put_u64_le(byte_len as u64);
934
935 if byte_len > 0 {
936 buf.put_u32_le(byte_len as u32);
938 for code_unit in utf16 {
939 buf.put_u16_le(*code_unit);
940 }
941 }
942
943 buf.put_u32_le(0);
945}
946
947fn encode_plp_binary(data: &[u8], buf: &mut BytesMut) {
956 buf.put_u64_le(data.len() as u64);
958
959 if !data.is_empty() {
960 buf.put_u32_le(data.len() as u32);
962 buf.put_slice(data);
963 }
964
965 buf.put_u32_le(0);
967}
968
969#[cfg(feature = "chrono")]
971fn encode_time_with_scale(time: chrono::NaiveTime, scale: u8, buf: &mut BytesMut) {
972 use chrono::Timelike;
973
974 let nanos = time.num_seconds_from_midnight() as u64 * 1_000_000_000 + time.nanosecond() as u64;
975 let intervals = nanos / time_scale_divisor(scale);
976 let len = time_byte_length(scale);
977
978 for i in 0..len {
979 buf.put_u8(((intervals >> (i * 8)) & 0xFF) as u8);
980 }
981}
982
983impl BulkInsert {
984 fn write_done(&mut self) {
986 let buf = &mut self.buffer;
987
988 buf.put_u8(TokenType::Done as u8);
989
990 let status = DoneStatus::from_bits(0x0010); buf.put_u16_le(status.to_bits());
993
994 buf.put_u16_le(0);
996
997 buf.put_u64_le(self.total_rows);
999 }
1000
1001 pub fn take_packets(&mut self) -> Vec<BytesMut> {
1005 const MAX_PACKET_SIZE: usize = 4096;
1006 const HEADER_SIZE: usize = 8;
1007 const MAX_PAYLOAD: usize = MAX_PACKET_SIZE - HEADER_SIZE;
1008
1009 let data = self.buffer.split();
1010 let mut packets = Vec::new();
1011 let mut offset = 0;
1012
1013 while offset < data.len() {
1014 let remaining = data.len() - offset;
1015 let payload_size = remaining.min(MAX_PAYLOAD);
1016 let is_last = offset + payload_size >= data.len();
1017
1018 let mut packet = BytesMut::with_capacity(MAX_PACKET_SIZE);
1019
1020 let header = PacketHeader {
1022 packet_type: PacketType::BulkLoad,
1023 status: if is_last {
1024 PacketStatus::END_OF_MESSAGE
1025 } else {
1026 PacketStatus::NORMAL
1027 },
1028 length: (HEADER_SIZE + payload_size) as u16,
1029 spid: 0,
1030 packet_id: self.packet_id,
1031 window: 0,
1032 };
1033
1034 header.encode(&mut packet);
1035
1036 packet.put_slice(&data[offset..offset + payload_size]);
1038
1039 packets.push(packet);
1040 offset += payload_size;
1041 self.packet_id = self.packet_id.wrapping_add(1);
1042 }
1043
1044 packets
1045 }
1046
1047 pub fn total_rows(&self) -> u64 {
1049 self.total_rows
1050 }
1051
1052 pub fn rows_in_batch(&self) -> usize {
1054 self.rows_in_batch
1055 }
1056
1057 pub fn should_flush(&self) -> bool {
1059 self.batch_size > 0 && self.rows_in_batch >= self.batch_size
1060 }
1061
1062 pub fn finish_packets(&mut self) -> Vec<BytesMut> {
1065 self.write_done();
1066 self.take_packets()
1067 }
1068
1069 pub fn result(&self) -> BulkInsertResult {
1071 BulkInsertResult {
1072 rows_affected: self.total_rows,
1073 batches_committed: self.batches_committed,
1074 has_errors: false,
1075 }
1076 }
1077}
1078
1079fn decimal_byte_length(precision: u8) -> u8 {
1081 match precision {
1082 1..=9 => 5,
1083 10..=19 => 9,
1084 20..=28 => 13,
1085 29..=38 => 17,
1086 _ => 17, }
1088}
1089
1090#[cfg(feature = "chrono")]
1092fn time_byte_length(scale: u8) -> u8 {
1093 match scale {
1094 0..=2 => 3,
1095 3..=4 => 4,
1096 5..=7 => 5,
1097 _ => 5,
1098 }
1099}
1100
1101#[cfg(feature = "chrono")]
1103fn time_scale_divisor(scale: u8) -> u64 {
1104 match scale {
1105 0 => 1_000_000_000,
1106 1 => 100_000_000,
1107 2 => 10_000_000,
1108 3 => 1_000_000,
1109 4 => 100_000,
1110 5 => 10_000,
1111 6 => 1_000,
1112 7 => 100,
1113 _ => 100,
1114 }
1115}
1116
1117#[cfg(test)]
1118#[allow(clippy::unwrap_used)]
1119mod tests {
1120 use super::*;
1121
1122 #[test]
1123 fn test_bulk_options_default() {
1124 let opts = BulkOptions::default();
1125 assert_eq!(opts.batch_size, 0);
1126 assert!(opts.check_constraints);
1127 assert!(!opts.fire_triggers);
1128 assert!(opts.keep_nulls);
1129 assert!(!opts.table_lock);
1130 }
1131
1132 #[test]
1133 fn test_bulk_column_creation() {
1134 let col = BulkColumn::new("id", "INT", 0);
1135 assert_eq!(col.name, "id");
1136 assert_eq!(col.type_id, 0x38);
1137 assert!(col.nullable);
1138 }
1139
1140 #[test]
1141 fn test_parse_sql_type() {
1142 let (type_id, len, _prec, _scale) = parse_sql_type("INT");
1143 assert_eq!(type_id, 0x38);
1144 assert!(len.is_none());
1145
1146 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
1147 assert_eq!(type_id, 0xE7);
1148 assert_eq!(len, Some(200)); let (type_id, _, prec, scale) = parse_sql_type("DECIMAL(10,2)");
1151 assert_eq!(type_id, 0x6C);
1152 assert_eq!(prec, Some(10));
1153 assert_eq!(scale, Some(2));
1154 }
1155
1156 #[test]
1157 fn test_insert_bulk_statement() {
1158 let builder = BulkInsertBuilder::new("dbo.Users")
1159 .with_typed_columns(vec![
1160 BulkColumn::new("id", "INT", 0),
1161 BulkColumn::new("name", "NVARCHAR(100)", 1),
1162 ])
1163 .table_lock(true);
1164
1165 let sql = builder.build_insert_bulk_statement().unwrap();
1166 assert!(sql.contains("INSERT BULK dbo.Users"));
1167 assert!(sql.contains("TABLOCK"));
1168 }
1169
1170 #[test]
1171 fn test_bulk_insert_rejects_injection() {
1172 let builder = BulkInsertBuilder::new("table;DROP TABLE users")
1173 .with_typed_columns(vec![BulkColumn::new("id", "INT", 0)]);
1174
1175 assert!(builder.build_insert_bulk_statement().is_err());
1176 }
1177
1178 #[test]
1179 fn test_bulk_insert_validates_column_names() {
1180 let builder = BulkInsertBuilder::new("Users").with_typed_columns(vec![BulkColumn::new(
1181 "col;DROP TABLE x",
1182 "INT",
1183 0,
1184 )]);
1185
1186 assert!(builder.build_insert_bulk_statement().is_err());
1187 }
1188
1189 #[test]
1190 fn test_bulk_insert_accepts_qualified_names() {
1191 let builder = BulkInsertBuilder::new("catalog.dbo.Users")
1192 .with_typed_columns(vec![BulkColumn::new("id", "INT", 0)]);
1193
1194 assert!(builder.build_insert_bulk_statement().is_ok());
1195 }
1196
1197 #[test]
1198 fn test_bulk_insert_creation() {
1199 let columns = vec![
1200 BulkColumn::new("id", "INT", 0),
1201 BulkColumn::new("name", "NVARCHAR(100)", 1),
1202 ];
1203
1204 let bulk = BulkInsert::new(columns, 1000);
1205 assert_eq!(bulk.total_rows(), 0);
1206 assert_eq!(bulk.rows_in_batch(), 0);
1207 assert!(!bulk.should_flush());
1208 }
1209
1210 #[test]
1211 fn test_decimal_byte_length() {
1212 assert_eq!(decimal_byte_length(5), 5);
1213 assert_eq!(decimal_byte_length(15), 9);
1214 assert_eq!(decimal_byte_length(25), 13);
1215 assert_eq!(decimal_byte_length(35), 17);
1216 }
1217
1218 #[test]
1219 #[cfg(feature = "chrono")]
1220 fn test_time_byte_length() {
1221 assert_eq!(time_byte_length(0), 3);
1222 assert_eq!(time_byte_length(3), 4);
1223 assert_eq!(time_byte_length(7), 5);
1224 }
1225
1226 #[test]
1227 fn test_plp_string_encoding() {
1228 let mut buf = BytesMut::new();
1229 let text = "Hello";
1230 let utf16: Vec<u16> = text.encode_utf16().collect();
1231
1232 encode_plp_string(&utf16, &mut buf);
1233
1234 assert_eq!(buf.len(), 8 + 4 + 10 + 4);
1240
1241 assert_eq!(&buf[0..8], &10u64.to_le_bytes());
1243
1244 assert_eq!(&buf[8..12], &10u32.to_le_bytes());
1246
1247 assert_eq!(&buf[22..26], &0u32.to_le_bytes());
1249 }
1250
1251 #[test]
1252 fn test_plp_binary_encoding() {
1253 let mut buf = BytesMut::new();
1254 let data = b"test binary data";
1255
1256 encode_plp_binary(data, &mut buf);
1257
1258 assert_eq!(buf.len(), 8 + 4 + 16 + 4);
1264
1265 assert_eq!(&buf[0..8], &16u64.to_le_bytes());
1267
1268 assert_eq!(&buf[8..12], &16u32.to_le_bytes());
1270
1271 assert_eq!(&buf[12..28], data);
1273
1274 assert_eq!(&buf[28..32], &0u32.to_le_bytes());
1276 }
1277
1278 #[test]
1279 fn test_plp_empty_string() {
1280 let mut buf = BytesMut::new();
1281 let utf16: Vec<u16> = "".encode_utf16().collect();
1282
1283 encode_plp_string(&utf16, &mut buf);
1284
1285 assert_eq!(buf.len(), 8 + 4);
1287
1288 assert_eq!(&buf[0..8], &0u64.to_le_bytes());
1290
1291 assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1293 }
1294
1295 #[test]
1296 fn test_plp_empty_binary() {
1297 let mut buf = BytesMut::new();
1298
1299 encode_plp_binary(&[], &mut buf);
1300
1301 assert_eq!(buf.len(), 8 + 4);
1303
1304 assert_eq!(&buf[0..8], &0u64.to_le_bytes());
1306
1307 assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1309 }
1310
1311 #[test]
1312 fn test_parse_sql_type_max() {
1313 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(MAX)");
1315 assert_eq!(type_id, 0xE7);
1316 assert_eq!(len, Some(0xFFFF)); let (type_id, len, _, _) = parse_sql_type("VARBINARY(MAX)");
1320 assert_eq!(type_id, 0xA5);
1321 assert_eq!(len, Some(0xFFFF));
1322
1323 let (type_id, len, _, _) = parse_sql_type("VARCHAR(MAX)");
1325 assert_eq!(type_id, 0xA7);
1326 assert_eq!(len, Some(0xFFFF));
1327
1328 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
1330 assert_eq!(type_id, 0xE7);
1331 assert_eq!(len, Some(200)); }
1333}