mod functions;
mod parse;
pub use functions::{FunctionCategory, SqlFunction};
pub use parse::{parse_type, TypeParseError};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum DataType {
Boolean,
SmallInt,
Integer,
BigInt,
Float,
Double,
Decimal { precision: u8, scale: u8 },
Varchar { max_length: Option<u32> },
Char { length: u32 },
Text,
Blob,
Date,
Time,
Timestamp { with_timezone: bool },
Interval,
Array(Box<DataType>),
Null,
Unknown,
}
impl DataType {
pub fn is_numeric(&self) -> bool {
matches!(
self,
DataType::SmallInt
| DataType::Integer
| DataType::BigInt
| DataType::Float
| DataType::Double
| DataType::Decimal { .. }
)
}
pub fn is_string(&self) -> bool {
matches!(
self,
DataType::Varchar { .. } | DataType::Char { .. } | DataType::Text
)
}
pub fn is_temporal(&self) -> bool {
matches!(
self,
DataType::Date | DataType::Time | DataType::Timestamp { .. } | DataType::Interval
)
}
pub fn to_sql(&self) -> String {
match self {
DataType::Boolean => "BOOLEAN".to_string(),
DataType::SmallInt => "SMALLINT".to_string(),
DataType::Integer => "INTEGER".to_string(),
DataType::BigInt => "BIGINT".to_string(),
DataType::Float => "FLOAT".to_string(),
DataType::Double => "DOUBLE".to_string(),
DataType::Decimal { precision, scale } => {
if *scale == 0 {
format!("DECIMAL({precision})")
} else {
format!("DECIMAL({precision},{scale})")
}
}
DataType::Varchar { max_length: None } => "VARCHAR".to_string(),
DataType::Varchar {
max_length: Some(len),
} => format!("VARCHAR({len})"),
DataType::Char { length } => format!("CHAR({length})"),
DataType::Text => "TEXT".to_string(),
DataType::Blob => "BLOB".to_string(),
DataType::Date => "DATE".to_string(),
DataType::Time => "TIME".to_string(),
DataType::Timestamp { with_timezone } => {
if *with_timezone {
"TIMESTAMP WITH TIME ZONE".to_string()
} else {
"TIMESTAMP".to_string()
}
}
DataType::Interval => "INTERVAL".to_string(),
DataType::Array(inner) => format!("{}[]", inner.to_sql()),
DataType::Null => "NULL".to_string(),
DataType::Unknown => "UNKNOWN".to_string(),
}
}
}
impl std::fmt::Display for DataType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_sql())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TypedColumn {
pub data_type: DataType,
pub nullable: bool,
}
impl TypedColumn {
pub fn new(data_type: DataType, nullable: bool) -> Self {
Self {
data_type,
nullable,
}
}
pub fn nullable(data_type: DataType) -> Self {
Self::new(data_type, true)
}
pub fn not_null(data_type: DataType) -> Self {
Self::new(data_type, false)
}
pub fn unknown() -> Self {
Self::nullable(DataType::Unknown)
}
}
impl std::fmt::Display for TypedColumn {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.data_type)?;
if !self.nullable {
write!(f, " NOT NULL")?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_data_type_display() {
assert_eq!(DataType::Integer.to_string(), "INTEGER");
assert_eq!(
DataType::Decimal {
precision: 10,
scale: 2
}
.to_string(),
"DECIMAL(10,2)"
);
assert_eq!(
DataType::Varchar { max_length: None }.to_string(),
"VARCHAR"
);
assert_eq!(
DataType::Varchar {
max_length: Some(255)
}
.to_string(),
"VARCHAR(255)"
);
assert_eq!(
DataType::Timestamp {
with_timezone: true
}
.to_string(),
"TIMESTAMP WITH TIME ZONE"
);
assert_eq!(
DataType::Array(Box::new(DataType::Integer)).to_string(),
"INTEGER[]"
);
}
#[test]
fn test_is_numeric() {
assert!(DataType::Integer.is_numeric());
assert!(DataType::BigInt.is_numeric());
assert!(DataType::Double.is_numeric());
assert!(DataType::Decimal {
precision: 10,
scale: 2
}
.is_numeric());
assert!(!DataType::Varchar { max_length: None }.is_numeric());
assert!(!DataType::Date.is_numeric());
}
#[test]
fn test_typed_column_display() {
let col = TypedColumn::not_null(DataType::Integer);
assert_eq!(col.to_string(), "INTEGER NOT NULL");
let col = TypedColumn::nullable(DataType::Varchar {
max_length: Some(100),
});
assert_eq!(col.to_string(), "VARCHAR(100)");
}
}