mssql_client/
bulk.rs

1//! Bulk Copy Protocol (BCP) support.
2//!
3//! This module provides first-class support for bulk insert operations using
4//! the TDS Bulk Load protocol (packet type 0x07). BCP is significantly more
5//! efficient than individual INSERT statements for loading large amounts of data.
6//!
7//! ## Performance Benefits
8//!
9//! - Minimal logging (when using simple recovery model)
10//! - Batch commits reduce lock contention
11//! - Direct data streaming without SQL parsing overhead
12//! - Optional table lock for maximum throughput
13//!
14//! ## Usage
15//!
16//! ```rust,ignore
17//! use mssql_client::{Client, BulkInsert, BulkOptions};
18//!
19//! let mut bulk = client
20//!     .bulk_insert("dbo.Users")
21//!     .with_columns(&["id", "name", "email"])
22//!     .with_options(BulkOptions {
23//!         batch_size: 1000,
24//!         check_constraints: true,
25//!         fire_triggers: false,
26//!         keep_nulls: true,
27//!         table_lock: true,
28//!     })
29//!     .build()
30//!     .await?;
31//!
32//! // Send rows
33//! for user in users {
34//!     bulk.send_row(&[&user.id, &user.name, &user.email]).await?;
35//! }
36//!
37//! let result = bulk.finish().await?;
38//! println!("Inserted {} rows", result.rows_affected);
39//! ```
40//!
41//! ## Implementation Notes
42//!
43//! The bulk load protocol uses:
44//! - Packet type 0x07 (BulkLoad)
45//! - COLMETADATA token describing column structure
46//! - ROW tokens containing actual data
47//! - DONE token signaling completion
48//!
49//! Per MS-TDS specification, the row data format matches the server output format
50//! (same as SELECT results) rather than storage format.
51
52use bytes::{BufMut, BytesMut};
53use std::sync::Arc;
54
55use mssql_types::{SqlValue, ToSql, TypeError};
56use tds_protocol::packet::{PacketHeader, PacketStatus, PacketType};
57use tds_protocol::token::{DoneStatus, TokenType};
58
59use crate::error::Error;
60
61/// Options controlling bulk insert behavior.
62///
63/// These options map to SQL Server's BULK INSERT hints and
64/// affect performance, logging, and constraint checking.
65#[derive(Debug, Clone)]
66pub struct BulkOptions {
67    /// Number of rows per batch commit.
68    ///
69    /// Smaller batches use less memory but have more overhead.
70    /// Larger batches are more efficient but hold locks longer.
71    /// Default: 0 (single batch for entire operation).
72    pub batch_size: usize,
73
74    /// Check constraints during insert.
75    ///
76    /// Default: true
77    pub check_constraints: bool,
78
79    /// Fire INSERT triggers on the table.
80    ///
81    /// Default: false (better performance)
82    pub fire_triggers: bool,
83
84    /// Keep NULL values instead of using column defaults.
85    ///
86    /// Default: true
87    pub keep_nulls: bool,
88
89    /// Acquire a table-level lock for the duration of the bulk operation.
90    ///
91    /// This can significantly improve performance by reducing lock
92    /// escalation overhead, but blocks all other access to the table.
93    /// Default: false
94    pub table_lock: bool,
95
96    /// Order hint for the data being inserted.
97    ///
98    /// If data is pre-sorted by the clustered index, specify the columns
99    /// here to avoid a sort operation on the server.
100    /// Default: None
101    pub order_hint: Option<Vec<String>>,
102
103    /// Maximum errors allowed before aborting.
104    ///
105    /// Default: 0 (abort on first error)
106    pub max_errors: u32,
107}
108
109impl Default for BulkOptions {
110    fn default() -> Self {
111        Self {
112            batch_size: 0,
113            check_constraints: true,
114            fire_triggers: false,
115            keep_nulls: true,
116            table_lock: false,
117            order_hint: None,
118            max_errors: 0,
119        }
120    }
121}
122
123/// Column definition for bulk insert.
124#[derive(Debug, Clone)]
125pub struct BulkColumn {
126    /// Column name.
127    pub name: String,
128    /// SQL Server type (e.g., "INT", "NVARCHAR(100)").
129    pub sql_type: String,
130    /// Whether the column allows NULL values.
131    pub nullable: bool,
132    /// Column ordinal (0-based).
133    pub ordinal: usize,
134    /// TDS type ID.
135    type_id: u8,
136    /// Maximum length for variable-length types.
137    max_length: Option<u32>,
138    /// Precision for decimal types.
139    precision: Option<u8>,
140    /// Scale for decimal types.
141    scale: Option<u8>,
142}
143
144impl BulkColumn {
145    /// Create a new bulk column definition.
146    pub fn new<S: Into<String>>(name: S, sql_type: S, ordinal: usize) -> Self {
147        let sql_type_str: String = sql_type.into();
148        let (type_id, max_length, precision, scale) = parse_sql_type(&sql_type_str);
149
150        Self {
151            name: name.into(),
152            sql_type: sql_type_str,
153            nullable: true,
154            ordinal,
155            type_id,
156            max_length,
157            precision,
158            scale,
159        }
160    }
161
162    /// Set whether this column allows NULL values.
163    #[must_use]
164    pub fn with_nullable(mut self, nullable: bool) -> Self {
165        self.nullable = nullable;
166        self
167    }
168}
169
170/// Parse SQL type string into TDS type information.
171fn parse_sql_type(sql_type: &str) -> (u8, Option<u32>, Option<u8>, Option<u8>) {
172    let upper = sql_type.to_uppercase();
173
174    // Extract base type and parameters
175    let (base, params) = if let Some(paren_pos) = upper.find('(') {
176        let base = &upper[..paren_pos];
177        let params_str = upper[paren_pos + 1..].trim_end_matches(')');
178        (base, Some(params_str))
179    } else {
180        (upper.as_str(), None)
181    };
182
183    match base {
184        "BIT" => (0x32, None, None, None),
185        "TINYINT" => (0x30, None, None, None),
186        "SMALLINT" => (0x34, None, None, None),
187        "INT" => (0x38, None, None, None),
188        "BIGINT" => (0x7F, None, None, None),
189        "REAL" => (0x3B, None, None, None),
190        "FLOAT" => (0x3E, None, None, None),
191        "DATE" => (0x28, None, None, None),
192        "TIME" => {
193            let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
194            (0x29, None, None, Some(scale))
195        }
196        "DATETIME" => (0x3D, None, None, None),
197        "DATETIME2" => {
198            let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
199            (0x2A, None, None, Some(scale))
200        }
201        "DATETIMEOFFSET" => {
202            let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
203            (0x2B, None, None, Some(scale))
204        }
205        "SMALLDATETIME" => (0x3F, None, None, None),
206        "UNIQUEIDENTIFIER" => (0x24, Some(16), None, None),
207        "VARCHAR" | "CHAR" => {
208            let len = params
209                .and_then(|p| {
210                    if p == "MAX" {
211                        Some(0xFFFF_u32)
212                    } else {
213                        p.parse().ok()
214                    }
215                })
216                .unwrap_or(8000);
217            (0xA7, Some(len), None, None)
218        }
219        "NVARCHAR" | "NCHAR" => {
220            let is_max = params.map(|p| p == "MAX").unwrap_or(false);
221            if is_max {
222                // MAX types use 0xFFFF marker (not doubled)
223                (0xE7, Some(0xFFFF), None, None)
224            } else {
225                // Normal lengths are in characters, double for UTF-16 byte length
226                let len = params.and_then(|p| p.parse().ok()).unwrap_or(4000);
227                (0xE7, Some(len * 2), None, None)
228            }
229        }
230        "VARBINARY" | "BINARY" => {
231            let len = params
232                .and_then(|p| {
233                    if p == "MAX" {
234                        Some(0xFFFF_u32)
235                    } else {
236                        p.parse().ok()
237                    }
238                })
239                .unwrap_or(8000);
240            (0xA5, Some(len), None, None)
241        }
242        "DECIMAL" | "NUMERIC" => {
243            let (precision, scale) = if let Some(p) = params {
244                let parts: Vec<&str> = p.split(',').map(|s| s.trim()).collect();
245                (
246                    parts.first().and_then(|s| s.parse().ok()).unwrap_or(18),
247                    parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0),
248                )
249            } else {
250                (18, 0)
251            };
252            (0x6C, None, Some(precision), Some(scale))
253        }
254        "MONEY" => (0x3C, Some(8), None, None),
255        "SMALLMONEY" => (0x7A, Some(4), None, None),
256        "XML" => (0xF1, Some(0xFFFF), None, None),
257        "TEXT" => (0x23, Some(0x7FFF_FFFF), None, None),
258        "NTEXT" => (0x63, Some(0x7FFF_FFFF), None, None),
259        "IMAGE" => (0x22, Some(0x7FFF_FFFF), None, None),
260        _ => (0xE7, Some(8000), None, None), // Default to NVARCHAR(4000)
261    }
262}
263
264/// Result of a bulk insert operation.
265#[derive(Debug, Clone)]
266pub struct BulkInsertResult {
267    /// Total number of rows inserted.
268    pub rows_affected: u64,
269    /// Number of batches committed.
270    pub batches_committed: u32,
271    /// Whether any errors were encountered.
272    pub has_errors: bool,
273}
274
275/// Builder for configuring a bulk insert operation.
276#[derive(Debug)]
277pub struct BulkInsertBuilder {
278    table_name: String,
279    columns: Vec<BulkColumn>,
280    options: BulkOptions,
281}
282
283impl BulkInsertBuilder {
284    /// Create a new bulk insert builder for the specified table.
285    pub fn new<S: Into<String>>(table_name: S) -> Self {
286        Self {
287            table_name: table_name.into(),
288            columns: Vec::new(),
289            options: BulkOptions::default(),
290        }
291    }
292
293    /// Specify the columns to insert.
294    ///
295    /// Columns will be queried from the server if not specified,
296    /// but providing them explicitly is more efficient.
297    #[must_use]
298    pub fn with_columns(mut self, column_names: &[&str]) -> Self {
299        self.columns = column_names
300            .iter()
301            .enumerate()
302            .map(|(i, name)| BulkColumn::new(*name, "NVARCHAR(MAX)", i))
303            .collect();
304        self
305    }
306
307    /// Specify columns with full type information.
308    #[must_use]
309    pub fn with_typed_columns(mut self, columns: Vec<BulkColumn>) -> Self {
310        self.columns = columns;
311        self
312    }
313
314    /// Set bulk insert options.
315    #[must_use]
316    pub fn with_options(mut self, options: BulkOptions) -> Self {
317        self.options = options;
318        self
319    }
320
321    /// Set the batch size.
322    #[must_use]
323    pub fn batch_size(mut self, size: usize) -> Self {
324        self.options.batch_size = size;
325        self
326    }
327
328    /// Enable or disable table lock.
329    #[must_use]
330    pub fn table_lock(mut self, enabled: bool) -> Self {
331        self.options.table_lock = enabled;
332        self
333    }
334
335    /// Enable or disable trigger firing.
336    #[must_use]
337    pub fn fire_triggers(mut self, enabled: bool) -> Self {
338        self.options.fire_triggers = enabled;
339        self
340    }
341
342    /// Get the table name.
343    pub fn table_name(&self) -> &str {
344        &self.table_name
345    }
346
347    /// Get the columns.
348    pub fn columns(&self) -> &[BulkColumn] {
349        &self.columns
350    }
351
352    /// Get the options.
353    pub fn options(&self) -> &BulkOptions {
354        &self.options
355    }
356
357    /// Build the INSERT BULK SQL statement.
358    pub fn build_insert_bulk_statement(&self) -> String {
359        let mut sql = format!("INSERT BULK {}", self.table_name);
360
361        // Add column definitions
362        if !self.columns.is_empty() {
363            sql.push_str(" (");
364            let cols: Vec<String> = self
365                .columns
366                .iter()
367                .map(|c| format!("{} {}", c.name, c.sql_type))
368                .collect();
369            sql.push_str(&cols.join(", "));
370            sql.push(')');
371        }
372
373        // Add WITH clause for options
374        let mut hints: Vec<String> = Vec::new();
375
376        if self.options.check_constraints {
377            hints.push("CHECK_CONSTRAINTS".to_string());
378        }
379        if self.options.fire_triggers {
380            hints.push("FIRE_TRIGGERS".to_string());
381        }
382        if self.options.keep_nulls {
383            hints.push("KEEP_NULLS".to_string());
384        }
385        if self.options.table_lock {
386            hints.push("TABLOCK".to_string());
387        }
388        if self.options.batch_size > 0 {
389            hints.push(format!("ROWS_PER_BATCH = {}", self.options.batch_size));
390        }
391
392        if let Some(ref order) = self.options.order_hint {
393            hints.push(format!("ORDER({})", order.join(", ")));
394        }
395
396        if !hints.is_empty() {
397            sql.push_str(" WITH (");
398            sql.push_str(&hints.join(", "));
399            sql.push(')');
400        }
401
402        sql
403    }
404}
405
406/// Active bulk insert operation.
407///
408/// This struct manages the streaming of row data to the server.
409/// Call `send_row()` for each row, then `finish()` to complete.
410pub struct BulkInsert {
411    /// Column metadata.
412    columns: Arc<[BulkColumn]>,
413    /// Buffer for accumulating rows.
414    buffer: BytesMut,
415    /// Rows in current batch.
416    rows_in_batch: usize,
417    /// Total rows sent.
418    total_rows: u64,
419    /// Batch size (0 = single batch).
420    batch_size: usize,
421    /// Number of batches committed.
422    batches_committed: u32,
423    /// Packet ID counter.
424    packet_id: u8,
425}
426
427impl BulkInsert {
428    /// Create a new bulk insert operation.
429    pub fn new(columns: Vec<BulkColumn>, batch_size: usize) -> Self {
430        let mut bulk = Self {
431            columns: columns.into(),
432            buffer: BytesMut::with_capacity(64 * 1024), // 64KB initial buffer
433            rows_in_batch: 0,
434            total_rows: 0,
435            batch_size,
436            batches_committed: 0,
437            packet_id: 1,
438        };
439
440        // Write COLMETADATA token
441        bulk.write_colmetadata();
442
443        bulk
444    }
445
446    /// Write the COLMETADATA token to the buffer.
447    fn write_colmetadata(&mut self) {
448        let buf = &mut self.buffer;
449
450        // Token type
451        buf.put_u8(TokenType::ColMetaData as u8);
452
453        // Column count
454        buf.put_u16_le(self.columns.len() as u16);
455
456        for col in self.columns.iter() {
457            // User type (always 0 for basic types)
458            buf.put_u32_le(0);
459
460            // Flags: Nullable (bit 0) | CaseSen (bit 1) | Updateable (bits 2-3) | etc.
461            let flags: u16 = if col.nullable { 0x0001 } else { 0x0000 };
462            buf.put_u16_le(flags);
463
464            // Type info
465            buf.put_u8(col.type_id);
466
467            // Type-specific length/precision/scale
468            match col.type_id {
469                // Fixed-length types - no additional info needed
470                0x32 | 0x30 | 0x34 | 0x38 | 0x7F | 0x3B | 0x3E | 0x3D | 0x3F | 0x28 => {}
471
472                // Variable-length string/binary types
473                0xE7 | 0xA7 | 0xA5 | 0xAD => {
474                    // Max length (2 bytes for normal, 4 bytes for MAX)
475                    let max_len = col.max_length.unwrap_or(8000);
476                    if max_len == 0xFFFF {
477                        buf.put_u16_le(0xFFFF);
478                    } else {
479                        buf.put_u16_le(max_len as u16);
480                    }
481
482                    // Collation for string types (5 bytes)
483                    if col.type_id == 0xE7 || col.type_id == 0xA7 {
484                        // Default collation (Latin1_General_CI_AS)
485                        buf.put_u32_le(0x0409_0904); // LCID + flags
486                        buf.put_u8(52); // Sort ID
487                    }
488                }
489
490                // Decimal/Numeric
491                0x6C | 0x6A => {
492                    // Length (calculated from precision)
493                    let precision = col.precision.unwrap_or(18);
494                    let len = decimal_byte_length(precision);
495                    buf.put_u8(len);
496                    buf.put_u8(precision);
497                    buf.put_u8(col.scale.unwrap_or(0));
498                }
499
500                // Time-based with scale
501                0x29..=0x2B => {
502                    buf.put_u8(col.scale.unwrap_or(7));
503                }
504
505                // GUID
506                0x24 => {
507                    buf.put_u8(16);
508                }
509
510                // Other types - write max length if present
511                _ => {
512                    if let Some(len) = col.max_length {
513                        if len <= 0xFFFF {
514                            buf.put_u16_le(len as u16);
515                        }
516                    }
517                }
518            }
519
520            // Column name (B_VARCHAR format: 1-byte length prefix)
521            let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
522            buf.put_u8(name_utf16.len() as u8);
523            for code_unit in name_utf16 {
524                buf.put_u16_le(code_unit);
525            }
526        }
527    }
528
529    /// Send a row of data.
530    ///
531    /// The values must match the column order and types specified
532    /// when creating the bulk insert.
533    ///
534    /// # Errors
535    ///
536    /// Returns an error if:
537    /// - Wrong number of values provided
538    /// - A value cannot be converted to the expected type
539    pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
540        if values.len() != self.columns.len() {
541            return Err(Error::Config(format!(
542                "expected {} values, got {}",
543                self.columns.len(),
544                values.len()
545            )));
546        }
547
548        // Convert all values to SqlValue
549        let sql_values: Result<Vec<SqlValue>, TypeError> =
550            values.iter().map(|v| v.to_sql()).collect();
551        let sql_values = sql_values.map_err(Error::from)?;
552
553        self.write_row(&sql_values)?;
554
555        self.rows_in_batch += 1;
556        self.total_rows += 1;
557
558        Ok(())
559    }
560
561    /// Send a row of pre-converted SQL values.
562    pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
563        if values.len() != self.columns.len() {
564            return Err(Error::Config(format!(
565                "expected {} values, got {}",
566                self.columns.len(),
567                values.len()
568            )));
569        }
570
571        self.write_row(values)?;
572
573        self.rows_in_batch += 1;
574        self.total_rows += 1;
575
576        Ok(())
577    }
578
579    /// Write a ROW token to the buffer.
580    fn write_row(&mut self, values: &[SqlValue]) -> Result<(), Error> {
581        // ROW token type
582        self.buffer.put_u8(TokenType::Row as u8);
583
584        // Collect column info needed for encoding to avoid borrow conflict
585        let columns: Vec<_> = self.columns.iter().cloned().collect();
586
587        // Write each column value
588        for (i, (col, value)) in columns.iter().zip(values.iter()).enumerate() {
589            self.encode_column_value(col, value)
590                .map_err(|e| Error::Config(format!("failed to encode column {}: {}", i, e)))?;
591        }
592
593        Ok(())
594    }
595
596    /// Encode a column value according to its type.
597    fn encode_column_value(&mut self, col: &BulkColumn, value: &SqlValue) -> Result<(), TypeError> {
598        let buf = &mut self.buffer;
599
600        // Check if this column uses PLP (Partially Length-Prefixed) encoding
601        // MAX types (max_length == 0xFFFF) use PLP format
602        let is_plp_type =
603            col.max_length == Some(0xFFFF) && matches!(col.type_id, 0xE7 | 0xA7 | 0xA5 | 0xAD);
604
605        match value {
606            SqlValue::Null => {
607                // NULL encoding depends on type
608                match col.type_id {
609                    // Variable-length types
610                    0xE7 | 0xA7 | 0xA5 | 0xAD => {
611                        if is_plp_type {
612                            // PLP NULL: 0xFFFFFFFFFFFFFFFF
613                            buf.put_u64_le(0xFFFF_FFFF_FFFF_FFFF);
614                        } else {
615                            // Standard NULL: 0xFFFF length marker
616                            buf.put_u16_le(0xFFFF);
617                        }
618                    }
619                    // Nullable fixed types use 0 length
620                    0x26 | 0x6C | 0x6A | 0x24 | 0x29 | 0x2A | 0x2B => {
621                        buf.put_u8(0);
622                    }
623                    // Fixed types without nullable variant
624                    _ => {
625                        if col.nullable {
626                            buf.put_u8(0);
627                        } else {
628                            return Err(TypeError::UnexpectedNull);
629                        }
630                    }
631                }
632            }
633
634            SqlValue::Bool(v) => {
635                buf.put_u8(1); // Length
636                buf.put_u8(if *v { 1 } else { 0 });
637            }
638
639            SqlValue::TinyInt(v) => {
640                buf.put_u8(1); // Length
641                buf.put_u8(*v);
642            }
643
644            SqlValue::SmallInt(v) => {
645                buf.put_u8(2); // Length
646                buf.put_i16_le(*v);
647            }
648
649            SqlValue::Int(v) => {
650                buf.put_u8(4); // Length
651                buf.put_i32_le(*v);
652            }
653
654            SqlValue::BigInt(v) => {
655                buf.put_u8(8); // Length
656                buf.put_i64_le(*v);
657            }
658
659            SqlValue::Float(v) => {
660                buf.put_u8(4); // Length
661                buf.put_f32_le(*v);
662            }
663
664            SqlValue::Double(v) => {
665                buf.put_u8(8); // Length
666                buf.put_f64_le(*v);
667            }
668
669            SqlValue::String(s) => {
670                // UTF-16LE encoding for NVARCHAR
671                let utf16: Vec<u16> = s.encode_utf16().collect();
672                let byte_len = utf16.len() * 2;
673
674                if is_plp_type {
675                    // PLP format for MAX types - supports unlimited size
676                    // Send as a single chunk for simplicity
677                    encode_plp_string(&utf16, buf);
678                } else if byte_len > 0xFFFF {
679                    // Non-MAX column can't hold this much data
680                    return Err(TypeError::BufferTooSmall {
681                        needed: byte_len,
682                        available: 0xFFFF,
683                    });
684                } else {
685                    // Standard encoding with 2-byte length prefix
686                    buf.put_u16_le(byte_len as u16);
687                    for code_unit in utf16 {
688                        buf.put_u16_le(code_unit);
689                    }
690                }
691            }
692
693            SqlValue::Binary(b) => {
694                if is_plp_type {
695                    // PLP format for MAX types - supports unlimited size
696                    encode_plp_binary(b, buf);
697                } else if b.len() > 0xFFFF {
698                    // Non-MAX column can't hold this much data
699                    return Err(TypeError::BufferTooSmall {
700                        needed: b.len(),
701                        available: 0xFFFF,
702                    });
703                } else {
704                    // Standard encoding with 2-byte length prefix
705                    buf.put_u16_le(b.len() as u16);
706                    buf.put_slice(b);
707                }
708            }
709
710            // Feature-gated types - use mssql_types::encode module
711            #[cfg(feature = "decimal")]
712            SqlValue::Decimal(d) => {
713                let precision = col.precision.unwrap_or(18);
714                let len = decimal_byte_length(precision);
715                buf.put_u8(len);
716
717                // Sign: 0 = negative, 1 = positive
718                buf.put_u8(if d.is_sign_negative() { 0 } else { 1 });
719
720                // Mantissa as unsigned 128-bit integer
721                let mantissa = d.mantissa().unsigned_abs();
722                let mantissa_bytes = mantissa.to_le_bytes();
723                buf.put_slice(&mantissa_bytes[..((len - 1) as usize)]);
724            }
725
726            #[cfg(feature = "uuid")]
727            SqlValue::Uuid(u) => {
728                buf.put_u8(16); // Length
729                // Use mssql_types encode function
730                mssql_types::encode::encode_uuid(*u, buf);
731            }
732
733            #[cfg(feature = "chrono")]
734            SqlValue::Date(d) => {
735                buf.put_u8(3); // Length
736                mssql_types::encode::encode_date(*d, buf);
737            }
738
739            #[cfg(feature = "chrono")]
740            SqlValue::Time(t) => {
741                let scale = col.scale.unwrap_or(7);
742                let len = time_byte_length(scale);
743                buf.put_u8(len);
744                // Encode time with proper scale handling
745                encode_time_with_scale(*t, scale, buf);
746            }
747
748            #[cfg(feature = "chrono")]
749            SqlValue::DateTime(dt) => {
750                let scale = col.scale.unwrap_or(7);
751                let time_len = time_byte_length(scale);
752                let total_len = time_len + 3;
753                buf.put_u8(total_len);
754                // Encode time then date
755                encode_time_with_scale(dt.time(), scale, buf);
756                mssql_types::encode::encode_date(dt.date(), buf);
757            }
758
759            #[cfg(feature = "chrono")]
760            SqlValue::DateTimeOffset(dto) => {
761                let scale = col.scale.unwrap_or(7);
762                let time_len = time_byte_length(scale);
763                let total_len = time_len + 3 + 2;
764                buf.put_u8(total_len);
765                // Use mssql_types encode
766                encode_time_with_scale(dto.time(), scale, buf);
767                mssql_types::encode::encode_date(dto.date_naive(), buf);
768                // Timezone offset in minutes
769                use chrono::Offset;
770                let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
771                buf.put_i16_le(offset_minutes);
772            }
773
774            #[cfg(feature = "json")]
775            SqlValue::Json(j) => {
776                let s = j.to_string();
777                encode_nvarchar_value(&s, buf)?;
778            }
779
780            SqlValue::Xml(x) => {
781                encode_nvarchar_value(x, buf)?;
782            }
783
784            SqlValue::Tvp(_) => {
785                // TVPs are not valid in bulk copy operations - they're for RPC parameters only
786                return Err(TypeError::UnsupportedConversion {
787                    from: "TVP".to_string(),
788                    to: "bulk copy value",
789                });
790            }
791            // Handle future SqlValue variants
792            _ => {
793                return Err(TypeError::UnsupportedConversion {
794                    from: value.type_name().to_string(),
795                    to: "bulk copy value",
796                });
797            }
798        }
799
800        Ok(())
801    }
802}
803
804/// Encode a string as NVARCHAR with length prefix.
805fn encode_nvarchar_value(s: &str, buf: &mut BytesMut) -> Result<(), TypeError> {
806    let utf16: Vec<u16> = s.encode_utf16().collect();
807    let byte_len = utf16.len() * 2;
808
809    if byte_len > 0xFFFF {
810        return Err(TypeError::BufferTooSmall {
811            needed: byte_len,
812            available: 0xFFFF,
813        });
814    }
815
816    buf.put_u16_le(byte_len as u16);
817    for code_unit in utf16 {
818        buf.put_u16_le(code_unit);
819    }
820    Ok(())
821}
822
823/// Encode a UTF-16 string using PLP (Partially Length-Prefixed) format.
824///
825/// PLP format (per MS-TDS specification):
826/// - 8 bytes: total length in bytes (little-endian)
827/// - Chunks: 4-byte chunk length + data, repeated
828/// - Terminator: 4 bytes of zero
829///
830/// For simplicity, we send the entire value as a single chunk.
831/// This is efficient for bulk operations where we already have the complete data.
832fn encode_plp_string(utf16: &[u16], buf: &mut BytesMut) {
833    let byte_len = utf16.len() * 2;
834
835    // Total length (8 bytes)
836    buf.put_u64_le(byte_len as u64);
837
838    if byte_len > 0 {
839        // Single chunk: length (4 bytes) + data
840        buf.put_u32_le(byte_len as u32);
841        for code_unit in utf16 {
842            buf.put_u16_le(*code_unit);
843        }
844    }
845
846    // Terminator chunk (length = 0)
847    buf.put_u32_le(0);
848}
849
850/// Encode binary data using PLP (Partially Length-Prefixed) format.
851///
852/// PLP format (per MS-TDS specification):
853/// - 8 bytes: total length in bytes (little-endian)
854/// - Chunks: 4-byte chunk length + data, repeated
855/// - Terminator: 4 bytes of zero
856///
857/// For simplicity, we send the entire value as a single chunk.
858fn encode_plp_binary(data: &[u8], buf: &mut BytesMut) {
859    // Total length (8 bytes)
860    buf.put_u64_le(data.len() as u64);
861
862    if !data.is_empty() {
863        // Single chunk: length (4 bytes) + data
864        buf.put_u32_le(data.len() as u32);
865        buf.put_slice(data);
866    }
867
868    // Terminator chunk (length = 0)
869    buf.put_u32_le(0);
870}
871
872/// Encode time with specific scale (for bulk copy).
873#[cfg(feature = "chrono")]
874fn encode_time_with_scale(time: chrono::NaiveTime, scale: u8, buf: &mut BytesMut) {
875    use chrono::Timelike;
876
877    let nanos = time.num_seconds_from_midnight() as u64 * 1_000_000_000 + time.nanosecond() as u64;
878    let intervals = nanos / time_scale_divisor(scale);
879    let len = time_byte_length(scale);
880
881    for i in 0..len {
882        buf.put_u8(((intervals >> (i * 8)) & 0xFF) as u8);
883    }
884}
885
886impl BulkInsert {
887    /// Write the DONE token signaling completion.
888    fn write_done(&mut self) {
889        let buf = &mut self.buffer;
890
891        buf.put_u8(TokenType::Done as u8);
892
893        // Status: FINAL (0x00) | COUNT (0x10)
894        let status = DoneStatus {
895            more: false,
896            error: false,
897            in_xact: false,
898            count: true,
899            attn: false,
900            srverror: false,
901        };
902        buf.put_u16_le(status.to_bits());
903
904        // Current command (0 for bulk load)
905        buf.put_u16_le(0);
906
907        // Row count
908        buf.put_u64_le(self.total_rows);
909    }
910
911    /// Get the buffered data as packets ready to send.
912    ///
913    /// Returns a vector of complete TDS packets with BulkLoad packet type (0x07).
914    pub fn take_packets(&mut self) -> Vec<BytesMut> {
915        const MAX_PACKET_SIZE: usize = 4096;
916        const HEADER_SIZE: usize = 8;
917        const MAX_PAYLOAD: usize = MAX_PACKET_SIZE - HEADER_SIZE;
918
919        let data = self.buffer.split();
920        let mut packets = Vec::new();
921        let mut offset = 0;
922
923        while offset < data.len() {
924            let remaining = data.len() - offset;
925            let payload_size = remaining.min(MAX_PAYLOAD);
926            let is_last = offset + payload_size >= data.len();
927
928            let mut packet = BytesMut::with_capacity(MAX_PACKET_SIZE);
929
930            // Write packet header
931            let header = PacketHeader {
932                packet_type: PacketType::BulkLoad,
933                status: if is_last {
934                    PacketStatus::END_OF_MESSAGE
935                } else {
936                    PacketStatus::NORMAL
937                },
938                length: (HEADER_SIZE + payload_size) as u16,
939                spid: 0,
940                packet_id: self.packet_id,
941                window: 0,
942            };
943
944            header.encode(&mut packet);
945
946            // Write payload
947            packet.put_slice(&data[offset..offset + payload_size]);
948
949            packets.push(packet);
950            offset += payload_size;
951            self.packet_id = self.packet_id.wrapping_add(1);
952        }
953
954        packets
955    }
956
957    /// Get total rows sent so far.
958    pub fn total_rows(&self) -> u64 {
959        self.total_rows
960    }
961
962    /// Get rows in current batch.
963    pub fn rows_in_batch(&self) -> usize {
964        self.rows_in_batch
965    }
966
967    /// Check if a batch flush is needed.
968    pub fn should_flush(&self) -> bool {
969        self.batch_size > 0 && self.rows_in_batch >= self.batch_size
970    }
971
972    /// Prepare for finishing the bulk operation.
973    /// Writes the DONE token and returns final packets.
974    pub fn finish_packets(&mut self) -> Vec<BytesMut> {
975        self.write_done();
976        self.take_packets()
977    }
978
979    /// Create a result from the current state.
980    pub fn result(&self) -> BulkInsertResult {
981        BulkInsertResult {
982            rows_affected: self.total_rows,
983            batches_committed: self.batches_committed,
984            has_errors: false,
985        }
986    }
987}
988
989/// Calculate byte length for decimal based on precision.
990fn decimal_byte_length(precision: u8) -> u8 {
991    match precision {
992        1..=9 => 5,
993        10..=19 => 9,
994        20..=28 => 13,
995        29..=38 => 17,
996        _ => 17, // Max precision
997    }
998}
999
1000/// Calculate byte length for time based on scale.
1001#[cfg(feature = "chrono")]
1002fn time_byte_length(scale: u8) -> u8 {
1003    match scale {
1004        0..=2 => 3,
1005        3..=4 => 4,
1006        5..=7 => 5,
1007        _ => 5,
1008    }
1009}
1010
1011/// Get the divisor for time scale.
1012#[cfg(feature = "chrono")]
1013fn time_scale_divisor(scale: u8) -> u64 {
1014    match scale {
1015        0 => 1_000_000_000,
1016        1 => 100_000_000,
1017        2 => 10_000_000,
1018        3 => 1_000_000,
1019        4 => 100_000,
1020        5 => 10_000,
1021        6 => 1_000,
1022        7 => 100,
1023        _ => 100,
1024    }
1025}
1026
1027#[cfg(test)]
1028#[allow(clippy::unwrap_used)]
1029mod tests {
1030    use super::*;
1031
1032    #[test]
1033    fn test_bulk_options_default() {
1034        let opts = BulkOptions::default();
1035        assert_eq!(opts.batch_size, 0);
1036        assert!(opts.check_constraints);
1037        assert!(!opts.fire_triggers);
1038        assert!(opts.keep_nulls);
1039        assert!(!opts.table_lock);
1040    }
1041
1042    #[test]
1043    fn test_bulk_column_creation() {
1044        let col = BulkColumn::new("id", "INT", 0);
1045        assert_eq!(col.name, "id");
1046        assert_eq!(col.type_id, 0x38);
1047        assert!(col.nullable);
1048    }
1049
1050    #[test]
1051    fn test_parse_sql_type() {
1052        let (type_id, len, _prec, _scale) = parse_sql_type("INT");
1053        assert_eq!(type_id, 0x38);
1054        assert!(len.is_none());
1055
1056        let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
1057        assert_eq!(type_id, 0xE7);
1058        assert_eq!(len, Some(200)); // UTF-16 doubles
1059
1060        let (type_id, _, prec, scale) = parse_sql_type("DECIMAL(10,2)");
1061        assert_eq!(type_id, 0x6C);
1062        assert_eq!(prec, Some(10));
1063        assert_eq!(scale, Some(2));
1064    }
1065
1066    #[test]
1067    fn test_insert_bulk_statement() {
1068        let builder = BulkInsertBuilder::new("dbo.Users")
1069            .with_typed_columns(vec![
1070                BulkColumn::new("id", "INT", 0),
1071                BulkColumn::new("name", "NVARCHAR(100)", 1),
1072            ])
1073            .table_lock(true);
1074
1075        let sql = builder.build_insert_bulk_statement();
1076        assert!(sql.contains("INSERT BULK dbo.Users"));
1077        assert!(sql.contains("TABLOCK"));
1078    }
1079
1080    #[test]
1081    fn test_bulk_insert_creation() {
1082        let columns = vec![
1083            BulkColumn::new("id", "INT", 0),
1084            BulkColumn::new("name", "NVARCHAR(100)", 1),
1085        ];
1086
1087        let bulk = BulkInsert::new(columns, 1000);
1088        assert_eq!(bulk.total_rows(), 0);
1089        assert_eq!(bulk.rows_in_batch(), 0);
1090        assert!(!bulk.should_flush());
1091    }
1092
1093    #[test]
1094    fn test_decimal_byte_length() {
1095        assert_eq!(decimal_byte_length(5), 5);
1096        assert_eq!(decimal_byte_length(15), 9);
1097        assert_eq!(decimal_byte_length(25), 13);
1098        assert_eq!(decimal_byte_length(35), 17);
1099    }
1100
1101    #[test]
1102    #[cfg(feature = "chrono")]
1103    fn test_time_byte_length() {
1104        assert_eq!(time_byte_length(0), 3);
1105        assert_eq!(time_byte_length(3), 4);
1106        assert_eq!(time_byte_length(7), 5);
1107    }
1108
1109    #[test]
1110    fn test_plp_string_encoding() {
1111        let mut buf = BytesMut::new();
1112        let text = "Hello";
1113        let utf16: Vec<u16> = text.encode_utf16().collect();
1114
1115        encode_plp_string(&utf16, &mut buf);
1116
1117        // Verify structure:
1118        // - 8 bytes total length
1119        // - 4 bytes chunk length
1120        // - data (5 chars * 2 bytes = 10 bytes)
1121        // - 4 bytes terminator (0)
1122        assert_eq!(buf.len(), 8 + 4 + 10 + 4);
1123
1124        // Check total length
1125        assert_eq!(&buf[0..8], &10u64.to_le_bytes());
1126
1127        // Check chunk length
1128        assert_eq!(&buf[8..12], &10u32.to_le_bytes());
1129
1130        // Check terminator
1131        assert_eq!(&buf[22..26], &0u32.to_le_bytes());
1132    }
1133
1134    #[test]
1135    fn test_plp_binary_encoding() {
1136        let mut buf = BytesMut::new();
1137        let data = b"test binary data";
1138
1139        encode_plp_binary(data, &mut buf);
1140
1141        // Verify structure:
1142        // - 8 bytes total length
1143        // - 4 bytes chunk length
1144        // - data (16 bytes)
1145        // - 4 bytes terminator (0)
1146        assert_eq!(buf.len(), 8 + 4 + 16 + 4);
1147
1148        // Check total length
1149        assert_eq!(&buf[0..8], &16u64.to_le_bytes());
1150
1151        // Check chunk length
1152        assert_eq!(&buf[8..12], &16u32.to_le_bytes());
1153
1154        // Check data
1155        assert_eq!(&buf[12..28], data);
1156
1157        // Check terminator
1158        assert_eq!(&buf[28..32], &0u32.to_le_bytes());
1159    }
1160
1161    #[test]
1162    fn test_plp_empty_string() {
1163        let mut buf = BytesMut::new();
1164        let utf16: Vec<u16> = "".encode_utf16().collect();
1165
1166        encode_plp_string(&utf16, &mut buf);
1167
1168        // Empty string: total length (8) + terminator (4)
1169        assert_eq!(buf.len(), 8 + 4);
1170
1171        // Check total length is 0
1172        assert_eq!(&buf[0..8], &0u64.to_le_bytes());
1173
1174        // Check terminator
1175        assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1176    }
1177
1178    #[test]
1179    fn test_plp_empty_binary() {
1180        let mut buf = BytesMut::new();
1181
1182        encode_plp_binary(&[], &mut buf);
1183
1184        // Empty binary: total length (8) + terminator (4)
1185        assert_eq!(buf.len(), 8 + 4);
1186
1187        // Check total length is 0
1188        assert_eq!(&buf[0..8], &0u64.to_le_bytes());
1189
1190        // Check terminator
1191        assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1192    }
1193
1194    #[test]
1195    fn test_parse_sql_type_max() {
1196        // Test NVARCHAR(MAX) parsing - uses 0xFFFF marker (not doubled for MAX)
1197        let (type_id, len, _, _) = parse_sql_type("NVARCHAR(MAX)");
1198        assert_eq!(type_id, 0xE7);
1199        assert_eq!(len, Some(0xFFFF)); // MAX marker is 0xFFFF
1200
1201        // Test VARBINARY(MAX) parsing
1202        let (type_id, len, _, _) = parse_sql_type("VARBINARY(MAX)");
1203        assert_eq!(type_id, 0xA5);
1204        assert_eq!(len, Some(0xFFFF));
1205
1206        // Test VARCHAR(MAX) parsing
1207        let (type_id, len, _, _) = parse_sql_type("VARCHAR(MAX)");
1208        assert_eq!(type_id, 0xA7);
1209        assert_eq!(len, Some(0xFFFF));
1210
1211        // Verify normal NVARCHAR does double the length
1212        let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
1213        assert_eq!(type_id, 0xE7);
1214        assert_eq!(len, Some(200)); // 100 * 2 for UTF-16
1215    }
1216}