1use crate::SqlValue;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26#[non_exhaustive]
27pub enum TvpColumnType {
28 Bit,
30 TinyInt,
32 SmallInt,
34 Int,
36 BigInt,
38 Real,
40 Float,
42 Decimal {
44 precision: u8,
46 scale: u8,
48 },
49 NVarChar {
51 max_length: u16,
53 },
54 VarChar {
56 max_length: u16,
58 },
59 VarBinary {
61 max_length: u16,
63 },
64 UniqueIdentifier,
66 Date,
68 Time {
70 scale: u8,
72 },
73 DateTime2 {
75 scale: u8,
77 },
78 DateTimeOffset {
80 scale: u8,
82 },
83 Money,
85 SmallMoney,
87 DateTime,
89 SmallDateTime,
91 Xml,
93}
94
95impl TvpColumnType {
96 #[must_use]
100 pub fn from_sql_type(sql_type: &str) -> Option<Self> {
101 let sql_type = sql_type.trim().to_uppercase();
102
103 if sql_type.starts_with("NVARCHAR") {
105 let max_len = Self::parse_length(&sql_type).unwrap_or(4000);
106 return Some(Self::NVarChar {
107 max_length: max_len,
108 });
109 }
110 if sql_type.starts_with("VARCHAR") {
111 let max_len = Self::parse_length(&sql_type).unwrap_or(8000);
112 return Some(Self::VarChar {
113 max_length: max_len,
114 });
115 }
116 if sql_type.starts_with("VARBINARY") {
117 let max_len = Self::parse_length(&sql_type).unwrap_or(8000);
118 return Some(Self::VarBinary {
119 max_length: max_len,
120 });
121 }
122 if sql_type.starts_with("DECIMAL") || sql_type.starts_with("NUMERIC") {
123 let (precision, scale) = Self::parse_precision_scale(&sql_type).unwrap_or((18, 0));
124 return Some(Self::Decimal { precision, scale });
125 }
126 if sql_type.starts_with("TIME") {
127 let scale = Self::parse_scale(&sql_type).unwrap_or(7);
128 return Some(Self::Time { scale });
129 }
130 if sql_type.starts_with("DATETIME2") {
131 let scale = Self::parse_scale(&sql_type).unwrap_or(7);
132 return Some(Self::DateTime2 { scale });
133 }
134 if sql_type.starts_with("DATETIMEOFFSET") {
135 let scale = Self::parse_scale(&sql_type).unwrap_or(7);
136 return Some(Self::DateTimeOffset { scale });
137 }
138
139 match sql_type.as_str() {
141 "BIT" => Some(Self::Bit),
142 "TINYINT" => Some(Self::TinyInt),
143 "SMALLINT" => Some(Self::SmallInt),
144 "INT" | "INTEGER" => Some(Self::Int),
145 "BIGINT" => Some(Self::BigInt),
146 "REAL" => Some(Self::Real),
147 "FLOAT" => Some(Self::Float),
148 "MONEY" => Some(Self::Money),
149 "SMALLMONEY" => Some(Self::SmallMoney),
150 "UNIQUEIDENTIFIER" => Some(Self::UniqueIdentifier),
151 "DATE" => Some(Self::Date),
152 "DATETIME" => Some(Self::DateTime),
153 "SMALLDATETIME" => Some(Self::SmallDateTime),
154 "XML" => Some(Self::Xml),
155 _ => None,
156 }
157 }
158
159 fn parse_length(sql_type: &str) -> Option<u16> {
161 let start = sql_type.find('(')?;
162 let end = sql_type.find(')')?;
163 let inner = sql_type[start + 1..end].trim();
164
165 if inner.eq_ignore_ascii_case("MAX") {
166 Some(u16::MAX)
167 } else {
168 inner.parse().ok()
169 }
170 }
171
172 fn parse_precision_scale(sql_type: &str) -> Option<(u8, u8)> {
174 let start = sql_type.find('(')?;
175 let end = sql_type.find(')')?;
176 let inner = sql_type[start + 1..end].trim();
177
178 if let Some(comma) = inner.find(',') {
179 let precision = inner[..comma].trim().parse().ok()?;
180 let scale = inner[comma + 1..].trim().parse().ok()?;
181 Some((precision, scale))
182 } else {
183 let precision = inner.parse().ok()?;
184 Some((precision, 0))
185 }
186 }
187
188 fn parse_scale(sql_type: &str) -> Option<u8> {
190 let start = sql_type.find('(')?;
191 let end = sql_type.find(')')?;
192 let inner = sql_type[start + 1..end].trim();
193 inner.parse().ok()
194 }
195
196 #[must_use]
198 pub const fn type_id(&self) -> u8 {
199 match self {
200 Self::Bit => 0x68, Self::TinyInt => 0x26, Self::SmallInt => 0x26, Self::Int => 0x26, Self::BigInt => 0x26, Self::Real => 0x6D, Self::Float => 0x6D, Self::Decimal { .. } => 0x6C, Self::NVarChar { .. } => 0xE7, Self::VarChar { .. } => 0xA7, Self::VarBinary { .. } => 0xA5, Self::UniqueIdentifier => 0x24, Self::Date => 0x28, Self::Time { .. } => 0x29, Self::DateTime2 { .. } => 0x2A, Self::DateTimeOffset { .. } => 0x2B, Self::Money | Self::SmallMoney => 0x6E, Self::DateTime | Self::SmallDateTime => 0x6F, Self::Xml => 0xF1, }
220 }
221
222 #[must_use]
224 pub const fn max_length(&self) -> Option<u16> {
225 match self {
226 Self::Bit => Some(1),
227 Self::TinyInt => Some(1),
228 Self::SmallInt => Some(2),
229 Self::Int => Some(4),
230 Self::BigInt => Some(8),
231 Self::Real => Some(4),
232 Self::Float => Some(8),
233 Self::Decimal { .. } => Some(17), Self::NVarChar { max_length } => Some(if *max_length == u16::MAX {
235 0xFFFF
236 } else {
237 *max_length * 2
238 }),
239 Self::VarChar { max_length } => Some(*max_length),
240 Self::VarBinary { max_length } => Some(*max_length),
241 Self::UniqueIdentifier => Some(16),
242 Self::Date => None,
243 Self::Time { .. } => None,
244 Self::DateTime2 { .. } => None,
245 Self::DateTimeOffset { .. } => None,
246 Self::Money | Self::DateTime => Some(8),
247 Self::SmallMoney | Self::SmallDateTime => Some(4),
248 Self::Xml => Some(0xFFFF), }
250 }
251}
252
253#[derive(Debug, Clone, PartialEq)]
255pub struct TvpColumnDef {
256 pub column_type: TvpColumnType,
258 pub nullable: bool,
260}
261
262impl TvpColumnDef {
263 #[must_use]
265 pub const fn new(column_type: TvpColumnType) -> Self {
266 Self {
267 column_type,
268 nullable: false,
269 }
270 }
271
272 #[must_use]
274 pub const fn nullable(column_type: TvpColumnType) -> Self {
275 Self {
276 column_type,
277 nullable: true,
278 }
279 }
280
281 #[must_use]
285 pub fn from_sql_type(sql_type: &str) -> Option<Self> {
286 TvpColumnType::from_sql_type(sql_type).map(Self::new)
287 }
288}
289
290#[derive(Debug, Clone, PartialEq)]
295pub struct TvpData {
296 pub schema: String,
298 pub type_name: String,
300 pub columns: Vec<TvpColumnDef>,
302 pub rows: Vec<Vec<SqlValue>>,
304}
305
306impl TvpData {
307 #[must_use]
309 pub fn new(schema: impl Into<String>, type_name: impl Into<String>) -> Self {
310 Self {
311 schema: schema.into(),
312 type_name: type_name.into(),
313 columns: Vec::new(),
314 rows: Vec::new(),
315 }
316 }
317
318 #[must_use]
320 pub fn with_column(mut self, column: TvpColumnDef) -> Self {
321 self.columns.push(column);
322 self
323 }
324
325 #[must_use]
331 pub fn with_row(mut self, values: Vec<SqlValue>) -> Self {
332 assert_eq!(
333 values.len(),
334 self.columns.len(),
335 "Row value count ({}) must match column count ({})",
336 values.len(),
337 self.columns.len()
338 );
339 self.rows.push(values);
340 self
341 }
342
343 pub fn try_add_row(&mut self, values: Vec<SqlValue>) -> Result<(), TvpError> {
347 if values.len() != self.columns.len() {
348 return Err(TvpError::ColumnCountMismatch {
349 expected: self.columns.len(),
350 actual: values.len(),
351 });
352 }
353 self.rows.push(values);
354 Ok(())
355 }
356
357 #[must_use]
359 pub fn len(&self) -> usize {
360 self.rows.len()
361 }
362
363 #[must_use]
365 pub fn is_empty(&self) -> bool {
366 self.rows.is_empty()
367 }
368
369 #[must_use]
371 pub fn column_count(&self) -> usize {
372 self.columns.len()
373 }
374}
375
376#[derive(Debug, Clone, thiserror::Error)]
378#[non_exhaustive]
379pub enum TvpError {
380 #[error("column count mismatch: expected {expected}, got {actual}")]
382 ColumnCountMismatch {
383 expected: usize,
385 actual: usize,
387 },
388 #[error("unknown SQL type: {0}")]
390 UnknownSqlType(String),
391}
392
393#[cfg(test)]
394#[allow(clippy::unwrap_used, clippy::expect_used)]
395mod tests {
396 use super::*;
397
398 #[test]
399 fn test_column_type_from_sql_type() {
400 assert!(matches!(
401 TvpColumnType::from_sql_type("INT"),
402 Some(TvpColumnType::Int)
403 ));
404 assert!(matches!(
405 TvpColumnType::from_sql_type("BIGINT"),
406 Some(TvpColumnType::BigInt)
407 ));
408 assert!(matches!(
409 TvpColumnType::from_sql_type("nvarchar(100)"),
410 Some(TvpColumnType::NVarChar { max_length: 100 })
411 ));
412 assert!(matches!(
413 TvpColumnType::from_sql_type("NVARCHAR(MAX)"),
414 Some(TvpColumnType::NVarChar { max_length: 65535 })
415 ));
416 assert!(matches!(
417 TvpColumnType::from_sql_type("DECIMAL(18, 2)"),
418 Some(TvpColumnType::Decimal {
419 precision: 18,
420 scale: 2
421 })
422 ));
423 assert!(matches!(
424 TvpColumnType::from_sql_type("datetime2(3)"),
425 Some(TvpColumnType::DateTime2 { scale: 3 })
426 ));
427 }
428
429 #[test]
430 fn test_tvp_data_builder() {
431 let tvp = TvpData::new("dbo", "UserIdList")
432 .with_column(TvpColumnDef::new(TvpColumnType::Int))
433 .with_row(vec![SqlValue::Int(1)])
434 .with_row(vec![SqlValue::Int(2)])
435 .with_row(vec![SqlValue::Int(3)]);
436
437 assert_eq!(tvp.schema, "dbo");
438 assert_eq!(tvp.type_name, "UserIdList");
439 assert_eq!(tvp.column_count(), 1);
440 assert_eq!(tvp.len(), 3);
441 }
442
443 #[test]
444 #[should_panic(expected = "Row value count (2) must match column count (1)")]
445 fn test_tvp_data_row_mismatch_panics() {
446 let _ = TvpData::new("dbo", "Test")
447 .with_column(TvpColumnDef::new(TvpColumnType::Int))
448 .with_row(vec![SqlValue::Int(1), SqlValue::Int(2)]);
449 }
450
451 #[test]
452 fn test_tvp_data_try_add_row_error() {
453 let mut tvp =
454 TvpData::new("dbo", "Test").with_column(TvpColumnDef::new(TvpColumnType::Int));
455
456 let result = tvp.try_add_row(vec![SqlValue::Int(1), SqlValue::Int(2)]);
457 assert!(matches!(result, Err(TvpError::ColumnCountMismatch { .. })));
458 }
459}