1use bytes::{BufMut, BytesMut};
57use once_cell::sync::Lazy;
58use regex::Regex;
59use std::sync::Arc;
60
61use mssql_types::{SqlValue, ToSql, TypeError};
62use tds_protocol::packet::{PacketHeader, PacketStatus, PacketType};
63use tds_protocol::token::{Collation, DoneStatus, TokenType};
64
65use crate::error::Error;
66
67#[derive(Debug, Clone)]
72pub struct BulkOptions {
73 pub batch_size: usize,
79
80 pub check_constraints: bool,
84
85 pub fire_triggers: bool,
89
90 pub keep_nulls: bool,
94
95 pub table_lock: bool,
101
102 pub order_hint: Option<Vec<String>>,
108}
109
110impl Default for BulkOptions {
111 fn default() -> Self {
112 Self {
113 batch_size: 0,
114 check_constraints: true,
115 fire_triggers: false,
116 keep_nulls: true,
117 table_lock: false,
118 order_hint: None,
119 }
120 }
121}
122
123#[derive(Debug, Clone)]
125pub struct BulkColumn {
126 pub name: String,
128 pub sql_type: String,
130 pub nullable: bool,
132 pub ordinal: usize,
134 type_id: u8,
136 max_length: Option<u32>,
138 precision: Option<u8>,
140 scale: Option<u8>,
142 collation: Option<Collation>,
150}
151
152impl BulkColumn {
153 pub fn new<S: Into<String>>(name: S, sql_type: S, ordinal: usize) -> Result<Self, TypeError> {
163 let sql_type_str: String = sql_type.into();
164 reject_unsupported_bulk_type(&sql_type_str)?;
165 let (type_id, max_length, precision, scale) = parse_sql_type(&sql_type_str);
166
167 Ok(Self {
168 name: name.into(),
169 sql_type: sql_type_str,
170 nullable: true,
171 ordinal,
172 type_id,
173 max_length,
174 precision,
175 scale,
176 collation: None,
177 })
178 }
179
180 #[must_use]
182 pub fn with_nullable(mut self, nullable: bool) -> Self {
183 self.nullable = nullable;
184 self
185 }
186
187 #[must_use]
194 pub fn with_collation(mut self, collation: Collation) -> Self {
195 self.collation = Some(collation);
196 self
197 }
198}
199
200fn parse_sql_type(sql_type: &str) -> (u8, Option<u32>, Option<u8>, Option<u8>) {
209 let upper = sql_type.to_uppercase();
210
211 let (base, params) = if let Some(paren_pos) = upper.find('(') {
213 let base = &upper[..paren_pos];
214 let params_str = upper[paren_pos + 1..].trim_end_matches(')');
215 (base, Some(params_str))
216 } else {
217 (upper.as_str(), None)
218 };
219
220 match base {
225 "BIT" => (0x68, Some(1), None, None), "TINYINT" => (0x26, Some(1), None, None), "SMALLINT" => (0x26, Some(2), None, None), "INT" => (0x26, Some(4), None, None), "BIGINT" => (0x26, Some(8), None, None), "REAL" => (0x6D, Some(4), None, None), "FLOAT" => (0x6D, Some(8), None, None), "DATE" => (0x28, None, None, None),
233 "TIME" => {
234 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
235 (0x29, None, None, Some(scale))
236 }
237 "DATETIME" => (0x6F, Some(8), None, None), "DATETIME2" => {
239 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
240 (0x2A, None, None, Some(scale))
241 }
242 "DATETIMEOFFSET" => {
243 let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
244 (0x2B, None, None, Some(scale))
245 }
246 "SMALLDATETIME" => (0x6F, Some(4), None, None), "UNIQUEIDENTIFIER" => (0x24, Some(16), None, None),
248 "VARCHAR" | "CHAR" => {
249 let len = params
250 .and_then(|p| {
251 if p == "MAX" {
252 Some(0xFFFF_u32)
253 } else {
254 p.parse().ok()
255 }
256 })
257 .unwrap_or(8000);
258 (0xA7, Some(len), None, None)
259 }
260 "NVARCHAR" | "NCHAR" => {
261 let is_max = params.map(|p| p == "MAX").unwrap_or(false);
262 if is_max {
263 (0xE7, Some(0xFFFF), None, None)
265 } else {
266 let len = params.and_then(|p| p.parse().ok()).unwrap_or(4000);
268 (0xE7, Some(len * 2), None, None)
269 }
270 }
271 "VARBINARY" | "BINARY" => {
272 let len = params
273 .and_then(|p| {
274 if p == "MAX" {
275 Some(0xFFFF_u32)
276 } else {
277 p.parse().ok()
278 }
279 })
280 .unwrap_or(8000);
281 (0xA5, Some(len), None, None)
282 }
283 "DECIMAL" | "NUMERIC" => {
284 let (precision, scale) = if let Some(p) = params {
285 let parts: Vec<&str> = p.split(',').map(|s| s.trim()).collect();
286 (
287 parts.first().and_then(|s| s.parse().ok()).unwrap_or(18),
288 parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0),
289 )
290 } else {
291 (18, 0)
292 };
293 (0x6C, None, Some(precision), Some(scale))
294 }
295 "MONEY" => (0x6E, Some(8), None, None), "SMALLMONEY" => (0x6E, Some(4), None, None), "XML" => (0xF1, Some(0xFFFF), None, None),
298 _ => (0xE7, Some(8000), None, None), }
300}
301
302fn reject_unsupported_bulk_type(sql_type: &str) -> Result<(), TypeError> {
308 let base = sql_type
309 .split('(')
310 .next()
311 .unwrap_or("")
312 .trim()
313 .to_uppercase();
314 match base.as_str() {
315 "TEXT" | "NTEXT" => Err(TypeError::UnsupportedType {
316 sql_type: base,
317 reason: "TEXT/NTEXT are not supported. Use VARCHAR(MAX) / \
318 NVARCHAR(MAX) instead (Microsoft deprecated TEXT/NTEXT in \
319 SQL Server 2005)."
320 .to_string(),
321 }),
322 "IMAGE" => Err(TypeError::UnsupportedType {
323 sql_type: base,
324 reason: "IMAGE is not supported. Use VARBINARY(MAX) instead \
325 (Microsoft deprecated IMAGE in SQL Server 2005)."
326 .to_string(),
327 }),
328 _ => Ok(()),
329 }
330}
331
332#[derive(Debug, Clone)]
334pub struct BulkInsertResult {
335 pub rows_affected: u64,
337 pub batches_committed: u32,
339 pub has_errors: bool,
341}
342
343#[derive(Debug)]
345pub struct BulkInsertBuilder {
346 table_name: String,
347 columns: Vec<BulkColumn>,
348 options: BulkOptions,
349}
350
351impl BulkInsertBuilder {
352 pub fn new<S: Into<String>>(table_name: S) -> Self {
354 Self {
355 table_name: table_name.into(),
356 columns: Vec::new(),
357 options: BulkOptions::default(),
358 }
359 }
360
361 #[must_use]
366 #[allow(clippy::expect_used)] pub fn with_columns(mut self, column_names: &[&str]) -> Self {
368 self.columns = column_names
369 .iter()
370 .enumerate()
371 .map(|(i, name)| {
372 BulkColumn::new(*name, "NVARCHAR(MAX)", i)
373 .expect("NVARCHAR(MAX) is always a supported type")
374 })
375 .collect();
376 self
377 }
378
379 #[must_use]
381 pub fn with_typed_columns(mut self, columns: Vec<BulkColumn>) -> Self {
382 self.columns = columns;
383 self
384 }
385
386 #[must_use]
388 pub fn with_options(mut self, options: BulkOptions) -> Self {
389 self.options = options;
390 self
391 }
392
393 #[must_use]
395 pub fn batch_size(mut self, size: usize) -> Self {
396 self.options.batch_size = size;
397 self
398 }
399
400 #[must_use]
402 pub fn table_lock(mut self, enabled: bool) -> Self {
403 self.options.table_lock = enabled;
404 self
405 }
406
407 #[must_use]
409 pub fn fire_triggers(mut self, enabled: bool) -> Self {
410 self.options.fire_triggers = enabled;
411 self
412 }
413
414 pub fn table_name(&self) -> &str {
416 &self.table_name
417 }
418
419 pub fn columns(&self) -> &[BulkColumn] {
421 &self.columns
422 }
423
424 pub fn options(&self) -> &BulkOptions {
426 &self.options
427 }
428
429 pub fn build_insert_bulk_statement(&self) -> Result<String, Error> {
436 crate::validation::validate_qualified_identifier(&self.table_name)?;
438
439 for col in &self.columns {
441 crate::validation::validate_identifier(&col.name)?;
442 }
443
444 let mut sql = format!("INSERT BULK {}", self.table_name);
445
446 if !self.columns.is_empty() {
448 sql.push_str(" (");
449 let cols: Vec<String> = self
450 .columns
451 .iter()
452 .map(|c| {
453 validate_sql_type(&c.sql_type)?;
459 Ok(format!("{} {}", c.name, c.sql_type))
460 })
461 .collect::<Result<Vec<_>, Error>>()?;
462 sql.push_str(&cols.join(", "));
463 sql.push(')');
464 }
465
466 let mut hints: Vec<String> = Vec::new();
468
469 if self.options.check_constraints {
470 hints.push("CHECK_CONSTRAINTS".to_string());
471 }
472 if self.options.fire_triggers {
473 hints.push("FIRE_TRIGGERS".to_string());
474 }
475 if self.options.keep_nulls {
476 hints.push("KEEP_NULLS".to_string());
477 }
478 if self.options.table_lock {
479 hints.push("TABLOCK".to_string());
480 }
481 if self.options.batch_size > 0 {
482 hints.push(format!("ROWS_PER_BATCH = {}", self.options.batch_size));
483 }
484
485 if let Some(ref order) = self.options.order_hint {
486 for col_name in order {
488 crate::validation::validate_identifier(col_name)?;
489 }
490 hints.push(format!("ORDER({})", order.join(", ")));
491 }
492
493 if !hints.is_empty() {
494 sql.push_str(" WITH (");
495 sql.push_str(&hints.join(", "));
496 sql.push(')');
497 }
498
499 Ok(sql)
500 }
501}
502
503fn validate_sql_type(type_str: &str) -> Result<(), Error> {
509 #[allow(clippy::expect_used)] static SQL_TYPE_RE: Lazy<Regex> =
511 Lazy::new(|| Regex::new(r"^[a-zA-Z][a-zA-Z0-9_ ()\.,]{0,127}$").expect("valid regex"));
512
513 if type_str.is_empty() {
514 return Err(Error::Config("SQL type cannot be empty".into()));
515 }
516
517 if !SQL_TYPE_RE.is_match(type_str) {
518 return Err(Error::Config(format!(
519 "invalid SQL type '{type_str}': contains disallowed characters"
520 )));
521 }
522
523 Ok(())
524}
525
526pub struct BulkInsert {
531 columns: Arc<[BulkColumn]>,
533 fixed_len: Arc<[bool]>,
536 buffer: BytesMut,
538 rows_in_batch: usize,
540 total_rows: u64,
542 batch_size: usize,
544 batches_committed: u32,
546 packet_id: u8,
548}
549
550impl BulkInsert {
551 pub fn new(columns: Vec<BulkColumn>, batch_size: usize) -> Self {
553 Self::new_with_server_metadata(columns, batch_size, None, None)
554 }
555
556 pub fn new_with_server_metadata(
567 mut columns: Vec<BulkColumn>,
568 batch_size: usize,
569 raw_colmetadata: Option<bytes::Bytes>,
570 server_columns: Option<&[tds_protocol::token::ColumnData]>,
571 ) -> Self {
572 let fixed_len: Vec<bool> = if let Some(srv_cols) = server_columns {
575 for (col, srv) in columns.iter_mut().zip(srv_cols.iter()) {
581 if col.collation.is_none() {
582 col.collation = srv.type_info.collation;
583 }
584 }
585 srv_cols
586 .iter()
587 .map(|c| c.type_id.is_fixed_length())
588 .collect()
589 } else {
590 columns
595 .iter()
596 .map(|c| !c.nullable && nullable_to_fixed_type(c.type_id, c.max_length).is_some())
597 .collect()
598 };
599
600 let mut bulk = Self {
601 columns: columns.into(),
602 fixed_len: fixed_len.into(),
603 buffer: BytesMut::with_capacity(64 * 1024),
604 rows_in_batch: 0,
605 total_rows: 0,
606 batch_size,
607 batches_committed: 0,
608 packet_id: 1,
609 };
610
611 if let Some(raw) = raw_colmetadata {
612 bulk.buffer.extend_from_slice(&raw);
613 } else {
614 bulk.write_colmetadata();
615 }
616
617 bulk
618 }
619
620 fn write_colmetadata(&mut self) {
622 let buf = &mut self.buffer;
623
624 buf.put_u8(TokenType::ColMetaData as u8);
626
627 buf.put_u16_le(self.columns.len() as u16);
629
630 for col in self.columns.iter() {
631 buf.put_u32_le(0);
633
634 let effective_type_id = if !col.nullable {
638 nullable_to_fixed_type(col.type_id, col.max_length).unwrap_or(col.type_id)
639 } else {
640 col.type_id
641 };
642 let is_fixed_variant = effective_type_id != col.type_id;
643
644 let mut flags: u16 = 0x0008; if col.nullable {
648 flags |= 0x0001; }
650 buf.put_u16_le(flags);
651
652 buf.put_u8(effective_type_id);
654
655 if is_fixed_variant {
658 let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
659 buf.put_u8(name_utf16.len() as u8);
660 for code_unit in name_utf16 {
661 buf.put_u16_le(code_unit);
662 }
663 continue;
664 }
665
666 match col.type_id {
668 0x26 | 0x68 | 0x6D | 0x6E | 0x6F => {
671 buf.put_u8(col.max_length.unwrap_or(4) as u8);
672 }
673
674 0x28 => {}
676
677 0xE7 | 0xA7 | 0xA5 | 0xAD => {
679 let max_len = col.max_length.unwrap_or(8000);
681 if max_len == 0xFFFF {
682 buf.put_u16_le(0xFFFF);
683 } else {
684 buf.put_u16_le(max_len as u16);
685 }
686
687 if col.type_id == 0xE7 || col.type_id == 0xA7 {
691 if let Some(coll) = col.collation.as_ref() {
692 buf.put_slice(&coll.to_bytes());
693 } else {
694 buf.put_slice(&[0x09, 0x04, 0xD0, 0x00, 0x34]);
697 }
698 }
699 }
700
701 0x6C | 0x6A => {
703 let precision = col.precision.unwrap_or(18);
705 let len = decimal_byte_length(precision);
706 buf.put_u8(len);
707 buf.put_u8(precision);
708 buf.put_u8(col.scale.unwrap_or(0));
709 }
710
711 0x29..=0x2B => {
713 buf.put_u8(col.scale.unwrap_or(7));
714 }
715
716 0x24 => {
718 buf.put_u8(16);
719 }
720
721 _ => {
723 if let Some(len) = col.max_length {
724 if len <= 0xFFFF {
725 buf.put_u16_le(len as u16);
726 }
727 }
728 }
729 }
730
731 let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
733 buf.put_u8(name_utf16.len() as u8);
734 for code_unit in name_utf16 {
735 buf.put_u16_le(code_unit);
736 }
737 }
738 }
739
740 pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
751 if values.len() != self.columns.len() {
752 return Err(Error::Config(format!(
753 "expected {} values, got {}",
754 self.columns.len(),
755 values.len()
756 )));
757 }
758
759 let sql_values: Result<Vec<SqlValue>, TypeError> =
761 values.iter().map(|v| v.to_sql()).collect();
762 let sql_values = sql_values.map_err(Error::from)?;
763
764 self.write_row(&sql_values)?;
765
766 self.rows_in_batch += 1;
767 self.total_rows += 1;
768
769 Ok(())
770 }
771
772 pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
774 if values.len() != self.columns.len() {
775 return Err(Error::Config(format!(
776 "expected {} values, got {}",
777 self.columns.len(),
778 values.len()
779 )));
780 }
781
782 self.write_row(values)?;
783
784 self.rows_in_batch += 1;
785 self.total_rows += 1;
786
787 Ok(())
788 }
789
790 fn write_row(&mut self, values: &[SqlValue]) -> Result<(), Error> {
792 self.buffer.put_u8(TokenType::Row as u8);
794
795 let columns: Vec<_> = self.columns.iter().cloned().collect();
797 let fixed_len = self.fixed_len.clone();
798
799 for (i, (col, value)) in columns.iter().zip(values.iter()).enumerate() {
801 let is_fixed = *fixed_len.get(i).unwrap_or(&false);
802 self.encode_column_value(col, value, is_fixed)
803 .map_err(|e| Error::Config(format!("failed to encode column {i}: {e}")))?;
804 }
805
806 Ok(())
807 }
808
809 fn encode_column_value(
815 &mut self,
816 col: &BulkColumn,
817 value: &SqlValue,
818 is_fixed: bool,
819 ) -> Result<(), TypeError> {
820 let buf = &mut self.buffer;
821
822 let is_plp_type =
825 col.max_length == Some(0xFFFF) && matches!(col.type_id, 0xE7 | 0xA7 | 0xA5 | 0xAD);
826
827 match value {
828 SqlValue::Null => {
829 match col.type_id {
831 0xE7 | 0xA7 | 0xA5 | 0xAD => {
833 if is_plp_type {
834 buf.put_u64_le(0xFFFF_FFFF_FFFF_FFFF);
836 } else {
837 buf.put_u16_le(0xFFFF);
839 }
840 }
841 0x26 | 0x68 | 0x6D | 0x6E | 0x6F | 0x6C | 0x6A | 0x24 | 0x28 | 0x29 | 0x2A
844 | 0x2B => {
845 buf.put_u8(0);
846 }
847 _ => {
849 if col.nullable {
850 buf.put_u8(0);
851 } else {
852 return Err(TypeError::UnexpectedNull);
853 }
854 }
855 }
856 }
857
858 SqlValue::Bool(v) => {
859 if !is_fixed {
860 buf.put_u8(1);
861 }
862 buf.put_u8(if *v { 1 } else { 0 });
863 }
864
865 SqlValue::TinyInt(v) => {
866 if !is_fixed {
867 buf.put_u8(1);
868 }
869 buf.put_u8(*v);
870 }
871
872 SqlValue::SmallInt(v) => {
873 if !is_fixed {
874 buf.put_u8(2);
875 }
876 buf.put_i16_le(*v);
877 }
878
879 SqlValue::Int(v) => {
880 if !is_fixed {
881 buf.put_u8(4);
882 }
883 buf.put_i32_le(*v);
884 }
885
886 SqlValue::BigInt(v) => {
887 if !is_fixed {
888 buf.put_u8(8);
889 }
890 buf.put_i64_le(*v);
891 }
892
893 SqlValue::Float(v) => {
894 if !is_fixed {
895 buf.put_u8(4);
896 }
897 buf.put_f32_le(*v);
898 }
899
900 SqlValue::Double(v) => {
901 if !is_fixed {
902 buf.put_u8(8);
903 }
904 buf.put_f64_le(*v);
905 }
906
907 SqlValue::String(s) => {
908 let is_varchar = matches!(col.type_id, 0xA7 | 0x2F | 0xAF);
914
915 if is_varchar {
916 let encoded = encode_varchar_for_collation(s, col.collation.as_ref());
917 let byte_len = encoded.len();
918
919 if is_plp_type {
920 encode_plp_binary(&encoded, buf);
921 } else if byte_len > 0xFFFF {
922 return Err(TypeError::BufferTooSmall {
923 needed: byte_len,
924 available: 0xFFFF,
925 });
926 } else {
927 buf.put_u16_le(byte_len as u16);
928 buf.put_slice(&encoded);
929 }
930 } else {
931 let utf16: Vec<u16> = s.encode_utf16().collect();
933 let byte_len = utf16.len() * 2;
934
935 if is_plp_type {
936 encode_plp_string(&utf16, buf);
939 } else if byte_len > 0xFFFF {
940 return Err(TypeError::BufferTooSmall {
942 needed: byte_len,
943 available: 0xFFFF,
944 });
945 } else {
946 buf.put_u16_le(byte_len as u16);
948 for code_unit in utf16 {
949 buf.put_u16_le(code_unit);
950 }
951 }
952 }
953 }
954
955 SqlValue::Binary(b) => {
956 if is_plp_type {
957 encode_plp_binary(b, buf);
959 } else if b.len() > 0xFFFF {
960 return Err(TypeError::BufferTooSmall {
962 needed: b.len(),
963 available: 0xFFFF,
964 });
965 } else {
966 buf.put_u16_le(b.len() as u16);
968 buf.put_slice(b);
969 }
970 }
971
972 #[cfg(feature = "decimal")]
974 SqlValue::Decimal(d) => {
975 if col.type_id == 0x6E {
976 encode_money_value(*d, col, buf, is_fixed)?;
978 } else {
979 let precision = col.precision.unwrap_or(18);
980 let len = decimal_byte_length(precision);
981 buf.put_u8(len);
982
983 buf.put_u8(if d.is_sign_negative() { 0 } else { 1 });
985
986 let mantissa = d.mantissa().unsigned_abs();
988 let mantissa_bytes = mantissa.to_le_bytes();
989 buf.put_slice(&mantissa_bytes[..((len - 1) as usize)]);
990 }
991 }
992
993 #[cfg(feature = "uuid")]
994 SqlValue::Uuid(u) => {
995 buf.put_u8(16); mssql_types::encode::encode_uuid(*u, buf);
998 }
999
1000 #[cfg(feature = "chrono")]
1001 SqlValue::Date(d) => {
1002 buf.put_u8(3); mssql_types::encode::encode_date(*d, buf);
1004 }
1005
1006 #[cfg(feature = "chrono")]
1007 SqlValue::Time(t) => {
1008 let scale = col.scale.unwrap_or(7);
1009 let len = time_byte_length(scale);
1010 buf.put_u8(len);
1011 encode_time_with_scale(*t, scale, buf);
1013 }
1014
1015 #[cfg(feature = "chrono")]
1016 SqlValue::DateTime(dt) => {
1017 if col.type_id == 0x6F {
1022 let total_len = col.max_length.unwrap_or(8) as u8;
1023 if !is_fixed {
1024 buf.put_u8(total_len);
1025 }
1026 match total_len {
1027 8 => mssql_types::encode::encode_datetime_legacy(*dt, buf),
1028 4 => mssql_types::encode::encode_smalldatetime(*dt, buf)?,
1029 _ => {
1030 return Err(TypeError::InvalidDateTime(format!(
1031 "DATETIMEN max_length must be 4 or 8, got {total_len}"
1032 )));
1033 }
1034 }
1035 } else {
1036 let scale = col.scale.unwrap_or(7);
1037 let time_len = time_byte_length(scale);
1038 let total_len = time_len + 3;
1039 buf.put_u8(total_len);
1040 encode_time_with_scale(dt.time(), scale, buf);
1042 mssql_types::encode::encode_date(dt.date(), buf);
1043 }
1044 }
1045 #[cfg(feature = "chrono")]
1046 SqlValue::SmallDateTime(dt) => {
1047 if !is_fixed {
1050 buf.put_u8(4);
1051 }
1052 mssql_types::encode::encode_smalldatetime(*dt, buf)?;
1053 }
1054 #[cfg(feature = "decimal")]
1055 SqlValue::Money(d) => {
1056 if !is_fixed {
1058 buf.put_u8(8);
1059 }
1060 mssql_types::encode::encode_money(*d, buf)?;
1061 }
1062 #[cfg(feature = "decimal")]
1063 SqlValue::SmallMoney(d) => {
1064 if !is_fixed {
1065 buf.put_u8(4);
1066 }
1067 mssql_types::encode::encode_smallmoney(*d, buf)?;
1068 }
1069
1070 #[cfg(feature = "chrono")]
1071 SqlValue::DateTimeOffset(dto) => {
1072 let scale = col.scale.unwrap_or(7);
1073 let time_len = time_byte_length(scale);
1074 let total_len = time_len + 3 + 2;
1075 buf.put_u8(total_len);
1076 encode_time_with_scale(dto.time(), scale, buf);
1078 mssql_types::encode::encode_date(dto.date_naive(), buf);
1079 use chrono::Offset;
1081 let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
1082 buf.put_i16_le(offset_minutes);
1083 }
1084
1085 #[cfg(feature = "json")]
1086 SqlValue::Json(j) => {
1087 let s = j.to_string();
1088 encode_nvarchar_value(&s, buf)?;
1089 }
1090
1091 SqlValue::Xml(x) => {
1092 encode_nvarchar_value(x, buf)?;
1093 }
1094
1095 SqlValue::Tvp(_) => {
1096 return Err(TypeError::UnsupportedConversion {
1098 from: "TVP".to_string(),
1099 to: "bulk copy value",
1100 });
1101 }
1102 _ => {
1104 return Err(TypeError::UnsupportedConversion {
1105 from: value.type_name().to_string(),
1106 to: "bulk copy value",
1107 });
1108 }
1109 }
1110
1111 Ok(())
1112 }
1113}
1114
1115#[cfg(feature = "decimal")]
1121fn encode_money_value(
1122 value: rust_decimal::Decimal,
1123 col: &BulkColumn,
1124 buf: &mut BytesMut,
1125 is_fixed: bool,
1126) -> Result<(), TypeError> {
1127 let money_bytes: u8 = col.max_length.unwrap_or(8) as u8;
1128 if !is_fixed {
1129 buf.put_u8(money_bytes);
1130 }
1131 match money_bytes {
1132 4 => mssql_types::encode::encode_smallmoney(value, buf),
1133 8 => mssql_types::encode::encode_money(value, buf),
1134 _ => Err(TypeError::InvalidDecimal(format!(
1135 "MONEY column has invalid max_length: {money_bytes}"
1136 ))),
1137 }
1138}
1139
1140fn encode_nvarchar_value(s: &str, buf: &mut BytesMut) -> Result<(), TypeError> {
1142 let utf16: Vec<u16> = s.encode_utf16().collect();
1143 let byte_len = utf16.len() * 2;
1144
1145 if byte_len > 0xFFFF {
1146 return Err(TypeError::BufferTooSmall {
1147 needed: byte_len,
1148 available: 0xFFFF,
1149 });
1150 }
1151
1152 buf.put_u16_le(byte_len as u16);
1153 for code_unit in utf16 {
1154 buf.put_u16_le(code_unit);
1155 }
1156 Ok(())
1157}
1158
1159const PLP_UNKNOWN_LEN: u64 = 0xFFFFFFFFFFFFFFFE;
1164
1165fn encode_plp_string(utf16: &[u16], buf: &mut BytesMut) {
1178 let byte_len = utf16.len() * 2;
1179
1180 buf.put_u64_le(PLP_UNKNOWN_LEN);
1181
1182 if byte_len > 0 {
1183 buf.put_u32_le(byte_len as u32);
1184 for code_unit in utf16 {
1185 buf.put_u16_le(*code_unit);
1186 }
1187 }
1188
1189 buf.put_u32_le(0);
1190}
1191
1192fn encode_plp_binary(data: &[u8], buf: &mut BytesMut) {
1195 buf.put_u64_le(PLP_UNKNOWN_LEN);
1196
1197 if !data.is_empty() {
1198 buf.put_u32_le(data.len() as u32);
1199 buf.put_slice(data);
1200 }
1201
1202 buf.put_u32_le(0);
1203}
1204
1205fn encode_varchar_for_collation(value: &str, collation: Option<&Collation>) -> Vec<u8> {
1210 tds_protocol::collation::encode_str_for_collation(value, collation)
1211}
1212
1213#[cfg(feature = "chrono")]
1215fn encode_time_with_scale(time: chrono::NaiveTime, scale: u8, buf: &mut BytesMut) {
1216 use chrono::Timelike;
1217
1218 let nanos = time.num_seconds_from_midnight() as u64 * 1_000_000_000 + time.nanosecond() as u64;
1219 let intervals = nanos / time_scale_divisor(scale);
1220 let len = time_byte_length(scale);
1221
1222 for i in 0..len {
1223 buf.put_u8(((intervals >> (i * 8)) & 0xFF) as u8);
1224 }
1225}
1226
1227impl BulkInsert {
1228 fn write_done(&mut self) {
1230 let buf = &mut self.buffer;
1231
1232 buf.put_u8(TokenType::Done as u8);
1233
1234 let status = DoneStatus::from_bits(0x0010); buf.put_u16_le(status.to_bits());
1237
1238 buf.put_u16_le(0);
1240
1241 buf.put_u64_le(self.total_rows);
1243 }
1244
1245 pub fn take_packets(&mut self) -> Vec<BytesMut> {
1249 const MAX_PACKET_SIZE: usize = 4096;
1250 const HEADER_SIZE: usize = 8;
1251 const MAX_PAYLOAD: usize = MAX_PACKET_SIZE - HEADER_SIZE;
1252
1253 let data = self.buffer.split();
1254 let mut packets = Vec::new();
1255 let mut offset = 0;
1256
1257 while offset < data.len() {
1258 let remaining = data.len() - offset;
1259 let payload_size = remaining.min(MAX_PAYLOAD);
1260 let is_last = offset + payload_size >= data.len();
1261
1262 let mut packet = BytesMut::with_capacity(MAX_PACKET_SIZE);
1263
1264 let header = PacketHeader {
1266 packet_type: PacketType::BulkLoad,
1267 status: if is_last {
1268 PacketStatus::END_OF_MESSAGE
1269 } else {
1270 PacketStatus::NORMAL
1271 },
1272 length: (HEADER_SIZE + payload_size) as u16,
1273 spid: 0,
1274 packet_id: self.packet_id,
1275 window: 0,
1276 };
1277
1278 header.encode(&mut packet);
1279
1280 packet.put_slice(&data[offset..offset + payload_size]);
1282
1283 packets.push(packet);
1284 offset += payload_size;
1285 self.packet_id = self.packet_id.wrapping_add(1);
1286 }
1287
1288 packets
1289 }
1290
1291 pub fn total_rows(&self) -> u64 {
1293 self.total_rows
1294 }
1295
1296 pub fn rows_in_batch(&self) -> usize {
1298 self.rows_in_batch
1299 }
1300
1301 pub fn should_flush(&self) -> bool {
1303 self.batch_size > 0 && self.rows_in_batch >= self.batch_size
1304 }
1305
1306 pub fn finish_packets(&mut self) -> Vec<BytesMut> {
1309 self.write_done();
1310 self.take_packets()
1311 }
1312
1313 pub fn result(&self) -> BulkInsertResult {
1315 BulkInsertResult {
1316 rows_affected: self.total_rows,
1317 batches_committed: self.batches_committed,
1318 has_errors: false,
1319 }
1320 }
1321}
1322
1323pub struct BulkWriter<'a, S: crate::state::ConnectionState> {
1347 client: &'a mut crate::client::Client<S>,
1348 bulk: BulkInsert,
1349}
1350
1351impl<'a, S: crate::state::ConnectionState> BulkWriter<'a, S> {
1352 pub(crate) fn new(client: &'a mut crate::client::Client<S>, bulk: BulkInsert) -> Self {
1354 Self { client, bulk }
1355 }
1356
1357 pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
1363 self.bulk.send_row(values)
1364 }
1365
1366 pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
1368 self.bulk.send_row_values(values)
1369 }
1370
1371 pub fn total_rows(&self) -> u64 {
1373 self.bulk.total_rows()
1374 }
1375
1376 pub async fn finish(mut self) -> Result<BulkInsertResult, Error> {
1381 let total_rows = self.bulk.total_rows();
1382 tracing::debug!(total_rows = total_rows, "finishing bulk insert");
1383
1384 self.bulk.write_done();
1386 let payload = self.bulk.buffer.split().freeze();
1387
1388 let rows_affected = self.client.send_and_read_bulk_load(payload).await?;
1390
1391 Ok(BulkInsertResult {
1392 rows_affected,
1393 batches_committed: 1,
1394 has_errors: false,
1395 })
1396 }
1397}
1398
1399fn nullable_to_fixed_type(type_id: u8, max_length: Option<u32>) -> Option<u8> {
1409 match (type_id, max_length) {
1410 (0x68, _) => Some(0x32), (0x26, Some(1)) => Some(0x30), (0x26, Some(2)) => Some(0x34), (0x26, Some(4)) => Some(0x38), (0x26, Some(8)) => Some(0x7F), (0x6D, Some(4)) => Some(0x3B), (0x6D, Some(8)) => Some(0x3E), (0x6E, Some(4)) => Some(0x7A), (0x6E, Some(8)) => Some(0x3C), (0x6F, Some(4)) => Some(0x3A), (0x6F, Some(8)) => Some(0x3D), _ => None,
1422 }
1423}
1424
1425fn decimal_byte_length(precision: u8) -> u8 {
1427 match precision {
1428 1..=9 => 5,
1429 10..=19 => 9,
1430 20..=28 => 13,
1431 29..=38 => 17,
1432 _ => 17, }
1434}
1435
1436#[cfg(feature = "chrono")]
1438fn time_byte_length(scale: u8) -> u8 {
1439 match scale {
1440 0..=2 => 3,
1441 3..=4 => 4,
1442 5..=7 => 5,
1443 _ => 5,
1444 }
1445}
1446
1447#[cfg(feature = "chrono")]
1449fn time_scale_divisor(scale: u8) -> u64 {
1450 match scale {
1451 0 => 1_000_000_000,
1452 1 => 100_000_000,
1453 2 => 10_000_000,
1454 3 => 1_000_000,
1455 4 => 100_000,
1456 5 => 10_000,
1457 6 => 1_000,
1458 7 => 100,
1459 _ => 100,
1460 }
1461}
1462
1463#[cfg(test)]
1464#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
1465mod tests {
1466 use super::*;
1467
1468 #[test]
1469 fn test_bulk_options_default() {
1470 let opts = BulkOptions::default();
1471 assert_eq!(opts.batch_size, 0);
1472 assert!(opts.check_constraints);
1473 assert!(!opts.fire_triggers);
1474 assert!(opts.keep_nulls);
1475 assert!(!opts.table_lock);
1476 }
1477
1478 #[test]
1479 fn test_bulk_column_creation() {
1480 let col = BulkColumn::new("id", "INT", 0).unwrap();
1481 assert_eq!(col.name, "id");
1482 assert_eq!(col.type_id, 0x26); assert_eq!(col.max_length, Some(4));
1484 assert!(col.nullable);
1485 }
1486
1487 #[test]
1488 fn test_bulk_column_rejects_text() {
1489 let err = BulkColumn::new("body", "TEXT", 0).unwrap_err();
1490 match err {
1491 TypeError::UnsupportedType { sql_type, reason } => {
1492 assert_eq!(sql_type, "TEXT");
1493 assert!(
1494 reason.contains("VARCHAR(MAX)"),
1495 "error should redirect to VARCHAR(MAX), got: {reason}"
1496 );
1497 assert!(
1498 reason.contains("deprecated"),
1499 "error should mention deprecation, got: {reason}"
1500 );
1501 }
1502 other => panic!("expected UnsupportedType, got {other:?}"),
1503 }
1504 }
1505
1506 #[test]
1507 fn test_bulk_column_rejects_ntext() {
1508 let err = BulkColumn::new("body", "NTEXT", 0).unwrap_err();
1509 match err {
1510 TypeError::UnsupportedType { sql_type, reason } => {
1511 assert_eq!(sql_type, "NTEXT");
1512 assert!(
1513 reason.contains("NVARCHAR(MAX)"),
1514 "error should redirect to NVARCHAR(MAX), got: {reason}"
1515 );
1516 assert!(
1517 reason.contains("deprecated"),
1518 "error should mention deprecation, got: {reason}"
1519 );
1520 }
1521 other => panic!("expected UnsupportedType, got {other:?}"),
1522 }
1523 }
1524
1525 #[test]
1526 fn test_bulk_column_rejects_text_case_insensitive() {
1527 assert!(matches!(
1528 BulkColumn::new("body", "text", 0),
1529 Err(TypeError::UnsupportedType { .. })
1530 ));
1531 assert!(matches!(
1532 BulkColumn::new("body", "Ntext", 0),
1533 Err(TypeError::UnsupportedType { .. })
1534 ));
1535 }
1536
1537 #[test]
1538 fn test_bulk_column_rejects_image() {
1539 let err = BulkColumn::new("blob", "IMAGE", 0).unwrap_err();
1540 match err {
1541 TypeError::UnsupportedType { sql_type, reason } => {
1542 assert_eq!(sql_type, "IMAGE");
1543 assert!(
1544 reason.contains("VARBINARY(MAX)"),
1545 "error should redirect to VARBINARY(MAX), got: {reason}"
1546 );
1547 assert!(
1548 reason.contains("deprecated"),
1549 "error should mention deprecation, got: {reason}"
1550 );
1551 }
1552 other => panic!("expected UnsupportedType, got {other:?}"),
1553 }
1554 }
1555
1556 #[test]
1557 fn test_bulk_column_rejects_image_case_insensitive() {
1558 assert!(matches!(
1559 BulkColumn::new("blob", "image", 0),
1560 Err(TypeError::UnsupportedType { .. })
1561 ));
1562 assert!(matches!(
1563 BulkColumn::new("blob", "Image", 0),
1564 Err(TypeError::UnsupportedType { .. })
1565 ));
1566 }
1567
1568 #[test]
1569 fn test_parse_sql_type() {
1570 let (type_id, len, _prec, _scale) = parse_sql_type("INT");
1572 assert_eq!(type_id, 0x26);
1573 assert_eq!(len, Some(4));
1574
1575 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
1576 assert_eq!(type_id, 0xE7);
1577 assert_eq!(len, Some(200)); let (type_id, _, prec, scale) = parse_sql_type("DECIMAL(10,2)");
1580 assert_eq!(type_id, 0x6C);
1581 assert_eq!(prec, Some(10));
1582 assert_eq!(scale, Some(2));
1583
1584 let (type_id, len, _, _) = parse_sql_type("SMALLDATETIME");
1586 assert_eq!(type_id, 0x6F);
1587 assert_eq!(len, Some(4));
1588
1589 let (type_id, len, _, _) = parse_sql_type("DATETIME");
1590 assert_eq!(type_id, 0x6F);
1591 assert_eq!(len, Some(8));
1592 }
1593
1594 #[test]
1595 fn test_insert_bulk_statement() {
1596 let builder = BulkInsertBuilder::new("dbo.Users")
1597 .with_typed_columns(vec![
1598 BulkColumn::new("id", "INT", 0).unwrap(),
1599 BulkColumn::new("name", "NVARCHAR(100)", 1).unwrap(),
1600 ])
1601 .table_lock(true);
1602
1603 let sql = builder.build_insert_bulk_statement().unwrap();
1604 assert!(sql.contains("INSERT BULK dbo.Users"));
1605 assert!(sql.contains("TABLOCK"));
1606 }
1607
1608 #[test]
1609 fn test_bulk_insert_rejects_injection() {
1610 let builder = BulkInsertBuilder::new("table;DROP TABLE users")
1611 .with_typed_columns(vec![BulkColumn::new("id", "INT", 0).unwrap()]);
1612
1613 assert!(builder.build_insert_bulk_statement().is_err());
1614 }
1615
1616 #[test]
1617 fn test_bulk_insert_validates_column_names() {
1618 let builder = BulkInsertBuilder::new("Users")
1619 .with_typed_columns(vec![BulkColumn::new("col;DROP TABLE x", "INT", 0).unwrap()]);
1620
1621 assert!(builder.build_insert_bulk_statement().is_err());
1622 }
1623
1624 #[test]
1625 fn test_bulk_insert_accepts_qualified_names() {
1626 let builder = BulkInsertBuilder::new("catalog.dbo.Users")
1627 .with_typed_columns(vec![BulkColumn::new("id", "INT", 0).unwrap()]);
1628
1629 assert!(builder.build_insert_bulk_statement().is_ok());
1630 }
1631
1632 #[test]
1633 fn test_bulk_insert_creation() {
1634 let columns = vec![
1635 BulkColumn::new("id", "INT", 0).unwrap(),
1636 BulkColumn::new("name", "NVARCHAR(100)", 1).unwrap(),
1637 ];
1638
1639 let bulk = BulkInsert::new(columns, 1000);
1640 assert_eq!(bulk.total_rows(), 0);
1641 assert_eq!(bulk.rows_in_batch(), 0);
1642 assert!(!bulk.should_flush());
1643 }
1644
1645 #[test]
1646 fn test_decimal_byte_length() {
1647 assert_eq!(decimal_byte_length(5), 5);
1648 assert_eq!(decimal_byte_length(15), 9);
1649 assert_eq!(decimal_byte_length(25), 13);
1650 assert_eq!(decimal_byte_length(35), 17);
1651 }
1652
1653 #[test]
1654 #[cfg(feature = "chrono")]
1655 fn test_time_byte_length() {
1656 assert_eq!(time_byte_length(0), 3);
1657 assert_eq!(time_byte_length(3), 4);
1658 assert_eq!(time_byte_length(7), 5);
1659 }
1660
1661 #[test]
1662 fn test_plp_string_encoding() {
1663 let mut buf = BytesMut::new();
1664 let text = "Hello";
1665 let utf16: Vec<u16> = text.encode_utf16().collect();
1666
1667 encode_plp_string(&utf16, &mut buf);
1668
1669 assert_eq!(buf.len(), 8 + 4 + 10 + 4);
1675
1676 assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1678
1679 assert_eq!(&buf[8..12], &10u32.to_le_bytes());
1681
1682 assert_eq!(&buf[22..26], &0u32.to_le_bytes());
1684 }
1685
1686 #[test]
1687 fn test_plp_binary_encoding() {
1688 let mut buf = BytesMut::new();
1689 let data = b"test binary data";
1690
1691 encode_plp_binary(data, &mut buf);
1692
1693 assert_eq!(buf.len(), 8 + 4 + 16 + 4);
1699
1700 assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1702
1703 assert_eq!(&buf[8..12], &16u32.to_le_bytes());
1705
1706 assert_eq!(&buf[12..28], data);
1708
1709 assert_eq!(&buf[28..32], &0u32.to_le_bytes());
1711 }
1712
1713 #[test]
1714 fn test_plp_empty_string() {
1715 let mut buf = BytesMut::new();
1716 let utf16: Vec<u16> = "".encode_utf16().collect();
1717
1718 encode_plp_string(&utf16, &mut buf);
1719
1720 assert_eq!(buf.len(), 8 + 4);
1722
1723 assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1725
1726 assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1728 }
1729
1730 #[test]
1731 fn test_plp_empty_binary() {
1732 let mut buf = BytesMut::new();
1733
1734 encode_plp_binary(&[], &mut buf);
1735
1736 assert_eq!(buf.len(), 8 + 4);
1738
1739 assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1741
1742 assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1744 }
1745
1746 #[test]
1749 fn test_write_colmetadata_roundtrip() {
1750 use tds_protocol::token::ColMetaData;
1751
1752 let columns = vec![
1753 BulkColumn::new("id", "INT", 0).unwrap(),
1754 BulkColumn::new("tiny", "TINYINT", 1).unwrap(),
1755 BulkColumn::new("small", "SMALLINT", 2).unwrap(),
1756 BulkColumn::new("big", "BIGINT", 3).unwrap(),
1757 BulkColumn::new("flag", "BIT", 4).unwrap(),
1758 BulkColumn::new("r", "REAL", 5).unwrap(),
1759 BulkColumn::new("f", "FLOAT", 6).unwrap(),
1760 BulkColumn::new("name", "NVARCHAR(100)", 7).unwrap(),
1761 BulkColumn::new("code", "VARCHAR(50)", 8).unwrap(),
1762 BulkColumn::new("data", "VARBINARY(200)", 9).unwrap(),
1763 BulkColumn::new("d", "DATE", 10).unwrap(),
1764 BulkColumn::new("t", "TIME(3)", 11).unwrap(),
1765 BulkColumn::new("dt", "DATETIME", 12).unwrap(),
1766 BulkColumn::new("dt2", "DATETIME2(7)", 13).unwrap(),
1767 BulkColumn::new("dto", "DATETIMEOFFSET(7)", 14).unwrap(),
1768 BulkColumn::new("sdt", "SMALLDATETIME", 15).unwrap(),
1769 BulkColumn::new("uid", "UNIQUEIDENTIFIER", 16).unwrap(),
1770 BulkColumn::new("amt", "DECIMAL(18,2)", 17).unwrap(),
1771 BulkColumn::new("price", "MONEY", 18).unwrap(),
1772 BulkColumn::new("smoney", "SMALLMONEY", 19).unwrap(),
1773 BulkColumn::new("nmax", "NVARCHAR(MAX)", 20).unwrap(),
1774 BulkColumn::new("vmax", "VARCHAR(MAX)", 21).unwrap(),
1775 BulkColumn::new("bmax", "VARBINARY(MAX)", 22).unwrap(),
1776 ];
1777
1778 let bulk = BulkInsert::new(columns.clone(), 0);
1779
1780 let buf = &bulk.buffer[1..];
1782 let mut cursor = bytes::Bytes::copy_from_slice(buf);
1783 let meta = ColMetaData::decode(&mut cursor)
1784 .expect("write_colmetadata output should be parseable by TDS decoder");
1785
1786 assert_eq!(meta.columns.len(), columns.len());
1787
1788 for (i, (parsed, original)) in meta.columns.iter().zip(columns.iter()).enumerate() {
1790 assert_eq!(parsed.name, original.name, "column {i} name mismatch");
1791 assert_eq!(
1792 parsed.col_type, original.type_id,
1793 "column {i} ({}) type mismatch",
1794 original.name
1795 );
1796
1797 match original.type_id {
1799 0x26 => {
1801 assert_eq!(
1802 parsed.type_info.max_length, original.max_length,
1803 "column {i} ({}) INTN max_length",
1804 original.name
1805 );
1806 }
1807 0x68 => {
1809 assert_eq!(parsed.type_info.max_length, Some(1));
1810 }
1811 0x6D => {
1813 assert_eq!(
1814 parsed.type_info.max_length, original.max_length,
1815 "column {i} ({}) FLTN max_length",
1816 original.name
1817 );
1818 }
1819 0x6E => {
1821 assert_eq!(
1822 parsed.type_info.max_length, original.max_length,
1823 "column {i} ({}) MONEYN max_length",
1824 original.name
1825 );
1826 }
1827 0x6F => {
1829 assert_eq!(
1830 parsed.type_info.max_length, original.max_length,
1831 "column {i} ({}) DATETIMEN max_length",
1832 original.name
1833 );
1834 }
1835 0x24 => {
1837 assert_eq!(parsed.type_info.max_length, Some(16));
1838 }
1839 0x28 => {}
1841 0x29..=0x2B => {
1843 assert_eq!(
1844 parsed.type_info.scale, original.scale,
1845 "column {i} ({}) scale",
1846 original.name
1847 );
1848 }
1849 0xE7 | 0xA7 => {
1851 assert_eq!(
1852 parsed.type_info.max_length, original.max_length,
1853 "column {i} ({}) string max_length",
1854 original.name
1855 );
1856 assert!(
1857 parsed.type_info.collation.is_some(),
1858 "column {i} ({}) should have collation",
1859 original.name
1860 );
1861 }
1862 0xA5 => {
1864 assert_eq!(
1865 parsed.type_info.max_length, original.max_length,
1866 "column {i} ({}) binary max_length",
1867 original.name
1868 );
1869 assert!(
1870 parsed.type_info.collation.is_none(),
1871 "column {i} ({}) should not have collation",
1872 original.name
1873 );
1874 }
1875 0x6C => {
1877 assert_eq!(
1878 parsed.type_info.precision, original.precision,
1879 "column {i} ({}) precision",
1880 original.name
1881 );
1882 assert_eq!(
1883 parsed.type_info.scale, original.scale,
1884 "column {i} ({}) scale",
1885 original.name
1886 );
1887 }
1888 _ => {}
1889 }
1890 }
1891 }
1892
1893 #[test]
1897 fn test_write_colmetadata_not_null_uses_fixed_types() {
1898 use tds_protocol::token::ColMetaData;
1899 use tds_protocol::types::TypeId;
1900
1901 let columns = vec![
1902 BulkColumn::new("id", "INT", 0)
1903 .unwrap()
1904 .with_nullable(false),
1905 BulkColumn::new("tiny", "TINYINT", 1)
1906 .unwrap()
1907 .with_nullable(false),
1908 BulkColumn::new("small", "SMALLINT", 2)
1909 .unwrap()
1910 .with_nullable(false),
1911 BulkColumn::new("big", "BIGINT", 3)
1912 .unwrap()
1913 .with_nullable(false),
1914 BulkColumn::new("flag", "BIT", 4)
1915 .unwrap()
1916 .with_nullable(false),
1917 BulkColumn::new("r", "REAL", 5)
1918 .unwrap()
1919 .with_nullable(false),
1920 BulkColumn::new("f", "FLOAT", 6)
1921 .unwrap()
1922 .with_nullable(false),
1923 BulkColumn::new("dt", "DATETIME", 7)
1924 .unwrap()
1925 .with_nullable(false),
1926 BulkColumn::new("sdt", "SMALLDATETIME", 8)
1927 .unwrap()
1928 .with_nullable(false),
1929 BulkColumn::new("mny", "MONEY", 9)
1930 .unwrap()
1931 .with_nullable(false),
1932 BulkColumn::new("smny", "SMALLMONEY", 10)
1933 .unwrap()
1934 .with_nullable(false),
1935 ];
1936
1937 let bulk = BulkInsert::new(columns.clone(), 0);
1938
1939 for (i, fixed) in bulk.fixed_len.iter().enumerate() {
1941 assert!(
1942 *fixed,
1943 "column {i} ({}) should be fixed_len",
1944 columns[i].name
1945 );
1946 }
1947
1948 let buf = &bulk.buffer[1..]; let mut cursor = bytes::Bytes::copy_from_slice(buf);
1951 let meta = ColMetaData::decode(&mut cursor).expect("parseable");
1952
1953 let expected: &[(&str, TypeId)] = &[
1955 ("id", TypeId::Int4),
1956 ("tiny", TypeId::Int1),
1957 ("small", TypeId::Int2),
1958 ("big", TypeId::Int8),
1959 ("flag", TypeId::Bit),
1960 ("r", TypeId::Float4),
1961 ("f", TypeId::Float8),
1962 ("dt", TypeId::DateTime),
1963 ("sdt", TypeId::DateTime4),
1964 ("mny", TypeId::Money),
1965 ("smny", TypeId::Money4),
1966 ];
1967
1968 for (i, (name, ty)) in expected.iter().enumerate() {
1969 assert_eq!(meta.columns[i].name, *name, "column {i} name");
1970 assert_eq!(meta.columns[i].type_id, *ty, "column {i} ({name}) type");
1971 assert_eq!(
1972 meta.columns[i].flags & 0x0001,
1973 0,
1974 "column {i} ({name}) should not have Nullable flag set"
1975 );
1976 }
1977 }
1978
1979 #[test]
1983 fn test_write_colmetadata_uses_caller_collation() {
1984 use tds_protocol::token::{ColMetaData, Collation};
1985
1986 let chinese = Collation {
1988 lcid: 0x0804,
1989 sort_id: 0x52,
1990 };
1991
1992 let columns = vec![
1993 BulkColumn::new("s", "VARCHAR(50)", 0)
1994 .unwrap()
1995 .with_collation(chinese),
1996 BulkColumn::new("n", "NVARCHAR(50)", 1)
1998 .unwrap()
1999 .with_collation(chinese),
2000 BulkColumn::new("d", "VARCHAR(10)", 2).unwrap(),
2002 ];
2003 let bulk = BulkInsert::new(columns, 0);
2004
2005 let buf = &bulk.buffer[1..];
2006 let mut cursor = bytes::Bytes::copy_from_slice(buf);
2007 let meta = ColMetaData::decode(&mut cursor).expect("parseable");
2008
2009 let c0 = meta.columns[0]
2010 .type_info
2011 .collation
2012 .as_ref()
2013 .expect("VARCHAR has collation");
2014 assert_eq!(c0.lcid, chinese.lcid, "VARCHAR caller LCID");
2015 assert_eq!(c0.sort_id, chinese.sort_id, "VARCHAR caller sort_id");
2016
2017 let c1 = meta.columns[1]
2018 .type_info
2019 .collation
2020 .as_ref()
2021 .expect("NVARCHAR has collation");
2022 assert_eq!(c1.lcid, chinese.lcid, "NVARCHAR caller LCID");
2023 assert_eq!(c1.sort_id, chinese.sort_id, "NVARCHAR caller sort_id");
2024
2025 let default = meta.columns[2]
2028 .type_info
2029 .collation
2030 .as_ref()
2031 .expect("VARCHAR has default collation");
2032 assert_eq!(default.to_bytes(), [0x09, 0x04, 0xD0, 0x00, 0x34]);
2033 }
2034
2035 #[test]
2036 fn test_parse_sql_type_max() {
2037 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(MAX)");
2039 assert_eq!(type_id, 0xE7);
2040 assert_eq!(len, Some(0xFFFF)); let (type_id, len, _, _) = parse_sql_type("VARBINARY(MAX)");
2044 assert_eq!(type_id, 0xA5);
2045 assert_eq!(len, Some(0xFFFF));
2046
2047 let (type_id, len, _, _) = parse_sql_type("VARCHAR(MAX)");
2049 assert_eq!(type_id, 0xA7);
2050 assert_eq!(len, Some(0xFFFF));
2051
2052 let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
2054 assert_eq!(type_id, 0xE7);
2055 assert_eq!(len, Some(200)); }
2057}