1use bytes::{BufMut, BytesMut};
57use once_cell::sync::Lazy;
58use regex::Regex;
59use std::sync::Arc;
60
61use mssql_types::{SqlValue, ToSql, TypeError};
62use tds_protocol::packet::{PacketHeader, PacketStatus, PacketType};
63use tds_protocol::token::{Collation, DoneStatus, TokenType};
64
65use crate::error::Error;
66
67#[derive(Debug, Clone)]
72pub struct BulkOptions {
73 pub batch_size: usize,
83
84 pub check_constraints: bool,
88
89 pub fire_triggers: bool,
93
94 pub keep_nulls: bool,
98
99 pub table_lock: bool,
105
106 pub order_hint: Option<Vec<String>>,
112}
113
114impl Default for BulkOptions {
115 fn default() -> Self {
116 Self {
117 batch_size: 0,
118 check_constraints: true,
119 fire_triggers: false,
120 keep_nulls: true,
121 table_lock: false,
122 order_hint: None,
123 }
124 }
125}
126
127#[derive(Debug, Clone)]
129pub struct BulkColumn {
130 pub name: String,
132 pub sql_type: String,
134 pub nullable: bool,
136 pub ordinal: usize,
138 type_id: u8,
140 max_length: Option<u32>,
142 precision: Option<u8>,
144 scale: Option<u8>,
146 collation: Option<Collation>,
154}
155
156impl BulkColumn {
157 pub fn new<S: Into<String>>(name: S, sql_type: S, ordinal: usize) -> Result<Self, TypeError> {
167 let sql_type_str: String = sql_type.into();
168 reject_unsupported_bulk_type(&sql_type_str)?;
169 let (type_id, max_length, precision, scale) =
170 parse_sql_type(&sql_type_str).ok_or_else(|| {
171 let base = sql_type_str
172 .split('(')
173 .next()
174 .unwrap_or("")
175 .trim()
176 .to_uppercase();
177 TypeError::UnsupportedType {
178 sql_type: base,
179 reason: "unsupported bulk-insert column type. Supported types: \
180 BIT, TINYINT, SMALLINT, INT, BIGINT, REAL, FLOAT, \
181 DECIMAL/NUMERIC, MONEY, SMALLMONEY, CHAR/VARCHAR, \
182 NCHAR/NVARCHAR (incl. MAX), BINARY/VARBINARY (incl. MAX), \
183 UNIQUEIDENTIFIER, DATE, TIME, DATETIME, DATETIME2, \
184 DATETIMEOFFSET, SMALLDATETIME, and XML."
185 .to_string(),
186 }
187 })?;
188
189 Ok(Self {
190 name: name.into(),
191 sql_type: sql_type_str,
192 nullable: true,
193 ordinal,
194 type_id,
195 max_length,
196 precision,
197 scale,
198 collation: None,
199 })
200 }
201
202 #[must_use]
204 pub fn with_nullable(mut self, nullable: bool) -> Self {
205 self.nullable = nullable;
206 self
207 }
208
209 #[must_use]
216 pub fn with_collation(mut self, collation: Collation) -> Self {
217 self.collation = Some(collation);
218 self
219 }
220}
221
222type ParsedSqlType = (u8, Option<u32>, Option<u8>, Option<u8>);
224
225fn parse_sql_type(sql_type: &str) -> Option<ParsedSqlType> {
236 let upper = sql_type.to_uppercase();
237
238 let (base, params) = if let Some(paren_pos) = upper.find('(') {
240 let base = upper[..paren_pos].trim();
244 let params_str = upper[paren_pos + 1..].trim_end_matches(')').trim();
245 (base, Some(params_str))
246 } else {
247 (upper.as_str().trim(), None)
248 };
249
250 let result = match base {
255 "BIT" => (0x68, Some(1), None, None), "TINYINT" => (0x26, Some(1), None, None), "SMALLINT" => (0x26, Some(2), None, None), "INT" => (0x26, Some(4), None, None), "BIGINT" => (0x26, Some(8), None, None), "REAL" => (0x6D, Some(4), None, None), "FLOAT" => (0x6D, Some(8), None, None), "DATE" => (0x28, None, None, None),
263 "TIME" => {
264 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
265 (0x29, None, None, Some(scale))
266 }
267 "DATETIME" => (0x6F, Some(8), None, None), "DATETIME2" => {
269 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
270 (0x2A, None, None, Some(scale))
271 }
272 "DATETIMEOFFSET" => {
273 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
274 (0x2B, None, None, Some(scale))
275 }
276 "SMALLDATETIME" => (0x6F, Some(4), None, None), "UNIQUEIDENTIFIER" => (0x24, Some(16), None, None),
278 "VARCHAR" | "CHAR" => {
279 let len = params
280 .and_then(|p| {
281 if p == "MAX" {
282 Some(0xFFFF_u32)
283 } else {
284 p.parse().ok()
285 }
286 })
287 .unwrap_or(8000);
288 (0xA7, Some(len), None, None)
289 }
290 "NVARCHAR" | "NCHAR" => {
291 let is_max = params.map(|p| p == "MAX").unwrap_or(false);
292 if is_max {
293 (0xE7, Some(0xFFFF), None, None)
295 } else {
296 let len = params.and_then(|p| p.parse().ok()).unwrap_or(4000);
298 (0xE7, Some(len * 2), None, None)
299 }
300 }
301 "VARBINARY" | "BINARY" => {
302 let len = params
303 .and_then(|p| {
304 if p == "MAX" {
305 Some(0xFFFF_u32)
306 } else {
307 p.parse().ok()
308 }
309 })
310 .unwrap_or(8000);
311 (0xA5, Some(len), None, None)
312 }
313 "DECIMAL" | "NUMERIC" => {
314 let (precision, scale) = if let Some(p) = params {
315 let parts: Vec<&str> = p.split(',').map(|s| s.trim()).collect();
316 (
317 parts.first().and_then(|s| s.parse().ok()).unwrap_or(18),
318 parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0),
319 )
320 } else {
321 (18, 0)
322 };
323 (0x6C, None, Some(precision), Some(scale))
324 }
325 "MONEY" => (0x6E, Some(8), None, None), "SMALLMONEY" => (0x6E, Some(4), None, None), "XML" => (0xF1, Some(0xFFFF), None, None),
328 _ => return None,
332 };
333 Some(result)
334}
335
336fn reject_unsupported_bulk_type(sql_type: &str) -> Result<(), TypeError> {
342 let base = sql_type
343 .split('(')
344 .next()
345 .unwrap_or("")
346 .trim()
347 .to_uppercase();
348 match base.as_str() {
349 "TEXT" | "NTEXT" => Err(TypeError::UnsupportedType {
350 sql_type: base,
351 reason: "TEXT/NTEXT are not supported. Use VARCHAR(MAX) / \
352 NVARCHAR(MAX) instead (Microsoft deprecated TEXT/NTEXT in \
353 SQL Server 2005)."
354 .to_string(),
355 }),
356 "IMAGE" => Err(TypeError::UnsupportedType {
357 sql_type: base,
358 reason: "IMAGE is not supported. Use VARBINARY(MAX) instead \
359 (Microsoft deprecated IMAGE in SQL Server 2005)."
360 .to_string(),
361 }),
362 _ => Ok(()),
363 }
364}
365
366#[derive(Debug, Clone)]
368pub struct BulkInsertResult {
369 pub rows_affected: u64,
371 pub batches_committed: u32,
375 pub has_errors: bool,
377}
378
379#[derive(Debug)]
381pub struct BulkInsertBuilder {
382 table_name: String,
383 columns: Vec<BulkColumn>,
384 options: BulkOptions,
385}
386
387impl BulkInsertBuilder {
388 pub fn new<S: Into<String>>(table_name: S) -> Self {
390 Self {
391 table_name: table_name.into(),
392 columns: Vec::new(),
393 options: BulkOptions::default(),
394 }
395 }
396
397 #[must_use]
402 #[allow(clippy::expect_used)] pub fn with_columns(mut self, column_names: &[&str]) -> Self {
404 self.columns = column_names
405 .iter()
406 .enumerate()
407 .map(|(i, name)| {
408 BulkColumn::new(*name, "NVARCHAR(MAX)", i)
409 .expect("NVARCHAR(MAX) is always a supported type")
410 })
411 .collect();
412 self
413 }
414
415 #[must_use]
417 pub fn with_typed_columns(mut self, columns: Vec<BulkColumn>) -> Self {
418 self.columns = columns;
419 self
420 }
421
422 #[must_use]
424 pub fn with_options(mut self, options: BulkOptions) -> Self {
425 self.options = options;
426 self
427 }
428
429 #[must_use]
431 pub fn batch_size(mut self, size: usize) -> Self {
432 self.options.batch_size = size;
433 self
434 }
435
436 #[must_use]
438 pub fn table_lock(mut self, enabled: bool) -> Self {
439 self.options.table_lock = enabled;
440 self
441 }
442
443 #[must_use]
445 pub fn fire_triggers(mut self, enabled: bool) -> Self {
446 self.options.fire_triggers = enabled;
447 self
448 }
449
450 pub fn table_name(&self) -> &str {
452 &self.table_name
453 }
454
455 pub fn columns(&self) -> &[BulkColumn] {
457 &self.columns
458 }
459
460 pub fn options(&self) -> &BulkOptions {
462 &self.options
463 }
464
465 pub fn build_insert_bulk_statement(&self) -> Result<String, Error> {
472 crate::validation::validate_qualified_identifier(&self.table_name)?;
474
475 for col in &self.columns {
477 crate::validation::validate_identifier(&col.name)?;
478 }
479
480 let mut sql = format!("INSERT BULK {}", self.table_name);
481
482 if !self.columns.is_empty() {
484 sql.push_str(" (");
485 let cols: Vec<String> = self
486 .columns
487 .iter()
488 .map(|c| {
489 validate_sql_type(&c.sql_type)?;
495 Ok(format!("{} {}", c.name, c.sql_type))
496 })
497 .collect::<Result<Vec<_>, Error>>()?;
498 sql.push_str(&cols.join(", "));
499 sql.push(')');
500 }
501
502 let mut hints: Vec<String> = Vec::new();
504
505 if self.options.check_constraints {
506 hints.push("CHECK_CONSTRAINTS".to_string());
507 }
508 if self.options.fire_triggers {
509 hints.push("FIRE_TRIGGERS".to_string());
510 }
511 if self.options.keep_nulls {
512 hints.push("KEEP_NULLS".to_string());
513 }
514 if self.options.table_lock {
515 hints.push("TABLOCK".to_string());
516 }
517 if self.options.batch_size > 0 {
518 hints.push(format!("ROWS_PER_BATCH = {}", self.options.batch_size));
519 }
520
521 if let Some(ref order) = self.options.order_hint {
522 for col_name in order {
524 crate::validation::validate_identifier(col_name)?;
525 }
526 hints.push(format!("ORDER({})", order.join(", ")));
527 }
528
529 if !hints.is_empty() {
530 sql.push_str(" WITH (");
531 sql.push_str(&hints.join(", "));
532 sql.push(')');
533 }
534
535 Ok(sql)
536 }
537}
538
539fn validate_sql_type(type_str: &str) -> Result<(), Error> {
545 #[allow(clippy::expect_used)] static SQL_TYPE_RE: Lazy<Regex> =
547 Lazy::new(|| Regex::new(r"^[a-zA-Z][a-zA-Z0-9_ ()\.,]{0,127}$").expect("valid regex"));
548
549 if type_str.is_empty() {
550 return Err(Error::Config("SQL type cannot be empty".into()));
551 }
552
553 if !SQL_TYPE_RE.is_match(type_str) {
554 return Err(Error::Config(format!(
555 "invalid SQL type '{type_str}': contains disallowed characters"
556 )));
557 }
558
559 Ok(())
560}
561
562pub struct BulkInsert {
567 columns: Arc<[BulkColumn]>,
569 fixed_len: Arc<[bool]>,
572 buffer: BytesMut,
574 rows_in_batch: usize,
576 total_rows: u64,
578 batch_size: usize,
580 batches_committed: u32,
582 packet_id: u8,
584}
585
586impl BulkInsert {
587 pub fn new(columns: Vec<BulkColumn>, batch_size: usize) -> Self {
589 Self::new_with_server_metadata(columns, batch_size, None, None)
590 }
591
592 pub(crate) fn new_with_server_metadata(
603 mut columns: Vec<BulkColumn>,
604 batch_size: usize,
605 raw_colmetadata: Option<bytes::Bytes>,
606 server_columns: Option<&[tds_protocol::token::ColumnData]>,
607 ) -> Self {
608 let fixed_len: Vec<bool> = if let Some(srv_cols) = server_columns {
611 for (col, srv) in columns.iter_mut().zip(srv_cols.iter()) {
617 if col.collation.is_none() {
618 col.collation = srv.type_info.collation;
619 }
620 }
621 srv_cols
622 .iter()
623 .map(|c| c.type_id.is_fixed_length())
624 .collect()
625 } else {
626 columns
631 .iter()
632 .map(|c| !c.nullable && nullable_to_fixed_type(c.type_id, c.max_length).is_some())
633 .collect()
634 };
635
636 let mut bulk = Self {
637 columns: columns.into(),
638 fixed_len: fixed_len.into(),
639 buffer: BytesMut::with_capacity(64 * 1024),
640 rows_in_batch: 0,
641 total_rows: 0,
642 batch_size,
643 batches_committed: 0,
644 packet_id: 1,
645 };
646
647 if let Some(raw) = raw_colmetadata {
648 bulk.buffer.extend_from_slice(&raw);
649 } else {
650 bulk.write_colmetadata();
651 }
652
653 bulk
654 }
655
656 fn write_colmetadata(&mut self) {
658 let buf = &mut self.buffer;
659
660 buf.put_u8(TokenType::ColMetaData as u8);
662
663 buf.put_u16_le(self.columns.len() as u16);
665
666 for col in self.columns.iter() {
667 buf.put_u32_le(0);
669
670 let effective_type_id = if !col.nullable {
674 nullable_to_fixed_type(col.type_id, col.max_length).unwrap_or(col.type_id)
675 } else {
676 col.type_id
677 };
678 let is_fixed_variant = effective_type_id != col.type_id;
679
680 let mut flags: u16 = 0x0008; if col.nullable {
684 flags |= 0x0001; }
686 buf.put_u16_le(flags);
687
688 buf.put_u8(effective_type_id);
690
691 if is_fixed_variant {
694 let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
695 buf.put_u8(name_utf16.len() as u8);
696 for code_unit in name_utf16 {
697 buf.put_u16_le(code_unit);
698 }
699 continue;
700 }
701
702 match col.type_id {
704 0x26 | 0x68 | 0x6D | 0x6E | 0x6F => {
707 buf.put_u8(col.max_length.unwrap_or(4) as u8);
708 }
709
710 0x28 => {}
712
713 0xE7 | 0xA7 | 0xA5 | 0xAD => {
715 let max_len = col.max_length.unwrap_or(8000);
717 if max_len == 0xFFFF {
718 buf.put_u16_le(0xFFFF);
719 } else {
720 buf.put_u16_le(max_len as u16);
721 }
722
723 if col.type_id == 0xE7 || col.type_id == 0xA7 {
727 if let Some(coll) = col.collation.as_ref() {
728 buf.put_slice(&coll.to_bytes());
729 } else {
730 buf.put_slice(&[0x09, 0x04, 0xD0, 0x00, 0x34]);
733 }
734 }
735 }
736
737 0x6C | 0x6A => {
739 let precision = col.precision.unwrap_or(18);
741 let len = decimal_byte_length(precision);
742 buf.put_u8(len);
743 buf.put_u8(precision);
744 buf.put_u8(col.scale.unwrap_or(0));
745 }
746
747 0x29..=0x2B => {
749 buf.put_u8(col.scale.unwrap_or(7));
750 }
751
752 0x24 => {
754 buf.put_u8(16);
755 }
756
757 _ => {
759 if let Some(len) = col.max_length {
760 if len <= 0xFFFF {
761 buf.put_u16_le(len as u16);
762 }
763 }
764 }
765 }
766
767 let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
769 buf.put_u8(name_utf16.len() as u8);
770 for code_unit in name_utf16 {
771 buf.put_u16_le(code_unit);
772 }
773 }
774 }
775
776 pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
787 if values.len() != self.columns.len() {
788 return Err(Error::Config(format!(
789 "expected {} values, got {}",
790 self.columns.len(),
791 values.len()
792 )));
793 }
794
795 let sql_values: Result<Vec<SqlValue>, TypeError> =
797 values.iter().map(|v| v.to_sql()).collect();
798 let sql_values = sql_values.map_err(Error::from)?;
799
800 self.write_row(&sql_values)?;
801
802 self.rows_in_batch += 1;
803 self.total_rows += 1;
804
805 Ok(())
806 }
807
808 pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
810 if values.len() != self.columns.len() {
811 return Err(Error::Config(format!(
812 "expected {} values, got {}",
813 self.columns.len(),
814 values.len()
815 )));
816 }
817
818 self.write_row(values)?;
819
820 self.rows_in_batch += 1;
821 self.total_rows += 1;
822
823 Ok(())
824 }
825
826 fn write_row(&mut self, values: &[SqlValue]) -> Result<(), Error> {
828 self.buffer.put_u8(TokenType::Row as u8);
830
831 let columns: Vec<_> = self.columns.iter().cloned().collect();
833 let fixed_len = self.fixed_len.clone();
834
835 for (i, (col, value)) in columns.iter().zip(values.iter()).enumerate() {
837 let is_fixed = *fixed_len.get(i).unwrap_or(&false);
838 self.encode_column_value(col, value, is_fixed)
839 .map_err(|e| Error::Config(format!("failed to encode column {i}: {e}")))?;
840 }
841
842 Ok(())
843 }
844
845 fn encode_column_value(
851 &mut self,
852 col: &BulkColumn,
853 value: &SqlValue,
854 is_fixed: bool,
855 ) -> Result<(), TypeError> {
856 let buf = &mut self.buffer;
857
858 let is_plp_type =
861 col.max_length == Some(0xFFFF) && matches!(col.type_id, 0xE7 | 0xA7 | 0xA5 | 0xAD);
862
863 match value {
864 SqlValue::Null => {
865 match col.type_id {
867 0xE7 | 0xA7 | 0xA5 | 0xAD => {
869 if is_plp_type {
870 buf.put_u64_le(0xFFFF_FFFF_FFFF_FFFF);
872 } else {
873 buf.put_u16_le(0xFFFF);
875 }
876 }
877 0x26 | 0x68 | 0x6D | 0x6E | 0x6F | 0x6C | 0x6A | 0x24 | 0x28 | 0x29 | 0x2A
880 | 0x2B => {
881 buf.put_u8(0);
882 }
883 _ => {
885 if col.nullable {
886 buf.put_u8(0);
887 } else {
888 return Err(TypeError::UnexpectedNull);
889 }
890 }
891 }
892 }
893
894 SqlValue::Bool(v) => {
895 if !is_fixed {
896 buf.put_u8(1);
897 }
898 buf.put_u8(if *v { 1 } else { 0 });
899 }
900
901 SqlValue::TinyInt(v) => {
902 if !is_fixed {
903 buf.put_u8(1);
904 }
905 buf.put_u8(*v);
906 }
907
908 SqlValue::SmallInt(v) => {
909 if !is_fixed {
910 buf.put_u8(2);
911 }
912 buf.put_i16_le(*v);
913 }
914
915 SqlValue::Int(v) => {
916 if !is_fixed {
917 buf.put_u8(4);
918 }
919 buf.put_i32_le(*v);
920 }
921
922 SqlValue::BigInt(v) => {
923 if !is_fixed {
924 buf.put_u8(8);
925 }
926 buf.put_i64_le(*v);
927 }
928
929 SqlValue::Float(v) => {
930 if !is_fixed {
931 buf.put_u8(4);
932 }
933 buf.put_f32_le(*v);
934 }
935
936 SqlValue::Double(v) => {
937 if !is_fixed {
938 buf.put_u8(8);
939 }
940 buf.put_f64_le(*v);
941 }
942
943 SqlValue::String(s) => {
944 let is_varchar = matches!(col.type_id, 0xA7 | 0x2F | 0xAF);
950
951 if is_varchar {
952 let encoded = encode_varchar_for_collation(s, col.collation.as_ref());
953 let byte_len = encoded.len();
954
955 if is_plp_type {
956 encode_plp_binary(&encoded, buf);
957 } else if byte_len > 0xFFFF {
958 return Err(TypeError::BufferTooSmall {
959 needed: byte_len,
960 available: 0xFFFF,
961 });
962 } else {
963 buf.put_u16_le(byte_len as u16);
964 buf.put_slice(&encoded);
965 }
966 } else {
967 let utf16: Vec<u16> = s.encode_utf16().collect();
969 let byte_len = utf16.len() * 2;
970
971 if is_plp_type {
972 encode_plp_string(&utf16, buf);
975 } else if byte_len > 0xFFFF {
976 return Err(TypeError::BufferTooSmall {
978 needed: byte_len,
979 available: 0xFFFF,
980 });
981 } else {
982 buf.put_u16_le(byte_len as u16);
984 for code_unit in utf16 {
985 buf.put_u16_le(code_unit);
986 }
987 }
988 }
989 }
990
991 SqlValue::Binary(b) => {
992 if is_plp_type {
993 encode_plp_binary(b, buf);
995 } else if b.len() > 0xFFFF {
996 return Err(TypeError::BufferTooSmall {
998 needed: b.len(),
999 available: 0xFFFF,
1000 });
1001 } else {
1002 buf.put_u16_le(b.len() as u16);
1004 buf.put_slice(b);
1005 }
1006 }
1007
1008 #[cfg(feature = "decimal")]
1010 SqlValue::Decimal(d) => {
1011 if col.type_id == 0x6E {
1012 encode_money_value(*d, col, buf, is_fixed)?;
1014 } else {
1015 let precision = col.precision.unwrap_or(18);
1016 let len = decimal_byte_length(precision);
1017 buf.put_u8(len);
1018
1019 buf.put_u8(if d.is_sign_negative() { 0 } else { 1 });
1021
1022 let mantissa = d.mantissa().unsigned_abs();
1024 let mantissa_bytes = mantissa.to_le_bytes();
1025 buf.put_slice(&mantissa_bytes[..((len - 1) as usize)]);
1026 }
1027 }
1028
1029 #[cfg(feature = "uuid")]
1030 SqlValue::Uuid(u) => {
1031 buf.put_u8(16); mssql_types::__private::encode_uuid(*u, buf);
1034 }
1035
1036 #[cfg(feature = "chrono")]
1037 SqlValue::Date(d) => {
1038 buf.put_u8(3); mssql_types::__private::encode_date(*d, buf)?;
1040 }
1041
1042 #[cfg(feature = "chrono")]
1043 SqlValue::Time(t) => {
1044 let scale = col.scale.unwrap_or(7);
1045 let len = time_byte_length(scale);
1046 buf.put_u8(len);
1047 encode_time_with_scale(*t, scale, buf);
1049 }
1050
1051 #[cfg(feature = "chrono")]
1052 SqlValue::DateTime(dt) => {
1053 if col.type_id == 0x6F {
1058 let total_len = col.max_length.unwrap_or(8) as u8;
1059 if !is_fixed {
1060 buf.put_u8(total_len);
1061 }
1062 match total_len {
1063 8 => mssql_types::__private::encode_datetime_legacy(*dt, buf),
1064 4 => mssql_types::__private::encode_smalldatetime(*dt, buf)?,
1065 _ => {
1066 return Err(TypeError::InvalidDateTime(format!(
1067 "DATETIMEN max_length must be 4 or 8, got {total_len}"
1068 )));
1069 }
1070 }
1071 } else {
1072 let scale = col.scale.unwrap_or(7);
1073 let time_len = time_byte_length(scale);
1074 let total_len = time_len + 3;
1075 buf.put_u8(total_len);
1076 encode_time_with_scale(dt.time(), scale, buf);
1078 mssql_types::__private::encode_date(dt.date(), buf)?;
1079 }
1080 }
1081 #[cfg(feature = "chrono")]
1082 SqlValue::SmallDateTime(dt) => {
1083 if !is_fixed {
1086 buf.put_u8(4);
1087 }
1088 mssql_types::__private::encode_smalldatetime(*dt, buf)?;
1089 }
1090 #[cfg(feature = "decimal")]
1091 SqlValue::Money(d) => {
1092 if !is_fixed {
1094 buf.put_u8(8);
1095 }
1096 mssql_types::__private::encode_money(*d, buf)?;
1097 }
1098 #[cfg(feature = "decimal")]
1099 SqlValue::SmallMoney(d) => {
1100 if !is_fixed {
1101 buf.put_u8(4);
1102 }
1103 mssql_types::__private::encode_smallmoney(*d, buf)?;
1104 }
1105
1106 #[cfg(feature = "chrono")]
1107 SqlValue::DateTimeOffset(dto) => {
1108 let scale = col.scale.unwrap_or(7);
1109 let time_len = time_byte_length(scale);
1110 let total_len = time_len + 3 + 2;
1111 buf.put_u8(total_len);
1112 let utc = dto.naive_utc();
1115 encode_time_with_scale(utc.time(), scale, buf);
1116 mssql_types::__private::encode_date(utc.date(), buf)?;
1117 use chrono::Offset;
1119 let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
1120 buf.put_i16_le(offset_minutes);
1121 }
1122
1123 #[cfg(feature = "json")]
1124 SqlValue::Json(j) => {
1125 let s = j.to_string();
1126 encode_nvarchar_value(&s, buf)?;
1127 }
1128
1129 SqlValue::Xml(x) => {
1130 encode_nvarchar_value(x, buf)?;
1131 }
1132
1133 SqlValue::Tvp(_) => {
1134 return Err(TypeError::UnsupportedConversion {
1136 from: "TVP".to_string(),
1137 to: "bulk copy value",
1138 });
1139 }
1140 _ => {
1142 return Err(TypeError::UnsupportedConversion {
1143 from: value.type_name().to_string(),
1144 to: "bulk copy value",
1145 });
1146 }
1147 }
1148
1149 Ok(())
1150 }
1151}
1152
1153#[cfg(feature = "decimal")]
1159fn encode_money_value(
1160 value: rust_decimal::Decimal,
1161 col: &BulkColumn,
1162 buf: &mut BytesMut,
1163 is_fixed: bool,
1164) -> Result<(), TypeError> {
1165 let money_bytes: u8 = col.max_length.unwrap_or(8) as u8;
1166 if !is_fixed {
1167 buf.put_u8(money_bytes);
1168 }
1169 match money_bytes {
1170 4 => mssql_types::__private::encode_smallmoney(value, buf),
1171 8 => mssql_types::__private::encode_money(value, buf),
1172 _ => Err(TypeError::InvalidDecimal(format!(
1173 "MONEY column has invalid max_length: {money_bytes}"
1174 ))),
1175 }
1176}
1177
1178fn encode_nvarchar_value(s: &str, buf: &mut BytesMut) -> Result<(), TypeError> {
1180 let utf16: Vec<u16> = s.encode_utf16().collect();
1181 let byte_len = utf16.len() * 2;
1182
1183 if byte_len > 0xFFFF {
1184 return Err(TypeError::BufferTooSmall {
1185 needed: byte_len,
1186 available: 0xFFFF,
1187 });
1188 }
1189
1190 buf.put_u16_le(byte_len as u16);
1191 for code_unit in utf16 {
1192 buf.put_u16_le(code_unit);
1193 }
1194 Ok(())
1195}
1196
1197const PLP_UNKNOWN_LEN: u64 = 0xFFFFFFFFFFFFFFFE;
1202
1203fn encode_plp_string(utf16: &[u16], buf: &mut BytesMut) {
1216 let byte_len = utf16.len() * 2;
1217
1218 buf.put_u64_le(PLP_UNKNOWN_LEN);
1219
1220 if byte_len > 0 {
1221 buf.put_u32_le(byte_len as u32);
1222 for code_unit in utf16 {
1223 buf.put_u16_le(*code_unit);
1224 }
1225 }
1226
1227 buf.put_u32_le(0);
1228}
1229
1230fn encode_plp_binary(data: &[u8], buf: &mut BytesMut) {
1233 buf.put_u64_le(PLP_UNKNOWN_LEN);
1234
1235 if !data.is_empty() {
1236 buf.put_u32_le(data.len() as u32);
1237 buf.put_slice(data);
1238 }
1239
1240 buf.put_u32_le(0);
1241}
1242
1243fn encode_varchar_for_collation(value: &str, collation: Option<&Collation>) -> Vec<u8> {
1248 tds_protocol::__private::encode_str_for_collation(value, collation)
1249}
1250
1251#[cfg(feature = "chrono")]
1253fn encode_time_with_scale(time: chrono::NaiveTime, scale: u8, buf: &mut BytesMut) {
1254 use chrono::Timelike;
1255
1256 let nanos = time.num_seconds_from_midnight() as u64 * 1_000_000_000 + time.nanosecond() as u64;
1257 let intervals = nanos / time_scale_divisor(scale);
1258 let len = time_byte_length(scale);
1259
1260 for i in 0..len {
1261 buf.put_u8(((intervals >> (i * 8)) & 0xFF) as u8);
1262 }
1263}
1264
1265impl BulkInsert {
1266 fn write_done(&mut self) {
1268 let buf = &mut self.buffer;
1269
1270 buf.put_u8(TokenType::Done as u8);
1271
1272 let status = DoneStatus::from_bits(0x0010); buf.put_u16_le(status.to_bits());
1275
1276 buf.put_u16_le(0);
1278
1279 buf.put_u64_le(self.total_rows);
1281 }
1282
1283 pub fn take_packets(&mut self) -> Vec<BytesMut> {
1287 const MAX_PACKET_SIZE: usize = 4096;
1288 const HEADER_SIZE: usize = 8;
1289 const MAX_PAYLOAD: usize = MAX_PACKET_SIZE - HEADER_SIZE;
1290
1291 let data = self.buffer.split();
1292 let mut packets = Vec::new();
1293 let mut offset = 0;
1294
1295 while offset < data.len() {
1296 let remaining = data.len() - offset;
1297 let payload_size = remaining.min(MAX_PAYLOAD);
1298 let is_last = offset + payload_size >= data.len();
1299
1300 let mut packet = BytesMut::with_capacity(MAX_PACKET_SIZE);
1301
1302 let header = PacketHeader {
1304 packet_type: PacketType::BulkLoad,
1305 status: if is_last {
1306 PacketStatus::END_OF_MESSAGE
1307 } else {
1308 PacketStatus::NORMAL
1309 },
1310 length: (HEADER_SIZE + payload_size) as u16,
1311 spid: 0,
1312 packet_id: self.packet_id,
1313 window: 0,
1314 };
1315
1316 header.encode(&mut packet);
1317
1318 packet.put_slice(&data[offset..offset + payload_size]);
1320
1321 packets.push(packet);
1322 offset += payload_size;
1323 self.packet_id = self.packet_id.wrapping_add(1);
1324 }
1325
1326 packets
1327 }
1328
1329 pub fn total_rows(&self) -> u64 {
1331 self.total_rows
1332 }
1333
1334 pub fn rows_in_batch(&self) -> usize {
1336 self.rows_in_batch
1337 }
1338
1339 pub fn should_flush(&self) -> bool {
1345 self.batch_size > 0 && self.rows_in_batch >= self.batch_size
1346 }
1347
1348 pub fn finish_packets(&mut self) -> Vec<BytesMut> {
1351 self.write_done();
1352 self.take_packets()
1353 }
1354
1355 pub fn result(&self) -> BulkInsertResult {
1357 BulkInsertResult {
1358 rows_affected: self.total_rows,
1359 batches_committed: self.batches_committed,
1360 has_errors: false,
1361 }
1362 }
1363}
1364
1365pub struct BulkWriter<'a, S: crate::state::ConnectionState> {
1394 client: &'a mut crate::client::Client<S>,
1395 bulk: BulkInsert,
1396}
1397
1398impl<'a, S: crate::state::ConnectionState> BulkWriter<'a, S> {
1399 pub(crate) fn new(client: &'a mut crate::client::Client<S>, bulk: BulkInsert) -> Self {
1401 Self { client, bulk }
1402 }
1403
1404 pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
1410 self.bulk.send_row(values)
1411 }
1412
1413 pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
1415 self.bulk.send_row_values(values)
1416 }
1417
1418 pub fn total_rows(&self) -> u64 {
1420 self.bulk.total_rows()
1421 }
1422
1423 pub async fn finish(mut self) -> Result<BulkInsertResult, Error> {
1435 let deadline = self.client.command_deadline();
1436 let total_rows = self.bulk.total_rows();
1437 tracing::debug!(total_rows = total_rows, "finishing bulk insert");
1438
1439 self.bulk.write_done();
1441 let payload = self.bulk.buffer.split().freeze();
1442
1443 let send_and_read = self.client.send_and_read_bulk_load(payload);
1452 let rows_affected = match deadline {
1453 Some(d) => tokio::time::timeout(d, send_and_read)
1454 .await
1455 .map_err(|_| Error::CommandTimeout)??,
1456 None => send_and_read.await?,
1457 };
1458
1459 Ok(BulkInsertResult {
1460 rows_affected,
1461 batches_committed: 1,
1462 has_errors: false,
1463 })
1464 }
1465}
1466
1467fn nullable_to_fixed_type(type_id: u8, max_length: Option<u32>) -> Option<u8> {
1477 match (type_id, max_length) {
1478 (0x68, _) => Some(0x32), (0x26, Some(1)) => Some(0x30), (0x26, Some(2)) => Some(0x34), (0x26, Some(4)) => Some(0x38), (0x26, Some(8)) => Some(0x7F), (0x6D, Some(4)) => Some(0x3B), (0x6D, Some(8)) => Some(0x3E), (0x6E, Some(4)) => Some(0x7A), (0x6E, Some(8)) => Some(0x3C), (0x6F, Some(4)) => Some(0x3A), (0x6F, Some(8)) => Some(0x3D), _ => None,
1490 }
1491}
1492
1493fn decimal_byte_length(precision: u8) -> u8 {
1495 match precision {
1496 1..=9 => 5,
1497 10..=19 => 9,
1498 20..=28 => 13,
1499 29..=38 => 17,
1500 _ => 17, }
1502}
1503
1504#[cfg(feature = "chrono")]
1506fn time_byte_length(scale: u8) -> u8 {
1507 match scale {
1508 0..=2 => 3,
1509 3..=4 => 4,
1510 5..=7 => 5,
1511 _ => 5,
1512 }
1513}
1514
1515#[cfg(feature = "chrono")]
1517fn time_scale_divisor(scale: u8) -> u64 {
1518 match scale {
1519 0 => 1_000_000_000,
1520 1 => 100_000_000,
1521 2 => 10_000_000,
1522 3 => 1_000_000,
1523 4 => 100_000,
1524 5 => 10_000,
1525 6 => 1_000,
1526 7 => 100,
1527 _ => 100,
1528 }
1529}
1530
1531#[cfg(test)]
1532#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
1533mod tests {
1534 use super::*;
1535
1536 #[test]
1537 fn test_bulk_options_default() {
1538 let opts = BulkOptions::default();
1539 assert_eq!(opts.batch_size, 0);
1540 assert!(opts.check_constraints);
1541 assert!(!opts.fire_triggers);
1542 assert!(opts.keep_nulls);
1543 assert!(!opts.table_lock);
1544 }
1545
1546 #[test]
1547 fn test_bulk_column_creation() {
1548 let col = BulkColumn::new("id", "INT", 0).unwrap();
1549 assert_eq!(col.name, "id");
1550 assert_eq!(col.type_id, 0x26); assert_eq!(col.max_length, Some(4));
1552 assert!(col.nullable);
1553 }
1554
1555 #[test]
1556 fn test_bulk_column_rejects_text() {
1557 let err = BulkColumn::new("body", "TEXT", 0).unwrap_err();
1558 match err {
1559 TypeError::UnsupportedType { sql_type, reason } => {
1560 assert_eq!(sql_type, "TEXT");
1561 assert!(
1562 reason.contains("VARCHAR(MAX)"),
1563 "error should redirect to VARCHAR(MAX), got: {reason}"
1564 );
1565 assert!(
1566 reason.contains("deprecated"),
1567 "error should mention deprecation, got: {reason}"
1568 );
1569 }
1570 other => panic!("expected UnsupportedType, got {other:?}"),
1571 }
1572 }
1573
1574 #[test]
1575 fn test_bulk_column_rejects_ntext() {
1576 let err = BulkColumn::new("body", "NTEXT", 0).unwrap_err();
1577 match err {
1578 TypeError::UnsupportedType { sql_type, reason } => {
1579 assert_eq!(sql_type, "NTEXT");
1580 assert!(
1581 reason.contains("NVARCHAR(MAX)"),
1582 "error should redirect to NVARCHAR(MAX), got: {reason}"
1583 );
1584 assert!(
1585 reason.contains("deprecated"),
1586 "error should mention deprecation, got: {reason}"
1587 );
1588 }
1589 other => panic!("expected UnsupportedType, got {other:?}"),
1590 }
1591 }
1592
1593 #[test]
1594 fn test_bulk_column_rejects_text_case_insensitive() {
1595 assert!(matches!(
1596 BulkColumn::new("body", "text", 0),
1597 Err(TypeError::UnsupportedType { .. })
1598 ));
1599 assert!(matches!(
1600 BulkColumn::new("body", "Ntext", 0),
1601 Err(TypeError::UnsupportedType { .. })
1602 ));
1603 }
1604
1605 #[test]
1606 fn test_bulk_column_rejects_image() {
1607 let err = BulkColumn::new("blob", "IMAGE", 0).unwrap_err();
1608 match err {
1609 TypeError::UnsupportedType { sql_type, reason } => {
1610 assert_eq!(sql_type, "IMAGE");
1611 assert!(
1612 reason.contains("VARBINARY(MAX)"),
1613 "error should redirect to VARBINARY(MAX), got: {reason}"
1614 );
1615 assert!(
1616 reason.contains("deprecated"),
1617 "error should mention deprecation, got: {reason}"
1618 );
1619 }
1620 other => panic!("expected UnsupportedType, got {other:?}"),
1621 }
1622 }
1623
1624 #[test]
1625 fn test_bulk_column_rejects_image_case_insensitive() {
1626 assert!(matches!(
1627 BulkColumn::new("blob", "image", 0),
1628 Err(TypeError::UnsupportedType { .. })
1629 ));
1630 assert!(matches!(
1631 BulkColumn::new("blob", "Image", 0),
1632 Err(TypeError::UnsupportedType { .. })
1633 ));
1634 }
1635
1636 #[test]
1637 fn test_parse_sql_type() {
1638 let (type_id, len, _prec, _scale) = parse_sql_type("INT").unwrap();
1640 assert_eq!(type_id, 0x26);
1641 assert_eq!(len, Some(4));
1642
1643 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)").unwrap();
1644 assert_eq!(type_id, 0xE7);
1645 assert_eq!(len, Some(200)); let (type_id, _, prec, scale) = parse_sql_type("DECIMAL(10,2)").unwrap();
1648 assert_eq!(type_id, 0x6C);
1649 assert_eq!(prec, Some(10));
1650 assert_eq!(scale, Some(2));
1651
1652 let (type_id, len, _, _) = parse_sql_type("SMALLDATETIME").unwrap();
1654 assert_eq!(type_id, 0x6F);
1655 assert_eq!(len, Some(4));
1656
1657 let (type_id, len, _, _) = parse_sql_type("DATETIME").unwrap();
1658 assert_eq!(type_id, 0x6F);
1659 assert_eq!(len, Some(8));
1660
1661 assert_eq!(parse_sql_type("SQL_VARIANT"), None);
1663 assert_eq!(parse_sql_type("NOTATYPE"), None);
1664 }
1665
1666 #[test]
1667 fn test_bulk_column_rejects_unknown_type() {
1668 for bogus in ["SQL_VARIANT", "GEOGRAPHY", "HIERARCHYID", "NOTATYPE"] {
1671 let err = BulkColumn::new("c", bogus, 0).unwrap_err();
1672 assert!(
1673 matches!(err, TypeError::UnsupportedType { .. }),
1674 "expected UnsupportedType for {bogus}, got {err:?}"
1675 );
1676 }
1677 assert!(BulkColumn::new("c", "VARCHAR(garbage)", 0).is_ok());
1680 assert!(BulkColumn::new("c", "MONEY", 0).is_ok());
1681 assert!(BulkColumn::new("c", "DATETIME2(3)", 0).is_ok());
1682 }
1683
1684 #[test]
1685 fn test_parse_sql_type_tolerates_surrounding_spaces() {
1686 assert!(parse_sql_type("INT ").is_some());
1690 assert!(parse_sql_type(" INT").is_some());
1691 assert!(parse_sql_type("VARCHAR (50)").is_some());
1692 assert!(parse_sql_type("DECIMAL (18, 2)").is_some());
1693 let (id, len, _, _) = parse_sql_type("NVARCHAR( MAX )").unwrap();
1694 assert_eq!(id, 0xE7);
1695 assert_eq!(len, Some(0xFFFF)); assert!(BulkColumn::new("c", "VARCHAR (50)", 0).is_ok());
1697 assert!(BulkColumn::new("c", "INT ", 0).is_ok());
1698 }
1699
1700 #[test]
1701 fn test_insert_bulk_statement() {
1702 let builder = BulkInsertBuilder::new("dbo.Users")
1703 .with_typed_columns(vec![
1704 BulkColumn::new("id", "INT", 0).unwrap(),
1705 BulkColumn::new("name", "NVARCHAR(100)", 1).unwrap(),
1706 ])
1707 .table_lock(true);
1708
1709 let sql = builder.build_insert_bulk_statement().unwrap();
1710 assert!(sql.contains("INSERT BULK dbo.Users"));
1711 assert!(sql.contains("TABLOCK"));
1712 }
1713
1714 #[test]
1715 fn test_bulk_insert_rejects_injection() {
1716 let builder = BulkInsertBuilder::new("table;DROP TABLE users")
1717 .with_typed_columns(vec![BulkColumn::new("id", "INT", 0).unwrap()]);
1718
1719 assert!(builder.build_insert_bulk_statement().is_err());
1720 }
1721
1722 #[test]
1723 fn test_bulk_insert_validates_column_names() {
1724 let builder = BulkInsertBuilder::new("Users")
1725 .with_typed_columns(vec![BulkColumn::new("col;DROP TABLE x", "INT", 0).unwrap()]);
1726
1727 assert!(builder.build_insert_bulk_statement().is_err());
1728 }
1729
1730 #[test]
1731 fn test_bulk_insert_accepts_qualified_names() {
1732 let builder = BulkInsertBuilder::new("catalog.dbo.Users")
1733 .with_typed_columns(vec![BulkColumn::new("id", "INT", 0).unwrap()]);
1734
1735 assert!(builder.build_insert_bulk_statement().is_ok());
1736 }
1737
1738 #[test]
1739 fn test_bulk_insert_creation() {
1740 let columns = vec![
1741 BulkColumn::new("id", "INT", 0).unwrap(),
1742 BulkColumn::new("name", "NVARCHAR(100)", 1).unwrap(),
1743 ];
1744
1745 let bulk = BulkInsert::new(columns, 1000);
1746 assert_eq!(bulk.total_rows(), 0);
1747 assert_eq!(bulk.rows_in_batch(), 0);
1748 assert!(!bulk.should_flush());
1749 }
1750
1751 #[test]
1752 fn test_decimal_byte_length() {
1753 assert_eq!(decimal_byte_length(5), 5);
1754 assert_eq!(decimal_byte_length(15), 9);
1755 assert_eq!(decimal_byte_length(25), 13);
1756 assert_eq!(decimal_byte_length(35), 17);
1757 }
1758
1759 #[test]
1760 #[cfg(feature = "chrono")]
1761 fn test_time_byte_length() {
1762 assert_eq!(time_byte_length(0), 3);
1763 assert_eq!(time_byte_length(3), 4);
1764 assert_eq!(time_byte_length(7), 5);
1765 }
1766
1767 #[test]
1768 fn test_plp_string_encoding() {
1769 let mut buf = BytesMut::new();
1770 let text = "Hello";
1771 let utf16: Vec<u16> = text.encode_utf16().collect();
1772
1773 encode_plp_string(&utf16, &mut buf);
1774
1775 assert_eq!(buf.len(), 8 + 4 + 10 + 4);
1781
1782 assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1784
1785 assert_eq!(&buf[8..12], &10u32.to_le_bytes());
1787
1788 assert_eq!(&buf[22..26], &0u32.to_le_bytes());
1790 }
1791
1792 #[test]
1793 fn test_plp_binary_encoding() {
1794 let mut buf = BytesMut::new();
1795 let data = b"test binary data";
1796
1797 encode_plp_binary(data, &mut buf);
1798
1799 assert_eq!(buf.len(), 8 + 4 + 16 + 4);
1805
1806 assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1808
1809 assert_eq!(&buf[8..12], &16u32.to_le_bytes());
1811
1812 assert_eq!(&buf[12..28], data);
1814
1815 assert_eq!(&buf[28..32], &0u32.to_le_bytes());
1817 }
1818
1819 #[test]
1820 fn test_plp_empty_string() {
1821 let mut buf = BytesMut::new();
1822 let utf16: Vec<u16> = "".encode_utf16().collect();
1823
1824 encode_plp_string(&utf16, &mut buf);
1825
1826 assert_eq!(buf.len(), 8 + 4);
1828
1829 assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1831
1832 assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1834 }
1835
1836 #[test]
1837 fn test_plp_empty_binary() {
1838 let mut buf = BytesMut::new();
1839
1840 encode_plp_binary(&[], &mut buf);
1841
1842 assert_eq!(buf.len(), 8 + 4);
1844
1845 assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1847
1848 assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1850 }
1851
1852 #[test]
1855 fn test_write_colmetadata_roundtrip() {
1856 use tds_protocol::token::ColMetaData;
1857
1858 let columns = vec![
1859 BulkColumn::new("id", "INT", 0).unwrap(),
1860 BulkColumn::new("tiny", "TINYINT", 1).unwrap(),
1861 BulkColumn::new("small", "SMALLINT", 2).unwrap(),
1862 BulkColumn::new("big", "BIGINT", 3).unwrap(),
1863 BulkColumn::new("flag", "BIT", 4).unwrap(),
1864 BulkColumn::new("r", "REAL", 5).unwrap(),
1865 BulkColumn::new("f", "FLOAT", 6).unwrap(),
1866 BulkColumn::new("name", "NVARCHAR(100)", 7).unwrap(),
1867 BulkColumn::new("code", "VARCHAR(50)", 8).unwrap(),
1868 BulkColumn::new("data", "VARBINARY(200)", 9).unwrap(),
1869 BulkColumn::new("d", "DATE", 10).unwrap(),
1870 BulkColumn::new("t", "TIME(3)", 11).unwrap(),
1871 BulkColumn::new("dt", "DATETIME", 12).unwrap(),
1872 BulkColumn::new("dt2", "DATETIME2(7)", 13).unwrap(),
1873 BulkColumn::new("dto", "DATETIMEOFFSET(7)", 14).unwrap(),
1874 BulkColumn::new("sdt", "SMALLDATETIME", 15).unwrap(),
1875 BulkColumn::new("uid", "UNIQUEIDENTIFIER", 16).unwrap(),
1876 BulkColumn::new("amt", "DECIMAL(18,2)", 17).unwrap(),
1877 BulkColumn::new("price", "MONEY", 18).unwrap(),
1878 BulkColumn::new("smoney", "SMALLMONEY", 19).unwrap(),
1879 BulkColumn::new("nmax", "NVARCHAR(MAX)", 20).unwrap(),
1880 BulkColumn::new("vmax", "VARCHAR(MAX)", 21).unwrap(),
1881 BulkColumn::new("bmax", "VARBINARY(MAX)", 22).unwrap(),
1882 ];
1883
1884 let bulk = BulkInsert::new(columns.clone(), 0);
1885
1886 let buf = &bulk.buffer[1..];
1888 let mut cursor = bytes::Bytes::copy_from_slice(buf);
1889 let meta = ColMetaData::decode(&mut cursor)
1890 .expect("write_colmetadata output should be parseable by TDS decoder");
1891
1892 assert_eq!(meta.columns.len(), columns.len());
1893
1894 for (i, (parsed, original)) in meta.columns.iter().zip(columns.iter()).enumerate() {
1896 assert_eq!(parsed.name, original.name, "column {i} name mismatch");
1897 assert_eq!(
1898 parsed.col_type, original.type_id,
1899 "column {i} ({}) type mismatch",
1900 original.name
1901 );
1902
1903 match original.type_id {
1905 0x26 => {
1907 assert_eq!(
1908 parsed.type_info.max_length, original.max_length,
1909 "column {i} ({}) INTN max_length",
1910 original.name
1911 );
1912 }
1913 0x68 => {
1915 assert_eq!(parsed.type_info.max_length, Some(1));
1916 }
1917 0x6D => {
1919 assert_eq!(
1920 parsed.type_info.max_length, original.max_length,
1921 "column {i} ({}) FLTN max_length",
1922 original.name
1923 );
1924 }
1925 0x6E => {
1927 assert_eq!(
1928 parsed.type_info.max_length, original.max_length,
1929 "column {i} ({}) MONEYN max_length",
1930 original.name
1931 );
1932 }
1933 0x6F => {
1935 assert_eq!(
1936 parsed.type_info.max_length, original.max_length,
1937 "column {i} ({}) DATETIMEN max_length",
1938 original.name
1939 );
1940 }
1941 0x24 => {
1943 assert_eq!(parsed.type_info.max_length, Some(16));
1944 }
1945 0x28 => {}
1947 0x29..=0x2B => {
1949 assert_eq!(
1950 parsed.type_info.scale, original.scale,
1951 "column {i} ({}) scale",
1952 original.name
1953 );
1954 }
1955 0xE7 | 0xA7 => {
1957 assert_eq!(
1958 parsed.type_info.max_length, original.max_length,
1959 "column {i} ({}) string max_length",
1960 original.name
1961 );
1962 assert!(
1963 parsed.type_info.collation.is_some(),
1964 "column {i} ({}) should have collation",
1965 original.name
1966 );
1967 }
1968 0xA5 => {
1970 assert_eq!(
1971 parsed.type_info.max_length, original.max_length,
1972 "column {i} ({}) binary max_length",
1973 original.name
1974 );
1975 assert!(
1976 parsed.type_info.collation.is_none(),
1977 "column {i} ({}) should not have collation",
1978 original.name
1979 );
1980 }
1981 0x6C => {
1983 assert_eq!(
1984 parsed.type_info.precision, original.precision,
1985 "column {i} ({}) precision",
1986 original.name
1987 );
1988 assert_eq!(
1989 parsed.type_info.scale, original.scale,
1990 "column {i} ({}) scale",
1991 original.name
1992 );
1993 }
1994 _ => {}
1995 }
1996 }
1997 }
1998
1999 #[test]
2003 fn test_write_colmetadata_not_null_uses_fixed_types() {
2004 use tds_protocol::token::ColMetaData;
2005 use tds_protocol::types::TypeId;
2006
2007 let columns = vec![
2008 BulkColumn::new("id", "INT", 0)
2009 .unwrap()
2010 .with_nullable(false),
2011 BulkColumn::new("tiny", "TINYINT", 1)
2012 .unwrap()
2013 .with_nullable(false),
2014 BulkColumn::new("small", "SMALLINT", 2)
2015 .unwrap()
2016 .with_nullable(false),
2017 BulkColumn::new("big", "BIGINT", 3)
2018 .unwrap()
2019 .with_nullable(false),
2020 BulkColumn::new("flag", "BIT", 4)
2021 .unwrap()
2022 .with_nullable(false),
2023 BulkColumn::new("r", "REAL", 5)
2024 .unwrap()
2025 .with_nullable(false),
2026 BulkColumn::new("f", "FLOAT", 6)
2027 .unwrap()
2028 .with_nullable(false),
2029 BulkColumn::new("dt", "DATETIME", 7)
2030 .unwrap()
2031 .with_nullable(false),
2032 BulkColumn::new("sdt", "SMALLDATETIME", 8)
2033 .unwrap()
2034 .with_nullable(false),
2035 BulkColumn::new("mny", "MONEY", 9)
2036 .unwrap()
2037 .with_nullable(false),
2038 BulkColumn::new("smny", "SMALLMONEY", 10)
2039 .unwrap()
2040 .with_nullable(false),
2041 ];
2042
2043 let bulk = BulkInsert::new(columns.clone(), 0);
2044
2045 for (i, fixed) in bulk.fixed_len.iter().enumerate() {
2047 assert!(
2048 *fixed,
2049 "column {i} ({}) should be fixed_len",
2050 columns[i].name
2051 );
2052 }
2053
2054 let buf = &bulk.buffer[1..]; let mut cursor = bytes::Bytes::copy_from_slice(buf);
2057 let meta = ColMetaData::decode(&mut cursor).expect("parseable");
2058
2059 let expected: &[(&str, TypeId)] = &[
2061 ("id", TypeId::Int4),
2062 ("tiny", TypeId::Int1),
2063 ("small", TypeId::Int2),
2064 ("big", TypeId::Int8),
2065 ("flag", TypeId::Bit),
2066 ("r", TypeId::Float4),
2067 ("f", TypeId::Float8),
2068 ("dt", TypeId::DateTime),
2069 ("sdt", TypeId::DateTime4),
2070 ("mny", TypeId::Money),
2071 ("smny", TypeId::Money4),
2072 ];
2073
2074 for (i, (name, ty)) in expected.iter().enumerate() {
2075 assert_eq!(meta.columns[i].name, *name, "column {i} name");
2076 assert_eq!(meta.columns[i].type_id, *ty, "column {i} ({name}) type");
2077 assert_eq!(
2078 meta.columns[i].flags & 0x0001,
2079 0,
2080 "column {i} ({name}) should not have Nullable flag set"
2081 );
2082 }
2083 }
2084
2085 #[test]
2089 fn test_write_colmetadata_uses_caller_collation() {
2090 use tds_protocol::token::{ColMetaData, Collation};
2091
2092 let chinese = Collation {
2094 lcid: 0x0804,
2095 sort_id: 0x52,
2096 };
2097
2098 let columns = vec![
2099 BulkColumn::new("s", "VARCHAR(50)", 0)
2100 .unwrap()
2101 .with_collation(chinese),
2102 BulkColumn::new("n", "NVARCHAR(50)", 1)
2104 .unwrap()
2105 .with_collation(chinese),
2106 BulkColumn::new("d", "VARCHAR(10)", 2).unwrap(),
2108 ];
2109 let bulk = BulkInsert::new(columns, 0);
2110
2111 let buf = &bulk.buffer[1..];
2112 let mut cursor = bytes::Bytes::copy_from_slice(buf);
2113 let meta = ColMetaData::decode(&mut cursor).expect("parseable");
2114
2115 let c0 = meta.columns[0]
2116 .type_info
2117 .collation
2118 .as_ref()
2119 .expect("VARCHAR has collation");
2120 assert_eq!(c0.lcid, chinese.lcid, "VARCHAR caller LCID");
2121 assert_eq!(c0.sort_id, chinese.sort_id, "VARCHAR caller sort_id");
2122
2123 let c1 = meta.columns[1]
2124 .type_info
2125 .collation
2126 .as_ref()
2127 .expect("NVARCHAR has collation");
2128 assert_eq!(c1.lcid, chinese.lcid, "NVARCHAR caller LCID");
2129 assert_eq!(c1.sort_id, chinese.sort_id, "NVARCHAR caller sort_id");
2130
2131 let default = meta.columns[2]
2134 .type_info
2135 .collation
2136 .as_ref()
2137 .expect("VARCHAR has default collation");
2138 assert_eq!(default.to_bytes(), [0x09, 0x04, 0xD0, 0x00, 0x34]);
2139 }
2140
2141 #[test]
2142 fn test_parse_sql_type_max() {
2143 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(MAX)").unwrap();
2145 assert_eq!(type_id, 0xE7);
2146 assert_eq!(len, Some(0xFFFF)); let (type_id, len, _, _) = parse_sql_type("VARBINARY(MAX)").unwrap();
2150 assert_eq!(type_id, 0xA5);
2151 assert_eq!(len, Some(0xFFFF));
2152
2153 let (type_id, len, _, _) = parse_sql_type("VARCHAR(MAX)").unwrap();
2155 assert_eq!(type_id, 0xA7);
2156 assert_eq!(len, Some(0xFFFF));
2157
2158 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)").unwrap();
2160 assert_eq!(type_id, 0xE7);
2161 assert_eq!(len, Some(200)); }
2163}