1use bytes::{BufMut, BytesMut};
53use std::sync::Arc;
54
55use mssql_types::{SqlValue, ToSql, TypeError};
56use tds_protocol::packet::{PacketHeader, PacketStatus, PacketType};
57use tds_protocol::token::{DoneStatus, TokenType};
58
59use crate::error::Error;
60
61#[derive(Debug, Clone)]
66pub struct BulkOptions {
67 pub batch_size: usize,
73
74 pub check_constraints: bool,
78
79 pub fire_triggers: bool,
83
84 pub keep_nulls: bool,
88
89 pub table_lock: bool,
95
96 pub order_hint: Option<Vec<String>>,
102
103 pub max_errors: u32,
107}
108
109impl Default for BulkOptions {
110 fn default() -> Self {
111 Self {
112 batch_size: 0,
113 check_constraints: true,
114 fire_triggers: false,
115 keep_nulls: true,
116 table_lock: false,
117 order_hint: None,
118 max_errors: 0,
119 }
120 }
121}
122
123#[derive(Debug, Clone)]
125pub struct BulkColumn {
126 pub name: String,
128 pub sql_type: String,
130 pub nullable: bool,
132 pub ordinal: usize,
134 type_id: u8,
136 max_length: Option<u32>,
138 precision: Option<u8>,
140 scale: Option<u8>,
142}
143
144impl BulkColumn {
145 pub fn new<S: Into<String>>(name: S, sql_type: S, ordinal: usize) -> Self {
147 let sql_type_str: String = sql_type.into();
148 let (type_id, max_length, precision, scale) = parse_sql_type(&sql_type_str);
149
150 Self {
151 name: name.into(),
152 sql_type: sql_type_str,
153 nullable: true,
154 ordinal,
155 type_id,
156 max_length,
157 precision,
158 scale,
159 }
160 }
161
162 #[must_use]
164 pub fn with_nullable(mut self, nullable: bool) -> Self {
165 self.nullable = nullable;
166 self
167 }
168}
169
170fn parse_sql_type(sql_type: &str) -> (u8, Option<u32>, Option<u8>, Option<u8>) {
172 let upper = sql_type.to_uppercase();
173
174 let (base, params) = if let Some(paren_pos) = upper.find('(') {
176 let base = &upper[..paren_pos];
177 let params_str = upper[paren_pos + 1..].trim_end_matches(')');
178 (base, Some(params_str))
179 } else {
180 (upper.as_str(), None)
181 };
182
183 match base {
184 "BIT" => (0x32, None, None, None),
185 "TINYINT" => (0x30, None, None, None),
186 "SMALLINT" => (0x34, None, None, None),
187 "INT" => (0x38, None, None, None),
188 "BIGINT" => (0x7F, None, None, None),
189 "REAL" => (0x3B, None, None, None),
190 "FLOAT" => (0x3E, None, None, None),
191 "DATE" => (0x28, None, None, None),
192 "TIME" => {
193 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
194 (0x29, None, None, Some(scale))
195 }
196 "DATETIME" => (0x3D, None, None, None),
197 "DATETIME2" => {
198 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
199 (0x2A, None, None, Some(scale))
200 }
201 "DATETIMEOFFSET" => {
202 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
203 (0x2B, None, None, Some(scale))
204 }
205 "SMALLDATETIME" => (0x3F, None, None, None),
206 "UNIQUEIDENTIFIER" => (0x24, Some(16), None, None),
207 "VARCHAR" | "CHAR" => {
208 let len = params
209 .and_then(|p| {
210 if p == "MAX" {
211 Some(0xFFFF_u32)
212 } else {
213 p.parse().ok()
214 }
215 })
216 .unwrap_or(8000);
217 (0xA7, Some(len), None, None)
218 }
219 "NVARCHAR" | "NCHAR" => {
220 let is_max = params.map(|p| p == "MAX").unwrap_or(false);
221 if is_max {
222 (0xE7, Some(0xFFFF), None, None)
224 } else {
225 let len = params.and_then(|p| p.parse().ok()).unwrap_or(4000);
227 (0xE7, Some(len * 2), None, None)
228 }
229 }
230 "VARBINARY" | "BINARY" => {
231 let len = params
232 .and_then(|p| {
233 if p == "MAX" {
234 Some(0xFFFF_u32)
235 } else {
236 p.parse().ok()
237 }
238 })
239 .unwrap_or(8000);
240 (0xA5, Some(len), None, None)
241 }
242 "DECIMAL" | "NUMERIC" => {
243 let (precision, scale) = if let Some(p) = params {
244 let parts: Vec<&str> = p.split(',').map(|s| s.trim()).collect();
245 (
246 parts.first().and_then(|s| s.parse().ok()).unwrap_or(18),
247 parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0),
248 )
249 } else {
250 (18, 0)
251 };
252 (0x6C, None, Some(precision), Some(scale))
253 }
254 "MONEY" => (0x3C, Some(8), None, None),
255 "SMALLMONEY" => (0x7A, Some(4), None, None),
256 "XML" => (0xF1, Some(0xFFFF), None, None),
257 "TEXT" => (0x23, Some(0x7FFF_FFFF), None, None),
258 "NTEXT" => (0x63, Some(0x7FFF_FFFF), None, None),
259 "IMAGE" => (0x22, Some(0x7FFF_FFFF), None, None),
260 _ => (0xE7, Some(8000), None, None), }
262}
263
264#[derive(Debug, Clone)]
266pub struct BulkInsertResult {
267 pub rows_affected: u64,
269 pub batches_committed: u32,
271 pub has_errors: bool,
273}
274
275#[derive(Debug)]
277pub struct BulkInsertBuilder {
278 table_name: String,
279 columns: Vec<BulkColumn>,
280 options: BulkOptions,
281}
282
283impl BulkInsertBuilder {
284 pub fn new<S: Into<String>>(table_name: S) -> Self {
286 Self {
287 table_name: table_name.into(),
288 columns: Vec::new(),
289 options: BulkOptions::default(),
290 }
291 }
292
293 #[must_use]
298 pub fn with_columns(mut self, column_names: &[&str]) -> Self {
299 self.columns = column_names
300 .iter()
301 .enumerate()
302 .map(|(i, name)| BulkColumn::new(*name, "NVARCHAR(MAX)", i))
303 .collect();
304 self
305 }
306
307 #[must_use]
309 pub fn with_typed_columns(mut self, columns: Vec<BulkColumn>) -> Self {
310 self.columns = columns;
311 self
312 }
313
314 #[must_use]
316 pub fn with_options(mut self, options: BulkOptions) -> Self {
317 self.options = options;
318 self
319 }
320
321 #[must_use]
323 pub fn batch_size(mut self, size: usize) -> Self {
324 self.options.batch_size = size;
325 self
326 }
327
328 #[must_use]
330 pub fn table_lock(mut self, enabled: bool) -> Self {
331 self.options.table_lock = enabled;
332 self
333 }
334
335 #[must_use]
337 pub fn fire_triggers(mut self, enabled: bool) -> Self {
338 self.options.fire_triggers = enabled;
339 self
340 }
341
342 pub fn table_name(&self) -> &str {
344 &self.table_name
345 }
346
347 pub fn columns(&self) -> &[BulkColumn] {
349 &self.columns
350 }
351
352 pub fn options(&self) -> &BulkOptions {
354 &self.options
355 }
356
357 pub fn build_insert_bulk_statement(&self) -> String {
359 let mut sql = format!("INSERT BULK {}", self.table_name);
360
361 if !self.columns.is_empty() {
363 sql.push_str(" (");
364 let cols: Vec<String> = self
365 .columns
366 .iter()
367 .map(|c| format!("{} {}", c.name, c.sql_type))
368 .collect();
369 sql.push_str(&cols.join(", "));
370 sql.push(')');
371 }
372
373 let mut hints: Vec<String> = Vec::new();
375
376 if self.options.check_constraints {
377 hints.push("CHECK_CONSTRAINTS".to_string());
378 }
379 if self.options.fire_triggers {
380 hints.push("FIRE_TRIGGERS".to_string());
381 }
382 if self.options.keep_nulls {
383 hints.push("KEEP_NULLS".to_string());
384 }
385 if self.options.table_lock {
386 hints.push("TABLOCK".to_string());
387 }
388 if self.options.batch_size > 0 {
389 hints.push(format!("ROWS_PER_BATCH = {}", self.options.batch_size));
390 }
391
392 if let Some(ref order) = self.options.order_hint {
393 hints.push(format!("ORDER({})", order.join(", ")));
394 }
395
396 if !hints.is_empty() {
397 sql.push_str(" WITH (");
398 sql.push_str(&hints.join(", "));
399 sql.push(')');
400 }
401
402 sql
403 }
404}
405
406pub struct BulkInsert {
411 columns: Arc<[BulkColumn]>,
413 buffer: BytesMut,
415 rows_in_batch: usize,
417 total_rows: u64,
419 batch_size: usize,
421 batches_committed: u32,
423 packet_id: u8,
425}
426
427impl BulkInsert {
428 pub fn new(columns: Vec<BulkColumn>, batch_size: usize) -> Self {
430 let mut bulk = Self {
431 columns: columns.into(),
432 buffer: BytesMut::with_capacity(64 * 1024), rows_in_batch: 0,
434 total_rows: 0,
435 batch_size,
436 batches_committed: 0,
437 packet_id: 1,
438 };
439
440 bulk.write_colmetadata();
442
443 bulk
444 }
445
446 fn write_colmetadata(&mut self) {
448 let buf = &mut self.buffer;
449
450 buf.put_u8(TokenType::ColMetaData as u8);
452
453 buf.put_u16_le(self.columns.len() as u16);
455
456 for col in self.columns.iter() {
457 buf.put_u32_le(0);
459
460 let flags: u16 = if col.nullable { 0x0001 } else { 0x0000 };
462 buf.put_u16_le(flags);
463
464 buf.put_u8(col.type_id);
466
467 match col.type_id {
469 0x32 | 0x30 | 0x34 | 0x38 | 0x7F | 0x3B | 0x3E | 0x3D | 0x3F | 0x28 => {}
471
472 0xE7 | 0xA7 | 0xA5 | 0xAD => {
474 let max_len = col.max_length.unwrap_or(8000);
476 if max_len == 0xFFFF {
477 buf.put_u16_le(0xFFFF);
478 } else {
479 buf.put_u16_le(max_len as u16);
480 }
481
482 if col.type_id == 0xE7 || col.type_id == 0xA7 {
484 buf.put_u32_le(0x0409_0904); buf.put_u8(52); }
488 }
489
490 0x6C | 0x6A => {
492 let precision = col.precision.unwrap_or(18);
494 let len = decimal_byte_length(precision);
495 buf.put_u8(len);
496 buf.put_u8(precision);
497 buf.put_u8(col.scale.unwrap_or(0));
498 }
499
500 0x29..=0x2B => {
502 buf.put_u8(col.scale.unwrap_or(7));
503 }
504
505 0x24 => {
507 buf.put_u8(16);
508 }
509
510 _ => {
512 if let Some(len) = col.max_length {
513 if len <= 0xFFFF {
514 buf.put_u16_le(len as u16);
515 }
516 }
517 }
518 }
519
520 let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
522 buf.put_u8(name_utf16.len() as u8);
523 for code_unit in name_utf16 {
524 buf.put_u16_le(code_unit);
525 }
526 }
527 }
528
529 pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
540 if values.len() != self.columns.len() {
541 return Err(Error::Config(format!(
542 "expected {} values, got {}",
543 self.columns.len(),
544 values.len()
545 )));
546 }
547
548 let sql_values: Result<Vec<SqlValue>, TypeError> =
550 values.iter().map(|v| v.to_sql()).collect();
551 let sql_values = sql_values.map_err(Error::from)?;
552
553 self.write_row(&sql_values)?;
554
555 self.rows_in_batch += 1;
556 self.total_rows += 1;
557
558 Ok(())
559 }
560
561 pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
563 if values.len() != self.columns.len() {
564 return Err(Error::Config(format!(
565 "expected {} values, got {}",
566 self.columns.len(),
567 values.len()
568 )));
569 }
570
571 self.write_row(values)?;
572
573 self.rows_in_batch += 1;
574 self.total_rows += 1;
575
576 Ok(())
577 }
578
579 fn write_row(&mut self, values: &[SqlValue]) -> Result<(), Error> {
581 self.buffer.put_u8(TokenType::Row as u8);
583
584 let columns: Vec<_> = self.columns.iter().cloned().collect();
586
587 for (i, (col, value)) in columns.iter().zip(values.iter()).enumerate() {
589 self.encode_column_value(col, value)
590 .map_err(|e| Error::Config(format!("failed to encode column {}: {}", i, e)))?;
591 }
592
593 Ok(())
594 }
595
596 fn encode_column_value(&mut self, col: &BulkColumn, value: &SqlValue) -> Result<(), TypeError> {
598 let buf = &mut self.buffer;
599
600 let is_plp_type =
603 col.max_length == Some(0xFFFF) && matches!(col.type_id, 0xE7 | 0xA7 | 0xA5 | 0xAD);
604
605 match value {
606 SqlValue::Null => {
607 match col.type_id {
609 0xE7 | 0xA7 | 0xA5 | 0xAD => {
611 if is_plp_type {
612 buf.put_u64_le(0xFFFF_FFFF_FFFF_FFFF);
614 } else {
615 buf.put_u16_le(0xFFFF);
617 }
618 }
619 0x26 | 0x6C | 0x6A | 0x24 | 0x29 | 0x2A | 0x2B => {
621 buf.put_u8(0);
622 }
623 _ => {
625 if col.nullable {
626 buf.put_u8(0);
627 } else {
628 return Err(TypeError::UnexpectedNull);
629 }
630 }
631 }
632 }
633
634 SqlValue::Bool(v) => {
635 buf.put_u8(1); buf.put_u8(if *v { 1 } else { 0 });
637 }
638
639 SqlValue::TinyInt(v) => {
640 buf.put_u8(1); buf.put_u8(*v);
642 }
643
644 SqlValue::SmallInt(v) => {
645 buf.put_u8(2); buf.put_i16_le(*v);
647 }
648
649 SqlValue::Int(v) => {
650 buf.put_u8(4); buf.put_i32_le(*v);
652 }
653
654 SqlValue::BigInt(v) => {
655 buf.put_u8(8); buf.put_i64_le(*v);
657 }
658
659 SqlValue::Float(v) => {
660 buf.put_u8(4); buf.put_f32_le(*v);
662 }
663
664 SqlValue::Double(v) => {
665 buf.put_u8(8); buf.put_f64_le(*v);
667 }
668
669 SqlValue::String(s) => {
670 let utf16: Vec<u16> = s.encode_utf16().collect();
672 let byte_len = utf16.len() * 2;
673
674 if is_plp_type {
675 encode_plp_string(&utf16, buf);
678 } else if byte_len > 0xFFFF {
679 return Err(TypeError::BufferTooSmall {
681 needed: byte_len,
682 available: 0xFFFF,
683 });
684 } else {
685 buf.put_u16_le(byte_len as u16);
687 for code_unit in utf16 {
688 buf.put_u16_le(code_unit);
689 }
690 }
691 }
692
693 SqlValue::Binary(b) => {
694 if is_plp_type {
695 encode_plp_binary(b, buf);
697 } else if b.len() > 0xFFFF {
698 return Err(TypeError::BufferTooSmall {
700 needed: b.len(),
701 available: 0xFFFF,
702 });
703 } else {
704 buf.put_u16_le(b.len() as u16);
706 buf.put_slice(b);
707 }
708 }
709
710 #[cfg(feature = "decimal")]
712 SqlValue::Decimal(d) => {
713 let precision = col.precision.unwrap_or(18);
714 let len = decimal_byte_length(precision);
715 buf.put_u8(len);
716
717 buf.put_u8(if d.is_sign_negative() { 0 } else { 1 });
719
720 let mantissa = d.mantissa().unsigned_abs();
722 let mantissa_bytes = mantissa.to_le_bytes();
723 buf.put_slice(&mantissa_bytes[..((len - 1) as usize)]);
724 }
725
726 #[cfg(feature = "uuid")]
727 SqlValue::Uuid(u) => {
728 buf.put_u8(16); mssql_types::encode::encode_uuid(*u, buf);
731 }
732
733 #[cfg(feature = "chrono")]
734 SqlValue::Date(d) => {
735 buf.put_u8(3); mssql_types::encode::encode_date(*d, buf);
737 }
738
739 #[cfg(feature = "chrono")]
740 SqlValue::Time(t) => {
741 let scale = col.scale.unwrap_or(7);
742 let len = time_byte_length(scale);
743 buf.put_u8(len);
744 encode_time_with_scale(*t, scale, buf);
746 }
747
748 #[cfg(feature = "chrono")]
749 SqlValue::DateTime(dt) => {
750 let scale = col.scale.unwrap_or(7);
751 let time_len = time_byte_length(scale);
752 let total_len = time_len + 3;
753 buf.put_u8(total_len);
754 encode_time_with_scale(dt.time(), scale, buf);
756 mssql_types::encode::encode_date(dt.date(), buf);
757 }
758
759 #[cfg(feature = "chrono")]
760 SqlValue::DateTimeOffset(dto) => {
761 let scale = col.scale.unwrap_or(7);
762 let time_len = time_byte_length(scale);
763 let total_len = time_len + 3 + 2;
764 buf.put_u8(total_len);
765 encode_time_with_scale(dto.time(), scale, buf);
767 mssql_types::encode::encode_date(dto.date_naive(), buf);
768 use chrono::Offset;
770 let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
771 buf.put_i16_le(offset_minutes);
772 }
773
774 #[cfg(feature = "json")]
775 SqlValue::Json(j) => {
776 let s = j.to_string();
777 encode_nvarchar_value(&s, buf)?;
778 }
779
780 SqlValue::Xml(x) => {
781 encode_nvarchar_value(x, buf)?;
782 }
783 }
784
785 Ok(())
786 }
787}
788
789fn encode_nvarchar_value(s: &str, buf: &mut BytesMut) -> Result<(), TypeError> {
791 let utf16: Vec<u16> = s.encode_utf16().collect();
792 let byte_len = utf16.len() * 2;
793
794 if byte_len > 0xFFFF {
795 return Err(TypeError::BufferTooSmall {
796 needed: byte_len,
797 available: 0xFFFF,
798 });
799 }
800
801 buf.put_u16_le(byte_len as u16);
802 for code_unit in utf16 {
803 buf.put_u16_le(code_unit);
804 }
805 Ok(())
806}
807
808fn encode_plp_string(utf16: &[u16], buf: &mut BytesMut) {
818 let byte_len = utf16.len() * 2;
819
820 buf.put_u64_le(byte_len as u64);
822
823 if byte_len > 0 {
824 buf.put_u32_le(byte_len as u32);
826 for code_unit in utf16 {
827 buf.put_u16_le(*code_unit);
828 }
829 }
830
831 buf.put_u32_le(0);
833}
834
835fn encode_plp_binary(data: &[u8], buf: &mut BytesMut) {
844 buf.put_u64_le(data.len() as u64);
846
847 if !data.is_empty() {
848 buf.put_u32_le(data.len() as u32);
850 buf.put_slice(data);
851 }
852
853 buf.put_u32_le(0);
855}
856
857#[cfg(feature = "chrono")]
859fn encode_time_with_scale(time: chrono::NaiveTime, scale: u8, buf: &mut BytesMut) {
860 use chrono::Timelike;
861
862 let nanos = time.num_seconds_from_midnight() as u64 * 1_000_000_000 + time.nanosecond() as u64;
863 let intervals = nanos / time_scale_divisor(scale);
864 let len = time_byte_length(scale);
865
866 for i in 0..len {
867 buf.put_u8(((intervals >> (i * 8)) & 0xFF) as u8);
868 }
869}
870
871impl BulkInsert {
872 fn write_done(&mut self) {
874 let buf = &mut self.buffer;
875
876 buf.put_u8(TokenType::Done as u8);
877
878 let status = DoneStatus {
880 more: false,
881 error: false,
882 in_xact: false,
883 count: true,
884 attn: false,
885 srverror: false,
886 };
887 buf.put_u16_le(status.to_bits());
888
889 buf.put_u16_le(0);
891
892 buf.put_u64_le(self.total_rows);
894 }
895
896 pub fn take_packets(&mut self) -> Vec<BytesMut> {
900 const MAX_PACKET_SIZE: usize = 4096;
901 const HEADER_SIZE: usize = 8;
902 const MAX_PAYLOAD: usize = MAX_PACKET_SIZE - HEADER_SIZE;
903
904 let data = self.buffer.split();
905 let mut packets = Vec::new();
906 let mut offset = 0;
907
908 while offset < data.len() {
909 let remaining = data.len() - offset;
910 let payload_size = remaining.min(MAX_PAYLOAD);
911 let is_last = offset + payload_size >= data.len();
912
913 let mut packet = BytesMut::with_capacity(MAX_PACKET_SIZE);
914
915 let header = PacketHeader {
917 packet_type: PacketType::BulkLoad,
918 status: if is_last {
919 PacketStatus::END_OF_MESSAGE
920 } else {
921 PacketStatus::NORMAL
922 },
923 length: (HEADER_SIZE + payload_size) as u16,
924 spid: 0,
925 packet_id: self.packet_id,
926 window: 0,
927 };
928
929 header.encode(&mut packet);
930
931 packet.put_slice(&data[offset..offset + payload_size]);
933
934 packets.push(packet);
935 offset += payload_size;
936 self.packet_id = self.packet_id.wrapping_add(1);
937 }
938
939 packets
940 }
941
942 pub fn total_rows(&self) -> u64 {
944 self.total_rows
945 }
946
947 pub fn rows_in_batch(&self) -> usize {
949 self.rows_in_batch
950 }
951
952 pub fn should_flush(&self) -> bool {
954 self.batch_size > 0 && self.rows_in_batch >= self.batch_size
955 }
956
957 pub fn finish_packets(&mut self) -> Vec<BytesMut> {
960 self.write_done();
961 self.take_packets()
962 }
963
964 pub fn result(&self) -> BulkInsertResult {
966 BulkInsertResult {
967 rows_affected: self.total_rows,
968 batches_committed: self.batches_committed,
969 has_errors: false,
970 }
971 }
972}
973
974fn decimal_byte_length(precision: u8) -> u8 {
976 match precision {
977 1..=9 => 5,
978 10..=19 => 9,
979 20..=28 => 13,
980 29..=38 => 17,
981 _ => 17, }
983}
984
985fn time_byte_length(scale: u8) -> u8 {
987 match scale {
988 0..=2 => 3,
989 3..=4 => 4,
990 5..=7 => 5,
991 _ => 5,
992 }
993}
994
995fn time_scale_divisor(scale: u8) -> u64 {
997 match scale {
998 0 => 1_000_000_000,
999 1 => 100_000_000,
1000 2 => 10_000_000,
1001 3 => 1_000_000,
1002 4 => 100_000,
1003 5 => 10_000,
1004 6 => 1_000,
1005 7 => 100,
1006 _ => 100,
1007 }
1008}
1009
1010#[cfg(test)]
1011#[allow(clippy::unwrap_used)]
1012mod tests {
1013 use super::*;
1014
1015 #[test]
1016 fn test_bulk_options_default() {
1017 let opts = BulkOptions::default();
1018 assert_eq!(opts.batch_size, 0);
1019 assert!(opts.check_constraints);
1020 assert!(!opts.fire_triggers);
1021 assert!(opts.keep_nulls);
1022 assert!(!opts.table_lock);
1023 }
1024
1025 #[test]
1026 fn test_bulk_column_creation() {
1027 let col = BulkColumn::new("id", "INT", 0);
1028 assert_eq!(col.name, "id");
1029 assert_eq!(col.type_id, 0x38);
1030 assert!(col.nullable);
1031 }
1032
1033 #[test]
1034 fn test_parse_sql_type() {
1035 let (type_id, len, _prec, _scale) = parse_sql_type("INT");
1036 assert_eq!(type_id, 0x38);
1037 assert!(len.is_none());
1038
1039 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
1040 assert_eq!(type_id, 0xE7);
1041 assert_eq!(len, Some(200)); let (type_id, _, prec, scale) = parse_sql_type("DECIMAL(10,2)");
1044 assert_eq!(type_id, 0x6C);
1045 assert_eq!(prec, Some(10));
1046 assert_eq!(scale, Some(2));
1047 }
1048
1049 #[test]
1050 fn test_insert_bulk_statement() {
1051 let builder = BulkInsertBuilder::new("dbo.Users")
1052 .with_typed_columns(vec![
1053 BulkColumn::new("id", "INT", 0),
1054 BulkColumn::new("name", "NVARCHAR(100)", 1),
1055 ])
1056 .table_lock(true);
1057
1058 let sql = builder.build_insert_bulk_statement();
1059 assert!(sql.contains("INSERT BULK dbo.Users"));
1060 assert!(sql.contains("TABLOCK"));
1061 }
1062
1063 #[test]
1064 fn test_bulk_insert_creation() {
1065 let columns = vec![
1066 BulkColumn::new("id", "INT", 0),
1067 BulkColumn::new("name", "NVARCHAR(100)", 1),
1068 ];
1069
1070 let bulk = BulkInsert::new(columns, 1000);
1071 assert_eq!(bulk.total_rows(), 0);
1072 assert_eq!(bulk.rows_in_batch(), 0);
1073 assert!(!bulk.should_flush());
1074 }
1075
1076 #[test]
1077 fn test_decimal_byte_length() {
1078 assert_eq!(decimal_byte_length(5), 5);
1079 assert_eq!(decimal_byte_length(15), 9);
1080 assert_eq!(decimal_byte_length(25), 13);
1081 assert_eq!(decimal_byte_length(35), 17);
1082 }
1083
1084 #[test]
1085 fn test_time_byte_length() {
1086 assert_eq!(time_byte_length(0), 3);
1087 assert_eq!(time_byte_length(3), 4);
1088 assert_eq!(time_byte_length(7), 5);
1089 }
1090
1091 #[test]
1092 fn test_plp_string_encoding() {
1093 let mut buf = BytesMut::new();
1094 let text = "Hello";
1095 let utf16: Vec<u16> = text.encode_utf16().collect();
1096
1097 encode_plp_string(&utf16, &mut buf);
1098
1099 assert_eq!(buf.len(), 8 + 4 + 10 + 4);
1105
1106 assert_eq!(&buf[0..8], &10u64.to_le_bytes());
1108
1109 assert_eq!(&buf[8..12], &10u32.to_le_bytes());
1111
1112 assert_eq!(&buf[22..26], &0u32.to_le_bytes());
1114 }
1115
1116 #[test]
1117 fn test_plp_binary_encoding() {
1118 let mut buf = BytesMut::new();
1119 let data = b"test binary data";
1120
1121 encode_plp_binary(data, &mut buf);
1122
1123 assert_eq!(buf.len(), 8 + 4 + 16 + 4);
1129
1130 assert_eq!(&buf[0..8], &16u64.to_le_bytes());
1132
1133 assert_eq!(&buf[8..12], &16u32.to_le_bytes());
1135
1136 assert_eq!(&buf[12..28], data);
1138
1139 assert_eq!(&buf[28..32], &0u32.to_le_bytes());
1141 }
1142
1143 #[test]
1144 fn test_plp_empty_string() {
1145 let mut buf = BytesMut::new();
1146 let utf16: Vec<u16> = "".encode_utf16().collect();
1147
1148 encode_plp_string(&utf16, &mut buf);
1149
1150 assert_eq!(buf.len(), 8 + 4);
1152
1153 assert_eq!(&buf[0..8], &0u64.to_le_bytes());
1155
1156 assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1158 }
1159
1160 #[test]
1161 fn test_plp_empty_binary() {
1162 let mut buf = BytesMut::new();
1163
1164 encode_plp_binary(&[], &mut buf);
1165
1166 assert_eq!(buf.len(), 8 + 4);
1168
1169 assert_eq!(&buf[0..8], &0u64.to_le_bytes());
1171
1172 assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1174 }
1175
1176 #[test]
1177 fn test_parse_sql_type_max() {
1178 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(MAX)");
1180 assert_eq!(type_id, 0xE7);
1181 assert_eq!(len, Some(0xFFFF)); let (type_id, len, _, _) = parse_sql_type("VARBINARY(MAX)");
1185 assert_eq!(type_id, 0xA5);
1186 assert_eq!(len, Some(0xFFFF));
1187
1188 let (type_id, len, _, _) = parse_sql_type("VARCHAR(MAX)");
1190 assert_eq!(type_id, 0xA7);
1191 assert_eq!(len, Some(0xFFFF));
1192
1193 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
1195 assert_eq!(type_id, 0xE7);
1196 assert_eq!(len, Some(200)); }
1198}