use sea_query::RcOrArc;
#[cfg(feature = "with-serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))]
pub enum Type {
SmallInt,
Integer,
BigInt,
Decimal(ArbitraryPrecisionNumericAttr),
Numeric(ArbitraryPrecisionNumericAttr),
Real,
DoublePrecision,
SmallSerial,
Serial,
BigSerial,
Money,
Varchar(StringAttr),
Char(StringAttr),
Text,
Bytea,
Timestamp(TimeAttr),
TimestampWithTimeZone(TimeAttr),
Date,
Time(TimeAttr),
TimeWithTimeZone(TimeAttr),
Interval(IntervalAttr),
Boolean,
Point,
Line,
Lseg,
Box,
Path,
Polygon,
Circle,
Cidr,
Inet,
MacAddr,
MacAddr8,
Bit(BitAttr),
VarBit(BitAttr),
TsVector,
TsQuery,
Uuid,
Xml,
Json,
JsonBinary,
Array(ArrayDef),
#[cfg(feature = "postgres-vector")]
Vector(VectorDef),
Int4Range,
Int8Range,
NumRange,
TsRange,
TsTzRange,
DateRange,
PgLsn,
Unknown(String),
Enum(EnumDef),
}
impl Type {
#[allow(clippy::should_implement_trait)]
pub fn from_str(column_type: &str, udt_name: Option<&str>, is_enum: bool) -> Type {
match column_type.to_lowercase().as_str() {
"smallint" | "int2" => Type::SmallInt,
"integer" | "int" | "int4" => Type::Integer,
"bigint" | "int8" => Type::BigInt,
"decimal" => Type::Decimal(ArbitraryPrecisionNumericAttr::default()),
"numeric" => Type::Numeric(ArbitraryPrecisionNumericAttr::default()),
"real" | "float4" => Type::Real,
"double precision" | "double" | "float8" => Type::DoublePrecision,
"smallserial" | "serial2" => Type::SmallSerial,
"serial" | "serial4" => Type::Serial,
"bigserial" | "serial8" => Type::BigSerial,
"money" => Type::Money,
"character varying" | "varchar" => Type::Varchar(StringAttr::default()),
"character" | "char" => Type::Char(StringAttr::default()),
"text" => Type::Text,
"bytea" => Type::Bytea,
"timestamp" | "timestamp without time zone" => Type::Timestamp(TimeAttr::default()),
"timestamp with time zone" => Type::TimestampWithTimeZone(TimeAttr::default()),
"date" => Type::Date,
"time" | "time without time zone" => Type::Time(TimeAttr::default()),
"time with time zone" => Type::TimeWithTimeZone(TimeAttr::default()),
"interval" => Type::Interval(IntervalAttr::default()),
"boolean" | "bool" => Type::Boolean,
"point" => Type::Point,
"line" => Type::Line,
"lseg" => Type::Lseg,
"box" => Type::Box,
"path" => Type::Path,
"polygon" => Type::Polygon,
"circle" => Type::Circle,
"cidr" => Type::Cidr,
"inet" => Type::Inet,
"macaddr" => Type::MacAddr,
"macaddr8" => Type::MacAddr8,
"bit" => Type::Bit(BitAttr::default()),
"bit varying" | "varbit" => Type::VarBit(BitAttr::default()),
"tsvector" => Type::TsVector,
"tsquery" => Type::TsQuery,
"uuid" => Type::Uuid,
"xml" => Type::Xml,
"json" => Type::Json,
"jsonb" => Type::JsonBinary,
"int4range" => Type::Int4Range,
"int8range" => Type::Int8Range,
"numrange" => Type::NumRange,
"tsrange" => Type::TsRange,
"tstzrange" => Type::TsTzRange,
"daterange" => Type::DateRange,
"pg_lsn" => Type::PgLsn,
"user-defined" => match (is_enum, udt_name) {
(true, _) => Type::Enum(EnumDef::default()),
#[cfg(feature = "postgres-vector")]
(false, Some("vector")) => Type::Vector(VectorDef::default()),
(false, Some(other_name)) => Type::Unknown(other_name.to_owned()),
_ => Type::Unknown("user-defined".to_owned()),
},
"array" => Type::Array(ArrayDef::default()),
_ => Type::Unknown(column_type.to_owned()),
}
}
}
#[derive(Debug, Clone, PartialEq, Default)]
#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))]
pub struct ArbitraryPrecisionNumericAttr {
pub precision: Option<u16>,
pub scale: Option<u16>,
}
#[derive(Debug, Clone, PartialEq, Default)]
#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))]
pub struct StringAttr {
pub length: Option<u16>,
}
#[derive(Debug, Clone, PartialEq, Default)]
#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))]
pub struct TimeAttr {
pub precision: Option<u16>,
}
#[derive(Debug, Clone, PartialEq, Default)]
#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))]
pub struct IntervalAttr {
pub field: Option<String>,
pub precision: Option<u16>,
}
#[derive(Debug, Clone, PartialEq, Default)]
#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))]
pub struct BitAttr {
pub length: Option<u16>,
}
#[derive(Debug, Clone, PartialEq, Default)]
#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))]
pub struct EnumDef {
pub values: Vec<String>,
pub typename: String,
}
#[derive(Debug, Clone, PartialEq, Default)]
#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))]
pub struct ArrayDef {
pub col_type: Option<RcOrArc<Type>>,
}
#[cfg(feature = "postgres-vector")]
#[derive(Debug, Clone, PartialEq, Default)]
#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))]
pub struct VectorDef {
pub length: Option<u32>,
}
impl Type {
pub fn has_numeric_attr(&self) -> bool {
matches!(self, Type::Numeric(_) | Type::Decimal(_))
}
pub fn has_string_attr(&self) -> bool {
matches!(self, Type::Varchar(_) | Type::Char(_))
}
pub fn has_time_attr(&self) -> bool {
matches!(
self,
Type::Timestamp(_)
| Type::TimestampWithTimeZone(_)
| Type::Time(_)
| Type::TimeWithTimeZone(_)
)
}
pub fn has_interval_attr(&self) -> bool {
matches!(self, Type::Interval(_))
}
pub fn has_bit_attr(&self) -> bool {
matches!(self, Type::Bit(_) | Type::VarBit(_))
}
pub fn has_enum_attr(&self) -> bool {
matches!(self, Type::Enum(_))
}
pub fn has_array_attr(&self) -> bool {
matches!(self, Type::Array(_))
}
#[cfg(feature = "postgres-vector")]
pub fn has_vector_attr(&self) -> bool {
matches!(self, Type::Vector(_))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_user_defined_enum_and_vector() {
assert_eq!(
Type::from_str("user-defined", None, true),
Type::Enum(EnumDef::default())
);
#[cfg(feature = "postgres-vector")]
assert_eq!(
Type::from_str("user-defined", Some("vector"), false),
Type::Vector(VectorDef::default())
);
#[cfg(not(feature = "postgres-vector"))]
assert_eq!(
Type::from_str("user-defined", Some("vector"), false),
Type::Unknown("vector".into())
);
assert_eq!(
Type::from_str("user-defined", Some("foo_bar"), false),
Type::Unknown("foo_bar".into())
);
}
}