Skip to main content

mssql_client/
bulk.rs

1//! Bulk Copy Protocol (BCP) support.
2//!
3//! This module provides first-class support for bulk insert operations using
4//! the TDS Bulk Load protocol (packet type 0x07). BCP is significantly more
5//! efficient than individual INSERT statements for loading large amounts of data.
6//!
7//! ## Performance Benefits
8//!
9//! - Minimal logging (when using simple recovery model)
10//! - Batch commits reduce lock contention
11//! - Direct data streaming without SQL parsing overhead
12//! - Optional table lock for maximum throughput
13//!
14//! ## Usage
15//!
16//! ```rust,ignore
17//! use mssql_client::{Client, BulkInsert, BulkOptions};
18//!
19//! let mut bulk = client
20//!     .bulk_insert("dbo.Users")
21//!     .with_columns(&["id", "name", "email"])
22//!     .with_options(BulkOptions {
23//!         batch_size: 1000,
24//!         check_constraints: true,
25//!         fire_triggers: false,
26//!         keep_nulls: true,
27//!         table_lock: true,
28//!     })
29//!     .build()
30//!     .await?;
31//!
32//! // Send rows
33//! for user in users {
34//!     bulk.send_row(&[&user.id, &user.name, &user.email]).await?;
35//! }
36//!
37//! let result = bulk.finish().await?;
38//! println!("Inserted {} rows", result.rows_affected);
39//! ```
40//!
41//! ## Implementation Notes
42//!
43//! The bulk load protocol uses:
44//! - Packet type 0x07 (BulkLoad)
45//! - COLMETADATA token describing column structure
46//! - ROW tokens containing actual data
47//! - DONE token signaling completion
48//!
49//! Per MS-TDS specification, the row data format matches the server output format
50//! (same as SELECT results) rather than storage format.
51
52use bytes::{BufMut, BytesMut};
53use once_cell::sync::Lazy;
54use regex::Regex;
55use std::sync::Arc;
56
57use mssql_types::{SqlValue, ToSql, TypeError};
58use tds_protocol::packet::{PacketHeader, PacketStatus, PacketType};
59use tds_protocol::token::{DoneStatus, TokenType};
60
61use crate::error::Error;
62
63/// Options controlling bulk insert behavior.
64///
65/// These options map to SQL Server's BULK INSERT hints and
66/// affect performance, logging, and constraint checking.
67#[derive(Debug, Clone)]
68pub struct BulkOptions {
69    /// Number of rows per batch commit.
70    ///
71    /// Smaller batches use less memory but have more overhead.
72    /// Larger batches are more efficient but hold locks longer.
73    /// Default: 0 (single batch for entire operation).
74    pub batch_size: usize,
75
76    /// Check constraints during insert.
77    ///
78    /// Default: true
79    pub check_constraints: bool,
80
81    /// Fire INSERT triggers on the table.
82    ///
83    /// Default: false (better performance)
84    pub fire_triggers: bool,
85
86    /// Keep NULL values instead of using column defaults.
87    ///
88    /// Default: true
89    pub keep_nulls: bool,
90
91    /// Acquire a table-level lock for the duration of the bulk operation.
92    ///
93    /// This can significantly improve performance by reducing lock
94    /// escalation overhead, but blocks all other access to the table.
95    /// Default: false
96    pub table_lock: bool,
97
98    /// Order hint for the data being inserted.
99    ///
100    /// If data is pre-sorted by the clustered index, specify the columns
101    /// here to avoid a sort operation on the server.
102    /// Default: None
103    pub order_hint: Option<Vec<String>>,
104
105    /// Maximum errors allowed before aborting.
106    ///
107    /// Default: 0 (abort on first error)
108    pub max_errors: u32,
109}
110
111impl Default for BulkOptions {
112    fn default() -> Self {
113        Self {
114            batch_size: 0,
115            check_constraints: true,
116            fire_triggers: false,
117            keep_nulls: true,
118            table_lock: false,
119            order_hint: None,
120            max_errors: 0,
121        }
122    }
123}
124
125/// Column definition for bulk insert.
126#[derive(Debug, Clone)]
127pub struct BulkColumn {
128    /// Column name.
129    pub name: String,
130    /// SQL Server type (e.g., "INT", "NVARCHAR(100)").
131    pub sql_type: String,
132    /// Whether the column allows NULL values.
133    pub nullable: bool,
134    /// Column ordinal (0-based).
135    pub ordinal: usize,
136    /// TDS type ID.
137    type_id: u8,
138    /// Maximum length for variable-length types.
139    max_length: Option<u32>,
140    /// Precision for decimal types.
141    precision: Option<u8>,
142    /// Scale for decimal types.
143    scale: Option<u8>,
144}
145
146impl BulkColumn {
147    /// Create a new bulk column definition.
148    pub fn new<S: Into<String>>(name: S, sql_type: S, ordinal: usize) -> Self {
149        let sql_type_str: String = sql_type.into();
150        let (type_id, max_length, precision, scale) = parse_sql_type(&sql_type_str);
151
152        Self {
153            name: name.into(),
154            sql_type: sql_type_str,
155            nullable: true,
156            ordinal,
157            type_id,
158            max_length,
159            precision,
160            scale,
161        }
162    }
163
164    /// Set whether this column allows NULL values.
165    #[must_use]
166    pub fn with_nullable(mut self, nullable: bool) -> Self {
167        self.nullable = nullable;
168        self
169    }
170}
171
172/// Parse SQL type string into TDS type information.
173///
174/// Type parameters (e.g., the "100" in `VARCHAR(100)`) are parsed with
175/// `.parse().ok()` — if a parameter is malformed it falls through to the
176/// type's SQL Server default length (e.g., 8000 for VARCHAR, 4000 for
177/// NVARCHAR). This is intentional: bulk-insert column definitions come
178/// from user code, and defaulting to max length is safer than rejecting
179/// the operation when the base type is valid.
180fn parse_sql_type(sql_type: &str) -> (u8, Option<u32>, Option<u8>, Option<u8>) {
181    let upper = sql_type.to_uppercase();
182
183    // Extract base type and parameters
184    let (base, params) = if let Some(paren_pos) = upper.find('(') {
185        let base = &upper[..paren_pos];
186        let params_str = upper[paren_pos + 1..].trim_end_matches(')');
187        (base, Some(params_str))
188    } else {
189        (upper.as_str(), None)
190    };
191
192    match base {
193        "BIT" => (0x32, None, None, None),
194        "TINYINT" => (0x30, None, None, None),
195        "SMALLINT" => (0x34, None, None, None),
196        "INT" => (0x38, None, None, None),
197        "BIGINT" => (0x7F, None, None, None),
198        "REAL" => (0x3B, None, None, None),
199        "FLOAT" => (0x3E, None, None, None),
200        "DATE" => (0x28, None, None, None),
201        "TIME" => {
202            let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
203            (0x29, None, None, Some(scale))
204        }
205        "DATETIME" => (0x3D, None, None, None),
206        "DATETIME2" => {
207            let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
208            (0x2A, None, None, Some(scale))
209        }
210        "DATETIMEOFFSET" => {
211            let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
212            (0x2B, None, None, Some(scale))
213        }
214        "SMALLDATETIME" => (0x3F, None, None, None),
215        "UNIQUEIDENTIFIER" => (0x24, Some(16), None, None),
216        "VARCHAR" | "CHAR" => {
217            let len = params
218                .and_then(|p| {
219                    if p == "MAX" {
220                        Some(0xFFFF_u32)
221                    } else {
222                        p.parse().ok()
223                    }
224                })
225                .unwrap_or(8000);
226            (0xA7, Some(len), None, None)
227        }
228        "NVARCHAR" | "NCHAR" => {
229            let is_max = params.map(|p| p == "MAX").unwrap_or(false);
230            if is_max {
231                // MAX types use 0xFFFF marker (not doubled)
232                (0xE7, Some(0xFFFF), None, None)
233            } else {
234                // Normal lengths are in characters, double for UTF-16 byte length
235                let len = params.and_then(|p| p.parse().ok()).unwrap_or(4000);
236                (0xE7, Some(len * 2), None, None)
237            }
238        }
239        "VARBINARY" | "BINARY" => {
240            let len = params
241                .and_then(|p| {
242                    if p == "MAX" {
243                        Some(0xFFFF_u32)
244                    } else {
245                        p.parse().ok()
246                    }
247                })
248                .unwrap_or(8000);
249            (0xA5, Some(len), None, None)
250        }
251        "DECIMAL" | "NUMERIC" => {
252            let (precision, scale) = if let Some(p) = params {
253                let parts: Vec<&str> = p.split(',').map(|s| s.trim()).collect();
254                (
255                    parts.first().and_then(|s| s.parse().ok()).unwrap_or(18),
256                    parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0),
257                )
258            } else {
259                (18, 0)
260            };
261            (0x6C, None, Some(precision), Some(scale))
262        }
263        "MONEY" => (0x3C, Some(8), None, None),
264        "SMALLMONEY" => (0x7A, Some(4), None, None),
265        "XML" => (0xF1, Some(0xFFFF), None, None),
266        "TEXT" => (0x23, Some(0x7FFF_FFFF), None, None),
267        "NTEXT" => (0x63, Some(0x7FFF_FFFF), None, None),
268        "IMAGE" => (0x22, Some(0x7FFF_FFFF), None, None),
269        _ => (0xE7, Some(8000), None, None), // Default to NVARCHAR(4000)
270    }
271}
272
273/// Result of a bulk insert operation.
274#[derive(Debug, Clone)]
275pub struct BulkInsertResult {
276    /// Total number of rows inserted.
277    pub rows_affected: u64,
278    /// Number of batches committed.
279    pub batches_committed: u32,
280    /// Whether any errors were encountered.
281    pub has_errors: bool,
282}
283
284/// Builder for configuring a bulk insert operation.
285#[derive(Debug)]
286pub struct BulkInsertBuilder {
287    table_name: String,
288    columns: Vec<BulkColumn>,
289    options: BulkOptions,
290}
291
292impl BulkInsertBuilder {
293    /// Create a new bulk insert builder for the specified table.
294    pub fn new<S: Into<String>>(table_name: S) -> Self {
295        Self {
296            table_name: table_name.into(),
297            columns: Vec::new(),
298            options: BulkOptions::default(),
299        }
300    }
301
302    /// Specify the columns to insert.
303    ///
304    /// Columns will be queried from the server if not specified,
305    /// but providing them explicitly is more efficient.
306    #[must_use]
307    pub fn with_columns(mut self, column_names: &[&str]) -> Self {
308        self.columns = column_names
309            .iter()
310            .enumerate()
311            .map(|(i, name)| BulkColumn::new(*name, "NVARCHAR(MAX)", i))
312            .collect();
313        self
314    }
315
316    /// Specify columns with full type information.
317    #[must_use]
318    pub fn with_typed_columns(mut self, columns: Vec<BulkColumn>) -> Self {
319        self.columns = columns;
320        self
321    }
322
323    /// Set bulk insert options.
324    #[must_use]
325    pub fn with_options(mut self, options: BulkOptions) -> Self {
326        self.options = options;
327        self
328    }
329
330    /// Set the batch size.
331    #[must_use]
332    pub fn batch_size(mut self, size: usize) -> Self {
333        self.options.batch_size = size;
334        self
335    }
336
337    /// Enable or disable table lock.
338    #[must_use]
339    pub fn table_lock(mut self, enabled: bool) -> Self {
340        self.options.table_lock = enabled;
341        self
342    }
343
344    /// Enable or disable trigger firing.
345    #[must_use]
346    pub fn fire_triggers(mut self, enabled: bool) -> Self {
347        self.options.fire_triggers = enabled;
348        self
349    }
350
351    /// Get the table name.
352    pub fn table_name(&self) -> &str {
353        &self.table_name
354    }
355
356    /// Get the columns.
357    pub fn columns(&self) -> &[BulkColumn] {
358        &self.columns
359    }
360
361    /// Get the options.
362    pub fn options(&self) -> &BulkOptions {
363        &self.options
364    }
365
366    /// Build the INSERT BULK SQL statement.
367    ///
368    /// # Errors
369    ///
370    /// Returns an error if the table name or any column name fails identifier
371    /// validation, preventing SQL injection.
372    pub fn build_insert_bulk_statement(&self) -> Result<String, Error> {
373        // Validate table name (may be schema-qualified: dbo.Users, catalog.schema.table)
374        crate::validation::validate_qualified_identifier(&self.table_name)?;
375
376        // Validate column names
377        for col in &self.columns {
378            crate::validation::validate_identifier(&col.name)?;
379        }
380
381        let mut sql = format!("INSERT BULK {}", self.table_name);
382
383        // Add column definitions
384        if !self.columns.is_empty() {
385            sql.push_str(" (");
386            let cols: Vec<String> = self
387                .columns
388                .iter()
389                .map(|c| {
390                    // Validate sql_type to prevent SQL injection: only allow
391                    // alphanumerics, parentheses (for length/precision), commas,
392                    // spaces, and the MAX keyword — which covers all valid T-SQL
393                    // type specifiers like "NVARCHAR(100)", "DECIMAL(18, 2)",
394                    // "VARBINARY(MAX)", etc.
395                    validate_sql_type(&c.sql_type)?;
396                    Ok(format!("{} {}", c.name, c.sql_type))
397                })
398                .collect::<Result<Vec<_>, Error>>()?;
399            sql.push_str(&cols.join(", "));
400            sql.push(')');
401        }
402
403        // Add WITH clause for options
404        let mut hints: Vec<String> = Vec::new();
405
406        if self.options.check_constraints {
407            hints.push("CHECK_CONSTRAINTS".to_string());
408        }
409        if self.options.fire_triggers {
410            hints.push("FIRE_TRIGGERS".to_string());
411        }
412        if self.options.keep_nulls {
413            hints.push("KEEP_NULLS".to_string());
414        }
415        if self.options.table_lock {
416            hints.push("TABLOCK".to_string());
417        }
418        if self.options.batch_size > 0 {
419            hints.push(format!("ROWS_PER_BATCH = {}", self.options.batch_size));
420        }
421
422        if let Some(ref order) = self.options.order_hint {
423            // Validate order hint column names
424            for col_name in order {
425                crate::validation::validate_identifier(col_name)?;
426            }
427            hints.push(format!("ORDER({})", order.join(", ")));
428        }
429
430        if !hints.is_empty() {
431            sql.push_str(" WITH (");
432            sql.push_str(&hints.join(", "));
433            sql.push(')');
434        }
435
436        Ok(sql)
437    }
438}
439
440/// Validate a SQL type specifier to prevent SQL injection.
441///
442/// Allows only characters that can appear in valid T-SQL type declarations:
443/// letters, digits, parentheses, commas, spaces, and periods.
444/// Examples: "INT", "NVARCHAR(100)", "DECIMAL(18, 2)", "VARBINARY(MAX)".
445fn validate_sql_type(type_str: &str) -> Result<(), Error> {
446    #[allow(clippy::expect_used)] // Static regex compilation with known-valid pattern
447    static SQL_TYPE_RE: Lazy<Regex> =
448        Lazy::new(|| Regex::new(r"^[a-zA-Z][a-zA-Z0-9_ ()\.,]{0,127}$").expect("valid regex"));
449
450    if type_str.is_empty() {
451        return Err(Error::Config("SQL type cannot be empty".into()));
452    }
453
454    if !SQL_TYPE_RE.is_match(type_str) {
455        return Err(Error::Config(format!(
456            "invalid SQL type '{type_str}': contains disallowed characters"
457        )));
458    }
459
460    Ok(())
461}
462
463/// Active bulk insert operation.
464///
465/// This struct manages the streaming of row data to the server.
466/// Call `send_row()` for each row, then `finish()` to complete.
467pub struct BulkInsert {
468    /// Column metadata.
469    columns: Arc<[BulkColumn]>,
470    /// Buffer for accumulating rows.
471    buffer: BytesMut,
472    /// Rows in current batch.
473    rows_in_batch: usize,
474    /// Total rows sent.
475    total_rows: u64,
476    /// Batch size (0 = single batch).
477    batch_size: usize,
478    /// Number of batches committed.
479    batches_committed: u32,
480    /// Packet ID counter.
481    packet_id: u8,
482}
483
484impl BulkInsert {
485    /// Create a new bulk insert operation.
486    pub fn new(columns: Vec<BulkColumn>, batch_size: usize) -> Self {
487        let mut bulk = Self {
488            columns: columns.into(),
489            buffer: BytesMut::with_capacity(64 * 1024), // 64KB initial buffer
490            rows_in_batch: 0,
491            total_rows: 0,
492            batch_size,
493            batches_committed: 0,
494            packet_id: 1,
495        };
496
497        // Write COLMETADATA token
498        bulk.write_colmetadata();
499
500        bulk
501    }
502
503    /// Write the COLMETADATA token to the buffer.
504    fn write_colmetadata(&mut self) {
505        let buf = &mut self.buffer;
506
507        // Token type
508        buf.put_u8(TokenType::ColMetaData as u8);
509
510        // Column count
511        buf.put_u16_le(self.columns.len() as u16);
512
513        for col in self.columns.iter() {
514            // User type (always 0 for basic types)
515            buf.put_u32_le(0);
516
517            // Flags: Nullable (bit 0) | CaseSen (bit 1) | Updateable (bits 2-3) | etc.
518            let flags: u16 = if col.nullable { 0x0001 } else { 0x0000 };
519            buf.put_u16_le(flags);
520
521            // Type info
522            buf.put_u8(col.type_id);
523
524            // Type-specific length/precision/scale
525            match col.type_id {
526                // Fixed-length types - no additional info needed
527                0x32 | 0x30 | 0x34 | 0x38 | 0x7F | 0x3B | 0x3E | 0x3D | 0x3F | 0x28 => {}
528
529                // Variable-length string/binary types
530                0xE7 | 0xA7 | 0xA5 | 0xAD => {
531                    // Max length (2 bytes for normal, 4 bytes for MAX)
532                    let max_len = col.max_length.unwrap_or(8000);
533                    if max_len == 0xFFFF {
534                        buf.put_u16_le(0xFFFF);
535                    } else {
536                        buf.put_u16_le(max_len as u16);
537                    }
538
539                    // Collation for string types (5 bytes)
540                    if col.type_id == 0xE7 || col.type_id == 0xA7 {
541                        // Default collation (Latin1_General_CI_AS)
542                        buf.put_u32_le(0x0409_0904); // LCID + flags
543                        buf.put_u8(52); // Sort ID
544                    }
545                }
546
547                // Decimal/Numeric
548                0x6C | 0x6A => {
549                    // Length (calculated from precision)
550                    let precision = col.precision.unwrap_or(18);
551                    let len = decimal_byte_length(precision);
552                    buf.put_u8(len);
553                    buf.put_u8(precision);
554                    buf.put_u8(col.scale.unwrap_or(0));
555                }
556
557                // Time-based with scale
558                0x29..=0x2B => {
559                    buf.put_u8(col.scale.unwrap_or(7));
560                }
561
562                // GUID
563                0x24 => {
564                    buf.put_u8(16);
565                }
566
567                // Other types - write max length if present
568                _ => {
569                    if let Some(len) = col.max_length {
570                        if len <= 0xFFFF {
571                            buf.put_u16_le(len as u16);
572                        }
573                    }
574                }
575            }
576
577            // Column name (B_VARCHAR format: 1-byte length prefix)
578            let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
579            buf.put_u8(name_utf16.len() as u8);
580            for code_unit in name_utf16 {
581                buf.put_u16_le(code_unit);
582            }
583        }
584    }
585
586    /// Send a row of data.
587    ///
588    /// The values must match the column order and types specified
589    /// when creating the bulk insert.
590    ///
591    /// # Errors
592    ///
593    /// Returns an error if:
594    /// - Wrong number of values provided
595    /// - A value cannot be converted to the expected type
596    pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
597        if values.len() != self.columns.len() {
598            return Err(Error::Config(format!(
599                "expected {} values, got {}",
600                self.columns.len(),
601                values.len()
602            )));
603        }
604
605        // Convert all values to SqlValue
606        let sql_values: Result<Vec<SqlValue>, TypeError> =
607            values.iter().map(|v| v.to_sql()).collect();
608        let sql_values = sql_values.map_err(Error::from)?;
609
610        self.write_row(&sql_values)?;
611
612        self.rows_in_batch += 1;
613        self.total_rows += 1;
614
615        Ok(())
616    }
617
618    /// Send a row of pre-converted SQL values.
619    pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
620        if values.len() != self.columns.len() {
621            return Err(Error::Config(format!(
622                "expected {} values, got {}",
623                self.columns.len(),
624                values.len()
625            )));
626        }
627
628        self.write_row(values)?;
629
630        self.rows_in_batch += 1;
631        self.total_rows += 1;
632
633        Ok(())
634    }
635
636    /// Write a ROW token to the buffer.
637    fn write_row(&mut self, values: &[SqlValue]) -> Result<(), Error> {
638        // ROW token type
639        self.buffer.put_u8(TokenType::Row as u8);
640
641        // Collect column info needed for encoding to avoid borrow conflict
642        let columns: Vec<_> = self.columns.iter().cloned().collect();
643
644        // Write each column value
645        for (i, (col, value)) in columns.iter().zip(values.iter()).enumerate() {
646            self.encode_column_value(col, value)
647                .map_err(|e| Error::Config(format!("failed to encode column {i}: {e}")))?;
648        }
649
650        Ok(())
651    }
652
653    /// Encode a column value according to its type.
654    fn encode_column_value(&mut self, col: &BulkColumn, value: &SqlValue) -> Result<(), TypeError> {
655        let buf = &mut self.buffer;
656
657        // Check if this column uses PLP (Partially Length-Prefixed) encoding
658        // MAX types (max_length == 0xFFFF) use PLP format
659        let is_plp_type =
660            col.max_length == Some(0xFFFF) && matches!(col.type_id, 0xE7 | 0xA7 | 0xA5 | 0xAD);
661
662        match value {
663            SqlValue::Null => {
664                // NULL encoding depends on type
665                match col.type_id {
666                    // Variable-length types
667                    0xE7 | 0xA7 | 0xA5 | 0xAD => {
668                        if is_plp_type {
669                            // PLP NULL: 0xFFFFFFFFFFFFFFFF
670                            buf.put_u64_le(0xFFFF_FFFF_FFFF_FFFF);
671                        } else {
672                            // Standard NULL: 0xFFFF length marker
673                            buf.put_u16_le(0xFFFF);
674                        }
675                    }
676                    // Nullable fixed types use 0 length
677                    0x26 | 0x6C | 0x6A | 0x24 | 0x29 | 0x2A | 0x2B => {
678                        buf.put_u8(0);
679                    }
680                    // Fixed types without nullable variant
681                    _ => {
682                        if col.nullable {
683                            buf.put_u8(0);
684                        } else {
685                            return Err(TypeError::UnexpectedNull);
686                        }
687                    }
688                }
689            }
690
691            SqlValue::Bool(v) => {
692                buf.put_u8(1); // Length
693                buf.put_u8(if *v { 1 } else { 0 });
694            }
695
696            SqlValue::TinyInt(v) => {
697                buf.put_u8(1); // Length
698                buf.put_u8(*v);
699            }
700
701            SqlValue::SmallInt(v) => {
702                buf.put_u8(2); // Length
703                buf.put_i16_le(*v);
704            }
705
706            SqlValue::Int(v) => {
707                buf.put_u8(4); // Length
708                buf.put_i32_le(*v);
709            }
710
711            SqlValue::BigInt(v) => {
712                buf.put_u8(8); // Length
713                buf.put_i64_le(*v);
714            }
715
716            SqlValue::Float(v) => {
717                buf.put_u8(4); // Length
718                buf.put_f32_le(*v);
719            }
720
721            SqlValue::Double(v) => {
722                buf.put_u8(8); // Length
723                buf.put_f64_le(*v);
724            }
725
726            SqlValue::String(s) => {
727                // UTF-16LE encoding for NVARCHAR
728                let utf16: Vec<u16> = s.encode_utf16().collect();
729                let byte_len = utf16.len() * 2;
730
731                if is_plp_type {
732                    // PLP format for MAX types - supports unlimited size
733                    // Send as a single chunk for simplicity
734                    encode_plp_string(&utf16, buf);
735                } else if byte_len > 0xFFFF {
736                    // Non-MAX column can't hold this much data
737                    return Err(TypeError::BufferTooSmall {
738                        needed: byte_len,
739                        available: 0xFFFF,
740                    });
741                } else {
742                    // Standard encoding with 2-byte length prefix
743                    buf.put_u16_le(byte_len as u16);
744                    for code_unit in utf16 {
745                        buf.put_u16_le(code_unit);
746                    }
747                }
748            }
749
750            SqlValue::Binary(b) => {
751                if is_plp_type {
752                    // PLP format for MAX types - supports unlimited size
753                    encode_plp_binary(b, buf);
754                } else if b.len() > 0xFFFF {
755                    // Non-MAX column can't hold this much data
756                    return Err(TypeError::BufferTooSmall {
757                        needed: b.len(),
758                        available: 0xFFFF,
759                    });
760                } else {
761                    // Standard encoding with 2-byte length prefix
762                    buf.put_u16_le(b.len() as u16);
763                    buf.put_slice(b);
764                }
765            }
766
767            // Feature-gated types - use mssql_types::encode module
768            #[cfg(feature = "decimal")]
769            SqlValue::Decimal(d) => {
770                let precision = col.precision.unwrap_or(18);
771                let len = decimal_byte_length(precision);
772                buf.put_u8(len);
773
774                // Sign: 0 = negative, 1 = positive
775                buf.put_u8(if d.is_sign_negative() { 0 } else { 1 });
776
777                // Mantissa as unsigned 128-bit integer
778                let mantissa = d.mantissa().unsigned_abs();
779                let mantissa_bytes = mantissa.to_le_bytes();
780                buf.put_slice(&mantissa_bytes[..((len - 1) as usize)]);
781            }
782
783            #[cfg(feature = "uuid")]
784            SqlValue::Uuid(u) => {
785                buf.put_u8(16); // Length
786                // Use mssql_types encode function
787                mssql_types::encode::encode_uuid(*u, buf);
788            }
789
790            #[cfg(feature = "chrono")]
791            SqlValue::Date(d) => {
792                buf.put_u8(3); // Length
793                mssql_types::encode::encode_date(*d, buf);
794            }
795
796            #[cfg(feature = "chrono")]
797            SqlValue::Time(t) => {
798                let scale = col.scale.unwrap_or(7);
799                let len = time_byte_length(scale);
800                buf.put_u8(len);
801                // Encode time with proper scale handling
802                encode_time_with_scale(*t, scale, buf);
803            }
804
805            #[cfg(feature = "chrono")]
806            SqlValue::DateTime(dt) => {
807                let scale = col.scale.unwrap_or(7);
808                let time_len = time_byte_length(scale);
809                let total_len = time_len + 3;
810                buf.put_u8(total_len);
811                // Encode time then date
812                encode_time_with_scale(dt.time(), scale, buf);
813                mssql_types::encode::encode_date(dt.date(), buf);
814            }
815
816            #[cfg(feature = "chrono")]
817            SqlValue::DateTimeOffset(dto) => {
818                let scale = col.scale.unwrap_or(7);
819                let time_len = time_byte_length(scale);
820                let total_len = time_len + 3 + 2;
821                buf.put_u8(total_len);
822                // Use mssql_types encode
823                encode_time_with_scale(dto.time(), scale, buf);
824                mssql_types::encode::encode_date(dto.date_naive(), buf);
825                // Timezone offset in minutes
826                use chrono::Offset;
827                let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
828                buf.put_i16_le(offset_minutes);
829            }
830
831            #[cfg(feature = "json")]
832            SqlValue::Json(j) => {
833                let s = j.to_string();
834                encode_nvarchar_value(&s, buf)?;
835            }
836
837            SqlValue::Xml(x) => {
838                encode_nvarchar_value(x, buf)?;
839            }
840
841            SqlValue::Tvp(_) => {
842                // TVPs are not valid in bulk copy operations - they're for RPC parameters only
843                return Err(TypeError::UnsupportedConversion {
844                    from: "TVP".to_string(),
845                    to: "bulk copy value",
846                });
847            }
848            // Handle future SqlValue variants
849            _ => {
850                return Err(TypeError::UnsupportedConversion {
851                    from: value.type_name().to_string(),
852                    to: "bulk copy value",
853                });
854            }
855        }
856
857        Ok(())
858    }
859}
860
861/// Encode a string as NVARCHAR with length prefix.
862fn encode_nvarchar_value(s: &str, buf: &mut BytesMut) -> Result<(), TypeError> {
863    let utf16: Vec<u16> = s.encode_utf16().collect();
864    let byte_len = utf16.len() * 2;
865
866    if byte_len > 0xFFFF {
867        return Err(TypeError::BufferTooSmall {
868            needed: byte_len,
869            available: 0xFFFF,
870        });
871    }
872
873    buf.put_u16_le(byte_len as u16);
874    for code_unit in utf16 {
875        buf.put_u16_le(code_unit);
876    }
877    Ok(())
878}
879
880/// Encode a UTF-16 string using PLP (Partially Length-Prefixed) format.
881///
882/// PLP format (per MS-TDS specification):
883/// - 8 bytes: total length in bytes (little-endian)
884/// - Chunks: 4-byte chunk length + data, repeated
885/// - Terminator: 4 bytes of zero
886///
887/// For simplicity, we send the entire value as a single chunk.
888/// This is efficient for bulk operations where we already have the complete data.
889fn encode_plp_string(utf16: &[u16], buf: &mut BytesMut) {
890    let byte_len = utf16.len() * 2;
891
892    // Total length (8 bytes)
893    buf.put_u64_le(byte_len as u64);
894
895    if byte_len > 0 {
896        // Single chunk: length (4 bytes) + data
897        buf.put_u32_le(byte_len as u32);
898        for code_unit in utf16 {
899            buf.put_u16_le(*code_unit);
900        }
901    }
902
903    // Terminator chunk (length = 0)
904    buf.put_u32_le(0);
905}
906
907/// Encode binary data using PLP (Partially Length-Prefixed) format.
908///
909/// PLP format (per MS-TDS specification):
910/// - 8 bytes: total length in bytes (little-endian)
911/// - Chunks: 4-byte chunk length + data, repeated
912/// - Terminator: 4 bytes of zero
913///
914/// For simplicity, we send the entire value as a single chunk.
915fn encode_plp_binary(data: &[u8], buf: &mut BytesMut) {
916    // Total length (8 bytes)
917    buf.put_u64_le(data.len() as u64);
918
919    if !data.is_empty() {
920        // Single chunk: length (4 bytes) + data
921        buf.put_u32_le(data.len() as u32);
922        buf.put_slice(data);
923    }
924
925    // Terminator chunk (length = 0)
926    buf.put_u32_le(0);
927}
928
929/// Encode time with specific scale (for bulk copy).
930#[cfg(feature = "chrono")]
931fn encode_time_with_scale(time: chrono::NaiveTime, scale: u8, buf: &mut BytesMut) {
932    use chrono::Timelike;
933
934    let nanos = time.num_seconds_from_midnight() as u64 * 1_000_000_000 + time.nanosecond() as u64;
935    let intervals = nanos / time_scale_divisor(scale);
936    let len = time_byte_length(scale);
937
938    for i in 0..len {
939        buf.put_u8(((intervals >> (i * 8)) & 0xFF) as u8);
940    }
941}
942
943impl BulkInsert {
944    /// Write the DONE token signaling completion.
945    fn write_done(&mut self) {
946        let buf = &mut self.buffer;
947
948        buf.put_u8(TokenType::Done as u8);
949
950        // Status: FINAL (0x00) | COUNT (0x10)
951        let status = DoneStatus::from_bits(0x0010); // DONE_COUNT
952        buf.put_u16_le(status.to_bits());
953
954        // Current command (0 for bulk load)
955        buf.put_u16_le(0);
956
957        // Row count
958        buf.put_u64_le(self.total_rows);
959    }
960
961    /// Get the buffered data as packets ready to send.
962    ///
963    /// Returns a vector of complete TDS packets with BulkLoad packet type (0x07).
964    pub fn take_packets(&mut self) -> Vec<BytesMut> {
965        const MAX_PACKET_SIZE: usize = 4096;
966        const HEADER_SIZE: usize = 8;
967        const MAX_PAYLOAD: usize = MAX_PACKET_SIZE - HEADER_SIZE;
968
969        let data = self.buffer.split();
970        let mut packets = Vec::new();
971        let mut offset = 0;
972
973        while offset < data.len() {
974            let remaining = data.len() - offset;
975            let payload_size = remaining.min(MAX_PAYLOAD);
976            let is_last = offset + payload_size >= data.len();
977
978            let mut packet = BytesMut::with_capacity(MAX_PACKET_SIZE);
979
980            // Write packet header
981            let header = PacketHeader {
982                packet_type: PacketType::BulkLoad,
983                status: if is_last {
984                    PacketStatus::END_OF_MESSAGE
985                } else {
986                    PacketStatus::NORMAL
987                },
988                length: (HEADER_SIZE + payload_size) as u16,
989                spid: 0,
990                packet_id: self.packet_id,
991                window: 0,
992            };
993
994            header.encode(&mut packet);
995
996            // Write payload
997            packet.put_slice(&data[offset..offset + payload_size]);
998
999            packets.push(packet);
1000            offset += payload_size;
1001            self.packet_id = self.packet_id.wrapping_add(1);
1002        }
1003
1004        packets
1005    }
1006
1007    /// Get total rows sent so far.
1008    pub fn total_rows(&self) -> u64 {
1009        self.total_rows
1010    }
1011
1012    /// Get rows in current batch.
1013    pub fn rows_in_batch(&self) -> usize {
1014        self.rows_in_batch
1015    }
1016
1017    /// Check if a batch flush is needed.
1018    pub fn should_flush(&self) -> bool {
1019        self.batch_size > 0 && self.rows_in_batch >= self.batch_size
1020    }
1021
1022    /// Prepare for finishing the bulk operation.
1023    /// Writes the DONE token and returns final packets.
1024    pub fn finish_packets(&mut self) -> Vec<BytesMut> {
1025        self.write_done();
1026        self.take_packets()
1027    }
1028
1029    /// Create a result from the current state.
1030    pub fn result(&self) -> BulkInsertResult {
1031        BulkInsertResult {
1032            rows_affected: self.total_rows,
1033            batches_committed: self.batches_committed,
1034            has_errors: false,
1035        }
1036    }
1037}
1038
1039/// Calculate byte length for decimal based on precision.
1040fn decimal_byte_length(precision: u8) -> u8 {
1041    match precision {
1042        1..=9 => 5,
1043        10..=19 => 9,
1044        20..=28 => 13,
1045        29..=38 => 17,
1046        _ => 17, // Max precision
1047    }
1048}
1049
1050/// Calculate byte length for time based on scale.
1051#[cfg(feature = "chrono")]
1052fn time_byte_length(scale: u8) -> u8 {
1053    match scale {
1054        0..=2 => 3,
1055        3..=4 => 4,
1056        5..=7 => 5,
1057        _ => 5,
1058    }
1059}
1060
1061/// Get the divisor for time scale.
1062#[cfg(feature = "chrono")]
1063fn time_scale_divisor(scale: u8) -> u64 {
1064    match scale {
1065        0 => 1_000_000_000,
1066        1 => 100_000_000,
1067        2 => 10_000_000,
1068        3 => 1_000_000,
1069        4 => 100_000,
1070        5 => 10_000,
1071        6 => 1_000,
1072        7 => 100,
1073        _ => 100,
1074    }
1075}
1076
1077#[cfg(test)]
1078#[allow(clippy::unwrap_used)]
1079mod tests {
1080    use super::*;
1081
1082    #[test]
1083    fn test_bulk_options_default() {
1084        let opts = BulkOptions::default();
1085        assert_eq!(opts.batch_size, 0);
1086        assert!(opts.check_constraints);
1087        assert!(!opts.fire_triggers);
1088        assert!(opts.keep_nulls);
1089        assert!(!opts.table_lock);
1090    }
1091
1092    #[test]
1093    fn test_bulk_column_creation() {
1094        let col = BulkColumn::new("id", "INT", 0);
1095        assert_eq!(col.name, "id");
1096        assert_eq!(col.type_id, 0x38);
1097        assert!(col.nullable);
1098    }
1099
1100    #[test]
1101    fn test_parse_sql_type() {
1102        let (type_id, len, _prec, _scale) = parse_sql_type("INT");
1103        assert_eq!(type_id, 0x38);
1104        assert!(len.is_none());
1105
1106        let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
1107        assert_eq!(type_id, 0xE7);
1108        assert_eq!(len, Some(200)); // UTF-16 doubles
1109
1110        let (type_id, _, prec, scale) = parse_sql_type("DECIMAL(10,2)");
1111        assert_eq!(type_id, 0x6C);
1112        assert_eq!(prec, Some(10));
1113        assert_eq!(scale, Some(2));
1114    }
1115
1116    #[test]
1117    fn test_insert_bulk_statement() {
1118        let builder = BulkInsertBuilder::new("dbo.Users")
1119            .with_typed_columns(vec![
1120                BulkColumn::new("id", "INT", 0),
1121                BulkColumn::new("name", "NVARCHAR(100)", 1),
1122            ])
1123            .table_lock(true);
1124
1125        let sql = builder.build_insert_bulk_statement().unwrap();
1126        assert!(sql.contains("INSERT BULK dbo.Users"));
1127        assert!(sql.contains("TABLOCK"));
1128    }
1129
1130    #[test]
1131    fn test_bulk_insert_rejects_injection() {
1132        let builder = BulkInsertBuilder::new("table;DROP TABLE users")
1133            .with_typed_columns(vec![BulkColumn::new("id", "INT", 0)]);
1134
1135        assert!(builder.build_insert_bulk_statement().is_err());
1136    }
1137
1138    #[test]
1139    fn test_bulk_insert_validates_column_names() {
1140        let builder = BulkInsertBuilder::new("Users").with_typed_columns(vec![BulkColumn::new(
1141            "col;DROP TABLE x",
1142            "INT",
1143            0,
1144        )]);
1145
1146        assert!(builder.build_insert_bulk_statement().is_err());
1147    }
1148
1149    #[test]
1150    fn test_bulk_insert_accepts_qualified_names() {
1151        let builder = BulkInsertBuilder::new("catalog.dbo.Users")
1152            .with_typed_columns(vec![BulkColumn::new("id", "INT", 0)]);
1153
1154        assert!(builder.build_insert_bulk_statement().is_ok());
1155    }
1156
1157    #[test]
1158    fn test_bulk_insert_creation() {
1159        let columns = vec![
1160            BulkColumn::new("id", "INT", 0),
1161            BulkColumn::new("name", "NVARCHAR(100)", 1),
1162        ];
1163
1164        let bulk = BulkInsert::new(columns, 1000);
1165        assert_eq!(bulk.total_rows(), 0);
1166        assert_eq!(bulk.rows_in_batch(), 0);
1167        assert!(!bulk.should_flush());
1168    }
1169
1170    #[test]
1171    fn test_decimal_byte_length() {
1172        assert_eq!(decimal_byte_length(5), 5);
1173        assert_eq!(decimal_byte_length(15), 9);
1174        assert_eq!(decimal_byte_length(25), 13);
1175        assert_eq!(decimal_byte_length(35), 17);
1176    }
1177
1178    #[test]
1179    #[cfg(feature = "chrono")]
1180    fn test_time_byte_length() {
1181        assert_eq!(time_byte_length(0), 3);
1182        assert_eq!(time_byte_length(3), 4);
1183        assert_eq!(time_byte_length(7), 5);
1184    }
1185
1186    #[test]
1187    fn test_plp_string_encoding() {
1188        let mut buf = BytesMut::new();
1189        let text = "Hello";
1190        let utf16: Vec<u16> = text.encode_utf16().collect();
1191
1192        encode_plp_string(&utf16, &mut buf);
1193
1194        // Verify structure:
1195        // - 8 bytes total length
1196        // - 4 bytes chunk length
1197        // - data (5 chars * 2 bytes = 10 bytes)
1198        // - 4 bytes terminator (0)
1199        assert_eq!(buf.len(), 8 + 4 + 10 + 4);
1200
1201        // Check total length
1202        assert_eq!(&buf[0..8], &10u64.to_le_bytes());
1203
1204        // Check chunk length
1205        assert_eq!(&buf[8..12], &10u32.to_le_bytes());
1206
1207        // Check terminator
1208        assert_eq!(&buf[22..26], &0u32.to_le_bytes());
1209    }
1210
1211    #[test]
1212    fn test_plp_binary_encoding() {
1213        let mut buf = BytesMut::new();
1214        let data = b"test binary data";
1215
1216        encode_plp_binary(data, &mut buf);
1217
1218        // Verify structure:
1219        // - 8 bytes total length
1220        // - 4 bytes chunk length
1221        // - data (16 bytes)
1222        // - 4 bytes terminator (0)
1223        assert_eq!(buf.len(), 8 + 4 + 16 + 4);
1224
1225        // Check total length
1226        assert_eq!(&buf[0..8], &16u64.to_le_bytes());
1227
1228        // Check chunk length
1229        assert_eq!(&buf[8..12], &16u32.to_le_bytes());
1230
1231        // Check data
1232        assert_eq!(&buf[12..28], data);
1233
1234        // Check terminator
1235        assert_eq!(&buf[28..32], &0u32.to_le_bytes());
1236    }
1237
1238    #[test]
1239    fn test_plp_empty_string() {
1240        let mut buf = BytesMut::new();
1241        let utf16: Vec<u16> = "".encode_utf16().collect();
1242
1243        encode_plp_string(&utf16, &mut buf);
1244
1245        // Empty string: total length (8) + terminator (4)
1246        assert_eq!(buf.len(), 8 + 4);
1247
1248        // Check total length is 0
1249        assert_eq!(&buf[0..8], &0u64.to_le_bytes());
1250
1251        // Check terminator
1252        assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1253    }
1254
1255    #[test]
1256    fn test_plp_empty_binary() {
1257        let mut buf = BytesMut::new();
1258
1259        encode_plp_binary(&[], &mut buf);
1260
1261        // Empty binary: total length (8) + terminator (4)
1262        assert_eq!(buf.len(), 8 + 4);
1263
1264        // Check total length is 0
1265        assert_eq!(&buf[0..8], &0u64.to_le_bytes());
1266
1267        // Check terminator
1268        assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1269    }
1270
1271    #[test]
1272    fn test_parse_sql_type_max() {
1273        // Test NVARCHAR(MAX) parsing - uses 0xFFFF marker (not doubled for MAX)
1274        let (type_id, len, _, _) = parse_sql_type("NVARCHAR(MAX)");
1275        assert_eq!(type_id, 0xE7);
1276        assert_eq!(len, Some(0xFFFF)); // MAX marker is 0xFFFF
1277
1278        // Test VARBINARY(MAX) parsing
1279        let (type_id, len, _, _) = parse_sql_type("VARBINARY(MAX)");
1280        assert_eq!(type_id, 0xA5);
1281        assert_eq!(len, Some(0xFFFF));
1282
1283        // Test VARCHAR(MAX) parsing
1284        let (type_id, len, _, _) = parse_sql_type("VARCHAR(MAX)");
1285        assert_eq!(type_id, 0xA7);
1286        assert_eq!(len, Some(0xFFFF));
1287
1288        // Verify normal NVARCHAR does double the length
1289        let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
1290        assert_eq!(type_id, 0xE7);
1291        assert_eq!(len, Some(200)); // 100 * 2 for UTF-16
1292    }
1293}