use crate::SqlValue;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum TvpColumnType {
Bit,
TinyInt,
SmallInt,
Int,
BigInt,
Real,
Float,
Decimal {
precision: u8,
scale: u8,
},
NVarChar {
max_length: u16,
},
VarChar {
max_length: u16,
},
VarBinary {
max_length: u16,
},
UniqueIdentifier,
Date,
Time {
scale: u8,
},
DateTime2 {
scale: u8,
},
DateTimeOffset {
scale: u8,
},
Money,
SmallMoney,
DateTime,
SmallDateTime,
Xml,
}
impl TvpColumnType {
#[must_use]
pub fn from_sql_type(sql_type: &str) -> Option<Self> {
let sql_type = sql_type.trim().to_uppercase();
if sql_type.starts_with("NVARCHAR") {
let max_len = Self::parse_length(&sql_type).unwrap_or(4000);
return Some(Self::NVarChar {
max_length: max_len,
});
}
if sql_type.starts_with("VARCHAR") {
let max_len = Self::parse_length(&sql_type).unwrap_or(8000);
return Some(Self::VarChar {
max_length: max_len,
});
}
if sql_type.starts_with("VARBINARY") {
let max_len = Self::parse_length(&sql_type).unwrap_or(8000);
return Some(Self::VarBinary {
max_length: max_len,
});
}
if sql_type.starts_with("DECIMAL") || sql_type.starts_with("NUMERIC") {
let (precision, scale) = Self::parse_precision_scale(&sql_type).unwrap_or((18, 0));
return Some(Self::Decimal { precision, scale });
}
if sql_type.starts_with("TIME") {
let scale = Self::parse_scale(&sql_type).unwrap_or(7);
return Some(Self::Time { scale });
}
if sql_type.starts_with("DATETIME2") {
let scale = Self::parse_scale(&sql_type).unwrap_or(7);
return Some(Self::DateTime2 { scale });
}
if sql_type.starts_with("DATETIMEOFFSET") {
let scale = Self::parse_scale(&sql_type).unwrap_or(7);
return Some(Self::DateTimeOffset { scale });
}
match sql_type.as_str() {
"BIT" => Some(Self::Bit),
"TINYINT" => Some(Self::TinyInt),
"SMALLINT" => Some(Self::SmallInt),
"INT" | "INTEGER" => Some(Self::Int),
"BIGINT" => Some(Self::BigInt),
"REAL" => Some(Self::Real),
"FLOAT" => Some(Self::Float),
"MONEY" => Some(Self::Money),
"SMALLMONEY" => Some(Self::SmallMoney),
"UNIQUEIDENTIFIER" => Some(Self::UniqueIdentifier),
"DATE" => Some(Self::Date),
"DATETIME" => Some(Self::DateTime),
"SMALLDATETIME" => Some(Self::SmallDateTime),
"XML" => Some(Self::Xml),
_ => None,
}
}
fn parse_length(sql_type: &str) -> Option<u16> {
let start = sql_type.find('(')?;
let end = sql_type.find(')')?;
let inner = sql_type[start + 1..end].trim();
if inner.eq_ignore_ascii_case("MAX") {
Some(u16::MAX)
} else {
inner.parse().ok()
}
}
fn parse_precision_scale(sql_type: &str) -> Option<(u8, u8)> {
let start = sql_type.find('(')?;
let end = sql_type.find(')')?;
let inner = sql_type[start + 1..end].trim();
if let Some(comma) = inner.find(',') {
let precision = inner[..comma].trim().parse().ok()?;
let scale = inner[comma + 1..].trim().parse().ok()?;
Some((precision, scale))
} else {
let precision = inner.parse().ok()?;
Some((precision, 0))
}
}
fn parse_scale(sql_type: &str) -> Option<u8> {
let start = sql_type.find('(')?;
let end = sql_type.find(')')?;
let inner = sql_type[start + 1..end].trim();
inner.parse().ok()
}
#[must_use]
pub const fn type_id(&self) -> u8 {
match self {
Self::Bit => 0x68, Self::TinyInt => 0x26, Self::SmallInt => 0x26, Self::Int => 0x26, Self::BigInt => 0x26, Self::Real => 0x6D, Self::Float => 0x6D, Self::Decimal { .. } => 0x6C, Self::NVarChar { .. } => 0xE7, Self::VarChar { .. } => 0xA7, Self::VarBinary { .. } => 0xA5, Self::UniqueIdentifier => 0x24, Self::Date => 0x28, Self::Time { .. } => 0x29, Self::DateTime2 { .. } => 0x2A, Self::DateTimeOffset { .. } => 0x2B, Self::Money | Self::SmallMoney => 0x6E, Self::DateTime | Self::SmallDateTime => 0x6F, Self::Xml => 0xF1, }
}
#[must_use]
pub const fn max_length(&self) -> Option<u16> {
match self {
Self::Bit => Some(1),
Self::TinyInt => Some(1),
Self::SmallInt => Some(2),
Self::Int => Some(4),
Self::BigInt => Some(8),
Self::Real => Some(4),
Self::Float => Some(8),
Self::Decimal { .. } => Some(17), Self::NVarChar { max_length } => Some(if *max_length == u16::MAX {
0xFFFF
} else {
*max_length * 2
}),
Self::VarChar { max_length } => Some(*max_length),
Self::VarBinary { max_length } => Some(*max_length),
Self::UniqueIdentifier => Some(16),
Self::Date => None,
Self::Time { .. } => None,
Self::DateTime2 { .. } => None,
Self::DateTimeOffset { .. } => None,
Self::Money | Self::DateTime => Some(8),
Self::SmallMoney | Self::SmallDateTime => Some(4),
Self::Xml => Some(0xFFFF), }
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct TvpColumnDef {
pub column_type: TvpColumnType,
pub nullable: bool,
}
impl TvpColumnDef {
#[must_use]
pub const fn new(column_type: TvpColumnType) -> Self {
Self {
column_type,
nullable: false,
}
}
#[must_use]
pub const fn nullable(column_type: TvpColumnType) -> Self {
Self {
column_type,
nullable: true,
}
}
#[must_use]
pub fn from_sql_type(sql_type: &str) -> Option<Self> {
TvpColumnType::from_sql_type(sql_type).map(Self::new)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct TvpData {
pub schema: String,
pub type_name: String,
pub columns: Vec<TvpColumnDef>,
pub rows: Vec<Vec<SqlValue>>,
}
impl TvpData {
#[must_use]
pub fn new(schema: impl Into<String>, type_name: impl Into<String>) -> Self {
Self {
schema: schema.into(),
type_name: type_name.into(),
columns: Vec::new(),
rows: Vec::new(),
}
}
#[must_use]
pub fn with_column(mut self, column: TvpColumnDef) -> Self {
self.columns.push(column);
self
}
#[must_use]
pub fn with_row(mut self, values: Vec<SqlValue>) -> Self {
assert_eq!(
values.len(),
self.columns.len(),
"Row value count ({}) must match column count ({})",
values.len(),
self.columns.len()
);
self.rows.push(values);
self
}
pub fn try_add_row(&mut self, values: Vec<SqlValue>) -> Result<(), TvpError> {
if values.len() != self.columns.len() {
return Err(TvpError::ColumnCountMismatch {
expected: self.columns.len(),
actual: values.len(),
});
}
self.rows.push(values);
Ok(())
}
#[must_use]
pub fn len(&self) -> usize {
self.rows.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.rows.is_empty()
}
#[must_use]
pub fn column_count(&self) -> usize {
self.columns.len()
}
}
#[derive(Debug, Clone, thiserror::Error)]
#[non_exhaustive]
pub enum TvpError {
#[error("column count mismatch: expected {expected}, got {actual}")]
ColumnCountMismatch {
expected: usize,
actual: usize,
},
#[error("unknown SQL type: {0}")]
UnknownSqlType(String),
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_column_type_from_sql_type() {
assert!(matches!(
TvpColumnType::from_sql_type("INT"),
Some(TvpColumnType::Int)
));
assert!(matches!(
TvpColumnType::from_sql_type("BIGINT"),
Some(TvpColumnType::BigInt)
));
assert!(matches!(
TvpColumnType::from_sql_type("nvarchar(100)"),
Some(TvpColumnType::NVarChar { max_length: 100 })
));
assert!(matches!(
TvpColumnType::from_sql_type("NVARCHAR(MAX)"),
Some(TvpColumnType::NVarChar { max_length: 65535 })
));
assert!(matches!(
TvpColumnType::from_sql_type("DECIMAL(18, 2)"),
Some(TvpColumnType::Decimal {
precision: 18,
scale: 2
})
));
assert!(matches!(
TvpColumnType::from_sql_type("datetime2(3)"),
Some(TvpColumnType::DateTime2 { scale: 3 })
));
}
#[test]
fn test_tvp_data_builder() {
let tvp = TvpData::new("dbo", "UserIdList")
.with_column(TvpColumnDef::new(TvpColumnType::Int))
.with_row(vec![SqlValue::Int(1)])
.with_row(vec![SqlValue::Int(2)])
.with_row(vec![SqlValue::Int(3)]);
assert_eq!(tvp.schema, "dbo");
assert_eq!(tvp.type_name, "UserIdList");
assert_eq!(tvp.column_count(), 1);
assert_eq!(tvp.len(), 3);
}
#[test]
#[should_panic(expected = "Row value count (2) must match column count (1)")]
fn test_tvp_data_row_mismatch_panics() {
let _ = TvpData::new("dbo", "Test")
.with_column(TvpColumnDef::new(TvpColumnType::Int))
.with_row(vec![SqlValue::Int(1), SqlValue::Int(2)]);
}
#[test]
fn test_tvp_data_try_add_row_error() {
let mut tvp =
TvpData::new("dbo", "Test").with_column(TvpColumnDef::new(TvpColumnType::Int));
let result = tvp.try_add_row(vec![SqlValue::Int(1), SqlValue::Int(2)]);
assert!(matches!(result, Err(TvpError::ColumnCountMismatch { .. })));
}
}