use crate::checks::pg_helpers::{
ConstrType, NodeEnum, column_has_constraint, column_type_name, is_serial_pattern,
range_var_name,
};
use crate::checks::{Check, Config, MigrationContext};
use crate::violation::Violation;
use pg_query::protobuf::ColumnDef;
const CONSTR_PRIMARY: i32 = ConstrType::ConstrPrimary as i32;
pub struct CreateTableSerialCheck;
impl Check for CreateTableSerialCheck {
fn check(&self, node: &NodeEnum, config: &Config, _ctx: &MigrationContext) -> Vec<Violation> {
let NodeEnum::CreateStmt(create_stmt) = node else {
return vec![];
};
let table_name = create_stmt
.relation
.as_ref()
.map(range_var_name)
.unwrap_or_default();
let table_primary_key = table_primary_key(create_stmt);
create_stmt
.table_elts
.iter()
.filter_map(|elt| match &elt.node {
Some(NodeEnum::ColumnDef(col)) if is_serial_pattern(col) => {
let constraint_context = column_constraint_context(col, &table_primary_key);
let column_name = &col.colname;
let serial_type = column_type_name(col);
let safe_alternative = safe_alternative(
&table_name,
column_name,
&serial_type,
col,
config,
constraint_context,
);
Some(Violation::new(
"CREATE TABLE with SERIAL",
format!(
"Column '{column_name}' in CREATE TABLE '{table_name}' uses {serial_type} which is a PostgreSQL pseudo-type (non-standard SQL). \
SERIAL also creates a separately-owned sequence object, which can complicate permissions, dump/restore, and replication workflows."
),
safe_alternative,
))
}
_ => None,
})
.collect()
}
}
#[derive(Clone, Copy)]
enum ColumnConstraintContext {
Standalone,
Composite,
Regular,
}
fn table_primary_key(create_stmt: &pg_query::protobuf::CreateStmt) -> Vec<String> {
create_stmt
.table_elts
.iter()
.find_map(|elt| match &elt.node {
Some(NodeEnum::Constraint(constraint)) if constraint.contype == CONSTR_PRIMARY => Some(
constraint
.keys
.iter()
.filter_map(|key| match &key.node {
Some(NodeEnum::String(s)) => Some(s.sval.clone()),
_ => None,
})
.collect(),
),
_ => None,
})
.unwrap_or_default()
}
fn column_constraint_context(
col: &ColumnDef,
table_primary_key: &[String],
) -> ColumnConstraintContext {
if column_has_constraint(col, CONSTR_PRIMARY) {
return ColumnConstraintContext::Standalone;
}
if !table_primary_key.iter().any(|name| name == &col.colname) {
return ColumnConstraintContext::Regular;
}
if table_primary_key.len() == 1 {
ColumnConstraintContext::Standalone
} else {
ColumnConstraintContext::Composite
}
}
fn identity_sql_type(serial_type: &str) -> &'static str {
match serial_type {
"smallserial" => "SMALLINT",
"bigserial" => "BIGINT",
_ => "INTEGER",
}
}
fn safe_alternative(
table_name: &str,
column_name: &str,
serial_type: &str,
col: &ColumnDef,
config: &Config,
constraint_context: ColumnConstraintContext,
) -> String {
let data_type = identity_sql_type(serial_type);
let suffix = inline_constraint_suffix(col, constraint_context);
if config.postgres_version.is_some_and(|version| version < 10) {
let version = config.postgres_version.unwrap_or_default();
return format!(
"Identity columns require PostgreSQL 10+.\n\
If you must support PostgreSQL {version}, use an explicit integer column and sequence instead:\n\
1. CREATE TABLE {table_name} ({column_name} {data_type}{suffix});\n\
2. CREATE SEQUENCE {table_name}_{column_name}_seq;\n\
3. ALTER SEQUENCE {table_name}_{column_name}_seq OWNED BY {table_name}.{column_name};\n\
4. ALTER TABLE {table_name} ALTER COLUMN {column_name} SET DEFAULT nextval('{table_name}_{column_name}_seq');"
);
}
format!(
"Use SQL-standard identity columns instead:\n CREATE TABLE {table_name} ({column_name} {data_type} GENERATED BY DEFAULT AS IDENTITY{suffix});"
)
}
fn inline_constraint_suffix(
col: &ColumnDef,
constraint_context: ColumnConstraintContext,
) -> String {
let is_not_null = column_has_constraint(col, ConstrType::ConstrNotnull as i32);
let is_unique = column_has_constraint(col, ConstrType::ConstrUnique as i32);
let mut suffixes = Vec::new();
if matches!(constraint_context, ColumnConstraintContext::Standalone) {
suffixes.push("PRIMARY KEY");
} else {
if is_not_null {
suffixes.push("NOT NULL");
}
if is_unique {
suffixes.push("UNIQUE");
}
}
if suffixes.is_empty() {
String::new()
} else {
format!(" {}", suffixes.join(" "))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::checks::test_utils::parse_sql;
use crate::{assert_allows, assert_detects_violation};
#[test]
fn test_detects_create_table_with_serial() {
assert_detects_violation!(
CreateTableSerialCheck,
"CREATE TABLE users (id SERIAL PRIMARY KEY);",
"CREATE TABLE with SERIAL"
);
}
#[test]
fn test_detects_create_table_with_bigserial() {
assert_detects_violation!(
CreateTableSerialCheck,
"CREATE TABLE events (id BIGSERIAL PRIMARY KEY);",
"CREATE TABLE with SERIAL"
);
}
#[test]
fn test_detects_create_table_with_smallserial() {
assert_detects_violation!(
CreateTableSerialCheck,
"CREATE TABLE users (id SMALLSERIAL PRIMARY KEY);",
"CREATE TABLE with SERIAL"
);
}
#[test]
fn test_allows_create_table_with_identity() {
assert_allows!(
CreateTableSerialCheck,
"CREATE TABLE users (id BIGINT GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY);"
);
}
#[test]
fn test_allows_alter_table_add_serial() {
assert_allows!(
CreateTableSerialCheck,
"ALTER TABLE users ADD COLUMN id SERIAL;"
);
}
#[test]
fn test_detects_create_table_with_non_primary_key_serial() {
assert_detects_violation!(
CreateTableSerialCheck,
"CREATE TABLE users (event_no SERIAL, id BIGINT PRIMARY KEY);",
"CREATE TABLE with SERIAL"
);
}
#[test]
fn test_ignores_other_statements() {
assert_allows!(
CreateTableSerialCheck,
"CREATE INDEX idx_users_email ON users(email);"
);
}
#[test]
fn test_detects_create_table_with_serial_in_table_level_primary_key() {
assert_detects_violation!(
CreateTableSerialCheck,
"CREATE TABLE users (id SERIAL, name TEXT, PRIMARY KEY (id));",
"CREATE TABLE with SERIAL"
);
}
#[test]
fn test_detects_create_table_with_serial_on_non_primary_key_table_level_column() {
assert_detects_violation!(
CreateTableSerialCheck,
"CREATE TABLE users (event_no SERIAL, id BIGINT, PRIMARY KEY (id));",
"CREATE TABLE with SERIAL"
);
}
#[test]
fn test_safe_alternative_omits_primary_key_for_composite_primary_key() {
let stmt = parse_sql(
"CREATE TABLE users (tenant_id BIGINT, id SERIAL, PRIMARY KEY (tenant_id, id));",
);
let violations =
CreateTableSerialCheck.check(&stmt, &Config::default(), &MigrationContext::default());
assert_eq!(violations.len(), 1, "Expected exactly 1 violation");
assert_eq!(violations[0].operation, "CREATE TABLE with SERIAL");
assert!(
violations[0]
.safe_alternative
.contains("GENERATED BY DEFAULT AS IDENTITY")
);
assert!(!violations[0].safe_alternative.contains("PRIMARY KEY"));
}
#[test]
fn test_safe_alternative_preserves_primary_key() {
let stmt = parse_sql("CREATE TABLE events (id BIGSERIAL PRIMARY KEY);");
let violations =
CreateTableSerialCheck.check(&stmt, &Config::default(), &MigrationContext::default());
assert_eq!(violations.len(), 1, "Expected exactly 1 violation");
assert!(violations[0].safe_alternative.contains("BIGINT"));
assert!(violations[0].safe_alternative.contains("PRIMARY KEY"));
}
#[test]
fn test_non_primary_key_safe_alternative_preserves_inline_constraints() {
let stmt = parse_sql(
"CREATE TABLE users (external_id SERIAL NOT NULL UNIQUE, id BIGINT PRIMARY KEY);",
);
let violations =
CreateTableSerialCheck.check(&stmt, &Config::default(), &MigrationContext::default());
assert_eq!(violations.len(), 1, "Expected exactly 1 violation");
assert!(
violations[0]
.safe_alternative
.contains("GENERATED BY DEFAULT AS IDENTITY NOT NULL UNIQUE")
);
}
#[test]
fn test_pre_pg10_safe_alternative_uses_explicit_sequence() {
let stmt = parse_sql("CREATE TABLE users (id SERIAL PRIMARY KEY);");
let violations = CreateTableSerialCheck.check(
&stmt,
&Config {
postgres_version: Some(9),
..Config::default()
},
&MigrationContext::default(),
);
assert_eq!(violations.len(), 1, "Expected exactly 1 violation");
assert!(
violations[0]
.safe_alternative
.contains("Identity columns require PostgreSQL 10+")
);
assert!(
violations[0]
.safe_alternative
.contains("CREATE SEQUENCE users_id_seq")
);
assert!(violations[0].safe_alternative.contains("PRIMARY KEY"));
}
#[test]
fn test_pre_pg10_regular_serial_preserves_inline_constraints() {
let stmt = parse_sql(
"CREATE TABLE users (external_id SERIAL NOT NULL UNIQUE, id BIGINT PRIMARY KEY);",
);
let violations = CreateTableSerialCheck.check(
&stmt,
&Config {
postgres_version: Some(9),
..Config::default()
},
&MigrationContext::default(),
);
assert_eq!(violations.len(), 1, "Expected exactly 1 violation");
assert!(
violations[0]
.safe_alternative
.contains("CREATE TABLE users (external_id INTEGER NOT NULL UNIQUE);")
);
assert!(!violations[0].safe_alternative.contains("PRIMARY KEY"));
}
}