sqlx-gen 0.5.4

Generate Rust structs from database schema introspection
Documentation
use std::path::Path;

use crate::error::Result;

use crate::codegen::GeneratedFile;

const COMMENT: &str = "// Auto-generated by sqlx-gen. Do not edit.";
const INNER_ATTR: &str = "#![allow(unused_attributes)]";

pub fn write_files(
    files: &[GeneratedFile],
    output_dir: &Path,
    single_file: bool,
    dry_run: bool,
) -> Result<()> {
    if dry_run {
        for f in files {
            println!("{}", build_file_content(f));
            println!();
        }
        return Ok(());
    }

    std::fs::create_dir_all(output_dir)?;

    if single_file {
        write_single_file(files, output_dir)?;
    } else {
        write_multi_files(files, output_dir)?;
    }

    Ok(())
}

fn build_file_content(f: &GeneratedFile) -> String {
    let mut content = String::new();
    content.push_str(COMMENT);
    content.push('\n');
    if let Some(origin) = &f.origin {
        content.push_str(&format!("// {}\n", origin));
    }
    content.push('\n');
    content.push_str(INNER_ATTR);
    content.push_str("\n\n");
    content.push_str(&f.code);
    content
}

fn write_single_file(files: &[GeneratedFile], output_dir: &Path) -> Result<()> {
    let mut content = format!("{}\n\n{}\n\n", COMMENT, INNER_ATTR);

    for f in files {
        if let Some(origin) = &f.origin {
            content.push_str(&format!("// --- {} ---\n\n", origin));
        }
        content.push_str(&f.code);
        content.push('\n');
    }

    let path = output_dir.join("models.rs");
    std::fs::write(&path, &content)?;
    log::info!("Wrote {}", path.display());

    Ok(())
}

