1use crate::SqlValue;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26#[non_exhaustive]
27pub enum TvpColumnType {
28 Bit,
30 TinyInt,
32 SmallInt,
34 Int,
36 BigInt,
38 Real,
40 Float,
42 Decimal {
44 precision: u8,
46 scale: u8,
48 },
49 NVarChar {
51 max_length: u16,
53 },
54 VarChar {
56 max_length: u16,
58 },
59 VarBinary {
61 max_length: u16,
63 },
64 UniqueIdentifier,
66 Date,
68 Time {
70 scale: u8,
72 },
73 DateTime2 {
75 scale: u8,
77 },
78 DateTimeOffset {
80 scale: u8,
82 },
83 Xml,
85}
86
87impl TvpColumnType {
88 #[must_use]
92 pub fn from_sql_type(sql_type: &str) -> Option<Self> {
93 let sql_type = sql_type.trim().to_uppercase();
94
95 if sql_type.starts_with("NVARCHAR") {
97 let max_len = Self::parse_length(&sql_type).unwrap_or(4000);
98 return Some(Self::NVarChar {
99 max_length: max_len,
100 });
101 }
102 if sql_type.starts_with("VARCHAR") {
103 let max_len = Self::parse_length(&sql_type).unwrap_or(8000);
104 return Some(Self::VarChar {
105 max_length: max_len,
106 });
107 }
108 if sql_type.starts_with("VARBINARY") {
109 let max_len = Self::parse_length(&sql_type).unwrap_or(8000);
110 return Some(Self::VarBinary {
111 max_length: max_len,
112 });
113 }
114 if sql_type.starts_with("DECIMAL") || sql_type.starts_with("NUMERIC") {
115 let (precision, scale) = Self::parse_precision_scale(&sql_type).unwrap_or((18, 0));
116 return Some(Self::Decimal { precision, scale });
117 }
118 if sql_type.starts_with("TIME") {
119 let scale = Self::parse_scale(&sql_type).unwrap_or(7);
120 return Some(Self::Time { scale });
121 }
122 if sql_type.starts_with("DATETIME2") {
123 let scale = Self::parse_scale(&sql_type).unwrap_or(7);
124 return Some(Self::DateTime2 { scale });
125 }
126 if sql_type.starts_with("DATETIMEOFFSET") {
127 let scale = Self::parse_scale(&sql_type).unwrap_or(7);
128 return Some(Self::DateTimeOffset { scale });
129 }
130
131 match sql_type.as_str() {
133 "BIT" => Some(Self::Bit),
134 "TINYINT" => Some(Self::TinyInt),
135 "SMALLINT" => Some(Self::SmallInt),
136 "INT" | "INTEGER" => Some(Self::Int),
137 "BIGINT" => Some(Self::BigInt),
138 "REAL" => Some(Self::Real),
139 "FLOAT" => Some(Self::Float),
140 "UNIQUEIDENTIFIER" => Some(Self::UniqueIdentifier),
141 "DATE" => Some(Self::Date),
142 "XML" => Some(Self::Xml),
143 _ => None,
144 }
145 }
146
147 fn parse_length(sql_type: &str) -> Option<u16> {
149 let start = sql_type.find('(')?;
150 let end = sql_type.find(')')?;
151 let inner = sql_type[start + 1..end].trim();
152
153 if inner.eq_ignore_ascii_case("MAX") {
154 Some(u16::MAX)
155 } else {
156 inner.parse().ok()
157 }
158 }
159
160 fn parse_precision_scale(sql_type: &str) -> Option<(u8, u8)> {
162 let start = sql_type.find('(')?;
163 let end = sql_type.find(')')?;
164 let inner = sql_type[start + 1..end].trim();
165
166 if let Some(comma) = inner.find(',') {
167 let precision = inner[..comma].trim().parse().ok()?;
168 let scale = inner[comma + 1..].trim().parse().ok()?;
169 Some((precision, scale))
170 } else {
171 let precision = inner.parse().ok()?;
172 Some((precision, 0))
173 }
174 }
175
176 fn parse_scale(sql_type: &str) -> Option<u8> {
178 let start = sql_type.find('(')?;
179 let end = sql_type.find(')')?;
180 let inner = sql_type[start + 1..end].trim();
181 inner.parse().ok()
182 }
183
184 #[must_use]
186 pub const fn type_id(&self) -> u8 {
187 match self {
188 Self::Bit => 0x68, Self::TinyInt => 0x26, Self::SmallInt => 0x26, Self::Int => 0x26, Self::BigInt => 0x26, Self::Real => 0x6D, Self::Float => 0x6D, Self::Decimal { .. } => 0x6C, Self::NVarChar { .. } => 0xE7, Self::VarChar { .. } => 0xA7, Self::VarBinary { .. } => 0xA5, Self::UniqueIdentifier => 0x24, Self::Date => 0x28, Self::Time { .. } => 0x29, Self::DateTime2 { .. } => 0x2A, Self::DateTimeOffset { .. } => 0x2B, Self::Xml => 0xF1, }
206 }
207
208 #[must_use]
210 pub const fn max_length(&self) -> Option<u16> {
211 match self {
212 Self::Bit => Some(1),
213 Self::TinyInt => Some(1),
214 Self::SmallInt => Some(2),
215 Self::Int => Some(4),
216 Self::BigInt => Some(8),
217 Self::Real => Some(4),
218 Self::Float => Some(8),
219 Self::Decimal { .. } => Some(17), Self::NVarChar { max_length } => Some(if *max_length == u16::MAX {
221 0xFFFF
222 } else {
223 *max_length * 2
224 }),
225 Self::VarChar { max_length } => Some(*max_length),
226 Self::VarBinary { max_length } => Some(*max_length),
227 Self::UniqueIdentifier => Some(16),
228 Self::Date => None,
229 Self::Time { .. } => None,
230 Self::DateTime2 { .. } => None,
231 Self::DateTimeOffset { .. } => None,
232 Self::Xml => Some(0xFFFF), }
234 }
235}
236
237#[derive(Debug, Clone, PartialEq)]
239pub struct TvpColumnDef {
240 pub column_type: TvpColumnType,
242 pub nullable: bool,
244}
245
246impl TvpColumnDef {
247 #[must_use]
249 pub const fn new(column_type: TvpColumnType) -> Self {
250 Self {
251 column_type,
252 nullable: false,
253 }
254 }
255
256 #[must_use]
258 pub const fn nullable(column_type: TvpColumnType) -> Self {
259 Self {
260 column_type,
261 nullable: true,
262 }
263 }
264
265 #[must_use]
269 pub fn from_sql_type(sql_type: &str) -> Option<Self> {
270 TvpColumnType::from_sql_type(sql_type).map(Self::new)
271 }
272}
273
274#[derive(Debug, Clone, PartialEq)]
279pub struct TvpData {
280 pub schema: String,
282 pub type_name: String,
284 pub columns: Vec<TvpColumnDef>,
286 pub rows: Vec<Vec<SqlValue>>,
288}
289
290impl TvpData {
291 #[must_use]
293 pub fn new(schema: impl Into<String>, type_name: impl Into<String>) -> Self {
294 Self {
295 schema: schema.into(),
296 type_name: type_name.into(),
297 columns: Vec::new(),
298 rows: Vec::new(),
299 }
300 }
301
302 #[must_use]
304 pub fn with_column(mut self, column: TvpColumnDef) -> Self {
305 self.columns.push(column);
306 self
307 }
308
309 #[must_use]
315 pub fn with_row(mut self, values: Vec<SqlValue>) -> Self {
316 assert_eq!(
317 values.len(),
318 self.columns.len(),
319 "Row value count ({}) must match column count ({})",
320 values.len(),
321 self.columns.len()
322 );
323 self.rows.push(values);
324 self
325 }
326
327 pub fn try_add_row(&mut self, values: Vec<SqlValue>) -> Result<(), TvpError> {
331 if values.len() != self.columns.len() {
332 return Err(TvpError::ColumnCountMismatch {
333 expected: self.columns.len(),
334 actual: values.len(),
335 });
336 }
337 self.rows.push(values);
338 Ok(())
339 }
340
341 #[must_use]
343 pub fn len(&self) -> usize {
344 self.rows.len()
345 }
346
347 #[must_use]
349 pub fn is_empty(&self) -> bool {
350 self.rows.is_empty()
351 }
352
353 #[must_use]
355 pub fn column_count(&self) -> usize {
356 self.columns.len()
357 }
358}
359
360#[derive(Debug, Clone, thiserror::Error)]
362#[non_exhaustive]
363pub enum TvpError {
364 #[error("column count mismatch: expected {expected}, got {actual}")]
366 ColumnCountMismatch {
367 expected: usize,
369 actual: usize,
371 },
372 #[error("unknown SQL type: {0}")]
374 UnknownSqlType(String),
375}
376
377#[cfg(test)]
378#[allow(clippy::unwrap_used, clippy::expect_used)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn test_column_type_from_sql_type() {
384 assert!(matches!(
385 TvpColumnType::from_sql_type("INT"),
386 Some(TvpColumnType::Int)
387 ));
388 assert!(matches!(
389 TvpColumnType::from_sql_type("BIGINT"),
390 Some(TvpColumnType::BigInt)
391 ));
392 assert!(matches!(
393 TvpColumnType::from_sql_type("nvarchar(100)"),
394 Some(TvpColumnType::NVarChar { max_length: 100 })
395 ));
396 assert!(matches!(
397 TvpColumnType::from_sql_type("NVARCHAR(MAX)"),
398 Some(TvpColumnType::NVarChar { max_length: 65535 })
399 ));
400 assert!(matches!(
401 TvpColumnType::from_sql_type("DECIMAL(18, 2)"),
402 Some(TvpColumnType::Decimal {
403 precision: 18,
404 scale: 2
405 })
406 ));
407 assert!(matches!(
408 TvpColumnType::from_sql_type("datetime2(3)"),
409 Some(TvpColumnType::DateTime2 { scale: 3 })
410 ));
411 }
412
413 #[test]
414 fn test_tvp_data_builder() {
415 let tvp = TvpData::new("dbo", "UserIdList")
416 .with_column(TvpColumnDef::new(TvpColumnType::Int))
417 .with_row(vec![SqlValue::Int(1)])
418 .with_row(vec![SqlValue::Int(2)])
419 .with_row(vec![SqlValue::Int(3)]);
420
421 assert_eq!(tvp.schema, "dbo");
422 assert_eq!(tvp.type_name, "UserIdList");
423 assert_eq!(tvp.column_count(), 1);
424 assert_eq!(tvp.len(), 3);
425 }
426
427 #[test]
428 #[should_panic(expected = "Row value count (2) must match column count (1)")]
429 fn test_tvp_data_row_mismatch_panics() {
430 let _ = TvpData::new("dbo", "Test")
431 .with_column(TvpColumnDef::new(TvpColumnType::Int))
432 .with_row(vec![SqlValue::Int(1), SqlValue::Int(2)]);
433 }
434
435 #[test]
436 fn test_tvp_data_try_add_row_error() {
437 let mut tvp =
438 TvpData::new("dbo", "Test").with_column(TvpColumnDef::new(TvpColumnType::Int));
439
440 let result = tvp.try_add_row(vec![SqlValue::Int(1), SqlValue::Int(2)]);
441 assert!(matches!(result, Err(TvpError::ColumnCountMismatch { .. })));
442 }
443}