sqlx-gen 0.5.4

Generate Rust structs from database schema introspection
Documentation
use clap::Parser;
use log::{info, warn};
use sqlx::{MySqlPool, PgPool, SqlitePool};

use sqlx_gen::cli::{Cli, Command, CrudArgs, DatabaseKind, EntitiesArgs, GenerateCommand, Methods};
use sqlx_gen::codegen;
use sqlx_gen::error::Result;
use sqlx_gen::introspect;
use sqlx_gen::writer;

#[tokio::main]
async fn main() -> Result<()> {
    env_logger::Builder::from_env(
        env_logger::Env::default().default_filter_or("info"),
    )
    .init();

    let cli = Cli::parse();
    match cli.command {
        Command::Generate { subcommand } => match subcommand {
            GenerateCommand::Entities(args) => run_entities(args).await?,
            GenerateCommand::Crud(args) => run_crud(args)?,
        },
    }
    Ok(())
}

async fn run_entities(args: EntitiesArgs) -> Result<()> {
    let db_kind = args.db.database_kind()?;
    let type_overrides = args.parse_type_overrides();

    info!(
        "Connecting to {} database...",
        match db_kind {
            DatabaseKind::Postgres => "PostgreSQL",
            DatabaseKind::Mysql => "MySQL",
            DatabaseKind::Sqlite => "SQLite",
        }
    );

    let mut schema_info = match db_kind {
        DatabaseKind::Postgres => {
            let pool = PgPool::connect(&args.db.database_url).await?;
            let info =
                introspect::postgres::introspect(&pool, &args.db.schemas, args.views).await?;
            pool.close().await;
            info
        }
        DatabaseKind::Mysql => {
            let pool = MySqlPool::connect(&args.db.database_url).await?;
            let info =
                introspect::mysql::introspect(&pool, &args.db.schemas, args.views).await?;
            pool.close().await;
            info
        }
        DatabaseKind::Sqlite => {
            let pool = SqlitePool::connect(&args.db.database_url).await?;
            let info = introspect::sqlite::introspect(&pool, args.views).await?;
            pool.close().await;
            info
        }
    };

    // Filter tables if requested (whitelist)
    if let Some(ref filter) = args.tables {
        schema_info.tables.retain(|t| filter.contains(&t.name));
    }

    // Exclude tables/views if requested (blacklist)
    if let Some(ref exclude) = args.exclude_tables {
        schema_info.tables.retain(|t| !exclude.contains(&t.name));
        schema_info.views.retain(|v| !exclude.contains(&v.name));
    }

    let table_count = schema_info.tables.len();
    let view_count = schema_info.views.len();
    let enum_count = schema_info.enums.len();
    info!(
        "Found {} tables, {} views, {} enums, {} composite types, {} domains",
        table_count,
        view_count,
        enum_count,
        schema_info.composite_types.len(),
        schema_info.domains.len(),
    );

    let files = codegen::generate(
        &schema_info,
        db_kind,
        &args.derives,
        &type_overrides,
        args.single_file,
        args.time_crate,
    );

    writer::write_files(&files, &args.output_dir, args.single_file, args.dry_run)?;

    if !args.dry_run {
        info!("Done! Generated {} files.", files.len() + 1); // +1 for mod.rs
        rustfmt_dir(&args.output_dir);
    }

    Ok(())
}

