1use bytes::{BufMut, BytesMut};
57use once_cell::sync::Lazy;
58use regex::Regex;
59use std::sync::Arc;
60
61use mssql_types::{SqlValue, ToSql, TypeError};
62use tds_protocol::packet::{PacketHeader, PacketStatus, PacketType};
63use tds_protocol::token::{Collation, DoneStatus, TokenType};
64
65use crate::error::Error;
66
67#[derive(Debug, Clone)]
72pub struct BulkOptions {
73 pub batch_size: usize,
83
84 pub check_constraints: bool,
88
89 pub fire_triggers: bool,
93
94 pub keep_nulls: bool,
98
99 pub table_lock: bool,
105
106 pub order_hint: Option<Vec<String>>,
112}
113
114impl Default for BulkOptions {
115 fn default() -> Self {
116 Self {
117 batch_size: 0,
118 check_constraints: true,
119 fire_triggers: false,
120 keep_nulls: true,
121 table_lock: false,
122 order_hint: None,
123 }
124 }
125}
126
127#[derive(Debug, Clone)]
129pub struct BulkColumn {
130 pub name: String,
132 pub sql_type: String,
134 pub nullable: bool,
136 pub ordinal: usize,
138 type_id: u8,
140 max_length: Option<u32>,
142 precision: Option<u8>,
144 scale: Option<u8>,
146 collation: Option<Collation>,
154}
155
156impl BulkColumn {
157 pub fn new<S: Into<String>>(name: S, sql_type: S, ordinal: usize) -> Result<Self, TypeError> {
167 let sql_type_str: String = sql_type.into();
168 reject_unsupported_bulk_type(&sql_type_str)?;
169 let (type_id, max_length, precision, scale) = parse_sql_type(&sql_type_str);
170
171 Ok(Self {
172 name: name.into(),
173 sql_type: sql_type_str,
174 nullable: true,
175 ordinal,
176 type_id,
177 max_length,
178 precision,
179 scale,
180 collation: None,
181 })
182 }
183
184 #[must_use]
186 pub fn with_nullable(mut self, nullable: bool) -> Self {
187 self.nullable = nullable;
188 self
189 }
190
191 #[must_use]
198 pub fn with_collation(mut self, collation: Collation) -> Self {
199 self.collation = Some(collation);
200 self
201 }
202}
203
204fn parse_sql_type(sql_type: &str) -> (u8, Option<u32>, Option<u8>, Option<u8>) {
213 let upper = sql_type.to_uppercase();
214
215 let (base, params) = if let Some(paren_pos) = upper.find('(') {
217 let base = &upper[..paren_pos];
218 let params_str = upper[paren_pos + 1..].trim_end_matches(')');
219 (base, Some(params_str))
220 } else {
221 (upper.as_str(), None)
222 };
223
224 match base {
229 "BIT" => (0x68, Some(1), None, None), "TINYINT" => (0x26, Some(1), None, None), "SMALLINT" => (0x26, Some(2), None, None), "INT" => (0x26, Some(4), None, None), "BIGINT" => (0x26, Some(8), None, None), "REAL" => (0x6D, Some(4), None, None), "FLOAT" => (0x6D, Some(8), None, None), "DATE" => (0x28, None, None, None),
237 "TIME" => {
238 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
239 (0x29, None, None, Some(scale))
240 }
241 "DATETIME" => (0x6F, Some(8), None, None), "DATETIME2" => {
243 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
244 (0x2A, None, None, Some(scale))
245 }
246 "DATETIMEOFFSET" => {
247 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
248 (0x2B, None, None, Some(scale))
249 }
250 "SMALLDATETIME" => (0x6F, Some(4), None, None), "UNIQUEIDENTIFIER" => (0x24, Some(16), None, None),
252 "VARCHAR" | "CHAR" => {
253 let len = params
254 .and_then(|p| {
255 if p == "MAX" {
256 Some(0xFFFF_u32)
257 } else {
258 p.parse().ok()
259 }
260 })
261 .unwrap_or(8000);
262 (0xA7, Some(len), None, None)
263 }
264 "NVARCHAR" | "NCHAR" => {
265 let is_max = params.map(|p| p == "MAX").unwrap_or(false);
266 if is_max {
267 (0xE7, Some(0xFFFF), None, None)
269 } else {
270 let len = params.and_then(|p| p.parse().ok()).unwrap_or(4000);
272 (0xE7, Some(len * 2), None, None)
273 }
274 }
275 "VARBINARY" | "BINARY" => {
276 let len = params
277 .and_then(|p| {
278 if p == "MAX" {
279 Some(0xFFFF_u32)
280 } else {
281 p.parse().ok()
282 }
283 })
284 .unwrap_or(8000);
285 (0xA5, Some(len), None, None)
286 }
287 "DECIMAL" | "NUMERIC" => {
288 let (precision, scale) = if let Some(p) = params {
289 let parts: Vec<&str> = p.split(',').map(|s| s.trim()).collect();
290 (
291 parts.first().and_then(|s| s.parse().ok()).unwrap_or(18),
292 parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0),
293 )
294 } else {
295 (18, 0)
296 };
297 (0x6C, None, Some(precision), Some(scale))
298 }
299 "MONEY" => (0x6E, Some(8), None, None), "SMALLMONEY" => (0x6E, Some(4), None, None), "XML" => (0xF1, Some(0xFFFF), None, None),
302 _ => (0xE7, Some(8000), None, None), }
304}
305
306fn reject_unsupported_bulk_type(sql_type: &str) -> Result<(), TypeError> {
312 let base = sql_type
313 .split('(')
314 .next()
315 .unwrap_or("")
316 .trim()
317 .to_uppercase();
318 match base.as_str() {
319 "TEXT" | "NTEXT" => Err(TypeError::UnsupportedType {
320 sql_type: base,
321 reason: "TEXT/NTEXT are not supported. Use VARCHAR(MAX) / \
322 NVARCHAR(MAX) instead (Microsoft deprecated TEXT/NTEXT in \
323 SQL Server 2005)."
324 .to_string(),
325 }),
326 "IMAGE" => Err(TypeError::UnsupportedType {
327 sql_type: base,
328 reason: "IMAGE is not supported. Use VARBINARY(MAX) instead \
329 (Microsoft deprecated IMAGE in SQL Server 2005)."
330 .to_string(),
331 }),
332 _ => Ok(()),
333 }
334}
335
336#[derive(Debug, Clone)]
338pub struct BulkInsertResult {
339 pub rows_affected: u64,
341 pub batches_committed: u32,
345 pub has_errors: bool,
347}
348
349#[derive(Debug)]
351pub struct BulkInsertBuilder {
352 table_name: String,
353 columns: Vec<BulkColumn>,
354 options: BulkOptions,
355}
356
357impl BulkInsertBuilder {
358 pub fn new<S: Into<String>>(table_name: S) -> Self {
360 Self {
361 table_name: table_name.into(),
362 columns: Vec::new(),
363 options: BulkOptions::default(),
364 }
365 }
366
367 #[must_use]
372 #[allow(clippy::expect_used)] pub fn with_columns(mut self, column_names: &[&str]) -> Self {
374 self.columns = column_names
375 .iter()
376 .enumerate()
377 .map(|(i, name)| {
378 BulkColumn::new(*name, "NVARCHAR(MAX)", i)
379 .expect("NVARCHAR(MAX) is always a supported type")
380 })
381 .collect();
382 self
383 }
384
385 #[must_use]
387 pub fn with_typed_columns(mut self, columns: Vec<BulkColumn>) -> Self {
388 self.columns = columns;
389 self
390 }
391
392 #[must_use]
394 pub fn with_options(mut self, options: BulkOptions) -> Self {
395 self.options = options;
396 self
397 }
398
399 #[must_use]
401 pub fn batch_size(mut self, size: usize) -> Self {
402 self.options.batch_size = size;
403 self
404 }
405
406 #[must_use]
408 pub fn table_lock(mut self, enabled: bool) -> Self {
409 self.options.table_lock = enabled;
410 self
411 }
412
413 #[must_use]
415 pub fn fire_triggers(mut self, enabled: bool) -> Self {
416 self.options.fire_triggers = enabled;
417 self
418 }
419
420 pub fn table_name(&self) -> &str {
422 &self.table_name
423 }
424
425 pub fn columns(&self) -> &[BulkColumn] {
427 &self.columns
428 }
429
430 pub fn options(&self) -> &BulkOptions {
432 &self.options
433 }
434
435 pub fn build_insert_bulk_statement(&self) -> Result<String, Error> {
442 crate::validation::validate_qualified_identifier(&self.table_name)?;
444
445 for col in &self.columns {
447 crate::validation::validate_identifier(&col.name)?;
448 }
449
450 let mut sql = format!("INSERT BULK {}", self.table_name);
451
452 if !self.columns.is_empty() {
454 sql.push_str(" (");
455 let cols: Vec<String> = self
456 .columns
457 .iter()
458 .map(|c| {
459 validate_sql_type(&c.sql_type)?;
465 Ok(format!("{} {}", c.name, c.sql_type))
466 })
467 .collect::<Result<Vec<_>, Error>>()?;
468 sql.push_str(&cols.join(", "));
469 sql.push(')');
470 }
471
472 let mut hints: Vec<String> = Vec::new();
474
475 if self.options.check_constraints {
476 hints.push("CHECK_CONSTRAINTS".to_string());
477 }
478 if self.options.fire_triggers {
479 hints.push("FIRE_TRIGGERS".to_string());
480 }
481 if self.options.keep_nulls {
482 hints.push("KEEP_NULLS".to_string());
483 }
484 if self.options.table_lock {
485 hints.push("TABLOCK".to_string());
486 }
487 if self.options.batch_size > 0 {
488 hints.push(format!("ROWS_PER_BATCH = {}", self.options.batch_size));
489 }
490
491 if let Some(ref order) = self.options.order_hint {
492 for col_name in order {
494 crate::validation::validate_identifier(col_name)?;
495 }
496 hints.push(format!("ORDER({})", order.join(", ")));
497 }
498
499 if !hints.is_empty() {
500 sql.push_str(" WITH (");
501 sql.push_str(&hints.join(", "));
502 sql.push(')');
503 }
504
505 Ok(sql)
506 }
507}
508
509fn validate_sql_type(type_str: &str) -> Result<(), Error> {
515 #[allow(clippy::expect_used)] static SQL_TYPE_RE: Lazy<Regex> =
517 Lazy::new(|| Regex::new(r"^[a-zA-Z][a-zA-Z0-9_ ()\.,]{0,127}$").expect("valid regex"));
518
519 if type_str.is_empty() {
520 return Err(Error::Config("SQL type cannot be empty".into()));
521 }
522
523 if !SQL_TYPE_RE.is_match(type_str) {
524 return Err(Error::Config(format!(
525 "invalid SQL type '{type_str}': contains disallowed characters"
526 )));
527 }
528
529 Ok(())
530}
531
532pub struct BulkInsert {
537 columns: Arc<[BulkColumn]>,
539 fixed_len: Arc<[bool]>,
542 buffer: BytesMut,
544 rows_in_batch: usize,
546 total_rows: u64,
548 batch_size: usize,
550 batches_committed: u32,
552 packet_id: u8,
554}
555
556impl BulkInsert {
557 pub fn new(columns: Vec<BulkColumn>, batch_size: usize) -> Self {
559 Self::new_with_server_metadata(columns, batch_size, None, None)
560 }
561
562 pub fn new_with_server_metadata(
573 mut columns: Vec<BulkColumn>,
574 batch_size: usize,
575 raw_colmetadata: Option<bytes::Bytes>,
576 server_columns: Option<&[tds_protocol::token::ColumnData]>,
577 ) -> Self {
578 let fixed_len: Vec<bool> = if let Some(srv_cols) = server_columns {
581 for (col, srv) in columns.iter_mut().zip(srv_cols.iter()) {
587 if col.collation.is_none() {
588 col.collation = srv.type_info.collation;
589 }
590 }
591 srv_cols
592 .iter()
593 .map(|c| c.type_id.is_fixed_length())
594 .collect()
595 } else {
596 columns
601 .iter()
602 .map(|c| !c.nullable && nullable_to_fixed_type(c.type_id, c.max_length).is_some())
603 .collect()
604 };
605
606 let mut bulk = Self {
607 columns: columns.into(),
608 fixed_len: fixed_len.into(),
609 buffer: BytesMut::with_capacity(64 * 1024),
610 rows_in_batch: 0,
611 total_rows: 0,
612 batch_size,
613 batches_committed: 0,
614 packet_id: 1,
615 };
616
617 if let Some(raw) = raw_colmetadata {
618 bulk.buffer.extend_from_slice(&raw);
619 } else {
620 bulk.write_colmetadata();
621 }
622
623 bulk
624 }
625
626 fn write_colmetadata(&mut self) {
628 let buf = &mut self.buffer;
629
630 buf.put_u8(TokenType::ColMetaData as u8);
632
633 buf.put_u16_le(self.columns.len() as u16);
635
636 for col in self.columns.iter() {
637 buf.put_u32_le(0);
639
640 let effective_type_id = if !col.nullable {
644 nullable_to_fixed_type(col.type_id, col.max_length).unwrap_or(col.type_id)
645 } else {
646 col.type_id
647 };
648 let is_fixed_variant = effective_type_id != col.type_id;
649
650 let mut flags: u16 = 0x0008; if col.nullable {
654 flags |= 0x0001; }
656 buf.put_u16_le(flags);
657
658 buf.put_u8(effective_type_id);
660
661 if is_fixed_variant {
664 let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
665 buf.put_u8(name_utf16.len() as u8);
666 for code_unit in name_utf16 {
667 buf.put_u16_le(code_unit);
668 }
669 continue;
670 }
671
672 match col.type_id {
674 0x26 | 0x68 | 0x6D | 0x6E | 0x6F => {
677 buf.put_u8(col.max_length.unwrap_or(4) as u8);
678 }
679
680 0x28 => {}
682
683 0xE7 | 0xA7 | 0xA5 | 0xAD => {
685 let max_len = col.max_length.unwrap_or(8000);
687 if max_len == 0xFFFF {
688 buf.put_u16_le(0xFFFF);
689 } else {
690 buf.put_u16_le(max_len as u16);
691 }
692
693 if col.type_id == 0xE7 || col.type_id == 0xA7 {
697 if let Some(coll) = col.collation.as_ref() {
698 buf.put_slice(&coll.to_bytes());
699 } else {
700 buf.put_slice(&[0x09, 0x04, 0xD0, 0x00, 0x34]);
703 }
704 }
705 }
706
707 0x6C | 0x6A => {
709 let precision = col.precision.unwrap_or(18);
711 let len = decimal_byte_length(precision);
712 buf.put_u8(len);
713 buf.put_u8(precision);
714 buf.put_u8(col.scale.unwrap_or(0));
715 }
716
717 0x29..=0x2B => {
719 buf.put_u8(col.scale.unwrap_or(7));
720 }
721
722 0x24 => {
724 buf.put_u8(16);
725 }
726
727 _ => {
729 if let Some(len) = col.max_length {
730 if len <= 0xFFFF {
731 buf.put_u16_le(len as u16);
732 }
733 }
734 }
735 }
736
737 let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
739 buf.put_u8(name_utf16.len() as u8);
740 for code_unit in name_utf16 {
741 buf.put_u16_le(code_unit);
742 }
743 }
744 }
745
746 pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
757 if values.len() != self.columns.len() {
758 return Err(Error::Config(format!(
759 "expected {} values, got {}",
760 self.columns.len(),
761 values.len()
762 )));
763 }
764
765 let sql_values: Result<Vec<SqlValue>, TypeError> =
767 values.iter().map(|v| v.to_sql()).collect();
768 let sql_values = sql_values.map_err(Error::from)?;
769
770 self.write_row(&sql_values)?;
771
772 self.rows_in_batch += 1;
773 self.total_rows += 1;
774
775 Ok(())
776 }
777
778 pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
780 if values.len() != self.columns.len() {
781 return Err(Error::Config(format!(
782 "expected {} values, got {}",
783 self.columns.len(),
784 values.len()
785 )));
786 }
787
788 self.write_row(values)?;
789
790 self.rows_in_batch += 1;
791 self.total_rows += 1;
792
793 Ok(())
794 }
795
796 fn write_row(&mut self, values: &[SqlValue]) -> Result<(), Error> {
798 self.buffer.put_u8(TokenType::Row as u8);
800
801 let columns: Vec<_> = self.columns.iter().cloned().collect();
803 let fixed_len = self.fixed_len.clone();
804
805 for (i, (col, value)) in columns.iter().zip(values.iter()).enumerate() {
807 let is_fixed = *fixed_len.get(i).unwrap_or(&false);
808 self.encode_column_value(col, value, is_fixed)
809 .map_err(|e| Error::Config(format!("failed to encode column {i}: {e}")))?;
810 }
811
812 Ok(())
813 }
814
815 fn encode_column_value(
821 &mut self,
822 col: &BulkColumn,
823 value: &SqlValue,
824 is_fixed: bool,
825 ) -> Result<(), TypeError> {
826 let buf = &mut self.buffer;
827
828 let is_plp_type =
831 col.max_length == Some(0xFFFF) && matches!(col.type_id, 0xE7 | 0xA7 | 0xA5 | 0xAD);
832
833 match value {
834 SqlValue::Null => {
835 match col.type_id {
837 0xE7 | 0xA7 | 0xA5 | 0xAD => {
839 if is_plp_type {
840 buf.put_u64_le(0xFFFF_FFFF_FFFF_FFFF);
842 } else {
843 buf.put_u16_le(0xFFFF);
845 }
846 }
847 0x26 | 0x68 | 0x6D | 0x6E | 0x6F | 0x6C | 0x6A | 0x24 | 0x28 | 0x29 | 0x2A
850 | 0x2B => {
851 buf.put_u8(0);
852 }
853 _ => {
855 if col.nullable {
856 buf.put_u8(0);
857 } else {
858 return Err(TypeError::UnexpectedNull);
859 }
860 }
861 }
862 }
863
864 SqlValue::Bool(v) => {
865 if !is_fixed {
866 buf.put_u8(1);
867 }
868 buf.put_u8(if *v { 1 } else { 0 });
869 }
870
871 SqlValue::TinyInt(v) => {
872 if !is_fixed {
873 buf.put_u8(1);
874 }
875 buf.put_u8(*v);
876 }
877
878 SqlValue::SmallInt(v) => {
879 if !is_fixed {
880 buf.put_u8(2);
881 }
882 buf.put_i16_le(*v);
883 }
884
885 SqlValue::Int(v) => {
886 if !is_fixed {
887 buf.put_u8(4);
888 }
889 buf.put_i32_le(*v);
890 }
891
892 SqlValue::BigInt(v) => {
893 if !is_fixed {
894 buf.put_u8(8);
895 }
896 buf.put_i64_le(*v);
897 }
898
899 SqlValue::Float(v) => {
900 if !is_fixed {
901 buf.put_u8(4);
902 }
903 buf.put_f32_le(*v);
904 }
905
906 SqlValue::Double(v) => {
907 if !is_fixed {
908 buf.put_u8(8);
909 }
910 buf.put_f64_le(*v);
911 }
912
913 SqlValue::String(s) => {
914 let is_varchar = matches!(col.type_id, 0xA7 | 0x2F | 0xAF);
920
921 if is_varchar {
922 let encoded = encode_varchar_for_collation(s, col.collation.as_ref());
923 let byte_len = encoded.len();
924
925 if is_plp_type {
926 encode_plp_binary(&encoded, buf);
927 } else if byte_len > 0xFFFF {
928 return Err(TypeError::BufferTooSmall {
929 needed: byte_len,
930 available: 0xFFFF,
931 });
932 } else {
933 buf.put_u16_le(byte_len as u16);
934 buf.put_slice(&encoded);
935 }
936 } else {
937 let utf16: Vec<u16> = s.encode_utf16().collect();
939 let byte_len = utf16.len() * 2;
940
941 if is_plp_type {
942 encode_plp_string(&utf16, buf);
945 } else if byte_len > 0xFFFF {
946 return Err(TypeError::BufferTooSmall {
948 needed: byte_len,
949 available: 0xFFFF,
950 });
951 } else {
952 buf.put_u16_le(byte_len as u16);
954 for code_unit in utf16 {
955 buf.put_u16_le(code_unit);
956 }
957 }
958 }
959 }
960
961 SqlValue::Binary(b) => {
962 if is_plp_type {
963 encode_plp_binary(b, buf);
965 } else if b.len() > 0xFFFF {
966 return Err(TypeError::BufferTooSmall {
968 needed: b.len(),
969 available: 0xFFFF,
970 });
971 } else {
972 buf.put_u16_le(b.len() as u16);
974 buf.put_slice(b);
975 }
976 }
977
978 #[cfg(feature = "decimal")]
980 SqlValue::Decimal(d) => {
981 if col.type_id == 0x6E {
982 encode_money_value(*d, col, buf, is_fixed)?;
984 } else {
985 let precision = col.precision.unwrap_or(18);
986 let len = decimal_byte_length(precision);
987 buf.put_u8(len);
988
989 buf.put_u8(if d.is_sign_negative() { 0 } else { 1 });
991
992 let mantissa = d.mantissa().unsigned_abs();
994 let mantissa_bytes = mantissa.to_le_bytes();
995 buf.put_slice(&mantissa_bytes[..((len - 1) as usize)]);
996 }
997 }
998
999 #[cfg(feature = "uuid")]
1000 SqlValue::Uuid(u) => {
1001 buf.put_u8(16); mssql_types::encode::encode_uuid(*u, buf);
1004 }
1005
1006 #[cfg(feature = "chrono")]
1007 SqlValue::Date(d) => {
1008 buf.put_u8(3); mssql_types::encode::encode_date(*d, buf)?;
1010 }
1011
1012 #[cfg(feature = "chrono")]
1013 SqlValue::Time(t) => {
1014 let scale = col.scale.unwrap_or(7);
1015 let len = time_byte_length(scale);
1016 buf.put_u8(len);
1017 encode_time_with_scale(*t, scale, buf);
1019 }
1020
1021 #[cfg(feature = "chrono")]
1022 SqlValue::DateTime(dt) => {
1023 if col.type_id == 0x6F {
1028 let total_len = col.max_length.unwrap_or(8) as u8;
1029 if !is_fixed {
1030 buf.put_u8(total_len);
1031 }
1032 match total_len {
1033 8 => mssql_types::encode::encode_datetime_legacy(*dt, buf),
1034 4 => mssql_types::encode::encode_smalldatetime(*dt, buf)?,
1035 _ => {
1036 return Err(TypeError::InvalidDateTime(format!(
1037 "DATETIMEN max_length must be 4 or 8, got {total_len}"
1038 )));
1039 }
1040 }
1041 } else {
1042 let scale = col.scale.unwrap_or(7);
1043 let time_len = time_byte_length(scale);
1044 let total_len = time_len + 3;
1045 buf.put_u8(total_len);
1046 encode_time_with_scale(dt.time(), scale, buf);
1048 mssql_types::encode::encode_date(dt.date(), buf)?;
1049 }
1050 }
1051 #[cfg(feature = "chrono")]
1052 SqlValue::SmallDateTime(dt) => {
1053 if !is_fixed {
1056 buf.put_u8(4);
1057 }
1058 mssql_types::encode::encode_smalldatetime(*dt, buf)?;
1059 }
1060 #[cfg(feature = "decimal")]
1061 SqlValue::Money(d) => {
1062 if !is_fixed {
1064 buf.put_u8(8);
1065 }
1066 mssql_types::encode::encode_money(*d, buf)?;
1067 }
1068 #[cfg(feature = "decimal")]
1069 SqlValue::SmallMoney(d) => {
1070 if !is_fixed {
1071 buf.put_u8(4);
1072 }
1073 mssql_types::encode::encode_smallmoney(*d, buf)?;
1074 }
1075
1076 #[cfg(feature = "chrono")]
1077 SqlValue::DateTimeOffset(dto) => {
1078 let scale = col.scale.unwrap_or(7);
1079 let time_len = time_byte_length(scale);
1080 let total_len = time_len + 3 + 2;
1081 buf.put_u8(total_len);
1082 let utc = dto.naive_utc();
1085 encode_time_with_scale(utc.time(), scale, buf);
1086 mssql_types::encode::encode_date(utc.date(), buf)?;
1087 use chrono::Offset;
1089 let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
1090 buf.put_i16_le(offset_minutes);
1091 }
1092
1093 #[cfg(feature = "json")]
1094 SqlValue::Json(j) => {
1095 let s = j.to_string();
1096 encode_nvarchar_value(&s, buf)?;
1097 }
1098
1099 SqlValue::Xml(x) => {
1100 encode_nvarchar_value(x, buf)?;
1101 }
1102
1103 SqlValue::Tvp(_) => {
1104 return Err(TypeError::UnsupportedConversion {
1106 from: "TVP".to_string(),
1107 to: "bulk copy value",
1108 });
1109 }
1110 _ => {
1112 return Err(TypeError::UnsupportedConversion {
1113 from: value.type_name().to_string(),
1114 to: "bulk copy value",
1115 });
1116 }
1117 }
1118
1119 Ok(())
1120 }
1121}
1122
1123#[cfg(feature = "decimal")]
1129fn encode_money_value(
1130 value: rust_decimal::Decimal,
1131 col: &BulkColumn,
1132 buf: &mut BytesMut,
1133 is_fixed: bool,
1134) -> Result<(), TypeError> {
1135 let money_bytes: u8 = col.max_length.unwrap_or(8) as u8;
1136 if !is_fixed {
1137 buf.put_u8(money_bytes);
1138 }
1139 match money_bytes {
1140 4 => mssql_types::encode::encode_smallmoney(value, buf),
1141 8 => mssql_types::encode::encode_money(value, buf),
1142 _ => Err(TypeError::InvalidDecimal(format!(
1143 "MONEY column has invalid max_length: {money_bytes}"
1144 ))),
1145 }
1146}
1147
1148fn encode_nvarchar_value(s: &str, buf: &mut BytesMut) -> Result<(), TypeError> {
1150 let utf16: Vec<u16> = s.encode_utf16().collect();
1151 let byte_len = utf16.len() * 2;
1152
1153 if byte_len > 0xFFFF {
1154 return Err(TypeError::BufferTooSmall {
1155 needed: byte_len,
1156 available: 0xFFFF,
1157 });
1158 }
1159
1160 buf.put_u16_le(byte_len as u16);
1161 for code_unit in utf16 {
1162 buf.put_u16_le(code_unit);
1163 }
1164 Ok(())
1165}
1166
1167const PLP_UNKNOWN_LEN: u64 = 0xFFFFFFFFFFFFFFFE;
1172
1173fn encode_plp_string(utf16: &[u16], buf: &mut BytesMut) {
1186 let byte_len = utf16.len() * 2;
1187
1188 buf.put_u64_le(PLP_UNKNOWN_LEN);
1189
1190 if byte_len > 0 {
1191 buf.put_u32_le(byte_len as u32);
1192 for code_unit in utf16 {
1193 buf.put_u16_le(*code_unit);
1194 }
1195 }
1196
1197 buf.put_u32_le(0);
1198}
1199
1200fn encode_plp_binary(data: &[u8], buf: &mut BytesMut) {
1203 buf.put_u64_le(PLP_UNKNOWN_LEN);
1204
1205 if !data.is_empty() {
1206 buf.put_u32_le(data.len() as u32);
1207 buf.put_slice(data);
1208 }
1209
1210 buf.put_u32_le(0);
1211}
1212
1213fn encode_varchar_for_collation(value: &str, collation: Option<&Collation>) -> Vec<u8> {
1218 tds_protocol::collation::encode_str_for_collation(value, collation)
1219}
1220
1221#[cfg(feature = "chrono")]
1223fn encode_time_with_scale(time: chrono::NaiveTime, scale: u8, buf: &mut BytesMut) {
1224 use chrono::Timelike;
1225
1226 let nanos = time.num_seconds_from_midnight() as u64 * 1_000_000_000 + time.nanosecond() as u64;
1227 let intervals = nanos / time_scale_divisor(scale);
1228 let len = time_byte_length(scale);
1229
1230 for i in 0..len {
1231 buf.put_u8(((intervals >> (i * 8)) & 0xFF) as u8);
1232 }
1233}
1234
1235impl BulkInsert {
1236 fn write_done(&mut self) {
1238 let buf = &mut self.buffer;
1239
1240 buf.put_u8(TokenType::Done as u8);
1241
1242 let status = DoneStatus::from_bits(0x0010); buf.put_u16_le(status.to_bits());
1245
1246 buf.put_u16_le(0);
1248
1249 buf.put_u64_le(self.total_rows);
1251 }
1252
1253 pub fn take_packets(&mut self) -> Vec<BytesMut> {
1257 const MAX_PACKET_SIZE: usize = 4096;
1258 const HEADER_SIZE: usize = 8;
1259 const MAX_PAYLOAD: usize = MAX_PACKET_SIZE - HEADER_SIZE;
1260
1261 let data = self.buffer.split();
1262 let mut packets = Vec::new();
1263 let mut offset = 0;
1264
1265 while offset < data.len() {
1266 let remaining = data.len() - offset;
1267 let payload_size = remaining.min(MAX_PAYLOAD);
1268 let is_last = offset + payload_size >= data.len();
1269
1270 let mut packet = BytesMut::with_capacity(MAX_PACKET_SIZE);
1271
1272 let header = PacketHeader {
1274 packet_type: PacketType::BulkLoad,
1275 status: if is_last {
1276 PacketStatus::END_OF_MESSAGE
1277 } else {
1278 PacketStatus::NORMAL
1279 },
1280 length: (HEADER_SIZE + payload_size) as u16,
1281 spid: 0,
1282 packet_id: self.packet_id,
1283 window: 0,
1284 };
1285
1286 header.encode(&mut packet);
1287
1288 packet.put_slice(&data[offset..offset + payload_size]);
1290
1291 packets.push(packet);
1292 offset += payload_size;
1293 self.packet_id = self.packet_id.wrapping_add(1);
1294 }
1295
1296 packets
1297 }
1298
1299 pub fn total_rows(&self) -> u64 {
1301 self.total_rows
1302 }
1303
1304 pub fn rows_in_batch(&self) -> usize {
1306 self.rows_in_batch
1307 }
1308
1309 pub fn should_flush(&self) -> bool {
1315 self.batch_size > 0 && self.rows_in_batch >= self.batch_size
1316 }
1317
1318 pub fn finish_packets(&mut self) -> Vec<BytesMut> {
1321 self.write_done();
1322 self.take_packets()
1323 }
1324
1325 pub fn result(&self) -> BulkInsertResult {
1327 BulkInsertResult {
1328 rows_affected: self.total_rows,
1329 batches_committed: self.batches_committed,
1330 has_errors: false,
1331 }
1332 }
1333}
1334
1335pub struct BulkWriter<'a, S: crate::state::ConnectionState> {
1364 client: &'a mut crate::client::Client<S>,
1365 bulk: BulkInsert,
1366}
1367
1368impl<'a, S: crate::state::ConnectionState> BulkWriter<'a, S> {
1369 pub(crate) fn new(client: &'a mut crate::client::Client<S>, bulk: BulkInsert) -> Self {
1371 Self { client, bulk }
1372 }
1373
1374 pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
1380 self.bulk.send_row(values)
1381 }
1382
1383 pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
1385 self.bulk.send_row_values(values)
1386 }
1387
1388 pub fn total_rows(&self) -> u64 {
1390 self.bulk.total_rows()
1391 }
1392
1393 pub async fn finish(mut self) -> Result<BulkInsertResult, Error> {
1405 let deadline = self.client.command_deadline();
1406 let total_rows = self.bulk.total_rows();
1407 tracing::debug!(total_rows = total_rows, "finishing bulk insert");
1408
1409 self.bulk.write_done();
1411 let payload = self.bulk.buffer.split().freeze();
1412
1413 let send_and_read = self.client.send_and_read_bulk_load(payload);
1422 let rows_affected = match deadline {
1423 Some(d) => tokio::time::timeout(d, send_and_read)
1424 .await
1425 .map_err(|_| Error::CommandTimeout)??,
1426 None => send_and_read.await?,
1427 };
1428
1429 Ok(BulkInsertResult {
1430 rows_affected,
1431 batches_committed: 1,
1432 has_errors: false,
1433 })
1434 }
1435}
1436
1437fn nullable_to_fixed_type(type_id: u8, max_length: Option<u32>) -> Option<u8> {
1447 match (type_id, max_length) {
1448 (0x68, _) => Some(0x32), (0x26, Some(1)) => Some(0x30), (0x26, Some(2)) => Some(0x34), (0x26, Some(4)) => Some(0x38), (0x26, Some(8)) => Some(0x7F), (0x6D, Some(4)) => Some(0x3B), (0x6D, Some(8)) => Some(0x3E), (0x6E, Some(4)) => Some(0x7A), (0x6E, Some(8)) => Some(0x3C), (0x6F, Some(4)) => Some(0x3A), (0x6F, Some(8)) => Some(0x3D), _ => None,
1460 }
1461}
1462
1463fn decimal_byte_length(precision: u8) -> u8 {
1465 match precision {
1466 1..=9 => 5,
1467 10..=19 => 9,
1468 20..=28 => 13,
1469 29..=38 => 17,
1470 _ => 17, }
1472}
1473
1474#[cfg(feature = "chrono")]
1476fn time_byte_length(scale: u8) -> u8 {
1477 match scale {
1478 0..=2 => 3,
1479 3..=4 => 4,
1480 5..=7 => 5,
1481 _ => 5,
1482 }
1483}
1484
1485#[cfg(feature = "chrono")]
1487fn time_scale_divisor(scale: u8) -> u64 {
1488 match scale {
1489 0 => 1_000_000_000,
1490 1 => 100_000_000,
1491 2 => 10_000_000,
1492 3 => 1_000_000,
1493 4 => 100_000,
1494 5 => 10_000,
1495 6 => 1_000,
1496 7 => 100,
1497 _ => 100,
1498 }
1499}
1500
1501#[cfg(test)]
1502#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
1503mod tests {
1504 use super::*;
1505
1506 #[test]
1507 fn test_bulk_options_default() {
1508 let opts = BulkOptions::default();
1509 assert_eq!(opts.batch_size, 0);
1510 assert!(opts.check_constraints);
1511 assert!(!opts.fire_triggers);
1512 assert!(opts.keep_nulls);
1513 assert!(!opts.table_lock);
1514 }
1515
1516 #[test]
1517 fn test_bulk_column_creation() {
1518 let col = BulkColumn::new("id", "INT", 0).unwrap();
1519 assert_eq!(col.name, "id");
1520 assert_eq!(col.type_id, 0x26); assert_eq!(col.max_length, Some(4));
1522 assert!(col.nullable);
1523 }
1524
1525 #[test]
1526 fn test_bulk_column_rejects_text() {
1527 let err = BulkColumn::new("body", "TEXT", 0).unwrap_err();
1528 match err {
1529 TypeError::UnsupportedType { sql_type, reason } => {
1530 assert_eq!(sql_type, "TEXT");
1531 assert!(
1532 reason.contains("VARCHAR(MAX)"),
1533 "error should redirect to VARCHAR(MAX), got: {reason}"
1534 );
1535 assert!(
1536 reason.contains("deprecated"),
1537 "error should mention deprecation, got: {reason}"
1538 );
1539 }
1540 other => panic!("expected UnsupportedType, got {other:?}"),
1541 }
1542 }
1543
1544 #[test]
1545 fn test_bulk_column_rejects_ntext() {
1546 let err = BulkColumn::new("body", "NTEXT", 0).unwrap_err();
1547 match err {
1548 TypeError::UnsupportedType { sql_type, reason } => {
1549 assert_eq!(sql_type, "NTEXT");
1550 assert!(
1551 reason.contains("NVARCHAR(MAX)"),
1552 "error should redirect to NVARCHAR(MAX), got: {reason}"
1553 );
1554 assert!(
1555 reason.contains("deprecated"),
1556 "error should mention deprecation, got: {reason}"
1557 );
1558 }
1559 other => panic!("expected UnsupportedType, got {other:?}"),
1560 }
1561 }
1562
1563 #[test]
1564 fn test_bulk_column_rejects_text_case_insensitive() {
1565 assert!(matches!(
1566 BulkColumn::new("body", "text", 0),
1567 Err(TypeError::UnsupportedType { .. })
1568 ));
1569 assert!(matches!(
1570 BulkColumn::new("body", "Ntext", 0),
1571 Err(TypeError::UnsupportedType { .. })
1572 ));
1573 }
1574
1575 #[test]
1576 fn test_bulk_column_rejects_image() {
1577 let err = BulkColumn::new("blob", "IMAGE", 0).unwrap_err();
1578 match err {
1579 TypeError::UnsupportedType { sql_type, reason } => {
1580 assert_eq!(sql_type, "IMAGE");
1581 assert!(
1582 reason.contains("VARBINARY(MAX)"),
1583 "error should redirect to VARBINARY(MAX), got: {reason}"
1584 );
1585 assert!(
1586 reason.contains("deprecated"),
1587 "error should mention deprecation, got: {reason}"
1588 );
1589 }
1590 other => panic!("expected UnsupportedType, got {other:?}"),
1591 }
1592 }
1593
1594 #[test]
1595 fn test_bulk_column_rejects_image_case_insensitive() {
1596 assert!(matches!(
1597 BulkColumn::new("blob", "image", 0),
1598 Err(TypeError::UnsupportedType { .. })
1599 ));
1600 assert!(matches!(
1601 BulkColumn::new("blob", "Image", 0),
1602 Err(TypeError::UnsupportedType { .. })
1603 ));
1604 }
1605
1606 #[test]
1607 fn test_parse_sql_type() {
1608 let (type_id, len, _prec, _scale) = parse_sql_type("INT");
1610 assert_eq!(type_id, 0x26);
1611 assert_eq!(len, Some(4));
1612
1613 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
1614 assert_eq!(type_id, 0xE7);
1615 assert_eq!(len, Some(200)); let (type_id, _, prec, scale) = parse_sql_type("DECIMAL(10,2)");
1618 assert_eq!(type_id, 0x6C);
1619 assert_eq!(prec, Some(10));
1620 assert_eq!(scale, Some(2));
1621
1622 let (type_id, len, _, _) = parse_sql_type("SMALLDATETIME");
1624 assert_eq!(type_id, 0x6F);
1625 assert_eq!(len, Some(4));
1626
1627 let (type_id, len, _, _) = parse_sql_type("DATETIME");
1628 assert_eq!(type_id, 0x6F);
1629 assert_eq!(len, Some(8));
1630 }
1631
1632 #[test]
1633 fn test_insert_bulk_statement() {
1634 let builder = BulkInsertBuilder::new("dbo.Users")
1635 .with_typed_columns(vec![
1636 BulkColumn::new("id", "INT", 0).unwrap(),
1637 BulkColumn::new("name", "NVARCHAR(100)", 1).unwrap(),
1638 ])
1639 .table_lock(true);
1640
1641 let sql = builder.build_insert_bulk_statement().unwrap();
1642 assert!(sql.contains("INSERT BULK dbo.Users"));
1643 assert!(sql.contains("TABLOCK"));
1644 }
1645
1646 #[test]
1647 fn test_bulk_insert_rejects_injection() {
1648 let builder = BulkInsertBuilder::new("table;DROP TABLE users")
1649 .with_typed_columns(vec![BulkColumn::new("id", "INT", 0).unwrap()]);
1650
1651 assert!(builder.build_insert_bulk_statement().is_err());
1652 }
1653
1654 #[test]
1655 fn test_bulk_insert_validates_column_names() {
1656 let builder = BulkInsertBuilder::new("Users")
1657 .with_typed_columns(vec![BulkColumn::new("col;DROP TABLE x", "INT", 0).unwrap()]);
1658
1659 assert!(builder.build_insert_bulk_statement().is_err());
1660 }
1661
1662 #[test]
1663 fn test_bulk_insert_accepts_qualified_names() {
1664 let builder = BulkInsertBuilder::new("catalog.dbo.Users")
1665 .with_typed_columns(vec![BulkColumn::new("id", "INT", 0).unwrap()]);
1666
1667 assert!(builder.build_insert_bulk_statement().is_ok());
1668 }
1669
1670 #[test]
1671 fn test_bulk_insert_creation() {
1672 let columns = vec![
1673 BulkColumn::new("id", "INT", 0).unwrap(),
1674 BulkColumn::new("name", "NVARCHAR(100)", 1).unwrap(),
1675 ];
1676
1677 let bulk = BulkInsert::new(columns, 1000);
1678 assert_eq!(bulk.total_rows(), 0);
1679 assert_eq!(bulk.rows_in_batch(), 0);
1680 assert!(!bulk.should_flush());
1681 }
1682
1683 #[test]
1684 fn test_decimal_byte_length() {
1685 assert_eq!(decimal_byte_length(5), 5);
1686 assert_eq!(decimal_byte_length(15), 9);
1687 assert_eq!(decimal_byte_length(25), 13);
1688 assert_eq!(decimal_byte_length(35), 17);
1689 }
1690
1691 #[test]
1692 #[cfg(feature = "chrono")]
1693 fn test_time_byte_length() {
1694 assert_eq!(time_byte_length(0), 3);
1695 assert_eq!(time_byte_length(3), 4);
1696 assert_eq!(time_byte_length(7), 5);
1697 }
1698
1699 #[test]
1700 fn test_plp_string_encoding() {
1701 let mut buf = BytesMut::new();
1702 let text = "Hello";
1703 let utf16: Vec<u16> = text.encode_utf16().collect();
1704
1705 encode_plp_string(&utf16, &mut buf);
1706
1707 assert_eq!(buf.len(), 8 + 4 + 10 + 4);
1713
1714 assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1716
1717 assert_eq!(&buf[8..12], &10u32.to_le_bytes());
1719
1720 assert_eq!(&buf[22..26], &0u32.to_le_bytes());
1722 }
1723
1724 #[test]
1725 fn test_plp_binary_encoding() {
1726 let mut buf = BytesMut::new();
1727 let data = b"test binary data";
1728
1729 encode_plp_binary(data, &mut buf);
1730
1731 assert_eq!(buf.len(), 8 + 4 + 16 + 4);
1737
1738 assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1740
1741 assert_eq!(&buf[8..12], &16u32.to_le_bytes());
1743
1744 assert_eq!(&buf[12..28], data);
1746
1747 assert_eq!(&buf[28..32], &0u32.to_le_bytes());
1749 }
1750
1751 #[test]
1752 fn test_plp_empty_string() {
1753 let mut buf = BytesMut::new();
1754 let utf16: Vec<u16> = "".encode_utf16().collect();
1755
1756 encode_plp_string(&utf16, &mut buf);
1757
1758 assert_eq!(buf.len(), 8 + 4);
1760
1761 assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1763
1764 assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1766 }
1767
1768 #[test]
1769 fn test_plp_empty_binary() {
1770 let mut buf = BytesMut::new();
1771
1772 encode_plp_binary(&[], &mut buf);
1773
1774 assert_eq!(buf.len(), 8 + 4);
1776
1777 assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1779
1780 assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1782 }
1783
1784 #[test]
1787 fn test_write_colmetadata_roundtrip() {
1788 use tds_protocol::token::ColMetaData;
1789
1790 let columns = vec![
1791 BulkColumn::new("id", "INT", 0).unwrap(),
1792 BulkColumn::new("tiny", "TINYINT", 1).unwrap(),
1793 BulkColumn::new("small", "SMALLINT", 2).unwrap(),
1794 BulkColumn::new("big", "BIGINT", 3).unwrap(),
1795 BulkColumn::new("flag", "BIT", 4).unwrap(),
1796 BulkColumn::new("r", "REAL", 5).unwrap(),
1797 BulkColumn::new("f", "FLOAT", 6).unwrap(),
1798 BulkColumn::new("name", "NVARCHAR(100)", 7).unwrap(),
1799 BulkColumn::new("code", "VARCHAR(50)", 8).unwrap(),
1800 BulkColumn::new("data", "VARBINARY(200)", 9).unwrap(),
1801 BulkColumn::new("d", "DATE", 10).unwrap(),
1802 BulkColumn::new("t", "TIME(3)", 11).unwrap(),
1803 BulkColumn::new("dt", "DATETIME", 12).unwrap(),
1804 BulkColumn::new("dt2", "DATETIME2(7)", 13).unwrap(),
1805 BulkColumn::new("dto", "DATETIMEOFFSET(7)", 14).unwrap(),
1806 BulkColumn::new("sdt", "SMALLDATETIME", 15).unwrap(),
1807 BulkColumn::new("uid", "UNIQUEIDENTIFIER", 16).unwrap(),
1808 BulkColumn::new("amt", "DECIMAL(18,2)", 17).unwrap(),
1809 BulkColumn::new("price", "MONEY", 18).unwrap(),
1810 BulkColumn::new("smoney", "SMALLMONEY", 19).unwrap(),
1811 BulkColumn::new("nmax", "NVARCHAR(MAX)", 20).unwrap(),
1812 BulkColumn::new("vmax", "VARCHAR(MAX)", 21).unwrap(),
1813 BulkColumn::new("bmax", "VARBINARY(MAX)", 22).unwrap(),
1814 ];
1815
1816 let bulk = BulkInsert::new(columns.clone(), 0);
1817
1818 let buf = &bulk.buffer[1..];
1820 let mut cursor = bytes::Bytes::copy_from_slice(buf);
1821 let meta = ColMetaData::decode(&mut cursor)
1822 .expect("write_colmetadata output should be parseable by TDS decoder");
1823
1824 assert_eq!(meta.columns.len(), columns.len());
1825
1826 for (i, (parsed, original)) in meta.columns.iter().zip(columns.iter()).enumerate() {
1828 assert_eq!(parsed.name, original.name, "column {i} name mismatch");
1829 assert_eq!(
1830 parsed.col_type, original.type_id,
1831 "column {i} ({}) type mismatch",
1832 original.name
1833 );
1834
1835 match original.type_id {
1837 0x26 => {
1839 assert_eq!(
1840 parsed.type_info.max_length, original.max_length,
1841 "column {i} ({}) INTN max_length",
1842 original.name
1843 );
1844 }
1845 0x68 => {
1847 assert_eq!(parsed.type_info.max_length, Some(1));
1848 }
1849 0x6D => {
1851 assert_eq!(
1852 parsed.type_info.max_length, original.max_length,
1853 "column {i} ({}) FLTN max_length",
1854 original.name
1855 );
1856 }
1857 0x6E => {
1859 assert_eq!(
1860 parsed.type_info.max_length, original.max_length,
1861 "column {i} ({}) MONEYN max_length",
1862 original.name
1863 );
1864 }
1865 0x6F => {
1867 assert_eq!(
1868 parsed.type_info.max_length, original.max_length,
1869 "column {i} ({}) DATETIMEN max_length",
1870 original.name
1871 );
1872 }
1873 0x24 => {
1875 assert_eq!(parsed.type_info.max_length, Some(16));
1876 }
1877 0x28 => {}
1879 0x29..=0x2B => {
1881 assert_eq!(
1882 parsed.type_info.scale, original.scale,
1883 "column {i} ({}) scale",
1884 original.name
1885 );
1886 }
1887 0xE7 | 0xA7 => {
1889 assert_eq!(
1890 parsed.type_info.max_length, original.max_length,
1891 "column {i} ({}) string max_length",
1892 original.name
1893 );
1894 assert!(
1895 parsed.type_info.collation.is_some(),
1896 "column {i} ({}) should have collation",
1897 original.name
1898 );
1899 }
1900 0xA5 => {
1902 assert_eq!(
1903 parsed.type_info.max_length, original.max_length,
1904 "column {i} ({}) binary max_length",
1905 original.name
1906 );
1907 assert!(
1908 parsed.type_info.collation.is_none(),
1909 "column {i} ({}) should not have collation",
1910 original.name
1911 );
1912 }
1913 0x6C => {
1915 assert_eq!(
1916 parsed.type_info.precision, original.precision,
1917 "column {i} ({}) precision",
1918 original.name
1919 );
1920 assert_eq!(
1921 parsed.type_info.scale, original.scale,
1922 "column {i} ({}) scale",
1923 original.name
1924 );
1925 }
1926 _ => {}
1927 }
1928 }
1929 }
1930
1931 #[test]
1935 fn test_write_colmetadata_not_null_uses_fixed_types() {
1936 use tds_protocol::token::ColMetaData;
1937 use tds_protocol::types::TypeId;
1938
1939 let columns = vec![
1940 BulkColumn::new("id", "INT", 0)
1941 .unwrap()
1942 .with_nullable(false),
1943 BulkColumn::new("tiny", "TINYINT", 1)
1944 .unwrap()
1945 .with_nullable(false),
1946 BulkColumn::new("small", "SMALLINT", 2)
1947 .unwrap()
1948 .with_nullable(false),
1949 BulkColumn::new("big", "BIGINT", 3)
1950 .unwrap()
1951 .with_nullable(false),
1952 BulkColumn::new("flag", "BIT", 4)
1953 .unwrap()
1954 .with_nullable(false),
1955 BulkColumn::new("r", "REAL", 5)
1956 .unwrap()
1957 .with_nullable(false),
1958 BulkColumn::new("f", "FLOAT", 6)
1959 .unwrap()
1960 .with_nullable(false),
1961 BulkColumn::new("dt", "DATETIME", 7)
1962 .unwrap()
1963 .with_nullable(false),
1964 BulkColumn::new("sdt", "SMALLDATETIME", 8)
1965 .unwrap()
1966 .with_nullable(false),
1967 BulkColumn::new("mny", "MONEY", 9)
1968 .unwrap()
1969 .with_nullable(false),
1970 BulkColumn::new("smny", "SMALLMONEY", 10)
1971 .unwrap()
1972 .with_nullable(false),
1973 ];
1974
1975 let bulk = BulkInsert::new(columns.clone(), 0);
1976
1977 for (i, fixed) in bulk.fixed_len.iter().enumerate() {
1979 assert!(
1980 *fixed,
1981 "column {i} ({}) should be fixed_len",
1982 columns[i].name
1983 );
1984 }
1985
1986 let buf = &bulk.buffer[1..]; let mut cursor = bytes::Bytes::copy_from_slice(buf);
1989 let meta = ColMetaData::decode(&mut cursor).expect("parseable");
1990
1991 let expected: &[(&str, TypeId)] = &[
1993 ("id", TypeId::Int4),
1994 ("tiny", TypeId::Int1),
1995 ("small", TypeId::Int2),
1996 ("big", TypeId::Int8),
1997 ("flag", TypeId::Bit),
1998 ("r", TypeId::Float4),
1999 ("f", TypeId::Float8),
2000 ("dt", TypeId::DateTime),
2001 ("sdt", TypeId::DateTime4),
2002 ("mny", TypeId::Money),
2003 ("smny", TypeId::Money4),
2004 ];
2005
2006 for (i, (name, ty)) in expected.iter().enumerate() {
2007 assert_eq!(meta.columns[i].name, *name, "column {i} name");
2008 assert_eq!(meta.columns[i].type_id, *ty, "column {i} ({name}) type");
2009 assert_eq!(
2010 meta.columns[i].flags & 0x0001,
2011 0,
2012 "column {i} ({name}) should not have Nullable flag set"
2013 );
2014 }
2015 }
2016
2017 #[test]
2021 fn test_write_colmetadata_uses_caller_collation() {
2022 use tds_protocol::token::{ColMetaData, Collation};
2023
2024 let chinese = Collation {
2026 lcid: 0x0804,
2027 sort_id: 0x52,
2028 };
2029
2030 let columns = vec![
2031 BulkColumn::new("s", "VARCHAR(50)", 0)
2032 .unwrap()
2033 .with_collation(chinese),
2034 BulkColumn::new("n", "NVARCHAR(50)", 1)
2036 .unwrap()
2037 .with_collation(chinese),
2038 BulkColumn::new("d", "VARCHAR(10)", 2).unwrap(),
2040 ];
2041 let bulk = BulkInsert::new(columns, 0);
2042
2043 let buf = &bulk.buffer[1..];
2044 let mut cursor = bytes::Bytes::copy_from_slice(buf);
2045 let meta = ColMetaData::decode(&mut cursor).expect("parseable");
2046
2047 let c0 = meta.columns[0]
2048 .type_info
2049 .collation
2050 .as_ref()
2051 .expect("VARCHAR has collation");
2052 assert_eq!(c0.lcid, chinese.lcid, "VARCHAR caller LCID");
2053 assert_eq!(c0.sort_id, chinese.sort_id, "VARCHAR caller sort_id");
2054
2055 let c1 = meta.columns[1]
2056 .type_info
2057 .collation
2058 .as_ref()
2059 .expect("NVARCHAR has collation");
2060 assert_eq!(c1.lcid, chinese.lcid, "NVARCHAR caller LCID");
2061 assert_eq!(c1.sort_id, chinese.sort_id, "NVARCHAR caller sort_id");
2062
2063 let default = meta.columns[2]
2066 .type_info
2067 .collation
2068 .as_ref()
2069 .expect("VARCHAR has default collation");
2070 assert_eq!(default.to_bytes(), [0x09, 0x04, 0xD0, 0x00, 0x34]);
2071 }
2072
2073 #[test]
2074 fn test_parse_sql_type_max() {
2075 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(MAX)");
2077 assert_eq!(type_id, 0xE7);
2078 assert_eq!(len, Some(0xFFFF)); let (type_id, len, _, _) = parse_sql_type("VARBINARY(MAX)");
2082 assert_eq!(type_id, 0xA5);
2083 assert_eq!(len, Some(0xFFFF));
2084
2085 let (type_id, len, _, _) = parse_sql_type("VARCHAR(MAX)");
2087 assert_eq!(type_id, 0xA7);
2088 assert_eq!(len, Some(0xFFFF));
2089
2090 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
2092 assert_eq!(type_id, 0xE7);
2093 assert_eq!(len, Some(200)); }
2095}