diesel-guard 0.10.0

Linter for dangerous Postgres migration patterns in Diesel and SQLx. Prevents downtime caused by unsafe schema changes.
Documentation
//! Detection for `CREATE TABLE` with SERIAL pseudo-types.
//!
//! `SERIAL` / `BIGSERIAL` / `SMALLSERIAL` are PostgreSQL pseudo-types (not
//! standard SQL). PostgreSQL 10+ provides SQL-standard identity columns
//! (`GENERATED ... AS IDENTITY`) as the preferred replacement.

use crate::checks::pg_helpers::{
    ConstrType, NodeEnum, column_has_constraint, column_type_name, is_serial_pattern,
    range_var_name,
};
use crate::checks::{Check, Config, MigrationContext};
use crate::violation::Violation;
use pg_query::protobuf::ColumnDef;

const CONSTR_PRIMARY: i32 = ConstrType::ConstrPrimary as i32;

pub struct CreateTableSerialCheck;

impl Check for CreateTableSerialCheck {
    fn check(&self, node: &NodeEnum, config: &Config, _ctx: &MigrationContext) -> Vec<Violation> {
        let NodeEnum::CreateStmt(create_stmt) = node else {
            return vec![];
        };

        let table_name = create_stmt
            .relation
            .as_ref()
            .map(range_var_name)
            .unwrap_or_default();
        let table_primary_key = table_primary_key(create_stmt);

        create_stmt
            .table_elts
            .iter()
            .filter_map(|elt| match &elt.node {
                Some(NodeEnum::ColumnDef(col)) if is_serial_pattern(col) => {
                    let constraint_context = column_constraint_context(col, &table_primary_key);
                    let column_name = &col.colname;
                    let serial_type = column_type_name(col);
                    let safe_alternative = safe_alternative(
                        &table_name,
                        column_name,
                        &serial_type,
                        col,
                        config,
                        constraint_context,
                    );

                    Some(Violation::new(
                        "CREATE TABLE with SERIAL",
                        format!(
                            "Column '{column_name}' in CREATE TABLE '{table_name}' uses {serial_type} which is a PostgreSQL pseudo-type (non-standard SQL). \
                            SERIAL also creates a separately-owned sequence object, which can complicate permissions, dump/restore, and replication workflows."
                        ),
                        safe_alternative,
                    ))
                }
                _ => None,
            })
            .collect()
    }
}

#[derive(Clone, Copy)]
enum ColumnConstraintContext {
    Standalone,
    Composite,
    Regular,
}

fn table_primary_key(create_stmt: &pg_query::protobuf::CreateStmt) -> Vec<String> {
    create_stmt
        .table_elts
        .iter()
        .find_map(|elt| match &elt.node {
            Some(NodeEnum::Constraint(constraint)) if constraint.contype == CONSTR_PRIMARY => Some(
                constraint
                    .keys
                    .iter()
                    .filter_map(|key| match &key.node {
                        Some(NodeEnum::String(s)) => Some(s.sval.clone()),
                        _ => None,
                    })
                    .collect(),
            ),
            _ => None,
        })
        .unwrap_or_default()
}

fn column_constraint_context(
    col: &ColumnDef,
    table_primary_key: &[String],
) -> ColumnConstraintContext {
    if column_has_constraint(col, CONSTR_PRIMARY) {
        return ColumnConstraintContext::Standalone;
    }

    if !table_primary_key.iter().any(|name| name == &col.colname) {
        return ColumnConstraintContext::Regular;
    }

    if table_primary_key.len() == 1 {
        ColumnConstraintContext::Standalone
    } else {
        ColumnConstraintContext::Composite
    }
}

fn identity_sql_type(serial_type: &str) -> &'static str {
    match serial_type {
        "smallserial" => "SMALLINT",
        "bigserial" => "BIGINT",
        _ => "INTEGER",
    }
}

