pgmold 0.33.6

PostgreSQL schema-as-code management tool
Documentation
use crate::model::*;
use crate::util::Result;
use sqlparser::ast::{
    ColumnDef, ColumnOption, DataType, Expr, FunctionArg as SqlFunctionArg, FunctionArgExpr,
    FunctionArguments, ReferentialAction as SqlReferentialAction, TableConstraint,
};
use std::collections::BTreeMap;

use super::util::{
    extract_qualified_name, normalize_expr, parse_data_type, truncate_identifier, unquote_ident,
};

pub(super) struct ParsedTable {
    pub(super) table: Table,
    pub(super) sequences: Vec<Sequence>,
}

pub(super) fn parse_create_table(
    schema: &str,
    name: &str,
    columns: &[ColumnDef],
    constraints: &[TableConstraint],
    partition_by: Option<&Expr>,
) -> Result<ParsedTable> {
    let mut table = Table {
        schema: schema.to_string(),
        name: name.to_string(),
        columns: BTreeMap::new(),
        indexes: Vec::new(),
        primary_key: None,
        foreign_keys: Vec::new(),
        check_constraints: Vec::new(),
        comment: None,
        row_level_security: false,
        force_row_level_security: false,
        policies: Vec::new(),
        partition_by: partition_by.and_then(parse_partition_by),
        owner: None,
        grants: Vec::new(),
    };

    let mut sequences = Vec::new();

    for col_def in columns {
        let (column, maybe_sequence) = parse_column_with_serial(schema, name, col_def)?;
        table.columns.insert(column.name.clone(), column);
        if let Some(seq) = maybe_sequence {
            sequences.push(seq);
        }
    }

    for col_def in columns {
        for option in &col_def.options {
            if matches!(option.option, ColumnOption::PrimaryKey(_)) {
                let pk_col = unquote_ident(&col_def.name.to_string()).to_string();
                table.primary_key = Some(PrimaryKey {
                    columns: vec![pk_col.clone()],
                });
                if let Some(col) = table.columns.get_mut(&pk_col) {
                    col.nullable = false;
                }
            }
        }
    }

    for constraint in constraints {
        match constraint {
            TableConstraint::PrimaryKey(pk) => {
                let pk_columns: Vec<String> = pk
                    .columns
                    .iter()
                    .map(|c| unquote_ident(&c.to_string()).to_string())
                    .collect();
                table.primary_key = Some(PrimaryKey {
                    columns: pk_columns.clone(),
                });
                for pk_col in &pk_columns {
                    if let Some(col) = table.columns.get_mut(pk_col) {
                        col.nullable = false;
                    }
                }
            }
            TableConstraint::ForeignKey(fk) => {
                let fk_name = fk
                    .name
                    .as_ref()
                    .map(|n| unquote_ident(&n.to_string()).to_string())
                    .unwrap_or_else(|| {
                        format!(
                            "{}_{}_fkey",
                            table.name,
                            unquote_ident(&fk.columns[0].to_string())
                        )
                    });

                let (ref_schema, ref_table) = extract_qualified_name(&fk.foreign_table);
                table.foreign_keys.push(ForeignKey {
                    name: truncate_identifier(&fk_name),
                    columns: fk
                        .columns
                        .iter()
                        .map(|c| unquote_ident(&c.to_string()).to_string())
                        .collect(),
                    referenced_schema: ref_schema,
                    referenced_table: ref_table,
                    referenced_columns: fk
                        .referred_columns
                        .iter()
                        .map(|c| unquote_ident(&c.to_string()).to_string())
                        .collect(),
                    on_delete: parse_referential_action(&fk.on_delete),
                    on_update: parse_referential_action(&fk.on_update),
                });
            }
            TableConstraint::Check(chk) => {
                let constraint_name = chk
                    .name
                    .as_ref()
                    .map(|n| unquote_ident(&n.to_string()).to_string())
                    .unwrap_or_else(|| format!("{}_check", table.name));

                table.check_constraints.push(CheckConstraint {
                    name: constraint_name,
                    expression: normalize_expr(&chk.expr.to_string()),
                });
            }
            TableConstraint::Unique(uniq) => {
                let constraint_name = uniq
                    .name
                    .as_ref()
                    .map(|n| unquote_ident(&n.to_string()).to_string())
                    .unwrap_or_else(|| format!("{}_unique", table.name));

                table.indexes.push(Index {
                    name: constraint_name,
                    columns: uniq
                        .columns
                        .iter()
                        .map(|c| unquote_ident(&c.column.expr.to_string()).to_string())
                        .collect(),
                    unique: true,
                    index_type: IndexType::BTree,
                    predicate: None,
                    is_constraint: true,
                });
            }
            _ => {}
        }
    }

    table.foreign_keys.sort();
    table.check_constraints.sort();
    table.indexes.sort();

    Ok(ParsedTable { table, sequences })
}

