1use crate::SqlValue;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum TvpColumnType {
27 Bit,
29 TinyInt,
31 SmallInt,
33 Int,
35 BigInt,
37 Real,
39 Float,
41 Decimal {
43 precision: u8,
45 scale: u8,
47 },
48 NVarChar {
50 max_length: u16,
52 },
53 VarChar {
55 max_length: u16,
57 },
58 VarBinary {
60 max_length: u16,
62 },
63 UniqueIdentifier,
65 Date,
67 Time {
69 scale: u8,
71 },
72 DateTime2 {
74 scale: u8,
76 },
77 DateTimeOffset {
79 scale: u8,
81 },
82 Xml,
84}
85
86impl TvpColumnType {
87 #[must_use]
91 pub fn from_sql_type(sql_type: &str) -> Option<Self> {
92 let sql_type = sql_type.trim().to_uppercase();
93
94 if sql_type.starts_with("NVARCHAR") {
96 let max_len = Self::parse_length(&sql_type).unwrap_or(4000);
97 return Some(Self::NVarChar {
98 max_length: max_len,
99 });
100 }
101 if sql_type.starts_with("VARCHAR") {
102 let max_len = Self::parse_length(&sql_type).unwrap_or(8000);
103 return Some(Self::VarChar {
104 max_length: max_len,
105 });
106 }
107 if sql_type.starts_with("VARBINARY") {
108 let max_len = Self::parse_length(&sql_type).unwrap_or(8000);
109 return Some(Self::VarBinary {
110 max_length: max_len,
111 });
112 }
113 if sql_type.starts_with("DECIMAL") || sql_type.starts_with("NUMERIC") {
114 let (precision, scale) = Self::parse_precision_scale(&sql_type).unwrap_or((18, 0));
115 return Some(Self::Decimal { precision, scale });
116 }
117 if sql_type.starts_with("TIME") {
118 let scale = Self::parse_scale(&sql_type).unwrap_or(7);
119 return Some(Self::Time { scale });
120 }
121 if sql_type.starts_with("DATETIME2") {
122 let scale = Self::parse_scale(&sql_type).unwrap_or(7);
123 return Some(Self::DateTime2 { scale });
124 }
125 if sql_type.starts_with("DATETIMEOFFSET") {
126 let scale = Self::parse_scale(&sql_type).unwrap_or(7);
127 return Some(Self::DateTimeOffset { scale });
128 }
129
130 match sql_type.as_str() {
132 "BIT" => Some(Self::Bit),
133 "TINYINT" => Some(Self::TinyInt),
134 "SMALLINT" => Some(Self::SmallInt),
135 "INT" | "INTEGER" => Some(Self::Int),
136 "BIGINT" => Some(Self::BigInt),
137 "REAL" => Some(Self::Real),
138 "FLOAT" => Some(Self::Float),
139 "UNIQUEIDENTIFIER" => Some(Self::UniqueIdentifier),
140 "DATE" => Some(Self::Date),
141 "XML" => Some(Self::Xml),
142 _ => None,
143 }
144 }
145
146 fn parse_length(sql_type: &str) -> Option<u16> {
148 let start = sql_type.find('(')?;
149 let end = sql_type.find(')')?;
150 let inner = sql_type[start + 1..end].trim();
151
152 if inner.eq_ignore_ascii_case("MAX") {
153 Some(u16::MAX)
154 } else {
155 inner.parse().ok()
156 }
157 }
158
159 fn parse_precision_scale(sql_type: &str) -> Option<(u8, u8)> {
161 let start = sql_type.find('(')?;
162 let end = sql_type.find(')')?;
163 let inner = sql_type[start + 1..end].trim();
164
165 if let Some(comma) = inner.find(',') {
166 let precision = inner[..comma].trim().parse().ok()?;
167 let scale = inner[comma + 1..].trim().parse().ok()?;
168 Some((precision, scale))
169 } else {
170 let precision = inner.parse().ok()?;
171 Some((precision, 0))
172 }
173 }
174
175 fn parse_scale(sql_type: &str) -> Option<u8> {
177 let start = sql_type.find('(')?;
178 let end = sql_type.find(')')?;
179 let inner = sql_type[start + 1..end].trim();
180 inner.parse().ok()
181 }
182
183 #[must_use]
185 pub const fn type_id(&self) -> u8 {
186 match self {
187 Self::Bit => 0x68, Self::TinyInt => 0x26, Self::SmallInt => 0x26, Self::Int => 0x26, Self::BigInt => 0x26, Self::Real => 0x6D, Self::Float => 0x6D, Self::Decimal { .. } => 0x6C, Self::NVarChar { .. } => 0xE7, Self::VarChar { .. } => 0xA7, Self::VarBinary { .. } => 0xA5, Self::UniqueIdentifier => 0x24, Self::Date => 0x28, Self::Time { .. } => 0x29, Self::DateTime2 { .. } => 0x2A, Self::DateTimeOffset { .. } => 0x2B, Self::Xml => 0xF1, }
205 }
206
207 #[must_use]
209 pub const fn max_length(&self) -> Option<u16> {
210 match self {
211 Self::Bit => Some(1),
212 Self::TinyInt => Some(1),
213 Self::SmallInt => Some(2),
214 Self::Int => Some(4),
215 Self::BigInt => Some(8),
216 Self::Real => Some(4),
217 Self::Float => Some(8),
218 Self::Decimal { .. } => Some(17), Self::NVarChar { max_length } => Some(if *max_length == u16::MAX {
220 0xFFFF
221 } else {
222 *max_length * 2
223 }),
224 Self::VarChar { max_length } => Some(*max_length),
225 Self::VarBinary { max_length } => Some(*max_length),
226 Self::UniqueIdentifier => Some(16),
227 Self::Date => None,
228 Self::Time { .. } => None,
229 Self::DateTime2 { .. } => None,
230 Self::DateTimeOffset { .. } => None,
231 Self::Xml => Some(0xFFFF), }
233 }
234}
235
236#[derive(Debug, Clone, PartialEq)]
238pub struct TvpColumnDef {
239 pub column_type: TvpColumnType,
241 pub nullable: bool,
243}
244
245impl TvpColumnDef {
246 #[must_use]
248 pub const fn new(column_type: TvpColumnType) -> Self {
249 Self {
250 column_type,
251 nullable: false,
252 }
253 }
254
255 #[must_use]
257 pub const fn nullable(column_type: TvpColumnType) -> Self {
258 Self {
259 column_type,
260 nullable: true,
261 }
262 }
263
264 #[must_use]
268 pub fn from_sql_type(sql_type: &str) -> Option<Self> {
269 TvpColumnType::from_sql_type(sql_type).map(Self::new)
270 }
271}
272
273#[derive(Debug, Clone, PartialEq)]
278pub struct TvpData {
279 pub schema: String,
281 pub type_name: String,
283 pub columns: Vec<TvpColumnDef>,
285 pub rows: Vec<Vec<SqlValue>>,
287}
288
289impl TvpData {
290 #[must_use]
292 pub fn new(schema: impl Into<String>, type_name: impl Into<String>) -> Self {
293 Self {
294 schema: schema.into(),
295 type_name: type_name.into(),
296 columns: Vec::new(),
297 rows: Vec::new(),
298 }
299 }
300
301 #[must_use]
303 pub fn with_column(mut self, column: TvpColumnDef) -> Self {
304 self.columns.push(column);
305 self
306 }
307
308 #[must_use]
314 pub fn with_row(mut self, values: Vec<SqlValue>) -> Self {
315 assert_eq!(
316 values.len(),
317 self.columns.len(),
318 "Row value count ({}) must match column count ({})",
319 values.len(),
320 self.columns.len()
321 );
322 self.rows.push(values);
323 self
324 }
325
326 pub fn try_add_row(&mut self, values: Vec<SqlValue>) -> Result<(), TvpError> {
330 if values.len() != self.columns.len() {
331 return Err(TvpError::ColumnCountMismatch {
332 expected: self.columns.len(),
333 actual: values.len(),
334 });
335 }
336 self.rows.push(values);
337 Ok(())
338 }
339
340 #[must_use]
342 pub fn len(&self) -> usize {
343 self.rows.len()
344 }
345
346 #[must_use]
348 pub fn is_empty(&self) -> bool {
349 self.rows.is_empty()
350 }
351
352 #[must_use]
354 pub fn column_count(&self) -> usize {
355 self.columns.len()
356 }
357}
358
359#[derive(Debug, Clone, thiserror::Error)]
361pub enum TvpError {
362 #[error("column count mismatch: expected {expected}, got {actual}")]
364 ColumnCountMismatch {
365 expected: usize,
367 actual: usize,
369 },
370 #[error("unknown SQL type: {0}")]
372 UnknownSqlType(String),
373}
374
375#[cfg(test)]
376#[allow(clippy::unwrap_used, clippy::expect_used)]
377mod tests {
378 use super::*;
379
380 #[test]
381 fn test_column_type_from_sql_type() {
382 assert!(matches!(
383 TvpColumnType::from_sql_type("INT"),
384 Some(TvpColumnType::Int)
385 ));
386 assert!(matches!(
387 TvpColumnType::from_sql_type("BIGINT"),
388 Some(TvpColumnType::BigInt)
389 ));
390 assert!(matches!(
391 TvpColumnType::from_sql_type("nvarchar(100)"),
392 Some(TvpColumnType::NVarChar { max_length: 100 })
393 ));
394 assert!(matches!(
395 TvpColumnType::from_sql_type("NVARCHAR(MAX)"),
396 Some(TvpColumnType::NVarChar { max_length: 65535 })
397 ));
398 assert!(matches!(
399 TvpColumnType::from_sql_type("DECIMAL(18, 2)"),
400 Some(TvpColumnType::Decimal {
401 precision: 18,
402 scale: 2
403 })
404 ));
405 assert!(matches!(
406 TvpColumnType::from_sql_type("datetime2(3)"),
407 Some(TvpColumnType::DateTime2 { scale: 3 })
408 ));
409 }
410
411 #[test]
412 fn test_tvp_data_builder() {
413 let tvp = TvpData::new("dbo", "UserIdList")
414 .with_column(TvpColumnDef::new(TvpColumnType::Int))
415 .with_row(vec![SqlValue::Int(1)])
416 .with_row(vec![SqlValue::Int(2)])
417 .with_row(vec![SqlValue::Int(3)]);
418
419 assert_eq!(tvp.schema, "dbo");
420 assert_eq!(tvp.type_name, "UserIdList");
421 assert_eq!(tvp.column_count(), 1);
422 assert_eq!(tvp.len(), 3);
423 }
424
425 #[test]
426 #[should_panic(expected = "Row value count (2) must match column count (1)")]
427 fn test_tvp_data_row_mismatch_panics() {
428 let _ = TvpData::new("dbo", "Test")
429 .with_column(TvpColumnDef::new(TvpColumnType::Int))
430 .with_row(vec![SqlValue::Int(1), SqlValue::Int(2)]);
431 }
432
433 #[test]
434 fn test_tvp_data_try_add_row_error() {
435 let mut tvp =
436 TvpData::new("dbo", "Test").with_column(TvpColumnDef::new(TvpColumnType::Int));
437
438 let result = tvp.try_add_row(vec![SqlValue::Int(1), SqlValue::Int(2)]);
439 assert!(matches!(result, Err(TvpError::ColumnCountMismatch { .. })));
440 }
441}