use clap::Parser;
use log::{info, warn};
use sqlx::{MySqlPool, PgPool, SqlitePool};
use sqlx_gen::cli::{Cli, Command, CrudArgs, DatabaseKind, EntitiesArgs, GenerateCommand, Methods};
use sqlx_gen::codegen;
use sqlx_gen::error::Result;
use sqlx_gen::introspect;
use sqlx_gen::writer;
#[tokio::main]
async fn main() -> Result<()> {
env_logger::Builder::from_env(
env_logger::Env::default().default_filter_or("info"),
)
.init();
let cli = Cli::parse();
match cli.command {
Command::Generate { subcommand } => match subcommand {
GenerateCommand::Entities(args) => run_entities(args).await?,
GenerateCommand::Crud(args) => run_crud(args)?,
},
}
Ok(())
}
async fn run_entities(args: EntitiesArgs) -> Result<()> {
let db_kind = args.db.database_kind()?;
let type_overrides = args.parse_type_overrides();
info!(
"Connecting to {} database...",
match db_kind {
DatabaseKind::Postgres => "PostgreSQL",
DatabaseKind::Mysql => "MySQL",
DatabaseKind::Sqlite => "SQLite",
}
);
let mut schema_info = match db_kind {
DatabaseKind::Postgres => {
let pool = PgPool::connect(&args.db.database_url).await?;
let info =
introspect::postgres::introspect(&pool, &args.db.schemas, args.views).await?;
pool.close().await;
info
}
DatabaseKind::Mysql => {
let pool = MySqlPool::connect(&args.db.database_url).await?;
let info =
introspect::mysql::introspect(&pool, &args.db.schemas, args.views).await?;
pool.close().await;
info
}
DatabaseKind::Sqlite => {
let pool = SqlitePool::connect(&args.db.database_url).await?;
let info = introspect::sqlite::introspect(&pool, args.views).await?;
pool.close().await;
info
}
};
if let Some(ref filter) = args.tables {
schema_info.tables.retain(|t| filter.contains(&t.name));
}
if let Some(ref exclude) = args.exclude_tables {
schema_info.tables.retain(|t| !exclude.contains(&t.name));
schema_info.views.retain(|v| !exclude.contains(&v.name));
}
let table_count = schema_info.tables.len();
let view_count = schema_info.views.len();
let enum_count = schema_info.enums.len();
info!(
"Found {} tables, {} views, {} enums, {} composite types, {} domains",
table_count,
view_count,
enum_count,
schema_info.composite_types.len(),
schema_info.domains.len(),
);
let files = codegen::generate(
&schema_info,
db_kind,
&args.derives,
&type_overrides,
args.single_file,
args.time_crate,
);
writer::write_files(&files, &args.output_dir, args.single_file, args.dry_run)?;
if !args.dry_run {
info!("Done! Generated {} files.", files.len() + 1); rustfmt_dir(&args.output_dir);
}
Ok(())
}
fn run_crud(args: CrudArgs) -> Result<()> {
let db_kind = args.database_kind()?;
let entities_module = args.resolve_entities_module()?;
if entities_module.contains('/') || entities_module.contains('\\') || entities_module.ends_with(".rs") {
return Err(sqlx_gen::error::Error::Config(format!(
"--entities-module must be a Rust module path (e.g. \"crate::models::users\"), got \"{}\"",
entities_module
)));
}
let entity = codegen::entity_parser::parse_entity_file(&args.entity_file)?;
info!(
"Parsed entity '{}' from {}",
entity.struct_name,
args.entity_file.display()
);
if entity.is_view {
info!(" (detected as view - write methods will be skipped)");
}
if args.methods.is_empty() {
return Err(sqlx_gen::error::Error::Config(
"--methods is required. Use -m '*' for all, or specify methods: get_all,paginate,get,insert,insert_many,update,overwrite,delete".to_string()
));
}
let methods = Methods::from_list(&args.methods).map_err(sqlx_gen::error::Error::Config)?;
let (tokens, imports) = codegen::crud_gen::generate_crud_from_parsed(
&entity,
db_kind,
&entities_module,
&methods,
args.query_macro,
args.pool_visibility,
);
let tab_spaces = codegen::detect_tab_spaces(&args.output_dir);
let code = codegen::format_tokens_with_imports_and_tab_spaces(&tokens, &imports, tab_spaces);
if args.dry_run {
println!("{}", code);
} else {
std::fs::create_dir_all(&args.output_dir)?;
let entity_stem = args
.entity_file
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or(&entity.table_name);
let normalized = codegen::normalize_module_name(entity_stem);
let filename = format!("{}_repository.rs", normalized);
let file_path = args.output_dir.join(&filename);
let content = format!(
"// Auto-generated by sqlx-gen. Do not edit.\n\n{}",
code
);
std::fs::write(&file_path, &content)?;
info!("Wrote {}", file_path.display());
let edition = detect_edition(&args.output_dir);
rustfmt_file(&file_path, &edition);
let mod_name = format!("{}_repository", normalized);
update_mod_rs(&args.output_dir, &mod_name)?;
}
Ok(())
}
fn update_mod_rs(dir: &std::path::Path, mod_name: &str) -> Result<()> {
let mod_path = dir.join("mod.rs");
let mod_line = format!("pub mod {};", mod_name);
if mod_path.is_file() {
let content = std::fs::read_to_string(&mod_path)?;
if content.lines().any(|line| line.trim() == mod_line) {
return Ok(());
}
let mut new_content = content.trim_end().to_string();
if !new_content.is_empty() {
new_content.push('\n');
}
new_content.push_str(&mod_line);
new_content.push('\n');
std::fs::write(&mod_path, new_content)?;
info!("Updated {}", mod_path.display());
} else {
std::fs::write(&mod_path, format!("{}\n", mod_line))?;
info!("Created {}", mod_path.display());
}
Ok(())
}
fn rustfmt_dir(dir: &std::path::Path) {
let edition = detect_edition(dir);
let entries = match std::fs::read_dir(dir) {
Ok(e) => e,
Err(_) => return,
};
for entry in entries.flatten() {
let path = entry.path();
if path.extension().is_some_and(|ext| ext == "rs") {
rustfmt_file(&path, &edition);
}
}
}
fn rustfmt_file(path: &std::path::Path, edition: &str) {
match std::process::Command::new("rustfmt")
.arg("--edition")
.arg(edition)
.arg(path)
.status()
{
Ok(status) if status.success() => {}
Ok(status) => warn!("rustfmt exited with {}", status),
Err(_) => warn!("rustfmt not found, skipping formatting"),
}
}
fn detect_edition(start_dir: &std::path::Path) -> String {
let dir = if start_dir.is_file() {
start_dir.parent().unwrap_or(start_dir)
} else {
start_dir
};
let mut current = match dir.canonicalize() {
Ok(p) => p,
Err(_) => return "2021".to_string(),
};
loop {
let cargo_toml = current.join("Cargo.toml");
if cargo_toml.is_file() {
if let Ok(content) = std::fs::read_to_string(&cargo_toml) {
if let Some(edition) = parse_edition_from_cargo_toml(&content) {
return edition;
}
}
}
if !current.pop() {
break;
}
}
"2021".to_string()
}
fn parse_edition_from_cargo_toml(content: &str) -> Option<String> {
for line in content.lines() {
let trimmed = line.trim();
if let Some(rest) = trimmed.strip_prefix("edition") {
let rest = rest.trim_start();
if let Some(value) = rest.strip_prefix('=') {
let value = value.trim().trim_matches('"').trim_matches('\'');
if !value.is_empty() {
return Some(value.to_string());
}
}
}
}
None
}