pub(super) fn parse_column_with_serial(
    table_schema: &str,
    table_name: &str,
    col_def: &ColumnDef,
) -> Result<(Column, Option<Sequence>)> {
    let mut nullable = true;
    let mut default = None;

    for option in &col_def.options {
        match &option.option {
            ColumnOption::NotNull => nullable = false,
            ColumnOption::Null => nullable = true,
            ColumnOption::Default(expr) => {
                default = Some(normalize_expr(&expr.to_string()));
            }
            _ => {}
        }
    }

    let col_name = unquote_ident(&col_def.name.to_string()).to_string();

    if let Some(seq_data_type) = detect_serial_type(&col_def.data_type) {
        let seq_name = format!("{table_name}_{col_name}_seq");
        let seq_qualified = qualified_name(table_schema, &seq_name);

        let pg_type = match seq_data_type {
            SequenceDataType::SmallInt => PgType::SmallInt,
            SequenceDataType::Integer => PgType::Integer,
            SequenceDataType::BigInt => PgType::BigInt,
        };

        let max_value = match seq_data_type {
            SequenceDataType::SmallInt => Some(32767),
            SequenceDataType::Integer => Some(2147483647),
            SequenceDataType::BigInt => Some(9223372036854775807),
        };

        let nextval_ref = if table_schema == "public" {
            seq_name.clone()
        } else {
            seq_qualified.clone()
        };

        let column = Column {
            name: col_name.clone(),
            data_type: pg_type,
            nullable,
            default: Some(format!("nextval('{nextval_ref}'::regclass)")),
            comment: None,
        };

        let sequence = Sequence {
            name: seq_name,
            schema: table_schema.to_string(),
            data_type: seq_data_type,
            start: Some(1),
            increment: Some(1),
            min_value: Some(1),
            max_value,
            cycle: false,
            cache: Some(1),
            owned_by: Some(SequenceOwner {
                table_schema: table_schema.to_string(),
                table_name: table_name.to_string(),
                column_name: col_name,
            }),
            owner: None,
            grants: Vec::new(),
            comment: None,
        };

        Ok((column, Some(sequence)))
    } else {
        let column = Column {
            name: col_name,
            data_type: parse_data_type(&col_def.data_type)?,
            nullable,
            default,
            comment: None,
        };
        Ok((column, None))
    }
}

pub(super) fn detect_serial_type(dt: &DataType) -> Option<SequenceDataType> {
    if let DataType::Custom(name, _) = dt {
        let type_name = name.to_string().to_lowercase();
        match type_name.as_str() {
            "serial" => Some(SequenceDataType::Integer),
            "bigserial" => Some(SequenceDataType::BigInt),
            "smallserial" => Some(SequenceDataType::SmallInt),
            _ => None,
        }
    } else {
        None
    }
}

pub(super) fn parse_referential_action(action: &Option<SqlReferentialAction>) -> ReferentialAction {
    match action {
        Some(SqlReferentialAction::NoAction) => ReferentialAction::NoAction,
        Some(SqlReferentialAction::Restrict) => ReferentialAction::Restrict,
        Some(SqlReferentialAction::Cascade) => ReferentialAction::Cascade,
        Some(SqlReferentialAction::SetNull) => ReferentialAction::SetNull,
        Some(SqlReferentialAction::SetDefault) => ReferentialAction::SetDefault,
        None => ReferentialAction::NoAction,
    }
}

fn parse_partition_by(expr: &Expr) -> Option<PartitionKey> {
    match expr {
        Expr::Function(func) => {
            let strategy_name = func.name.to_string().to_uppercase();
            let strategy = match strategy_name.as_str() {
                "RANGE" => PartitionStrategy::Range,
                "LIST" => PartitionStrategy::List,
                "HASH" => PartitionStrategy::Hash,
                _ => return None,
            };

            let columns: Vec<String> = match &func.args {
                FunctionArguments::List(args) => args
                    .args
                    .iter()
                    .filter_map(|arg| match arg {
                        SqlFunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier(ident))) => {
                            Some(ident.value.clone())
                        }
                        SqlFunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => {
                            Some(expr.to_string())
                        }
                        _ => None,
                    })
                    .collect(),
                _ => Vec::new(),
            };

            Some(PartitionKey {
                strategy,
                columns,
                expressions: Vec::new(),
            })
        }
        _ => None,
    }
}