use std::fmt::Write as _;
use crate::core::{FieldSchema, FieldType, ModelSchema, Relation};
use crate::sql::{Dialect, Postgres};
#[must_use]
pub fn create_table_sql(model: &ModelSchema) -> String {
create_table_sql_with_dialect(&Postgres, model)
}
#[must_use]
pub fn create_table_if_not_exists_sql(model: &ModelSchema) -> String {
create_table_if_not_exists_sql_with_dialect(&Postgres, model)
}
#[must_use]
pub fn drop_table_sql(model: &ModelSchema, if_exists: bool, cascade: bool) -> String {
drop_table_sql_with_dialect(&Postgres, model, if_exists, cascade)
}
#[must_use]
pub fn create_constraints_sql(model: &ModelSchema) -> Vec<String> {
create_constraints_sql_with_dialect(&Postgres, model)
}
#[must_use]
pub fn create_table_sql_with_dialect(dialect: &dyn Dialect, model: &ModelSchema) -> String {
let mut s = String::new();
s.push_str("CREATE TABLE ");
s.push_str(&dialect.quote_ident(model.table));
s.push_str(" (");
let mut first = true;
for field in model.scalar_fields() {
if !first {
s.push_str(", ");
}
first = false;
write_column_def(&mut s, dialect, field);
}
s.push(')');
s
}
#[must_use]
pub fn create_table_if_not_exists_sql_with_dialect(
dialect: &dyn Dialect,
model: &ModelSchema,
) -> String {
let mut s = create_table_sql_with_dialect(dialect, model);
debug_assert!(s.starts_with("CREATE TABLE "));
s.replace_range(.."CREATE TABLE".len(), "CREATE TABLE IF NOT EXISTS");
s
}
#[must_use]
pub fn drop_table_sql_with_dialect(
dialect: &dyn Dialect,
model: &ModelSchema,
if_exists: bool,
cascade: bool,
) -> String {
let mut s = String::from("DROP TABLE ");
if if_exists {
s.push_str("IF EXISTS ");
}
s.push_str(&dialect.quote_ident(model.table));
if cascade {
s.push_str(" CASCADE");
}
s
}
#[must_use]
pub fn create_constraints_sql_with_dialect(
dialect: &dyn Dialect,
model: &ModelSchema,
) -> Vec<String> {
let mut out = Vec::new();
for field in model.scalar_fields() {
let Some(rel) = field.relation else { continue };
let (to, on) = match rel {
Relation::Fk { to, on } | Relation::O2O { to, on } => (to, on),
};
let mut s = String::from("ALTER TABLE ");
s.push_str(&dialect.quote_ident(model.table));
s.push_str(" ADD CONSTRAINT ");
s.push_str(&dialect.quote_ident(&format!("{}_{}_fkey", model.table, field.column)));
s.push_str(" FOREIGN KEY (");
s.push_str(&dialect.quote_ident(field.column));
s.push_str(") REFERENCES ");
s.push_str(&dialect.quote_ident(to));
s.push_str(" (");
s.push_str(&dialect.quote_ident(on));
s.push(')');
out.push(s);
}
for rel in model.composite_relations {
let mut s = String::from("ALTER TABLE ");
s.push_str(&dialect.quote_ident(model.table));
s.push_str(" ADD CONSTRAINT ");
s.push_str(&dialect.quote_ident(&format!("{}_{}_fkey", model.table, rel.name)));
s.push_str(" FOREIGN KEY (");
for (i, col) in rel.from.iter().enumerate() {
if i > 0 {
s.push_str(", ");
}
s.push_str(&dialect.quote_ident(col));
}
s.push_str(") REFERENCES ");
s.push_str(&dialect.quote_ident(rel.to));
s.push_str(" (");
for (i, col) in rel.on.iter().enumerate() {
if i > 0 {
s.push_str(", ");
}
s.push_str(&dialect.quote_ident(col));
}
s.push(')');
out.push(s);
}
out
}
fn write_column_def(s: &mut String, dialect: &dyn Dialect, field: &FieldSchema) {
s.push_str(&dialect.quote_ident(field.column));
s.push(' ');
s.push_str(&sql_type(dialect, field));
if let Some(expr) = field.default {
let _ = write!(s, " DEFAULT {expr}");
}
if !field.nullable {
s.push_str(" NOT NULL");
}
if field.primary_key {
s.push_str(" PRIMARY KEY");
}
if field.unique && !field.primary_key {
s.push_str(" UNIQUE");
}
write_check_constraint(s, dialect, field);
}
fn write_check_constraint(s: &mut String, dialect: &dyn Dialect, field: &FieldSchema) {
if field.min.is_none() && field.max.is_none() {
return;
}
s.push_str(" CHECK (");
let mut wrote = false;
if let Some(min) = field.min {
s.push_str(&dialect.quote_ident(field.column));
let _ = write!(s, " >= {min}");
wrote = true;
}
if let Some(max) = field.max {
if wrote {
s.push_str(" AND ");
}
s.push_str(&dialect.quote_ident(field.column));
let _ = write!(s, " <= {max}");
}
s.push(')');
}
fn sql_type(dialect: &dyn Dialect, field: &FieldSchema) -> String {
if field.auto && matches!(field.ty, FieldType::I16 | FieldType::I32 | FieldType::I64) {
return dialect.serial_type(field.ty).to_owned();
}
dialect.column_type(field.ty, field.max_length)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::FieldType;
fn pg() -> Postgres {
Postgres
}
fn fld(name: &'static str, ty: FieldType, auto: bool, default: Option<&'static str>) -> FieldSchema {
FieldSchema {
name,
column: name,
ty,
nullable: false,
primary_key: false,
relation: None,
max_length: None,
min: None,
max: None,
default,
auto,
unique: false,
}
}
#[test]
fn auto_i32_emits_serial() {
let f = fld("id", FieldType::I32, true, None);
assert_eq!(sql_type(&pg(), &f), "SERIAL");
}
#[test]
fn auto_i64_emits_bigserial() {
let f = fld("id", FieldType::I64, true, None);
assert_eq!(sql_type(&pg(), &f), "BIGSERIAL");
}
#[test]
fn auto_datetime_emits_timestamptz_not_bigserial() {
let f = fld("created_at", FieldType::DateTime, true, Some("now()"));
assert_eq!(sql_type(&pg(), &f), "TIMESTAMPTZ");
}
#[test]
fn auto_uuid_emits_uuid_not_bigserial() {
let f = fld("id", FieldType::Uuid, true, Some("gen_random_uuid()"));
assert_eq!(sql_type(&pg(), &f), "UUID");
}
#[test]
fn full_create_table_has_single_default_per_column() {
let mut col_def = String::new();
write_column_def(
&mut col_def,
&pg(),
&fld("created_at", FieldType::DateTime, true, Some("now()")),
);
let n_defaults = col_def.matches(" DEFAULT ").count();
assert_eq!(
n_defaults, 1,
"expected exactly one DEFAULT clause, got {n_defaults} in: {col_def}"
);
assert!(col_def.contains("TIMESTAMPTZ"), "got: {col_def}");
assert!(col_def.contains("DEFAULT now()"), "got: {col_def}");
assert!(!col_def.contains("BIGSERIAL"), "must not emit BIGSERIAL: {col_def}");
}
#[test]
fn full_create_table_uuid_auto_has_single_default() {
let mut col_def = String::new();
write_column_def(
&mut col_def,
&pg(),
&fld("id", FieldType::Uuid, true, Some("gen_random_uuid()")),
);
let n_defaults = col_def.matches(" DEFAULT ").count();
assert_eq!(n_defaults, 1, "got: {col_def}");
assert!(col_def.contains("UUID"));
assert!(col_def.contains("DEFAULT gen_random_uuid()"));
}
#[test]
fn auto_i64_default_clause_passthrough() {
let mut col_def = String::new();
write_column_def(
&mut col_def,
&pg(),
&fld("id", FieldType::I64, true, None),
);
assert!(col_def.contains("BIGSERIAL"), "got: {col_def}");
assert!(!col_def.contains(" DEFAULT "), "BIGSERIAL must not get an explicit DEFAULT: {col_def}");
}
}