mod functions;
mod parse;
pub use functions::{FunctionCategory, SqlFunction};
pub use parse::{parse_type, TypeParseError};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum DataType {
Boolean,
SmallInt,
Integer,
BigInt,
Float,
Double,
Decimal { precision: u8, scale: u8 },
Varchar { max_length: Option<u32> },
Char { length: u32 },
Text,
Blob,
Date,
Time,
Timestamp { with_timezone: bool },
Interval,
Array(Box<DataType>),
Struct(Vec<(String, DataType)>),
Map(Box<DataType>, Box<DataType>),
Null,
Unknown,
}
impl DataType {
pub fn is_numeric(&self) -> bool {
matches!(
self,
DataType::SmallInt
| DataType::Integer
| DataType::BigInt
| DataType::Float
| DataType::Double
| DataType::Decimal { .. }
)
}
pub fn is_string(&self) -> bool {
matches!(
self,
DataType::Varchar { .. } | DataType::Char { .. } | DataType::Text
)
}
pub fn is_complex(&self) -> bool {
matches!(
self,
DataType::Array(_) | DataType::Struct(_) | DataType::Map(_, _)
)
}
pub fn is_temporal(&self) -> bool {
matches!(
self,
DataType::Date | DataType::Time | DataType::Timestamp { .. } | DataType::Interval
)
}
pub fn normalize(&self) -> DataType {
match self {
DataType::Text => DataType::Varchar { max_length: None },
DataType::Array(inner) => DataType::Array(Box::new(inner.normalize())),
DataType::Struct(fields) => DataType::Struct(
fields
.iter()
.map(|(name, dt)| (name.clone(), dt.normalize()))
.collect(),
),
DataType::Map(k, v) => DataType::Map(Box::new(k.normalize()), Box::new(v.normalize())),
other => other.clone(),
}
}
pub fn to_backend_sql(&self) -> String {
match self {
DataType::Text => "VARCHAR".to_string(),
other => other.to_sql(),
}
}
pub fn to_sql(&self) -> String {
match self {
DataType::Boolean => "BOOLEAN".to_string(),
DataType::SmallInt => "SMALLINT".to_string(),
DataType::Integer => "INTEGER".to_string(),
DataType::BigInt => "BIGINT".to_string(),
DataType::Float => "FLOAT".to_string(),
DataType::Double => "DOUBLE".to_string(),
DataType::Decimal { precision, scale } => {
if *scale == 0 {
format!("DECIMAL({precision})")
} else {
format!("DECIMAL({precision},{scale})")
}
}
DataType::Varchar { max_length: None } => "VARCHAR".to_string(),
DataType::Varchar {
max_length: Some(len),
} => format!("VARCHAR({len})"),
DataType::Char { length } => format!("CHAR({length})"),
DataType::Text => "TEXT".to_string(),
DataType::Blob => "BLOB".to_string(),
DataType::Date => "DATE".to_string(),
DataType::Time => "TIME".to_string(),
DataType::Timestamp { with_timezone } => {
if *with_timezone {
"TIMESTAMP WITH TIME ZONE".to_string()
} else {
"TIMESTAMP".to_string()
}
}
DataType::Interval => "INTERVAL".to_string(),
DataType::Array(inner) => format!("{}[]", inner.to_sql()),
DataType::Struct(fields) => {
let field_strs: Vec<String> = fields
.iter()
.map(|(name, dt)| format!("{} {}", name, dt.to_sql()))
.collect();
format!("STRUCT({})", field_strs.join(", "))
}
DataType::Map(key, value) => {
format!("MAP({}, {})", key.to_sql(), value.to_sql())
}
DataType::Null => "NULL".to_string(),
DataType::Unknown => "UNKNOWN".to_string(),
}
}
}
impl std::fmt::Display for DataType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_sql())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TypedColumn {
pub data_type: DataType,
pub nullable: bool,
}
impl TypedColumn {
pub fn new(data_type: DataType, nullable: bool) -> Self {
Self {
data_type,
nullable,
}
}
pub fn nullable(data_type: DataType) -> Self {
Self::new(data_type, true)
}
pub fn not_null(data_type: DataType) -> Self {
Self::new(data_type, false)
}
pub fn unknown() -> Self {
Self::nullable(DataType::Unknown)
}
}
impl std::fmt::Display for TypedColumn {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.data_type)?;
if !self.nullable {
write!(f, " NOT NULL")?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_data_type_display() {
assert_eq!(DataType::Integer.to_string(), "INTEGER");
assert_eq!(
DataType::Decimal {
precision: 10,
scale: 2
}
.to_string(),
"DECIMAL(10,2)"
);
assert_eq!(
DataType::Varchar { max_length: None }.to_string(),
"VARCHAR"
);
assert_eq!(
DataType::Varchar {
max_length: Some(255)
}
.to_string(),
"VARCHAR(255)"
);
assert_eq!(
DataType::Timestamp {
with_timezone: true
}
.to_string(),
"TIMESTAMP WITH TIME ZONE"
);
assert_eq!(
DataType::Array(Box::new(DataType::Integer)).to_string(),
"INTEGER[]"
);
}
#[test]
fn test_to_backend_sql_text_becomes_varchar() {
assert_eq!(DataType::Text.to_backend_sql(), "VARCHAR");
assert_eq!(DataType::Integer.to_backend_sql(), "INTEGER");
assert_eq!(
DataType::Varchar { max_length: None }.to_backend_sql(),
"VARCHAR"
);
}
#[test]
fn test_is_numeric() {
assert!(DataType::Integer.is_numeric());
assert!(DataType::BigInt.is_numeric());
assert!(DataType::Double.is_numeric());
assert!(DataType::Decimal {
precision: 10,
scale: 2
}
.is_numeric());
assert!(!DataType::Varchar { max_length: None }.is_numeric());
assert!(!DataType::Date.is_numeric());
}
#[test]
fn test_is_complex() {
assert!(DataType::Array(Box::new(DataType::Integer)).is_complex());
assert!(DataType::Struct(vec![("a".to_string(), DataType::Integer)]).is_complex());
assert!(DataType::Map(
Box::new(DataType::Varchar { max_length: None }),
Box::new(DataType::Integer)
)
.is_complex());
assert!(!DataType::Integer.is_complex());
assert!(!DataType::Varchar { max_length: None }.is_complex());
assert!(!DataType::Boolean.is_complex());
}
#[test]
fn test_map_to_sql() {
assert_eq!(
DataType::Map(
Box::new(DataType::Varchar { max_length: None }),
Box::new(DataType::Integer)
)
.to_sql(),
"MAP(VARCHAR, INTEGER)"
);
}
#[test]
fn test_normalize_text_to_varchar() {
assert_eq!(
DataType::Text.normalize(),
DataType::Varchar { max_length: None }
);
}
#[test]
fn test_normalize_scalar_unchanged() {
assert_eq!(DataType::Integer.normalize(), DataType::Integer);
assert_eq!(DataType::BigInt.normalize(), DataType::BigInt);
assert_eq!(DataType::Boolean.normalize(), DataType::Boolean);
assert_eq!(
DataType::Varchar { max_length: None }.normalize(),
DataType::Varchar { max_length: None }
);
assert_eq!(
DataType::Decimal {
precision: 10,
scale: 2
}
.normalize(),
DataType::Decimal {
precision: 10,
scale: 2
}
);
}
#[test]
fn test_normalize_array_recursive() {
let arr = DataType::Array(Box::new(DataType::Text));
assert_eq!(
arr.normalize(),
DataType::Array(Box::new(DataType::Varchar { max_length: None }))
);
let arr = DataType::Array(Box::new(DataType::Integer));
assert_eq!(
arr.normalize(),
DataType::Array(Box::new(DataType::Integer))
);
}
#[test]
fn test_normalize_struct_recursive() {
let s = DataType::Struct(vec![
("a".to_string(), DataType::Text),
("b".to_string(), DataType::Integer),
]);
assert_eq!(
s.normalize(),
DataType::Struct(vec![
("a".to_string(), DataType::Varchar { max_length: None }),
("b".to_string(), DataType::Integer),
])
);
}
#[test]
fn test_normalize_map_recursive() {
let m = DataType::Map(Box::new(DataType::Text), Box::new(DataType::Text));
assert_eq!(
m.normalize(),
DataType::Map(
Box::new(DataType::Varchar { max_length: None }),
Box::new(DataType::Varchar { max_length: None })
)
);
}
#[test]
fn test_normalize_deeply_nested() {
let s = DataType::Struct(vec![(
"a".to_string(),
DataType::Struct(vec![("x".to_string(), DataType::Text)]),
)]);
assert_eq!(
s.normalize(),
DataType::Struct(vec![(
"a".to_string(),
DataType::Struct(vec![(
"x".to_string(),
DataType::Varchar { max_length: None }
)]),
)])
);
}
#[test]
fn test_typed_column_display() {
let col = TypedColumn::not_null(DataType::Integer);
assert_eq!(col.to_string(), "INTEGER NOT NULL");
let col = TypedColumn::nullable(DataType::Varchar {
max_length: Some(100),
});
assert_eq!(col.to_string(), "VARCHAR(100)");
}
}