graphile_worker_migrations_macros 0.1.1

Proc macros for graphile_worker_migrations
Documentation
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::quote;
use std::env;
use std::fs;
use std::path::{Path, PathBuf};
use syn::{parse_macro_input, LitStr};

const STATEMENT_DELIMITER: &str = "-- graphile-worker-rs:statement";
const BREAKING_MARKER: &str = "-- graphile-worker-rs:breaking";
const METADATA_PREFIX: &str = "-- graphile-worker-rs:";

#[proc_macro]
pub fn include_migrations(input: TokenStream) -> TokenStream {
    let relative_dir = parse_macro_input!(input as LitStr);

    match expand_include_migrations(&relative_dir.value()) {
        Ok(tokens) => tokens.into(),
        Err(error) => error.to_compile_error().into(),
    }
}

fn expand_include_migrations(relative_dir: &str) -> Result<proc_macro2::TokenStream, syn::Error> {
    let manifest_dir = env::var("CARGO_MANIFEST_DIR")
        .map_err(|_| error("CARGO_MANIFEST_DIR is not set for include_migrations!"))?;
    let migration_dir = Path::new(&manifest_dir).join(relative_dir);
    let migrations = load_migrations(&migration_dir)?;

    let tracked_sources = migrations.iter().map(|migration| {
        let path = LitStr::new(
            &format!(
                "{}/{}.sql",
                relative_dir.trim_end_matches('/'),
                migration.name
            ),
            Span::call_site(),
        );

        quote! {
            include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/", #path))
        }
    });

    let migrations = migrations.iter().map(|migration| {
        let name = LitStr::new(&migration.name, Span::call_site());
        let is_breaking = migration.is_breaking;
        let statements = migration
            .statements
            .iter()
            .map(|statement| LitStr::new(statement, Span::call_site()));

        quote! {
            ::graphile_worker_migrations_core::GraphileWorkerMigration {
                name: #name,
                is_breaking: #is_breaking,
                stmts: &[#(#statements),*],
            }
        }
    });

    Ok(quote! {
        {
            const _GRAPHILE_WORKER_MIGRATION_SQL_FILES: &[&str] = &[#(#tracked_sources),*];
            &[#(#migrations),*]
        }
    })
}

fn load_migrations(migration_dir: &Path) -> Result<Vec<Migration>, syn::Error> {
    let entries = fs::read_dir(migration_dir).map_err(|source| {
        error(format!(
            "failed to read migration directory `{}`: {source}",
            migration_dir.display()
        ))
    })?;

    let mut migrations = Vec::new();
    for entry in entries {
        let entry = entry.map_err(|source| {
            error(format!(
                "failed to read an entry in migration directory `{}`: {source}",
                migration_dir.display()
            ))
        })?;
        let path = entry.path();

        if !path.is_file()
            || path.extension().and_then(|extension| extension.to_str()) != Some("sql")
        {
            continue;
        }

        migrations.push(load_migration_file(path)?);
    }

    sort_and_validate_migrations(migrations)
}

fn load_migration_file(path: PathBuf) -> Result<Migration, syn::Error> {
    let file_name = path
        .file_name()
        .and_then(|file_name| file_name.to_str())
        .ok_or_else(|| {
            error(format!(
                "migration path `{}` has no UTF-8 filename",
                path.display()
            ))
        })?;
    let number = parse_migration_number(file_name)?;
    let script = fs::read_to_string(&path).map_err(|source| {
        error(format!(
            "failed to read migration `{}`: {source}",
            path.display()
        ))
    })?;
    let parsed = parse_migration_script(&script);

    if parsed.statements.is_empty() {
        return Err(error(format!(
            "migration `{}` does not contain any SQL statements",
            path.display()
        )));
    }

    Ok(Migration {
        number,
        name: format!("m{number:06}"),
        is_breaking: parsed.is_breaking,
        statements: parsed.statements,
    })
}

fn parse_migration_number(file_name: &str) -> Result<u32, syn::Error> {
    if file_name.len() != "m000001.sql".len()
        || !file_name.starts_with('m')
        || !file_name.ends_with(".sql")
    {
        return Err(error(format!(
            "migration filename `{file_name}` must match `mNNNNNN.sql`"
        )));
    }

    let number = &file_name[1..7];
    if !number.as_bytes().iter().all(u8::is_ascii_digit) {
        return Err(error(format!(
            "migration filename `{file_name}` must contain six digits"
        )));
    }

    number.parse().map_err(|source| {
        error(format!(
            "failed to parse migration number from `{file_name}`: {source}"
        ))
    })
}

fn parse_migration_script(script: &str) -> ParsedMigrationScript {
    let is_breaking = script.lines().any(|line| line.trim() == BREAKING_MARKER);
    let statements = script
        .split(STATEMENT_DELIMITER)
        .filter_map(|section| {
            let statement = strip_metadata_lines(section);
            let statement = statement.trim();

            if statement.is_empty() {
                return None;
            }

            Some(statement.to_owned())
        })
        .collect();

    ParsedMigrationScript {
        is_breaking,
        statements,
    }
}

fn strip_metadata_lines(section: &str) -> String {
    section
        .lines()
        .filter(|line| !line.trim().starts_with(METADATA_PREFIX))
        .collect::<Vec<_>>()
        .join("\n")
}

fn sort_and_validate_migrations(
    mut migrations: Vec<Migration>,
) -> Result<Vec<Migration>, syn::Error> {
    if migrations.is_empty() {
        return Err(error("no migration SQL files found"));
    }

    migrations.sort_by_key(|migration| migration.number);

    for (index, migration) in migrations.iter().enumerate() {
        if index > 0 && migrations[index - 1].number == migration.number {
            return Err(error(format!(
                "duplicate migration number m{:06}",
                migration.number
            )));
        }

        let expected = index as u32 + 1;
        if migration.number != expected {
            return Err(error(format!(
                "migration numbers must be contiguous from m000001; expected m{expected:06}, found m{:06}",
                migration.number
            )));
        }
    }

    Ok(migrations)
}

fn error(message: impl Into<String>) -> syn::Error {
    syn::Error::new(Span::call_site(), message.into())
}

#[derive(Debug)]
struct Migration {
    number: u32,
    name: String,
    is_breaking: bool,
    statements: Vec<String>,
}

struct ParsedMigrationScript {
    is_breaking: bool,
    statements: Vec<String>,
}

#[cfg(test)]
mod tests {
    use super::*;

    fn migration(number: u32) -> Migration {
        Migration {
            number,
            name: format!("m{number:06}"),
            is_breaking: false,
            statements: vec!["select 1".to_owned()],
        }
    }

    #[test]
    fn parses_valid_migration_filenames() {
        assert_eq!(parse_migration_number("m000001.sql").unwrap(), 1);
        assert_eq!(parse_migration_number("m123456.sql").unwrap(), 123456);
    }

    #[test]
    fn rejects_invalid_migration_filenames() {
        assert!(parse_migration_number("000001.sql").is_err());
        assert!(parse_migration_number("m00001.sql").is_err());
        assert!(parse_migration_number("m000001_extra.sql").is_err());
        assert!(parse_migration_number("m00000x.sql").is_err());
    }

    #[test]
    fn sorts_migrations_by_number() {
        let migrations = sort_and_validate_migrations(vec![migration(2), migration(1)]).unwrap();

        assert_eq!(migrations[0].number, 1);
        assert_eq!(migrations[1].number, 2);
    }

    #[test]
    fn rejects_non_contiguous_migrations() {
        let error = sort_and_validate_migrations(vec![migration(1), migration(3)]).unwrap_err();

        assert!(error
            .to_string()
            .contains("migration numbers must be contiguous"));
    }

    #[test]
    fn detects_breaking_marker_and_removes_metadata() {
        let parsed = parse_migration_script(
            "\n-- graphile-worker-rs:breaking\n\nselect 1;\n-- normal comment\n",
        );

        assert!(parsed.is_breaking);
        assert_eq!(parsed.statements, vec!["select 1;\n-- normal comment"]);
    }

    #[test]
    fn splits_statements_and_skips_empty_sections() {
        let parsed = parse_migration_script(
            "\nselect 1;\n-- graphile-worker-rs:statement\n\n-- graphile-worker-rs:statement\nselect 2;\n",
        );

        assert_eq!(parsed.statements, vec!["select 1;", "select 2;"]);
    }

    #[test]
    fn empty_script_has_no_statements() {
        let parsed = parse_migration_script("\n-- graphile-worker-rs:breaking\n\n");

        assert!(parsed.statements.is_empty());
    }
}