fn run_crud(args: CrudArgs) -> Result<()> {
    let db_kind = args.database_kind()?;
    let entities_module = args.resolve_entities_module()?;

    // Validate that the resolved module is a Rust module path, not a file path
    if entities_module.contains('/') || entities_module.contains('\\') || entities_module.ends_with(".rs") {
        return Err(sqlx_gen::error::Error::Config(format!(
            "--entities-module must be a Rust module path (e.g. \"crate::models::users\"), got \"{}\"",
            entities_module
        )));
    }

    let entity = codegen::entity_parser::parse_entity_file(&args.entity_file)?;

    info!(
        "Parsed entity '{}' from {}",
        entity.struct_name,
        args.entity_file.display()
    );
    if entity.is_view {
        info!("  (detected as view - write methods will be skipped)");
    }

    if args.methods.is_empty() {
        return Err(sqlx_gen::error::Error::Config(
            "--methods is required. Use -m '*' for all, or specify methods: get_all,paginate,get,insert,insert_many,update,overwrite,delete".to_string()
        ));
    }

    let methods = Methods::from_list(&args.methods).map_err(sqlx_gen::error::Error::Config)?;

    let (tokens, imports) = codegen::crud_gen::generate_crud_from_parsed(
        &entity,
        db_kind,
        &entities_module,
        &methods,
        args.query_macro,
        args.pool_visibility,
    );

    let tab_spaces = codegen::detect_tab_spaces(&args.output_dir);
    let code = codegen::format_tokens_with_imports_and_tab_spaces(&tokens, &imports, tab_spaces);

    if args.dry_run {
        println!("{}", code);
    } else {
        std::fs::create_dir_all(&args.output_dir)?;

        let entity_stem = args
            .entity_file
            .file_stem()
            .and_then(|s| s.to_str())
            .unwrap_or(&entity.table_name);
        let normalized = codegen::normalize_module_name(entity_stem);
        let filename = format!("{}_repository.rs", normalized);
        let file_path = args.output_dir.join(&filename);

        let content = format!(
            "// Auto-generated by sqlx-gen. Do not edit.\n\n{}",
            code
        );
        std::fs::write(&file_path, &content)?;
        info!("Wrote {}", file_path.display());
        let edition = detect_edition(&args.output_dir);
        rustfmt_file(&file_path, &edition);

        let mod_name = format!("{}_repository", normalized);
        update_mod_rs(&args.output_dir, &mod_name)?;
    }

    Ok(())
}

fn update_mod_rs(dir: &std::path::Path, mod_name: &str) -> Result<()> {
    let mod_path = dir.join("mod.rs");
    let mod_line = format!("pub mod {};", mod_name);

    if mod_path.is_file() {
        let content = std::fs::read_to_string(&mod_path)?;
        if content.lines().any(|line| line.trim() == mod_line) {
            return Ok(());
        }
        let mut new_content = content.trim_end().to_string();
        if !new_content.is_empty() {
            new_content.push('\n');
        }
        new_content.push_str(&mod_line);
        new_content.push('\n');
        std::fs::write(&mod_path, new_content)?;
        info!("Updated {}", mod_path.display());
    } else {
        std::fs::write(&mod_path, format!("{}\n", mod_line))?;
        info!("Created {}", mod_path.display());
    }

    Ok(())
}

fn rustfmt_dir(dir: &std::path::Path) {
    let edition = detect_edition(dir);
    let entries = match std::fs::read_dir(dir) {
        Ok(e) => e,
        Err(_) => return,
    };
    for entry in entries.flatten() {
        let path = entry.path();
        if path.extension().is_some_and(|ext| ext == "rs") {
            rustfmt_file(&path, &edition);
        }
    }
}

fn rustfmt_file(path: &std::path::Path, edition: &str) {
    match std::process::Command::new("rustfmt")
        .arg("--edition")
        .arg(edition)
        .arg(path)
        .status()
    {
        Ok(status) if status.success() => {}
        Ok(status) => warn!("rustfmt exited with {}", status),
        Err(_) => warn!("rustfmt not found, skipping formatting"),
    }
}

fn detect_edition(start_dir: &std::path::Path) -> String {
    let dir = if start_dir.is_file() {
        start_dir.parent().unwrap_or(start_dir)
    } else {
        start_dir
    };

    let mut current = match dir.canonicalize() {
        Ok(p) => p,
        Err(_) => return "2021".to_string(),
    };

    loop {
        let cargo_toml = current.join("Cargo.toml");
        if cargo_toml.is_file() {
            if let Ok(content) = std::fs::read_to_string(&cargo_toml) {
                if let Some(edition) = parse_edition_from_cargo_toml(&content) {
                    return edition;
                }
            }
        }
        if !current.pop() {
            break;
        }
    }

    "2021".to_string()
}

fn parse_edition_from_cargo_toml(content: &str) -> Option<String> {
    for line in content.lines() {
        let trimmed = line.trim();
        if let Some(rest) = trimmed.strip_prefix("edition") {
            let rest = rest.trim_start();
            if let Some(value) = rest.strip_prefix('=') {
                let value = value.trim().trim_matches('"').trim_matches('\'');
                if !value.is_empty() {
                    return Some(value.to_string());
                }
            }
        }
    }
    None
}