fn write_multi_files(files: &[GeneratedFile], output_dir: &Path) -> Result<()> {
    let mut mod_entries = Vec::new();

    for f in files {
        let content = build_file_content(f);
        let path = output_dir.join(&f.filename);
        std::fs::write(&path, &content)?;
        log::info!("Wrote {}", path.display());

        let mod_name = f.filename.strip_suffix(".rs").unwrap_or(&f.filename);
        mod_entries.push(mod_name.to_string());
    }

    // Generate mod.rs
    let mut mod_content = format!("{}\n\n{}\n\n", COMMENT, INNER_ATTR);
    for m in &mod_entries {
        mod_content.push_str(&format!("pub mod {};\n", m));
    }

    let mod_path = output_dir.join("mod.rs");
    std::fs::write(&mod_path, &mod_content)?;
    log::info!("Wrote {}", mod_path.display());

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::codegen::GeneratedFile;

    fn make_file(filename: &str, code: &str, origin: Option<&str>) -> GeneratedFile {
        GeneratedFile {
            filename: filename.to_string(),
            origin: origin.map(|s| s.to_string()),
            code: code.to_string(),
        }
    }

    // ========== build_file_content ==========

    #[test]
    fn test_build_content_with_origin() {
        let f = make_file("users.rs", "pub struct Users {}", Some("Table: public.users"));
        let content = build_file_content(&f);
        assert!(content.contains(COMMENT));
        assert!(content.contains(INNER_ATTR));
        assert!(content.contains("// Table: public.users"));
        assert!(content.contains("pub struct Users {}"));
    }

    #[test]
    fn test_build_content_without_origin() {
        let f = make_file("types.rs", "pub enum Status {}", None);
        let content = build_file_content(&f);
        assert!(content.contains(COMMENT));
        assert!(content.contains(INNER_ATTR));
        assert!(!content.contains("// Table:"));
        assert!(content.contains("pub enum Status {}"));
    }

    #[test]
    fn test_build_content_header_value() {
        let f = make_file("x.rs", "", None);
        let content = build_file_content(&f);
        assert!(content.starts_with("// Auto-generated by sqlx-gen. Do not edit."));
    }

    #[test]
    fn test_build_content_preserves_code() {
        let code = "use chrono::NaiveDateTime;\n\npub struct Foo {\n    pub x: i32,\n}";
        let f = make_file("foo.rs", code, None);
        let content = build_file_content(&f);
        assert!(content.contains(code));
    }

    #[test]
    fn test_build_content_origin_format() {
        let f = make_file("x.rs", "code", Some("Table: public.users"));
        let content = build_file_content(&f);
        assert!(content.contains("// Table: public.users\n"));
    }

    #[test]
    fn test_build_content_empty_code() {
        let f = make_file("x.rs", "", Some("Table: public.x"));
        let content = build_file_content(&f);
        assert!(content.contains(COMMENT));
        assert!(content.contains(INNER_ATTR));
        assert!(content.contains("// Table: public.x"));
    }

    // ========== write_files dry_run ==========

    #[test]
    fn test_dry_run_returns_ok() {
        let files = vec![make_file("users.rs", "code", Some("origin"))];
        let dir = tempfile::tempdir().unwrap();
        let result = write_files(&files, dir.path(), false, true);
        assert!(result.is_ok());
    }

    #[test]
    fn test_dry_run_no_files_created() {
        let files = vec![make_file("users.rs", "code", Some("origin"))];
        let dir = tempfile::tempdir().unwrap();
        let sub = dir.path().join("output");
        let _ = write_files(&files, &sub, false, true);
        // Output dir should NOT be created in dry_run mode
        assert!(!sub.exists());
    }

    #[test]
    fn test_dry_run_empty_files() {
        let dir = tempfile::tempdir().unwrap();
        let result = write_files(&[], dir.path(), false, true);
        assert!(result.is_ok());
    }

    // ========== write_multi_files ==========

    #[test]
    fn test_multi_creates_files_and_mod() {
        let files = vec![
            make_file("users.rs", "pub struct Users {}", Some("Table: public.users")),
            make_file("posts.rs", "pub struct Posts {}", Some("Table: public.posts")),
        ];
        let dir = tempfile::tempdir().unwrap();
        write_files(&files, dir.path(), false, false).unwrap();

        assert!(dir.path().join("users.rs").exists());
        assert!(dir.path().join("posts.rs").exists());
        assert!(dir.path().join("mod.rs").exists());
    }

    #[test]
    fn test_multi_mod_rs_content() {
        let files = vec![
            make_file("users.rs", "code", Some("origin")),
            make_file("types.rs", "code", None),
        ];
        let dir = tempfile::tempdir().unwrap();
        write_files(&files, dir.path(), false, false).unwrap();

        let mod_content = std::fs::read_to_string(dir.path().join("mod.rs")).unwrap();
        assert!(mod_content.contains("pub mod users;"));
        assert!(mod_content.contains("pub mod types;"));
    }

    #[test]
    fn test_multi_file_has_header() {
        let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
        let dir = tempfile::tempdir().unwrap();
        write_files(&files, dir.path(), false, false).unwrap();

        let content = std::fs::read_to_string(dir.path().join("users.rs")).unwrap();
        assert!(content.starts_with(COMMENT));
    }

    #[test]
    fn test_multi_file_has_origin() {
        let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
        let dir = tempfile::tempdir().unwrap();
        write_files(&files, dir.path(), false, false).unwrap();

        let content = std::fs::read_to_string(dir.path().join("users.rs")).unwrap();
        assert!(content.contains("// Table: public.users"));
    }

    #[test]
    fn test_multi_creates_output_dir() {
        let dir = tempfile::tempdir().unwrap();
        let sub = dir.path().join("nested").join("output");
        let files = vec![make_file("users.rs", "code", Some("origin"))];
        write_files(&files, &sub, false, false).unwrap();
        assert!(sub.join("users.rs").exists());
    }

    #[test]
    fn test_multi_file_no_origin() {
        let files = vec![make_file("types.rs", "code", None)];
        let dir = tempfile::tempdir().unwrap();
        write_files(&files, dir.path(), false, false).unwrap();

        let content = std::fs::read_to_string(dir.path().join("types.rs")).unwrap();
        assert!(content.contains(COMMENT));
        assert!(content.contains(INNER_ATTR));
        // No origin line
        assert!(!content.contains("// Table:"));
    }

    // ========== write_single_file ==========

    #[test]
    fn test_single_creates_models_rs() {
        let files = vec![make_file("users.rs", "pub struct Users {}", Some("Table: public.users"))];
        let dir = tempfile::tempdir().unwrap();
        write_files(&files, dir.path(), true, false).unwrap();
        assert!(dir.path().join("models.rs").exists());
    }

    #[test]
    fn test_single_starts_with_header() {
        let files = vec![make_file("users.rs", "code", Some("origin"))];
        let dir = tempfile::tempdir().unwrap();
        write_files(&files, dir.path(), true, false).unwrap();

        let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
        assert!(content.starts_with(COMMENT));
    }

    #[test]
    fn test_single_has_section_separator() {
        let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
        let dir = tempfile::tempdir().unwrap();
        write_files(&files, dir.path(), true, false).unwrap();

        let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
        assert!(content.contains("// --- Table: public.users ---"));
    }

    #[test]
    fn test_single_concatenates_all_code() {
        let files = vec![
            make_file("users.rs", "struct Users;", Some("Table: public.users")),
            make_file("posts.rs", "struct Posts;", Some("Table: public.posts")),
        ];
        let dir = tempfile::tempdir().unwrap();
        write_files(&files, dir.path(), true, false).unwrap();

        let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
        assert!(content.contains("struct Users;"));
        assert!(content.contains("struct Posts;"));
    }

    #[test]
    fn test_single_no_origin_no_separator() {
        let files = vec![make_file("types.rs", "code", None)];
        let dir = tempfile::tempdir().unwrap();
        write_files(&files, dir.path(), true, false).unwrap();

        let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
        assert!(!content.contains("// ---"));
    }
}