pgmold 0.33.6

PostgreSQL schema-as-code management tool
Documentation
mod drizzle;

use crate::model::Schema;
use crate::parser::load_schema_sources;
use crate::util::SchemaError;

pub use drizzle::load_drizzle_schema;

type Result<T> = std::result::Result<T, SchemaError>;

pub fn load_schema_from_sources(sources: &[String]) -> Result<Schema> {
    if sources.is_empty() {
        return Err(SchemaError::ParseError(
            "No schema sources provided".to_string(),
        ));
    }

    let schemas: Vec<Schema> = sources
        .iter()
        .map(|source| load_single_source(source))
        .collect::<Result<_>>()?;

    merge_schemas(schemas)
}

fn load_single_source(source: &str) -> Result<Schema> {
    if let Some(path) = source.strip_prefix("sql:") {
        load_sql_source(path)
    } else if let Some(path) = source.strip_prefix("drizzle:") {
        load_drizzle_schema(path)
    } else {
        Err(SchemaError::ParseError(format!(
            "Unknown schema source prefix: {source}. \
             Use 'sql:' for SQL files/directories or 'drizzle:' for Drizzle ORM configs."
        )))
    }
}

fn load_sql_source(path: &str) -> Result<Schema> {
    load_schema_sources(&[path.to_string()])
}

fn merge_collection<V>(
    target: &mut std::collections::BTreeMap<String, V>,
    source: std::collections::BTreeMap<String, V>,
    object_type: &str,
) -> Result<()> {
    use std::collections::btree_map::Entry;

    for (name, value) in source {
        match target.entry(name) {
            Entry::Occupied(entry) => {
                return Err(SchemaError::ParseError(format!(
                    "Duplicate {object_type} \"{}\" from multiple sources",
                    entry.key()
                )));
            }
            Entry::Vacant(entry) => {
                entry.insert(value);
            }
        }
    }
    Ok(())
}

