1use bytes::{BufMut, BytesMut};
53use std::sync::Arc;
54
55use mssql_types::{SqlValue, ToSql, TypeError};
56use tds_protocol::packet::{PacketHeader, PacketStatus, PacketType};
57use tds_protocol::token::{DoneStatus, TokenType};
58
59use crate::error::Error;
60
61#[derive(Debug, Clone)]
66pub struct BulkOptions {
67 pub batch_size: usize,
73
74 pub check_constraints: bool,
78
79 pub fire_triggers: bool,
83
84 pub keep_nulls: bool,
88
89 pub table_lock: bool,
95
96 pub order_hint: Option<Vec<String>>,
102
103 pub max_errors: u32,
107}
108
109impl Default for BulkOptions {
110 fn default() -> Self {
111 Self {
112 batch_size: 0,
113 check_constraints: true,
114 fire_triggers: false,
115 keep_nulls: true,
116 table_lock: false,
117 order_hint: None,
118 max_errors: 0,
119 }
120 }
121}
122
123#[derive(Debug, Clone)]
125pub struct BulkColumn {
126 pub name: String,
128 pub sql_type: String,
130 pub nullable: bool,
132 pub ordinal: usize,
134 type_id: u8,
136 max_length: Option<u32>,
138 precision: Option<u8>,
140 scale: Option<u8>,
142}
143
144impl BulkColumn {
145 pub fn new<S: Into<String>>(name: S, sql_type: S, ordinal: usize) -> Self {
147 let sql_type_str: String = sql_type.into();
148 let (type_id, max_length, precision, scale) = parse_sql_type(&sql_type_str);
149
150 Self {
151 name: name.into(),
152 sql_type: sql_type_str,
153 nullable: true,
154 ordinal,
155 type_id,
156 max_length,
157 precision,
158 scale,
159 }
160 }
161
162 #[must_use]
164 pub fn with_nullable(mut self, nullable: bool) -> Self {
165 self.nullable = nullable;
166 self
167 }
168}
169
170fn parse_sql_type(sql_type: &str) -> (u8, Option<u32>, Option<u8>, Option<u8>) {
172 let upper = sql_type.to_uppercase();
173
174 let (base, params) = if let Some(paren_pos) = upper.find('(') {
176 let base = &upper[..paren_pos];
177 let params_str = upper[paren_pos + 1..].trim_end_matches(')');
178 (base, Some(params_str))
179 } else {
180 (upper.as_str(), None)
181 };
182
183 match base {
184 "BIT" => (0x32, None, None, None),
185 "TINYINT" => (0x30, None, None, None),
186 "SMALLINT" => (0x34, None, None, None),
187 "INT" => (0x38, None, None, None),
188 "BIGINT" => (0x7F, None, None, None),
189 "REAL" => (0x3B, None, None, None),
190 "FLOAT" => (0x3E, None, None, None),
191 "DATE" => (0x28, None, None, None),
192 "TIME" => {
193 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
194 (0x29, None, None, Some(scale))
195 }
196 "DATETIME" => (0x3D, None, None, None),
197 "DATETIME2" => {
198 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
199 (0x2A, None, None, Some(scale))
200 }
201 "DATETIMEOFFSET" => {
202 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
203 (0x2B, None, None, Some(scale))
204 }
205 "SMALLDATETIME" => (0x3F, None, None, None),
206 "UNIQUEIDENTIFIER" => (0x24, Some(16), None, None),
207 "VARCHAR" | "CHAR" => {
208 let len = params
209 .and_then(|p| {
210 if p == "MAX" {
211 Some(0xFFFF_u32)
212 } else {
213 p.parse().ok()
214 }
215 })
216 .unwrap_or(8000);
217 (0xA7, Some(len), None, None)
218 }
219 "NVARCHAR" | "NCHAR" => {
220 let is_max = params.map(|p| p == "MAX").unwrap_or(false);
221 if is_max {
222 (0xE7, Some(0xFFFF), None, None)
224 } else {
225 let len = params.and_then(|p| p.parse().ok()).unwrap_or(4000);
227 (0xE7, Some(len * 2), None, None)
228 }
229 }
230 "VARBINARY" | "BINARY" => {
231 let len = params
232 .and_then(|p| {
233 if p == "MAX" {
234 Some(0xFFFF_u32)
235 } else {
236 p.parse().ok()
237 }
238 })
239 .unwrap_or(8000);
240 (0xA5, Some(len), None, None)
241 }
242 "DECIMAL" | "NUMERIC" => {
243 let (precision, scale) = if let Some(p) = params {
244 let parts: Vec<&str> = p.split(',').map(|s| s.trim()).collect();
245 (
246 parts.first().and_then(|s| s.parse().ok()).unwrap_or(18),
247 parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0),
248 )
249 } else {
250 (18, 0)
251 };
252 (0x6C, None, Some(precision), Some(scale))
253 }
254 "MONEY" => (0x3C, Some(8), None, None),
255 "SMALLMONEY" => (0x7A, Some(4), None, None),
256 "XML" => (0xF1, Some(0xFFFF), None, None),
257 "TEXT" => (0x23, Some(0x7FFF_FFFF), None, None),
258 "NTEXT" => (0x63, Some(0x7FFF_FFFF), None, None),
259 "IMAGE" => (0x22, Some(0x7FFF_FFFF), None, None),
260 _ => (0xE7, Some(8000), None, None), }
262}
263
264#[derive(Debug, Clone)]
266pub struct BulkInsertResult {
267 pub rows_affected: u64,
269 pub batches_committed: u32,
271 pub has_errors: bool,
273}
274
275#[derive(Debug)]
277pub struct BulkInsertBuilder {
278 table_name: String,
279 columns: Vec<BulkColumn>,
280 options: BulkOptions,
281}
282
283impl BulkInsertBuilder {
284 pub fn new<S: Into<String>>(table_name: S) -> Self {
286 Self {
287 table_name: table_name.into(),
288 columns: Vec::new(),
289 options: BulkOptions::default(),
290 }
291 }
292
293 #[must_use]
298 pub fn with_columns(mut self, column_names: &[&str]) -> Self {
299 self.columns = column_names
300 .iter()
301 .enumerate()
302 .map(|(i, name)| BulkColumn::new(*name, "NVARCHAR(MAX)", i))
303 .collect();
304 self
305 }
306
307 #[must_use]
309 pub fn with_typed_columns(mut self, columns: Vec<BulkColumn>) -> Self {
310 self.columns = columns;
311 self
312 }
313
314 #[must_use]
316 pub fn with_options(mut self, options: BulkOptions) -> Self {
317 self.options = options;
318 self
319 }
320
321 #[must_use]
323 pub fn batch_size(mut self, size: usize) -> Self {
324 self.options.batch_size = size;
325 self
326 }
327
328 #[must_use]
330 pub fn table_lock(mut self, enabled: bool) -> Self {
331 self.options.table_lock = enabled;
332 self
333 }
334
335 #[must_use]
337 pub fn fire_triggers(mut self, enabled: bool) -> Self {
338 self.options.fire_triggers = enabled;
339 self
340 }
341
342 pub fn table_name(&self) -> &str {
344 &self.table_name
345 }
346
347 pub fn columns(&self) -> &[BulkColumn] {
349 &self.columns
350 }
351
352 pub fn options(&self) -> &BulkOptions {
354 &self.options
355 }
356
357 pub fn build_insert_bulk_statement(&self) -> String {
359 let mut sql = format!("INSERT BULK {}", self.table_name);
360
361 if !self.columns.is_empty() {
363 sql.push_str(" (");
364 let cols: Vec<String> = self
365 .columns
366 .iter()
367 .map(|c| format!("{} {}", c.name, c.sql_type))
368 .collect();
369 sql.push_str(&cols.join(", "));
370 sql.push(')');
371 }
372
373 let mut hints: Vec<String> = Vec::new();
375
376 if self.options.check_constraints {
377 hints.push("CHECK_CONSTRAINTS".to_string());
378 }
379 if self.options.fire_triggers {
380 hints.push("FIRE_TRIGGERS".to_string());
381 }
382 if self.options.keep_nulls {
383 hints.push("KEEP_NULLS".to_string());
384 }
385 if self.options.table_lock {
386 hints.push("TABLOCK".to_string());
387 }
388 if self.options.batch_size > 0 {
389 hints.push(format!("ROWS_PER_BATCH = {}", self.options.batch_size));
390 }
391
392 if let Some(ref order) = self.options.order_hint {
393 hints.push(format!("ORDER({})", order.join(", ")));
394 }
395
396 if !hints.is_empty() {
397 sql.push_str(" WITH (");
398 sql.push_str(&hints.join(", "));
399 sql.push(')');
400 }
401
402 sql
403 }
404}
405
406pub struct BulkInsert {
411 columns: Arc<[BulkColumn]>,
413 buffer: BytesMut,
415 rows_in_batch: usize,
417 total_rows: u64,
419 batch_size: usize,
421 batches_committed: u32,
423 packet_id: u8,
425}
426
427impl BulkInsert {
428 pub fn new(columns: Vec<BulkColumn>, batch_size: usize) -> Self {
430 let mut bulk = Self {
431 columns: columns.into(),
432 buffer: BytesMut::with_capacity(64 * 1024), rows_in_batch: 0,
434 total_rows: 0,
435 batch_size,
436 batches_committed: 0,
437 packet_id: 1,
438 };
439
440 bulk.write_colmetadata();
442
443 bulk
444 }
445
446 fn write_colmetadata(&mut self) {
448 let buf = &mut self.buffer;
449
450 buf.put_u8(TokenType::ColMetaData as u8);
452
453 buf.put_u16_le(self.columns.len() as u16);
455
456 for col in self.columns.iter() {
457 buf.put_u32_le(0);
459
460 let flags: u16 = if col.nullable { 0x0001 } else { 0x0000 };
462 buf.put_u16_le(flags);
463
464 buf.put_u8(col.type_id);
466
467 match col.type_id {
469 0x32 | 0x30 | 0x34 | 0x38 | 0x7F | 0x3B | 0x3E | 0x3D | 0x3F | 0x28 => {}
471
472 0xE7 | 0xA7 | 0xA5 | 0xAD => {
474 let max_len = col.max_length.unwrap_or(8000);
476 if max_len == 0xFFFF {
477 buf.put_u16_le(0xFFFF);
478 } else {
479 buf.put_u16_le(max_len as u16);
480 }
481
482 if col.type_id == 0xE7 || col.type_id == 0xA7 {
484 buf.put_u32_le(0x0409_0904); buf.put_u8(52); }
488 }
489
490 0x6C | 0x6A => {
492 let precision = col.precision.unwrap_or(18);
494 let len = decimal_byte_length(precision);
495 buf.put_u8(len);
496 buf.put_u8(precision);
497 buf.put_u8(col.scale.unwrap_or(0));
498 }
499
500 0x29..=0x2B => {
502 buf.put_u8(col.scale.unwrap_or(7));
503 }
504
505 0x24 => {
507 buf.put_u8(16);
508 }
509
510 _ => {
512 if let Some(len) = col.max_length {
513 if len <= 0xFFFF {
514 buf.put_u16_le(len as u16);
515 }
516 }
517 }
518 }
519
520 let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
522 buf.put_u8(name_utf16.len() as u8);
523 for code_unit in name_utf16 {
524 buf.put_u16_le(code_unit);
525 }
526 }
527 }
528
529 pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
540 if values.len() != self.columns.len() {
541 return Err(Error::Config(format!(
542 "expected {} values, got {}",
543 self.columns.len(),
544 values.len()
545 )));
546 }
547
548 let sql_values: Result<Vec<SqlValue>, TypeError> =
550 values.iter().map(|v| v.to_sql()).collect();
551 let sql_values = sql_values.map_err(Error::from)?;
552
553 self.write_row(&sql_values)?;
554
555 self.rows_in_batch += 1;
556 self.total_rows += 1;
557
558 Ok(())
559 }
560
561 pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
563 if values.len() != self.columns.len() {
564 return Err(Error::Config(format!(
565 "expected {} values, got {}",
566 self.columns.len(),
567 values.len()
568 )));
569 }
570
571 self.write_row(values)?;
572
573 self.rows_in_batch += 1;
574 self.total_rows += 1;
575
576 Ok(())
577 }
578
579 fn write_row(&mut self, values: &[SqlValue]) -> Result<(), Error> {
581 self.buffer.put_u8(TokenType::Row as u8);
583
584 let columns: Vec<_> = self.columns.iter().cloned().collect();
586
587 for (i, (col, value)) in columns.iter().zip(values.iter()).enumerate() {
589 self.encode_column_value(col, value)
590 .map_err(|e| Error::Config(format!("failed to encode column {}: {}", i, e)))?;
591 }
592
593 Ok(())
594 }
595
596 fn encode_column_value(&mut self, col: &BulkColumn, value: &SqlValue) -> Result<(), TypeError> {
598 let buf = &mut self.buffer;
599
600 let is_plp_type =
603 col.max_length == Some(0xFFFF) && matches!(col.type_id, 0xE7 | 0xA7 | 0xA5 | 0xAD);
604
605 match value {
606 SqlValue::Null => {
607 match col.type_id {
609 0xE7 | 0xA7 | 0xA5 | 0xAD => {
611 if is_plp_type {
612 buf.put_u64_le(0xFFFF_FFFF_FFFF_FFFF);
614 } else {
615 buf.put_u16_le(0xFFFF);
617 }
618 }
619 0x26 | 0x6C | 0x6A | 0x24 | 0x29 | 0x2A | 0x2B => {
621 buf.put_u8(0);
622 }
623 _ => {
625 if col.nullable {
626 buf.put_u8(0);
627 } else {
628 return Err(TypeError::UnexpectedNull);
629 }
630 }
631 }
632 }
633
634 SqlValue::Bool(v) => {
635 buf.put_u8(1); buf.put_u8(if *v { 1 } else { 0 });
637 }
638
639 SqlValue::TinyInt(v) => {
640 buf.put_u8(1); buf.put_u8(*v);
642 }
643
644 SqlValue::SmallInt(v) => {
645 buf.put_u8(2); buf.put_i16_le(*v);
647 }
648
649 SqlValue::Int(v) => {
650 buf.put_u8(4); buf.put_i32_le(*v);
652 }
653
654 SqlValue::BigInt(v) => {
655 buf.put_u8(8); buf.put_i64_le(*v);
657 }
658
659 SqlValue::Float(v) => {
660 buf.put_u8(4); buf.put_f32_le(*v);
662 }
663
664 SqlValue::Double(v) => {
665 buf.put_u8(8); buf.put_f64_le(*v);
667 }
668
669 SqlValue::String(s) => {
670 let utf16: Vec<u16> = s.encode_utf16().collect();
672 let byte_len = utf16.len() * 2;
673
674 if is_plp_type {
675 encode_plp_string(&utf16, buf);
678 } else if byte_len > 0xFFFF {
679 return Err(TypeError::BufferTooSmall {
681 needed: byte_len,
682 available: 0xFFFF,
683 });
684 } else {
685 buf.put_u16_le(byte_len as u16);
687 for code_unit in utf16 {
688 buf.put_u16_le(code_unit);
689 }
690 }
691 }
692
693 SqlValue::Binary(b) => {
694 if is_plp_type {
695 encode_plp_binary(b, buf);
697 } else if b.len() > 0xFFFF {
698 return Err(TypeError::BufferTooSmall {
700 needed: b.len(),
701 available: 0xFFFF,
702 });
703 } else {
704 buf.put_u16_le(b.len() as u16);
706 buf.put_slice(b);
707 }
708 }
709
710 #[cfg(feature = "decimal")]
712 SqlValue::Decimal(d) => {
713 let precision = col.precision.unwrap_or(18);
714 let len = decimal_byte_length(precision);
715 buf.put_u8(len);
716
717 buf.put_u8(if d.is_sign_negative() { 0 } else { 1 });
719
720 let mantissa = d.mantissa().unsigned_abs();
722 let mantissa_bytes = mantissa.to_le_bytes();
723 buf.put_slice(&mantissa_bytes[..((len - 1) as usize)]);
724 }
725
726 #[cfg(feature = "uuid")]
727 SqlValue::Uuid(u) => {
728 buf.put_u8(16); mssql_types::encode::encode_uuid(*u, buf);
731 }
732
733 #[cfg(feature = "chrono")]
734 SqlValue::Date(d) => {
735 buf.put_u8(3); mssql_types::encode::encode_date(*d, buf);
737 }
738
739 #[cfg(feature = "chrono")]
740 SqlValue::Time(t) => {
741 let scale = col.scale.unwrap_or(7);
742 let len = time_byte_length(scale);
743 buf.put_u8(len);
744 encode_time_with_scale(*t, scale, buf);
746 }
747
748 #[cfg(feature = "chrono")]
749 SqlValue::DateTime(dt) => {
750 let scale = col.scale.unwrap_or(7);
751 let time_len = time_byte_length(scale);
752 let total_len = time_len + 3;
753 buf.put_u8(total_len);
754 encode_time_with_scale(dt.time(), scale, buf);
756 mssql_types::encode::encode_date(dt.date(), buf);
757 }
758
759 #[cfg(feature = "chrono")]
760 SqlValue::DateTimeOffset(dto) => {
761 let scale = col.scale.unwrap_or(7);
762 let time_len = time_byte_length(scale);
763 let total_len = time_len + 3 + 2;
764 buf.put_u8(total_len);
765 encode_time_with_scale(dto.time(), scale, buf);
767 mssql_types::encode::encode_date(dto.date_naive(), buf);
768 use chrono::Offset;
770 let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
771 buf.put_i16_le(offset_minutes);
772 }
773
774 #[cfg(feature = "json")]
775 SqlValue::Json(j) => {
776 let s = j.to_string();
777 encode_nvarchar_value(&s, buf)?;
778 }
779
780 SqlValue::Xml(x) => {
781 encode_nvarchar_value(x, buf)?;
782 }
783
784 SqlValue::Tvp(_) => {
785 return Err(TypeError::UnsupportedConversion {
787 from: "TVP".to_string(),
788 to: "bulk copy value",
789 });
790 }
791 _ => {
793 return Err(TypeError::UnsupportedConversion {
794 from: value.type_name().to_string(),
795 to: "bulk copy value",
796 });
797 }
798 }
799
800 Ok(())
801 }
802}
803
804fn encode_nvarchar_value(s: &str, buf: &mut BytesMut) -> Result<(), TypeError> {
806 let utf16: Vec<u16> = s.encode_utf16().collect();
807 let byte_len = utf16.len() * 2;
808
809 if byte_len > 0xFFFF {
810 return Err(TypeError::BufferTooSmall {
811 needed: byte_len,
812 available: 0xFFFF,
813 });
814 }
815
816 buf.put_u16_le(byte_len as u16);
817 for code_unit in utf16 {
818 buf.put_u16_le(code_unit);
819 }
820 Ok(())
821}
822
823fn encode_plp_string(utf16: &[u16], buf: &mut BytesMut) {
833 let byte_len = utf16.len() * 2;
834
835 buf.put_u64_le(byte_len as u64);
837
838 if byte_len > 0 {
839 buf.put_u32_le(byte_len as u32);
841 for code_unit in utf16 {
842 buf.put_u16_le(*code_unit);
843 }
844 }
845
846 buf.put_u32_le(0);
848}
849
850fn encode_plp_binary(data: &[u8], buf: &mut BytesMut) {
859 buf.put_u64_le(data.len() as u64);
861
862 if !data.is_empty() {
863 buf.put_u32_le(data.len() as u32);
865 buf.put_slice(data);
866 }
867
868 buf.put_u32_le(0);
870}
871
872#[cfg(feature = "chrono")]
874fn encode_time_with_scale(time: chrono::NaiveTime, scale: u8, buf: &mut BytesMut) {
875 use chrono::Timelike;
876
877 let nanos = time.num_seconds_from_midnight() as u64 * 1_000_000_000 + time.nanosecond() as u64;
878 let intervals = nanos / time_scale_divisor(scale);
879 let len = time_byte_length(scale);
880
881 for i in 0..len {
882 buf.put_u8(((intervals >> (i * 8)) & 0xFF) as u8);
883 }
884}
885
886impl BulkInsert {
887 fn write_done(&mut self) {
889 let buf = &mut self.buffer;
890
891 buf.put_u8(TokenType::Done as u8);
892
893 let status = DoneStatus {
895 more: false,
896 error: false,
897 in_xact: false,
898 count: true,
899 attn: false,
900 srverror: false,
901 };
902 buf.put_u16_le(status.to_bits());
903
904 buf.put_u16_le(0);
906
907 buf.put_u64_le(self.total_rows);
909 }
910
911 pub fn take_packets(&mut self) -> Vec<BytesMut> {
915 const MAX_PACKET_SIZE: usize = 4096;
916 const HEADER_SIZE: usize = 8;
917 const MAX_PAYLOAD: usize = MAX_PACKET_SIZE - HEADER_SIZE;
918
919 let data = self.buffer.split();
920 let mut packets = Vec::new();
921 let mut offset = 0;
922
923 while offset < data.len() {
924 let remaining = data.len() - offset;
925 let payload_size = remaining.min(MAX_PAYLOAD);
926 let is_last = offset + payload_size >= data.len();
927
928 let mut packet = BytesMut::with_capacity(MAX_PACKET_SIZE);
929
930 let header = PacketHeader {
932 packet_type: PacketType::BulkLoad,
933 status: if is_last {
934 PacketStatus::END_OF_MESSAGE
935 } else {
936 PacketStatus::NORMAL
937 },
938 length: (HEADER_SIZE + payload_size) as u16,
939 spid: 0,
940 packet_id: self.packet_id,
941 window: 0,
942 };
943
944 header.encode(&mut packet);
945
946 packet.put_slice(&data[offset..offset + payload_size]);
948
949 packets.push(packet);
950 offset += payload_size;
951 self.packet_id = self.packet_id.wrapping_add(1);
952 }
953
954 packets
955 }
956
957 pub fn total_rows(&self) -> u64 {
959 self.total_rows
960 }
961
962 pub fn rows_in_batch(&self) -> usize {
964 self.rows_in_batch
965 }
966
967 pub fn should_flush(&self) -> bool {
969 self.batch_size > 0 && self.rows_in_batch >= self.batch_size
970 }
971
972 pub fn finish_packets(&mut self) -> Vec<BytesMut> {
975 self.write_done();
976 self.take_packets()
977 }
978
979 pub fn result(&self) -> BulkInsertResult {
981 BulkInsertResult {
982 rows_affected: self.total_rows,
983 batches_committed: self.batches_committed,
984 has_errors: false,
985 }
986 }
987}
988
989fn decimal_byte_length(precision: u8) -> u8 {
991 match precision {
992 1..=9 => 5,
993 10..=19 => 9,
994 20..=28 => 13,
995 29..=38 => 17,
996 _ => 17, }
998}
999
1000#[cfg(feature = "chrono")]
1002fn time_byte_length(scale: u8) -> u8 {
1003 match scale {
1004 0..=2 => 3,
1005 3..=4 => 4,
1006 5..=7 => 5,
1007 _ => 5,
1008 }
1009}
1010
1011#[cfg(feature = "chrono")]
1013fn time_scale_divisor(scale: u8) -> u64 {
1014 match scale {
1015 0 => 1_000_000_000,
1016 1 => 100_000_000,
1017 2 => 10_000_000,
1018 3 => 1_000_000,
1019 4 => 100_000,
1020 5 => 10_000,
1021 6 => 1_000,
1022 7 => 100,
1023 _ => 100,
1024 }
1025}
1026
1027#[cfg(test)]
1028#[allow(clippy::unwrap_used)]
1029mod tests {
1030 use super::*;
1031
1032 #[test]
1033 fn test_bulk_options_default() {
1034 let opts = BulkOptions::default();
1035 assert_eq!(opts.batch_size, 0);
1036 assert!(opts.check_constraints);
1037 assert!(!opts.fire_triggers);
1038 assert!(opts.keep_nulls);
1039 assert!(!opts.table_lock);
1040 }
1041
1042 #[test]
1043 fn test_bulk_column_creation() {
1044 let col = BulkColumn::new("id", "INT", 0);
1045 assert_eq!(col.name, "id");
1046 assert_eq!(col.type_id, 0x38);
1047 assert!(col.nullable);
1048 }
1049
1050 #[test]
1051 fn test_parse_sql_type() {
1052 let (type_id, len, _prec, _scale) = parse_sql_type("INT");
1053 assert_eq!(type_id, 0x38);
1054 assert!(len.is_none());
1055
1056 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
1057 assert_eq!(type_id, 0xE7);
1058 assert_eq!(len, Some(200)); let (type_id, _, prec, scale) = parse_sql_type("DECIMAL(10,2)");
1061 assert_eq!(type_id, 0x6C);
1062 assert_eq!(prec, Some(10));
1063 assert_eq!(scale, Some(2));
1064 }
1065
1066 #[test]
1067 fn test_insert_bulk_statement() {
1068 let builder = BulkInsertBuilder::new("dbo.Users")
1069 .with_typed_columns(vec![
1070 BulkColumn::new("id", "INT", 0),
1071 BulkColumn::new("name", "NVARCHAR(100)", 1),
1072 ])
1073 .table_lock(true);
1074
1075 let sql = builder.build_insert_bulk_statement();
1076 assert!(sql.contains("INSERT BULK dbo.Users"));
1077 assert!(sql.contains("TABLOCK"));
1078 }
1079
1080 #[test]
1081 fn test_bulk_insert_creation() {
1082 let columns = vec![
1083 BulkColumn::new("id", "INT", 0),
1084 BulkColumn::new("name", "NVARCHAR(100)", 1),
1085 ];
1086
1087 let bulk = BulkInsert::new(columns, 1000);
1088 assert_eq!(bulk.total_rows(), 0);
1089 assert_eq!(bulk.rows_in_batch(), 0);
1090 assert!(!bulk.should_flush());
1091 }
1092
1093 #[test]
1094 fn test_decimal_byte_length() {
1095 assert_eq!(decimal_byte_length(5), 5);
1096 assert_eq!(decimal_byte_length(15), 9);
1097 assert_eq!(decimal_byte_length(25), 13);
1098 assert_eq!(decimal_byte_length(35), 17);
1099 }
1100
1101 #[test]
1102 #[cfg(feature = "chrono")]
1103 fn test_time_byte_length() {
1104 assert_eq!(time_byte_length(0), 3);
1105 assert_eq!(time_byte_length(3), 4);
1106 assert_eq!(time_byte_length(7), 5);
1107 }
1108
1109 #[test]
1110 fn test_plp_string_encoding() {
1111 let mut buf = BytesMut::new();
1112 let text = "Hello";
1113 let utf16: Vec<u16> = text.encode_utf16().collect();
1114
1115 encode_plp_string(&utf16, &mut buf);
1116
1117 assert_eq!(buf.len(), 8 + 4 + 10 + 4);
1123
1124 assert_eq!(&buf[0..8], &10u64.to_le_bytes());
1126
1127 assert_eq!(&buf[8..12], &10u32.to_le_bytes());
1129
1130 assert_eq!(&buf[22..26], &0u32.to_le_bytes());
1132 }
1133
1134 #[test]
1135 fn test_plp_binary_encoding() {
1136 let mut buf = BytesMut::new();
1137 let data = b"test binary data";
1138
1139 encode_plp_binary(data, &mut buf);
1140
1141 assert_eq!(buf.len(), 8 + 4 + 16 + 4);
1147
1148 assert_eq!(&buf[0..8], &16u64.to_le_bytes());
1150
1151 assert_eq!(&buf[8..12], &16u32.to_le_bytes());
1153
1154 assert_eq!(&buf[12..28], data);
1156
1157 assert_eq!(&buf[28..32], &0u32.to_le_bytes());
1159 }
1160
1161 #[test]
1162 fn test_plp_empty_string() {
1163 let mut buf = BytesMut::new();
1164 let utf16: Vec<u16> = "".encode_utf16().collect();
1165
1166 encode_plp_string(&utf16, &mut buf);
1167
1168 assert_eq!(buf.len(), 8 + 4);
1170
1171 assert_eq!(&buf[0..8], &0u64.to_le_bytes());
1173
1174 assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1176 }
1177
1178 #[test]
1179 fn test_plp_empty_binary() {
1180 let mut buf = BytesMut::new();
1181
1182 encode_plp_binary(&[], &mut buf);
1183
1184 assert_eq!(buf.len(), 8 + 4);
1186
1187 assert_eq!(&buf[0..8], &0u64.to_le_bytes());
1189
1190 assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1192 }
1193
1194 #[test]
1195 fn test_parse_sql_type_max() {
1196 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(MAX)");
1198 assert_eq!(type_id, 0xE7);
1199 assert_eq!(len, Some(0xFFFF)); let (type_id, len, _, _) = parse_sql_type("VARBINARY(MAX)");
1203 assert_eq!(type_id, 0xA5);
1204 assert_eq!(len, Some(0xFFFF));
1205
1206 let (type_id, len, _, _) = parse_sql_type("VARCHAR(MAX)");
1208 assert_eq!(type_id, 0xA7);
1209 assert_eq!(len, Some(0xFFFF));
1210
1211 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
1213 assert_eq!(type_id, 0xE7);
1214 assert_eq!(len, Some(200)); }
1216}