fn safe_alternative(
    table_name: &str,
    column_name: &str,
    serial_type: &str,
    col: &ColumnDef,
    config: &Config,
    constraint_context: ColumnConstraintContext,
) -> String {
    let data_type = identity_sql_type(serial_type);
    let suffix = inline_constraint_suffix(col, constraint_context);

    if config.postgres_version.is_some_and(|version| version < 10) {
        let version = config.postgres_version.unwrap_or_default();

        return format!(
            "Identity columns require PostgreSQL 10+.\n\
If you must support PostgreSQL {version}, use an explicit integer column and sequence instead:\n\
1. CREATE TABLE {table_name} ({column_name} {data_type}{suffix});\n\
2. CREATE SEQUENCE {table_name}_{column_name}_seq;\n\
3. ALTER SEQUENCE {table_name}_{column_name}_seq OWNED BY {table_name}.{column_name};\n\
4. ALTER TABLE {table_name} ALTER COLUMN {column_name} SET DEFAULT nextval('{table_name}_{column_name}_seq');"
        );
    }

    format!(
        "Use SQL-standard identity columns instead:\n   CREATE TABLE {table_name} ({column_name} {data_type} GENERATED BY DEFAULT AS IDENTITY{suffix});"
    )
}

fn inline_constraint_suffix(
    col: &ColumnDef,
    constraint_context: ColumnConstraintContext,
) -> String {
    let is_not_null = column_has_constraint(col, ConstrType::ConstrNotnull as i32);
    let is_unique = column_has_constraint(col, ConstrType::ConstrUnique as i32);

    let mut suffixes = Vec::new();

    if matches!(constraint_context, ColumnConstraintContext::Standalone) {
        suffixes.push("PRIMARY KEY");
    } else {
        if is_not_null {
            suffixes.push("NOT NULL");
        }
        if is_unique {
            suffixes.push("UNIQUE");
        }
    }

    if suffixes.is_empty() {
        String::new()
    } else {
        format!(" {}", suffixes.join(" "))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::checks::test_utils::parse_sql;
    use crate::{assert_allows, assert_detects_violation};

    #[test]
    fn test_detects_create_table_with_serial() {
        assert_detects_violation!(
            CreateTableSerialCheck,
            "CREATE TABLE users (id SERIAL PRIMARY KEY);",
            "CREATE TABLE with SERIAL"
        );
    }

    #[test]
    fn test_detects_create_table_with_bigserial() {
        assert_detects_violation!(
            CreateTableSerialCheck,
            "CREATE TABLE events (id BIGSERIAL PRIMARY KEY);",
            "CREATE TABLE with SERIAL"
        );
    }

    #[test]
    fn test_detects_create_table_with_smallserial() {
        assert_detects_violation!(
            CreateTableSerialCheck,
            "CREATE TABLE users (id SMALLSERIAL PRIMARY KEY);",
            "CREATE TABLE with SERIAL"
        );
    }

    #[test]
    fn test_allows_create_table_with_identity() {
        assert_allows!(
            CreateTableSerialCheck,
            "CREATE TABLE users (id BIGINT GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY);"
        );
    }

    #[test]
    fn test_allows_alter_table_add_serial() {
        assert_allows!(
            CreateTableSerialCheck,
            "ALTER TABLE users ADD COLUMN id SERIAL;"
        );
    }

    #[test]
    fn test_detects_create_table_with_non_primary_key_serial() {
        assert_detects_violation!(
            CreateTableSerialCheck,
            "CREATE TABLE users (event_no SERIAL, id BIGINT PRIMARY KEY);",
            "CREATE TABLE with SERIAL"
        );
    }

    #[test]
    fn test_ignores_other_statements() {
        assert_allows!(
            CreateTableSerialCheck,
            "CREATE INDEX idx_users_email ON users(email);"
        );
    }

    #[test]
    fn test_detects_create_table_with_serial_in_table_level_primary_key() {
        assert_detects_violation!(
            CreateTableSerialCheck,
            "CREATE TABLE users (id SERIAL, name TEXT, PRIMARY KEY (id));",
            "CREATE TABLE with SERIAL"
        );
    }

    #[test]
    fn test_detects_create_table_with_serial_on_non_primary_key_table_level_column() {
        assert_detects_violation!(
            CreateTableSerialCheck,
            "CREATE TABLE users (event_no SERIAL, id BIGINT, PRIMARY KEY (id));",
            "CREATE TABLE with SERIAL"
        );
    }

    #[test]
    fn test_safe_alternative_omits_primary_key_for_composite_primary_key() {
        let stmt = parse_sql(
            "CREATE TABLE users (tenant_id BIGINT, id SERIAL, PRIMARY KEY (tenant_id, id));",
        );
        let violations =
            CreateTableSerialCheck.check(&stmt, &Config::default(), &MigrationContext::default());

        assert_eq!(violations.len(), 1, "Expected exactly 1 violation");
        assert_eq!(violations[0].operation, "CREATE TABLE with SERIAL");
        assert!(
            violations[0]
                .safe_alternative
                .contains("GENERATED BY DEFAULT AS IDENTITY")
        );
        assert!(!violations[0].safe_alternative.contains("PRIMARY KEY"));
    }

    #[test]
    fn test_safe_alternative_preserves_primary_key() {
        let stmt = parse_sql("CREATE TABLE events (id BIGSERIAL PRIMARY KEY);");
        let violations =
            CreateTableSerialCheck.check(&stmt, &Config::default(), &MigrationContext::default());

        assert_eq!(violations.len(), 1, "Expected exactly 1 violation");
        assert!(violations[0].safe_alternative.contains("BIGINT"));
        assert!(violations[0].safe_alternative.contains("PRIMARY KEY"));
    }

    #[test]
    fn test_non_primary_key_safe_alternative_preserves_inline_constraints() {
        let stmt = parse_sql(
            "CREATE TABLE users (external_id SERIAL NOT NULL UNIQUE, id BIGINT PRIMARY KEY);",
        );
        let violations =
            CreateTableSerialCheck.check(&stmt, &Config::default(), &MigrationContext::default());

        assert_eq!(violations.len(), 1, "Expected exactly 1 violation");
        assert!(
            violations[0]
                .safe_alternative
                .contains("GENERATED BY DEFAULT AS IDENTITY NOT NULL UNIQUE")
        );
    }

    #[test]
    fn test_pre_pg10_safe_alternative_uses_explicit_sequence() {
        let stmt = parse_sql("CREATE TABLE users (id SERIAL PRIMARY KEY);");
        let violations = CreateTableSerialCheck.check(
            &stmt,
            &Config {
                postgres_version: Some(9),
                ..Config::default()
            },
            &MigrationContext::default(),
        );

        assert_eq!(violations.len(), 1, "Expected exactly 1 violation");
        assert!(
            violations[0]
                .safe_alternative
                .contains("Identity columns require PostgreSQL 10+")
        );
        assert!(
            violations[0]
                .safe_alternative
                .contains("CREATE SEQUENCE users_id_seq")
        );
        assert!(violations[0].safe_alternative.contains("PRIMARY KEY"));
    }

    #[test]
    fn test_pre_pg10_regular_serial_preserves_inline_constraints() {
        let stmt = parse_sql(
            "CREATE TABLE users (external_id SERIAL NOT NULL UNIQUE, id BIGINT PRIMARY KEY);",
        );
        let violations = CreateTableSerialCheck.check(
            &stmt,
            &Config {
                postgres_version: Some(9),
                ..Config::default()
            },
            &MigrationContext::default(),
        );

        assert_eq!(violations.len(), 1, "Expected exactly 1 violation");
        assert!(
            violations[0]
                .safe_alternative
                .contains("CREATE TABLE users (external_id INTEGER NOT NULL UNIQUE);")
        );
        assert!(!violations[0].safe_alternative.contains("PRIMARY KEY"));
    }
}