use crate::postgres::def::{ColumnInfo, Type};
use sea_query::{Alias, ColumnDef, ColumnType, DynIden, IntoIden, PgInterval, RcOrArc, StringLen};
use std::{convert::TryFrom, fmt::Write};
impl ColumnInfo {
pub fn write(&self) -> ColumnDef {
let mut col_info = self.clone();
let mut extras: Vec<String> = Vec::new();
if let Some(default) = self.default.as_ref() {
if default.0.starts_with("nextval") {
col_info = Self::convert_to_serial(col_info);
} else {
let mut string = "".to_owned();
write!(&mut string, "DEFAULT {}", default.0).unwrap();
extras.push(string);
}
}
let col_type = col_info.write_col_type();
let mut col_def = ColumnDef::new_with_type(Alias::new(self.name.as_str()), col_type);
if self.is_identity {
col_info = Self::convert_to_serial(col_info);
}
if matches!(
col_info.col_type,
Type::SmallSerial | Type::Serial | Type::BigSerial
) {
col_def.auto_increment();
}
if self.not_null.is_some() {
col_def.not_null();
}
if !extras.is_empty() {
col_def.extra(extras.join(" "));
}
col_def
}
fn convert_to_serial(mut col_info: ColumnInfo) -> ColumnInfo {
match col_info.col_type {
Type::SmallInt => {
col_info.col_type = Type::SmallSerial;
}
Type::Integer => {
col_info.col_type = Type::Serial;
}
Type::BigInt => {
col_info.col_type = Type::BigSerial;
}
_ => {}
};
col_info
}
pub fn write_col_type(&self) -> ColumnType {
fn write_type(col_type: &Type) -> ColumnType {
match col_type {
Type::SmallInt => ColumnType::SmallInteger,
Type::Integer => ColumnType::Integer,
Type::BigInt => ColumnType::BigInteger,
Type::Decimal(num_attr) | Type::Numeric(num_attr) => {
match (num_attr.precision, num_attr.scale) {
(None, None) => ColumnType::Decimal(None),
(precision, scale) => ColumnType::Decimal(Some((
precision.unwrap_or(0).into(),
scale.unwrap_or(0).into(),
))),
}
}
Type::Real => ColumnType::Float,
Type::DoublePrecision => ColumnType::Double,
Type::SmallSerial => ColumnType::SmallInteger,
Type::Serial => ColumnType::Integer,
Type::BigSerial => ColumnType::BigInteger,
Type::Money => ColumnType::Money(None),
Type::Varchar(string_attr) => match string_attr.length {
Some(length) => ColumnType::String(StringLen::N(length.into())),
None => ColumnType::String(StringLen::None),
},
Type::Char(string_attr) => ColumnType::Char(string_attr.length.map(Into::into)),
Type::Text => ColumnType::Text,
Type::Bytea => ColumnType::VarBinary(StringLen::None),
Type::Timestamp(_) => ColumnType::DateTime,
Type::TimestampWithTimeZone(_) => ColumnType::TimestampWithTimeZone,
Type::Date => ColumnType::Date,
Type::Time(_) => ColumnType::Time,
Type::TimeWithTimeZone(_) => ColumnType::Time,
Type::Interval(interval_attr) => {
let field = match &interval_attr.field {
Some(field) => PgInterval::try_from(field).ok(),
None => None,
};
let precision = interval_attr.precision.map(Into::into);
ColumnType::Interval(field, precision)
}
Type::Boolean => ColumnType::Boolean,
Type::Point => ColumnType::Custom("point".into_iden()),
Type::Line => ColumnType::Custom("line".into_iden()),
Type::Lseg => ColumnType::Custom("lseg".into_iden()),
Type::Box => ColumnType::Custom("box".into_iden()),
Type::Path => ColumnType::Custom("path".into_iden()),
Type::Polygon => ColumnType::Custom("polygon".into_iden()),
Type::Circle => ColumnType::Custom("circle".into_iden()),
Type::Cidr => ColumnType::Custom("cidr".into_iden()),
Type::Inet => ColumnType::Custom("inet".into_iden()),
Type::MacAddr => ColumnType::Custom("macaddr".into_iden()),
Type::MacAddr8 => ColumnType::Custom("macaddr8".into_iden()),
Type::Bit(bit_attr) => ColumnType::Bit(bit_attr.length.map(Into::into)),
Type::VarBit(bit_attr) => ColumnType::VarBit(bit_attr.length.unwrap_or(1).into()),
Type::TsVector => ColumnType::Custom("tsvector".into_iden()),
Type::TsQuery => ColumnType::Custom("tsquery".into_iden()),
Type::Uuid => ColumnType::Uuid,
Type::Xml => ColumnType::Custom("xml".into_iden()),
Type::Json => ColumnType::Json,
Type::JsonBinary => ColumnType::JsonBinary,
Type::Int4Range => ColumnType::Custom("int4range".into_iden()),
Type::Int8Range => ColumnType::Custom("int8range".into_iden()),
Type::NumRange => ColumnType::Custom("numrange".into_iden()),
Type::TsRange => ColumnType::Custom("tsrange".into_iden()),
Type::TsTzRange => ColumnType::Custom("tstzrange".into_iden()),
Type::DateRange => ColumnType::Custom("daterange".into_iden()),
Type::PgLsn => ColumnType::Custom("pg_lsn".into_iden()),
#[cfg(feature = "postgres-vector")]
Type::Vector(vector_attr) => match vector_attr.length {
Some(length) => ColumnType::Vector(Some(length)),
None => ColumnType::Vector(None),
},
Type::Unknown(s) => ColumnType::Custom(Alias::new(s).into_iden()),
Type::Enum(enum_def) => {
let name = Alias::new(&enum_def.typename).into_iden();
let variants: Vec<DynIden> = enum_def
.values
.iter()
.map(|variant| Alias::new(variant).into_iden())
.collect();
ColumnType::Enum { name, variants }
}
Type::Array(array_def) => ColumnType::Array(RcOrArc::new(write_type(
array_def.col_type.as_ref().expect("Array type not defined"),
))),
}
}
write_type(&self.col_type)
}
}