use crate::DataType;
use thiserror::Error;
#[derive(Debug, Error, PartialEq, Eq)]
pub enum TypeParseError {
#[error("empty type string")]
EmptyString,
#[error("unknown type: {0}")]
UnknownType(String),
#[error("invalid precision/scale for DECIMAL: {0}")]
InvalidDecimal(String),
#[error("invalid length for {type_name}: {value}")]
InvalidLength { type_name: String, value: String },
#[error("missing closing parenthesis")]
MissingCloseParen,
}
pub fn parse_type(type_str: &str) -> Result<DataType, TypeParseError> {
let type_str = type_str.trim();
if type_str.is_empty() {
return Err(TypeParseError::EmptyString);
}
let upper = type_str.to_uppercase();
if let Some(paren_pos) = upper.find('(') {
return parse_parameterized_type(&upper, paren_pos);
}
if upper.starts_with("TIMESTAMP") {
return parse_timestamp_type(&upper);
}
match upper.as_str() {
"BOOLEAN" | "BOOL" => Ok(DataType::Boolean),
"TINYINT" | "INT1" => Ok(DataType::SmallInt), "SMALLINT" | "INT2" => Ok(DataType::SmallInt),
"INT" | "INTEGER" | "INT4" => Ok(DataType::Integer),
"BIGINT" | "INT8" | "LONG" => Ok(DataType::BigInt),
"HUGEINT" | "INT16" => Ok(DataType::BigInt),
"REAL" | "FLOAT4" | "FLOAT" => Ok(DataType::Float),
"DOUBLE" | "FLOAT8" | "DOUBLE PRECISION" => Ok(DataType::Double),
"VARCHAR" | "STRING" | "TEXT" => Ok(DataType::Varchar { max_length: None }),
"CHAR" | "CHARACTER" => Ok(DataType::Char { length: 1 }),
"DATE" => Ok(DataType::Date),
"TIME" => Ok(DataType::Time),
"TIMESTAMP" => Ok(DataType::Timestamp {
with_timezone: false,
}),
"TIMESTAMPTZ" => Ok(DataType::Timestamp {
with_timezone: true,
}),
"INTERVAL" => Ok(DataType::Interval),
"BLOB" | "BYTEA" | "BINARY" | "VARBINARY" => Ok(DataType::Blob),
"NUMERIC" | "DECIMAL" => Ok(DataType::Decimal {
precision: 18,
scale: 0,
}),
_ => Err(TypeParseError::UnknownType(type_str.to_string())),
}
}
fn parse_parameterized_type(upper: &str, paren_pos: usize) -> Result<DataType, TypeParseError> {
let type_name = upper[..paren_pos].trim();
let params_str = &upper[paren_pos + 1..];
let close_pos = params_str
.find(')')
.ok_or(TypeParseError::MissingCloseParen)?;
let params = ¶ms_str[..close_pos];
match type_name {
"VARCHAR" | "VARYING" | "CHARACTER VARYING" | "STRING" => {
let length = parse_single_number(params, "VARCHAR")?;
Ok(DataType::Varchar {
max_length: Some(length),
})
}
"CHAR" | "CHARACTER" => {
let length = parse_single_number(params, "CHAR")?;
Ok(DataType::Char { length })
}
"DECIMAL" | "NUMERIC" | "DEC" => parse_decimal_params(params),
"FLOAT" => {
let precision = parse_single_number(params, "FLOAT")?;
if precision <= 24 {
Ok(DataType::Float)
} else {
Ok(DataType::Double)
}
}
"TIME" => {
Ok(DataType::Time)
}
"TIMESTAMP" => {
let suffix = ¶ms_str[close_pos + 1..].trim();
let with_tz =
suffix.starts_with("WITH TIME ZONE") || suffix.starts_with("WITH TIMEZONE");
Ok(DataType::Timestamp {
with_timezone: with_tz,
})
}
_ => Err(TypeParseError::UnknownType(type_name.to_string())),
}
}
fn parse_timestamp_type(upper: &str) -> Result<DataType, TypeParseError> {
let with_tz = upper.contains("WITH TIME ZONE")
|| upper.contains("WITH TIMEZONE")
|| upper == "TIMESTAMPTZ";
Ok(DataType::Timestamp {
with_timezone: with_tz,
})
}
fn parse_single_number(params: &str, type_name: &str) -> Result<u32, TypeParseError> {
params
.trim()
.parse::<u32>()
.map_err(|_| TypeParseError::InvalidLength {
type_name: type_name.to_string(),
value: params.to_string(),
})
}
fn parse_decimal_params(params: &str) -> Result<DataType, TypeParseError> {
let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
match parts.len() {
1 => {
let precision = parts[0]
.parse::<u8>()
.map_err(|_| TypeParseError::InvalidDecimal(params.to_string()))?;
Ok(DataType::Decimal {
precision,
scale: 0,
})
}
2 => {
let precision = parts[0]
.parse::<u8>()
.map_err(|_| TypeParseError::InvalidDecimal(params.to_string()))?;
let scale = parts[1]
.parse::<u8>()
.map_err(|_| TypeParseError::InvalidDecimal(params.to_string()))?;
Ok(DataType::Decimal { precision, scale })
}
_ => Err(TypeParseError::InvalidDecimal(params.to_string())),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_simple_types() {
assert_eq!(parse_type("INTEGER").unwrap(), DataType::Integer);
assert_eq!(parse_type("int").unwrap(), DataType::Integer);
assert_eq!(parse_type("BIGINT").unwrap(), DataType::BigInt);
assert_eq!(parse_type("BOOLEAN").unwrap(), DataType::Boolean);
assert_eq!(parse_type("bool").unwrap(), DataType::Boolean);
assert_eq!(parse_type("DATE").unwrap(), DataType::Date);
assert_eq!(
parse_type("VARCHAR").unwrap(),
DataType::Varchar { max_length: None }
);
}
#[test]
fn test_parse_varchar_with_length() {
assert_eq!(
parse_type("VARCHAR(255)").unwrap(),
DataType::Varchar {
max_length: Some(255)
}
);
assert_eq!(
parse_type("varchar(100)").unwrap(),
DataType::Varchar {
max_length: Some(100)
}
);
}
#[test]
fn test_parse_char_with_length() {
assert_eq!(
parse_type("CHAR(10)").unwrap(),
DataType::Char { length: 10 }
);
assert_eq!(parse_type("CHAR").unwrap(), DataType::Char { length: 1 });
}
#[test]
fn test_parse_decimal() {
assert_eq!(
parse_type("DECIMAL(10,2)").unwrap(),
DataType::Decimal {
precision: 10,
scale: 2
}
);
assert_eq!(
parse_type("DECIMAL(18)").unwrap(),
DataType::Decimal {
precision: 18,
scale: 0
}
);
assert_eq!(
parse_type("NUMERIC(5, 3)").unwrap(),
DataType::Decimal {
precision: 5,
scale: 3
}
);
assert_eq!(
parse_type("DECIMAL").unwrap(),
DataType::Decimal {
precision: 18,
scale: 0
}
);
}
#[test]
fn test_parse_timestamp() {
assert_eq!(
parse_type("TIMESTAMP").unwrap(),
DataType::Timestamp {
with_timezone: false
}
);
assert_eq!(
parse_type("TIMESTAMP WITH TIME ZONE").unwrap(),
DataType::Timestamp {
with_timezone: true
}
);
assert_eq!(
parse_type("TIMESTAMPTZ").unwrap(),
DataType::Timestamp {
with_timezone: true
}
);
}
#[test]
fn test_parse_float_precision() {
assert_eq!(parse_type("FLOAT").unwrap(), DataType::Float);
assert_eq!(parse_type("FLOAT(24)").unwrap(), DataType::Float);
assert_eq!(parse_type("FLOAT(53)").unwrap(), DataType::Double);
}
#[test]
fn test_parse_aliases() {
assert_eq!(parse_type("INT").unwrap(), DataType::Integer);
assert_eq!(parse_type("INT4").unwrap(), DataType::Integer);
assert_eq!(parse_type("INT8").unwrap(), DataType::BigInt);
assert_eq!(parse_type("REAL").unwrap(), DataType::Float);
assert_eq!(parse_type("DOUBLE PRECISION").unwrap(), DataType::Double);
assert_eq!(
parse_type("TEXT").unwrap(),
DataType::Varchar { max_length: None }
);
assert_eq!(
parse_type("STRING").unwrap(),
DataType::Varchar { max_length: None }
);
}
#[test]
fn test_parse_errors() {
assert!(matches!(parse_type(""), Err(TypeParseError::EmptyString)));
assert!(matches!(
parse_type("FOOBAR"),
Err(TypeParseError::UnknownType(_))
));
assert!(matches!(
parse_type("VARCHAR(abc)"),
Err(TypeParseError::InvalidLength { .. })
));
assert!(matches!(
parse_type("DECIMAL(a,b)"),
Err(TypeParseError::InvalidDecimal(_))
));
}
#[test]
fn test_case_insensitivity() {
assert_eq!(parse_type("integer").unwrap(), DataType::Integer);
assert_eq!(parse_type("INTEGER").unwrap(), DataType::Integer);
assert_eq!(parse_type("Integer").unwrap(), DataType::Integer);
assert_eq!(
parse_type("varchar(100)").unwrap(),
DataType::Varchar {
max_length: Some(100)
}
);
}
#[test]
fn test_whitespace_handling() {
assert_eq!(parse_type(" INTEGER ").unwrap(), DataType::Integer);
assert_eq!(
parse_type("DECIMAL( 10 , 2 )").unwrap(),
DataType::Decimal {
precision: 10,
scale: 2
}
);
}
}