Skip to main content

mssql_types/
tvp.rs

1//! Table-Valued Parameter (TVP) data structures.
2//!
3//! This module provides the low-level data structures for TVP encoding.
4//! These types are used by `SqlValue::Tvp` to carry TVP data through the
5//! type system.
6//!
7//! ## Wire Format
8//!
9//! TVPs are encoded as type `0xF3` in the TDS protocol with this structure:
10//!
11//! ```text
12//! TVP_TYPE_INFO = TVPTYPE TVP_TYPENAME TVP_COLMETADATA TVP_END_TOKEN *TVP_ROW TVP_END_TOKEN
13//! ```
14//!
15//! See [MS-TDS 2.2.6.9] for the complete specification.
16//!
17//! [MS-TDS 2.2.6.9]: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/c264db71-c1ec-4fe8-b5ef-19d54b1e6566
18
19use crate::SqlValue;
20
21/// Column type identifier for TVP columns.
22///
23/// This enum maps Rust/SQL types to their TDS type identifiers for encoding
24/// within TVP column metadata.
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26#[non_exhaustive]
27pub enum TvpColumnType {
28    /// BIT type (boolean).
29    Bit,
30    /// TINYINT type (u8).
31    TinyInt,
32    /// SMALLINT type (i16).
33    SmallInt,
34    /// INT type (i32).
35    Int,
36    /// BIGINT type (i64).
37    BigInt,
38    /// REAL type (f32).
39    Real,
40    /// FLOAT type (f64).
41    Float,
42    /// DECIMAL/NUMERIC type with precision and scale.
43    Decimal {
44        /// Maximum number of digits.
45        precision: u8,
46        /// Number of digits after decimal point.
47        scale: u8,
48    },
49    /// NVARCHAR type with max length in characters.
50    NVarChar {
51        /// Maximum length in characters. Use u16::MAX for MAX.
52        max_length: u16,
53    },
54    /// VARCHAR type with max length in bytes.
55    VarChar {
56        /// Maximum length in bytes. Use u16::MAX for MAX.
57        max_length: u16,
58    },
59    /// VARBINARY type with max length.
60    VarBinary {
61        /// Maximum length in bytes. Use u16::MAX for MAX.
62        max_length: u16,
63    },
64    /// UNIQUEIDENTIFIER type (UUID).
65    UniqueIdentifier,
66    /// DATE type.
67    Date,
68    /// TIME type with scale.
69    Time {
70        /// Fractional seconds precision (0-7).
71        scale: u8,
72    },
73    /// DATETIME2 type with scale.
74    DateTime2 {
75        /// Fractional seconds precision (0-7).
76        scale: u8,
77    },
78    /// DATETIMEOFFSET type with scale.
79    DateTimeOffset {
80        /// Fractional seconds precision (0-7).
81        scale: u8,
82    },
83    /// XML type.
84    Xml,
85}
86
87impl TvpColumnType {
88    /// Infer the TVP column type from an SQL type name string.
89    ///
90    /// This parses SQL type declarations like "INT", "NVARCHAR(100)", "DECIMAL(18,2)".
91    #[must_use]
92    pub fn from_sql_type(sql_type: &str) -> Option<Self> {
93        let sql_type = sql_type.trim().to_uppercase();
94
95        // Handle parameterized types
96        if sql_type.starts_with("NVARCHAR") {
97            let max_len = Self::parse_length(&sql_type).unwrap_or(4000);
98            return Some(Self::NVarChar {
99                max_length: max_len,
100            });
101        }
102        if sql_type.starts_with("VARCHAR") {
103            let max_len = Self::parse_length(&sql_type).unwrap_or(8000);
104            return Some(Self::VarChar {
105                max_length: max_len,
106            });
107        }
108        if sql_type.starts_with("VARBINARY") {
109            let max_len = Self::parse_length(&sql_type).unwrap_or(8000);
110            return Some(Self::VarBinary {
111                max_length: max_len,
112            });
113        }
114        if sql_type.starts_with("DECIMAL") || sql_type.starts_with("NUMERIC") {
115            let (precision, scale) = Self::parse_precision_scale(&sql_type).unwrap_or((18, 0));
116            return Some(Self::Decimal { precision, scale });
117        }
118        if sql_type.starts_with("TIME") {
119            let scale = Self::parse_scale(&sql_type).unwrap_or(7);
120            return Some(Self::Time { scale });
121        }
122        if sql_type.starts_with("DATETIME2") {
123            let scale = Self::parse_scale(&sql_type).unwrap_or(7);
124            return Some(Self::DateTime2 { scale });
125        }
126        if sql_type.starts_with("DATETIMEOFFSET") {
127            let scale = Self::parse_scale(&sql_type).unwrap_or(7);
128            return Some(Self::DateTimeOffset { scale });
129        }
130
131        // Handle simple types
132        match sql_type.as_str() {
133            "BIT" => Some(Self::Bit),
134            "TINYINT" => Some(Self::TinyInt),
135            "SMALLINT" => Some(Self::SmallInt),
136            "INT" | "INTEGER" => Some(Self::Int),
137            "BIGINT" => Some(Self::BigInt),
138            "REAL" => Some(Self::Real),
139            "FLOAT" => Some(Self::Float),
140            "UNIQUEIDENTIFIER" => Some(Self::UniqueIdentifier),
141            "DATE" => Some(Self::Date),
142            "XML" => Some(Self::Xml),
143            _ => None,
144        }
145    }
146
147    /// Parse length from types like "NVARCHAR(100)" or "NVARCHAR(MAX)".
148    fn parse_length(sql_type: &str) -> Option<u16> {
149        let start = sql_type.find('(')?;
150        let end = sql_type.find(')')?;
151        let inner = sql_type[start + 1..end].trim();
152
153        if inner.eq_ignore_ascii_case("MAX") {
154            Some(u16::MAX)
155        } else {
156            inner.parse().ok()
157        }
158    }
159
160    /// Parse precision and scale from types like "DECIMAL(18,2)".
161    fn parse_precision_scale(sql_type: &str) -> Option<(u8, u8)> {
162        let start = sql_type.find('(')?;
163        let end = sql_type.find(')')?;
164        let inner = sql_type[start + 1..end].trim();
165
166        if let Some(comma) = inner.find(',') {
167            let precision = inner[..comma].trim().parse().ok()?;
168            let scale = inner[comma + 1..].trim().parse().ok()?;
169            Some((precision, scale))
170        } else {
171            let precision = inner.parse().ok()?;
172            Some((precision, 0))
173        }
174    }
175
176    /// Parse scale from types like "TIME(3)" or "DATETIME2(7)".
177    fn parse_scale(sql_type: &str) -> Option<u8> {
178        let start = sql_type.find('(')?;
179        let end = sql_type.find(')')?;
180        let inner = sql_type[start + 1..end].trim();
181        inner.parse().ok()
182    }
183
184    /// Get the TDS type ID for this column type.
185    #[must_use]
186    pub const fn type_id(&self) -> u8 {
187        match self {
188            Self::Bit => 0x68,                   // BITNTYPE
189            Self::TinyInt => 0x26,               // INTNTYPE (len 1)
190            Self::SmallInt => 0x26,              // INTNTYPE (len 2)
191            Self::Int => 0x26,                   // INTNTYPE (len 4)
192            Self::BigInt => 0x26,                // INTNTYPE (len 8)
193            Self::Real => 0x6D,                  // FLTNTYPE (len 4)
194            Self::Float => 0x6D,                 // FLTNTYPE (len 8)
195            Self::Decimal { .. } => 0x6C,        // DECIMALNTYPE
196            Self::NVarChar { .. } => 0xE7,       // NVARCHARTYPE
197            Self::VarChar { .. } => 0xA7,        // BIGVARCHARTYPE
198            Self::VarBinary { .. } => 0xA5,      // BIGVARBINTYPE
199            Self::UniqueIdentifier => 0x24,      // GUIDTYPE
200            Self::Date => 0x28,                  // DATETYPE
201            Self::Time { .. } => 0x29,           // TIMETYPE
202            Self::DateTime2 { .. } => 0x2A,      // DATETIME2TYPE
203            Self::DateTimeOffset { .. } => 0x2B, // DATETIMEOFFSETTYPE
204            Self::Xml => 0xF1,                   // XMLTYPE
205        }
206    }
207
208    /// Get the max length field for this column type.
209    #[must_use]
210    pub const fn max_length(&self) -> Option<u16> {
211        match self {
212            Self::Bit => Some(1),
213            Self::TinyInt => Some(1),
214            Self::SmallInt => Some(2),
215            Self::Int => Some(4),
216            Self::BigInt => Some(8),
217            Self::Real => Some(4),
218            Self::Float => Some(8),
219            Self::Decimal { .. } => Some(17), // Max decimal size
220            Self::NVarChar { max_length } => Some(if *max_length == u16::MAX {
221                0xFFFF
222            } else {
223                *max_length * 2
224            }),
225            Self::VarChar { max_length } => Some(*max_length),
226            Self::VarBinary { max_length } => Some(*max_length),
227            Self::UniqueIdentifier => Some(16),
228            Self::Date => None,
229            Self::Time { .. } => None,
230            Self::DateTime2 { .. } => None,
231            Self::DateTimeOffset { .. } => None,
232            Self::Xml => Some(0xFFFF), // MAX
233        }
234    }
235}
236
237/// Column definition for a table-valued parameter.
238#[derive(Debug, Clone, PartialEq)]
239pub struct TvpColumnDef {
240    /// The column type.
241    pub column_type: TvpColumnType,
242    /// Whether the column is nullable.
243    pub nullable: bool,
244}
245
246impl TvpColumnDef {
247    /// Create a new non-nullable column definition.
248    #[must_use]
249    pub const fn new(column_type: TvpColumnType) -> Self {
250        Self {
251            column_type,
252            nullable: false,
253        }
254    }
255
256    /// Create a new nullable column definition.
257    #[must_use]
258    pub const fn nullable(column_type: TvpColumnType) -> Self {
259        Self {
260            column_type,
261            nullable: true,
262        }
263    }
264
265    /// Create from an SQL type string (e.g., "INT", "NVARCHAR(100)").
266    ///
267    /// Returns `None` if the SQL type is not recognized.
268    #[must_use]
269    pub fn from_sql_type(sql_type: &str) -> Option<Self> {
270        TvpColumnType::from_sql_type(sql_type).map(Self::new)
271    }
272}
273
274/// Raw table-valued parameter data for encoding.
275///
276/// This structure holds all the information needed to encode a TVP
277/// in the TDS wire format.
278#[derive(Debug, Clone, PartialEq)]
279pub struct TvpData {
280    /// The database schema (e.g., "dbo"). Empty for default schema.
281    pub schema: String,
282    /// The TVP type name as defined in the database.
283    pub type_name: String,
284    /// Column definitions.
285    pub columns: Vec<TvpColumnDef>,
286    /// Row data - each row is a Vec of SqlValues matching the columns.
287    pub rows: Vec<Vec<SqlValue>>,
288}
289
290impl TvpData {
291    /// Create a new empty TVP with the given schema and type name.
292    #[must_use]
293    pub fn new(schema: impl Into<String>, type_name: impl Into<String>) -> Self {
294        Self {
295            schema: schema.into(),
296            type_name: type_name.into(),
297            columns: Vec::new(),
298            rows: Vec::new(),
299        }
300    }
301
302    /// Add a column definition.
303    #[must_use]
304    pub fn with_column(mut self, column: TvpColumnDef) -> Self {
305        self.columns.push(column);
306        self
307    }
308
309    /// Add a row of values.
310    ///
311    /// # Panics
312    ///
313    /// Panics if the number of values doesn't match the number of columns.
314    #[must_use]
315    pub fn with_row(mut self, values: Vec<SqlValue>) -> Self {
316        assert_eq!(
317            values.len(),
318            self.columns.len(),
319            "Row value count ({}) must match column count ({})",
320            values.len(),
321            self.columns.len()
322        );
323        self.rows.push(values);
324        self
325    }
326
327    /// Add a row of values without panicking.
328    ///
329    /// Returns `Err` if the number of values doesn't match the number of columns.
330    pub fn try_add_row(&mut self, values: Vec<SqlValue>) -> Result<(), TvpError> {
331        if values.len() != self.columns.len() {
332            return Err(TvpError::ColumnCountMismatch {
333                expected: self.columns.len(),
334                actual: values.len(),
335            });
336        }
337        self.rows.push(values);
338        Ok(())
339    }
340
341    /// Get the number of rows.
342    #[must_use]
343    pub fn len(&self) -> usize {
344        self.rows.len()
345    }
346
347    /// Check if the TVP has no rows.
348    #[must_use]
349    pub fn is_empty(&self) -> bool {
350        self.rows.is_empty()
351    }
352
353    /// Get the number of columns.
354    #[must_use]
355    pub fn column_count(&self) -> usize {
356        self.columns.len()
357    }
358}
359
360/// Errors that can occur when working with TVPs.
361#[derive(Debug, Clone, thiserror::Error)]
362#[non_exhaustive]
363pub enum TvpError {
364    /// Column count mismatch between definition and row data.
365    #[error("column count mismatch: expected {expected}, got {actual}")]
366    ColumnCountMismatch {
367        /// Expected number of columns.
368        expected: usize,
369        /// Actual number of values in the row.
370        actual: usize,
371    },
372    /// Unknown SQL type.
373    #[error("unknown SQL type: {0}")]
374    UnknownSqlType(String),
375}
376
377#[cfg(test)]
378#[allow(clippy::unwrap_used, clippy::expect_used)]
379mod tests {
380    use super::*;
381
382    #[test]
383    fn test_column_type_from_sql_type() {
384        assert!(matches!(
385            TvpColumnType::from_sql_type("INT"),
386            Some(TvpColumnType::Int)
387        ));
388        assert!(matches!(
389            TvpColumnType::from_sql_type("BIGINT"),
390            Some(TvpColumnType::BigInt)
391        ));
392        assert!(matches!(
393            TvpColumnType::from_sql_type("nvarchar(100)"),
394            Some(TvpColumnType::NVarChar { max_length: 100 })
395        ));
396        assert!(matches!(
397            TvpColumnType::from_sql_type("NVARCHAR(MAX)"),
398            Some(TvpColumnType::NVarChar { max_length: 65535 })
399        ));
400        assert!(matches!(
401            TvpColumnType::from_sql_type("DECIMAL(18, 2)"),
402            Some(TvpColumnType::Decimal {
403                precision: 18,
404                scale: 2
405            })
406        ));
407        assert!(matches!(
408            TvpColumnType::from_sql_type("datetime2(3)"),
409            Some(TvpColumnType::DateTime2 { scale: 3 })
410        ));
411    }
412
413    #[test]
414    fn test_tvp_data_builder() {
415        let tvp = TvpData::new("dbo", "UserIdList")
416            .with_column(TvpColumnDef::new(TvpColumnType::Int))
417            .with_row(vec![SqlValue::Int(1)])
418            .with_row(vec![SqlValue::Int(2)])
419            .with_row(vec![SqlValue::Int(3)]);
420
421        assert_eq!(tvp.schema, "dbo");
422        assert_eq!(tvp.type_name, "UserIdList");
423        assert_eq!(tvp.column_count(), 1);
424        assert_eq!(tvp.len(), 3);
425    }
426
427    #[test]
428    #[should_panic(expected = "Row value count (2) must match column count (1)")]
429    fn test_tvp_data_row_mismatch_panics() {
430        let _ = TvpData::new("dbo", "Test")
431            .with_column(TvpColumnDef::new(TvpColumnType::Int))
432            .with_row(vec![SqlValue::Int(1), SqlValue::Int(2)]);
433    }
434
435    #[test]
436    fn test_tvp_data_try_add_row_error() {
437        let mut tvp =
438            TvpData::new("dbo", "Test").with_column(TvpColumnDef::new(TvpColumnType::Int));
439
440        let result = tvp.try_add_row(vec![SqlValue::Int(1), SqlValue::Int(2)]);
441        assert!(matches!(result, Err(TvpError::ColumnCountMismatch { .. })));
442    }
443}