fn merge_schemas(schemas: Vec<Schema>) -> Result<Schema> {
    if schemas.is_empty() {
        return Err(SchemaError::ParseError("No schemas to merge".to_string()));
    }

    if schemas.len() == 1 {
        let mut schema = schemas.into_iter().next().unwrap();
        schema.finalize().map_err(SchemaError::ParseError)?;
        return Ok(schema);
    }

    let mut merged = Schema::new();

    for schema in schemas {
        merge_collection(&mut merged.tables, schema.tables, "table")?;
        merge_collection(&mut merged.enums, schema.enums, "enum")?;
        merge_collection(&mut merged.functions, schema.functions, "function")?;
        merge_collection(&mut merged.views, schema.views, "view")?;
        merge_collection(&mut merged.triggers, schema.triggers, "trigger")?;
        merge_collection(&mut merged.sequences, schema.sequences, "sequence")?;
        merge_collection(&mut merged.domains, schema.domains, "domain")?;
        merge_collection(&mut merged.extensions, schema.extensions, "extension")?;
        merge_collection(&mut merged.schemas, schema.schemas, "schema")?;
        merge_collection(&mut merged.partitions, schema.partitions, "partition")?;

        merged.pending_policies.extend(schema.pending_policies);
        merged.pending_owners.extend(schema.pending_owners);
        merged.pending_grants.extend(schema.pending_grants);
        merged.pending_revokes.extend(schema.pending_revokes);
    }

    merged.finalize().map_err(SchemaError::ParseError)?;

    Ok(merged)
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::path::{Path, PathBuf};
    use tempfile::TempDir;

    fn write_sql_file(directory: &TempDir, filename: &str, content: &[u8]) -> PathBuf {
        let path = directory.path().join(filename);
        std::fs::write(&path, content).unwrap();
        path
    }

    fn sql_source(path: &Path) -> String {
        format!("sql:{}", path.display())
    }

    #[test]
    fn unknown_prefix_error() {
        let result = load_schema_from_sources(&["unknown:foo.sql".to_string()]);
        let err = result.unwrap_err().to_string();
        assert!(err.contains("Unknown schema source prefix"));
    }

    #[test]
    fn empty_sources_error() {
        let result = load_schema_from_sources(&[]);
        let err = result.unwrap_err().to_string();
        assert!(err.contains("No schema sources provided"));
    }

    #[test]
    fn orphan_policy_errors_single_source() {
        let dir = TempDir::new().unwrap();
        let file = write_sql_file(
            &dir,
            "orphan.sql",
            b"CREATE POLICY orphan_policy ON nonexistent_table FOR ALL USING (true);",
        );

        let result = load_schema_from_sources(&[sql_source(&file)]);
        let err = result.unwrap_err().to_string();
        assert!(err.contains("nonexistent_table"));
    }

    #[test]
    fn orphan_policy_errors_at_provider_level() {
        let dir1 = TempDir::new().unwrap();
        let dir2 = TempDir::new().unwrap();

        let table_file = write_sql_file(
            &dir1,
            "tables.sql",
            b"CREATE TABLE public.users (id serial PRIMARY KEY);",
        );
        let policy_file = write_sql_file(
            &dir2,
            "policies.sql",
            b"CREATE POLICY orphan_policy ON nonexistent_table FOR ALL USING (true);",
        );

        let result = load_schema_from_sources(&[sql_source(&table_file), sql_source(&policy_file)]);
        let err = result.unwrap_err().to_string();
        assert!(err.contains("nonexistent_table"));
    }

    #[test]
    fn ownership_from_secondary_source_applied() {
        let dir1 = TempDir::new().unwrap();
        let dir2 = TempDir::new().unwrap();

        let table_file = write_sql_file(
            &dir1,
            "tables.sql",
            b"CREATE TABLE public.users (id serial PRIMARY KEY);",
        );
        let ownership_file = write_sql_file(
            &dir2,
            "ownership.sql",
            b"ALTER TABLE public.users OWNER TO app_user;",
        );

        let merged =
            load_schema_from_sources(&[sql_source(&table_file), sql_source(&ownership_file)])
                .unwrap();
        assert_eq!(
            merged.tables["public.users"].owner,
            Some("app_user".to_string())
        );
    }

    #[test]
    fn grant_from_secondary_source_applied() {
        let dir1 = TempDir::new().unwrap();
        let dir2 = TempDir::new().unwrap();

        let table_file = write_sql_file(
            &dir1,
            "tables.sql",
            b"CREATE TABLE public.users (id serial PRIMARY KEY);",
        );
        let grant_file = write_sql_file(
            &dir2,
            "grants.sql",
            b"GRANT SELECT, INSERT ON TABLE public.users TO readonly_user;",
        );

        let merged =
            load_schema_from_sources(&[sql_source(&table_file), sql_source(&grant_file)]).unwrap();
        let grants = &merged.tables["public.users"].grants;
        assert_eq!(grants.len(), 1);
        assert_eq!(grants[0].grantee, "readonly_user");
        assert_eq!(
            grants[0].privileges,
            std::collections::BTreeSet::from([
                crate::model::Privilege::Select,
                crate::model::Privilege::Insert,
            ])
        );
    }

    #[test]
    fn revoke_from_secondary_source_applied() {
        let dir1 = TempDir::new().unwrap();
        let dir2 = TempDir::new().unwrap();

        let table_file = write_sql_file(
            &dir1,
            "tables.sql",
            b"CREATE TABLE public.users (id serial PRIMARY KEY);\n\
              GRANT SELECT, INSERT ON TABLE public.users TO app_user;",
        );
        let revoke_file = write_sql_file(
            &dir2,
            "revokes.sql",
            b"REVOKE INSERT ON TABLE public.users FROM app_user;",
        );

        let merged =
            load_schema_from_sources(&[sql_source(&table_file), sql_source(&revoke_file)]).unwrap();
        let grants = &merged.tables["public.users"].grants;
        assert_eq!(grants.len(), 1);
        assert_eq!(grants[0].grantee, "app_user");
        assert_eq!(
            grants[0].privileges,
            std::collections::BTreeSet::from([crate::model::Privilege::Select])
        );
    }
}