Skip to main content

mssql_client/
bulk.rs

1//! Bulk Copy Protocol (BCP) support.
2//!
3//! This module provides first-class support for bulk insert operations using
4//! the TDS Bulk Load protocol (packet type 0x07). BCP is significantly more
5//! efficient than individual INSERT statements for loading large amounts of data.
6//!
7//! ## Performance Benefits
8//!
9//! - Minimal logging (when using simple recovery model)
10//! - Batch commits reduce lock contention
11//! - Direct data streaming without SQL parsing overhead
12//! - Optional table lock for maximum throughput
13//!
14//! ## Usage
15//!
16//! ```rust,ignore
17//! use mssql_client::{Client, BulkInsert, BulkOptions};
18//!
19//! let mut bulk = client
20//!     .bulk_insert("dbo.Users")
21//!     .with_columns(&["id", "name", "email"])
22//!     .with_options(BulkOptions {
23//!         batch_size: 1000,
24//!         check_constraints: true,
25//!         fire_triggers: false,
26//!         keep_nulls: true,
27//!         table_lock: true,
28//!     })
29//!     .build()
30//!     .await?;
31//!
32//! // Send rows
33//! for user in users {
34//!     bulk.send_row(&[&user.id, &user.name, &user.email]).await?;
35//! }
36//!
37//! let result = bulk.finish().await?;
38//! println!("Inserted {} rows", result.rows_affected);
39//! ```
40//!
41//! ## Implementation Notes
42//!
43//! The bulk load protocol uses:
44//! - Packet type 0x07 (BulkLoad)
45//! - COLMETADATA token describing column structure
46//! - ROW tokens containing actual data
47//! - DONE token signaling completion
48//!
49//! Per MS-TDS specification, the row data format matches the server output format
50//! (same as SELECT results) rather than storage format.
51
52use bytes::{BufMut, BytesMut};
53use once_cell::sync::Lazy;
54use regex::Regex;
55use std::sync::Arc;
56
57use mssql_types::{SqlValue, ToSql, TypeError};
58use tds_protocol::packet::{PacketHeader, PacketStatus, PacketType};
59use tds_protocol::token::{DoneStatus, TokenType};
60
61use crate::error::Error;
62
63/// Options controlling bulk insert behavior.
64///
65/// These options map to SQL Server's BULK INSERT hints and
66/// affect performance, logging, and constraint checking.
67#[derive(Debug, Clone)]
68pub struct BulkOptions {
69    /// Number of rows per batch commit.
70    ///
71    /// Smaller batches use less memory but have more overhead.
72    /// Larger batches are more efficient but hold locks longer.
73    /// Default: 0 (single batch for entire operation).
74    pub batch_size: usize,
75
76    /// Check constraints during insert.
77    ///
78    /// Default: true
79    pub check_constraints: bool,
80
81    /// Fire INSERT triggers on the table.
82    ///
83    /// Default: false (better performance)
84    pub fire_triggers: bool,
85
86    /// Keep NULL values instead of using column defaults.
87    ///
88    /// Default: true
89    pub keep_nulls: bool,
90
91    /// Acquire a table-level lock for the duration of the bulk operation.
92    ///
93    /// This can significantly improve performance by reducing lock
94    /// escalation overhead, but blocks all other access to the table.
95    /// Default: false
96    pub table_lock: bool,
97
98    /// Order hint for the data being inserted.
99    ///
100    /// If data is pre-sorted by the clustered index, specify the columns
101    /// here to avoid a sort operation on the server.
102    /// Default: None
103    pub order_hint: Option<Vec<String>>,
104
105    /// Maximum errors allowed before aborting.
106    ///
107    /// Default: 0 (abort on first error)
108    pub max_errors: u32,
109}
110
111impl Default for BulkOptions {
112    fn default() -> Self {
113        Self {
114            batch_size: 0,
115            check_constraints: true,
116            fire_triggers: false,
117            keep_nulls: true,
118            table_lock: false,
119            order_hint: None,
120            max_errors: 0,
121        }
122    }
123}
124
125/// Column definition for bulk insert.
126#[derive(Debug, Clone)]
127pub struct BulkColumn {
128    /// Column name.
129    pub name: String,
130    /// SQL Server type (e.g., "INT", "NVARCHAR(100)").
131    pub sql_type: String,
132    /// Whether the column allows NULL values.
133    pub nullable: bool,
134    /// Column ordinal (0-based).
135    pub ordinal: usize,
136    /// TDS type ID.
137    type_id: u8,
138    /// Maximum length for variable-length types.
139    max_length: Option<u32>,
140    /// Precision for decimal types.
141    precision: Option<u8>,
142    /// Scale for decimal types.
143    scale: Option<u8>,
144}
145
146impl BulkColumn {
147    /// Create a new bulk column definition.
148    pub fn new<S: Into<String>>(name: S, sql_type: S, ordinal: usize) -> Self {
149        let sql_type_str: String = sql_type.into();
150        let (type_id, max_length, precision, scale) = parse_sql_type(&sql_type_str);
151
152        Self {
153            name: name.into(),
154            sql_type: sql_type_str,
155            nullable: true,
156            ordinal,
157            type_id,
158            max_length,
159            precision,
160            scale,
161        }
162    }
163
164    /// Set whether this column allows NULL values.
165    #[must_use]
166    pub fn with_nullable(mut self, nullable: bool) -> Self {
167        self.nullable = nullable;
168        self
169    }
170}
171
172/// Parse SQL type string into TDS type information.
173fn parse_sql_type(sql_type: &str) -> (u8, Option<u32>, Option<u8>, Option<u8>) {
174    let upper = sql_type.to_uppercase();
175
176    // Extract base type and parameters
177    let (base, params) = if let Some(paren_pos) = upper.find('(') {
178        let base = &upper[..paren_pos];
179        let params_str = upper[paren_pos + 1..].trim_end_matches(')');
180        (base, Some(params_str))
181    } else {
182        (upper.as_str(), None)
183    };
184
185    match base {
186        "BIT" => (0x32, None, None, None),
187        "TINYINT" => (0x30, None, None, None),
188        "SMALLINT" => (0x34, None, None, None),
189        "INT" => (0x38, None, None, None),
190        "BIGINT" => (0x7F, None, None, None),
191        "REAL" => (0x3B, None, None, None),
192        "FLOAT" => (0x3E, None, None, None),
193        "DATE" => (0x28, None, None, None),
194        "TIME" => {
195            let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
196            (0x29, None, None, Some(scale))
197        }
198        "DATETIME" => (0x3D, None, None, None),
199        "DATETIME2" => {
200            let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
201            (0x2A, None, None, Some(scale))
202        }
203        "DATETIMEOFFSET" => {
204            let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
205            (0x2B, None, None, Some(scale))
206        }
207        "SMALLDATETIME" => (0x3F, None, None, None),
208        "UNIQUEIDENTIFIER" => (0x24, Some(16), None, None),
209        "VARCHAR" | "CHAR" => {
210            let len = params
211                .and_then(|p| {
212                    if p == "MAX" {
213                        Some(0xFFFF_u32)
214                    } else {
215                        p.parse().ok()
216                    }
217                })
218                .unwrap_or(8000);
219            (0xA7, Some(len), None, None)
220        }
221        "NVARCHAR" | "NCHAR" => {
222            let is_max = params.map(|p| p == "MAX").unwrap_or(false);
223            if is_max {
224                // MAX types use 0xFFFF marker (not doubled)
225                (0xE7, Some(0xFFFF), None, None)
226            } else {
227                // Normal lengths are in characters, double for UTF-16 byte length
228                let len = params.and_then(|p| p.parse().ok()).unwrap_or(4000);
229                (0xE7, Some(len * 2), None, None)
230            }
231        }
232        "VARBINARY" | "BINARY" => {
233            let len = params
234                .and_then(|p| {
235                    if p == "MAX" {
236                        Some(0xFFFF_u32)
237                    } else {
238                        p.parse().ok()
239                    }
240                })
241                .unwrap_or(8000);
242            (0xA5, Some(len), None, None)
243        }
244        "DECIMAL" | "NUMERIC" => {
245            let (precision, scale) = if let Some(p) = params {
246                let parts: Vec<&str> = p.split(',').map(|s| s.trim()).collect();
247                (
248                    parts.first().and_then(|s| s.parse().ok()).unwrap_or(18),
249                    parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0),
250                )
251            } else {
252                (18, 0)
253            };
254            (0x6C, None, Some(precision), Some(scale))
255        }
256        "MONEY" => (0x3C, Some(8), None, None),
257        "SMALLMONEY" => (0x7A, Some(4), None, None),
258        "XML" => (0xF1, Some(0xFFFF), None, None),
259        "TEXT" => (0x23, Some(0x7FFF_FFFF), None, None),
260        "NTEXT" => (0x63, Some(0x7FFF_FFFF), None, None),
261        "IMAGE" => (0x22, Some(0x7FFF_FFFF), None, None),
262        _ => (0xE7, Some(8000), None, None), // Default to NVARCHAR(4000)
263    }
264}
265
266/// Result of a bulk insert operation.
267#[derive(Debug, Clone)]
268pub struct BulkInsertResult {
269    /// Total number of rows inserted.
270    pub rows_affected: u64,
271    /// Number of batches committed.
272    pub batches_committed: u32,
273    /// Whether any errors were encountered.
274    pub has_errors: bool,
275}
276
277/// Builder for configuring a bulk insert operation.
278#[derive(Debug)]
279pub struct BulkInsertBuilder {
280    table_name: String,
281    columns: Vec<BulkColumn>,
282    options: BulkOptions,
283}
284
285impl BulkInsertBuilder {
286    /// Create a new bulk insert builder for the specified table.
287    pub fn new<S: Into<String>>(table_name: S) -> Self {
288        Self {
289            table_name: table_name.into(),
290            columns: Vec::new(),
291            options: BulkOptions::default(),
292        }
293    }
294
295    /// Specify the columns to insert.
296    ///
297    /// Columns will be queried from the server if not specified,
298    /// but providing them explicitly is more efficient.
299    #[must_use]
300    pub fn with_columns(mut self, column_names: &[&str]) -> Self {
301        self.columns = column_names
302            .iter()
303            .enumerate()
304            .map(|(i, name)| BulkColumn::new(*name, "NVARCHAR(MAX)", i))
305            .collect();
306        self
307    }
308
309    /// Specify columns with full type information.
310    #[must_use]
311    pub fn with_typed_columns(mut self, columns: Vec<BulkColumn>) -> Self {
312        self.columns = columns;
313        self
314    }
315
316    /// Set bulk insert options.
317    #[must_use]
318    pub fn with_options(mut self, options: BulkOptions) -> Self {
319        self.options = options;
320        self
321    }
322
323    /// Set the batch size.
324    #[must_use]
325    pub fn batch_size(mut self, size: usize) -> Self {
326        self.options.batch_size = size;
327        self
328    }
329
330    /// Enable or disable table lock.
331    #[must_use]
332    pub fn table_lock(mut self, enabled: bool) -> Self {
333        self.options.table_lock = enabled;
334        self
335    }
336
337    /// Enable or disable trigger firing.
338    #[must_use]
339    pub fn fire_triggers(mut self, enabled: bool) -> Self {
340        self.options.fire_triggers = enabled;
341        self
342    }
343
344    /// Get the table name.
345    pub fn table_name(&self) -> &str {
346        &self.table_name
347    }
348
349    /// Get the columns.
350    pub fn columns(&self) -> &[BulkColumn] {
351        &self.columns
352    }
353
354    /// Get the options.
355    pub fn options(&self) -> &BulkOptions {
356        &self.options
357    }
358
359    /// Build the INSERT BULK SQL statement.
360    ///
361    /// # Errors
362    ///
363    /// Returns an error if the table name or any column name fails identifier
364    /// validation, preventing SQL injection.
365    pub fn build_insert_bulk_statement(&self) -> Result<String, Error> {
366        // Validate table name (may be schema-qualified: dbo.Users, catalog.schema.table)
367        validate_qualified_identifier(&self.table_name)?;
368
369        // Validate column names
370        for col in &self.columns {
371            validate_identifier(&col.name)?;
372        }
373
374        let mut sql = format!("INSERT BULK {}", self.table_name);
375
376        // Add column definitions
377        if !self.columns.is_empty() {
378            sql.push_str(" (");
379            let cols: Vec<String> = self
380                .columns
381                .iter()
382                .map(|c| {
383                    // Validate sql_type to prevent SQL injection: only allow
384                    // alphanumerics, parentheses (for length/precision), commas,
385                    // spaces, and the MAX keyword — which covers all valid T-SQL
386                    // type specifiers like "NVARCHAR(100)", "DECIMAL(18, 2)",
387                    // "VARBINARY(MAX)", etc.
388                    validate_sql_type(&c.sql_type)?;
389                    Ok(format!("{} {}", c.name, c.sql_type))
390                })
391                .collect::<Result<Vec<_>, Error>>()?;
392            sql.push_str(&cols.join(", "));
393            sql.push(')');
394        }
395
396        // Add WITH clause for options
397        let mut hints: Vec<String> = Vec::new();
398
399        if self.options.check_constraints {
400            hints.push("CHECK_CONSTRAINTS".to_string());
401        }
402        if self.options.fire_triggers {
403            hints.push("FIRE_TRIGGERS".to_string());
404        }
405        if self.options.keep_nulls {
406            hints.push("KEEP_NULLS".to_string());
407        }
408        if self.options.table_lock {
409            hints.push("TABLOCK".to_string());
410        }
411        if self.options.batch_size > 0 {
412            hints.push(format!("ROWS_PER_BATCH = {}", self.options.batch_size));
413        }
414
415        if let Some(ref order) = self.options.order_hint {
416            // Validate order hint column names
417            for col_name in order {
418                validate_identifier(col_name)?;
419            }
420            hints.push(format!("ORDER({})", order.join(", ")));
421        }
422
423        if !hints.is_empty() {
424            sql.push_str(" WITH (");
425            sql.push_str(&hints.join(", "));
426            sql.push(')');
427        }
428
429        Ok(sql)
430    }
431}
432
433/// Validate a SQL type specifier to prevent SQL injection.
434///
435/// Allows only characters that can appear in valid T-SQL type declarations:
436/// letters, digits, parentheses, commas, spaces, and periods.
437/// Examples: "INT", "NVARCHAR(100)", "DECIMAL(18, 2)", "VARBINARY(MAX)".
438fn validate_sql_type(type_str: &str) -> Result<(), Error> {
439    #[allow(clippy::expect_used)] // Static regex compilation with known-valid pattern
440    static SQL_TYPE_RE: Lazy<Regex> =
441        Lazy::new(|| Regex::new(r"^[a-zA-Z][a-zA-Z0-9_ ()\.,]{0,127}$").expect("valid regex"));
442
443    if type_str.is_empty() {
444        return Err(Error::Config("SQL type cannot be empty".into()));
445    }
446
447    if !SQL_TYPE_RE.is_match(type_str) {
448        return Err(Error::Config(format!(
449            "invalid SQL type '{type_str}': contains disallowed characters"
450        )));
451    }
452
453    Ok(())
454}
455
456/// Validate a single SQL identifier to prevent SQL injection.
457fn validate_identifier(name: &str) -> Result<(), Error> {
458    #[allow(clippy::expect_used)] // Static regex compilation with known-valid pattern
459    static IDENTIFIER_RE: Lazy<Regex> =
460        Lazy::new(|| Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_@#$]{0,127}$").expect("valid regex"));
461
462    if name.is_empty() {
463        return Err(Error::InvalidIdentifier(
464            "identifier cannot be empty".into(),
465        ));
466    }
467
468    if !IDENTIFIER_RE.is_match(name) {
469        return Err(Error::InvalidIdentifier(format!(
470            "invalid identifier '{name}': must start with letter/underscore, \
471             contain only alphanumerics/_/@/#/$, and be 1-128 characters"
472        )));
473    }
474
475    Ok(())
476}
477
478/// Validate a potentially schema-qualified identifier (e.g., `dbo.Users`, `catalog.schema.table`).
479///
480/// Splits on `.` and validates each part. SQL Server allows up to 4 parts:
481/// `server.catalog.schema.object`.
482fn validate_qualified_identifier(name: &str) -> Result<(), Error> {
483    if name.is_empty() {
484        return Err(Error::InvalidIdentifier(
485            "identifier cannot be empty".into(),
486        ));
487    }
488
489    let parts: Vec<&str> = name.split('.').collect();
490    if parts.len() > 4 {
491        return Err(Error::InvalidIdentifier(format!(
492            "invalid qualified identifier '{name}': too many parts (max 4: server.catalog.schema.object)"
493        )));
494    }
495
496    for part in &parts {
497        validate_identifier(part)?;
498    }
499
500    Ok(())
501}
502
503/// Active bulk insert operation.
504///
505/// This struct manages the streaming of row data to the server.
506/// Call `send_row()` for each row, then `finish()` to complete.
507pub struct BulkInsert {
508    /// Column metadata.
509    columns: Arc<[BulkColumn]>,
510    /// Buffer for accumulating rows.
511    buffer: BytesMut,
512    /// Rows in current batch.
513    rows_in_batch: usize,
514    /// Total rows sent.
515    total_rows: u64,
516    /// Batch size (0 = single batch).
517    batch_size: usize,
518    /// Number of batches committed.
519    batches_committed: u32,
520    /// Packet ID counter.
521    packet_id: u8,
522}
523
524impl BulkInsert {
525    /// Create a new bulk insert operation.
526    pub fn new(columns: Vec<BulkColumn>, batch_size: usize) -> Self {
527        let mut bulk = Self {
528            columns: columns.into(),
529            buffer: BytesMut::with_capacity(64 * 1024), // 64KB initial buffer
530            rows_in_batch: 0,
531            total_rows: 0,
532            batch_size,
533            batches_committed: 0,
534            packet_id: 1,
535        };
536
537        // Write COLMETADATA token
538        bulk.write_colmetadata();
539
540        bulk
541    }
542
543    /// Write the COLMETADATA token to the buffer.
544    fn write_colmetadata(&mut self) {
545        let buf = &mut self.buffer;
546
547        // Token type
548        buf.put_u8(TokenType::ColMetaData as u8);
549
550        // Column count
551        buf.put_u16_le(self.columns.len() as u16);
552
553        for col in self.columns.iter() {
554            // User type (always 0 for basic types)
555            buf.put_u32_le(0);
556
557            // Flags: Nullable (bit 0) | CaseSen (bit 1) | Updateable (bits 2-3) | etc.
558            let flags: u16 = if col.nullable { 0x0001 } else { 0x0000 };
559            buf.put_u16_le(flags);
560
561            // Type info
562            buf.put_u8(col.type_id);
563
564            // Type-specific length/precision/scale
565            match col.type_id {
566                // Fixed-length types - no additional info needed
567                0x32 | 0x30 | 0x34 | 0x38 | 0x7F | 0x3B | 0x3E | 0x3D | 0x3F | 0x28 => {}
568
569                // Variable-length string/binary types
570                0xE7 | 0xA7 | 0xA5 | 0xAD => {
571                    // Max length (2 bytes for normal, 4 bytes for MAX)
572                    let max_len = col.max_length.unwrap_or(8000);
573                    if max_len == 0xFFFF {
574                        buf.put_u16_le(0xFFFF);
575                    } else {
576                        buf.put_u16_le(max_len as u16);
577                    }
578
579                    // Collation for string types (5 bytes)
580                    if col.type_id == 0xE7 || col.type_id == 0xA7 {
581                        // Default collation (Latin1_General_CI_AS)
582                        buf.put_u32_le(0x0409_0904); // LCID + flags
583                        buf.put_u8(52); // Sort ID
584                    }
585                }
586
587                // Decimal/Numeric
588                0x6C | 0x6A => {
589                    // Length (calculated from precision)
590                    let precision = col.precision.unwrap_or(18);
591                    let len = decimal_byte_length(precision);
592                    buf.put_u8(len);
593                    buf.put_u8(precision);
594                    buf.put_u8(col.scale.unwrap_or(0));
595                }
596
597                // Time-based with scale
598                0x29..=0x2B => {
599                    buf.put_u8(col.scale.unwrap_or(7));
600                }
601
602                // GUID
603                0x24 => {
604                    buf.put_u8(16);
605                }
606
607                // Other types - write max length if present
608                _ => {
609                    if let Some(len) = col.max_length {
610                        if len <= 0xFFFF {
611                            buf.put_u16_le(len as u16);
612                        }
613                    }
614                }
615            }
616
617            // Column name (B_VARCHAR format: 1-byte length prefix)
618            let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
619            buf.put_u8(name_utf16.len() as u8);
620            for code_unit in name_utf16 {
621                buf.put_u16_le(code_unit);
622            }
623        }
624    }
625
626    /// Send a row of data.
627    ///
628    /// The values must match the column order and types specified
629    /// when creating the bulk insert.
630    ///
631    /// # Errors
632    ///
633    /// Returns an error if:
634    /// - Wrong number of values provided
635    /// - A value cannot be converted to the expected type
636    pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
637        if values.len() != self.columns.len() {
638            return Err(Error::Config(format!(
639                "expected {} values, got {}",
640                self.columns.len(),
641                values.len()
642            )));
643        }
644
645        // Convert all values to SqlValue
646        let sql_values: Result<Vec<SqlValue>, TypeError> =
647            values.iter().map(|v| v.to_sql()).collect();
648        let sql_values = sql_values.map_err(Error::from)?;
649
650        self.write_row(&sql_values)?;
651
652        self.rows_in_batch += 1;
653        self.total_rows += 1;
654
655        Ok(())
656    }
657
658    /// Send a row of pre-converted SQL values.
659    pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
660        if values.len() != self.columns.len() {
661            return Err(Error::Config(format!(
662                "expected {} values, got {}",
663                self.columns.len(),
664                values.len()
665            )));
666        }
667
668        self.write_row(values)?;
669
670        self.rows_in_batch += 1;
671        self.total_rows += 1;
672
673        Ok(())
674    }
675
676    /// Write a ROW token to the buffer.
677    fn write_row(&mut self, values: &[SqlValue]) -> Result<(), Error> {
678        // ROW token type
679        self.buffer.put_u8(TokenType::Row as u8);
680
681        // Collect column info needed for encoding to avoid borrow conflict
682        let columns: Vec<_> = self.columns.iter().cloned().collect();
683
684        // Write each column value
685        for (i, (col, value)) in columns.iter().zip(values.iter()).enumerate() {
686            self.encode_column_value(col, value)
687                .map_err(|e| Error::Config(format!("failed to encode column {i}: {e}")))?;
688        }
689
690        Ok(())
691    }
692
693    /// Encode a column value according to its type.
694    fn encode_column_value(&mut self, col: &BulkColumn, value: &SqlValue) -> Result<(), TypeError> {
695        let buf = &mut self.buffer;
696
697        // Check if this column uses PLP (Partially Length-Prefixed) encoding
698        // MAX types (max_length == 0xFFFF) use PLP format
699        let is_plp_type =
700            col.max_length == Some(0xFFFF) && matches!(col.type_id, 0xE7 | 0xA7 | 0xA5 | 0xAD);
701
702        match value {
703            SqlValue::Null => {
704                // NULL encoding depends on type
705                match col.type_id {
706                    // Variable-length types
707                    0xE7 | 0xA7 | 0xA5 | 0xAD => {
708                        if is_plp_type {
709                            // PLP NULL: 0xFFFFFFFFFFFFFFFF
710                            buf.put_u64_le(0xFFFF_FFFF_FFFF_FFFF);
711                        } else {
712                            // Standard NULL: 0xFFFF length marker
713                            buf.put_u16_le(0xFFFF);
714                        }
715                    }
716                    // Nullable fixed types use 0 length
717                    0x26 | 0x6C | 0x6A | 0x24 | 0x29 | 0x2A | 0x2B => {
718                        buf.put_u8(0);
719                    }
720                    // Fixed types without nullable variant
721                    _ => {
722                        if col.nullable {
723                            buf.put_u8(0);
724                        } else {
725                            return Err(TypeError::UnexpectedNull);
726                        }
727                    }
728                }
729            }
730
731            SqlValue::Bool(v) => {
732                buf.put_u8(1); // Length
733                buf.put_u8(if *v { 1 } else { 0 });
734            }
735
736            SqlValue::TinyInt(v) => {
737                buf.put_u8(1); // Length
738                buf.put_u8(*v);
739            }
740
741            SqlValue::SmallInt(v) => {
742                buf.put_u8(2); // Length
743                buf.put_i16_le(*v);
744            }
745
746            SqlValue::Int(v) => {
747                buf.put_u8(4); // Length
748                buf.put_i32_le(*v);
749            }
750
751            SqlValue::BigInt(v) => {
752                buf.put_u8(8); // Length
753                buf.put_i64_le(*v);
754            }
755
756            SqlValue::Float(v) => {
757                buf.put_u8(4); // Length
758                buf.put_f32_le(*v);
759            }
760
761            SqlValue::Double(v) => {
762                buf.put_u8(8); // Length
763                buf.put_f64_le(*v);
764            }
765
766            SqlValue::String(s) => {
767                // UTF-16LE encoding for NVARCHAR
768                let utf16: Vec<u16> = s.encode_utf16().collect();
769                let byte_len = utf16.len() * 2;
770
771                if is_plp_type {
772                    // PLP format for MAX types - supports unlimited size
773                    // Send as a single chunk for simplicity
774                    encode_plp_string(&utf16, buf);
775                } else if byte_len > 0xFFFF {
776                    // Non-MAX column can't hold this much data
777                    return Err(TypeError::BufferTooSmall {
778                        needed: byte_len,
779                        available: 0xFFFF,
780                    });
781                } else {
782                    // Standard encoding with 2-byte length prefix
783                    buf.put_u16_le(byte_len as u16);
784                    for code_unit in utf16 {
785                        buf.put_u16_le(code_unit);
786                    }
787                }
788            }
789
790            SqlValue::Binary(b) => {
791                if is_plp_type {
792                    // PLP format for MAX types - supports unlimited size
793                    encode_plp_binary(b, buf);
794                } else if b.len() > 0xFFFF {
795                    // Non-MAX column can't hold this much data
796                    return Err(TypeError::BufferTooSmall {
797                        needed: b.len(),
798                        available: 0xFFFF,
799                    });
800                } else {
801                    // Standard encoding with 2-byte length prefix
802                    buf.put_u16_le(b.len() as u16);
803                    buf.put_slice(b);
804                }
805            }
806
807            // Feature-gated types - use mssql_types::encode module
808            #[cfg(feature = "decimal")]
809            SqlValue::Decimal(d) => {
810                let precision = col.precision.unwrap_or(18);
811                let len = decimal_byte_length(precision);
812                buf.put_u8(len);
813
814                // Sign: 0 = negative, 1 = positive
815                buf.put_u8(if d.is_sign_negative() { 0 } else { 1 });
816
817                // Mantissa as unsigned 128-bit integer
818                let mantissa = d.mantissa().unsigned_abs();
819                let mantissa_bytes = mantissa.to_le_bytes();
820                buf.put_slice(&mantissa_bytes[..((len - 1) as usize)]);
821            }
822
823            #[cfg(feature = "uuid")]
824            SqlValue::Uuid(u) => {
825                buf.put_u8(16); // Length
826                // Use mssql_types encode function
827                mssql_types::encode::encode_uuid(*u, buf);
828            }
829
830            #[cfg(feature = "chrono")]
831            SqlValue::Date(d) => {
832                buf.put_u8(3); // Length
833                mssql_types::encode::encode_date(*d, buf);
834            }
835
836            #[cfg(feature = "chrono")]
837            SqlValue::Time(t) => {
838                let scale = col.scale.unwrap_or(7);
839                let len = time_byte_length(scale);
840                buf.put_u8(len);
841                // Encode time with proper scale handling
842                encode_time_with_scale(*t, scale, buf);
843            }
844
845            #[cfg(feature = "chrono")]
846            SqlValue::DateTime(dt) => {
847                let scale = col.scale.unwrap_or(7);
848                let time_len = time_byte_length(scale);
849                let total_len = time_len + 3;
850                buf.put_u8(total_len);
851                // Encode time then date
852                encode_time_with_scale(dt.time(), scale, buf);
853                mssql_types::encode::encode_date(dt.date(), buf);
854            }
855
856            #[cfg(feature = "chrono")]
857            SqlValue::DateTimeOffset(dto) => {
858                let scale = col.scale.unwrap_or(7);
859                let time_len = time_byte_length(scale);
860                let total_len = time_len + 3 + 2;
861                buf.put_u8(total_len);
862                // Use mssql_types encode
863                encode_time_with_scale(dto.time(), scale, buf);
864                mssql_types::encode::encode_date(dto.date_naive(), buf);
865                // Timezone offset in minutes
866                use chrono::Offset;
867                let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
868                buf.put_i16_le(offset_minutes);
869            }
870
871            #[cfg(feature = "json")]
872            SqlValue::Json(j) => {
873                let s = j.to_string();
874                encode_nvarchar_value(&s, buf)?;
875            }
876
877            SqlValue::Xml(x) => {
878                encode_nvarchar_value(x, buf)?;
879            }
880
881            SqlValue::Tvp(_) => {
882                // TVPs are not valid in bulk copy operations - they're for RPC parameters only
883                return Err(TypeError::UnsupportedConversion {
884                    from: "TVP".to_string(),
885                    to: "bulk copy value",
886                });
887            }
888            // Handle future SqlValue variants
889            _ => {
890                return Err(TypeError::UnsupportedConversion {
891                    from: value.type_name().to_string(),
892                    to: "bulk copy value",
893                });
894            }
895        }
896
897        Ok(())
898    }
899}
900
901/// Encode a string as NVARCHAR with length prefix.
902fn encode_nvarchar_value(s: &str, buf: &mut BytesMut) -> Result<(), TypeError> {
903    let utf16: Vec<u16> = s.encode_utf16().collect();
904    let byte_len = utf16.len() * 2;
905
906    if byte_len > 0xFFFF {
907        return Err(TypeError::BufferTooSmall {
908            needed: byte_len,
909            available: 0xFFFF,
910        });
911    }
912
913    buf.put_u16_le(byte_len as u16);
914    for code_unit in utf16 {
915        buf.put_u16_le(code_unit);
916    }
917    Ok(())
918}
919
920/// Encode a UTF-16 string using PLP (Partially Length-Prefixed) format.
921///
922/// PLP format (per MS-TDS specification):
923/// - 8 bytes: total length in bytes (little-endian)
924/// - Chunks: 4-byte chunk length + data, repeated
925/// - Terminator: 4 bytes of zero
926///
927/// For simplicity, we send the entire value as a single chunk.
928/// This is efficient for bulk operations where we already have the complete data.
929fn encode_plp_string(utf16: &[u16], buf: &mut BytesMut) {
930    let byte_len = utf16.len() * 2;
931
932    // Total length (8 bytes)
933    buf.put_u64_le(byte_len as u64);
934
935    if byte_len > 0 {
936        // Single chunk: length (4 bytes) + data
937        buf.put_u32_le(byte_len as u32);
938        for code_unit in utf16 {
939            buf.put_u16_le(*code_unit);
940        }
941    }
942
943    // Terminator chunk (length = 0)
944    buf.put_u32_le(0);
945}
946
947/// Encode binary data using PLP (Partially Length-Prefixed) format.
948///
949/// PLP format (per MS-TDS specification):
950/// - 8 bytes: total length in bytes (little-endian)
951/// - Chunks: 4-byte chunk length + data, repeated
952/// - Terminator: 4 bytes of zero
953///
954/// For simplicity, we send the entire value as a single chunk.
955fn encode_plp_binary(data: &[u8], buf: &mut BytesMut) {
956    // Total length (8 bytes)
957    buf.put_u64_le(data.len() as u64);
958
959    if !data.is_empty() {
960        // Single chunk: length (4 bytes) + data
961        buf.put_u32_le(data.len() as u32);
962        buf.put_slice(data);
963    }
964
965    // Terminator chunk (length = 0)
966    buf.put_u32_le(0);
967}
968
969/// Encode time with specific scale (for bulk copy).
970#[cfg(feature = "chrono")]
971fn encode_time_with_scale(time: chrono::NaiveTime, scale: u8, buf: &mut BytesMut) {
972    use chrono::Timelike;
973
974    let nanos = time.num_seconds_from_midnight() as u64 * 1_000_000_000 + time.nanosecond() as u64;
975    let intervals = nanos / time_scale_divisor(scale);
976    let len = time_byte_length(scale);
977
978    for i in 0..len {
979        buf.put_u8(((intervals >> (i * 8)) & 0xFF) as u8);
980    }
981}
982
983impl BulkInsert {
984    /// Write the DONE token signaling completion.
985    fn write_done(&mut self) {
986        let buf = &mut self.buffer;
987
988        buf.put_u8(TokenType::Done as u8);
989
990        // Status: FINAL (0x00) | COUNT (0x10)
991        let status = DoneStatus::from_bits(0x0010); // DONE_COUNT
992        buf.put_u16_le(status.to_bits());
993
994        // Current command (0 for bulk load)
995        buf.put_u16_le(0);
996
997        // Row count
998        buf.put_u64_le(self.total_rows);
999    }
1000
1001    /// Get the buffered data as packets ready to send.
1002    ///
1003    /// Returns a vector of complete TDS packets with BulkLoad packet type (0x07).
1004    pub fn take_packets(&mut self) -> Vec<BytesMut> {
1005        const MAX_PACKET_SIZE: usize = 4096;
1006        const HEADER_SIZE: usize = 8;
1007        const MAX_PAYLOAD: usize = MAX_PACKET_SIZE - HEADER_SIZE;
1008
1009        let data = self.buffer.split();
1010        let mut packets = Vec::new();
1011        let mut offset = 0;
1012
1013        while offset < data.len() {
1014            let remaining = data.len() - offset;
1015            let payload_size = remaining.min(MAX_PAYLOAD);
1016            let is_last = offset + payload_size >= data.len();
1017
1018            let mut packet = BytesMut::with_capacity(MAX_PACKET_SIZE);
1019
1020            // Write packet header
1021            let header = PacketHeader {
1022                packet_type: PacketType::BulkLoad,
1023                status: if is_last {
1024                    PacketStatus::END_OF_MESSAGE
1025                } else {
1026                    PacketStatus::NORMAL
1027                },
1028                length: (HEADER_SIZE + payload_size) as u16,
1029                spid: 0,
1030                packet_id: self.packet_id,
1031                window: 0,
1032            };
1033
1034            header.encode(&mut packet);
1035
1036            // Write payload
1037            packet.put_slice(&data[offset..offset + payload_size]);
1038
1039            packets.push(packet);
1040            offset += payload_size;
1041            self.packet_id = self.packet_id.wrapping_add(1);
1042        }
1043
1044        packets
1045    }
1046
1047    /// Get total rows sent so far.
1048    pub fn total_rows(&self) -> u64 {
1049        self.total_rows
1050    }
1051
1052    /// Get rows in current batch.
1053    pub fn rows_in_batch(&self) -> usize {
1054        self.rows_in_batch
1055    }
1056
1057    /// Check if a batch flush is needed.
1058    pub fn should_flush(&self) -> bool {
1059        self.batch_size > 0 && self.rows_in_batch >= self.batch_size
1060    }
1061
1062    /// Prepare for finishing the bulk operation.
1063    /// Writes the DONE token and returns final packets.
1064    pub fn finish_packets(&mut self) -> Vec<BytesMut> {
1065        self.write_done();
1066        self.take_packets()
1067    }
1068
1069    /// Create a result from the current state.
1070    pub fn result(&self) -> BulkInsertResult {
1071        BulkInsertResult {
1072            rows_affected: self.total_rows,
1073            batches_committed: self.batches_committed,
1074            has_errors: false,
1075        }
1076    }
1077}
1078
1079/// Calculate byte length for decimal based on precision.
1080fn decimal_byte_length(precision: u8) -> u8 {
1081    match precision {
1082        1..=9 => 5,
1083        10..=19 => 9,
1084        20..=28 => 13,
1085        29..=38 => 17,
1086        _ => 17, // Max precision
1087    }
1088}
1089
1090/// Calculate byte length for time based on scale.
1091#[cfg(feature = "chrono")]
1092fn time_byte_length(scale: u8) -> u8 {
1093    match scale {
1094        0..=2 => 3,
1095        3..=4 => 4,
1096        5..=7 => 5,
1097        _ => 5,
1098    }
1099}
1100
1101/// Get the divisor for time scale.
1102#[cfg(feature = "chrono")]
1103fn time_scale_divisor(scale: u8) -> u64 {
1104    match scale {
1105        0 => 1_000_000_000,
1106        1 => 100_000_000,
1107        2 => 10_000_000,
1108        3 => 1_000_000,
1109        4 => 100_000,
1110        5 => 10_000,
1111        6 => 1_000,
1112        7 => 100,
1113        _ => 100,
1114    }
1115}
1116
1117#[cfg(test)]
1118#[allow(clippy::unwrap_used)]
1119mod tests {
1120    use super::*;
1121
1122    #[test]
1123    fn test_bulk_options_default() {
1124        let opts = BulkOptions::default();
1125        assert_eq!(opts.batch_size, 0);
1126        assert!(opts.check_constraints);
1127        assert!(!opts.fire_triggers);
1128        assert!(opts.keep_nulls);
1129        assert!(!opts.table_lock);
1130    }
1131
1132    #[test]
1133    fn test_bulk_column_creation() {
1134        let col = BulkColumn::new("id", "INT", 0);
1135        assert_eq!(col.name, "id");
1136        assert_eq!(col.type_id, 0x38);
1137        assert!(col.nullable);
1138    }
1139
1140    #[test]
1141    fn test_parse_sql_type() {
1142        let (type_id, len, _prec, _scale) = parse_sql_type("INT");
1143        assert_eq!(type_id, 0x38);
1144        assert!(len.is_none());
1145
1146        let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
1147        assert_eq!(type_id, 0xE7);
1148        assert_eq!(len, Some(200)); // UTF-16 doubles
1149
1150        let (type_id, _, prec, scale) = parse_sql_type("DECIMAL(10,2)");
1151        assert_eq!(type_id, 0x6C);
1152        assert_eq!(prec, Some(10));
1153        assert_eq!(scale, Some(2));
1154    }
1155
1156    #[test]
1157    fn test_insert_bulk_statement() {
1158        let builder = BulkInsertBuilder::new("dbo.Users")
1159            .with_typed_columns(vec![
1160                BulkColumn::new("id", "INT", 0),
1161                BulkColumn::new("name", "NVARCHAR(100)", 1),
1162            ])
1163            .table_lock(true);
1164
1165        let sql = builder.build_insert_bulk_statement().unwrap();
1166        assert!(sql.contains("INSERT BULK dbo.Users"));
1167        assert!(sql.contains("TABLOCK"));
1168    }
1169
1170    #[test]
1171    fn test_bulk_insert_rejects_injection() {
1172        let builder = BulkInsertBuilder::new("table;DROP TABLE users")
1173            .with_typed_columns(vec![BulkColumn::new("id", "INT", 0)]);
1174
1175        assert!(builder.build_insert_bulk_statement().is_err());
1176    }
1177
1178    #[test]
1179    fn test_bulk_insert_validates_column_names() {
1180        let builder = BulkInsertBuilder::new("Users").with_typed_columns(vec![BulkColumn::new(
1181            "col;DROP TABLE x",
1182            "INT",
1183            0,
1184        )]);
1185
1186        assert!(builder.build_insert_bulk_statement().is_err());
1187    }
1188
1189    #[test]
1190    fn test_bulk_insert_accepts_qualified_names() {
1191        let builder = BulkInsertBuilder::new("catalog.dbo.Users")
1192            .with_typed_columns(vec![BulkColumn::new("id", "INT", 0)]);
1193
1194        assert!(builder.build_insert_bulk_statement().is_ok());
1195    }
1196
1197    #[test]
1198    fn test_bulk_insert_creation() {
1199        let columns = vec![
1200            BulkColumn::new("id", "INT", 0),
1201            BulkColumn::new("name", "NVARCHAR(100)", 1),
1202        ];
1203
1204        let bulk = BulkInsert::new(columns, 1000);
1205        assert_eq!(bulk.total_rows(), 0);
1206        assert_eq!(bulk.rows_in_batch(), 0);
1207        assert!(!bulk.should_flush());
1208    }
1209
1210    #[test]
1211    fn test_decimal_byte_length() {
1212        assert_eq!(decimal_byte_length(5), 5);
1213        assert_eq!(decimal_byte_length(15), 9);
1214        assert_eq!(decimal_byte_length(25), 13);
1215        assert_eq!(decimal_byte_length(35), 17);
1216    }
1217
1218    #[test]
1219    #[cfg(feature = "chrono")]
1220    fn test_time_byte_length() {
1221        assert_eq!(time_byte_length(0), 3);
1222        assert_eq!(time_byte_length(3), 4);
1223        assert_eq!(time_byte_length(7), 5);
1224    }
1225
1226    #[test]
1227    fn test_plp_string_encoding() {
1228        let mut buf = BytesMut::new();
1229        let text = "Hello";
1230        let utf16: Vec<u16> = text.encode_utf16().collect();
1231
1232        encode_plp_string(&utf16, &mut buf);
1233
1234        // Verify structure:
1235        // - 8 bytes total length
1236        // - 4 bytes chunk length
1237        // - data (5 chars * 2 bytes = 10 bytes)
1238        // - 4 bytes terminator (0)
1239        assert_eq!(buf.len(), 8 + 4 + 10 + 4);
1240
1241        // Check total length
1242        assert_eq!(&buf[0..8], &10u64.to_le_bytes());
1243
1244        // Check chunk length
1245        assert_eq!(&buf[8..12], &10u32.to_le_bytes());
1246
1247        // Check terminator
1248        assert_eq!(&buf[22..26], &0u32.to_le_bytes());
1249    }
1250
1251    #[test]
1252    fn test_plp_binary_encoding() {
1253        let mut buf = BytesMut::new();
1254        let data = b"test binary data";
1255
1256        encode_plp_binary(data, &mut buf);
1257
1258        // Verify structure:
1259        // - 8 bytes total length
1260        // - 4 bytes chunk length
1261        // - data (16 bytes)
1262        // - 4 bytes terminator (0)
1263        assert_eq!(buf.len(), 8 + 4 + 16 + 4);
1264
1265        // Check total length
1266        assert_eq!(&buf[0..8], &16u64.to_le_bytes());
1267
1268        // Check chunk length
1269        assert_eq!(&buf[8..12], &16u32.to_le_bytes());
1270
1271        // Check data
1272        assert_eq!(&buf[12..28], data);
1273
1274        // Check terminator
1275        assert_eq!(&buf[28..32], &0u32.to_le_bytes());
1276    }
1277
1278    #[test]
1279    fn test_plp_empty_string() {
1280        let mut buf = BytesMut::new();
1281        let utf16: Vec<u16> = "".encode_utf16().collect();
1282
1283        encode_plp_string(&utf16, &mut buf);
1284
1285        // Empty string: total length (8) + terminator (4)
1286        assert_eq!(buf.len(), 8 + 4);
1287
1288        // Check total length is 0
1289        assert_eq!(&buf[0..8], &0u64.to_le_bytes());
1290
1291        // Check terminator
1292        assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1293    }
1294
1295    #[test]
1296    fn test_plp_empty_binary() {
1297        let mut buf = BytesMut::new();
1298
1299        encode_plp_binary(&[], &mut buf);
1300
1301        // Empty binary: total length (8) + terminator (4)
1302        assert_eq!(buf.len(), 8 + 4);
1303
1304        // Check total length is 0
1305        assert_eq!(&buf[0..8], &0u64.to_le_bytes());
1306
1307        // Check terminator
1308        assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1309    }
1310
1311    #[test]
1312    fn test_parse_sql_type_max() {
1313        // Test NVARCHAR(MAX) parsing - uses 0xFFFF marker (not doubled for MAX)
1314        let (type_id, len, _, _) = parse_sql_type("NVARCHAR(MAX)");
1315        assert_eq!(type_id, 0xE7);
1316        assert_eq!(len, Some(0xFFFF)); // MAX marker is 0xFFFF
1317
1318        // Test VARBINARY(MAX) parsing
1319        let (type_id, len, _, _) = parse_sql_type("VARBINARY(MAX)");
1320        assert_eq!(type_id, 0xA5);
1321        assert_eq!(len, Some(0xFFFF));
1322
1323        // Test VARCHAR(MAX) parsing
1324        let (type_id, len, _, _) = parse_sql_type("VARCHAR(MAX)");
1325        assert_eq!(type_id, 0xA7);
1326        assert_eq!(len, Some(0xFFFF));
1327
1328        // Verify normal NVARCHAR does double the length
1329        let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
1330        assert_eq!(type_id, 0xE7);
1331        assert_eq!(len, Some(200)); // 100 * 2 for UTF-16
1332    }
1333}