use std::sync::Arc;
use crate::expr::Expr;
use crate::expr::write_expr;
use crate::types::Iden;
use crate::writer::SqlWriter;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ColumnDef {
pub(crate) name: Iden,
pub(crate) ty: Option<ColumnType>,
pub(crate) spec: ColumnSpec,
}
impl ColumnDef {
pub fn new<N>(name: N) -> Self
where
N: Into<Iden>,
{
Self {
name: name.into(),
ty: None,
spec: ColumnSpec::default(),
}
}
pub fn default(mut self, expr: Expr) -> Self {
if self.spec.generated.is_some() {
panic!("A generated column cannot have a default value.");
}
self.spec.default = Some(expr);
self
}
pub fn not_null(mut self) -> Self {
self.spec.nullable = Some(false);
self
}
pub fn null(mut self) -> Self {
self.spec.nullable = Some(true);
self
}
pub fn char(mut self, size: u32) -> Self {
if size == 0 {
panic!("Character type size must be greater than zero.");
}
self.ty = Some(ColumnType::Char(size));
self
}
pub fn varchar(mut self, size: u32) -> Self {
if size == 0 {
panic!("Character type size must be greater than zero.");
}
self.ty = Some(ColumnType::Varchar(size));
self
}
pub fn text(mut self) -> Self {
self.ty = Some(ColumnType::Text);
self
}
pub fn bytea(mut self) -> Self {
self.ty = Some(ColumnType::Bytea);
self
}
pub fn smallint(mut self) -> Self {
self.ty = Some(ColumnType::SmallInt);
self
}
pub fn int(mut self) -> Self {
self.ty = Some(ColumnType::Int);
self
}
pub fn bigint(mut self) -> Self {
self.ty = Some(ColumnType::BigInt);
self
}
pub fn float(mut self) -> Self {
self.ty = Some(ColumnType::Float);
self
}
pub fn double(mut self) -> Self {
self.ty = Some(ColumnType::Double);
self
}
pub fn numeric(mut self, precision: i32, scale: i32) -> Self {
if scale > precision {
panic!("Numeric scale cannot be greater than precision.");
}
if precision <= 0 {
panic!("Numeric precision must be greater than zero.");
}
if precision > 1000 {
panic!("Numeric precision cannot be greater than 1000.");
}
self.ty = Some(ColumnType::Numeric(Some((precision, scale))));
self
}
pub fn numeric_unbounded(mut self) -> Self {
self.ty = Some(ColumnType::Numeric(None));
self
}
pub fn smallserial(mut self) -> Self {
self.ty = Some(ColumnType::SmallSerial);
self
}
pub fn serial(mut self) -> Self {
self.ty = Some(ColumnType::Serial);
self
}
pub fn bigserial(mut self) -> Self {
self.ty = Some(ColumnType::BigSerial);
self
}
pub fn int4_range(mut self) -> Self {
self.ty = Some(ColumnType::Int4Range);
self
}
pub fn int8_range(mut self) -> Self {
self.ty = Some(ColumnType::Int8Range);
self
}
pub fn num_range(mut self) -> Self {
self.ty = Some(ColumnType::NumRange);
self
}
pub fn ts_range(mut self) -> Self {
self.ty = Some(ColumnType::TsRange);
self
}
pub fn ts_tz_range(mut self) -> Self {
self.ty = Some(ColumnType::TsTzRange);
self
}
pub fn date_range(mut self) -> Self {
self.ty = Some(ColumnType::DateRange);
self
}
pub fn date_time(mut self) -> Self {
self.ty = Some(ColumnType::DateTime);
self
}
pub fn timestamp(mut self) -> Self {
self.ty = Some(ColumnType::Timestamp);
self
}
pub fn timestamp_with_time_zone(mut self) -> Self {
self.ty = Some(ColumnType::TimestampWithTimeZone);
self
}
pub fn time(mut self) -> Self {
self.ty = Some(ColumnType::Time);
self
}
pub fn date(mut self) -> Self {
self.ty = Some(ColumnType::Date);
self
}
pub fn boolean(mut self) -> Self {
self.ty = Some(ColumnType::Boolean);
self
}
pub fn json(mut self) -> Self {
self.ty = Some(ColumnType::Json);
self
}
pub fn json_binary(mut self) -> Self {
self.ty = Some(ColumnType::JsonBinary);
self
}
pub fn uuid(mut self) -> Self {
self.ty = Some(ColumnType::Uuid);
self
}
pub fn array_of(mut self, ty: ColumnType) -> Self {
self.ty = Some(ColumnType::Array(Arc::new(ty)));
self
}
pub fn generated_as_stored<E>(mut self, expr: E) -> Self
where
E: Into<Expr>,
{
if self.spec.default.is_some() {
panic!("A generated column cannot have a default value.");
}
self.spec.generated = Some(GeneratedColumn {
expr: expr.into(),
kind: GeneratedColumnKind::Stored,
});
self
}
pub fn generated_as_virtual<E>(mut self, expr: E) -> Self
where
E: Into<Expr>,
{
if self.spec.default.is_some() {
panic!("A generated column cannot have a default value.");
}
self.spec.generated = Some(GeneratedColumn {
expr: expr.into(),
kind: GeneratedColumnKind::Virtual,
});
self
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
#[expect(missing_docs)]
pub enum ColumnType {
Char(u32),
Varchar(u32),
Text,
Bytea,
SmallInt,
Int,
BigInt,
Float,
Double,
Numeric(Option<(i32, i32)>),
Int4Range,
Int8Range,
NumRange,
TsRange,
TsTzRange,
DateRange,
SmallSerial,
Serial,
BigSerial,
DateTime,
Timestamp,
TimestampWithTimeZone,
Time,
Date,
Boolean,
Json,
JsonBinary,
Uuid,
Array(Arc<ColumnType>),
}
#[derive(Debug, Clone, Default)]
#[non_exhaustive]
#[expect(missing_docs)]
pub struct ColumnSpec {
pub nullable: Option<bool>,
pub default: Option<Expr>,
pub generated: Option<GeneratedColumn>,
pub unique: bool,
pub primary_key: bool,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
#[expect(missing_docs)]
pub struct GeneratedColumn {
pub expr: Expr,
pub kind: GeneratedColumnKind,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
#[expect(missing_docs)]
pub enum GeneratedColumnKind {
Stored,
Virtual,
}
pub(crate) fn write_column_type<W: SqlWriter>(w: &mut W, column_type: &ColumnType) {
match column_type {
ColumnType::Char(size) => {
w.push_str("char(");
w.push_str(&size.to_string());
w.push_str(")");
}
ColumnType::Varchar(size) => {
w.push_str("varchar(");
w.push_str(&size.to_string());
w.push_str(")");
}
ColumnType::Text => w.push_str("text"),
ColumnType::Bytea => w.push_str("bytea"),
ColumnType::SmallInt => w.push_str("smallint"),
ColumnType::Int => w.push_str("integer"),
ColumnType::BigInt => w.push_str("bigint"),
ColumnType::Float => w.push_str("real"),
ColumnType::Double => w.push_str("double precision"),
ColumnType::Numeric(Some((p, s))) => {
w.push_str("numeric(");
w.push_str(&p.to_string());
w.push_str(", ");
w.push_str(&s.to_string());
w.push_str(")");
}
ColumnType::Numeric(None) => {
w.push_str("numeric");
}
ColumnType::SmallSerial => w.push_str("smallserial"),
ColumnType::Serial => w.push_str("serial"),
ColumnType::BigSerial => w.push_str("bigserial"),
ColumnType::Int4Range => w.push_str("int4range"),
ColumnType::Int8Range => w.push_str("int8range"),
ColumnType::NumRange => w.push_str("numrange"),
ColumnType::TsRange => w.push_str("tsrange"),
ColumnType::TsTzRange => w.push_str("tstzrange"),
ColumnType::DateRange => w.push_str("daterange"),
ColumnType::DateTime => w.push_str("timestamp without time zone"),
ColumnType::Timestamp => w.push_str("timestamp"),
ColumnType::TimestampWithTimeZone => w.push_str("timestamp with time zone"),
ColumnType::Time => w.push_str("time"),
ColumnType::Date => w.push_str("date"),
ColumnType::Boolean => w.push_str("bool"),
ColumnType::Json => w.push_str("json"),
ColumnType::JsonBinary => w.push_str("jsonb"),
ColumnType::Uuid => w.push_str("uuid"),
ColumnType::Array(ty) => {
write_column_type(w, ty);
w.push_str("[]");
}
}
}
pub(crate) fn write_column_spec<W: SqlWriter>(w: &mut W, column_spec: &ColumnSpec) {
let ColumnSpec {
nullable,
default,
generated,
unique,
primary_key,
} = column_spec;
if let Some(nullable) = nullable {
w.push_str(if *nullable { " NULL" } else { " NOT NULL" });
}
if let Some(default) = default {
w.push_str(" DEFAULT ");
match default {
Expr::Value(_) | Expr::Keyword(_) => write_expr(w, default),
_ => {
w.push_str("(");
write_expr(w, default);
w.push_str(")");
}
}
}
if let Some(generated) = generated {
w.push_str(" GENERATED ALWAYS AS (");
write_expr(w, &generated.expr);
w.push_str(")");
w.push_str(match generated.kind {
GeneratedColumnKind::Stored => " STORED",
GeneratedColumnKind::Virtual => " VIRTUAL",
});
}
if *primary_key {
w.push_str(" PRIMARY KEY");
}
if *unique {
w.push_str(" UNIQUE");
}
}