1use bytes::{BufMut, BytesMut};
57use once_cell::sync::Lazy;
58use regex::Regex;
59use std::sync::Arc;
60
61use mssql_types::{SqlValue, ToSql, TypeError};
62use tds_protocol::packet::{PacketHeader, PacketStatus, PacketType};
63use tds_protocol::token::{Collation, DoneStatus, TokenType};
64
65use crate::error::Error;
66
67#[derive(Debug, Clone)]
72pub struct BulkOptions {
73 pub batch_size: usize,
83
84 pub check_constraints: bool,
88
89 pub fire_triggers: bool,
93
94 pub keep_nulls: bool,
98
99 pub table_lock: bool,
105
106 pub order_hint: Option<Vec<String>>,
112}
113
114impl Default for BulkOptions {
115 fn default() -> Self {
116 Self {
117 batch_size: 0,
118 check_constraints: true,
119 fire_triggers: false,
120 keep_nulls: true,
121 table_lock: false,
122 order_hint: None,
123 }
124 }
125}
126
127#[derive(Debug, Clone)]
129pub struct BulkColumn {
130 pub name: String,
132 pub sql_type: String,
134 pub nullable: bool,
136 pub ordinal: usize,
138 type_id: u8,
140 max_length: Option<u32>,
142 precision: Option<u8>,
144 scale: Option<u8>,
146 collation: Option<Collation>,
154}
155
156impl BulkColumn {
157 pub fn new<S: Into<String>>(name: S, sql_type: S, ordinal: usize) -> Result<Self, TypeError> {
167 let sql_type_str: String = sql_type.into();
168 reject_unsupported_bulk_type(&sql_type_str)?;
169 let (type_id, max_length, precision, scale) = parse_sql_type(&sql_type_str);
170
171 Ok(Self {
172 name: name.into(),
173 sql_type: sql_type_str,
174 nullable: true,
175 ordinal,
176 type_id,
177 max_length,
178 precision,
179 scale,
180 collation: None,
181 })
182 }
183
184 #[must_use]
186 pub fn with_nullable(mut self, nullable: bool) -> Self {
187 self.nullable = nullable;
188 self
189 }
190
191 #[must_use]
198 pub fn with_collation(mut self, collation: Collation) -> Self {
199 self.collation = Some(collation);
200 self
201 }
202}
203
204fn parse_sql_type(sql_type: &str) -> (u8, Option<u32>, Option<u8>, Option<u8>) {
213 let upper = sql_type.to_uppercase();
214
215 let (base, params) = if let Some(paren_pos) = upper.find('(') {
217 let base = &upper[..paren_pos];
218 let params_str = upper[paren_pos + 1..].trim_end_matches(')');
219 (base, Some(params_str))
220 } else {
221 (upper.as_str(), None)
222 };
223
224 match base {
229 "BIT" => (0x68, Some(1), None, None), "TINYINT" => (0x26, Some(1), None, None), "SMALLINT" => (0x26, Some(2), None, None), "INT" => (0x26, Some(4), None, None), "BIGINT" => (0x26, Some(8), None, None), "REAL" => (0x6D, Some(4), None, None), "FLOAT" => (0x6D, Some(8), None, None), "DATE" => (0x28, None, None, None),
237 "TIME" => {
238 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
239 (0x29, None, None, Some(scale))
240 }
241 "DATETIME" => (0x6F, Some(8), None, None), "DATETIME2" => {
243 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
244 (0x2A, None, None, Some(scale))
245 }
246 "DATETIMEOFFSET" => {
247 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
248 (0x2B, None, None, Some(scale))
249 }
250 "SMALLDATETIME" => (0x6F, Some(4), None, None), "UNIQUEIDENTIFIER" => (0x24, Some(16), None, None),
252 "VARCHAR" | "CHAR" => {
253 let len = params
254 .and_then(|p| {
255 if p == "MAX" {
256 Some(0xFFFF_u32)
257 } else {
258 p.parse().ok()
259 }
260 })
261 .unwrap_or(8000);
262 (0xA7, Some(len), None, None)
263 }
264 "NVARCHAR" | "NCHAR" => {
265 let is_max = params.map(|p| p == "MAX").unwrap_or(false);
266 if is_max {
267 (0xE7, Some(0xFFFF), None, None)
269 } else {
270 let len = params.and_then(|p| p.parse().ok()).unwrap_or(4000);
272 (0xE7, Some(len * 2), None, None)
273 }
274 }
275 "VARBINARY" | "BINARY" => {
276 let len = params
277 .and_then(|p| {
278 if p == "MAX" {
279 Some(0xFFFF_u32)
280 } else {
281 p.parse().ok()
282 }
283 })
284 .unwrap_or(8000);
285 (0xA5, Some(len), None, None)
286 }
287 "DECIMAL" | "NUMERIC" => {
288 let (precision, scale) = if let Some(p) = params {
289 let parts: Vec<&str> = p.split(',').map(|s| s.trim()).collect();
290 (
291 parts.first().and_then(|s| s.parse().ok()).unwrap_or(18),
292 parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0),
293 )
294 } else {
295 (18, 0)
296 };
297 (0x6C, None, Some(precision), Some(scale))
298 }
299 "MONEY" => (0x6E, Some(8), None, None), "SMALLMONEY" => (0x6E, Some(4), None, None), "XML" => (0xF1, Some(0xFFFF), None, None),
302 _ => (0xE7, Some(8000), None, None), }
304}
305
306fn reject_unsupported_bulk_type(sql_type: &str) -> Result<(), TypeError> {
312 let base = sql_type
313 .split('(')
314 .next()
315 .unwrap_or("")
316 .trim()
317 .to_uppercase();
318 match base.as_str() {
319 "TEXT" | "NTEXT" => Err(TypeError::UnsupportedType {
320 sql_type: base,
321 reason: "TEXT/NTEXT are not supported. Use VARCHAR(MAX) / \
322 NVARCHAR(MAX) instead (Microsoft deprecated TEXT/NTEXT in \
323 SQL Server 2005)."
324 .to_string(),
325 }),
326 "IMAGE" => Err(TypeError::UnsupportedType {
327 sql_type: base,
328 reason: "IMAGE is not supported. Use VARBINARY(MAX) instead \
329 (Microsoft deprecated IMAGE in SQL Server 2005)."
330 .to_string(),
331 }),
332 _ => Ok(()),
333 }
334}
335
336#[derive(Debug, Clone)]
338pub struct BulkInsertResult {
339 pub rows_affected: u64,
341 pub batches_committed: u32,
345 pub has_errors: bool,
347}
348
349#[derive(Debug)]
351pub struct BulkInsertBuilder {
352 table_name: String,
353 columns: Vec<BulkColumn>,
354 options: BulkOptions,
355}
356
357impl BulkInsertBuilder {
358 pub fn new<S: Into<String>>(table_name: S) -> Self {
360 Self {
361 table_name: table_name.into(),
362 columns: Vec::new(),
363 options: BulkOptions::default(),
364 }
365 }
366
367 #[must_use]
372 #[allow(clippy::expect_used)] pub fn with_columns(mut self, column_names: &[&str]) -> Self {
374 self.columns = column_names
375 .iter()
376 .enumerate()
377 .map(|(i, name)| {
378 BulkColumn::new(*name, "NVARCHAR(MAX)", i)
379 .expect("NVARCHAR(MAX) is always a supported type")
380 })
381 .collect();
382 self
383 }
384
385 #[must_use]
387 pub fn with_typed_columns(mut self, columns: Vec<BulkColumn>) -> Self {
388 self.columns = columns;
389 self
390 }
391
392 #[must_use]
394 pub fn with_options(mut self, options: BulkOptions) -> Self {
395 self.options = options;
396 self
397 }
398
399 #[must_use]
401 pub fn batch_size(mut self, size: usize) -> Self {
402 self.options.batch_size = size;
403 self
404 }
405
406 #[must_use]
408 pub fn table_lock(mut self, enabled: bool) -> Self {
409 self.options.table_lock = enabled;
410 self
411 }
412
413 #[must_use]
415 pub fn fire_triggers(mut self, enabled: bool) -> Self {
416 self.options.fire_triggers = enabled;
417 self
418 }
419
420 pub fn table_name(&self) -> &str {
422 &self.table_name
423 }
424
425 pub fn columns(&self) -> &[BulkColumn] {
427 &self.columns
428 }
429
430 pub fn options(&self) -> &BulkOptions {
432 &self.options
433 }
434
435 pub fn build_insert_bulk_statement(&self) -> Result<String, Error> {
442 crate::validation::validate_qualified_identifier(&self.table_name)?;
444
445 for col in &self.columns {
447 crate::validation::validate_identifier(&col.name)?;
448 }
449
450 let mut sql = format!("INSERT BULK {}", self.table_name);
451
452 if !self.columns.is_empty() {
454 sql.push_str(" (");
455 let cols: Vec<String> = self
456 .columns
457 .iter()
458 .map(|c| {
459 validate_sql_type(&c.sql_type)?;
465 Ok(format!("{} {}", c.name, c.sql_type))
466 })
467 .collect::<Result<Vec<_>, Error>>()?;
468 sql.push_str(&cols.join(", "));
469 sql.push(')');
470 }
471
472 let mut hints: Vec<String> = Vec::new();
474
475 if self.options.check_constraints {
476 hints.push("CHECK_CONSTRAINTS".to_string());
477 }
478 if self.options.fire_triggers {
479 hints.push("FIRE_TRIGGERS".to_string());
480 }
481 if self.options.keep_nulls {
482 hints.push("KEEP_NULLS".to_string());
483 }
484 if self.options.table_lock {
485 hints.push("TABLOCK".to_string());
486 }
487 if self.options.batch_size > 0 {
488 hints.push(format!("ROWS_PER_BATCH = {}", self.options.batch_size));
489 }
490
491 if let Some(ref order) = self.options.order_hint {
492 for col_name in order {
494 crate::validation::validate_identifier(col_name)?;
495 }
496 hints.push(format!("ORDER({})", order.join(", ")));
497 }
498
499 if !hints.is_empty() {
500 sql.push_str(" WITH (");
501 sql.push_str(&hints.join(", "));
502 sql.push(')');
503 }
504
505 Ok(sql)
506 }
507}
508
509fn validate_sql_type(type_str: &str) -> Result<(), Error> {
515 #[allow(clippy::expect_used)] static SQL_TYPE_RE: Lazy<Regex> =
517 Lazy::new(|| Regex::new(r"^[a-zA-Z][a-zA-Z0-9_ ()\.,]{0,127}$").expect("valid regex"));
518
519 if type_str.is_empty() {
520 return Err(Error::Config("SQL type cannot be empty".into()));
521 }
522
523 if !SQL_TYPE_RE.is_match(type_str) {
524 return Err(Error::Config(format!(
525 "invalid SQL type '{type_str}': contains disallowed characters"
526 )));
527 }
528
529 Ok(())
530}
531
532pub struct BulkInsert {
537 columns: Arc<[BulkColumn]>,
539 fixed_len: Arc<[bool]>,
542 buffer: BytesMut,
544 rows_in_batch: usize,
546 total_rows: u64,
548 batch_size: usize,
550 batches_committed: u32,
552 packet_id: u8,
554}
555
556impl BulkInsert {
557 pub fn new(columns: Vec<BulkColumn>, batch_size: usize) -> Self {
559 Self::new_with_server_metadata(columns, batch_size, None, None)
560 }
561
562 pub fn new_with_server_metadata(
573 mut columns: Vec<BulkColumn>,
574 batch_size: usize,
575 raw_colmetadata: Option<bytes::Bytes>,
576 server_columns: Option<&[tds_protocol::token::ColumnData]>,
577 ) -> Self {
578 let fixed_len: Vec<bool> = if let Some(srv_cols) = server_columns {
581 for (col, srv) in columns.iter_mut().zip(srv_cols.iter()) {
587 if col.collation.is_none() {
588 col.collation = srv.type_info.collation;
589 }
590 }
591 srv_cols
592 .iter()
593 .map(|c| c.type_id.is_fixed_length())
594 .collect()
595 } else {
596 columns
601 .iter()
602 .map(|c| !c.nullable && nullable_to_fixed_type(c.type_id, c.max_length).is_some())
603 .collect()
604 };
605
606 let mut bulk = Self {
607 columns: columns.into(),
608 fixed_len: fixed_len.into(),
609 buffer: BytesMut::with_capacity(64 * 1024),
610 rows_in_batch: 0,
611 total_rows: 0,
612 batch_size,
613 batches_committed: 0,
614 packet_id: 1,
615 };
616
617 if let Some(raw) = raw_colmetadata {
618 bulk.buffer.extend_from_slice(&raw);
619 } else {
620 bulk.write_colmetadata();
621 }
622
623 bulk
624 }
625
626 fn write_colmetadata(&mut self) {
628 let buf = &mut self.buffer;
629
630 buf.put_u8(TokenType::ColMetaData as u8);
632
633 buf.put_u16_le(self.columns.len() as u16);
635
636 for col in self.columns.iter() {
637 buf.put_u32_le(0);
639
640 let effective_type_id = if !col.nullable {
644 nullable_to_fixed_type(col.type_id, col.max_length).unwrap_or(col.type_id)
645 } else {
646 col.type_id
647 };
648 let is_fixed_variant = effective_type_id != col.type_id;
649
650 let mut flags: u16 = 0x0008; if col.nullable {
654 flags |= 0x0001; }
656 buf.put_u16_le(flags);
657
658 buf.put_u8(effective_type_id);
660
661 if is_fixed_variant {
664 let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
665 buf.put_u8(name_utf16.len() as u8);
666 for code_unit in name_utf16 {
667 buf.put_u16_le(code_unit);
668 }
669 continue;
670 }
671
672 match col.type_id {
674 0x26 | 0x68 | 0x6D | 0x6E | 0x6F => {
677 buf.put_u8(col.max_length.unwrap_or(4) as u8);
678 }
679
680 0x28 => {}
682
683 0xE7 | 0xA7 | 0xA5 | 0xAD => {
685 let max_len = col.max_length.unwrap_or(8000);
687 if max_len == 0xFFFF {
688 buf.put_u16_le(0xFFFF);
689 } else {
690 buf.put_u16_le(max_len as u16);
691 }
692
693 if col.type_id == 0xE7 || col.type_id == 0xA7 {
697 if let Some(coll) = col.collation.as_ref() {
698 buf.put_slice(&coll.to_bytes());
699 } else {
700 buf.put_slice(&[0x09, 0x04, 0xD0, 0x00, 0x34]);
703 }
704 }
705 }
706
707 0x6C | 0x6A => {
709 let precision = col.precision.unwrap_or(18);
711 let len = decimal_byte_length(precision);
712 buf.put_u8(len);
713 buf.put_u8(precision);
714 buf.put_u8(col.scale.unwrap_or(0));
715 }
716
717 0x29..=0x2B => {
719 buf.put_u8(col.scale.unwrap_or(7));
720 }
721
722 0x24 => {
724 buf.put_u8(16);
725 }
726
727 _ => {
729 if let Some(len) = col.max_length {
730 if len <= 0xFFFF {
731 buf.put_u16_le(len as u16);
732 }
733 }
734 }
735 }
736
737 let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
739 buf.put_u8(name_utf16.len() as u8);
740 for code_unit in name_utf16 {
741 buf.put_u16_le(code_unit);
742 }
743 }
744 }
745
746 pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
757 if values.len() != self.columns.len() {
758 return Err(Error::Config(format!(
759 "expected {} values, got {}",
760 self.columns.len(),
761 values.len()
762 )));
763 }
764
765 let sql_values: Result<Vec<SqlValue>, TypeError> =
767 values.iter().map(|v| v.to_sql()).collect();
768 let sql_values = sql_values.map_err(Error::from)?;
769
770 self.write_row(&sql_values)?;
771
772 self.rows_in_batch += 1;
773 self.total_rows += 1;
774
775 Ok(())
776 }
777
778 pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
780 if values.len() != self.columns.len() {
781 return Err(Error::Config(format!(
782 "expected {} values, got {}",
783 self.columns.len(),
784 values.len()
785 )));
786 }
787
788 self.write_row(values)?;
789
790 self.rows_in_batch += 1;
791 self.total_rows += 1;
792
793 Ok(())
794 }
795
796 fn write_row(&mut self, values: &[SqlValue]) -> Result<(), Error> {
798 self.buffer.put_u8(TokenType::Row as u8);
800
801 let columns: Vec<_> = self.columns.iter().cloned().collect();
803 let fixed_len = self.fixed_len.clone();
804
805 for (i, (col, value)) in columns.iter().zip(values.iter()).enumerate() {
807 let is_fixed = *fixed_len.get(i).unwrap_or(&false);
808 self.encode_column_value(col, value, is_fixed)
809 .map_err(|e| Error::Config(format!("failed to encode column {i}: {e}")))?;
810 }
811
812 Ok(())
813 }
814
815 fn encode_column_value(
821 &mut self,
822 col: &BulkColumn,
823 value: &SqlValue,
824 is_fixed: bool,
825 ) -> Result<(), TypeError> {
826 let buf = &mut self.buffer;
827
828 let is_plp_type =
831 col.max_length == Some(0xFFFF) && matches!(col.type_id, 0xE7 | 0xA7 | 0xA5 | 0xAD);
832
833 match value {
834 SqlValue::Null => {
835 match col.type_id {
837 0xE7 | 0xA7 | 0xA5 | 0xAD => {
839 if is_plp_type {
840 buf.put_u64_le(0xFFFF_FFFF_FFFF_FFFF);
842 } else {
843 buf.put_u16_le(0xFFFF);
845 }
846 }
847 0x26 | 0x68 | 0x6D | 0x6E | 0x6F | 0x6C | 0x6A | 0x24 | 0x28 | 0x29 | 0x2A
850 | 0x2B => {
851 buf.put_u8(0);
852 }
853 _ => {
855 if col.nullable {
856 buf.put_u8(0);
857 } else {
858 return Err(TypeError::UnexpectedNull);
859 }
860 }
861 }
862 }
863
864 SqlValue::Bool(v) => {
865 if !is_fixed {
866 buf.put_u8(1);
867 }
868 buf.put_u8(if *v { 1 } else { 0 });
869 }
870
871 SqlValue::TinyInt(v) => {
872 if !is_fixed {
873 buf.put_u8(1);
874 }
875 buf.put_u8(*v);
876 }
877
878 SqlValue::SmallInt(v) => {
879 if !is_fixed {
880 buf.put_u8(2);
881 }
882 buf.put_i16_le(*v);
883 }
884
885 SqlValue::Int(v) => {
886 if !is_fixed {
887 buf.put_u8(4);
888 }
889 buf.put_i32_le(*v);
890 }
891
892 SqlValue::BigInt(v) => {
893 if !is_fixed {
894 buf.put_u8(8);
895 }
896 buf.put_i64_le(*v);
897 }
898
899 SqlValue::Float(v) => {
900 if !is_fixed {
901 buf.put_u8(4);
902 }
903 buf.put_f32_le(*v);
904 }
905
906 SqlValue::Double(v) => {
907 if !is_fixed {
908 buf.put_u8(8);
909 }
910 buf.put_f64_le(*v);
911 }
912
913 SqlValue::String(s) => {
914 let is_varchar = matches!(col.type_id, 0xA7 | 0x2F | 0xAF);
920
921 if is_varchar {
922 let encoded = encode_varchar_for_collation(s, col.collation.as_ref());
923 let byte_len = encoded.len();
924
925 if is_plp_type {
926 encode_plp_binary(&encoded, buf);
927 } else if byte_len > 0xFFFF {
928 return Err(TypeError::BufferTooSmall {
929 needed: byte_len,
930 available: 0xFFFF,
931 });
932 } else {
933 buf.put_u16_le(byte_len as u16);
934 buf.put_slice(&encoded);
935 }
936 } else {
937 let utf16: Vec<u16> = s.encode_utf16().collect();
939 let byte_len = utf16.len() * 2;
940
941 if is_plp_type {
942 encode_plp_string(&utf16, buf);
945 } else if byte_len > 0xFFFF {
946 return Err(TypeError::BufferTooSmall {
948 needed: byte_len,
949 available: 0xFFFF,
950 });
951 } else {
952 buf.put_u16_le(byte_len as u16);
954 for code_unit in utf16 {
955 buf.put_u16_le(code_unit);
956 }
957 }
958 }
959 }
960
961 SqlValue::Binary(b) => {
962 if is_plp_type {
963 encode_plp_binary(b, buf);
965 } else if b.len() > 0xFFFF {
966 return Err(TypeError::BufferTooSmall {
968 needed: b.len(),
969 available: 0xFFFF,
970 });
971 } else {
972 buf.put_u16_le(b.len() as u16);
974 buf.put_slice(b);
975 }
976 }
977
978 #[cfg(feature = "decimal")]
980 SqlValue::Decimal(d) => {
981 if col.type_id == 0x6E {
982 encode_money_value(*d, col, buf, is_fixed)?;
984 } else {
985 let precision = col.precision.unwrap_or(18);
986 let len = decimal_byte_length(precision);
987 buf.put_u8(len);
988
989 buf.put_u8(if d.is_sign_negative() { 0 } else { 1 });
991
992 let mantissa = d.mantissa().unsigned_abs();
994 let mantissa_bytes = mantissa.to_le_bytes();
995 buf.put_slice(&mantissa_bytes[..((len - 1) as usize)]);
996 }
997 }
998
999 #[cfg(feature = "uuid")]
1000 SqlValue::Uuid(u) => {
1001 buf.put_u8(16); mssql_types::encode::encode_uuid(*u, buf);
1004 }
1005
1006 #[cfg(feature = "chrono")]
1007 SqlValue::Date(d) => {
1008 buf.put_u8(3); mssql_types::encode::encode_date(*d, buf);
1010 }
1011
1012 #[cfg(feature = "chrono")]
1013 SqlValue::Time(t) => {
1014 let scale = col.scale.unwrap_or(7);
1015 let len = time_byte_length(scale);
1016 buf.put_u8(len);
1017 encode_time_with_scale(*t, scale, buf);
1019 }
1020
1021 #[cfg(feature = "chrono")]
1022 SqlValue::DateTime(dt) => {
1023 if col.type_id == 0x6F {
1028 let total_len = col.max_length.unwrap_or(8) as u8;
1029 if !is_fixed {
1030 buf.put_u8(total_len);
1031 }
1032 match total_len {
1033 8 => mssql_types::encode::encode_datetime_legacy(*dt, buf),
1034 4 => mssql_types::encode::encode_smalldatetime(*dt, buf)?,
1035 _ => {
1036 return Err(TypeError::InvalidDateTime(format!(
1037 "DATETIMEN max_length must be 4 or 8, got {total_len}"
1038 )));
1039 }
1040 }
1041 } else {
1042 let scale = col.scale.unwrap_or(7);
1043 let time_len = time_byte_length(scale);
1044 let total_len = time_len + 3;
1045 buf.put_u8(total_len);
1046 encode_time_with_scale(dt.time(), scale, buf);
1048 mssql_types::encode::encode_date(dt.date(), buf);
1049 }
1050 }
1051 #[cfg(feature = "chrono")]
1052 SqlValue::SmallDateTime(dt) => {
1053 if !is_fixed {
1056 buf.put_u8(4);
1057 }
1058 mssql_types::encode::encode_smalldatetime(*dt, buf)?;
1059 }
1060 #[cfg(feature = "decimal")]
1061 SqlValue::Money(d) => {
1062 if !is_fixed {
1064 buf.put_u8(8);
1065 }
1066 mssql_types::encode::encode_money(*d, buf)?;
1067 }
1068 #[cfg(feature = "decimal")]
1069 SqlValue::SmallMoney(d) => {
1070 if !is_fixed {
1071 buf.put_u8(4);
1072 }
1073 mssql_types::encode::encode_smallmoney(*d, buf)?;
1074 }
1075
1076 #[cfg(feature = "chrono")]
1077 SqlValue::DateTimeOffset(dto) => {
1078 let scale = col.scale.unwrap_or(7);
1079 let time_len = time_byte_length(scale);
1080 let total_len = time_len + 3 + 2;
1081 buf.put_u8(total_len);
1082 let utc = dto.naive_utc();
1085 encode_time_with_scale(utc.time(), scale, buf);
1086 mssql_types::encode::encode_date(utc.date(), buf);
1087 use chrono::Offset;
1089 let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
1090 buf.put_i16_le(offset_minutes);
1091 }
1092
1093 #[cfg(feature = "json")]
1094 SqlValue::Json(j) => {
1095 let s = j.to_string();
1096 encode_nvarchar_value(&s, buf)?;
1097 }
1098
1099 SqlValue::Xml(x) => {
1100 encode_nvarchar_value(x, buf)?;
1101 }
1102
1103 SqlValue::Tvp(_) => {
1104 return Err(TypeError::UnsupportedConversion {
1106 from: "TVP".to_string(),
1107 to: "bulk copy value",
1108 });
1109 }
1110 _ => {
1112 return Err(TypeError::UnsupportedConversion {
1113 from: value.type_name().to_string(),
1114 to: "bulk copy value",
1115 });
1116 }
1117 }
1118
1119 Ok(())
1120 }
1121}
1122
1123#[cfg(feature = "decimal")]
1129fn encode_money_value(
1130 value: rust_decimal::Decimal,
1131 col: &BulkColumn,
1132 buf: &mut BytesMut,
1133 is_fixed: bool,
1134) -> Result<(), TypeError> {
1135 let money_bytes: u8 = col.max_length.unwrap_or(8) as u8;
1136 if !is_fixed {
1137 buf.put_u8(money_bytes);
1138 }
1139 match money_bytes {
1140 4 => mssql_types::encode::encode_smallmoney(value, buf),
1141 8 => mssql_types::encode::encode_money(value, buf),
1142 _ => Err(TypeError::InvalidDecimal(format!(
1143 "MONEY column has invalid max_length: {money_bytes}"
1144 ))),
1145 }
1146}
1147
1148fn encode_nvarchar_value(s: &str, buf: &mut BytesMut) -> Result<(), TypeError> {
1150 let utf16: Vec<u16> = s.encode_utf16().collect();
1151 let byte_len = utf16.len() * 2;
1152
1153 if byte_len > 0xFFFF {
1154 return Err(TypeError::BufferTooSmall {
1155 needed: byte_len,
1156 available: 0xFFFF,
1157 });
1158 }
1159
1160 buf.put_u16_le(byte_len as u16);
1161 for code_unit in utf16 {
1162 buf.put_u16_le(code_unit);
1163 }
1164 Ok(())
1165}
1166
1167const PLP_UNKNOWN_LEN: u64 = 0xFFFFFFFFFFFFFFFE;
1172
1173fn encode_plp_string(utf16: &[u16], buf: &mut BytesMut) {
1186 let byte_len = utf16.len() * 2;
1187
1188 buf.put_u64_le(PLP_UNKNOWN_LEN);
1189
1190 if byte_len > 0 {
1191 buf.put_u32_le(byte_len as u32);
1192 for code_unit in utf16 {
1193 buf.put_u16_le(*code_unit);
1194 }
1195 }
1196
1197 buf.put_u32_le(0);
1198}
1199
1200fn encode_plp_binary(data: &[u8], buf: &mut BytesMut) {
1203 buf.put_u64_le(PLP_UNKNOWN_LEN);
1204
1205 if !data.is_empty() {
1206 buf.put_u32_le(data.len() as u32);
1207 buf.put_slice(data);
1208 }
1209
1210 buf.put_u32_le(0);
1211}
1212
1213fn encode_varchar_for_collation(value: &str, collation: Option<&Collation>) -> Vec<u8> {
1218 tds_protocol::collation::encode_str_for_collation(value, collation)
1219}
1220
1221#[cfg(feature = "chrono")]
1223fn encode_time_with_scale(time: chrono::NaiveTime, scale: u8, buf: &mut BytesMut) {
1224 use chrono::Timelike;
1225
1226 let nanos = time.num_seconds_from_midnight() as u64 * 1_000_000_000 + time.nanosecond() as u64;
1227 let intervals = nanos / time_scale_divisor(scale);
1228 let len = time_byte_length(scale);
1229
1230 for i in 0..len {
1231 buf.put_u8(((intervals >> (i * 8)) & 0xFF) as u8);
1232 }
1233}
1234
1235impl BulkInsert {
1236 fn write_done(&mut self) {
1238 let buf = &mut self.buffer;
1239
1240 buf.put_u8(TokenType::Done as u8);
1241
1242 let status = DoneStatus::from_bits(0x0010); buf.put_u16_le(status.to_bits());
1245
1246 buf.put_u16_le(0);
1248
1249 buf.put_u64_le(self.total_rows);
1251 }
1252
1253 pub fn take_packets(&mut self) -> Vec<BytesMut> {
1257 const MAX_PACKET_SIZE: usize = 4096;
1258 const HEADER_SIZE: usize = 8;
1259 const MAX_PAYLOAD: usize = MAX_PACKET_SIZE - HEADER_SIZE;
1260
1261 let data = self.buffer.split();
1262 let mut packets = Vec::new();
1263 let mut offset = 0;
1264
1265 while offset < data.len() {
1266 let remaining = data.len() - offset;
1267 let payload_size = remaining.min(MAX_PAYLOAD);
1268 let is_last = offset + payload_size >= data.len();
1269
1270 let mut packet = BytesMut::with_capacity(MAX_PACKET_SIZE);
1271
1272 let header = PacketHeader {
1274 packet_type: PacketType::BulkLoad,
1275 status: if is_last {
1276 PacketStatus::END_OF_MESSAGE
1277 } else {
1278 PacketStatus::NORMAL
1279 },
1280 length: (HEADER_SIZE + payload_size) as u16,
1281 spid: 0,
1282 packet_id: self.packet_id,
1283 window: 0,
1284 };
1285
1286 header.encode(&mut packet);
1287
1288 packet.put_slice(&data[offset..offset + payload_size]);
1290
1291 packets.push(packet);
1292 offset += payload_size;
1293 self.packet_id = self.packet_id.wrapping_add(1);
1294 }
1295
1296 packets
1297 }
1298
1299 pub fn total_rows(&self) -> u64 {
1301 self.total_rows
1302 }
1303
1304 pub fn rows_in_batch(&self) -> usize {
1306 self.rows_in_batch
1307 }
1308
1309 pub fn should_flush(&self) -> bool {
1315 self.batch_size > 0 && self.rows_in_batch >= self.batch_size
1316 }
1317
1318 pub fn finish_packets(&mut self) -> Vec<BytesMut> {
1321 self.write_done();
1322 self.take_packets()
1323 }
1324
1325 pub fn result(&self) -> BulkInsertResult {
1327 BulkInsertResult {
1328 rows_affected: self.total_rows,
1329 batches_committed: self.batches_committed,
1330 has_errors: false,
1331 }
1332 }
1333}
1334
1335pub struct BulkWriter<'a, S: crate::state::ConnectionState> {
1364 client: &'a mut crate::client::Client<S>,
1365 bulk: BulkInsert,
1366}
1367
1368impl<'a, S: crate::state::ConnectionState> BulkWriter<'a, S> {
1369 pub(crate) fn new(client: &'a mut crate::client::Client<S>, bulk: BulkInsert) -> Self {
1371 Self { client, bulk }
1372 }
1373
1374 pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
1380 self.bulk.send_row(values)
1381 }
1382
1383 pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
1385 self.bulk.send_row_values(values)
1386 }
1387
1388 pub fn total_rows(&self) -> u64 {
1390 self.bulk.total_rows()
1391 }
1392
1393 pub async fn finish(mut self) -> Result<BulkInsertResult, Error> {
1398 let total_rows = self.bulk.total_rows();
1399 tracing::debug!(total_rows = total_rows, "finishing bulk insert");
1400
1401 self.bulk.write_done();
1403 let payload = self.bulk.buffer.split().freeze();
1404
1405 let rows_affected = self.client.send_and_read_bulk_load(payload).await?;
1407
1408 Ok(BulkInsertResult {
1409 rows_affected,
1410 batches_committed: 1,
1411 has_errors: false,
1412 })
1413 }
1414}
1415
1416fn nullable_to_fixed_type(type_id: u8, max_length: Option<u32>) -> Option<u8> {
1426 match (type_id, max_length) {
1427 (0x68, _) => Some(0x32), (0x26, Some(1)) => Some(0x30), (0x26, Some(2)) => Some(0x34), (0x26, Some(4)) => Some(0x38), (0x26, Some(8)) => Some(0x7F), (0x6D, Some(4)) => Some(0x3B), (0x6D, Some(8)) => Some(0x3E), (0x6E, Some(4)) => Some(0x7A), (0x6E, Some(8)) => Some(0x3C), (0x6F, Some(4)) => Some(0x3A), (0x6F, Some(8)) => Some(0x3D), _ => None,
1439 }
1440}
1441
1442fn decimal_byte_length(precision: u8) -> u8 {
1444 match precision {
1445 1..=9 => 5,
1446 10..=19 => 9,
1447 20..=28 => 13,
1448 29..=38 => 17,
1449 _ => 17, }
1451}
1452
1453#[cfg(feature = "chrono")]
1455fn time_byte_length(scale: u8) -> u8 {
1456 match scale {
1457 0..=2 => 3,
1458 3..=4 => 4,
1459 5..=7 => 5,
1460 _ => 5,
1461 }
1462}
1463
1464#[cfg(feature = "chrono")]
1466fn time_scale_divisor(scale: u8) -> u64 {
1467 match scale {
1468 0 => 1_000_000_000,
1469 1 => 100_000_000,
1470 2 => 10_000_000,
1471 3 => 1_000_000,
1472 4 => 100_000,
1473 5 => 10_000,
1474 6 => 1_000,
1475 7 => 100,
1476 _ => 100,
1477 }
1478}
1479
1480#[cfg(test)]
1481#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
1482mod tests {
1483 use super::*;
1484
1485 #[test]
1486 fn test_bulk_options_default() {
1487 let opts = BulkOptions::default();
1488 assert_eq!(opts.batch_size, 0);
1489 assert!(opts.check_constraints);
1490 assert!(!opts.fire_triggers);
1491 assert!(opts.keep_nulls);
1492 assert!(!opts.table_lock);
1493 }
1494
1495 #[test]
1496 fn test_bulk_column_creation() {
1497 let col = BulkColumn::new("id", "INT", 0).unwrap();
1498 assert_eq!(col.name, "id");
1499 assert_eq!(col.type_id, 0x26); assert_eq!(col.max_length, Some(4));
1501 assert!(col.nullable);
1502 }
1503
1504 #[test]
1505 fn test_bulk_column_rejects_text() {
1506 let err = BulkColumn::new("body", "TEXT", 0).unwrap_err();
1507 match err {
1508 TypeError::UnsupportedType { sql_type, reason } => {
1509 assert_eq!(sql_type, "TEXT");
1510 assert!(
1511 reason.contains("VARCHAR(MAX)"),
1512 "error should redirect to VARCHAR(MAX), got: {reason}"
1513 );
1514 assert!(
1515 reason.contains("deprecated"),
1516 "error should mention deprecation, got: {reason}"
1517 );
1518 }
1519 other => panic!("expected UnsupportedType, got {other:?}"),
1520 }
1521 }
1522
1523 #[test]
1524 fn test_bulk_column_rejects_ntext() {
1525 let err = BulkColumn::new("body", "NTEXT", 0).unwrap_err();
1526 match err {
1527 TypeError::UnsupportedType { sql_type, reason } => {
1528 assert_eq!(sql_type, "NTEXT");
1529 assert!(
1530 reason.contains("NVARCHAR(MAX)"),
1531 "error should redirect to NVARCHAR(MAX), got: {reason}"
1532 );
1533 assert!(
1534 reason.contains("deprecated"),
1535 "error should mention deprecation, got: {reason}"
1536 );
1537 }
1538 other => panic!("expected UnsupportedType, got {other:?}"),
1539 }
1540 }
1541
1542 #[test]
1543 fn test_bulk_column_rejects_text_case_insensitive() {
1544 assert!(matches!(
1545 BulkColumn::new("body", "text", 0),
1546 Err(TypeError::UnsupportedType { .. })
1547 ));
1548 assert!(matches!(
1549 BulkColumn::new("body", "Ntext", 0),
1550 Err(TypeError::UnsupportedType { .. })
1551 ));
1552 }
1553
1554 #[test]
1555 fn test_bulk_column_rejects_image() {
1556 let err = BulkColumn::new("blob", "IMAGE", 0).unwrap_err();
1557 match err {
1558 TypeError::UnsupportedType { sql_type, reason } => {
1559 assert_eq!(sql_type, "IMAGE");
1560 assert!(
1561 reason.contains("VARBINARY(MAX)"),
1562 "error should redirect to VARBINARY(MAX), got: {reason}"
1563 );
1564 assert!(
1565 reason.contains("deprecated"),
1566 "error should mention deprecation, got: {reason}"
1567 );
1568 }
1569 other => panic!("expected UnsupportedType, got {other:?}"),
1570 }
1571 }
1572
1573 #[test]
1574 fn test_bulk_column_rejects_image_case_insensitive() {
1575 assert!(matches!(
1576 BulkColumn::new("blob", "image", 0),
1577 Err(TypeError::UnsupportedType { .. })
1578 ));
1579 assert!(matches!(
1580 BulkColumn::new("blob", "Image", 0),
1581 Err(TypeError::UnsupportedType { .. })
1582 ));
1583 }
1584
1585 #[test]
1586 fn test_parse_sql_type() {
1587 let (type_id, len, _prec, _scale) = parse_sql_type("INT");
1589 assert_eq!(type_id, 0x26);
1590 assert_eq!(len, Some(4));
1591
1592 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
1593 assert_eq!(type_id, 0xE7);
1594 assert_eq!(len, Some(200)); let (type_id, _, prec, scale) = parse_sql_type("DECIMAL(10,2)");
1597 assert_eq!(type_id, 0x6C);
1598 assert_eq!(prec, Some(10));
1599 assert_eq!(scale, Some(2));
1600
1601 let (type_id, len, _, _) = parse_sql_type("SMALLDATETIME");
1603 assert_eq!(type_id, 0x6F);
1604 assert_eq!(len, Some(4));
1605
1606 let (type_id, len, _, _) = parse_sql_type("DATETIME");
1607 assert_eq!(type_id, 0x6F);
1608 assert_eq!(len, Some(8));
1609 }
1610
1611 #[test]
1612 fn test_insert_bulk_statement() {
1613 let builder = BulkInsertBuilder::new("dbo.Users")
1614 .with_typed_columns(vec![
1615 BulkColumn::new("id", "INT", 0).unwrap(),
1616 BulkColumn::new("name", "NVARCHAR(100)", 1).unwrap(),
1617 ])
1618 .table_lock(true);
1619
1620 let sql = builder.build_insert_bulk_statement().unwrap();
1621 assert!(sql.contains("INSERT BULK dbo.Users"));
1622 assert!(sql.contains("TABLOCK"));
1623 }
1624
1625 #[test]
1626 fn test_bulk_insert_rejects_injection() {
1627 let builder = BulkInsertBuilder::new("table;DROP TABLE users")
1628 .with_typed_columns(vec![BulkColumn::new("id", "INT", 0).unwrap()]);
1629
1630 assert!(builder.build_insert_bulk_statement().is_err());
1631 }
1632
1633 #[test]
1634 fn test_bulk_insert_validates_column_names() {
1635 let builder = BulkInsertBuilder::new("Users")
1636 .with_typed_columns(vec![BulkColumn::new("col;DROP TABLE x", "INT", 0).unwrap()]);
1637
1638 assert!(builder.build_insert_bulk_statement().is_err());
1639 }
1640
1641 #[test]
1642 fn test_bulk_insert_accepts_qualified_names() {
1643 let builder = BulkInsertBuilder::new("catalog.dbo.Users")
1644 .with_typed_columns(vec![BulkColumn::new("id", "INT", 0).unwrap()]);
1645
1646 assert!(builder.build_insert_bulk_statement().is_ok());
1647 }
1648
1649 #[test]
1650 fn test_bulk_insert_creation() {
1651 let columns = vec![
1652 BulkColumn::new("id", "INT", 0).unwrap(),
1653 BulkColumn::new("name", "NVARCHAR(100)", 1).unwrap(),
1654 ];
1655
1656 let bulk = BulkInsert::new(columns, 1000);
1657 assert_eq!(bulk.total_rows(), 0);
1658 assert_eq!(bulk.rows_in_batch(), 0);
1659 assert!(!bulk.should_flush());
1660 }
1661
1662 #[test]
1663 fn test_decimal_byte_length() {
1664 assert_eq!(decimal_byte_length(5), 5);
1665 assert_eq!(decimal_byte_length(15), 9);
1666 assert_eq!(decimal_byte_length(25), 13);
1667 assert_eq!(decimal_byte_length(35), 17);
1668 }
1669
1670 #[test]
1671 #[cfg(feature = "chrono")]
1672 fn test_time_byte_length() {
1673 assert_eq!(time_byte_length(0), 3);
1674 assert_eq!(time_byte_length(3), 4);
1675 assert_eq!(time_byte_length(7), 5);
1676 }
1677
1678 #[test]
1679 fn test_plp_string_encoding() {
1680 let mut buf = BytesMut::new();
1681 let text = "Hello";
1682 let utf16: Vec<u16> = text.encode_utf16().collect();
1683
1684 encode_plp_string(&utf16, &mut buf);
1685
1686 assert_eq!(buf.len(), 8 + 4 + 10 + 4);
1692
1693 assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1695
1696 assert_eq!(&buf[8..12], &10u32.to_le_bytes());
1698
1699 assert_eq!(&buf[22..26], &0u32.to_le_bytes());
1701 }
1702
1703 #[test]
1704 fn test_plp_binary_encoding() {
1705 let mut buf = BytesMut::new();
1706 let data = b"test binary data";
1707
1708 encode_plp_binary(data, &mut buf);
1709
1710 assert_eq!(buf.len(), 8 + 4 + 16 + 4);
1716
1717 assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1719
1720 assert_eq!(&buf[8..12], &16u32.to_le_bytes());
1722
1723 assert_eq!(&buf[12..28], data);
1725
1726 assert_eq!(&buf[28..32], &0u32.to_le_bytes());
1728 }
1729
1730 #[test]
1731 fn test_plp_empty_string() {
1732 let mut buf = BytesMut::new();
1733 let utf16: Vec<u16> = "".encode_utf16().collect();
1734
1735 encode_plp_string(&utf16, &mut buf);
1736
1737 assert_eq!(buf.len(), 8 + 4);
1739
1740 assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1742
1743 assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1745 }
1746
1747 #[test]
1748 fn test_plp_empty_binary() {
1749 let mut buf = BytesMut::new();
1750
1751 encode_plp_binary(&[], &mut buf);
1752
1753 assert_eq!(buf.len(), 8 + 4);
1755
1756 assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1758
1759 assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1761 }
1762
1763 #[test]
1766 fn test_write_colmetadata_roundtrip() {
1767 use tds_protocol::token::ColMetaData;
1768
1769 let columns = vec![
1770 BulkColumn::new("id", "INT", 0).unwrap(),
1771 BulkColumn::new("tiny", "TINYINT", 1).unwrap(),
1772 BulkColumn::new("small", "SMALLINT", 2).unwrap(),
1773 BulkColumn::new("big", "BIGINT", 3).unwrap(),
1774 BulkColumn::new("flag", "BIT", 4).unwrap(),
1775 BulkColumn::new("r", "REAL", 5).unwrap(),
1776 BulkColumn::new("f", "FLOAT", 6).unwrap(),
1777 BulkColumn::new("name", "NVARCHAR(100)", 7).unwrap(),
1778 BulkColumn::new("code", "VARCHAR(50)", 8).unwrap(),
1779 BulkColumn::new("data", "VARBINARY(200)", 9).unwrap(),
1780 BulkColumn::new("d", "DATE", 10).unwrap(),
1781 BulkColumn::new("t", "TIME(3)", 11).unwrap(),
1782 BulkColumn::new("dt", "DATETIME", 12).unwrap(),
1783 BulkColumn::new("dt2", "DATETIME2(7)", 13).unwrap(),
1784 BulkColumn::new("dto", "DATETIMEOFFSET(7)", 14).unwrap(),
1785 BulkColumn::new("sdt", "SMALLDATETIME", 15).unwrap(),
1786 BulkColumn::new("uid", "UNIQUEIDENTIFIER", 16).unwrap(),
1787 BulkColumn::new("amt", "DECIMAL(18,2)", 17).unwrap(),
1788 BulkColumn::new("price", "MONEY", 18).unwrap(),
1789 BulkColumn::new("smoney", "SMALLMONEY", 19).unwrap(),
1790 BulkColumn::new("nmax", "NVARCHAR(MAX)", 20).unwrap(),
1791 BulkColumn::new("vmax", "VARCHAR(MAX)", 21).unwrap(),
1792 BulkColumn::new("bmax", "VARBINARY(MAX)", 22).unwrap(),
1793 ];
1794
1795 let bulk = BulkInsert::new(columns.clone(), 0);
1796
1797 let buf = &bulk.buffer[1..];
1799 let mut cursor = bytes::Bytes::copy_from_slice(buf);
1800 let meta = ColMetaData::decode(&mut cursor)
1801 .expect("write_colmetadata output should be parseable by TDS decoder");
1802
1803 assert_eq!(meta.columns.len(), columns.len());
1804
1805 for (i, (parsed, original)) in meta.columns.iter().zip(columns.iter()).enumerate() {
1807 assert_eq!(parsed.name, original.name, "column {i} name mismatch");
1808 assert_eq!(
1809 parsed.col_type, original.type_id,
1810 "column {i} ({}) type mismatch",
1811 original.name
1812 );
1813
1814 match original.type_id {
1816 0x26 => {
1818 assert_eq!(
1819 parsed.type_info.max_length, original.max_length,
1820 "column {i} ({}) INTN max_length",
1821 original.name
1822 );
1823 }
1824 0x68 => {
1826 assert_eq!(parsed.type_info.max_length, Some(1));
1827 }
1828 0x6D => {
1830 assert_eq!(
1831 parsed.type_info.max_length, original.max_length,
1832 "column {i} ({}) FLTN max_length",
1833 original.name
1834 );
1835 }
1836 0x6E => {
1838 assert_eq!(
1839 parsed.type_info.max_length, original.max_length,
1840 "column {i} ({}) MONEYN max_length",
1841 original.name
1842 );
1843 }
1844 0x6F => {
1846 assert_eq!(
1847 parsed.type_info.max_length, original.max_length,
1848 "column {i} ({}) DATETIMEN max_length",
1849 original.name
1850 );
1851 }
1852 0x24 => {
1854 assert_eq!(parsed.type_info.max_length, Some(16));
1855 }
1856 0x28 => {}
1858 0x29..=0x2B => {
1860 assert_eq!(
1861 parsed.type_info.scale, original.scale,
1862 "column {i} ({}) scale",
1863 original.name
1864 );
1865 }
1866 0xE7 | 0xA7 => {
1868 assert_eq!(
1869 parsed.type_info.max_length, original.max_length,
1870 "column {i} ({}) string max_length",
1871 original.name
1872 );
1873 assert!(
1874 parsed.type_info.collation.is_some(),
1875 "column {i} ({}) should have collation",
1876 original.name
1877 );
1878 }
1879 0xA5 => {
1881 assert_eq!(
1882 parsed.type_info.max_length, original.max_length,
1883 "column {i} ({}) binary max_length",
1884 original.name
1885 );
1886 assert!(
1887 parsed.type_info.collation.is_none(),
1888 "column {i} ({}) should not have collation",
1889 original.name
1890 );
1891 }
1892 0x6C => {
1894 assert_eq!(
1895 parsed.type_info.precision, original.precision,
1896 "column {i} ({}) precision",
1897 original.name
1898 );
1899 assert_eq!(
1900 parsed.type_info.scale, original.scale,
1901 "column {i} ({}) scale",
1902 original.name
1903 );
1904 }
1905 _ => {}
1906 }
1907 }
1908 }
1909
1910 #[test]
1914 fn test_write_colmetadata_not_null_uses_fixed_types() {
1915 use tds_protocol::token::ColMetaData;
1916 use tds_protocol::types::TypeId;
1917
1918 let columns = vec![
1919 BulkColumn::new("id", "INT", 0)
1920 .unwrap()
1921 .with_nullable(false),
1922 BulkColumn::new("tiny", "TINYINT", 1)
1923 .unwrap()
1924 .with_nullable(false),
1925 BulkColumn::new("small", "SMALLINT", 2)
1926 .unwrap()
1927 .with_nullable(false),
1928 BulkColumn::new("big", "BIGINT", 3)
1929 .unwrap()
1930 .with_nullable(false),
1931 BulkColumn::new("flag", "BIT", 4)
1932 .unwrap()
1933 .with_nullable(false),
1934 BulkColumn::new("r", "REAL", 5)
1935 .unwrap()
1936 .with_nullable(false),
1937 BulkColumn::new("f", "FLOAT", 6)
1938 .unwrap()
1939 .with_nullable(false),
1940 BulkColumn::new("dt", "DATETIME", 7)
1941 .unwrap()
1942 .with_nullable(false),
1943 BulkColumn::new("sdt", "SMALLDATETIME", 8)
1944 .unwrap()
1945 .with_nullable(false),
1946 BulkColumn::new("mny", "MONEY", 9)
1947 .unwrap()
1948 .with_nullable(false),
1949 BulkColumn::new("smny", "SMALLMONEY", 10)
1950 .unwrap()
1951 .with_nullable(false),
1952 ];
1953
1954 let bulk = BulkInsert::new(columns.clone(), 0);
1955
1956 for (i, fixed) in bulk.fixed_len.iter().enumerate() {
1958 assert!(
1959 *fixed,
1960 "column {i} ({}) should be fixed_len",
1961 columns[i].name
1962 );
1963 }
1964
1965 let buf = &bulk.buffer[1..]; let mut cursor = bytes::Bytes::copy_from_slice(buf);
1968 let meta = ColMetaData::decode(&mut cursor).expect("parseable");
1969
1970 let expected: &[(&str, TypeId)] = &[
1972 ("id", TypeId::Int4),
1973 ("tiny", TypeId::Int1),
1974 ("small", TypeId::Int2),
1975 ("big", TypeId::Int8),
1976 ("flag", TypeId::Bit),
1977 ("r", TypeId::Float4),
1978 ("f", TypeId::Float8),
1979 ("dt", TypeId::DateTime),
1980 ("sdt", TypeId::DateTime4),
1981 ("mny", TypeId::Money),
1982 ("smny", TypeId::Money4),
1983 ];
1984
1985 for (i, (name, ty)) in expected.iter().enumerate() {
1986 assert_eq!(meta.columns[i].name, *name, "column {i} name");
1987 assert_eq!(meta.columns[i].type_id, *ty, "column {i} ({name}) type");
1988 assert_eq!(
1989 meta.columns[i].flags & 0x0001,
1990 0,
1991 "column {i} ({name}) should not have Nullable flag set"
1992 );
1993 }
1994 }
1995
1996 #[test]
2000 fn test_write_colmetadata_uses_caller_collation() {
2001 use tds_protocol::token::{ColMetaData, Collation};
2002
2003 let chinese = Collation {
2005 lcid: 0x0804,
2006 sort_id: 0x52,
2007 };
2008
2009 let columns = vec![
2010 BulkColumn::new("s", "VARCHAR(50)", 0)
2011 .unwrap()
2012 .with_collation(chinese),
2013 BulkColumn::new("n", "NVARCHAR(50)", 1)
2015 .unwrap()
2016 .with_collation(chinese),
2017 BulkColumn::new("d", "VARCHAR(10)", 2).unwrap(),
2019 ];
2020 let bulk = BulkInsert::new(columns, 0);
2021
2022 let buf = &bulk.buffer[1..];
2023 let mut cursor = bytes::Bytes::copy_from_slice(buf);
2024 let meta = ColMetaData::decode(&mut cursor).expect("parseable");
2025
2026 let c0 = meta.columns[0]
2027 .type_info
2028 .collation
2029 .as_ref()
2030 .expect("VARCHAR has collation");
2031 assert_eq!(c0.lcid, chinese.lcid, "VARCHAR caller LCID");
2032 assert_eq!(c0.sort_id, chinese.sort_id, "VARCHAR caller sort_id");
2033
2034 let c1 = meta.columns[1]
2035 .type_info
2036 .collation
2037 .as_ref()
2038 .expect("NVARCHAR has collation");
2039 assert_eq!(c1.lcid, chinese.lcid, "NVARCHAR caller LCID");
2040 assert_eq!(c1.sort_id, chinese.sort_id, "NVARCHAR caller sort_id");
2041
2042 let default = meta.columns[2]
2045 .type_info
2046 .collation
2047 .as_ref()
2048 .expect("VARCHAR has default collation");
2049 assert_eq!(default.to_bytes(), [0x09, 0x04, 0xD0, 0x00, 0x34]);
2050 }
2051
2052 #[test]
2053 fn test_parse_sql_type_max() {
2054 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(MAX)");
2056 assert_eq!(type_id, 0xE7);
2057 assert_eq!(len, Some(0xFFFF)); let (type_id, len, _, _) = parse_sql_type("VARBINARY(MAX)");
2061 assert_eq!(type_id, 0xA5);
2062 assert_eq!(len, Some(0xFFFF));
2063
2064 let (type_id, len, _, _) = parse_sql_type("VARCHAR(MAX)");
2066 assert_eq!(type_id, 0xA7);
2067 assert_eq!(len, Some(0xFFFF));
2068
2069 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
2071 assert_eq!(type_id, 0xE7);
2072 assert_eq!(len, Some(200)); }
2074}