mssql_types/
tvp.rs

1//! Table-Valued Parameter (TVP) data structures.
2//!
3//! This module provides the low-level data structures for TVP encoding.
4//! These types are used by `SqlValue::Tvp` to carry TVP data through the
5//! type system.
6//!
7//! ## Wire Format
8//!
9//! TVPs are encoded as type `0xF3` in the TDS protocol with this structure:
10//!
11//! ```text
12//! TVP_TYPE_INFO = TVPTYPE TVP_TYPENAME TVP_COLMETADATA TVP_END_TOKEN *TVP_ROW TVP_END_TOKEN
13//! ```
14//!
15//! See [MS-TDS 2.2.6.9] for the complete specification.
16//!
17//! [MS-TDS 2.2.6.9]: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/c264db71-c1ec-4fe8-b5ef-19d54b1e6566
18
19use crate::SqlValue;
20
21/// Column type identifier for TVP columns.
22///
23/// This enum maps Rust/SQL types to their TDS type identifiers for encoding
24/// within TVP column metadata.
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum TvpColumnType {
27    /// BIT type (boolean).
28    Bit,
29    /// TINYINT type (u8).
30    TinyInt,
31    /// SMALLINT type (i16).
32    SmallInt,
33    /// INT type (i32).
34    Int,
35    /// BIGINT type (i64).
36    BigInt,
37    /// REAL type (f32).
38    Real,
39    /// FLOAT type (f64).
40    Float,
41    /// DECIMAL/NUMERIC type with precision and scale.
42    Decimal {
43        /// Maximum number of digits.
44        precision: u8,
45        /// Number of digits after decimal point.
46        scale: u8,
47    },
48    /// NVARCHAR type with max length in characters.
49    NVarChar {
50        /// Maximum length in characters. Use u16::MAX for MAX.
51        max_length: u16,
52    },
53    /// VARCHAR type with max length in bytes.
54    VarChar {
55        /// Maximum length in bytes. Use u16::MAX for MAX.
56        max_length: u16,
57    },
58    /// VARBINARY type with max length.
59    VarBinary {
60        /// Maximum length in bytes. Use u16::MAX for MAX.
61        max_length: u16,
62    },
63    /// UNIQUEIDENTIFIER type (UUID).
64    UniqueIdentifier,
65    /// DATE type.
66    Date,
67    /// TIME type with scale.
68    Time {
69        /// Fractional seconds precision (0-7).
70        scale: u8,
71    },
72    /// DATETIME2 type with scale.
73    DateTime2 {
74        /// Fractional seconds precision (0-7).
75        scale: u8,
76    },
77    /// DATETIMEOFFSET type with scale.
78    DateTimeOffset {
79        /// Fractional seconds precision (0-7).
80        scale: u8,
81    },
82    /// XML type.
83    Xml,
84}
85
86impl TvpColumnType {
87    /// Infer the TVP column type from an SQL type name string.
88    ///
89    /// This parses SQL type declarations like "INT", "NVARCHAR(100)", "DECIMAL(18,2)".
90    #[must_use]
91    pub fn from_sql_type(sql_type: &str) -> Option<Self> {
92        let sql_type = sql_type.trim().to_uppercase();
93
94        // Handle parameterized types
95        if sql_type.starts_with("NVARCHAR") {
96            let max_len = Self::parse_length(&sql_type).unwrap_or(4000);
97            return Some(Self::NVarChar {
98                max_length: max_len,
99            });
100        }
101        if sql_type.starts_with("VARCHAR") {
102            let max_len = Self::parse_length(&sql_type).unwrap_or(8000);
103            return Some(Self::VarChar {
104                max_length: max_len,
105            });
106        }
107        if sql_type.starts_with("VARBINARY") {
108            let max_len = Self::parse_length(&sql_type).unwrap_or(8000);
109            return Some(Self::VarBinary {
110                max_length: max_len,
111            });
112        }
113        if sql_type.starts_with("DECIMAL") || sql_type.starts_with("NUMERIC") {
114            let (precision, scale) = Self::parse_precision_scale(&sql_type).unwrap_or((18, 0));
115            return Some(Self::Decimal { precision, scale });
116        }
117        if sql_type.starts_with("TIME") {
118            let scale = Self::parse_scale(&sql_type).unwrap_or(7);
119            return Some(Self::Time { scale });
120        }
121        if sql_type.starts_with("DATETIME2") {
122            let scale = Self::parse_scale(&sql_type).unwrap_or(7);
123            return Some(Self::DateTime2 { scale });
124        }
125        if sql_type.starts_with("DATETIMEOFFSET") {
126            let scale = Self::parse_scale(&sql_type).unwrap_or(7);
127            return Some(Self::DateTimeOffset { scale });
128        }
129
130        // Handle simple types
131        match sql_type.as_str() {
132            "BIT" => Some(Self::Bit),
133            "TINYINT" => Some(Self::TinyInt),
134            "SMALLINT" => Some(Self::SmallInt),
135            "INT" | "INTEGER" => Some(Self::Int),
136            "BIGINT" => Some(Self::BigInt),
137            "REAL" => Some(Self::Real),
138            "FLOAT" => Some(Self::Float),
139            "UNIQUEIDENTIFIER" => Some(Self::UniqueIdentifier),
140            "DATE" => Some(Self::Date),
141            "XML" => Some(Self::Xml),
142            _ => None,
143        }
144    }
145
146    /// Parse length from types like "NVARCHAR(100)" or "NVARCHAR(MAX)".
147    fn parse_length(sql_type: &str) -> Option<u16> {
148        let start = sql_type.find('(')?;
149        let end = sql_type.find(')')?;
150        let inner = sql_type[start + 1..end].trim();
151
152        if inner.eq_ignore_ascii_case("MAX") {
153            Some(u16::MAX)
154        } else {
155            inner.parse().ok()
156        }
157    }
158
159    /// Parse precision and scale from types like "DECIMAL(18,2)".
160    fn parse_precision_scale(sql_type: &str) -> Option<(u8, u8)> {
161        let start = sql_type.find('(')?;
162        let end = sql_type.find(')')?;
163        let inner = sql_type[start + 1..end].trim();
164
165        if let Some(comma) = inner.find(',') {
166            let precision = inner[..comma].trim().parse().ok()?;
167            let scale = inner[comma + 1..].trim().parse().ok()?;
168            Some((precision, scale))
169        } else {
170            let precision = inner.parse().ok()?;
171            Some((precision, 0))
172        }
173    }
174
175    /// Parse scale from types like "TIME(3)" or "DATETIME2(7)".
176    fn parse_scale(sql_type: &str) -> Option<u8> {
177        let start = sql_type.find('(')?;
178        let end = sql_type.find(')')?;
179        let inner = sql_type[start + 1..end].trim();
180        inner.parse().ok()
181    }
182
183    /// Get the TDS type ID for this column type.
184    #[must_use]
185    pub const fn type_id(&self) -> u8 {
186        match self {
187            Self::Bit => 0x68,                   // BITNTYPE
188            Self::TinyInt => 0x26,               // INTNTYPE (len 1)
189            Self::SmallInt => 0x26,              // INTNTYPE (len 2)
190            Self::Int => 0x26,                   // INTNTYPE (len 4)
191            Self::BigInt => 0x26,                // INTNTYPE (len 8)
192            Self::Real => 0x6D,                  // FLTNTYPE (len 4)
193            Self::Float => 0x6D,                 // FLTNTYPE (len 8)
194            Self::Decimal { .. } => 0x6C,        // DECIMALNTYPE
195            Self::NVarChar { .. } => 0xE7,       // NVARCHARTYPE
196            Self::VarChar { .. } => 0xA7,        // BIGVARCHARTYPE
197            Self::VarBinary { .. } => 0xA5,      // BIGVARBINTYPE
198            Self::UniqueIdentifier => 0x24,      // GUIDTYPE
199            Self::Date => 0x28,                  // DATETYPE
200            Self::Time { .. } => 0x29,           // TIMETYPE
201            Self::DateTime2 { .. } => 0x2A,      // DATETIME2TYPE
202            Self::DateTimeOffset { .. } => 0x2B, // DATETIMEOFFSETTYPE
203            Self::Xml => 0xF1,                   // XMLTYPE
204        }
205    }
206
207    /// Get the max length field for this column type.
208    #[must_use]
209    pub const fn max_length(&self) -> Option<u16> {
210        match self {
211            Self::Bit => Some(1),
212            Self::TinyInt => Some(1),
213            Self::SmallInt => Some(2),
214            Self::Int => Some(4),
215            Self::BigInt => Some(8),
216            Self::Real => Some(4),
217            Self::Float => Some(8),
218            Self::Decimal { .. } => Some(17), // Max decimal size
219            Self::NVarChar { max_length } => Some(if *max_length == u16::MAX {
220                0xFFFF
221            } else {
222                *max_length * 2
223            }),
224            Self::VarChar { max_length } => Some(*max_length),
225            Self::VarBinary { max_length } => Some(*max_length),
226            Self::UniqueIdentifier => Some(16),
227            Self::Date => None,
228            Self::Time { .. } => None,
229            Self::DateTime2 { .. } => None,
230            Self::DateTimeOffset { .. } => None,
231            Self::Xml => Some(0xFFFF), // MAX
232        }
233    }
234}
235
236/// Column definition for a table-valued parameter.
237#[derive(Debug, Clone, PartialEq)]
238pub struct TvpColumnDef {
239    /// The column type.
240    pub column_type: TvpColumnType,
241    /// Whether the column is nullable.
242    pub nullable: bool,
243}
244
245impl TvpColumnDef {
246    /// Create a new non-nullable column definition.
247    #[must_use]
248    pub const fn new(column_type: TvpColumnType) -> Self {
249        Self {
250            column_type,
251            nullable: false,
252        }
253    }
254
255    /// Create a new nullable column definition.
256    #[must_use]
257    pub const fn nullable(column_type: TvpColumnType) -> Self {
258        Self {
259            column_type,
260            nullable: true,
261        }
262    }
263
264    /// Create from an SQL type string (e.g., "INT", "NVARCHAR(100)").
265    ///
266    /// Returns `None` if the SQL type is not recognized.
267    #[must_use]
268    pub fn from_sql_type(sql_type: &str) -> Option<Self> {
269        TvpColumnType::from_sql_type(sql_type).map(Self::new)
270    }
271}
272
273/// Raw table-valued parameter data for encoding.
274///
275/// This structure holds all the information needed to encode a TVP
276/// in the TDS wire format.
277#[derive(Debug, Clone, PartialEq)]
278pub struct TvpData {
279    /// The database schema (e.g., "dbo"). Empty for default schema.
280    pub schema: String,
281    /// The TVP type name as defined in the database.
282    pub type_name: String,
283    /// Column definitions.
284    pub columns: Vec<TvpColumnDef>,
285    /// Row data - each row is a Vec of SqlValues matching the columns.
286    pub rows: Vec<Vec<SqlValue>>,
287}
288
289impl TvpData {
290    /// Create a new empty TVP with the given schema and type name.
291    #[must_use]
292    pub fn new(schema: impl Into<String>, type_name: impl Into<String>) -> Self {
293        Self {
294            schema: schema.into(),
295            type_name: type_name.into(),
296            columns: Vec::new(),
297            rows: Vec::new(),
298        }
299    }
300
301    /// Add a column definition.
302    #[must_use]
303    pub fn with_column(mut self, column: TvpColumnDef) -> Self {
304        self.columns.push(column);
305        self
306    }
307
308    /// Add a row of values.
309    ///
310    /// # Panics
311    ///
312    /// Panics if the number of values doesn't match the number of columns.
313    #[must_use]
314    pub fn with_row(mut self, values: Vec<SqlValue>) -> Self {
315        assert_eq!(
316            values.len(),
317            self.columns.len(),
318            "Row value count ({}) must match column count ({})",
319            values.len(),
320            self.columns.len()
321        );
322        self.rows.push(values);
323        self
324    }
325
326    /// Add a row of values without panicking.
327    ///
328    /// Returns `Err` if the number of values doesn't match the number of columns.
329    pub fn try_add_row(&mut self, values: Vec<SqlValue>) -> Result<(), TvpError> {
330        if values.len() != self.columns.len() {
331            return Err(TvpError::ColumnCountMismatch {
332                expected: self.columns.len(),
333                actual: values.len(),
334            });
335        }
336        self.rows.push(values);
337        Ok(())
338    }
339
340    /// Get the number of rows.
341    #[must_use]
342    pub fn len(&self) -> usize {
343        self.rows.len()
344    }
345
346    /// Check if the TVP has no rows.
347    #[must_use]
348    pub fn is_empty(&self) -> bool {
349        self.rows.is_empty()
350    }
351
352    /// Get the number of columns.
353    #[must_use]
354    pub fn column_count(&self) -> usize {
355        self.columns.len()
356    }
357}
358
359/// Errors that can occur when working with TVPs.
360#[derive(Debug, Clone, thiserror::Error)]
361pub enum TvpError {
362    /// Column count mismatch between definition and row data.
363    #[error("column count mismatch: expected {expected}, got {actual}")]
364    ColumnCountMismatch {
365        /// Expected number of columns.
366        expected: usize,
367        /// Actual number of values in the row.
368        actual: usize,
369    },
370    /// Unknown SQL type.
371    #[error("unknown SQL type: {0}")]
372    UnknownSqlType(String),
373}
374
375#[cfg(test)]
376#[allow(clippy::unwrap_used, clippy::expect_used)]
377mod tests {
378    use super::*;
379
380    #[test]
381    fn test_column_type_from_sql_type() {
382        assert!(matches!(
383            TvpColumnType::from_sql_type("INT"),
384            Some(TvpColumnType::Int)
385        ));
386        assert!(matches!(
387            TvpColumnType::from_sql_type("BIGINT"),
388            Some(TvpColumnType::BigInt)
389        ));
390        assert!(matches!(
391            TvpColumnType::from_sql_type("nvarchar(100)"),
392            Some(TvpColumnType::NVarChar { max_length: 100 })
393        ));
394        assert!(matches!(
395            TvpColumnType::from_sql_type("NVARCHAR(MAX)"),
396            Some(TvpColumnType::NVarChar { max_length: 65535 })
397        ));
398        assert!(matches!(
399            TvpColumnType::from_sql_type("DECIMAL(18, 2)"),
400            Some(TvpColumnType::Decimal {
401                precision: 18,
402                scale: 2
403            })
404        ));
405        assert!(matches!(
406            TvpColumnType::from_sql_type("datetime2(3)"),
407            Some(TvpColumnType::DateTime2 { scale: 3 })
408        ));
409    }
410
411    #[test]
412    fn test_tvp_data_builder() {
413        let tvp = TvpData::new("dbo", "UserIdList")
414            .with_column(TvpColumnDef::new(TvpColumnType::Int))
415            .with_row(vec![SqlValue::Int(1)])
416            .with_row(vec![SqlValue::Int(2)])
417            .with_row(vec![SqlValue::Int(3)]);
418
419        assert_eq!(tvp.schema, "dbo");
420        assert_eq!(tvp.type_name, "UserIdList");
421        assert_eq!(tvp.column_count(), 1);
422        assert_eq!(tvp.len(), 3);
423    }
424
425    #[test]
426    #[should_panic(expected = "Row value count (2) must match column count (1)")]
427    fn test_tvp_data_row_mismatch_panics() {
428        let _ = TvpData::new("dbo", "Test")
429            .with_column(TvpColumnDef::new(TvpColumnType::Int))
430            .with_row(vec![SqlValue::Int(1), SqlValue::Int(2)]);
431    }
432
433    #[test]
434    fn test_tvp_data_try_add_row_error() {
435        let mut tvp =
436            TvpData::new("dbo", "Test").with_column(TvpColumnDef::new(TvpColumnType::Int));
437
438        let result = tvp.try_add_row(vec![SqlValue::Int(1), SqlValue::Int(2)]);
439        assert!(matches!(result, Err(TvpError::ColumnCountMismatch { .. })));
440    }
441}