Skip to main content

mssql_client/
bulk.rs

1//! Bulk Copy Protocol (BCP) support.
2//!
3//! This module provides first-class support for bulk insert operations using
4//! the TDS Bulk Load protocol (packet type 0x07). BCP is significantly more
5//! efficient than individual INSERT statements for loading large amounts of data.
6//!
7//! ## Performance Benefits
8//!
9//! - Minimal logging (when using simple recovery model)
10//! - Batch commits reduce lock contention
11//! - Direct data streaming without SQL parsing overhead
12//! - Optional table lock for maximum throughput
13//!
14//! ## Usage
15//!
16//! ```text
17//! use mssql_client::{BulkInsertBuilder, BulkColumn, BulkOptions};
18//!
19//! let builder = BulkInsertBuilder::new("dbo.Users")
20//!     .with_typed_columns(vec![
21//!         BulkColumn::new("id", "INT", 0)?,
22//!         BulkColumn::new("name", "NVARCHAR(100)", 1)?,
23//!         BulkColumn::new("email", "NVARCHAR(200)", 2)?,
24//!     ])
25//!     .with_options(BulkOptions {
26//!         batch_size: 1000,
27//!         check_constraints: true,
28//!         fire_triggers: false,
29//!         keep_nulls: true,
30//!         table_lock: true,
31//!         order_hint: None,
32//!     });
33//!
34//! let mut writer = client.bulk_insert(&builder).await?;
35//!
36//! // Send rows — buffered in memory, sent on finish()
37//! for user in users {
38//!     writer.send_row(&[&user.id, &user.name, &user.email])?;
39//! }
40//!
41//! let result = writer.finish().await?;
42//! println!("Inserted {} rows", result.rows_affected);
43//! ```
44//!
45//! ## Implementation Notes
46//!
47//! The bulk load protocol uses:
48//! - Packet type 0x07 (BulkLoad)
49//! - COLMETADATA token describing column structure
50//! - ROW tokens containing actual data
51//! - DONE token signaling completion
52//!
53//! Per MS-TDS specification, the row data format matches the server output format
54//! (same as SELECT results) rather than storage format.
55
56use bytes::{BufMut, BytesMut};
57use once_cell::sync::Lazy;
58use regex::Regex;
59use std::sync::Arc;
60
61use mssql_types::{SqlValue, ToSql, TypeError};
62use tds_protocol::packet::{PacketHeader, PacketStatus, PacketType};
63use tds_protocol::token::{Collation, DoneStatus, TokenType};
64
65use crate::error::Error;
66
67/// Options controlling bulk insert behavior.
68///
69/// These options map to SQL Server's BULK INSERT hints and
70/// affect performance, logging, and constraint checking.
71#[derive(Debug, Clone)]
72pub struct BulkOptions {
73    /// Rows-per-batch hint sent to the server.
74    ///
75    /// When non-zero this is emitted as the `ROWS_PER_BATCH` hint on the
76    /// `INSERT BULK` statement, which helps the server pick a query plan.
77    /// It does **not** change client-side behavior: [`BulkWriter`] buffers
78    /// all rows in memory and sends them as a single batch on
79    /// [`finish()`](BulkWriter::finish). Incremental flushing is planned
80    /// alongside the response-streaming work.
81    /// Default: 0 (no hint).
82    pub batch_size: usize,
83
84    /// Check constraints during insert.
85    ///
86    /// Default: true
87    pub check_constraints: bool,
88
89    /// Fire INSERT triggers on the table.
90    ///
91    /// Default: false (better performance)
92    pub fire_triggers: bool,
93
94    /// Keep NULL values instead of using column defaults.
95    ///
96    /// Default: true
97    pub keep_nulls: bool,
98
99    /// Acquire a table-level lock for the duration of the bulk operation.
100    ///
101    /// This can significantly improve performance by reducing lock
102    /// escalation overhead, but blocks all other access to the table.
103    /// Default: false
104    pub table_lock: bool,
105
106    /// Order hint for the data being inserted.
107    ///
108    /// If data is pre-sorted by the clustered index, specify the columns
109    /// here to avoid a sort operation on the server.
110    /// Default: None
111    pub order_hint: Option<Vec<String>>,
112}
113
114impl Default for BulkOptions {
115    fn default() -> Self {
116        Self {
117            batch_size: 0,
118            check_constraints: true,
119            fire_triggers: false,
120            keep_nulls: true,
121            table_lock: false,
122            order_hint: None,
123        }
124    }
125}
126
127/// Column definition for bulk insert.
128#[derive(Debug, Clone)]
129pub struct BulkColumn {
130    /// Column name.
131    pub name: String,
132    /// SQL Server type (e.g., "INT", "NVARCHAR(100)").
133    pub sql_type: String,
134    /// Whether the column allows NULL values.
135    pub nullable: bool,
136    /// Column ordinal (0-based).
137    pub ordinal: usize,
138    /// TDS type ID.
139    type_id: u8,
140    /// Maximum length for variable-length types.
141    max_length: Option<u32>,
142    /// Precision for decimal types.
143    precision: Option<u8>,
144    /// Scale for decimal types.
145    scale: Option<u8>,
146    /// Collation for VARCHAR/CHAR columns.
147    ///
148    /// Populated automatically from the server's COLMETADATA when
149    /// [`Client::bulk_insert`](crate::Client::bulk_insert) is used. Can be set
150    /// manually via [`with_collation`](Self::with_collation) for the
151    /// schema-discovery-free path. When `None`, VARCHAR values fall back to
152    /// the default Latin1_General_CI_AS collation (Windows-1252).
153    collation: Option<Collation>,
154}
155
156impl BulkColumn {
157    /// Create a new bulk column definition.
158    ///
159    /// # Errors
160    ///
161    /// Returns [`TypeError::UnsupportedType`] when `sql_type` names a deprecated
162    /// large object type (`TEXT`, `NTEXT`, `IMAGE`). Use `VARCHAR(MAX)` /
163    /// `NVARCHAR(MAX)` / `VARBINARY(MAX)` instead — Microsoft deprecated
164    /// `TEXT` / `NTEXT` / `IMAGE` in SQL Server 2005 and recommends the `MAX`
165    /// types for all new development.
166    pub fn new<S: Into<String>>(name: S, sql_type: S, ordinal: usize) -> Result<Self, TypeError> {
167        let sql_type_str: String = sql_type.into();
168        reject_unsupported_bulk_type(&sql_type_str)?;
169        let (type_id, max_length, precision, scale) = parse_sql_type(&sql_type_str);
170
171        Ok(Self {
172            name: name.into(),
173            sql_type: sql_type_str,
174            nullable: true,
175            ordinal,
176            type_id,
177            max_length,
178            precision,
179            scale,
180            collation: None,
181        })
182    }
183
184    /// Set whether this column allows NULL values.
185    #[must_use]
186    pub fn with_nullable(mut self, nullable: bool) -> Self {
187        self.nullable = nullable;
188        self
189    }
190
191    /// Set the collation used for VARCHAR/CHAR columns.
192    ///
193    /// Required when [`Client::bulk_insert_without_schema_discovery`](crate::Client::bulk_insert_without_schema_discovery)
194    /// targets VARCHAR columns on a server whose default collation is not
195    /// Latin1_General_CI_AS and the target column uses a different code page.
196    /// Ignored for NVARCHAR/NCHAR columns (always UTF-16).
197    #[must_use]
198    pub fn with_collation(mut self, collation: Collation) -> Self {
199        self.collation = Some(collation);
200        self
201    }
202}
203
204/// Parse SQL type string into TDS type information.
205///
206/// Type parameters (e.g., the "100" in `VARCHAR(100)`) are parsed with
207/// `.parse().ok()` — if a parameter is malformed it falls through to the
208/// type's SQL Server default length (e.g., 8000 for VARCHAR, 4000 for
209/// NVARCHAR). This is intentional: bulk-insert column definitions come
210/// from user code, and defaulting to max length is safer than rejecting
211/// the operation when the base type is valid.
212fn parse_sql_type(sql_type: &str) -> (u8, Option<u32>, Option<u8>, Option<u8>) {
213    let upper = sql_type.to_uppercase();
214
215    // Extract base type and parameters
216    let (base, params) = if let Some(paren_pos) = upper.find('(') {
217        let base = &upper[..paren_pos];
218        let params_str = upper[paren_pos + 1..].trim_end_matches(')');
219        (base, Some(params_str))
220    } else {
221        (upper.as_str(), None)
222    };
223
224    // This returns the nullable type variant ID. `write_colmetadata` switches
225    // to the fixed-width variant (e.g. 0x26 INTN → 0x38 Int4) when the target
226    // column is NOT NULL, since SQL Server's BulkLoad rejects nullable type IDs
227    // for NOT NULL columns with error 4816.
228    match base {
229        "BIT" => (0x68, Some(1), None, None),      // BITN
230        "TINYINT" => (0x26, Some(1), None, None),  // INTN(1)
231        "SMALLINT" => (0x26, Some(2), None, None), // INTN(2)
232        "INT" => (0x26, Some(4), None, None),      // INTN(4)
233        "BIGINT" => (0x26, Some(8), None, None),   // INTN(8)
234        "REAL" => (0x6D, Some(4), None, None),     // FLTN(4)
235        "FLOAT" => (0x6D, Some(8), None, None),    // FLTN(8)
236        "DATE" => (0x28, None, None, None),
237        "TIME" => {
238            let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
239            (0x29, None, None, Some(scale))
240        }
241        "DATETIME" => (0x6F, Some(8), None, None), // DATETIMEN(8)
242        "DATETIME2" => {
243            let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
244            (0x2A, None, None, Some(scale))
245        }
246        "DATETIMEOFFSET" => {
247            let scale = params.and_then(|p| p.parse().ok()).unwrap_or(7);
248            (0x2B, None, None, Some(scale))
249        }
250        "SMALLDATETIME" => (0x6F, Some(4), None, None), // DATETIMEN(4)
251        "UNIQUEIDENTIFIER" => (0x24, Some(16), None, None),
252        "VARCHAR" | "CHAR" => {
253            let len = params
254                .and_then(|p| {
255                    if p == "MAX" {
256                        Some(0xFFFF_u32)
257                    } else {
258                        p.parse().ok()
259                    }
260                })
261                .unwrap_or(8000);
262            (0xA7, Some(len), None, None)
263        }
264        "NVARCHAR" | "NCHAR" => {
265            let is_max = params.map(|p| p == "MAX").unwrap_or(false);
266            if is_max {
267                // MAX types use 0xFFFF marker (not doubled)
268                (0xE7, Some(0xFFFF), None, None)
269            } else {
270                // Normal lengths are in characters, double for UTF-16 byte length
271                let len = params.and_then(|p| p.parse().ok()).unwrap_or(4000);
272                (0xE7, Some(len * 2), None, None)
273            }
274        }
275        "VARBINARY" | "BINARY" => {
276            let len = params
277                .and_then(|p| {
278                    if p == "MAX" {
279                        Some(0xFFFF_u32)
280                    } else {
281                        p.parse().ok()
282                    }
283                })
284                .unwrap_or(8000);
285            (0xA5, Some(len), None, None)
286        }
287        "DECIMAL" | "NUMERIC" => {
288            let (precision, scale) = if let Some(p) = params {
289                let parts: Vec<&str> = p.split(',').map(|s| s.trim()).collect();
290                (
291                    parts.first().and_then(|s| s.parse().ok()).unwrap_or(18),
292                    parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0),
293                )
294            } else {
295                (18, 0)
296            };
297            (0x6C, None, Some(precision), Some(scale))
298        }
299        "MONEY" => (0x6E, Some(8), None, None), // MONEYN(8)
300        "SMALLMONEY" => (0x6E, Some(4), None, None), // MONEYN(4)
301        "XML" => (0xF1, Some(0xFFFF), None, None),
302        _ => (0xE7, Some(8000), None, None), // Default to NVARCHAR(4000)
303    }
304}
305
306/// Reject deprecated large object types that this driver does not support in
307/// bulk insert. `TEXT` / `NTEXT` / `IMAGE` have been deprecated since SQL
308/// Server 2005 and use a legacy TEXTPTR wire format. Users should use
309/// `VARCHAR(MAX)` / `NVARCHAR(MAX)` / `VARBINARY(MAX)` which the driver
310/// supports end-to-end.
311fn reject_unsupported_bulk_type(sql_type: &str) -> Result<(), TypeError> {
312    let base = sql_type
313        .split('(')
314        .next()
315        .unwrap_or("")
316        .trim()
317        .to_uppercase();
318    match base.as_str() {
319        "TEXT" | "NTEXT" => Err(TypeError::UnsupportedType {
320            sql_type: base,
321            reason: "TEXT/NTEXT are not supported. Use VARCHAR(MAX) / \
322                     NVARCHAR(MAX) instead (Microsoft deprecated TEXT/NTEXT in \
323                     SQL Server 2005)."
324                .to_string(),
325        }),
326        "IMAGE" => Err(TypeError::UnsupportedType {
327            sql_type: base,
328            reason: "IMAGE is not supported. Use VARBINARY(MAX) instead \
329                     (Microsoft deprecated IMAGE in SQL Server 2005)."
330                .to_string(),
331        }),
332        _ => Ok(()),
333    }
334}
335
336/// Result of a bulk insert operation.
337#[derive(Debug, Clone)]
338pub struct BulkInsertResult {
339    /// Total number of rows inserted.
340    pub rows_affected: u64,
341    /// Number of batches committed.
342    ///
343    /// [`BulkWriter`] always sends a single batch, so this is currently 1.
344    pub batches_committed: u32,
345    /// Whether any errors were encountered.
346    pub has_errors: bool,
347}
348
349/// Builder for configuring a bulk insert operation.
350#[derive(Debug)]
351pub struct BulkInsertBuilder {
352    table_name: String,
353    columns: Vec<BulkColumn>,
354    options: BulkOptions,
355}
356
357impl BulkInsertBuilder {
358    /// Create a new bulk insert builder for the specified table.
359    pub fn new<S: Into<String>>(table_name: S) -> Self {
360        Self {
361            table_name: table_name.into(),
362            columns: Vec::new(),
363            options: BulkOptions::default(),
364        }
365    }
366
367    /// Specify the columns to insert.
368    ///
369    /// Columns will be queried from the server if not specified,
370    /// but providing them explicitly is more efficient.
371    #[must_use]
372    #[allow(clippy::expect_used)] // NVARCHAR(MAX) is always a supported bulk type
373    pub fn with_columns(mut self, column_names: &[&str]) -> Self {
374        self.columns = column_names
375            .iter()
376            .enumerate()
377            .map(|(i, name)| {
378                BulkColumn::new(*name, "NVARCHAR(MAX)", i)
379                    .expect("NVARCHAR(MAX) is always a supported type")
380            })
381            .collect();
382        self
383    }
384
385    /// Specify columns with full type information.
386    #[must_use]
387    pub fn with_typed_columns(mut self, columns: Vec<BulkColumn>) -> Self {
388        self.columns = columns;
389        self
390    }
391
392    /// Set bulk insert options.
393    #[must_use]
394    pub fn with_options(mut self, options: BulkOptions) -> Self {
395        self.options = options;
396        self
397    }
398
399    /// Set the batch size.
400    #[must_use]
401    pub fn batch_size(mut self, size: usize) -> Self {
402        self.options.batch_size = size;
403        self
404    }
405
406    /// Enable or disable table lock.
407    #[must_use]
408    pub fn table_lock(mut self, enabled: bool) -> Self {
409        self.options.table_lock = enabled;
410        self
411    }
412
413    /// Enable or disable trigger firing.
414    #[must_use]
415    pub fn fire_triggers(mut self, enabled: bool) -> Self {
416        self.options.fire_triggers = enabled;
417        self
418    }
419
420    /// Get the table name.
421    pub fn table_name(&self) -> &str {
422        &self.table_name
423    }
424
425    /// Get the columns.
426    pub fn columns(&self) -> &[BulkColumn] {
427        &self.columns
428    }
429
430    /// Get the options.
431    pub fn options(&self) -> &BulkOptions {
432        &self.options
433    }
434
435    /// Build the INSERT BULK SQL statement.
436    ///
437    /// # Errors
438    ///
439    /// Returns an error if the table name or any column name fails identifier
440    /// validation, preventing SQL injection.
441    pub fn build_insert_bulk_statement(&self) -> Result<String, Error> {
442        // Validate table name (may be schema-qualified: dbo.Users, catalog.schema.table)
443        crate::validation::validate_qualified_identifier(&self.table_name)?;
444
445        // Validate column names
446        for col in &self.columns {
447            crate::validation::validate_identifier(&col.name)?;
448        }
449
450        let mut sql = format!("INSERT BULK {}", self.table_name);
451
452        // Add column definitions
453        if !self.columns.is_empty() {
454            sql.push_str(" (");
455            let cols: Vec<String> = self
456                .columns
457                .iter()
458                .map(|c| {
459                    // Validate sql_type to prevent SQL injection: only allow
460                    // alphanumerics, parentheses (for length/precision), commas,
461                    // spaces, and the MAX keyword — which covers all valid T-SQL
462                    // type specifiers like "NVARCHAR(100)", "DECIMAL(18, 2)",
463                    // "VARBINARY(MAX)", etc.
464                    validate_sql_type(&c.sql_type)?;
465                    Ok(format!("{} {}", c.name, c.sql_type))
466                })
467                .collect::<Result<Vec<_>, Error>>()?;
468            sql.push_str(&cols.join(", "));
469            sql.push(')');
470        }
471
472        // Add WITH clause for options
473        let mut hints: Vec<String> = Vec::new();
474
475        if self.options.check_constraints {
476            hints.push("CHECK_CONSTRAINTS".to_string());
477        }
478        if self.options.fire_triggers {
479            hints.push("FIRE_TRIGGERS".to_string());
480        }
481        if self.options.keep_nulls {
482            hints.push("KEEP_NULLS".to_string());
483        }
484        if self.options.table_lock {
485            hints.push("TABLOCK".to_string());
486        }
487        if self.options.batch_size > 0 {
488            hints.push(format!("ROWS_PER_BATCH = {}", self.options.batch_size));
489        }
490
491        if let Some(ref order) = self.options.order_hint {
492            // Validate order hint column names
493            for col_name in order {
494                crate::validation::validate_identifier(col_name)?;
495            }
496            hints.push(format!("ORDER({})", order.join(", ")));
497        }
498
499        if !hints.is_empty() {
500            sql.push_str(" WITH (");
501            sql.push_str(&hints.join(", "));
502            sql.push(')');
503        }
504
505        Ok(sql)
506    }
507}
508
509/// Validate a SQL type specifier to prevent SQL injection.
510///
511/// Allows only characters that can appear in valid T-SQL type declarations:
512/// letters, digits, parentheses, commas, spaces, and periods.
513/// Examples: "INT", "NVARCHAR(100)", "DECIMAL(18, 2)", "VARBINARY(MAX)".
514fn validate_sql_type(type_str: &str) -> Result<(), Error> {
515    #[allow(clippy::expect_used)] // Static regex compilation with known-valid pattern
516    static SQL_TYPE_RE: Lazy<Regex> =
517        Lazy::new(|| Regex::new(r"^[a-zA-Z][a-zA-Z0-9_ ()\.,]{0,127}$").expect("valid regex"));
518
519    if type_str.is_empty() {
520        return Err(Error::Config("SQL type cannot be empty".into()));
521    }
522
523    if !SQL_TYPE_RE.is_match(type_str) {
524        return Err(Error::Config(format!(
525            "invalid SQL type '{type_str}': contains disallowed characters"
526        )));
527    }
528
529    Ok(())
530}
531
532/// Active bulk insert operation.
533///
534/// This struct manages the streaming of row data to the server.
535/// Call `send_row()` for each row, then `finish()` to complete.
536pub struct BulkInsert {
537    /// Column metadata.
538    columns: Arc<[BulkColumn]>,
539    /// Whether each column uses a fixed-length type on the wire.
540    /// When true, row values for that column are written without a length prefix.
541    fixed_len: Arc<[bool]>,
542    /// Buffer for accumulating rows.
543    buffer: BytesMut,
544    /// Rows in current batch.
545    rows_in_batch: usize,
546    /// Total rows sent.
547    total_rows: u64,
548    /// Batch size (0 = single batch).
549    batch_size: usize,
550    /// Number of batches committed.
551    batches_committed: u32,
552    /// Packet ID counter.
553    packet_id: u8,
554}
555
556impl BulkInsert {
557    /// Create a new bulk insert operation.
558    pub fn new(columns: Vec<BulkColumn>, batch_size: usize) -> Self {
559        Self::new_with_server_metadata(columns, batch_size, None, None)
560    }
561
562    /// Create a new bulk insert operation using server metadata.
563    ///
564    /// When `raw_colmetadata` is provided, it is written directly into the
565    /// BulkLoad buffer, ensuring the COLMETADATA matches the server's exact
566    /// encoding. `server_columns` provides per-column type info so row values
567    /// are encoded correctly (fixed-length types have no length prefix).
568    ///
569    /// This follows the pattern used by Tiberius: the server's own metadata
570    /// from `SELECT TOP 0` is echoed back rather than constructing it from
571    /// user-specified types.
572    pub(crate) fn new_with_server_metadata(
573        mut columns: Vec<BulkColumn>,
574        batch_size: usize,
575        raw_colmetadata: Option<bytes::Bytes>,
576        server_columns: Option<&[tds_protocol::token::ColumnData]>,
577    ) -> Self {
578        // Determine which columns use fixed-length types on the wire.
579        // Fixed-length types omit the per-row length prefix.
580        let fixed_len: Vec<bool> = if let Some(srv_cols) = server_columns {
581            // Propagate collation from server metadata for VARCHAR/CHAR columns.
582            // The user's BulkColumn is constructed from type strings alone and
583            // has no collation until we see the server's COLMETADATA — falling
584            // back to the default Latin1 on NON-Latin servers would silently
585            // corrupt extended characters.
586            for (col, srv) in columns.iter_mut().zip(srv_cols.iter()) {
587                if col.collation.is_none() {
588                    col.collation = srv.type_info.collation;
589                }
590            }
591            srv_cols
592                .iter()
593                .map(|c| c.type_id.is_fixed_length())
594                .collect()
595        } else {
596            // Without server metadata, NOT NULL columns of fixed-width types
597            // must use the fixed type ID variant (e.g. INT NOT NULL uses 0x38
598            // Int4, not 0x26 INTN). SQL Server rejects nullable type IDs for
599            // NOT NULL target columns with error 4816.
600            columns
601                .iter()
602                .map(|c| !c.nullable && nullable_to_fixed_type(c.type_id, c.max_length).is_some())
603                .collect()
604        };
605
606        let mut bulk = Self {
607            columns: columns.into(),
608            fixed_len: fixed_len.into(),
609            buffer: BytesMut::with_capacity(64 * 1024),
610            rows_in_batch: 0,
611            total_rows: 0,
612            batch_size,
613            batches_committed: 0,
614            packet_id: 1,
615        };
616
617        if let Some(raw) = raw_colmetadata {
618            bulk.buffer.extend_from_slice(&raw);
619        } else {
620            bulk.write_colmetadata();
621        }
622
623        bulk
624    }
625
626    /// Write the COLMETADATA token to the buffer.
627    fn write_colmetadata(&mut self) {
628        let buf = &mut self.buffer;
629
630        // Token type
631        buf.put_u8(TokenType::ColMetaData as u8);
632
633        // Column count
634        buf.put_u16_le(self.columns.len() as u16);
635
636        for col in self.columns.iter() {
637            // User type (always 0 for basic types)
638            buf.put_u32_le(0);
639
640            // For NOT NULL columns with a fixed-width type, use the fixed type ID
641            // variant (e.g. INT NOT NULL → 0x38 Int4 instead of 0x26 INTN).
642            // SQL Server's BCP rejects nullable type IDs for NOT NULL columns.
643            let effective_type_id = if !col.nullable {
644                nullable_to_fixed_type(col.type_id, col.max_length).unwrap_or(col.type_id)
645            } else {
646                col.type_id
647            };
648            let is_fixed_variant = effective_type_id != col.type_id;
649
650            // Flags: Nullable (bit 0) | Updateable (bit 3)
651            // BulkLoad columns must have Updateable set to indicate they accept writes.
652            let mut flags: u16 = 0x0008; // Updateable
653            if col.nullable {
654                flags |= 0x0001; // Nullable
655            }
656            buf.put_u16_le(flags);
657
658            // Type info
659            buf.put_u8(effective_type_id);
660
661            // Fixed-width types have no additional TYPE_INFO bytes — skip straight
662            // to the column name.
663            if is_fixed_variant {
664                let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
665                buf.put_u8(name_utf16.len() as u8);
666                for code_unit in name_utf16 {
667                    buf.put_u16_le(code_unit);
668                }
669                continue;
670            }
671
672            // Type-specific length/precision/scale
673            match col.type_id {
674                // Nullable fixed-length types — 1-byte max-length follows type ID
675                // INTN(0x26), BITN(0x68), FLTN(0x6D), MONEYN(0x6E), DATETIMEN(0x6F)
676                0x26 | 0x68 | 0x6D | 0x6E | 0x6F => {
677                    buf.put_u8(col.max_length.unwrap_or(4) as u8);
678                }
679
680                // DATE has no length byte (fixed 3-byte value)
681                0x28 => {}
682
683                // Variable-length string/binary types
684                0xE7 | 0xA7 | 0xA5 | 0xAD => {
685                    // Max length (2 bytes for normal, 4 bytes for MAX)
686                    let max_len = col.max_length.unwrap_or(8000);
687                    if max_len == 0xFFFF {
688                        buf.put_u16_le(0xFFFF);
689                    } else {
690                        buf.put_u16_le(max_len as u16);
691                    }
692
693                    // Collation for string types (5 bytes). Use the caller-
694                    // supplied collation when present (via `with_collation()`),
695                    // otherwise fall back to Latin1_General_CI_AS.
696                    if col.type_id == 0xE7 || col.type_id == 0xA7 {
697                        if let Some(coll) = col.collation.as_ref() {
698                            buf.put_slice(&coll.to_bytes());
699                        } else {
700                            // Default collation: Latin1_General_CI_AS
701                            // Bytes: LCID(0x0409) + flags(0xD000) + SortId(0x34)
702                            buf.put_slice(&[0x09, 0x04, 0xD0, 0x00, 0x34]);
703                        }
704                    }
705                }
706
707                // Decimal/Numeric
708                0x6C | 0x6A => {
709                    // Length (calculated from precision)
710                    let precision = col.precision.unwrap_or(18);
711                    let len = decimal_byte_length(precision);
712                    buf.put_u8(len);
713                    buf.put_u8(precision);
714                    buf.put_u8(col.scale.unwrap_or(0));
715                }
716
717                // Time-based with scale
718                0x29..=0x2B => {
719                    buf.put_u8(col.scale.unwrap_or(7));
720                }
721
722                // GUID
723                0x24 => {
724                    buf.put_u8(16);
725                }
726
727                // Other types - write max length if present
728                _ => {
729                    if let Some(len) = col.max_length {
730                        if len <= 0xFFFF {
731                            buf.put_u16_le(len as u16);
732                        }
733                    }
734                }
735            }
736
737            // Column name (B_VARCHAR format: 1-byte length prefix)
738            let name_utf16: Vec<u16> = col.name.encode_utf16().collect();
739            buf.put_u8(name_utf16.len() as u8);
740            for code_unit in name_utf16 {
741                buf.put_u16_le(code_unit);
742            }
743        }
744    }
745
746    /// Send a row of data.
747    ///
748    /// The values must match the column order and types specified
749    /// when creating the bulk insert.
750    ///
751    /// # Errors
752    ///
753    /// Returns an error if:
754    /// - Wrong number of values provided
755    /// - A value cannot be converted to the expected type
756    pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
757        if values.len() != self.columns.len() {
758            return Err(Error::Config(format!(
759                "expected {} values, got {}",
760                self.columns.len(),
761                values.len()
762            )));
763        }
764
765        // Convert all values to SqlValue
766        let sql_values: Result<Vec<SqlValue>, TypeError> =
767            values.iter().map(|v| v.to_sql()).collect();
768        let sql_values = sql_values.map_err(Error::from)?;
769
770        self.write_row(&sql_values)?;
771
772        self.rows_in_batch += 1;
773        self.total_rows += 1;
774
775        Ok(())
776    }
777
778    /// Send a row of pre-converted SQL values.
779    pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
780        if values.len() != self.columns.len() {
781            return Err(Error::Config(format!(
782                "expected {} values, got {}",
783                self.columns.len(),
784                values.len()
785            )));
786        }
787
788        self.write_row(values)?;
789
790        self.rows_in_batch += 1;
791        self.total_rows += 1;
792
793        Ok(())
794    }
795
796    /// Write a ROW token to the buffer.
797    fn write_row(&mut self, values: &[SqlValue]) -> Result<(), Error> {
798        // ROW token type
799        self.buffer.put_u8(TokenType::Row as u8);
800
801        // Collect column info needed for encoding to avoid borrow conflict
802        let columns: Vec<_> = self.columns.iter().cloned().collect();
803        let fixed_len = self.fixed_len.clone();
804
805        // Write each column value
806        for (i, (col, value)) in columns.iter().zip(values.iter()).enumerate() {
807            let is_fixed = *fixed_len.get(i).unwrap_or(&false);
808            self.encode_column_value(col, value, is_fixed)
809                .map_err(|e| Error::Config(format!("failed to encode column {i}: {e}")))?;
810        }
811
812        Ok(())
813    }
814
815    /// Encode a column value according to its type.
816    ///
817    /// When `is_fixed` is true, the column uses a fixed-length type on the wire
818    /// and values are written without a length prefix. When false, values include
819    /// a length prefix (1 byte for numeric nullable types, 2 bytes for strings).
820    fn encode_column_value(
821        &mut self,
822        col: &BulkColumn,
823        value: &SqlValue,
824        is_fixed: bool,
825    ) -> Result<(), TypeError> {
826        let buf = &mut self.buffer;
827
828        // Check if this column uses PLP (Partially Length-Prefixed) encoding
829        // MAX types (max_length == 0xFFFF) use PLP format
830        let is_plp_type =
831            col.max_length == Some(0xFFFF) && matches!(col.type_id, 0xE7 | 0xA7 | 0xA5 | 0xAD);
832
833        match value {
834            SqlValue::Null => {
835                // NULL encoding depends on type
836                match col.type_id {
837                    // Variable-length types
838                    0xE7 | 0xA7 | 0xA5 | 0xAD => {
839                        if is_plp_type {
840                            // PLP NULL: 0xFFFFFFFFFFFFFFFF
841                            buf.put_u64_le(0xFFFF_FFFF_FFFF_FFFF);
842                        } else {
843                            // Standard NULL: 0xFFFF length marker
844                            buf.put_u16_le(0xFFFF);
845                        }
846                    }
847                    // Nullable fixed types use 0 length
848                    // INTN, BITN, FLTN, MONEYN, DATETIMEN, Decimal, GUID, temporal
849                    0x26 | 0x68 | 0x6D | 0x6E | 0x6F | 0x6C | 0x6A | 0x24 | 0x28 | 0x29 | 0x2A
850                    | 0x2B => {
851                        buf.put_u8(0);
852                    }
853                    // Fixed types without nullable variant
854                    _ => {
855                        if col.nullable {
856                            buf.put_u8(0);
857                        } else {
858                            return Err(TypeError::UnexpectedNull);
859                        }
860                    }
861                }
862            }
863
864            SqlValue::Bool(v) => {
865                if !is_fixed {
866                    buf.put_u8(1);
867                }
868                buf.put_u8(if *v { 1 } else { 0 });
869            }
870
871            SqlValue::TinyInt(v) => {
872                if !is_fixed {
873                    buf.put_u8(1);
874                }
875                buf.put_u8(*v);
876            }
877
878            SqlValue::SmallInt(v) => {
879                if !is_fixed {
880                    buf.put_u8(2);
881                }
882                buf.put_i16_le(*v);
883            }
884
885            SqlValue::Int(v) => {
886                if !is_fixed {
887                    buf.put_u8(4);
888                }
889                buf.put_i32_le(*v);
890            }
891
892            SqlValue::BigInt(v) => {
893                if !is_fixed {
894                    buf.put_u8(8);
895                }
896                buf.put_i64_le(*v);
897            }
898
899            SqlValue::Float(v) => {
900                if !is_fixed {
901                    buf.put_u8(4);
902                }
903                buf.put_f32_le(*v);
904            }
905
906            SqlValue::Double(v) => {
907                if !is_fixed {
908                    buf.put_u8(8);
909                }
910                buf.put_f64_le(*v);
911            }
912
913            SqlValue::String(s) => {
914                // NVARCHAR/NCHAR columns (0xE7/0xEF) use UTF-16LE on the wire.
915                // VARCHAR/CHAR/BIGCHAR columns (0xA7/0x2F/0xAF) use the
916                // collation's code page for single-byte encoding — writing UTF-16
917                // into a VARCHAR column lands each surrogate half in its own
918                // single-byte slot and silently corrupts the data.
919                let is_varchar = matches!(col.type_id, 0xA7 | 0x2F | 0xAF);
920
921                if is_varchar {
922                    let encoded = encode_varchar_for_collation(s, col.collation.as_ref());
923                    let byte_len = encoded.len();
924
925                    if is_plp_type {
926                        encode_plp_binary(&encoded, buf);
927                    } else if byte_len > 0xFFFF {
928                        return Err(TypeError::BufferTooSmall {
929                            needed: byte_len,
930                            available: 0xFFFF,
931                        });
932                    } else {
933                        buf.put_u16_le(byte_len as u16);
934                        buf.put_slice(&encoded);
935                    }
936                } else {
937                    // UTF-16LE encoding for NVARCHAR
938                    let utf16: Vec<u16> = s.encode_utf16().collect();
939                    let byte_len = utf16.len() * 2;
940
941                    if is_plp_type {
942                        // PLP format for MAX types - supports unlimited size
943                        // Send as a single chunk for simplicity
944                        encode_plp_string(&utf16, buf);
945                    } else if byte_len > 0xFFFF {
946                        // Non-MAX column can't hold this much data
947                        return Err(TypeError::BufferTooSmall {
948                            needed: byte_len,
949                            available: 0xFFFF,
950                        });
951                    } else {
952                        // Standard encoding with 2-byte length prefix
953                        buf.put_u16_le(byte_len as u16);
954                        for code_unit in utf16 {
955                            buf.put_u16_le(code_unit);
956                        }
957                    }
958                }
959            }
960
961            SqlValue::Binary(b) => {
962                if is_plp_type {
963                    // PLP format for MAX types - supports unlimited size
964                    encode_plp_binary(b, buf);
965                } else if b.len() > 0xFFFF {
966                    // Non-MAX column can't hold this much data
967                    return Err(TypeError::BufferTooSmall {
968                        needed: b.len(),
969                        available: 0xFFFF,
970                    });
971                } else {
972                    // Standard encoding with 2-byte length prefix
973                    buf.put_u16_le(b.len() as u16);
974                    buf.put_slice(b);
975                }
976            }
977
978            // Feature-gated types - use mssql_types::encode module
979            #[cfg(feature = "decimal")]
980            SqlValue::Decimal(d) => {
981                if col.type_id == 0x6E {
982                    // MONEY / SMALLMONEY — fixed-point scaled by 10_000, not DECIMAL format.
983                    encode_money_value(*d, col, buf, is_fixed)?;
984                } else {
985                    let precision = col.precision.unwrap_or(18);
986                    let len = decimal_byte_length(precision);
987                    buf.put_u8(len);
988
989                    // Sign: 0 = negative, 1 = positive
990                    buf.put_u8(if d.is_sign_negative() { 0 } else { 1 });
991
992                    // Mantissa as unsigned 128-bit integer
993                    let mantissa = d.mantissa().unsigned_abs();
994                    let mantissa_bytes = mantissa.to_le_bytes();
995                    buf.put_slice(&mantissa_bytes[..((len - 1) as usize)]);
996                }
997            }
998
999            #[cfg(feature = "uuid")]
1000            SqlValue::Uuid(u) => {
1001                buf.put_u8(16); // Length
1002                // Use mssql_types encode function
1003                mssql_types::__private::encode_uuid(*u, buf);
1004            }
1005
1006            #[cfg(feature = "chrono")]
1007            SqlValue::Date(d) => {
1008                buf.put_u8(3); // Length
1009                mssql_types::__private::encode_date(*d, buf)?;
1010            }
1011
1012            #[cfg(feature = "chrono")]
1013            SqlValue::Time(t) => {
1014                let scale = col.scale.unwrap_or(7);
1015                let len = time_byte_length(scale);
1016                buf.put_u8(len);
1017                // Encode time with proper scale handling
1018                encode_time_with_scale(*t, scale, buf);
1019            }
1020
1021            #[cfg(feature = "chrono")]
1022            SqlValue::DateTime(dt) => {
1023                // Type 0x6F is DATETIMEN — legacy DATETIME (8 bytes) or
1024                // SMALLDATETIME (4 bytes) format selected by max_length. The
1025                // wire format differs from DATETIME2 (type 0x2A), which uses a
1026                // scale-aware time-then-date layout.
1027                if col.type_id == 0x6F {
1028                    let total_len = col.max_length.unwrap_or(8) as u8;
1029                    if !is_fixed {
1030                        buf.put_u8(total_len);
1031                    }
1032                    match total_len {
1033                        8 => mssql_types::__private::encode_datetime_legacy(*dt, buf),
1034                        4 => mssql_types::__private::encode_smalldatetime(*dt, buf)?,
1035                        _ => {
1036                            return Err(TypeError::InvalidDateTime(format!(
1037                                "DATETIMEN max_length must be 4 or 8, got {total_len}"
1038                            )));
1039                        }
1040                    }
1041                } else {
1042                    let scale = col.scale.unwrap_or(7);
1043                    let time_len = time_byte_length(scale);
1044                    let total_len = time_len + 3;
1045                    buf.put_u8(total_len);
1046                    // Encode time then date
1047                    encode_time_with_scale(dt.time(), scale, buf);
1048                    mssql_types::__private::encode_date(dt.date(), buf)?;
1049                }
1050            }
1051            #[cfg(feature = "chrono")]
1052            SqlValue::SmallDateTime(dt) => {
1053                // Explicit SMALLDATETIME variant — always 4-byte days+minutes,
1054                // regardless of column metadata.
1055                if !is_fixed {
1056                    buf.put_u8(4);
1057                }
1058                mssql_types::__private::encode_smalldatetime(*dt, buf)?;
1059            }
1060            #[cfg(feature = "decimal")]
1061            SqlValue::Money(d) => {
1062                // Force 8-byte MONEY encoding regardless of column metadata.
1063                if !is_fixed {
1064                    buf.put_u8(8);
1065                }
1066                mssql_types::__private::encode_money(*d, buf)?;
1067            }
1068            #[cfg(feature = "decimal")]
1069            SqlValue::SmallMoney(d) => {
1070                if !is_fixed {
1071                    buf.put_u8(4);
1072                }
1073                mssql_types::__private::encode_smallmoney(*d, buf)?;
1074            }
1075
1076            #[cfg(feature = "chrono")]
1077            SqlValue::DateTimeOffset(dto) => {
1078                let scale = col.scale.unwrap_or(7);
1079                let time_len = time_byte_length(scale);
1080                let total_len = time_len + 3 + 2;
1081                buf.put_u8(total_len);
1082                // The wire date/time portion is UTC per MS-TDS §2.2.5.5.1.9,
1083                // not the local wall-clock.
1084                let utc = dto.naive_utc();
1085                encode_time_with_scale(utc.time(), scale, buf);
1086                mssql_types::__private::encode_date(utc.date(), buf)?;
1087                // Timezone offset in minutes
1088                use chrono::Offset;
1089                let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
1090                buf.put_i16_le(offset_minutes);
1091            }
1092
1093            #[cfg(feature = "json")]
1094            SqlValue::Json(j) => {
1095                let s = j.to_string();
1096                encode_nvarchar_value(&s, buf)?;
1097            }
1098
1099            SqlValue::Xml(x) => {
1100                encode_nvarchar_value(x, buf)?;
1101            }
1102
1103            SqlValue::Tvp(_) => {
1104                // TVPs are not valid in bulk copy operations - they're for RPC parameters only
1105                return Err(TypeError::UnsupportedConversion {
1106                    from: "TVP".to_string(),
1107                    to: "bulk copy value",
1108                });
1109            }
1110            // Handle future SqlValue variants
1111            _ => {
1112                return Err(TypeError::UnsupportedConversion {
1113                    from: value.type_name().to_string(),
1114                    to: "bulk copy value",
1115                });
1116            }
1117        }
1118
1119        Ok(())
1120    }
1121}
1122
1123/// Encode a MONEY or SMALLMONEY column value with the appropriate length prefix.
1124///
1125/// When `is_fixed` is true (fixed type ID 0x3C or 0x7A), no length byte
1126/// precedes the payload. Otherwise a 1-byte length prefix is written
1127/// (matching the MONEYN nullable variant).
1128#[cfg(feature = "decimal")]
1129fn encode_money_value(
1130    value: rust_decimal::Decimal,
1131    col: &BulkColumn,
1132    buf: &mut BytesMut,
1133    is_fixed: bool,
1134) -> Result<(), TypeError> {
1135    let money_bytes: u8 = col.max_length.unwrap_or(8) as u8;
1136    if !is_fixed {
1137        buf.put_u8(money_bytes);
1138    }
1139    match money_bytes {
1140        4 => mssql_types::__private::encode_smallmoney(value, buf),
1141        8 => mssql_types::__private::encode_money(value, buf),
1142        _ => Err(TypeError::InvalidDecimal(format!(
1143            "MONEY column has invalid max_length: {money_bytes}"
1144        ))),
1145    }
1146}
1147
1148/// Encode a string as NVARCHAR with length prefix.
1149fn encode_nvarchar_value(s: &str, buf: &mut BytesMut) -> Result<(), TypeError> {
1150    let utf16: Vec<u16> = s.encode_utf16().collect();
1151    let byte_len = utf16.len() * 2;
1152
1153    if byte_len > 0xFFFF {
1154        return Err(TypeError::BufferTooSmall {
1155            needed: byte_len,
1156            available: 0xFFFF,
1157        });
1158    }
1159
1160    buf.put_u16_le(byte_len as u16);
1161    for code_unit in utf16 {
1162        buf.put_u16_le(code_unit);
1163    }
1164    Ok(())
1165}
1166
1167/// PLP marker for an unknown total length (MS-TDS 2.2.5.2.3).
1168/// When the client doesn't know or doesn't wish to compute the total in advance,
1169/// the 8-byte ULONGLONGLEN is set to this value and the server relies on chunk
1170/// framing + the 4-byte terminator to detect the end.
1171const PLP_UNKNOWN_LEN: u64 = 0xFFFFFFFFFFFFFFFE;
1172
1173/// Encode a UTF-16 string using PLP (Partially Length-Prefixed) format.
1174///
1175/// PLP format (per MS-TDS 2.2.5.2.3):
1176/// - 8 bytes: ULONGLONGLEN — PLP_UNKNOWN_LEN or actual total byte count
1177/// - One or more chunks: 4-byte chunk length + chunk bytes
1178/// - Terminator: 4-byte zero
1179///
1180/// We emit `PLP_UNKNOWN_LEN` for compatibility with SQL Server's BulkLoad
1181/// parser. Empirically, some server versions reject a concrete total length
1182/// in the BulkLoad (0x07) path even though the token-stream spec allows it
1183/// ("premature end-of-message" errors for NVARCHAR(MAX) bulk inserts).
1184/// Tiberius uses the same approach.
1185fn encode_plp_string(utf16: &[u16], buf: &mut BytesMut) {
1186    let byte_len = utf16.len() * 2;
1187
1188    buf.put_u64_le(PLP_UNKNOWN_LEN);
1189
1190    if byte_len > 0 {
1191        buf.put_u32_le(byte_len as u32);
1192        for code_unit in utf16 {
1193            buf.put_u16_le(*code_unit);
1194        }
1195    }
1196
1197    buf.put_u32_le(0);
1198}
1199
1200/// Encode binary data using PLP (Partially Length-Prefixed) format.
1201/// See [`encode_plp_string`] for the format specification.
1202fn encode_plp_binary(data: &[u8], buf: &mut BytesMut) {
1203    buf.put_u64_le(PLP_UNKNOWN_LEN);
1204
1205    if !data.is_empty() {
1206        buf.put_u32_le(data.len() as u32);
1207        buf.put_slice(data);
1208    }
1209
1210    buf.put_u32_le(0);
1211}
1212
1213/// Encode a Rust string into single-byte VARCHAR bytes using the column's collation.
1214///
1215/// Delegates to [`tds_protocol::__private::encode_str_for_collation`] so the
1216/// RPC parameter path and the bulk insert path share one implementation.
1217fn encode_varchar_for_collation(value: &str, collation: Option<&Collation>) -> Vec<u8> {
1218    tds_protocol::__private::encode_str_for_collation(value, collation)
1219}
1220
1221/// Encode time with specific scale (for bulk copy).
1222#[cfg(feature = "chrono")]
1223fn encode_time_with_scale(time: chrono::NaiveTime, scale: u8, buf: &mut BytesMut) {
1224    use chrono::Timelike;
1225
1226    let nanos = time.num_seconds_from_midnight() as u64 * 1_000_000_000 + time.nanosecond() as u64;
1227    let intervals = nanos / time_scale_divisor(scale);
1228    let len = time_byte_length(scale);
1229
1230    for i in 0..len {
1231        buf.put_u8(((intervals >> (i * 8)) & 0xFF) as u8);
1232    }
1233}
1234
1235impl BulkInsert {
1236    /// Write the DONE token signaling completion.
1237    fn write_done(&mut self) {
1238        let buf = &mut self.buffer;
1239
1240        buf.put_u8(TokenType::Done as u8);
1241
1242        // Status: FINAL (0x00) | COUNT (0x10)
1243        let status = DoneStatus::from_bits(0x0010); // DONE_COUNT
1244        buf.put_u16_le(status.to_bits());
1245
1246        // Current command (0 for bulk load)
1247        buf.put_u16_le(0);
1248
1249        // Row count
1250        buf.put_u64_le(self.total_rows);
1251    }
1252
1253    /// Get the buffered data as packets ready to send.
1254    ///
1255    /// Returns a vector of complete TDS packets with BulkLoad packet type (0x07).
1256    pub fn take_packets(&mut self) -> Vec<BytesMut> {
1257        const MAX_PACKET_SIZE: usize = 4096;
1258        const HEADER_SIZE: usize = 8;
1259        const MAX_PAYLOAD: usize = MAX_PACKET_SIZE - HEADER_SIZE;
1260
1261        let data = self.buffer.split();
1262        let mut packets = Vec::new();
1263        let mut offset = 0;
1264
1265        while offset < data.len() {
1266            let remaining = data.len() - offset;
1267            let payload_size = remaining.min(MAX_PAYLOAD);
1268            let is_last = offset + payload_size >= data.len();
1269
1270            let mut packet = BytesMut::with_capacity(MAX_PACKET_SIZE);
1271
1272            // Write packet header
1273            let header = PacketHeader {
1274                packet_type: PacketType::BulkLoad,
1275                status: if is_last {
1276                    PacketStatus::END_OF_MESSAGE
1277                } else {
1278                    PacketStatus::NORMAL
1279                },
1280                length: (HEADER_SIZE + payload_size) as u16,
1281                spid: 0,
1282                packet_id: self.packet_id,
1283                window: 0,
1284            };
1285
1286            header.encode(&mut packet);
1287
1288            // Write payload
1289            packet.put_slice(&data[offset..offset + payload_size]);
1290
1291            packets.push(packet);
1292            offset += payload_size;
1293            self.packet_id = self.packet_id.wrapping_add(1);
1294        }
1295
1296        packets
1297    }
1298
1299    /// Get total rows sent so far.
1300    pub fn total_rows(&self) -> u64 {
1301        self.total_rows
1302    }
1303
1304    /// Get rows in current batch.
1305    pub fn rows_in_batch(&self) -> usize {
1306        self.rows_in_batch
1307    }
1308
1309    /// Check if a batch flush is needed.
1310    ///
1311    /// Note: [`BulkWriter`] does not consult this — it buffers every row and
1312    /// sends one batch on finish. This is a helper for callers driving the
1313    /// lower-level [`BulkInsert`] packet API manually.
1314    pub fn should_flush(&self) -> bool {
1315        self.batch_size > 0 && self.rows_in_batch >= self.batch_size
1316    }
1317
1318    /// Prepare for finishing the bulk operation.
1319    /// Writes the DONE token and returns final packets.
1320    pub fn finish_packets(&mut self) -> Vec<BytesMut> {
1321        self.write_done();
1322        self.take_packets()
1323    }
1324
1325    /// Create a result from the current state.
1326    pub fn result(&self) -> BulkInsertResult {
1327        BulkInsertResult {
1328            rows_affected: self.total_rows,
1329            batches_committed: self.batches_committed,
1330            has_errors: false,
1331        }
1332    }
1333}
1334
1335/// Active streaming writer for bulk insert operations.
1336///
1337/// Created via [`crate::client::Client::bulk_insert()`]. Rows are buffered in
1338/// memory as they are added with [`send_row()`](BulkWriter::send_row), then
1339/// transmitted to the server when [`finish()`](BulkWriter::finish) is called.
1340///
1341/// The writer holds a mutable reference to the [`crate::Client`], preventing
1342/// other operations on the connection while the bulk insert is in progress.
1343///
1344/// # Example
1345///
1346/// ```rust,no_run
1347/// # use mssql_client::{BulkInsertBuilder, BulkColumn, SqlValue};
1348/// # async fn ex(client: &mut mssql_client::Client<mssql_client::Ready>) -> Result<(), mssql_client::Error> {
1349/// let builder = BulkInsertBuilder::new("dbo.Users")
1350///     .with_typed_columns(vec![
1351///         BulkColumn::new("id", "INT", 0)?,
1352///         BulkColumn::new("name", "NVARCHAR(100)", 1)?,
1353///     ]);
1354///
1355/// let mut writer = client.bulk_insert(&builder).await?;
1356/// writer.send_row_values(&[SqlValue::Int(1), SqlValue::String("Alice".into())])?;
1357/// writer.send_row_values(&[SqlValue::Int(2), SqlValue::String("Bob".into())])?;
1358/// let result = writer.finish().await?;
1359/// # let _ = result;
1360/// # Ok(())
1361/// # }
1362/// ```
1363pub struct BulkWriter<'a, S: crate::state::ConnectionState> {
1364    client: &'a mut crate::client::Client<S>,
1365    bulk: BulkInsert,
1366}
1367
1368impl<'a, S: crate::state::ConnectionState> BulkWriter<'a, S> {
1369    /// Create a new bulk writer.
1370    pub(crate) fn new(client: &'a mut crate::client::Client<S>, bulk: BulkInsert) -> Self {
1371        Self { client, bulk }
1372    }
1373
1374    /// Add a row to the bulk insert buffer.
1375    ///
1376    /// Values are encoded immediately but not sent to the server until
1377    /// [`finish()`](BulkWriter::finish) is called. The number of values must
1378    /// match the number of columns defined for this bulk insert.
1379    pub fn send_row<T: ToSql>(&mut self, values: &[T]) -> Result<(), Error> {
1380        self.bulk.send_row(values)
1381    }
1382
1383    /// Add a row of pre-converted SQL values to the buffer.
1384    pub fn send_row_values(&mut self, values: &[SqlValue]) -> Result<(), Error> {
1385        self.bulk.send_row_values(values)
1386    }
1387
1388    /// Get the number of rows buffered so far.
1389    pub fn total_rows(&self) -> u64 {
1390        self.bulk.total_rows()
1391    }
1392
1393    /// Finish the bulk insert operation and send all buffered data to the server.
1394    ///
1395    /// Writes the DONE token, sends the accumulated row data as a BulkLoad
1396    /// (0x07) message, and reads the server's response.
1397    ///
1398    /// The transfer runs under
1399    /// [`command_timeout`](crate::Config::command_timeout). On expiry it
1400    /// returns [`Error::CommandTimeout`] and
1401    /// the connection is abandoned mid-request (a cancel cannot be
1402    /// interleaved into a partially-sent BulkLoad message), so the pool
1403    /// discards it — the same semantics as `SqlBulkCopy.BulkCopyTimeout`.
1404    pub async fn finish(mut self) -> Result<BulkInsertResult, Error> {
1405        let deadline = self.client.command_deadline();
1406        let total_rows = self.bulk.total_rows();
1407        tracing::debug!(total_rows = total_rows, "finishing bulk insert");
1408
1409        // Write DONE token and freeze the payload
1410        self.bulk.write_done();
1411        let payload = self.bulk.buffer.split().freeze();
1412
1413        // Send BulkLoad data and read server response.
1414        //
1415        // The command timeout here is a hard abandon, not an ATTENTION
1416        // cancel: an Attention packet cannot be interleaved into a
1417        // partially-sent BulkLoad message without corrupting the framing.
1418        // On expiry the connection is left mid-request (in_flight stays
1419        // set), so the pool discards it instead of reusing it —
1420        // SqlBulkCopy's BulkCopyTimeout behaves the same way.
1421        let send_and_read = self.client.send_and_read_bulk_load(payload);
1422        let rows_affected = match deadline {
1423            Some(d) => tokio::time::timeout(d, send_and_read)
1424                .await
1425                .map_err(|_| Error::CommandTimeout)??,
1426            None => send_and_read.await?,
1427        };
1428
1429        Ok(BulkInsertResult {
1430            rows_affected,
1431            batches_committed: 1,
1432            has_errors: false,
1433        })
1434    }
1435}
1436
1437/// Map a nullable type ID to its fixed-width counterpart.
1438///
1439/// SQL Server's BulkLoad protocol rejects nullable type IDs (INTN, BITN, etc.)
1440/// when the target column is NOT NULL. For those columns, the fixed type ID
1441/// variant must be sent instead — with no max_length and no per-row length
1442/// prefix.
1443///
1444/// Returns `None` for types that have no fixed-width variant (e.g. NVARCHAR,
1445/// VARBINARY, DECIMAL, temporal types other than DATETIME/SMALLDATETIME).
1446fn nullable_to_fixed_type(type_id: u8, max_length: Option<u32>) -> Option<u8> {
1447    match (type_id, max_length) {
1448        (0x68, _) => Some(0x32),       // BITN → Bit
1449        (0x26, Some(1)) => Some(0x30), // INTN(1) → Int1 (TINYINT)
1450        (0x26, Some(2)) => Some(0x34), // INTN(2) → Int2 (SMALLINT)
1451        (0x26, Some(4)) => Some(0x38), // INTN(4) → Int4 (INT)
1452        (0x26, Some(8)) => Some(0x7F), // INTN(8) → Int8 (BIGINT)
1453        (0x6D, Some(4)) => Some(0x3B), // FLTN(4) → Float4 (REAL)
1454        (0x6D, Some(8)) => Some(0x3E), // FLTN(8) → Float8 (FLOAT)
1455        (0x6E, Some(4)) => Some(0x7A), // MONEYN(4) → Money4 (SMALLMONEY)
1456        (0x6E, Some(8)) => Some(0x3C), // MONEYN(8) → Money (MONEY)
1457        (0x6F, Some(4)) => Some(0x3A), // DATETIMEN(4) → DateTime4 (SMALLDATETIME)
1458        (0x6F, Some(8)) => Some(0x3D), // DATETIMEN(8) → DateTime (DATETIME)
1459        _ => None,
1460    }
1461}
1462
1463/// Calculate byte length for decimal based on precision.
1464fn decimal_byte_length(precision: u8) -> u8 {
1465    match precision {
1466        1..=9 => 5,
1467        10..=19 => 9,
1468        20..=28 => 13,
1469        29..=38 => 17,
1470        _ => 17, // Max precision
1471    }
1472}
1473
1474/// Calculate byte length for time based on scale.
1475#[cfg(feature = "chrono")]
1476fn time_byte_length(scale: u8) -> u8 {
1477    match scale {
1478        0..=2 => 3,
1479        3..=4 => 4,
1480        5..=7 => 5,
1481        _ => 5,
1482    }
1483}
1484
1485/// Get the divisor for time scale.
1486#[cfg(feature = "chrono")]
1487fn time_scale_divisor(scale: u8) -> u64 {
1488    match scale {
1489        0 => 1_000_000_000,
1490        1 => 100_000_000,
1491        2 => 10_000_000,
1492        3 => 1_000_000,
1493        4 => 100_000,
1494        5 => 10_000,
1495        6 => 1_000,
1496        7 => 100,
1497        _ => 100,
1498    }
1499}
1500
1501#[cfg(test)]
1502#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
1503mod tests {
1504    use super::*;
1505
1506    #[test]
1507    fn test_bulk_options_default() {
1508        let opts = BulkOptions::default();
1509        assert_eq!(opts.batch_size, 0);
1510        assert!(opts.check_constraints);
1511        assert!(!opts.fire_triggers);
1512        assert!(opts.keep_nulls);
1513        assert!(!opts.table_lock);
1514    }
1515
1516    #[test]
1517    fn test_bulk_column_creation() {
1518        let col = BulkColumn::new("id", "INT", 0).unwrap();
1519        assert_eq!(col.name, "id");
1520        assert_eq!(col.type_id, 0x26); // INTN
1521        assert_eq!(col.max_length, Some(4));
1522        assert!(col.nullable);
1523    }
1524
1525    #[test]
1526    fn test_bulk_column_rejects_text() {
1527        let err = BulkColumn::new("body", "TEXT", 0).unwrap_err();
1528        match err {
1529            TypeError::UnsupportedType { sql_type, reason } => {
1530                assert_eq!(sql_type, "TEXT");
1531                assert!(
1532                    reason.contains("VARCHAR(MAX)"),
1533                    "error should redirect to VARCHAR(MAX), got: {reason}"
1534                );
1535                assert!(
1536                    reason.contains("deprecated"),
1537                    "error should mention deprecation, got: {reason}"
1538                );
1539            }
1540            other => panic!("expected UnsupportedType, got {other:?}"),
1541        }
1542    }
1543
1544    #[test]
1545    fn test_bulk_column_rejects_ntext() {
1546        let err = BulkColumn::new("body", "NTEXT", 0).unwrap_err();
1547        match err {
1548            TypeError::UnsupportedType { sql_type, reason } => {
1549                assert_eq!(sql_type, "NTEXT");
1550                assert!(
1551                    reason.contains("NVARCHAR(MAX)"),
1552                    "error should redirect to NVARCHAR(MAX), got: {reason}"
1553                );
1554                assert!(
1555                    reason.contains("deprecated"),
1556                    "error should mention deprecation, got: {reason}"
1557                );
1558            }
1559            other => panic!("expected UnsupportedType, got {other:?}"),
1560        }
1561    }
1562
1563    #[test]
1564    fn test_bulk_column_rejects_text_case_insensitive() {
1565        assert!(matches!(
1566            BulkColumn::new("body", "text", 0),
1567            Err(TypeError::UnsupportedType { .. })
1568        ));
1569        assert!(matches!(
1570            BulkColumn::new("body", "Ntext", 0),
1571            Err(TypeError::UnsupportedType { .. })
1572        ));
1573    }
1574
1575    #[test]
1576    fn test_bulk_column_rejects_image() {
1577        let err = BulkColumn::new("blob", "IMAGE", 0).unwrap_err();
1578        match err {
1579            TypeError::UnsupportedType { sql_type, reason } => {
1580                assert_eq!(sql_type, "IMAGE");
1581                assert!(
1582                    reason.contains("VARBINARY(MAX)"),
1583                    "error should redirect to VARBINARY(MAX), got: {reason}"
1584                );
1585                assert!(
1586                    reason.contains("deprecated"),
1587                    "error should mention deprecation, got: {reason}"
1588                );
1589            }
1590            other => panic!("expected UnsupportedType, got {other:?}"),
1591        }
1592    }
1593
1594    #[test]
1595    fn test_bulk_column_rejects_image_case_insensitive() {
1596        assert!(matches!(
1597            BulkColumn::new("blob", "image", 0),
1598            Err(TypeError::UnsupportedType { .. })
1599        ));
1600        assert!(matches!(
1601            BulkColumn::new("blob", "Image", 0),
1602            Err(TypeError::UnsupportedType { .. })
1603        ));
1604    }
1605
1606    #[test]
1607    fn test_parse_sql_type() {
1608        // Integer types → INTN (0x26) with appropriate length
1609        let (type_id, len, _prec, _scale) = parse_sql_type("INT");
1610        assert_eq!(type_id, 0x26);
1611        assert_eq!(len, Some(4));
1612
1613        let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
1614        assert_eq!(type_id, 0xE7);
1615        assert_eq!(len, Some(200)); // UTF-16 doubles
1616
1617        let (type_id, _, prec, scale) = parse_sql_type("DECIMAL(10,2)");
1618        assert_eq!(type_id, 0x6C);
1619        assert_eq!(prec, Some(10));
1620        assert_eq!(scale, Some(2));
1621
1622        // SMALLDATETIME/DATETIME → DATETIMEN (0x6F)
1623        let (type_id, len, _, _) = parse_sql_type("SMALLDATETIME");
1624        assert_eq!(type_id, 0x6F);
1625        assert_eq!(len, Some(4));
1626
1627        let (type_id, len, _, _) = parse_sql_type("DATETIME");
1628        assert_eq!(type_id, 0x6F);
1629        assert_eq!(len, Some(8));
1630    }
1631
1632    #[test]
1633    fn test_insert_bulk_statement() {
1634        let builder = BulkInsertBuilder::new("dbo.Users")
1635            .with_typed_columns(vec![
1636                BulkColumn::new("id", "INT", 0).unwrap(),
1637                BulkColumn::new("name", "NVARCHAR(100)", 1).unwrap(),
1638            ])
1639            .table_lock(true);
1640
1641        let sql = builder.build_insert_bulk_statement().unwrap();
1642        assert!(sql.contains("INSERT BULK dbo.Users"));
1643        assert!(sql.contains("TABLOCK"));
1644    }
1645
1646    #[test]
1647    fn test_bulk_insert_rejects_injection() {
1648        let builder = BulkInsertBuilder::new("table;DROP TABLE users")
1649            .with_typed_columns(vec![BulkColumn::new("id", "INT", 0).unwrap()]);
1650
1651        assert!(builder.build_insert_bulk_statement().is_err());
1652    }
1653
1654    #[test]
1655    fn test_bulk_insert_validates_column_names() {
1656        let builder = BulkInsertBuilder::new("Users")
1657            .with_typed_columns(vec![BulkColumn::new("col;DROP TABLE x", "INT", 0).unwrap()]);
1658
1659        assert!(builder.build_insert_bulk_statement().is_err());
1660    }
1661
1662    #[test]
1663    fn test_bulk_insert_accepts_qualified_names() {
1664        let builder = BulkInsertBuilder::new("catalog.dbo.Users")
1665            .with_typed_columns(vec![BulkColumn::new("id", "INT", 0).unwrap()]);
1666
1667        assert!(builder.build_insert_bulk_statement().is_ok());
1668    }
1669
1670    #[test]
1671    fn test_bulk_insert_creation() {
1672        let columns = vec![
1673            BulkColumn::new("id", "INT", 0).unwrap(),
1674            BulkColumn::new("name", "NVARCHAR(100)", 1).unwrap(),
1675        ];
1676
1677        let bulk = BulkInsert::new(columns, 1000);
1678        assert_eq!(bulk.total_rows(), 0);
1679        assert_eq!(bulk.rows_in_batch(), 0);
1680        assert!(!bulk.should_flush());
1681    }
1682
1683    #[test]
1684    fn test_decimal_byte_length() {
1685        assert_eq!(decimal_byte_length(5), 5);
1686        assert_eq!(decimal_byte_length(15), 9);
1687        assert_eq!(decimal_byte_length(25), 13);
1688        assert_eq!(decimal_byte_length(35), 17);
1689    }
1690
1691    #[test]
1692    #[cfg(feature = "chrono")]
1693    fn test_time_byte_length() {
1694        assert_eq!(time_byte_length(0), 3);
1695        assert_eq!(time_byte_length(3), 4);
1696        assert_eq!(time_byte_length(7), 5);
1697    }
1698
1699    #[test]
1700    fn test_plp_string_encoding() {
1701        let mut buf = BytesMut::new();
1702        let text = "Hello";
1703        let utf16: Vec<u16> = text.encode_utf16().collect();
1704
1705        encode_plp_string(&utf16, &mut buf);
1706
1707        // Verify structure:
1708        // - 8 bytes PLP_UNKNOWN_LEN marker
1709        // - 4 bytes chunk length
1710        // - data (5 chars * 2 bytes = 10 bytes)
1711        // - 4 bytes terminator (0)
1712        assert_eq!(buf.len(), 8 + 4 + 10 + 4);
1713
1714        // Check total length marker (PLP_UNKNOWN_LEN)
1715        assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1716
1717        // Check chunk length
1718        assert_eq!(&buf[8..12], &10u32.to_le_bytes());
1719
1720        // Check terminator
1721        assert_eq!(&buf[22..26], &0u32.to_le_bytes());
1722    }
1723
1724    #[test]
1725    fn test_plp_binary_encoding() {
1726        let mut buf = BytesMut::new();
1727        let data = b"test binary data";
1728
1729        encode_plp_binary(data, &mut buf);
1730
1731        // Verify structure:
1732        // - 8 bytes PLP_UNKNOWN_LEN marker
1733        // - 4 bytes chunk length
1734        // - data (16 bytes)
1735        // - 4 bytes terminator (0)
1736        assert_eq!(buf.len(), 8 + 4 + 16 + 4);
1737
1738        // Check total length marker
1739        assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1740
1741        // Check chunk length
1742        assert_eq!(&buf[8..12], &16u32.to_le_bytes());
1743
1744        // Check data
1745        assert_eq!(&buf[12..28], data);
1746
1747        // Check terminator
1748        assert_eq!(&buf[28..32], &0u32.to_le_bytes());
1749    }
1750
1751    #[test]
1752    fn test_plp_empty_string() {
1753        let mut buf = BytesMut::new();
1754        let utf16: Vec<u16> = "".encode_utf16().collect();
1755
1756        encode_plp_string(&utf16, &mut buf);
1757
1758        // Empty string: PLP_UNKNOWN_LEN (8) + terminator (4)
1759        assert_eq!(buf.len(), 8 + 4);
1760
1761        // Check total length marker
1762        assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1763
1764        // Check terminator
1765        assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1766    }
1767
1768    #[test]
1769    fn test_plp_empty_binary() {
1770        let mut buf = BytesMut::new();
1771
1772        encode_plp_binary(&[], &mut buf);
1773
1774        // Empty binary: PLP_UNKNOWN_LEN (8) + terminator (4)
1775        assert_eq!(buf.len(), 8 + 4);
1776
1777        // Check total length marker
1778        assert_eq!(&buf[0..8], &PLP_UNKNOWN_LEN.to_le_bytes());
1779
1780        // Check terminator
1781        assert_eq!(&buf[8..12], &0u32.to_le_bytes());
1782    }
1783
1784    /// Verify that write_colmetadata() produces bytes that the TDS parser can
1785    /// decode correctly for all supported column types (nullable variants).
1786    #[test]
1787    fn test_write_colmetadata_roundtrip() {
1788        use tds_protocol::token::ColMetaData;
1789
1790        let columns = vec![
1791            BulkColumn::new("id", "INT", 0).unwrap(),
1792            BulkColumn::new("tiny", "TINYINT", 1).unwrap(),
1793            BulkColumn::new("small", "SMALLINT", 2).unwrap(),
1794            BulkColumn::new("big", "BIGINT", 3).unwrap(),
1795            BulkColumn::new("flag", "BIT", 4).unwrap(),
1796            BulkColumn::new("r", "REAL", 5).unwrap(),
1797            BulkColumn::new("f", "FLOAT", 6).unwrap(),
1798            BulkColumn::new("name", "NVARCHAR(100)", 7).unwrap(),
1799            BulkColumn::new("code", "VARCHAR(50)", 8).unwrap(),
1800            BulkColumn::new("data", "VARBINARY(200)", 9).unwrap(),
1801            BulkColumn::new("d", "DATE", 10).unwrap(),
1802            BulkColumn::new("t", "TIME(3)", 11).unwrap(),
1803            BulkColumn::new("dt", "DATETIME", 12).unwrap(),
1804            BulkColumn::new("dt2", "DATETIME2(7)", 13).unwrap(),
1805            BulkColumn::new("dto", "DATETIMEOFFSET(7)", 14).unwrap(),
1806            BulkColumn::new("sdt", "SMALLDATETIME", 15).unwrap(),
1807            BulkColumn::new("uid", "UNIQUEIDENTIFIER", 16).unwrap(),
1808            BulkColumn::new("amt", "DECIMAL(18,2)", 17).unwrap(),
1809            BulkColumn::new("price", "MONEY", 18).unwrap(),
1810            BulkColumn::new("smoney", "SMALLMONEY", 19).unwrap(),
1811            BulkColumn::new("nmax", "NVARCHAR(MAX)", 20).unwrap(),
1812            BulkColumn::new("vmax", "VARCHAR(MAX)", 21).unwrap(),
1813            BulkColumn::new("bmax", "VARBINARY(MAX)", 22).unwrap(),
1814        ];
1815
1816        let bulk = BulkInsert::new(columns.clone(), 0);
1817
1818        // Extract COLMETADATA bytes (skip the 0x81 token type byte)
1819        let buf = &bulk.buffer[1..];
1820        let mut cursor = bytes::Bytes::copy_from_slice(buf);
1821        let meta = ColMetaData::decode(&mut cursor)
1822            .expect("write_colmetadata output should be parseable by TDS decoder");
1823
1824        assert_eq!(meta.columns.len(), columns.len());
1825
1826        // Verify each column parsed correctly
1827        for (i, (parsed, original)) in meta.columns.iter().zip(columns.iter()).enumerate() {
1828            assert_eq!(parsed.name, original.name, "column {i} name mismatch");
1829            assert_eq!(
1830                parsed.col_type, original.type_id,
1831                "column {i} ({}) type mismatch",
1832                original.name
1833            );
1834
1835            // Verify type-specific metadata
1836            match original.type_id {
1837                // INTN — max_length should match
1838                0x26 => {
1839                    assert_eq!(
1840                        parsed.type_info.max_length, original.max_length,
1841                        "column {i} ({}) INTN max_length",
1842                        original.name
1843                    );
1844                }
1845                // BITN
1846                0x68 => {
1847                    assert_eq!(parsed.type_info.max_length, Some(1));
1848                }
1849                // FLTN
1850                0x6D => {
1851                    assert_eq!(
1852                        parsed.type_info.max_length, original.max_length,
1853                        "column {i} ({}) FLTN max_length",
1854                        original.name
1855                    );
1856                }
1857                // MONEYN
1858                0x6E => {
1859                    assert_eq!(
1860                        parsed.type_info.max_length, original.max_length,
1861                        "column {i} ({}) MONEYN max_length",
1862                        original.name
1863                    );
1864                }
1865                // DATETIMEN
1866                0x6F => {
1867                    assert_eq!(
1868                        parsed.type_info.max_length, original.max_length,
1869                        "column {i} ({}) DATETIMEN max_length",
1870                        original.name
1871                    );
1872                }
1873                // GUID
1874                0x24 => {
1875                    assert_eq!(parsed.type_info.max_length, Some(16));
1876                }
1877                // DATE — no extra metadata
1878                0x28 => {}
1879                // TIME/DATETIME2/DATETIMEOFFSET — scale
1880                0x29..=0x2B => {
1881                    assert_eq!(
1882                        parsed.type_info.scale, original.scale,
1883                        "column {i} ({}) scale",
1884                        original.name
1885                    );
1886                }
1887                // NVARCHAR/VARCHAR — max_length + collation
1888                0xE7 | 0xA7 => {
1889                    assert_eq!(
1890                        parsed.type_info.max_length, original.max_length,
1891                        "column {i} ({}) string max_length",
1892                        original.name
1893                    );
1894                    assert!(
1895                        parsed.type_info.collation.is_some(),
1896                        "column {i} ({}) should have collation",
1897                        original.name
1898                    );
1899                }
1900                // VARBINARY — max_length, no collation
1901                0xA5 => {
1902                    assert_eq!(
1903                        parsed.type_info.max_length, original.max_length,
1904                        "column {i} ({}) binary max_length",
1905                        original.name
1906                    );
1907                    assert!(
1908                        parsed.type_info.collation.is_none(),
1909                        "column {i} ({}) should not have collation",
1910                        original.name
1911                    );
1912                }
1913                // DECIMAL
1914                0x6C => {
1915                    assert_eq!(
1916                        parsed.type_info.precision, original.precision,
1917                        "column {i} ({}) precision",
1918                        original.name
1919                    );
1920                    assert_eq!(
1921                        parsed.type_info.scale, original.scale,
1922                        "column {i} ({}) scale",
1923                        original.name
1924                    );
1925                }
1926                _ => {}
1927            }
1928        }
1929    }
1930
1931    /// Verify that NOT NULL columns use fixed-width type IDs (0x38 Int4,
1932    /// 0x32 Bit, etc.) rather than nullable type IDs (0x26 INTN, 0x68 BITN).
1933    /// SQL Server's BulkLoad rejects nullable IDs for NOT NULL columns.
1934    #[test]
1935    fn test_write_colmetadata_not_null_uses_fixed_types() {
1936        use tds_protocol::token::ColMetaData;
1937        use tds_protocol::types::TypeId;
1938
1939        let columns = vec![
1940            BulkColumn::new("id", "INT", 0)
1941                .unwrap()
1942                .with_nullable(false),
1943            BulkColumn::new("tiny", "TINYINT", 1)
1944                .unwrap()
1945                .with_nullable(false),
1946            BulkColumn::new("small", "SMALLINT", 2)
1947                .unwrap()
1948                .with_nullable(false),
1949            BulkColumn::new("big", "BIGINT", 3)
1950                .unwrap()
1951                .with_nullable(false),
1952            BulkColumn::new("flag", "BIT", 4)
1953                .unwrap()
1954                .with_nullable(false),
1955            BulkColumn::new("r", "REAL", 5)
1956                .unwrap()
1957                .with_nullable(false),
1958            BulkColumn::new("f", "FLOAT", 6)
1959                .unwrap()
1960                .with_nullable(false),
1961            BulkColumn::new("dt", "DATETIME", 7)
1962                .unwrap()
1963                .with_nullable(false),
1964            BulkColumn::new("sdt", "SMALLDATETIME", 8)
1965                .unwrap()
1966                .with_nullable(false),
1967            BulkColumn::new("mny", "MONEY", 9)
1968                .unwrap()
1969                .with_nullable(false),
1970            BulkColumn::new("smny", "SMALLMONEY", 10)
1971                .unwrap()
1972                .with_nullable(false),
1973        ];
1974
1975        let bulk = BulkInsert::new(columns.clone(), 0);
1976
1977        // Every NOT NULL fixed-width column should have fixed_len=true
1978        for (i, fixed) in bulk.fixed_len.iter().enumerate() {
1979            assert!(
1980                *fixed,
1981                "column {i} ({}) should be fixed_len",
1982                columns[i].name
1983            );
1984        }
1985
1986        // Parse the generated COLMETADATA
1987        let buf = &bulk.buffer[1..]; // skip token type byte
1988        let mut cursor = bytes::Bytes::copy_from_slice(buf);
1989        let meta = ColMetaData::decode(&mut cursor).expect("parseable");
1990
1991        // Verify each column has the expected fixed type ID and no Nullable flag
1992        let expected: &[(&str, TypeId)] = &[
1993            ("id", TypeId::Int4),
1994            ("tiny", TypeId::Int1),
1995            ("small", TypeId::Int2),
1996            ("big", TypeId::Int8),
1997            ("flag", TypeId::Bit),
1998            ("r", TypeId::Float4),
1999            ("f", TypeId::Float8),
2000            ("dt", TypeId::DateTime),
2001            ("sdt", TypeId::DateTime4),
2002            ("mny", TypeId::Money),
2003            ("smny", TypeId::Money4),
2004        ];
2005
2006        for (i, (name, ty)) in expected.iter().enumerate() {
2007            assert_eq!(meta.columns[i].name, *name, "column {i} name");
2008            assert_eq!(meta.columns[i].type_id, *ty, "column {i} ({name}) type");
2009            assert_eq!(
2010                meta.columns[i].flags & 0x0001,
2011                0,
2012                "column {i} ({name}) should not have Nullable flag set"
2013            );
2014        }
2015    }
2016
2017    /// Verify that `with_collation()` on a VARCHAR column propagates into
2018    /// the COLMETADATA token — the hand-crafted path previously hardcoded
2019    /// Latin1_General_CI_AS regardless of the caller-supplied collation.
2020    #[test]
2021    fn test_write_colmetadata_uses_caller_collation() {
2022        use tds_protocol::token::{ColMetaData, Collation};
2023
2024        // Chinese_PRC_CI_AS: LCID 0x0804, sort_id 0x52 (just a non-default pair)
2025        let chinese = Collation {
2026            lcid: 0x0804,
2027            sort_id: 0x52,
2028        };
2029
2030        let columns = vec![
2031            BulkColumn::new("s", "VARCHAR(50)", 0)
2032                .unwrap()
2033                .with_collation(chinese),
2034            // NVARCHAR also writes 5 collation bytes — should honor caller too
2035            BulkColumn::new("n", "NVARCHAR(50)", 1)
2036                .unwrap()
2037                .with_collation(chinese),
2038            // VARCHAR without with_collation should keep the Latin1 default
2039            BulkColumn::new("d", "VARCHAR(10)", 2).unwrap(),
2040        ];
2041        let bulk = BulkInsert::new(columns, 0);
2042
2043        let buf = &bulk.buffer[1..];
2044        let mut cursor = bytes::Bytes::copy_from_slice(buf);
2045        let meta = ColMetaData::decode(&mut cursor).expect("parseable");
2046
2047        let c0 = meta.columns[0]
2048            .type_info
2049            .collation
2050            .as_ref()
2051            .expect("VARCHAR has collation");
2052        assert_eq!(c0.lcid, chinese.lcid, "VARCHAR caller LCID");
2053        assert_eq!(c0.sort_id, chinese.sort_id, "VARCHAR caller sort_id");
2054
2055        let c1 = meta.columns[1]
2056            .type_info
2057            .collation
2058            .as_ref()
2059            .expect("NVARCHAR has collation");
2060        assert_eq!(c1.lcid, chinese.lcid, "NVARCHAR caller LCID");
2061        assert_eq!(c1.sort_id, chinese.sort_id, "NVARCHAR caller sort_id");
2062
2063        // Default collation: Latin1_General_CI_AS wire bytes
2064        // [0x09, 0x04, 0xD0, 0x00, 0x34] → lcid u32 LE = 0x00D0_0409, sort_id = 0x34
2065        let default = meta.columns[2]
2066            .type_info
2067            .collation
2068            .as_ref()
2069            .expect("VARCHAR has default collation");
2070        assert_eq!(default.to_bytes(), [0x09, 0x04, 0xD0, 0x00, 0x34]);
2071    }
2072
2073    #[test]
2074    fn test_parse_sql_type_max() {
2075        // Test NVARCHAR(MAX) parsing - uses 0xFFFF marker (not doubled for MAX)
2076        let (type_id, len, _, _) = parse_sql_type("NVARCHAR(MAX)");
2077        assert_eq!(type_id, 0xE7);
2078        assert_eq!(len, Some(0xFFFF)); // MAX marker is 0xFFFF
2079
2080        // Test VARBINARY(MAX) parsing
2081        let (type_id, len, _, _) = parse_sql_type("VARBINARY(MAX)");
2082        assert_eq!(type_id, 0xA5);
2083        assert_eq!(len, Some(0xFFFF));
2084
2085        // Test VARCHAR(MAX) parsing
2086        let (type_id, len, _, _) = parse_sql_type("VARCHAR(MAX)");
2087        assert_eq!(type_id, 0xA7);
2088        assert_eq!(len, Some(0xFFFF));
2089
2090        // Verify normal NVARCHAR does double the length
2091        let (type_id, len, _, _) = parse_sql_type("NVARCHAR(100)");
2092        assert_eq!(type_id, 0xE7);
2093        assert_eq!(len, Some(200)); // 100 * 2 for UTF-16
2094    }
2095}