use std::fmt::{self, Display, Formatter};
use sqlx_core::type_info::TypeInfo;
use crate::protocol::type_info as protocol;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MssqlType {
Null,
Bit,
TinyInt,
SmallInt,
Int,
BigInt,
Real,
Float,
NVarChar,
VarChar,
VarBinary,
Other(String),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MssqlTypeInfo {
kind: MssqlType,
variable_length: bool,
size: Option<u16>,
protocol_type_info: Option<protocol::TypeInfo>,
}
impl MssqlTypeInfo {
pub const fn new(kind: MssqlType) -> Self {
Self {
kind,
variable_length: false,
size: None,
protocol_type_info: None,
}
}
pub(crate) const fn with_size(kind: MssqlType, size: u16) -> Self {
Self {
kind,
variable_length: true,
size: Some(size),
protocol_type_info: None,
}
}
pub fn kind(&self) -> &MssqlType {
&self.kind
}
pub(crate) const fn size(&self) -> Option<u16> {
self.size
}
pub(crate) const fn protocol_type_info(&self) -> Option<&protocol::TypeInfo> {
self.protocol_type_info.as_ref()
}
pub const NULL: Self = Self::new(MssqlType::Null);
pub const BIT: Self = Self::new(MssqlType::Bit);
pub const TINYINT: Self = Self::new(MssqlType::TinyInt);
pub const SMALLINT: Self = Self::new(MssqlType::SmallInt);
pub const INT: Self = Self::new(MssqlType::Int);
pub const BIGINT: Self = Self::new(MssqlType::BigInt);
pub const REAL: Self = Self::new(MssqlType::Real);
pub const FLOAT: Self = Self::new(MssqlType::Float);
pub const NVARCHAR: Self = Self::new(MssqlType::NVarChar);
pub const VARCHAR: Self = Self::new(MssqlType::VarChar);
pub const VARBINARY: Self = Self::new(MssqlType::VarBinary);
pub(crate) fn from_protocol(type_info: &protocol::TypeInfo) -> Self {
let kind = match type_info.ty {
protocol::DataType::Null => MssqlType::Null,
protocol::DataType::Bit | protocol::DataType::BitN => MssqlType::Bit,
protocol::DataType::TinyInt => MssqlType::TinyInt,
protocol::DataType::SmallInt => MssqlType::SmallInt,
protocol::DataType::Int => MssqlType::Int,
protocol::DataType::BigInt => MssqlType::BigInt,
protocol::DataType::Real => MssqlType::Real,
protocol::DataType::Float => MssqlType::Float,
protocol::DataType::IntN => match type_info.size {
1 => MssqlType::TinyInt,
2 => MssqlType::SmallInt,
4 => MssqlType::Int,
8 => MssqlType::BigInt,
_ => MssqlType::Other(type_info.name().to_owned()),
},
protocol::DataType::FloatN => match type_info.size {
4 => MssqlType::Real,
8 => MssqlType::Float,
_ => MssqlType::Other(type_info.name().to_owned()),
},
protocol::DataType::NVarChar | protocol::DataType::NChar => MssqlType::NVarChar,
protocol::DataType::VarChar
| protocol::DataType::Char
| protocol::DataType::BigVarChar
| protocol::DataType::BigChar => MssqlType::VarChar,
protocol::DataType::VarBinary
| protocol::DataType::Binary
| protocol::DataType::BigVarBinary
| protocol::DataType::BigBinary => MssqlType::VarBinary,
_ => MssqlType::Other(type_info.name().to_owned()),
};
Self {
kind,
variable_length: type_info.is_nullable_or_variable_length(),
size: u16::try_from(type_info.size).ok(),
protocol_type_info: Some(type_info.clone()),
}
}
}
impl TypeInfo for MssqlTypeInfo {
fn is_null(&self) -> bool {
matches!(self.kind, MssqlType::Null)
}
fn name(&self) -> &str {
match &self.kind {
MssqlType::Null => "NULL",
MssqlType::Bit => "BIT",
MssqlType::TinyInt => "TINYINT",
MssqlType::SmallInt => "SMALLINT",
MssqlType::Int => "INT",
MssqlType::BigInt => "BIGINT",
MssqlType::Real => "REAL",
MssqlType::Float => "FLOAT",
MssqlType::NVarChar => "NVARCHAR",
MssqlType::VarChar => "VARCHAR",
MssqlType::VarBinary => "VARBINARY",
MssqlType::Other(name) => name,
}
}
fn type_compatible(&self, other: &Self) -> bool {
self.kind == other.kind || self.is_null() || other.is_null()
}
}
impl Display for MssqlTypeInfo {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str(self.name())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn exposes_sql_server_type_names() {
assert_eq!("INT", MssqlTypeInfo::INT.name());
assert_eq!("NVARCHAR", MssqlTypeInfo::NVARCHAR.to_string());
assert_eq!("VARCHAR", MssqlTypeInfo::VARCHAR.to_string());
}
#[test]
fn null_is_compatible_with_known_types() {
assert!(MssqlTypeInfo::NULL.type_compatible(&MssqlTypeInfo::INT));
assert!(MssqlTypeInfo::NVARCHAR.type_compatible(&MssqlTypeInfo::NULL));
assert!(!MssqlTypeInfo::INT.type_compatible(&MssqlTypeInfo::BIGINT));
}
#[test]
fn maps_unicode_and_non_unicode_protocol_text_separately() {
assert_eq!(
MssqlType::NVarChar,
MssqlTypeInfo::from_protocol(&protocol::TypeInfo::new(protocol::DataType::NVarChar, 8))
.kind
);
assert_eq!(
MssqlType::VarChar,
MssqlTypeInfo::from_protocol(&protocol::TypeInfo::new(protocol::DataType::VarChar, 8))
.kind
);
assert_eq!(
MssqlType::VarChar,
MssqlTypeInfo::from_protocol(&protocol::TypeInfo::new(
protocol::DataType::BigVarChar,
8,
))
.kind
);
}
}