Skip to main content

mssql_types/
tvp.rs

1//! Table-Valued Parameter (TVP) data structures.
2//!
3//! This module provides the low-level data structures for TVP encoding.
4//! These types are used by `SqlValue::Tvp` to carry TVP data through the
5//! type system.
6//!
7//! ## Wire Format
8//!
9//! TVPs are encoded as type `0xF3` in the TDS protocol with this structure:
10//!
11//! ```text
12//! TVP_TYPE_INFO = TVPTYPE TVP_TYPENAME TVP_COLMETADATA TVP_END_TOKEN *TVP_ROW TVP_END_TOKEN
13//! ```
14//!
15//! See [MS-TDS 2.2.6.9] for the complete specification.
16//!
17//! [MS-TDS 2.2.6.9]: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/c264db71-c1ec-4fe8-b5ef-19d54b1e6566
18
19use crate::SqlValue;
20
21/// Column type identifier for TVP columns.
22///
23/// This enum maps Rust/SQL types to their TDS type identifiers for encoding
24/// within TVP column metadata.
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26#[non_exhaustive]
27pub enum TvpColumnType {
28    /// BIT type (boolean).
29    Bit,
30    /// TINYINT type (u8).
31    TinyInt,
32    /// SMALLINT type (i16).
33    SmallInt,
34    /// INT type (i32).
35    Int,
36    /// BIGINT type (i64).
37    BigInt,
38    /// REAL type (f32).
39    Real,
40    /// FLOAT type (f64).
41    Float,
42    /// DECIMAL/NUMERIC type with precision and scale.
43    Decimal {
44        /// Maximum number of digits.
45        precision: u8,
46        /// Number of digits after decimal point.
47        scale: u8,
48    },
49    /// NVARCHAR type with max length in characters.
50    NVarChar {
51        /// Maximum length in characters. Use u16::MAX for MAX.
52        max_length: u16,
53    },
54    /// VARCHAR type with max length in bytes.
55    VarChar {
56        /// Maximum length in bytes. Use u16::MAX for MAX.
57        max_length: u16,
58    },
59    /// VARBINARY type with max length.
60    VarBinary {
61        /// Maximum length in bytes. Use u16::MAX for MAX.
62        max_length: u16,
63    },
64    /// UNIQUEIDENTIFIER type (UUID).
65    UniqueIdentifier,
66    /// DATE type.
67    Date,
68    /// TIME type with scale.
69    Time {
70        /// Fractional seconds precision (0-7).
71        scale: u8,
72    },
73    /// DATETIME2 type with scale.
74    DateTime2 {
75        /// Fractional seconds precision (0-7).
76        scale: u8,
77    },
78    /// DATETIMEOFFSET type with scale.
79    DateTimeOffset {
80        /// Fractional seconds precision (0-7).
81        scale: u8,
82    },
83    /// MONEY type (8-byte scaled integer).
84    Money,
85    /// SMALLMONEY type (4-byte scaled integer).
86    SmallMoney,
87    /// Legacy DATETIME type (8 bytes: days + 1/300s ticks).
88    DateTime,
89    /// SMALLDATETIME type (4 bytes: days + minutes).
90    SmallDateTime,
91    /// XML type.
92    Xml,
93}
94
95impl TvpColumnType {
96    /// Infer the TVP column type from an SQL type name string.
97    ///
98    /// This parses SQL type declarations like "INT", "NVARCHAR(100)", "DECIMAL(18,2)".
99    #[must_use]
100    pub fn from_sql_type(sql_type: &str) -> Option<Self> {
101        let sql_type = sql_type.trim().to_uppercase();
102
103        // Handle parameterized types
104        if sql_type.starts_with("NVARCHAR") {
105            let max_len = Self::parse_length(&sql_type).unwrap_or(4000);
106            return Some(Self::NVarChar {
107                max_length: max_len,
108            });
109        }
110        if sql_type.starts_with("VARCHAR") {
111            let max_len = Self::parse_length(&sql_type).unwrap_or(8000);
112            return Some(Self::VarChar {
113                max_length: max_len,
114            });
115        }
116        if sql_type.starts_with("VARBINARY") {
117            let max_len = Self::parse_length(&sql_type).unwrap_or(8000);
118            return Some(Self::VarBinary {
119                max_length: max_len,
120            });
121        }
122        if sql_type.starts_with("DECIMAL") || sql_type.starts_with("NUMERIC") {
123            let (precision, scale) = Self::parse_precision_scale(&sql_type).unwrap_or((18, 0));
124            return Some(Self::Decimal { precision, scale });
125        }
126        if sql_type.starts_with("TIME") {
127            let scale = Self::parse_scale(&sql_type).unwrap_or(7);
128            return Some(Self::Time { scale });
129        }
130        if sql_type.starts_with("DATETIME2") {
131            let scale = Self::parse_scale(&sql_type).unwrap_or(7);
132            return Some(Self::DateTime2 { scale });
133        }
134        if sql_type.starts_with("DATETIMEOFFSET") {
135            let scale = Self::parse_scale(&sql_type).unwrap_or(7);
136            return Some(Self::DateTimeOffset { scale });
137        }
138
139        // Handle simple types
140        match sql_type.as_str() {
141            "BIT" => Some(Self::Bit),
142            "TINYINT" => Some(Self::TinyInt),
143            "SMALLINT" => Some(Self::SmallInt),
144            "INT" | "INTEGER" => Some(Self::Int),
145            "BIGINT" => Some(Self::BigInt),
146            "REAL" => Some(Self::Real),
147            "FLOAT" => Some(Self::Float),
148            "MONEY" => Some(Self::Money),
149            "SMALLMONEY" => Some(Self::SmallMoney),
150            "UNIQUEIDENTIFIER" => Some(Self::UniqueIdentifier),
151            "DATE" => Some(Self::Date),
152            "DATETIME" => Some(Self::DateTime),
153            "SMALLDATETIME" => Some(Self::SmallDateTime),
154            "XML" => Some(Self::Xml),
155            _ => None,
156        }
157    }
158
159    /// Parse length from types like "NVARCHAR(100)" or "NVARCHAR(MAX)".
160    fn parse_length(sql_type: &str) -> Option<u16> {
161        let start = sql_type.find('(')?;
162        let end = sql_type.find(')')?;
163        let inner = sql_type[start + 1..end].trim();
164
165        if inner.eq_ignore_ascii_case("MAX") {
166            Some(u16::MAX)
167        } else {
168            inner.parse().ok()
169        }
170    }
171
172    /// Parse precision and scale from types like "DECIMAL(18,2)".
173    fn parse_precision_scale(sql_type: &str) -> Option<(u8, u8)> {
174        let start = sql_type.find('(')?;
175        let end = sql_type.find(')')?;
176        let inner = sql_type[start + 1..end].trim();
177
178        if let Some(comma) = inner.find(',') {
179            let precision = inner[..comma].trim().parse().ok()?;
180            let scale = inner[comma + 1..].trim().parse().ok()?;
181            Some((precision, scale))
182        } else {
183            let precision = inner.parse().ok()?;
184            Some((precision, 0))
185        }
186    }
187
188    /// Parse scale from types like "TIME(3)" or "DATETIME2(7)".
189    fn parse_scale(sql_type: &str) -> Option<u8> {
190        let start = sql_type.find('(')?;
191        let end = sql_type.find(')')?;
192        let inner = sql_type[start + 1..end].trim();
193        inner.parse().ok()
194    }
195
196    /// Get the TDS type ID for this column type.
197    #[must_use]
198    pub const fn type_id(&self) -> u8 {
199        match self {
200            Self::Bit => 0x68,                            // BITNTYPE
201            Self::TinyInt => 0x26,                        // INTNTYPE (len 1)
202            Self::SmallInt => 0x26,                       // INTNTYPE (len 2)
203            Self::Int => 0x26,                            // INTNTYPE (len 4)
204            Self::BigInt => 0x26,                         // INTNTYPE (len 8)
205            Self::Real => 0x6D,                           // FLTNTYPE (len 4)
206            Self::Float => 0x6D,                          // FLTNTYPE (len 8)
207            Self::Decimal { .. } => 0x6C,                 // DECIMALNTYPE
208            Self::NVarChar { .. } => 0xE7,                // NVARCHARTYPE
209            Self::VarChar { .. } => 0xA7,                 // BIGVARCHARTYPE
210            Self::VarBinary { .. } => 0xA5,               // BIGVARBINTYPE
211            Self::UniqueIdentifier => 0x24,               // GUIDTYPE
212            Self::Date => 0x28,                           // DATETYPE
213            Self::Time { .. } => 0x29,                    // TIMETYPE
214            Self::DateTime2 { .. } => 0x2A,               // DATETIME2TYPE
215            Self::DateTimeOffset { .. } => 0x2B,          // DATETIMEOFFSETTYPE
216            Self::Money | Self::SmallMoney => 0x6E,       // MONEYNTYPE
217            Self::DateTime | Self::SmallDateTime => 0x6F, // DATETIMENTYPE
218            Self::Xml => 0xF1,                            // XMLTYPE
219        }
220    }
221
222    /// Get the max length field for this column type.
223    #[must_use]
224    pub const fn max_length(&self) -> Option<u16> {
225        match self {
226            Self::Bit => Some(1),
227            Self::TinyInt => Some(1),
228            Self::SmallInt => Some(2),
229            Self::Int => Some(4),
230            Self::BigInt => Some(8),
231            Self::Real => Some(4),
232            Self::Float => Some(8),
233            Self::Decimal { .. } => Some(17), // Max decimal size
234            Self::NVarChar { max_length } => Some(if *max_length == u16::MAX {
235                0xFFFF
236            } else {
237                *max_length * 2
238            }),
239            Self::VarChar { max_length } => Some(*max_length),
240            Self::VarBinary { max_length } => Some(*max_length),
241            Self::UniqueIdentifier => Some(16),
242            Self::Date => None,
243            Self::Time { .. } => None,
244            Self::DateTime2 { .. } => None,
245            Self::DateTimeOffset { .. } => None,
246            Self::Money | Self::DateTime => Some(8),
247            Self::SmallMoney | Self::SmallDateTime => Some(4),
248            Self::Xml => Some(0xFFFF), // MAX
249        }
250    }
251}
252
253/// Column definition for a table-valued parameter.
254#[derive(Debug, Clone, PartialEq)]
255pub struct TvpColumnDef {
256    /// The column type.
257    pub column_type: TvpColumnType,
258    /// Whether the column is nullable.
259    pub nullable: bool,
260}
261
262impl TvpColumnDef {
263    /// Create a new non-nullable column definition.
264    #[must_use]
265    pub const fn new(column_type: TvpColumnType) -> Self {
266        Self {
267            column_type,
268            nullable: false,
269        }
270    }
271
272    /// Create a new nullable column definition.
273    #[must_use]
274    pub const fn nullable(column_type: TvpColumnType) -> Self {
275        Self {
276            column_type,
277            nullable: true,
278        }
279    }
280
281    /// Create from an SQL type string (e.g., "INT", "NVARCHAR(100)").
282    ///
283    /// Returns `None` if the SQL type is not recognized.
284    #[must_use]
285    pub fn from_sql_type(sql_type: &str) -> Option<Self> {
286        TvpColumnType::from_sql_type(sql_type).map(Self::new)
287    }
288}
289
290/// Raw table-valued parameter data for encoding.
291///
292/// This structure holds all the information needed to encode a TVP
293/// in the TDS wire format.
294#[derive(Debug, Clone, PartialEq)]
295pub struct TvpData {
296    /// The database schema (e.g., "dbo"). Empty for default schema.
297    pub schema: String,
298    /// The TVP type name as defined in the database.
299    pub type_name: String,
300    /// Column definitions.
301    pub columns: Vec<TvpColumnDef>,
302    /// Row data - each row is a Vec of SqlValues matching the columns.
303    pub rows: Vec<Vec<SqlValue>>,
304}
305
306impl TvpData {
307    /// Create a new empty TVP with the given schema and type name.
308    #[must_use]
309    pub fn new(schema: impl Into<String>, type_name: impl Into<String>) -> Self {
310        Self {
311            schema: schema.into(),
312            type_name: type_name.into(),
313            columns: Vec::new(),
314            rows: Vec::new(),
315        }
316    }
317
318    /// Add a column definition.
319    #[must_use]
320    pub fn with_column(mut self, column: TvpColumnDef) -> Self {
321        self.columns.push(column);
322        self
323    }
324
325    /// Add a row of values.
326    ///
327    /// # Panics
328    ///
329    /// Panics if the number of values doesn't match the number of columns.
330    #[must_use]
331    pub fn with_row(mut self, values: Vec<SqlValue>) -> Self {
332        assert_eq!(
333            values.len(),
334            self.columns.len(),
335            "Row value count ({}) must match column count ({})",
336            values.len(),
337            self.columns.len()
338        );
339        self.rows.push(values);
340        self
341    }
342
343    /// Add a row of values without panicking.
344    ///
345    /// Returns `Err` if the number of values doesn't match the number of columns.
346    pub fn try_add_row(&mut self, values: Vec<SqlValue>) -> Result<(), TvpError> {
347        if values.len() != self.columns.len() {
348            return Err(TvpError::ColumnCountMismatch {
349                expected: self.columns.len(),
350                actual: values.len(),
351            });
352        }
353        self.rows.push(values);
354        Ok(())
355    }
356
357    /// Get the number of rows.
358    #[must_use]
359    pub fn len(&self) -> usize {
360        self.rows.len()
361    }
362
363    /// Check if the TVP has no rows.
364    #[must_use]
365    pub fn is_empty(&self) -> bool {
366        self.rows.is_empty()
367    }
368
369    /// Get the number of columns.
370    #[must_use]
371    pub fn column_count(&self) -> usize {
372        self.columns.len()
373    }
374}
375
376/// Errors that can occur when working with TVPs.
377#[derive(Debug, Clone, thiserror::Error)]
378#[non_exhaustive]
379pub enum TvpError {
380    /// Column count mismatch between definition and row data.
381    #[error("column count mismatch: expected {expected}, got {actual}")]
382    ColumnCountMismatch {
383        /// Expected number of columns.
384        expected: usize,
385        /// Actual number of values in the row.
386        actual: usize,
387    },
388    /// Unknown SQL type.
389    #[error("unknown SQL type: {0}")]
390    UnknownSqlType(String),
391}
392
393#[cfg(test)]
394#[allow(clippy::unwrap_used, clippy::expect_used)]
395mod tests {
396    use super::*;
397
398    #[test]
399    fn test_column_type_from_sql_type() {
400        assert!(matches!(
401            TvpColumnType::from_sql_type("INT"),
402            Some(TvpColumnType::Int)
403        ));
404        assert!(matches!(
405            TvpColumnType::from_sql_type("BIGINT"),
406            Some(TvpColumnType::BigInt)
407        ));
408        assert!(matches!(
409            TvpColumnType::from_sql_type("nvarchar(100)"),
410            Some(TvpColumnType::NVarChar { max_length: 100 })
411        ));
412        assert!(matches!(
413            TvpColumnType::from_sql_type("NVARCHAR(MAX)"),
414            Some(TvpColumnType::NVarChar { max_length: 65535 })
415        ));
416        assert!(matches!(
417            TvpColumnType::from_sql_type("DECIMAL(18, 2)"),
418            Some(TvpColumnType::Decimal {
419                precision: 18,
420                scale: 2
421            })
422        ));
423        assert!(matches!(
424            TvpColumnType::from_sql_type("datetime2(3)"),
425            Some(TvpColumnType::DateTime2 { scale: 3 })
426        ));
427    }
428
429    #[test]
430    fn test_tvp_data_builder() {
431        let tvp = TvpData::new("dbo", "UserIdList")
432            .with_column(TvpColumnDef::new(TvpColumnType::Int))
433            .with_row(vec![SqlValue::Int(1)])
434            .with_row(vec![SqlValue::Int(2)])
435            .with_row(vec![SqlValue::Int(3)]);
436
437        assert_eq!(tvp.schema, "dbo");
438        assert_eq!(tvp.type_name, "UserIdList");
439        assert_eq!(tvp.column_count(), 1);
440        assert_eq!(tvp.len(), 3);
441    }
442
443    #[test]
444    #[should_panic(expected = "Row value count (2) must match column count (1)")]
445    fn test_tvp_data_row_mismatch_panics() {
446        let _ = TvpData::new("dbo", "Test")
447            .with_column(TvpColumnDef::new(TvpColumnType::Int))
448            .with_row(vec![SqlValue::Int(1), SqlValue::Int(2)]);
449    }
450
451    #[test]
452    fn test_tvp_data_try_add_row_error() {
453        let mut tvp =
454            TvpData::new("dbo", "Test").with_column(TvpColumnDef::new(TvpColumnType::Int));
455
456        let result = tvp.try_add_row(vec![SqlValue::Int(1), SqlValue::Int(2)]);
457        assert!(matches!(result, Err(TvpError::ColumnCountMismatch { .. })));
458    }
459}