use sqlparser::ast::{
ArrayElemTypeDef, CharacterLength, DataType, Ident, ObjectName, StructBracketKind, StructField,
};
use crate::Result;
use crate::dialect::{SourceDialect, TargetDialect};
use crate::error::Error;
pub fn rewrite_data_type(
dt: &mut DataType,
source: SourceDialect,
target: TargetDialect,
) -> Result<()> {
match (source, target) {
(SourceDialect::Trino, TargetDialect::DuckDB) => rewrite_trino_type_duckdb(dt),
(SourceDialect::Trino, TargetDialect::DataFusion) => rewrite_trino_type_datafusion(dt),
(SourceDialect::Redshift, TargetDialect::DuckDB) => rewrite_redshift_type_duckdb(dt),
(SourceDialect::Redshift, TargetDialect::DataFusion) => {
rewrite_redshift_type_datafusion(dt)
}
(SourceDialect::Hive, TargetDialect::DuckDB) => rewrite_trino_type_duckdb(dt),
(SourceDialect::Hive, TargetDialect::DataFusion) => rewrite_trino_type_datafusion(dt),
}
}
fn rewrite_trino_type_duckdb(dt: &mut DataType) -> Result<()> {
match dt {
DataType::Custom(name, modifiers) if is_name(name, "row") => {
let fields = parse_row_modifiers(modifiers)?;
*dt = DataType::Struct(fields, StructBracketKind::Parentheses);
if let DataType::Struct(fields, _) = dt {
for field in fields.iter_mut() {
rewrite_trino_type_duckdb(&mut field.field_type)?;
}
}
}
DataType::Array(ArrayElemTypeDef::Parenthesis(inner)) => {
rewrite_trino_type_duckdb(inner)?;
let inner_owned = std::mem::replace(inner.as_mut(), DataType::Unspecified);
*dt = DataType::Array(ArrayElemTypeDef::SquareBracket(Box::new(inner_owned), None));
}
DataType::Array(ArrayElemTypeDef::AngleBracket(inner)) => {
rewrite_trino_type_duckdb(inner)?;
let inner_owned = std::mem::replace(inner.as_mut(), DataType::Unspecified);
*dt = DataType::Array(ArrayElemTypeDef::SquareBracket(Box::new(inner_owned), None));
}
DataType::Array(ArrayElemTypeDef::SquareBracket(inner, _)) => {
rewrite_trino_type_duckdb(inner)?;
}
DataType::Map(key, value) => {
rewrite_trino_type_duckdb(key)?;
rewrite_trino_type_duckdb(value)?;
}
DataType::Varbinary(_) => {
*dt = DataType::Blob(None);
}
DataType::Custom(name, modifiers) if is_name(name, "ipaddress") && modifiers.is_empty() => {
*dt = DataType::Varchar(None);
}
_ => {}
}
Ok(())
}
fn rewrite_trino_type_datafusion(dt: &mut DataType) -> Result<()> {
match dt {
DataType::Custom(name, modifiers) if is_name(name, "row") => {
let fields = parse_row_modifiers(modifiers)?;
*dt = DataType::Struct(fields, StructBracketKind::AngleBrackets);
if let DataType::Struct(fields, _) = dt {
for field in fields.iter_mut() {
rewrite_trino_type_datafusion(&mut field.field_type)?;
}
}
}
DataType::Array(ArrayElemTypeDef::Parenthesis(inner)) => {
rewrite_trino_type_datafusion(inner)?;
let inner_owned = std::mem::replace(inner.as_mut(), DataType::Unspecified);
*dt = DataType::Array(ArrayElemTypeDef::AngleBracket(Box::new(inner_owned)));
}
DataType::Array(ArrayElemTypeDef::AngleBracket(inner)) => {
rewrite_trino_type_datafusion(inner)?;
}
DataType::Array(ArrayElemTypeDef::SquareBracket(inner, _)) => {
rewrite_trino_type_datafusion(inner)?;
let inner_owned = std::mem::replace(inner.as_mut(), DataType::Unspecified);
*dt = DataType::Array(ArrayElemTypeDef::AngleBracket(Box::new(inner_owned)));
}
DataType::Map(key, value) => {
rewrite_trino_type_datafusion(key)?;
rewrite_trino_type_datafusion(value)?;
}
DataType::Varbinary(_) => {
*dt = DataType::Bytea;
}
DataType::Custom(name, modifiers) if is_name(name, "ipaddress") && modifiers.is_empty() => {
*dt = DataType::Varchar(None);
}
_ => {}
}
Ok(())
}
fn rewrite_redshift_type_duckdb(dt: &mut DataType) -> Result<()> {
match dt {
DataType::Varchar(Some(CharacterLength::Max)) => {
*dt = DataType::Varchar(None);
}
DataType::CharacterVarying(Some(CharacterLength::Max)) => {
*dt = DataType::Varchar(None);
}
DataType::Nvarchar(Some(CharacterLength::Max)) => {
*dt = DataType::Varchar(None);
}
DataType::Custom(name, modifiers) if is_name(name, "super") && modifiers.is_empty() => {
*dt = DataType::JSON;
}
DataType::Custom(name, modifiers) if is_name(name, "hllsketch") && modifiers.is_empty() => {
return Err(Error::Unsupported(
"Redshift HLLSKETCH type has no DuckDB equivalent".to_string(),
));
}
DataType::Custom(name, modifiers) if is_name(name, "geometry") && modifiers.is_empty() => {
return Err(Error::Unsupported(
"Redshift GEOMETRY type has no direct DuckDB equivalent".to_string(),
));
}
DataType::Varbinary(_) => {
*dt = DataType::Blob(None);
}
_ => {}
}
Ok(())
}
fn rewrite_redshift_type_datafusion(dt: &mut DataType) -> Result<()> {
match dt {
DataType::Varchar(Some(CharacterLength::Max)) => {
*dt = DataType::Varchar(None);
}
DataType::CharacterVarying(Some(CharacterLength::Max)) => {
*dt = DataType::Varchar(None);
}
DataType::Nvarchar(Some(CharacterLength::Max)) => {
*dt = DataType::Varchar(None);
}
DataType::Custom(name, modifiers) if is_name(name, "super") && modifiers.is_empty() => {
*dt = DataType::Varchar(None);
}
DataType::Custom(name, modifiers) if is_name(name, "hllsketch") && modifiers.is_empty() => {
return Err(Error::Unsupported(
"Redshift HLLSKETCH type has no DataFusion equivalent".to_string(),
));
}
DataType::Custom(name, modifiers) if is_name(name, "geometry") && modifiers.is_empty() => {
return Err(Error::Unsupported(
"Redshift GEOMETRY type has no direct DataFusion equivalent".to_string(),
));
}
DataType::Varbinary(_) => {
*dt = DataType::Bytea;
}
_ => {}
}
Ok(())
}
fn is_name(name: &ObjectName, target: &str) -> bool {
name.0
.last()
.and_then(|p| p.as_ident())
.map(|ident| ident.value.eq_ignore_ascii_case(target))
.unwrap_or(false)
}
fn parse_row_modifiers(modifiers: &[String]) -> Result<Vec<StructField>> {
if !modifiers.len().is_multiple_of(2) {
return Err(Error::Unsupported(format!(
"Cannot parse ROW type modifiers: {modifiers:?}"
)));
}
let mut fields = Vec::new();
for chunk in modifiers.chunks(2) {
let field_name = &chunk[0];
let type_str = &chunk[1];
let field_type = parse_type_string(type_str)?;
fields.push(StructField {
field_name: Some(Ident::new(field_name)),
field_type,
options: None,
});
}
Ok(fields)
}
fn parse_type_string(s: &str) -> Result<DataType> {
let upper = s.trim().to_uppercase();
let dt = match upper.as_str() {
"BOOLEAN" | "BOOL" => DataType::Boolean,
"TINYINT" | "INT1" => DataType::TinyInt(None),
"SMALLINT" | "INT2" => DataType::SmallInt(None),
"INTEGER" | "INT" | "INT4" => DataType::Integer(None),
"BIGINT" | "INT8" => DataType::BigInt(None),
"REAL" | "FLOAT4" => DataType::Real,
"DOUBLE" | "FLOAT8" | "DOUBLE PRECISION" => {
DataType::Double(sqlparser::ast::ExactNumberInfo::None)
}
"VARCHAR" | "STRING" => DataType::Varchar(None),
"TEXT" => DataType::Text,
"DATE" => DataType::Date,
"TIMESTAMP" => DataType::Timestamp(None, sqlparser::ast::TimezoneInfo::None),
"JSON" => DataType::JSON,
"BLOB" | "BYTEA" => DataType::Blob(None),
"UUID" => DataType::Uuid,
other => DataType::Custom(ObjectName::from(vec![Ident::new(other)]), vec![]),
};
Ok(dt)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dialect::TargetDialect;
#[test]
fn redshift_varchar_max() {
let mut dt = DataType::Varchar(Some(CharacterLength::Max));
rewrite_data_type(&mut dt, SourceDialect::Redshift, TargetDialect::DuckDB).unwrap();
assert_eq!(dt, DataType::Varchar(None));
}
#[test]
fn redshift_super_to_json() {
let mut dt = DataType::Custom(ObjectName::from(vec![Ident::new("SUPER")]), vec![]);
rewrite_data_type(&mut dt, SourceDialect::Redshift, TargetDialect::DuckDB).unwrap();
assert_eq!(dt, DataType::JSON);
}
#[test]
fn trino_varbinary_to_blob() {
let mut dt = DataType::Varbinary(None);
rewrite_data_type(&mut dt, SourceDialect::Trino, TargetDialect::DuckDB).unwrap();
assert_eq!(dt, DataType::Blob(None));
}
#[test]
fn trino_row_to_struct() {
let mut dt = DataType::Custom(
ObjectName::from(vec![Ident::new("ROW")]),
vec![
"a".to_string(),
"INTEGER".to_string(),
"b".to_string(),
"VARCHAR".to_string(),
],
);
rewrite_data_type(&mut dt, SourceDialect::Trino, TargetDialect::DuckDB).unwrap();
match &dt {
DataType::Struct(fields, StructBracketKind::Parentheses) => {
assert_eq!(fields.len(), 2);
assert_eq!(fields[0].field_name.as_ref().unwrap().value, "a");
assert_eq!(fields[1].field_name.as_ref().unwrap().value, "b");
}
other => panic!("Expected Struct, got {other:?}"),
}
}
#[test]
fn trino_varbinary_to_bytea_datafusion() {
let mut dt = DataType::Varbinary(None);
rewrite_data_type(&mut dt, SourceDialect::Trino, TargetDialect::DataFusion).unwrap();
assert_eq!(dt, DataType::Bytea);
}
#[test]
fn trino_row_to_struct_datafusion() {
let mut dt = DataType::Custom(
ObjectName::from(vec![Ident::new("ROW")]),
vec![
"a".to_string(),
"INTEGER".to_string(),
"b".to_string(),
"VARCHAR".to_string(),
],
);
rewrite_data_type(&mut dt, SourceDialect::Trino, TargetDialect::DataFusion).unwrap();
match &dt {
DataType::Struct(fields, StructBracketKind::AngleBrackets) => {
assert_eq!(fields.len(), 2);
assert_eq!(fields[0].field_name.as_ref().unwrap().value, "a");
assert_eq!(fields[1].field_name.as_ref().unwrap().value, "b");
}
other => panic!("Expected Struct<>, got {other:?}"),
}
}
#[test]
fn redshift_super_to_varchar_datafusion() {
let mut dt = DataType::Custom(ObjectName::from(vec![Ident::new("SUPER")]), vec![]);
rewrite_data_type(&mut dt, SourceDialect::Redshift, TargetDialect::DataFusion).unwrap();
assert_eq!(dt, DataType::Varchar(None));
}
}