Skip to main content

mssql_client/
bulk.rs

1//! Bulk Copy Protocol (BCP) support.
2//!
3//! This module provides first-class support for bulk insert operations using
4//! the TDS Bulk Load protocol (packet type 0x07). BCP is significantly more
5//! efficient than individual INSERT statements for loading large amounts of data.
6//!
7//! ## Performance Benefits
8//!
9//! - Minimal logging (when using simple recovery model)
10//! - Batch commits reduce lock contention
11//! - Direct data streaming without SQL parsing overhead
12//! - Optional table lock for maximum throughput
13//!
14//! ## Usage
15//!
16//! ```text
17//! use mssql_client::{BulkInsertBuilder, BulkColumn, BulkOptions};
18//!
19//! let builder = BulkInsertBuilder::new("dbo.Users")
20//!     .with_typed_columns(vec![
21//!         BulkColumn::new("id", "INT", 0)?,
22//!         BulkColumn::new("name", "NVARCHAR(100)", 1)?,
23//!         BulkColumn::new("email", "NVARCHAR(200)", 2)?,
24//!     ])
25//!     .with_options(BulkOptions {
26//!         batch_size: 1000,
27//!         check_constraints: true,
28//!         fire_triggers: false,
29//!         keep_nulls: true,
30//!         table_lock: true,
31//!         order_hint: None,
32//!     });
33//!
34//! let mut writer = client.bulk_insert(&builder).await?;
35//!
36//! // Send rows — buffered in memory, sent on finish()
37//! for user in users {
38//!     writer.send_row(&[&user.id, &user.name, &user.email])?;
39//! }
40//!
41//! let result = writer.finish().await?;
42//! println!("Inserted {} rows", result.rows_affected);
43//! ```
44//!
45//! ## Implementation Notes
46//!
47//! The bulk load protocol uses:
48//! - Packet type 0x07 (BulkLoad)
49//! - COLMETADATA token describing column structure
50//! - ROW tokens containing actual data
51//! - DONE token signaling completion
52//!
53//! Per MS-TDS specification, the row data format matches the server output format
54//! (same as SELECT results) rather than storage format.
55
56use bytes::{BufMut, BytesMut};
57use once_cell::sync::Lazy;
58use regex::Regex;
59use std::sync::Arc;
60
61use mssql_types::{SqlValue, ToSql, TypeError};
62use tds_protocol::packet::{PacketHeader, PacketStatus, PacketType};
63use tds_protocol::token::{Collation, DoneStatus, TokenType};
64
65use crate::error::Error;
66
67/// Options controlling bulk insert behavior.
68///
69/// These options map to SQL Server's BULK INSERT hints and
70/// affect performance, logging, and constraint checking.
71#[derive(Debug, Clone)]
72pub struct BulkOptions {
73    /// Rows-per-batch hint sent to the server.
74    ///
75    /// When non-zero this is emitted as the `ROWS_PER_BATCH` hint on the
76    /// `INSERT BULK` statement, which helps the server pick a query plan.
77    /// It does **not** change client-side behavior: [`BulkWriter`] buffers
78    /// all rows in memory and sends them as a single batch on
79    /// [`finish()`](BulkWriter::finish). Incremental flushing is planned
80    /// alongside the response-streaming work.
81    /// Default: 0 (no hint).
82    pub batch_size: usize,
83
84    /// Check constraints during insert.
85    ///
86    /// Default: true
87    pub check_constraints: bool,
88
89    /// Fire INSERT triggers on the table.
90    ///
91    /// Default: false (better performance)
92    pub fire_triggers: bool,
93
94    /// Keep NULL values instead of using column defaults.
95    ///
96    /// Default: true
97    pub keep_nulls: bool,
98
99    /// Acquire a table-level lock for the duration of the bulk operation.
100    ///
101    /// This can significantly improve performance by reducing lock
102    /// escalation overhead, but blocks all other access to the table.
103    /// Default: false
104    pub table_lock: bool,
105
106    /// Order hint for the data being inserted.
107    ///
108    /// If data is pre-sorted by the clustered index, specify the columns
109    /// here to avoid a sort operation on the server.
110    /// Default: None
111    pub order_hint: Option<Vec<String>>,
112}
113
114impl Default for BulkOptions {
115    fn default() -> Self {
116        Self {
117            batch_size: 0,
118            check_constraints: true,
119            fire_triggers: false,
120            keep_nulls: true,
121            table_lock: false,
122            order_hint: None,
123        }
124    }
125}
126
127/// Column definition for bulk insert.
128#[derive(Debug, Clone)]
129pub struct BulkColumn {
130    /// Column name.
131    pub name: String,
132    /// SQL Server type (e.g., "INT", "NVARCHAR(100)").
133    pub sql_type: String,
134    /// Whether the column allows NULL values.
135    pub nullable: bool,
136    /// Column ordinal (0-based).
137    pub ordinal: usize,
138    /// TDS type ID.
139    type_id: u8,
140    /// Maximum length for variable-length types.
141    max_length: Option<u32>,
142    /// Precision for decimal types.
143    precision: Option<u8>,
144    /// Scale for decimal types.
145    scale: Option<u8>,
146    /// Collation for VARCHAR/CHAR columns.
147    ///
148    /// Populated automatically from the server's COLMETADATA when
149    /// [`Client::bulk_insert`](crate::Client::bulk_insert) is used. Can be set
150    /// manually via [`with_collation`](Self::with_collation) for the
151    /// schema-discovery-free path. When `None`, VARCHAR values fall back to
152    /// the default Latin1_General_CI_AS collation (Windows-1252).
153    collation: Option<Collation>,
154}
155
156impl BulkColumn {
157    /// Create a new bulk column definition.
158    ///
159    /// # Errors
160    ///
161    /// Returns [`TypeError::UnsupportedType`] when `sql_type` names a deprecated
162    /// large object type (`TEXT`, `NTEXT`, `IMAGE`). Use `VARCHAR(MAX)` /
163    /// `NVARCHAR(MAX)` / `VARBINARY(MAX)` instead — Microsoft deprecated
164    /// `TEXT` / `NTEXT` / `IMAGE` in SQL Server 2005 and recommends the `MAX`
165    /// types for all new development.
166    pub fn new<S: Into<String>>(name: S, sql_type: S, ordinal: usize) -> Result<Self, TypeError> {
167        let sql_type_str: String = sql_type.into();
168        reject_unsupported_bulk_type(&sql_type_str)?;
169        let (type_id, max_length, precision, scale) =
170            parse_sql_type(&sql_type_str).ok_or_else(|| {
171                let base = sql_type_str
172                    .split('(')
173                    .next()
174                    .unwrap_or("")
175                    .trim()
176                    .to_uppercase();
177                TypeError::UnsupportedType {
178                    sql_type: base,
179                    reason: "unsupported bulk-insert column type. Supported types: \
180                             BIT, TINYINT, SMALLINT, INT, BIGINT, REAL, FLOAT, \
181                             DECIMAL/NUMERIC, MONEY, SMALLMONEY, CHAR/VARCHAR, \
182                             NCHAR/NVARCHAR (incl. MAX), BINARY/VARBINARY (incl. MAX), \
183                             UNIQUEIDENTIFIER, DATE, TIME, DATETIME, DATETIME2, \
184                             DATETIMEOFFSET, SMALLDATETIME, and XML."
185                        .to_string(),
186                }
187            })?;
188
189        Ok(Self {
190            name: name.into(),
191            sql_type: sql_type_str,
192            nullable: true,
193            ordinal,
194            type_id,
195            max_length,
196            precision,
197            scale,
198            collation: None,
199        })
200    }
201
202    /// Set whether this column allows NULL values.
203    #[must_use]
204    pub fn with_nullable(mut self, nullable: bool) -> Self {
205        self.nullable = nullable;
206        self
207    }
208
209    /// Set the collation used for VARCHAR/CHAR columns.
210    ///
211    /// Required when [`Client::bulk_insert_without_schema_discovery`](crate::Client::bulk_insert_without_schema_discovery)
212    /// targets VARCHAR columns on a server whose default collation is not
213    /// Latin1_General_CI_AS and the target column uses a different code page.
214    /// Ignored for NVARCHAR/NCHAR columns (always UTF-16).
215    #[must_use]
216    pub fn with_collation(mut self, collation: Collation) -> Self {
217        self.collation = Some(collation);
218        self
219    }
220}
221
222/// Parsed TDS type descriptor: `(type_id, max_length, precision, scale)`.
223type ParsedSqlType = (u8, Option<u32>, Option<u8>, Option<u8>);
224
225/// Parse SQL type string into TDS type information.
226///
227/// Type parameters (e.g., the "100" in `VARCHAR(100)`) are parsed with
228/// `.parse().ok()` — if a parameter is malformed it falls through to the
229/// type's SQL Server default length (e.g., 8000 for VARCHAR, 4000 for
230/// NVARCHAR). This is intentional: bulk-insert column definitions come
231/// from user code, and defaulting to max length is safer than rejecting
232/// the operation when the base type is valid. An unrecognized *base* type,
233/// by contrast, returns `None`, which `BulkColumn::new` turns into a
234/// [`TypeError::UnsupportedType`] error rather than silently coercing it.
235fn parse_sql_type(sql_type: &str) -> Option<ParsedSqlType> {
236    let upper = sql_type.to_uppercase();
237
238    // Extract base type and parameters
239    let (base, params) = if let Some(paren_pos) = upper.find('(') {
240        // Trim so spaced-but-valid spellings (`VARCHAR (50)`, `NVARCHAR( MAX )`)
241        // still resolve — validate_sql_type permits spaces, and the unknown-type
242        // rejection must not turn those into UnsupportedType errors.
243        let base = upper[..paren_pos].trim();
244        let params_str = upper[paren_pos + 1..].trim_end_matches(')').trim();
245        (base, Some(params_str))
246    } else {
247        (upper.as_str().trim(), None)
248    };
249
250    // This returns the nullable type variant ID. `write_colmetadata` switches
251    // to the fixed-width variant (e.g. 0x26 INTN → 0x38 Int4) when the target
252    // column is NOT NULL, since SQL Server's BulkLoad rejects nullable type IDs
253    // for NOT NULL columns with error 4816.
254    let result = match base {
255        "BIT" => (0x68, Some(1), None, None),      // BITN
256        "TINYINT" => (0x26, Some(1), None, None),  // INTN(1)
257        "SMALLINT" => (0x26, Some(2), None, None), // INTN(2)
258        "INT" => (0x26, Some(4), None, None),      // INTN(4)
259        "BIGINT" => (0x26, Some(8), None, None),   // INTN(8)
260        "REAL" => (0x6D, Some(4), None, None),     // FLTN(4)
261        "FLOAT" => (0x6D, Some(8), None, None),    // FLTN(8)
262        "DATE" => (0x28, None, None, None),
263        "TIME" => {
264            let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
265            (0x29, None, None, Some(scale))
266        }
267        "DATETIME" => (0x6F, Some(8), None, None), // DATETIMEN(8)
268        "DATETIME2" => {
269            let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
270            (0x2A, None, None, Some(scale))
271        }
272        "DATETIMEOFFSET" => {
273            let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
274            (0x2B, None, None, Some(scale))
275        }
276        "SMALLDATETIME" => (0x6F, Some(4), None, None), // DATETIMEN(4)
277        "UNIQUEIDENTIFIER" => (0x24, Some(16), None, None),
278        "VARCHAR" | "CHAR" => {
279            let len = params
280                .and_then(|p| {
281                    if p == "MAX" {
282                        Some(0xFFFF_u32)
283                    } else {
284                        p.parse().ok()
285                    }
286                })
287                .unwrap_or(8000);
288            (0xA7, Some(len), None, None)
289        }
290        "NVARCHAR" | "NCHAR" => {
291            let is_max = params.map(|p| p == "MAX").unwrap_or(false);
292            if is_max {
293                // MAX types use 0xFFFF marker (not doubled)
294                (0xE7, Some(0xFFFF), None, None)
295            } else {
296                // Normal lengths are in characters, double for UTF-16 byte length
297                let len = params.and_then(|p| p.parse().ok()).unwrap_or(4000);
298                (0xE7, Some(len * 2), None, None)
299            }
300        }
301        "VARBINARY" | "BINARY" => {
302            let len = params
303                .and_then(|p| {
304                    if p == "MAX" {
305                        Some(0xFFFF_u32)
306                    } else {
307                        p.parse().ok()
308                    }
309                })
310                .unwrap_or(8000);
311            (0xA5, Some(len), None, None)
312        }
313        "DECIMAL" | "NUMERIC" => {
314            let (precision, scale) = if let Some(p) = params {
315                let parts: Vec<&str> = p.split(',').map(|s| s.trim()).collect();
316                (
317                    parts.first().and_then(|s| s.parse().ok()).unwrap_or(18),
318                    parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0),
319                )
320            } else {
321                (18, 0)
322            };
323            (0x6C, None, Some(precision), Some(scale))
324        }
325        "MONEY" => (0x6E, Some(8), None, None), // MONEYN(8)
326        "SMALLMONEY" => (0x6E, Some(4), None, None), // MONEYN(4)
327        "XML" => (0xF1, Some(0xFFFF), None, None),
328        // Unknown base type: return None so BulkColumn::new rejects it with a
329        // clear error rather than silently coercing to NVARCHAR (which would
330        // put wrong bytes on the wire under a mislabeled type).
331        _ => return None,
332    };
333    Some(result)
334}
335
336/// Reject deprecated large object types that this driver does not support in
337/// bulk insert. `TEXT` / `NTEXT` / `IMAGE` have been deprecated since SQL
338/// Server 2005 and use a legacy TEXTPTR wire format. Users should use
339/// `VARCHAR(MAX)` / `NVARCHAR(MAX)` / `VARBINARY(MAX)` which the driver
340/// supports end-to-end.
341fn reject_unsupported_bulk_type(sql_type: &str) -> Result<(), TypeError> {
342    let base = sql_type
343        .split('(')
344        .next()
345        .unwrap_or("")
346        .trim()
347        .to_uppercase();
348    match base.as_str() {
349        "TEXT" | "NTEXT" => Err(TypeError::UnsupportedType {
350            sql_type: base,
351            reason: "TEXT/NTEXT are not supported. Use VARCHAR(MAX) / \
352                     NVARCHAR(MAX) instead (Microsoft deprecated TEXT/NTEXT in \
353                     SQL Server 2005)."
354                .to_string(),
355        }),
356        "IMAGE" => Err(TypeError::UnsupportedType {
357            sql_type: base,
358            reason: "IMAGE is not supported. Use VARBINARY(MAX) instead \
359                     (Microsoft deprecated IMAGE in SQL Server 2005)."
360                .to_string(),
361        }),
362        _ => Ok(()),
363    }
364}
365
366/// Result of a bulk insert operation.
367#[derive(Debug, Clone)]
368pub struct BulkInsertResult {
369    /// Total number of rows inserted.
370    pub rows_affected: u64,
371    /// Number of batches committed.
372    ///
373    /// [`BulkWriter`] always sends a single batch, so this is currently 1.
374    pub batches_committed: u32,
375    /// Whether any errors were encountered.
376    pub has_errors: bool,
377}
378
379/// Builder for configuring a bulk insert operation.
380#[derive(Debug)]
381pub struct BulkInsertBuilder {
382    table_name: String,
383    columns: Vec<BulkColumn>,
384    options: BulkOptions,
385}
386
387impl BulkInsertBuilder {
388    /// Create a new bulk insert builder for the specified table.
389    pub fn new<S: Into<String>>(table_name: S) -> Self {
390        Self {
391            table_name: table_name.into(),
392            columns: Vec::new(),
393            options: BulkOptions::default(),
394        }
395    }
396
397    /// Specify the columns to insert.
398    ///
399    /// Columns will be queried from the server if not specified,
400    /// but providing them explicitly is more efficient.
401    #[must_use]
402    #[allow(clippy::expect_used)] // NVARCHAR(MAX) is always a supported bulk type
403    pub fn with_columns(mut self, column_names: &[&str]) -> Self {
404        self.columns = column_names
405            .iter()
406            .enumerate()
407            .map(|(i, name)| {
408                BulkColumn::new(*name, "NVARCHAR(MAX)", i)
409                    .expect("NVARCHAR(MAX) is always a supported type")
410            })
411            .collect();
412        self
413    }
414
415    /// Specify columns with full type information.
416    #[must_use]
417    pub fn with_typed_columns(mut self, columns: Vec<BulkColumn>) -> Self {
418        self.columns = columns;
419        self
420    }
421
422    /// Set bulk insert options.
423    #[must_use]
424    pub fn with_options(mut self, options: BulkOptions) -> Self {
425        self.options = options;
426        self
427    }
428
429    /// Set the batch size.
430    #[must_use]
431    pub fn batch_size(mut self, size: usize) -> Self {
432        self.options.batch_size = size;
433        self
434    }
435
436    /// Enable or disable table lock.
437    #[must_use]
438    pub fn table_lock(mut self, enabled: bool) -> Self {
439        self.options.table_lock = enabled;
440        self
441    }
442
443    /// Enable or disable trigger firing.
444    #[must_use]
445    pub fn fire_triggers(mut self, enabled: bool) -> Self {
446        self.options.fire_triggers = enabled;
447        self
448    }
449
450    /// Get the table name.
451    pub fn table_name(&self) -> &str {
452        &self.table_name
453    }
454
455    /// Get the columns.
456    pub fn columns(&self) -> &[BulkColumn] {
457        &self.columns
458    }
459
460    /// Get the options.
461    pub fn options(&self) -> &BulkOptions {
462        &self.options
463    }
464
465    /// Build the INSERT BULK SQL statement.
466    ///
467    /// # Errors
468    ///
469    /// Returns an error if the table name or any column name fails identifier
470    /// validation, preventing SQL injection.
471    pub fn build_insert_bulk_statement(&self) -> Result<String, Error> {
472        // Validate table name (may be schema-qualified: dbo.Users, catalog.schema.table)
473        crate::validation::validate_qualified_identifier(&self.table_name)?;
474
475        // Validate column names
476        for col in &self.columns {
477            crate::validation::validate_identifier(&col.name)?;
478        }
479
480        let mut sql = format!("INSERT BULK {}", self.table_name);
481
482        // Add column definitions
483        if !self.columns.is_empty() {
484            sql.push_str(" (");
485            let cols: Vec<String> = self
486                .columns
487                .iter()
488                .map(|c| {
489                    // Validate sql_type to prevent SQL injection: only allow
490                    // alphanumerics, parentheses (for length/precision), commas,
491                    // spaces, and the MAX keyword — which covers all valid T-SQL
492                    // type specifiers like "NVARCHAR(100)", "DECIMAL(18, 2)",
493                    // "VARBINARY(MAX)", etc.
494                    validate_sql_type(&c.sql_type)?;
495                    Ok(format!("{} {}", c.name, c.sql_type))
496                })
497                .collect::<Result<Vec<_>, Error>>()?;
498            sql.push_str(&cols.join(", "));
499            sql.push(')');
500        }
501
502        // Add WITH clause for options
503        let mut hints: Vec<String> = Vec::new();
504
505        if self.options.check_constraints {
506            hints.push("CHECK_CONSTRAINTS".to_string());
507        }
508        if self.options.fire_triggers {
509            hints.push("FIRE_TRIGGERS".to_string());
510        }
511        if self.options.keep_nulls {
512            hints.push("KEEP_NULLS".to_string());
513        }
514        if self.options.table_lock {
515            hints.push("TABLOCK".to_string());
516        }
517        if self.options.batch_size > 0 {
518            hints.push(format!("ROWS_PER_BATCH = {}", self.options.batch_size));
519        }
520
521        if let Some(ref order) = self.options.order_hint {
522            // Validate order hint column names
523            for col_name in order {
524                crate::validation::validate_identifier(col_name)?;
525            }
526            hints.push(format!("ORDER({})", order.join(", ")));
527        }
528
529        if !hints.is_empty() {
530            sql.push_str(" WITH (");
531            sql.push_str(&hints.join(", "));
532            sql.push(')');
533        }
534
535        Ok(sql)
536    }
537}
538
539/// Validate a SQL type specifier to prevent SQL injection.
540///
541/// Allows only characters that can appear in valid T-SQL type declarations:
542/// letters, digits, parentheses, commas, spaces, and periods.
543/// Examples: "INT", "NVARCHAR(100)", "DECIMAL(18, 2)", "VARBINARY(MAX)".
544fn validate_sql_type(type_str: &str) -> Result<(), Error> {
545    #[allow(clippy::expect_used)] // Static regex compilation with known-valid pattern
546    static SQL_TYPE_RE: Lazy<Regex> =
547        Lazy::new(|| Regex::new(r"^[a-zA-Z][a-zA-Z0-9_ ()\.,]{0,127}$").expect("valid regex"));
548
549    if type_str.is_empty() {
550        return Err(Error::Config("SQL type cannot be empty".into()));
551    }
552
553    if !SQL_TYPE_RE.is_match(type_str) {
554        return Err(Error::Config(format!(
555            "invalid SQL type '{type_str}': contains disallowed characters"
556        )));
557    }
558
559    Ok(())
560}
561
562/// Active bulk insert operation.
563///
564/// This struct manages the streaming of row data to the server.
565/// Call `send_row()` for each row, then `finish()` to complete.
566pub struct BulkInsert {
567    /// Column metadata.
568    columns: Arc<[BulkColumn]>,
569    /// Whether each column uses a fixed-length type on the wire.
570    /// When true, row values for that column are written without a length prefix.
571    fixed_len: Arc<[bool]>,
572    /// Buffer for accumulating rows.
573    buffer: BytesMut,
574    /// Rows in current batch.
575    rows_in_batch: usize,
576    /// Total rows sent.
577    total_rows: u64,
578    /// Batch size (0 = single batch).
579    batch_size: usize,
580    /// Number of batches committed.
581    batches_committed: u32,
582    /// Packet ID counter.
583    packet_id: u8,
584}
585
586impl BulkInsert {
587    /// Create a new bulk insert operation.
588    pub fn new(columns: Vec<BulkColumn>, batch_size: usize) -> Self {
589        Self::new_with_server_metadata(columns, batch_size, None, None)
590    }
591
592    /// Create a new bulk insert operation using server metadata.
593    ///
594    /// When `raw_colmetadata` is provided, it is written directly into the
595    /// BulkLoad buffer, ensuring the COLMETADATA matches the server's exact
596    /// encoding. `server_columns` provides per-column type info so row values
597    /// are encoded correctly (fixed-length types have no length prefix).
598    ///
599    /// This follows the pattern used by Tiberius: the server's own metadata
600    /// from `SELECT TOP 0` is echoed back rather than constructing it from
601    /// user-specified types.
602    pub(crate) fn new_with_server_metadata(
603        mut columns: Vec<BulkColumn>,
604        batch_size: usize,
605        raw_colmetadata: Option<bytes::Bytes>,
606        server_columns: Option<&[tds_protocol::token::ColumnData]>,
607    ) -> Self {
608        // Determine which columns use fixed-length types on the wire.
609        // Fixed-length types omit the per-row length prefix.
610        let fixed_len: Vec<bool> = if let Some(srv_cols) = server_columns {
611            // Propagate collation from server metadata for VARCHAR/CHAR columns.
612            // The user's BulkColumn is constructed from type strings alone and
613            // has no collation until we see the server's COLMETADATA — falling
614            // back to the default Latin1 on NON-Latin servers would silently
615            // corrupt extended characters.
616            for (col, srv) in columns.iter_mut().zip(srv_cols.iter()) {
617                if col.collation.is_none() {
618                    col.collation = srv.type_info.collation;
619                }
620            }
621            srv_cols
622                .iter()
623                .map(|c| c.type_id.is_fixed_length())
624                .collect()
625        } else {
626            // Without server metadata, NOT NULL columns of fixed-width types
627            // must use the fixed type ID variant (e.g. INT NOT NULL uses 0x38
628            // Int4, not 0x26 INTN). SQL Server rejects nullable type IDs for
629            // NOT NULL target columns with error 4816.
630            columns
631                .iter()
632                .map(|c| !c.nullable && nullable_to_fixed_type(c.type_id, c.max_length).is_some())
633                .collect()
634        };
635
636        let mut bulk = Self {
637            columns: columns.into(),
638            fixed_len: fixed_len.into(),
639            buffer: BytesMut::with_capacity(64 * 1024),
640            rows_in_batch: 0,
641            total_rows: 0,
642            batch_size,
643            batches_committed: 0,
644            packet_id: 1,
645        };
646
647        if let Some(raw) = raw_colmetadata {
648            bulk.buffer.extend_from_slice(&raw);
649        } else {
650            bulk.write_colmetadata();
651        }
652
653        bulk
654    }
655
656    /// Write the COLMETADATA token to the buffer.
657    fn write_colmetadata(&mut self) {
658        let buf = &mut self.buffer;
659
660        // Token type
661        buf.put_u8(TokenType::ColMetaData as u8);
662
663        // Column count
664        buf.put_u16_le(self.columns.len() as u16);
665
666        for col in self.columns.iter() {
667            // User type (always 0 for basic types)
668            buf.put_u32_le(0);
669
670            // For NOT NULL columns with a fixed-width type, use the fixed type ID
671            // variant (e.g. INT NOT NULL → 0x38 Int4 instead of 0x26 INTN).
672            // SQL Server's BCP rejects nullable type IDs for NOT NULL columns.
673            let effective_type_id = if !col.nullable {
674                nullable_to_fixed_type(col.type_id, col.max_length).unwrap_or(col.type_id)
675            } else {
676                col.type_id
677            };
678            let is_fixed_variant = effective_type_id != col.type_id;
679
680            // Flags: Nullable (bit 0) | Updateable (bit 3)
681            // BulkLoad columns must have Updateable set to indicate they accept writes.
682            let mut flags: u16 = 0x0008; // Updateable
683            if col.nullable {
684                flags |= 0x0001; // Nullable
685            }
686            buf.put_u16_le(flags);
687
688            // Type info
689            buf.put_u8(effective_type_id);
690
691            // Fixed-width types have no additional TYPE_INFO bytes — skip straight
692            // to the column name.
693            if is_fixed_variant {
694                let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
695                buf.put_u8(name_utf16.len() as u8);
696                for code_unit in name_utf16 {
697                    buf.put_u16_le(code_unit);
698                }
699                continue;
700            }
701
702            // Type-specific length/precision/scale
703            match col.type_id {
704                // Nullable fixed-length types — 1-byte max-length follows type ID
705                // INTN(0x26), BITN(0x68), FLTN(0x6D), MONEYN(0x6E), DATETIMEN(0x6F)
706                0x26 | 0x68 | 0x6D | 0x6E | 0x6F => {
707                    buf.put_u8(col.max_length.unwrap_or(4) as u8);
708                }
709
710                // DATE has no length byte (fixed 3-byte value)
711                0x28 => {}
712
713                // Variable-length string/binary types
714                0xE7 | 0xA7 | 0xA5 | 0xAD => {
715                    // Max length (2 bytes for normal, 4 bytes for MAX)
716                    let max_len = col.max_length.unwrap_or(8000);
717                    if max_len == 0xFFFF {
718                        buf.put_u16_le(0xFFFF);
719                    } else {
720                        buf.put_u16_le(max_len as u16);
721                    }
722
723                    // Collation for string types (5 bytes). Use the caller-
724                    // supplied collation when present (via `with_collation()`),
725                    // otherwise fall back to Latin1_General_CI_AS.
726                    if col.type_id == 0xE7 || col.type_id == 0xA7 {
727                        if let Some(coll) = col.collation.as_ref() {
728                            buf.put_slice(&coll.to_bytes());
729                        } else {
730                            // Default collation: Latin1_General_CI_AS
731                            // Bytes: LCID(0x0409) + flags(0xD000) + SortId(0x34)
732                            buf.put_slice(&[0x09, 0x04, 0xD0, 0x00, 0x34]);
733                        }
734                    }
735                }
736
737                // Decimal/Numeric
738                0x6C | 0x6A => {
739                    // Length (calculated from precision)
740                    let precision = col.precision.unwrap_or(18);
741                    let len = decimal_byte_length(precision);
742                    buf.put_u8(len);
743                    buf.put_u8(precision);
744                    buf.put_u8(col.scale.unwrap_or(0));
745                }
746
747                // Time-based with scale
748                0x29..=0x2B => {
749                    buf.put_u8(col.scale.unwrap_or(7));
750                }
751
752                // GUID
753                0x24 => {
754                    buf.put_u8(16);
755                }
756
757                // Other types - write max length if present
758                _ => {
759                    if let Some(len) = col.max_length {
760                        if len <= 0xFFFF {
761                            buf.put_u16_le(len as u16);
762                        }
763                    }
764                }
765            }
766
767            // Column name (B_VARCHAR format: 1-byte length prefix)
768            let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
769            buf.put_u8(name_utf16.len() as u8);
770            for code_unit in name_utf16 {
771                buf.put_u16_le(code_unit);
772            }
773        }
774    }
775
776    /// Send a row of data.
777    ///
778    /// The values must match the column order and types specified
779    /// when creating the bulk insert.
780    ///
781    /// # Errors
782    ///
783    /// Returns an error if:
784    /// - Wrong number of values provided
785    /// - A value cannot be converted to the expected type
786    pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
787        if values.len() != self.columns.len() {
788            return Err(Error::Config(format!(
789                "expected {} values, got {}",
790                self.columns.len(),
791                values.len()
792            )));
793        }
794
795        // Convert all values to SqlValue
796        let sql_values: Result<Vec<SqlValue>, TypeError> =
797            values.iter().map(|v| v.to_sql()).collect();
798        let sql_values = sql_values.map_err(Error::from)?;
799
800        self.write_row(&sql_values)?;
801
802        self.rows_in_batch += 1;
803        self.total_rows += 1;
804
805        Ok(())
806    }
807
808    /// Send a row of pre-converted SQL values.
809    pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
810        if values.len() != self.columns.len() {
811            return Err(Error::Config(format!(
812                "expected {} values, got {}",
813                self.columns.len(),
814                values.len()
815            )));
816        }
817
818        self.write_row(values)?;
819
820        self.rows_in_batch += 1;
821        self.total_rows += 1;
822
823        Ok(())
824    }
825
826    /// Write a ROW token to the buffer.
827    fn write_row(&mut self, values: &[SqlValue]) -> Result<(), Error> {
828        // ROW token type
829        self.buffer.put_u8(TokenType::Row as u8);
830
831        // Collect column info needed for encoding to avoid borrow conflict
832        let columns: Vec<_> = self.columns.iter().cloned().collect();
833        let fixed_len = self.fixed_len.clone();
834
835        // Write each column value
836        for (i, (col, value)) in columns.iter().zip(values.iter()).enumerate() {
837            let is_fixed = *fixed_len.get(i).unwrap_or(&false);
838            self.encode_column_value(col, value, is_fixed)
839                .map_err(|e| Error::Config(format!("failed to encode column {i}: {e}")))?;
840        }
841
842        Ok(())
843    }
844
845    /// Encode a column value according to its type.
846    ///
847    /// When `is_fixed` is true, the column uses a fixed-length type on the wire
848    /// and values are written without a length prefix. When false, values include
849    /// a length prefix (1 byte for numeric nullable types, 2 bytes for strings).
850    fn encode_column_value(
851        &mut self,
852        col: &BulkColumn,
853        value: &SqlValue,
854        is_fixed: bool,
855    ) -> Result<(), TypeError> {
856        let buf = &mut self.buffer;
857
858        // Check if this column uses PLP (Partially Length-Prefixed) encoding
859        // MAX types (max_length == 0xFFFF) use PLP format
860        let is_plp_type =
861            col.max_length == Some(0xFFFF) && matches!(col.type_id, 0xE7 | 0xA7 | 0xA5 | 0xAD);
862
863        match value {
864            SqlValue::Null => {
865                // NULL encoding depends on type
866                match col.type_id {
867                    // Variable-length types
868                    0xE7 | 0xA7 | 0xA5 | 0xAD => {
869                        if is_plp_type {
870                            // PLP NULL: 0xFFFFFFFFFFFFFFFF
871                            buf.put_u64_le(0xFFFF_FFFF_FFFF_FFFF);
872                        } else {
873                            // Standard NULL: 0xFFFF length marker
874                            buf.put_u16_le(0xFFFF);
875                        }
876                    }
877                    // Nullable fixed types use 0 length
878                    // INTN, BITN, FLTN, MONEYN, DATETIMEN, Decimal, GUID, temporal
879                    0x26 | 0x68 | 0x6D | 0x6E | 0x6F | 0x6C | 0x6A | 0x24 | 0x28 | 0x29 | 0x2A
880                    | 0x2B => {
881                        buf.put_u8(0);
882                    }
883                    // Fixed types without nullable variant
884                    _ => {
885                        if col.nullable {
886                            buf.put_u8(0);
887                        } else {
888                            return Err(TypeError::UnexpectedNull);
889                        }
890                    }
891                }
892            }
893
894            SqlValue::Bool(v) => {
895                if !is_fixed {
896                    buf.put_u8(1);
897                }
898                buf.put_u8(if *v { 1 } else { 0 });
899            }
900
901            SqlValue::TinyInt(v) => {
902                if !is_fixed {
903                    buf.put_u8(1);
904                }
905                buf.put_u8(*v);
906            }
907
908            SqlValue::SmallInt(v) => {
909                if !is_fixed {
910                    buf.put_u8(2);
911                }
912                buf.put_i16_le(*v);
913            }
914
915            SqlValue::Int(v) => {
916                if !is_fixed {
917                    buf.put_u8(4);
918                }
919                buf.put_i32_le(*v);
920            }
921
922            SqlValue::BigInt(v) => {
923                if !is_fixed {
924                    buf.put_u8(8);
925                }
926                buf.put_i64_le(*v);
927            }
928
929            SqlValue::Float(v) => {
930                if !is_fixed {
931                    buf.put_u8(4);
932                }
933                buf.put_f32_le(*v);
934            }
935
936            SqlValue::Double(v) => {
937                if !is_fixed {
938                    buf.put_u8(8);
939                }
940                buf.put_f64_le(*v);
941            }
942
943            SqlValue::String(s) => {
944                // NVARCHAR/NCHAR columns (0xE7/0xEF) use UTF-16LE on the wire.
945                // VARCHAR/CHAR/BIGCHAR columns (0xA7/0x2F/0xAF) use the
946                // collation's code page for single-byte encoding — writing UTF-16
947                // into a VARCHAR column lands each surrogate half in its own
948                // single-byte slot and silently corrupts the data.
949                let is_varchar = matches!(col.type_id, 0xA7 | 0x2F | 0xAF);
950
951                if is_varchar {
952                    let encoded = encode_varchar_for_collation(s, col.collation.as_ref());
953                    let byte_len = encoded.len();
954
955                    if is_plp_type {
956                        encode_plp_binary(&encoded, buf);
957                    } else if byte_len > 0xFFFF {
958                        return Err(TypeError::BufferTooSmall {
959                            needed: byte_len,
960                            available: 0xFFFF,
961                        });
962                    } else {
963                        buf.put_u16_le(byte_len as u16);
964                        buf.put_slice(&encoded);
965                    }
966                } else {
967                    // UTF-16LE encoding for NVARCHAR
968                    let utf16: Vec<u16> = s.encode_utf16().collect();
969                    let byte_len = utf16.len() * 2;
970
971                    if is_plp_type {
972                        // PLP format for MAX types - supports unlimited size
973                        // Send as a single chunk for simplicity
974                        encode_plp_string(&utf16, buf);
975                    } else if byte_len > 0xFFFF {
976                        // Non-MAX column can't hold this much data
977                        return Err(TypeError::BufferTooSmall {
978                            needed: byte_len,
979                            available: 0xFFFF,
980                        });
981                    } else {
982                        // Standard encoding with 2-byte length prefix
983                        buf.put_u16_le(byte_len as u16);
984                        for code_unit in utf16 {
985                            buf.put_u16_le(code_unit);
986                        }
987                    }
988                }
989            }
990
991            SqlValue::Binary(b) => {
992                if is_plp_type {
993                    // PLP format for MAX types - supports unlimited size
994                    encode_plp_binary(b, buf);
995                } else if b.len() > 0xFFFF {
996                    // Non-MAX column can't hold this much data
997                    return Err(TypeError::BufferTooSmall {
998                        needed: b.len(),
999                        available: 0xFFFF,
1000                    });
1001                } else {
1002                    // Standard encoding with 2-byte length prefix
1003                    buf.put_u16_le(b.len() as u16);
1004                    buf.put_slice(b);
1005                }
1006            }
1007
1008            // Feature-gated types - use mssql_types::encode module
1009            #[cfg(feature = "decimal")]
1010            SqlValue::Decimal(d) => {
1011                if col.type_id == 0x6E {
1012                    // MONEY / SMALLMONEY — fixed-point scaled by 10_000, not DECIMAL format.
1013                    encode_money_value(*d, col, buf, is_fixed)?;
1014                } else {
1015                    let precision = col.precision.unwrap_or(18);
1016                    let len = decimal_byte_length(precision);
1017                    buf.put_u8(len);
1018
1019                    // Sign: 0 = negative, 1 = positive
1020                    buf.put_u8(if d.is_sign_negative() { 0 } else { 1 });
1021
1022                    // Mantissa as unsigned 128-bit integer
1023                    let mantissa = d.mantissa().unsigned_abs();
1024                    let mantissa_bytes = mantissa.to_le_bytes();
1025                    buf.put_slice(&mantissa_bytes[..((len - 1) as usize)]);
1026                }
1027            }
1028
1029            #[cfg(feature = "uuid")]
1030            SqlValue::Uuid(u) => {
1031                buf.put_u8(16); // Length
1032                // Use mssql_types encode function
1033                mssql_types::__private::encode_uuid(*u, buf);
1034            }
1035
1036            #[cfg(feature = "chrono")]
1037            SqlValue::Date(d) => {
1038                buf.put_u8(3); // Length
1039                mssql_types::__private::encode_date(*d, buf)?;
1040            }
1041
1042            #[cfg(feature = "chrono")]
1043            SqlValue::Time(t) => {
1044                let scale = col.scale.unwrap_or(7);
1045                let len = time_byte_length(scale);
1046                buf.put_u8(len);
1047                // Encode time with proper scale handling
1048                encode_time_with_scale(*t, scale, buf);
1049            }
1050
1051            #[cfg(feature = "chrono")]
1052            SqlValue::DateTime(dt) => {
1053                // Type 0x6F is DATETIMEN — legacy DATETIME (8 bytes) or
1054                // SMALLDATETIME (4 bytes) format selected by max_length. The
1055                // wire format differs from DATETIME2 (type 0x2A), which uses a
1056                // scale-aware time-then-date layout.
1057                if col.type_id == 0x6F {
1058                    let total_len = col.max_length.unwrap_or(8) as u8;
1059                    if !is_fixed {
1060                        buf.put_u8(total_len);
1061                    }
1062                    match total_len {
1063                        8 => mssql_types::__private::encode_datetime_legacy(*dt, buf),
1064                        4 => mssql_types::__private::encode_smalldatetime(*dt, buf)?,
1065                        _ => {
1066                            return Err(TypeError::InvalidDateTime(format!(
1067                                "DATETIMEN max_length must be 4 or 8, got {total_len}"
1068                            )));
1069                        }
1070                    }
1071                } else {
1072                    let scale = col.scale.unwrap_or(7);
1073                    let time_len = time_byte_length(scale);
1074                    let total_len = time_len + 3;
1075                    buf.put_u8(total_len);
1076                    // Encode time then date
1077                    encode_time_with_scale(dt.time(), scale, buf);
1078                    mssql_types::__private::encode_date(dt.date(), buf)?;
1079                }
1080            }
1081            #[cfg(feature = "chrono")]
1082            SqlValue::SmallDateTime(dt) => {
1083                // Explicit SMALLDATETIME variant — always 4-byte days+minutes,
1084                // regardless of column metadata.
1085                if !is_fixed {
1086                    buf.put_u8(4);
1087                }
1088                mssql_types::__private::encode_smalldatetime(*dt, buf)?;
1089            }
1090            #[cfg(feature = "decimal")]
1091            SqlValue::Money(d) => {
1092                // Force 8-byte MONEY encoding regardless of column metadata.
1093                if !is_fixed {
1094                    buf.put_u8(8);
1095                }
1096                mssql_types::__private::encode_money(*d, buf)?;
1097            }
1098            #[cfg(feature = "decimal")]
1099            SqlValue::SmallMoney(d) => {
1100                if !is_fixed {
1101                    buf.put_u8(4);
1102                }
1103                mssql_types::__private::encode_smallmoney(*d, buf)?;
1104            }
1105
1106            #[cfg(feature = "chrono")]
1107            SqlValue::DateTimeOffset(dto) => {
1108                let scale = col.scale.unwrap_or(7);
1109                let time_len = time_byte_length(scale);
1110                let total_len = time_len + 3 + 2;
1111                buf.put_u8(total_len);
1112                // The wire date/time portion is UTC per MS-TDS §2.2.5.5.1.9,
1113                // not the local wall-clock.
1114                let utc = dto.naive_utc();
1115                encode_time_with_scale(utc.time(), scale, buf);
1116                mssql_types::__private::encode_date(utc.date(), buf)?;
1117                // Timezone offset in minutes
1118                use chrono::Offset;
1119                let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
1120                buf.put_i16_le(offset_minutes);
1121            }
1122
1123            #[cfg(feature = "json")]
1124            SqlValue::Json(j) => {
1125                let s = j.to_string();
1126                encode_nvarchar_value(&s, buf)?;
1127            }
1128
1129            SqlValue::Xml(x) => {
1130                encode_nvarchar_value(x, buf)?;
1131            }
1132
1133            SqlValue::Tvp(_) => {
1134                // TVPs are not valid in bulk copy operations - they're for RPC parameters only
1135                return Err(TypeError::UnsupportedConversion {
1136                    from: "TVP".to_string(),
1137                    to: "bulk copy value",
1138                });
1139            }
1140            // Handle future SqlValue variants
1141            _ => {
1142                return Err(TypeError::UnsupportedConversion {
1143                    from: value.type_name().to_string(),
1144                    to: "bulk copy value",
1145                });
1146            }
1147        }
1148
1149        Ok(())
1150    }
1151}
1152
1153/// Encode a MONEY or SMALLMONEY column value with the appropriate length prefix.
1154///
1155/// When `is_fixed` is true (fixed type ID 0x3C or 0x7A), no length byte
1156/// precedes the payload. Otherwise a 1-byte length prefix is written
1157/// (matching the MONEYN nullable variant).
1158#[cfg(feature = "decimal")]
1159fn encode_money_value(
1160    value: rust_decimal::Decimal,
1161    col: &BulkColumn,
1162    buf: &mut BytesMut,
1163    is_fixed: bool,
1164) -> Result<(), TypeError> {
1165    let money_bytes: u8 = col.max_length.unwrap_or(8) as u8;
1166    if !is_fixed {
1167        buf.put_u8(money_bytes);
1168    }
1169    match money_bytes {
1170        4 => mssql_types::__private::encode_smallmoney(value, buf),
1171        8 => mssql_types::__private::encode_money(value, buf),
1172        _ => Err(TypeError::InvalidDecimal(format!(
1173            "MONEY column has invalid max_length: {money_bytes}"
1174        ))),
1175    }
1176}
1177
1178/// Encode a string as NVARCHAR with length prefix.
1179fn encode_nvarchar_value(s: &str, buf: &mut BytesMut) -> Result<(), TypeError> {
1180    let utf16: Vec<u16> = s.encode_utf16().collect();
1181    let byte_len = utf16.len() * 2;
1182
1183    if byte_len > 0xFFFF {
1184        return Err(TypeError::BufferTooSmall {
1185            needed: byte_len,
1186            available: 0xFFFF,
1187        });
1188    }
1189
1190    buf.put_u16_le(byte_len as u16);
1191    for code_unit in utf16 {
1192        buf.put_u16_le(code_unit);
1193    }
1194    Ok(())
1195}
1196
1197/// PLP marker for an unknown total length (MS-TDS 2.2.5.2.3).
1198/// When the client doesn't know or doesn't wish to compute the total in advance,
1199/// the 8-byte ULONGLONGLEN is set to this value and the server relies on chunk
1200/// framing + the 4-byte terminator to detect the end.
1201const PLP_UNKNOWN_LEN: u64 = 0xFFFFFFFFFFFFFFFE;
1202
1203/// Encode a UTF-16 string using PLP (Partially Length-Prefixed) format.
1204///
1205/// PLP format (per MS-TDS 2.2.5.2.3):
1206/// - 8 bytes: ULONGLONGLEN — PLP_UNKNOWN_LEN or actual total byte count
1207/// - One or more chunks: 4-byte chunk length + chunk bytes
1208/// - Terminator: 4-byte zero
1209///
1210/// We emit `PLP_UNKNOWN_LEN` for compatibility with SQL Server's BulkLoad
1211/// parser. Empirically, some server versions reject a concrete total length
1212/// in the BulkLoad (0x07) path even though the token-stream spec allows it
1213/// ("premature end-of-message" errors for NVARCHAR(MAX) bulk inserts).
1214/// Tiberius uses the same approach.
1215fn encode_plp_string(utf16: &[u16], buf: &mut BytesMut) {
1216    let byte_len = utf16.len() * 2;
1217
1218    buf.put_u64_le(PLP_UNKNOWN_LEN);
1219
1220    if byte_len > 0 {
1221        buf.put_u32_le(byte_len as u32);
1222        for code_unit in utf16 {
1223            buf.put_u16_le(*code_unit);
1224        }
1225    }
1226
1227    buf.put_u32_le(0);
1228}
1229
1230/// Encode binary data using PLP (Partially Length-Prefixed) format.
1231/// See [`encode_plp_string`] for the format specification.
1232fn encode_plp_binary(data: &[u8], buf: &mut BytesMut) {
1233    buf.put_u64_le(PLP_UNKNOWN_LEN);
1234
1235    if !data.is_empty() {
1236        buf.put_u32_le(data.len() as u32);
1237        buf.put_slice(data);
1238    }
1239
1240    buf.put_u32_le(0);
1241}
1242
1243/// Encode a Rust string into single-byte VARCHAR bytes using the column's collation.
1244///
1245/// Delegates to [`tds_protocol::__private::encode_str_for_collation`] so the
1246/// RPC parameter path and the bulk insert path share one implementation.
1247fn encode_varchar_for_collation(value: &str, collation: Option<&Collation>) -> Vec<u8> {
1248    tds_protocol::__private::encode_str_for_collation(value, collation)
1249}
1250
1251/// Encode time with specific scale (for bulk copy).
1252#[cfg(feature = "chrono")]
1253fn encode_time_with_scale(time: chrono::NaiveTime, scale: u8, buf: &mut BytesMut) {
1254    use chrono::Timelike;
1255
1256    let nanos = time.num_seconds_from_midnight() as u64 * 1_000_000_000 + time.nanosecond() as u64;
1257    let intervals = nanos / time_scale_divisor(scale);
1258    let len = time_byte_length(scale);
1259
1260    for i in 0..len {
1261        buf.put_u8(((intervals >> (i * 8)) & 0xFF) as u8);
1262    }
1263}
1264
1265impl BulkInsert {
1266    /// Write the DONE token signaling completion.
1267    fn write_done(&mut self) {
1268        let buf = &mut self.buffer;
1269
1270        buf.put_u8(TokenType::Done as u8);
1271
1272        // Status: FINAL (0x00) | COUNT (0x10)
1273        let status = DoneStatus::from_bits(0x0010); // DONE_COUNT
1274        buf.put_u16_le(status.to_bits());
1275
1276        // Current command (0 for bulk load)
1277        buf.put_u16_le(0);
1278
1279        // Row count
1280        buf.put_u64_le(self.total_rows);
1281    }
1282
1283    /// Get the buffered data as packets ready to send.
1284    ///
1285    /// Returns a vector of complete TDS packets with BulkLoad packet type (0x07).
1286    pub fn take_packets(&mut self) -> Vec<BytesMut> {
1287        const MAX_PACKET_SIZE: usize = 4096;
1288        const HEADER_SIZE: usize = 8;
1289        const MAX_PAYLOAD: usize = MAX_PACKET_SIZE - HEADER_SIZE;
1290
1291        let data = self.buffer.split();
1292        let mut packets = Vec::new();
1293        let mut offset = 0;
1294
1295        while offset < data.len() {
1296            let remaining = data.len() - offset;
1297            let payload_size = remaining.min(MAX_PAYLOAD);
1298            let is_last = offset + payload_size >= data.len();
1299
1300            let mut packet = BytesMut::with_capacity(MAX_PACKET_SIZE);
1301
1302            // Write packet header
1303            let header = PacketHeader {
1304                packet_type: PacketType::BulkLoad,
1305                status: if is_last {
1306                    PacketStatus::END_OF_MESSAGE
1307                } else {
1308                    PacketStatus::NORMAL
1309                },
1310                length: (HEADER_SIZE + payload_size) as u16,
1311                spid: 0,
1312                packet_id: self.packet_id,
1313                window: 0,
1314            };
1315
1316            header.encode(&mut packet);
1317
1318            // Write payload
1319            packet.put_slice(&data[offset..offset + payload_size]);
1320
1321            packets.push(packet);
1322            offset += payload_size;
1323            self.packet_id = self.packet_id.wrapping_add(1);
1324        }
1325
1326        packets
1327    }
1328
1329    /// Get total rows sent so far.
1330    pub fn total_rows(&self) -> u64 {
1331        self.total_rows
1332    }
1333
1334    /// Get rows in current batch.
1335    pub fn rows_in_batch(&self) -> usize {
1336        self.rows_in_batch
1337    }
1338
1339    /// Check if a batch flush is needed.
1340    ///
1341    /// Note: [`BulkWriter`] does not consult this — it buffers every row and
1342    /// sends one batch on finish. This is a helper for callers driving the
1343    /// lower-level [`BulkInsert`] packet API manually.
1344    pub fn should_flush(&self) -> bool {
1345        self.batch_size > 0 && self.rows_in_batch >= self.batch_size
1346    }
1347
1348    /// Prepare for finishing the bulk operation.
1349    /// Writes the DONE token and returns final packets.
1350    pub fn finish_packets(&mut self) -> Vec<BytesMut> {
1351        self.write_done();
1352        self.take_packets()
1353    }
1354
1355    /// Create a result from the current state.
1356    pub fn result(&self) -> BulkInsertResult {
1357        BulkInsertResult {
1358            rows_affected: self.total_rows,
1359            batches_committed: self.batches_committed,
1360            has_errors: false,
1361        }
1362    }
1363}
1364
1365/// Active streaming writer for bulk insert operations.
1366///
1367/// Created via [`crate::client::Client::bulk_insert()`]. Rows are buffered in
1368/// memory as they are added with [`send_row()`](BulkWriter::send_row), then
1369/// transmitted to the server when [`finish()`](BulkWriter::finish) is called.
1370///
1371/// The writer holds a mutable reference to the [`crate::Client`], preventing
1372/// other operations on the connection while the bulk insert is in progress.
1373///
1374/// # Example
1375///
1376/// ```rust,no_run
1377/// # use mssql_client::{BulkInsertBuilder, BulkColumn, SqlValue};
1378/// # async fn ex(client: &mut mssql_client::Client<mssql_client::Ready>) -> Result<(), mssql_client::Error> {
1379/// let builder = BulkInsertBuilder::new("dbo.Users")
1380///     .with_typed_columns(vec![
1381///         BulkColumn::new("id", "INT", 0)?,
1382///         BulkColumn::new("name", "NVARCHAR(100)", 1)?,
1383///     ]);
1384///
1385/// let mut writer = client.bulk_insert(&builder).await?;
1386/// writer.send_row_values(&[SqlValue::Int(1), SqlValue::String("Alice".into())])?;
1387/// writer.send_row_values(&[SqlValue::Int(2), SqlValue::String("Bob".into())])?;
1388/// let result = writer.finish().await?;
1389/// # let _ = result;
1390/// # Ok(())
1391/// # }
1392/// ```
1393pub struct BulkWriter<'a, S: crate::state::ConnectionState> {
1394    client: &'a mut crate::client::Client<S>,
1395    bulk: BulkInsert,
1396}
1397
1398impl<'a, S: crate::state::ConnectionState> BulkWriter<'a, S> {
1399    /// Create a new bulk writer.
1400    pub(crate) fn new(client: &'a mut crate::client::Client<S>, bulk: BulkInsert) -> Self {
1401        Self { client, bulk }
1402    }
1403
1404    /// Add a row to the bulk insert buffer.
1405    ///
1406    /// Values are encoded immediately but not sent to the server until
1407    /// [`finish()`](BulkWriter::finish) is called. The number of values must
1408    /// match the number of columns defined for this bulk insert.
1409    pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
1410        self.bulk.send_row(values)
1411    }
1412
1413    /// Add a row of pre-converted SQL values to the buffer.
1414    pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
1415        self.bulk.send_row_values(values)
1416    }
1417
1418    /// Get the number of rows buffered so far.
1419    pub fn total_rows(&self) -> u64 {
1420        self.bulk.total_rows()
1421    }
1422
1423    /// Finish the bulk insert operation and send all buffered data to the server.
1424    ///
1425    /// Writes the DONE token, sends the accumulated row data as a BulkLoad
1426    /// (0x07) message, and reads the server's response.
1427    ///
1428    /// The transfer runs under
1429    /// [`command_timeout`](crate::Config::command_timeout). On expiry it
1430    /// returns [`Error::CommandTimeout`] and
1431    /// the connection is abandoned mid-request (a cancel cannot be
1432    /// interleaved into a partially-sent BulkLoad message), so the pool
1433    /// discards it — the same semantics as `SqlBulkCopy.BulkCopyTimeout`.
1434    pub async fn finish(mut self) -> Result<BulkInsertResult, Error> {
1435        let deadline = self.client.command_deadline();
1436        let total_rows = self.bulk.total_rows();
1437        tracing::debug!(total_rows = total_rows, "finishing bulk insert");
1438
1439        // Write DONE token and freeze the payload
1440        self.bulk.write_done();
1441        let payload = self.bulk.buffer.split().freeze();
1442
1443        // Send BulkLoad data and read server response.
1444        //
1445        // The command timeout here is a hard abandon, not an ATTENTION
1446        // cancel: an Attention packet cannot be interleaved into a
1447        // partially-sent BulkLoad message without corrupting the framing.
1448        // On expiry the connection is left mid-request (in_flight stays
1449        // set), so the pool discards it instead of reusing it —
1450        // SqlBulkCopy's BulkCopyTimeout behaves the same way.
1451        let send_and_read = self.client.send_and_read_bulk_load(payload);
1452        let rows_affected = match deadline {
1453            Some(d) => tokio::time::timeout(d, send_and_read)
1454                .await
1455                .map_err(|_| Error::CommandTimeout)??,
1456            None => send_and_read.await?,
1457        };
1458
1459        Ok(BulkInsertResult {
1460            rows_affected,
1461            batches_committed: 1,
1462            has_errors: false,
1463        })
1464    }
1465}
1466
1467/// Map a nullable type ID to its fixed-width counterpart.
1468///
1469/// SQL Server's BulkLoad protocol rejects nullable type IDs (INTN, BITN, etc.)
1470/// when the target column is NOT NULL. For those columns, the fixed type ID
1471/// variant must be sent instead — with no max_length and no per-row length
1472/// prefix.
1473///
1474/// Returns `None` for types that have no fixed-width variant (e.g. NVARCHAR,
1475/// VARBINARY, DECIMAL, temporal types other than DATETIME/SMALLDATETIME).
1476fn nullable_to_fixed_type(type_id: u8, max_length: Option<u32>) -> Option<u8> {
1477    match (type_id, max_length) {
1478        (0x68, _) => Some(0x32),       // BITN → Bit
1479        (0x26, Some(1)) => Some(0x30), // INTN(1) → Int1 (TINYINT)
1480        (0x26, Some(2)) => Some(0x34), // INTN(2) → Int2 (SMALLINT)
1481        (0x26, Some(4)) => Some(0x38), // INTN(4) → Int4 (INT)
1482        (0x26, Some(8)) => Some(0x7F), // INTN(8) → Int8 (BIGINT)
1483        (0x6D, Some(4)) => Some(0x3B), // FLTN(4) → Float4 (REAL)
1484        (0x6D, Some(8)) => Some(0x3E), // FLTN(8) → Float8 (FLOAT)
1485        (0x6E, Some(4)) => Some(0x7A), // MONEYN(4) → Money4 (SMALLMONEY)
1486        (0x6E, Some(8)) => Some(0x3C), // MONEYN(8) → Money (MONEY)
1487        (0x6F, Some(4)) => Some(0x3A), // DATETIMEN(4) → DateTime4 (SMALLDATETIME)
1488        (0x6F, Some(8)) => Some(0x3D), // DATETIMEN(8) → DateTime (DATETIME)
1489        _ => None,
1490    }
1491}
1492
1493/// Calculate byte length for decimal based on precision.
1494fn decimal_byte_length(precision: u8) -> u8 {
1495    match precision {
1496        1..=9 => 5,
1497        10..=19 => 9,
1498        20..=28 => 13,
1499        29..=38 => 17,
1500        _ => 17, // Max precision
1501    }
1502}
1503
1504/// Calculate byte length for time based on scale.
1505#[cfg(feature = "chrono")]
1506fn time_byte_length(scale: u8) -> u8 {
1507    match scale {
1508        0..=2 => 3,
1509        3..=4 => 4,
1510        5..=7 => 5,
1511        _ => 5,
1512    }
1513}
1514
1515/// Get the divisor for time scale.
1516#[cfg(feature = "chrono")]
1517fn time_scale_divisor(scale: u8) -> u64 {
1518    match scale {
1519        0 => 1_000_000_000,
1520        1 => 100_000_000,
1521        2 => 10_000_000,
1522        3 => 1_000_000,
1523        4 => 100_000,
1524        5 => 10_000,
1525        6 => 1_000,
1526        7 => 100,
1527        _ => 100,
1528    }
1529}
1530
1531#[cfg(test)]
1532#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
1533mod tests {
1534    use super::*;
1535
1536    #[test]
1537    fn test_bulk_options_default() {
1538        let opts = BulkOptions::default();
1539        assert_eq!(opts.batch_size, 0);
1540        assert!(opts.check_constraints);
1541        assert!(!opts.fire_triggers);
1542        assert!(opts.keep_nulls);
1543        assert!(!opts.table_lock);
1544    }
1545
1546    #[test]
1547    fn test_bulk_column_creation() {
1548        let col = BulkColumn::new("id", "INT", 0).unwrap();
1549        assert_eq!(col.name, "id");
1550        assert_eq!(col.type_id, 0x26); // INTN
1551        assert_eq!(col.max_length, Some(4));
1552        assert!(col.nullable);
1553    }
1554
1555    #[test]
1556    fn test_bulk_column_rejects_text() {
1557        let err = BulkColumn::new("body", "TEXT", 0).unwrap_err();
1558        match err {
1559            TypeError::UnsupportedType { sql_type, reason } => {
1560                assert_eq!(sql_type, "TEXT");
1561                assert!(
1562                    reason.contains("VARCHAR(MAX)"),
1563                    "error should redirect to VARCHAR(MAX), got: {reason}"
1564                );
1565                assert!(
1566                    reason.contains("deprecated"),
1567                    "error should mention deprecation, got: {reason}"
1568                );
1569            }
1570            other => panic!("expected UnsupportedType, got {other:?}"),
1571        }
1572    }
1573
1574    #[test]
1575    fn test_bulk_column_rejects_ntext() {
1576        let err = BulkColumn::new("body", "NTEXT", 0).unwrap_err();
1577        match err {
1578            TypeError::UnsupportedType { sql_type, reason } => {
1579                assert_eq!(sql_type, "NTEXT");
1580                assert!(
1581                    reason.contains("NVARCHAR(MAX)"),
1582                    "error should redirect to NVARCHAR(MAX), got: {reason}"
1583                );
1584                assert!(
1585                    reason.contains("deprecated"),
1586                    "error should mention deprecation, got: {reason}"
1587                );
1588            }
1589            other => panic!("expected UnsupportedType, got {other:?}"),
1590        }
1591    }
1592
1593    #[test]
1594    fn test_bulk_column_rejects_text_case_insensitive() {
1595        assert!(matches!(
1596            BulkColumn::new("body", "text", 0),
1597            Err(TypeError::UnsupportedType { .. })
1598        ));
1599        assert!(matches!(
1600            BulkColumn::new("body", "Ntext", 0),
1601            Err(TypeError::UnsupportedType { .. })
1602        ));
1603    }
1604
1605    #[test]
1606    fn test_bulk_column_rejects_image() {
1607        let err = BulkColumn::new("blob", "IMAGE", 0).unwrap_err();
1608        match err {
1609            TypeError::UnsupportedType { sql_type, reason } => {
1610                assert_eq!(sql_type, "IMAGE");
1611                assert!(
1612                    reason.contains("VARBINARY(MAX)"),
1613                    "error should redirect to VARBINARY(MAX), got: {reason}"
1614                );
1615                assert!(
1616                    reason.contains("deprecated"),
1617                    "error should mention deprecation, got: {reason}"
1618                );
1619            }
1620            other => panic!("expected UnsupportedType, got {other:?}"),
1621        }
1622    }
1623
1624    #[test]
1625    fn test_bulk_column_rejects_image_case_insensitive() {
1626        assert!(matches!(
1627            BulkColumn::new("blob", "image", 0),
1628            Err(TypeError::UnsupportedType { .. })
1629        ));
1630        assert!(matches!(
1631            BulkColumn::new("blob", "Image", 0),
1632            Err(TypeError::UnsupportedType { .. })
1633        ));
1634    }
1635
1636    #[test]
1637    fn test_parse_sql_type() {
1638        // Integer types → INTN (0x26) with appropriate length
1639        let (type_id, len, _prec, _scale) = parse_sql_type("INT").unwrap();
1640        assert_eq!(type_id, 0x26);
1641        assert_eq!(len, Some(4));
1642
1643        let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)").unwrap();
1644        assert_eq!(type_id, 0xE7);
1645        assert_eq!(len, Some(200)); // UTF-16 doubles
1646
1647        let (type_id, _, prec, scale) = parse_sql_type("DECIMAL(10,2)").unwrap();
1648        assert_eq!(type_id, 0x6C);
1649        assert_eq!(prec, Some(10));
1650        assert_eq!(scale, Some(2));
1651
1652        // SMALLDATETIME/DATETIME → DATETIMEN (0x6F)
1653        let (type_id, len, _, _) = parse_sql_type("SMALLDATETIME").unwrap();
1654        assert_eq!(type_id, 0x6F);
1655        assert_eq!(len, Some(4));
1656
1657        let (type_id, len, _, _) = parse_sql_type("DATETIME").unwrap();
1658        assert_eq!(type_id, 0x6F);
1659        assert_eq!(len, Some(8));
1660
1661        // Unknown base types are rejected (return None), not silently coerced.
1662        assert_eq!(parse_sql_type("SQL_VARIANT"), None);
1663        assert_eq!(parse_sql_type("NOTATYPE"), None);
1664    }
1665
1666    #[test]
1667    fn test_bulk_column_rejects_unknown_type() {
1668        // A base type outside the supported set is an UnsupportedType error,
1669        // not a silent NVARCHAR coercion that would corrupt the wire data.
1670        for bogus in ["SQL_VARIANT", "GEOGRAPHY", "HIERARCHYID", "NOTATYPE"] {
1671            let err = BulkColumn::new("c", bogus, 0).unwrap_err();
1672            assert!(
1673                matches!(err, TypeError::UnsupportedType { .. }),
1674                "expected UnsupportedType for {bogus}, got {err:?}"
1675            );
1676        }
1677        // Malformed *parameters* on a valid base type still fall back to the
1678        // type default (unchanged behavior), so these must still succeed.
1679        assert!(BulkColumn::new("c", "VARCHAR(garbage)", 0).is_ok());
1680        assert!(BulkColumn::new("c", "MONEY", 0).is_ok());
1681        assert!(BulkColumn::new("c", "DATETIME2(3)", 0).is_ok());
1682    }
1683
1684    #[test]
1685    fn test_parse_sql_type_tolerates_surrounding_spaces() {
1686        // validate_sql_type permits spaces in type declarations, so the base
1687        // (and MAX/param detection) must tolerate them — otherwise the new
1688        // unknown-type rejection regresses valid spellings to UnsupportedType.
1689        assert!(parse_sql_type("INT ").is_some());
1690        assert!(parse_sql_type(" INT").is_some());
1691        assert!(parse_sql_type("VARCHAR (50)").is_some());
1692        assert!(parse_sql_type("DECIMAL (18, 2)").is_some());
1693        let (id, len, _, _) = parse_sql_type("NVARCHAR( MAX )").unwrap();
1694        assert_eq!(id, 0xE7);
1695        assert_eq!(len, Some(0xFFFF)); // MAX detected despite inner spaces
1696        assert!(BulkColumn::new("c", "VARCHAR (50)", 0).is_ok());
1697        assert!(BulkColumn::new("c", "INT ", 0).is_ok());
1698    }
1699
1700    #[test]
1701    fn test_insert_bulk_statement() {
1702        let builder = BulkInsertBuilder::new("dbo.Users")
1703            .with_typed_columns(vec![
1704                BulkColumn::new("id", "INT", 0).unwrap(),
1705                BulkColumn::new("name", "NVARCHAR(100)", 1).unwrap(),
1706            ])
1707            .table_lock(true);
1708
1709        let sql = builder.build_insert_bulk_statement().unwrap();
1710        assert!(sql.contains("INSERT BULK dbo.Users"));
1711        assert!(sql.contains("TABLOCK"));
1712    }
1713
1714    #[test]
1715    fn test_bulk_insert_rejects_injection() {
1716        let builder = BulkInsertBuilder::new("table;DROP TABLE users")
1717            .with_typed_columns(vec![BulkColumn::new("id", "INT", 0).unwrap()]);
1718
1719        assert!(builder.build_insert_bulk_statement().is_err());
1720    }
1721
1722    #[test]
1723    fn test_bulk_insert_validates_column_names() {
1724        let builder = BulkInsertBuilder::new("Users")
1725            .with_typed_columns(vec![BulkColumn::new("col;DROP TABLE x", "INT", 0).unwrap()]);
1726
1727        assert!(builder.build_insert_bulk_statement().is_err());
1728    }
1729
1730    #[test]
1731    fn test_bulk_insert_accepts_qualified_names() {
1732        let builder = BulkInsertBuilder::new("catalog.dbo.Users")
1733            .with_typed_columns(vec![BulkColumn::new("id", "INT", 0).unwrap()]);
1734
1735        assert!(builder.build_insert_bulk_statement().is_ok());
1736    }
1737
1738    #[test]
1739    fn test_bulk_insert_creation() {
1740        let columns = vec![
1741            BulkColumn::new("id", "INT", 0).unwrap(),
1742            BulkColumn::new("name", "NVARCHAR(100)", 1).unwrap(),
1743        ];
1744
1745        let bulk = BulkInsert::new(columns, 1000);
1746        assert_eq!(bulk.total_rows(), 0);
1747        assert_eq!(bulk.rows_in_batch(), 0);
1748        assert!(!bulk.should_flush());
1749    }
1750
1751    #[test]
1752    fn test_decimal_byte_length() {
1753        assert_eq!(decimal_byte_length(5), 5);
1754        assert_eq!(decimal_byte_length(15), 9);
1755        assert_eq!(decimal_byte_length(25), 13);
1756        assert_eq!(decimal_byte_length(35), 17);
1757    }
1758
1759    #[test]
1760    #[cfg(feature = "chrono")]
1761    fn test_time_byte_length() {
1762        assert_eq!(time_byte_length(0), 3);
1763        assert_eq!(time_byte_length(3), 4);
1764        assert_eq!(time_byte_length(7), 5);
1765    }
1766
1767    #[test]
1768    fn test_plp_string_encoding() {
1769        let mut buf = BytesMut::new();
1770        let text = "Hello";
1771        let utf16: Vec<u16> = text.encode_utf16().collect();
1772
1773        encode_plp_string(&utf16, &mut buf);
1774
1775        // Verify structure:
1776        // - 8 bytes PLP_UNKNOWN_LEN marker
1777        // - 4 bytes chunk length
1778        // - data (5 chars * 2 bytes = 10 bytes)
1779        // - 4 bytes terminator (0)
1780        assert_eq!(buf.len(), 8 + 4 + 10 + 4);
1781
1782        // Check total length marker (PLP_UNKNOWN_LEN)
1783        assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1784
1785        // Check chunk length
1786        assert_eq!(&buf[8..12], &10u32.to_le_bytes());
1787
1788        // Check terminator
1789        assert_eq!(&buf[22..26], &0u32.to_le_bytes());
1790    }
1791
1792    #[test]
1793    fn test_plp_binary_encoding() {
1794        let mut buf = BytesMut::new();
1795        let data = b"test binary data";
1796
1797        encode_plp_binary(data, &mut buf);
1798
1799        // Verify structure:
1800        // - 8 bytes PLP_UNKNOWN_LEN marker
1801        // - 4 bytes chunk length
1802        // - data (16 bytes)
1803        // - 4 bytes terminator (0)
1804        assert_eq!(buf.len(), 8 + 4 + 16 + 4);
1805
1806        // Check total length marker
1807        assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1808
1809        // Check chunk length
1810        assert_eq!(&buf[8..12], &16u32.to_le_bytes());
1811
1812        // Check data
1813        assert_eq!(&buf[12..28], data);
1814
1815        // Check terminator
1816        assert_eq!(&buf[28..32], &0u32.to_le_bytes());
1817    }
1818
1819    #[test]
1820    fn test_plp_empty_string() {
1821        let mut buf = BytesMut::new();
1822        let utf16: Vec<u16> = "".encode_utf16().collect();
1823
1824        encode_plp_string(&utf16, &mut buf);
1825
1826        // Empty string: PLP_UNKNOWN_LEN (8) + terminator (4)
1827        assert_eq!(buf.len(), 8 + 4);
1828
1829        // Check total length marker
1830        assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1831
1832        // Check terminator
1833        assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1834    }
1835
1836    #[test]
1837    fn test_plp_empty_binary() {
1838        let mut buf = BytesMut::new();
1839
1840        encode_plp_binary(&[], &mut buf);
1841
1842        // Empty binary: PLP_UNKNOWN_LEN (8) + terminator (4)
1843        assert_eq!(buf.len(), 8 + 4);
1844
1845        // Check total length marker
1846        assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1847
1848        // Check terminator
1849        assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1850    }
1851
1852    /// Verify that write_colmetadata() produces bytes that the TDS parser can
1853    /// decode correctly for all supported column types (nullable variants).
1854    #[test]
1855    fn test_write_colmetadata_roundtrip() {
1856        use tds_protocol::token::ColMetaData;
1857
1858        let columns = vec![
1859            BulkColumn::new("id", "INT", 0).unwrap(),
1860            BulkColumn::new("tiny", "TINYINT", 1).unwrap(),
1861            BulkColumn::new("small", "SMALLINT", 2).unwrap(),
1862            BulkColumn::new("big", "BIGINT", 3).unwrap(),
1863            BulkColumn::new("flag", "BIT", 4).unwrap(),
1864            BulkColumn::new("r", "REAL", 5).unwrap(),
1865            BulkColumn::new("f", "FLOAT", 6).unwrap(),
1866            BulkColumn::new("name", "NVARCHAR(100)", 7).unwrap(),
1867            BulkColumn::new("code", "VARCHAR(50)", 8).unwrap(),
1868            BulkColumn::new("data", "VARBINARY(200)", 9).unwrap(),
1869            BulkColumn::new("d", "DATE", 10).unwrap(),
1870            BulkColumn::new("t", "TIME(3)", 11).unwrap(),
1871            BulkColumn::new("dt", "DATETIME", 12).unwrap(),
1872            BulkColumn::new("dt2", "DATETIME2(7)", 13).unwrap(),
1873            BulkColumn::new("dto", "DATETIMEOFFSET(7)", 14).unwrap(),
1874            BulkColumn::new("sdt", "SMALLDATETIME", 15).unwrap(),
1875            BulkColumn::new("uid", "UNIQUEIDENTIFIER", 16).unwrap(),
1876            BulkColumn::new("amt", "DECIMAL(18,2)", 17).unwrap(),
1877            BulkColumn::new("price", "MONEY", 18).unwrap(),
1878            BulkColumn::new("smoney", "SMALLMONEY", 19).unwrap(),
1879            BulkColumn::new("nmax", "NVARCHAR(MAX)", 20).unwrap(),
1880            BulkColumn::new("vmax", "VARCHAR(MAX)", 21).unwrap(),
1881            BulkColumn::new("bmax", "VARBINARY(MAX)", 22).unwrap(),
1882        ];
1883
1884        let bulk = BulkInsert::new(columns.clone(), 0);
1885
1886        // Extract COLMETADATA bytes (skip the 0x81 token type byte)
1887        let buf = &bulk.buffer[1..];
1888        let mut cursor = bytes::Bytes::copy_from_slice(buf);
1889        let meta = ColMetaData::decode(&mut cursor)
1890            .expect("write_colmetadata output should be parseable by TDS decoder");
1891
1892        assert_eq!(meta.columns.len(), columns.len());
1893
1894        // Verify each column parsed correctly
1895        for (i, (parsed, original)) in meta.columns.iter().zip(columns.iter()).enumerate() {
1896            assert_eq!(parsed.name, original.name, "column {i} name mismatch");
1897            assert_eq!(
1898                parsed.col_type, original.type_id,
1899                "column {i} ({}) type mismatch",
1900                original.name
1901            );
1902
1903            // Verify type-specific metadata
1904            match original.type_id {
1905                // INTN — max_length should match
1906                0x26 => {
1907                    assert_eq!(
1908                        parsed.type_info.max_length, original.max_length,
1909                        "column {i} ({}) INTN max_length",
1910                        original.name
1911                    );
1912                }
1913                // BITN
1914                0x68 => {
1915                    assert_eq!(parsed.type_info.max_length, Some(1));
1916                }
1917                // FLTN
1918                0x6D => {
1919                    assert_eq!(
1920                        parsed.type_info.max_length, original.max_length,
1921                        "column {i} ({}) FLTN max_length",
1922                        original.name
1923                    );
1924                }
1925                // MONEYN
1926                0x6E => {
1927                    assert_eq!(
1928                        parsed.type_info.max_length, original.max_length,
1929                        "column {i} ({}) MONEYN max_length",
1930                        original.name
1931                    );
1932                }
1933                // DATETIMEN
1934                0x6F => {
1935                    assert_eq!(
1936                        parsed.type_info.max_length, original.max_length,
1937                        "column {i} ({}) DATETIMEN max_length",
1938                        original.name
1939                    );
1940                }
1941                // GUID
1942                0x24 => {
1943                    assert_eq!(parsed.type_info.max_length, Some(16));
1944                }
1945                // DATE — no extra metadata
1946                0x28 => {}
1947                // TIME/DATETIME2/DATETIMEOFFSET — scale
1948                0x29..=0x2B => {
1949                    assert_eq!(
1950                        parsed.type_info.scale, original.scale,
1951                        "column {i} ({}) scale",
1952                        original.name
1953                    );
1954                }
1955                // NVARCHAR/VARCHAR — max_length + collation
1956                0xE7 | 0xA7 => {
1957                    assert_eq!(
1958                        parsed.type_info.max_length, original.max_length,
1959                        "column {i} ({}) string max_length",
1960                        original.name
1961                    );
1962                    assert!(
1963                        parsed.type_info.collation.is_some(),
1964                        "column {i} ({}) should have collation",
1965                        original.name
1966                    );
1967                }
1968                // VARBINARY — max_length, no collation
1969                0xA5 => {
1970                    assert_eq!(
1971                        parsed.type_info.max_length, original.max_length,
1972                        "column {i} ({}) binary max_length",
1973                        original.name
1974                    );
1975                    assert!(
1976                        parsed.type_info.collation.is_none(),
1977                        "column {i} ({}) should not have collation",
1978                        original.name
1979                    );
1980                }
1981                // DECIMAL
1982                0x6C => {
1983                    assert_eq!(
1984                        parsed.type_info.precision, original.precision,
1985                        "column {i} ({}) precision",
1986                        original.name
1987                    );
1988                    assert_eq!(
1989                        parsed.type_info.scale, original.scale,
1990                        "column {i} ({}) scale",
1991                        original.name
1992                    );
1993                }
1994                _ => {}
1995            }
1996        }
1997    }
1998
1999    /// Verify that NOT NULL columns use fixed-width type IDs (0x38 Int4,
2000    /// 0x32 Bit, etc.) rather than nullable type IDs (0x26 INTN, 0x68 BITN).
2001    /// SQL Server's BulkLoad rejects nullable IDs for NOT NULL columns.
2002    #[test]
2003    fn test_write_colmetadata_not_null_uses_fixed_types() {
2004        use tds_protocol::token::ColMetaData;
2005        use tds_protocol::types::TypeId;
2006
2007        let columns = vec![
2008            BulkColumn::new("id", "INT", 0)
2009                .unwrap()
2010                .with_nullable(false),
2011            BulkColumn::new("tiny", "TINYINT", 1)
2012                .unwrap()
2013                .with_nullable(false),
2014            BulkColumn::new("small", "SMALLINT", 2)
2015                .unwrap()
2016                .with_nullable(false),
2017            BulkColumn::new("big", "BIGINT", 3)
2018                .unwrap()
2019                .with_nullable(false),
2020            BulkColumn::new("flag", "BIT", 4)
2021                .unwrap()
2022                .with_nullable(false),
2023            BulkColumn::new("r", "REAL", 5)
2024                .unwrap()
2025                .with_nullable(false),
2026            BulkColumn::new("f", "FLOAT", 6)
2027                .unwrap()
2028                .with_nullable(false),
2029            BulkColumn::new("dt", "DATETIME", 7)
2030                .unwrap()
2031                .with_nullable(false),
2032            BulkColumn::new("sdt", "SMALLDATETIME", 8)
2033                .unwrap()
2034                .with_nullable(false),
2035            BulkColumn::new("mny", "MONEY", 9)
2036                .unwrap()
2037                .with_nullable(false),
2038            BulkColumn::new("smny", "SMALLMONEY", 10)
2039                .unwrap()
2040                .with_nullable(false),
2041        ];
2042
2043        let bulk = BulkInsert::new(columns.clone(), 0);
2044
2045        // Every NOT NULL fixed-width column should have fixed_len=true
2046        for (i, fixed) in bulk.fixed_len.iter().enumerate() {
2047            assert!(
2048                *fixed,
2049                "column {i} ({}) should be fixed_len",
2050                columns[i].name
2051            );
2052        }
2053
2054        // Parse the generated COLMETADATA
2055        let buf = &bulk.buffer[1..]; // skip token type byte
2056        let mut cursor = bytes::Bytes::copy_from_slice(buf);
2057        let meta = ColMetaData::decode(&mut cursor).expect("parseable");
2058
2059        // Verify each column has the expected fixed type ID and no Nullable flag
2060        let expected: &[(&str, TypeId)] = &[
2061            ("id", TypeId::Int4),
2062            ("tiny", TypeId::Int1),
2063            ("small", TypeId::Int2),
2064            ("big", TypeId::Int8),
2065            ("flag", TypeId::Bit),
2066            ("r", TypeId::Float4),
2067            ("f", TypeId::Float8),
2068            ("dt", TypeId::DateTime),
2069            ("sdt", TypeId::DateTime4),
2070            ("mny", TypeId::Money),
2071            ("smny", TypeId::Money4),
2072        ];
2073
2074        for (i, (name, ty)) in expected.iter().enumerate() {
2075            assert_eq!(meta.columns[i].name, *name, "column {i} name");
2076            assert_eq!(meta.columns[i].type_id, *ty, "column {i} ({name}) type");
2077            assert_eq!(
2078                meta.columns[i].flags & 0x0001,
2079                0,
2080                "column {i} ({name}) should not have Nullable flag set"
2081            );
2082        }
2083    }
2084
2085    /// Verify that `with_collation()` on a VARCHAR column propagates into
2086    /// the COLMETADATA token — the hand-crafted path previously hardcoded
2087    /// Latin1_General_CI_AS regardless of the caller-supplied collation.
2088    #[test]
2089    fn test_write_colmetadata_uses_caller_collation() {
2090        use tds_protocol::token::{ColMetaData, Collation};
2091
2092        // Chinese_PRC_CI_AS: LCID 0x0804, sort_id 0x52 (just a non-default pair)
2093        let chinese = Collation {
2094            lcid: 0x0804,
2095            sort_id: 0x52,
2096        };
2097
2098        let columns = vec![
2099            BulkColumn::new("s", "VARCHAR(50)", 0)
2100                .unwrap()
2101                .with_collation(chinese),
2102            // NVARCHAR also writes 5 collation bytes — should honor caller too
2103            BulkColumn::new("n", "NVARCHAR(50)", 1)
2104                .unwrap()
2105                .with_collation(chinese),
2106            // VARCHAR without with_collation should keep the Latin1 default
2107            BulkColumn::new("d", "VARCHAR(10)", 2).unwrap(),
2108        ];
2109        let bulk = BulkInsert::new(columns, 0);
2110
2111        let buf = &bulk.buffer[1..];
2112        let mut cursor = bytes::Bytes::copy_from_slice(buf);
2113        let meta = ColMetaData::decode(&mut cursor).expect("parseable");
2114
2115        let c0 = meta.columns[0]
2116            .type_info
2117            .collation
2118            .as_ref()
2119            .expect("VARCHAR has collation");
2120        assert_eq!(c0.lcid, chinese.lcid, "VARCHAR caller LCID");
2121        assert_eq!(c0.sort_id, chinese.sort_id, "VARCHAR caller sort_id");
2122
2123        let c1 = meta.columns[1]
2124            .type_info
2125            .collation
2126            .as_ref()
2127            .expect("NVARCHAR has collation");
2128        assert_eq!(c1.lcid, chinese.lcid, "NVARCHAR caller LCID");
2129        assert_eq!(c1.sort_id, chinese.sort_id, "NVARCHAR caller sort_id");
2130
2131        // Default collation: Latin1_General_CI_AS wire bytes
2132        // [0x09, 0x04, 0xD0, 0x00, 0x34] → lcid u32 LE = 0x00D0_0409, sort_id = 0x34
2133        let default = meta.columns[2]
2134            .type_info
2135            .collation
2136            .as_ref()
2137            .expect("VARCHAR has default collation");
2138        assert_eq!(default.to_bytes(), [0x09, 0x04, 0xD0, 0x00, 0x34]);
2139    }
2140
2141    #[test]
2142    fn test_parse_sql_type_max() {
2143        // Test NVARCHAR(MAX) parsing - uses 0xFFFF marker (not doubled for MAX)
2144        let (type_id, len, _, _) = parse_sql_type("NVARCHAR(MAX)").unwrap();
2145        assert_eq!(type_id, 0xE7);
2146        assert_eq!(len, Some(0xFFFF)); // MAX marker is 0xFFFF
2147
2148        // Test VARBINARY(MAX) parsing
2149        let (type_id, len, _, _) = parse_sql_type("VARBINARY(MAX)").unwrap();
2150        assert_eq!(type_id, 0xA5);
2151        assert_eq!(len, Some(0xFFFF));
2152
2153        // Test VARCHAR(MAX) parsing
2154        let (type_id, len, _, _) = parse_sql_type("VARCHAR(MAX)").unwrap();
2155        assert_eq!(type_id, 0xA7);
2156        assert_eq!(len, Some(0xFFFF));
2157
2158        // Verify normal NVARCHAR does double the length
2159        let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)").unwrap();
2160        assert_eq!(type_id, 0xE7);
2161        assert_eq!(len, Some(200)); // 100 * 2 for UTF-16
2162    }
2163}