#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Deferrable {
Deferred,
Immediate,
}
pub trait Constraint: Send + Sync {
fn name(&self) -> &str;
fn constraint_sql(&self, table: &str) -> String;
fn create_sql(&self, table: &str) -> String;
fn remove_sql(&self, table: &str) -> String;
fn describe(&self) -> String;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CheckConstraint {
pub name: String,
pub check: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UniqueConstraint {
pub name: String,
pub fields: Vec<String>,
pub condition: Option<String>,
pub deferrable: Option<Deferrable>,
pub include: Vec<String>,
pub opclasses: Vec<String>,
pub nulls_distinct: Option<bool>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ExclusionConstraint {
pub name: String,
pub expressions: Vec<(String, String)>,
pub index_type: String,
pub condition: Option<String>,
}
#[must_use]
fn format_include(columns: &[String]) -> String {
if columns.is_empty() {
String::new()
} else {
format!(" INCLUDE ({})", columns.join(", "))
}
}
#[must_use]
fn format_condition(condition: Option<&str>) -> String {
condition
.filter(|value| !value.trim().is_empty())
.map_or_else(String::new, |value| format!(" WHERE {value}"))
}
#[must_use]
fn format_deferrable(deferrable: Option<Deferrable>) -> String {
match deferrable {
Some(Deferrable::Deferred) => " DEFERRABLE INITIALLY DEFERRED".to_owned(),
Some(Deferrable::Immediate) => " DEFERRABLE INITIALLY IMMEDIATE".to_owned(),
None => String::new(),
}
}
#[must_use]
fn format_nulls_distinct(nulls_distinct: Option<bool>) -> String {
match nulls_distinct {
Some(true) => " NULLS DISTINCT".to_owned(),
Some(false) => " NULLS NOT DISTINCT".to_owned(),
None => String::new(),
}
}
#[must_use]
fn join_columns_with_opclasses(columns: &[String], opclasses: &[String]) -> String {
columns
.iter()
.enumerate()
.map(|(index, column)| match opclasses.get(index) {
Some(opclass) if !opclass.trim().is_empty() => format!("{column} {opclass}"),
_ => column.clone(),
})
.collect::<Vec<_>>()
.join(", ")
}
#[must_use]
fn normalize_index_type(index_type: &str) -> String {
let trimmed = index_type.trim();
if trimmed.is_empty() {
"gist".to_owned()
} else {
trimmed.to_ascii_lowercase()
}
}
impl Constraint for CheckConstraint {
fn name(&self) -> &str {
&self.name
}
fn constraint_sql(&self, _table: &str) -> String {
format!("CONSTRAINT {} CHECK ({})", self.name, self.check)
}
fn create_sql(&self, table: &str) -> String {
format!("ALTER TABLE {table} ADD {}", self.constraint_sql(table))
}
fn remove_sql(&self, table: &str) -> String {
format!("ALTER TABLE {table} DROP CONSTRAINT {}", self.name)
}
fn describe(&self) -> String {
format!("Check constraint {}: {}", self.name, self.check)
}
}
impl Constraint for UniqueConstraint {
fn name(&self) -> &str {
&self.name
}
fn constraint_sql(&self, _table: &str) -> String {
let columns = join_columns_with_opclasses(&self.fields, &self.opclasses);
format!(
"CONSTRAINT {} UNIQUE ({}){}{}{}{}",
self.name,
columns,
format_condition(self.condition.as_deref()),
format_deferrable(self.deferrable),
format_include(&self.include),
format_nulls_distinct(self.nulls_distinct),
)
}
fn create_sql(&self, table: &str) -> String {
format!("ALTER TABLE {table} ADD {}", self.constraint_sql(table))
}
fn remove_sql(&self, table: &str) -> String {
format!("ALTER TABLE {table} DROP CONSTRAINT {}", self.name)
}
fn describe(&self) -> String {
let mut description = format!(
"Unique constraint {} on {}",
self.name,
self.fields.join(", ")
);
if let Some(condition) = self.condition.as_deref() {
description.push_str(&format!(" where {condition}"));
}
description
}
}
impl Constraint for ExclusionConstraint {
fn name(&self) -> &str {
&self.name
}
fn constraint_sql(&self, _table: &str) -> String {
let expressions = self
.expressions
.iter()
.map(|(expression, operator)| format!("{expression} WITH {operator}"))
.collect::<Vec<_>>()
.join(", ");
format!(
"CONSTRAINT {} EXCLUDE USING {} ({}){}",
self.name,
normalize_index_type(&self.index_type),
expressions,
format_condition(self.condition.as_deref()),
)
}
fn create_sql(&self, table: &str) -> String {
format!("ALTER TABLE {table} ADD {}", self.constraint_sql(table))
}
fn remove_sql(&self, table: &str) -> String {
format!("ALTER TABLE {table} DROP CONSTRAINT {}", self.name)
}
fn describe(&self) -> String {
format!(
"Exclusion constraint {} using {} on {}",
self.name,
normalize_index_type(&self.index_type),
self.expressions
.iter()
.map(|(expression, operator)| format!("{expression} WITH {operator}"))
.collect::<Vec<_>>()
.join(", "),
)
}
}
#[cfg(test)]
mod tests {
use super::{CheckConstraint, Constraint, Deferrable, ExclusionConstraint, UniqueConstraint};
#[test]
fn test_check_constraint_sql() {
let constraint = CheckConstraint {
name: "age_gte_18".into(),
check: "age >= 18".into(),
};
assert_eq!(
constraint.constraint_sql("users"),
"CONSTRAINT age_gte_18 CHECK (age >= 18)"
);
}
#[test]
fn test_check_constraint_create_remove() {
let constraint = CheckConstraint {
name: "age_gte_18".into(),
check: "age >= 18".into(),
};
assert_eq!(
constraint.create_sql("users"),
"ALTER TABLE users ADD CONSTRAINT age_gte_18 CHECK (age >= 18)"
);
assert_eq!(
constraint.remove_sql("users"),
"ALTER TABLE users DROP CONSTRAINT age_gte_18"
);
}
#[test]
fn test_unique_constraint_basic() {
let constraint = UniqueConstraint {
name: "uq_book_title".into(),
fields: vec!["title".into()],
condition: None,
deferrable: None,
include: Vec::new(),
opclasses: Vec::new(),
nulls_distinct: None,
};
assert_eq!(
constraint.constraint_sql("books"),
"CONSTRAINT uq_book_title UNIQUE (title)"
);
}
#[test]
fn test_unique_constraint_with_condition() {
let constraint = UniqueConstraint {
name: "uq_live_slug".into(),
fields: vec!["slug".into()],
condition: Some("deleted_at IS NULL".into()),
deferrable: None,
include: Vec::new(),
opclasses: Vec::new(),
nulls_distinct: None,
};
assert_eq!(
constraint.constraint_sql("posts"),
"CONSTRAINT uq_live_slug UNIQUE (slug) WHERE deleted_at IS NULL"
);
}
#[test]
fn test_unique_constraint_deferrable() {
let constraint = UniqueConstraint {
name: "uq_booking_slot".into(),
fields: vec!["room_id".into(), "starts_at".into()],
condition: None,
deferrable: Some(Deferrable::Deferred),
include: Vec::new(),
opclasses: Vec::new(),
nulls_distinct: None,
};
assert_eq!(
constraint.constraint_sql("bookings"),
"CONSTRAINT uq_booking_slot UNIQUE (room_id, starts_at) DEFERRABLE INITIALLY DEFERRED"
);
}
#[test]
fn unique_constraint_supports_include_opclasses_and_nulls_distinct() {
let constraint = UniqueConstraint {
name: "uq_title_pattern".into(),
fields: vec!["title".into(), "subtitle".into()],
condition: Some("published = TRUE".into()),
deferrable: Some(Deferrable::Immediate),
include: vec!["id".into()],
opclasses: vec!["text_pattern_ops".into()],
nulls_distinct: Some(false),
};
assert_eq!(
constraint.constraint_sql("books"),
"CONSTRAINT uq_title_pattern UNIQUE (title text_pattern_ops, subtitle) WHERE published = TRUE DEFERRABLE INITIALLY IMMEDIATE INCLUDE (id) NULLS NOT DISTINCT"
);
}
#[test]
fn test_exclusion_constraint_sql() {
let constraint = ExclusionConstraint {
name: "exclude_room_overlap".into(),
expressions: vec![
("room_id".into(), "=".into()),
("daterange(starts_at, ends_at)".into(), "&&".into()),
],
index_type: "GIST".into(),
condition: Some("cancelled = FALSE".into()),
};
assert_eq!(
constraint.constraint_sql("bookings"),
"CONSTRAINT exclude_room_overlap EXCLUDE USING gist (room_id WITH =, daterange(starts_at, ends_at) WITH &&) WHERE cancelled = FALSE"
);
}
#[test]
fn exclusion_constraint_defaults_to_gist_when_type_missing() {
let constraint = ExclusionConstraint {
name: "exclude_room_overlap".into(),
expressions: vec![("room_id".into(), "=".into())],
index_type: String::new(),
condition: None,
};
assert_eq!(
constraint.constraint_sql("bookings"),
"CONSTRAINT exclude_room_overlap EXCLUDE USING gist (room_id WITH =)"
);
}
}