use console::style;
use heck::ToUpperCamelCase;
use sqlx::PgPool;
#[derive(Debug, sqlx::FromRow)]
struct ColumnInfo {
column_name: String,
data_type: String,
is_nullable: String,
#[allow(dead_code)]
column_default: Option<String>,
}
fn pg_type_info(pg_type: &str, _col_name: &str) -> (&'static str, &'static str) {
match pg_type {
"text" | "character varying" | "varchar" | "char" | "character" => ("String", "string"),
"bigint" | "int8" => ("i64", "big_integer"),
"bigserial" | "serial8" => ("i64", "big_integer"),
"integer" | "int" | "int4" => ("i64", "integer"),
"serial" | "serial4" => ("i32", "integer"),
"smallint" | "int2" => ("i16", "small_integer"),
"boolean" | "bool" => ("bool", "boolean"),
"double precision" | "float8" | "real" | "float4" => ("f64", "double"),
"numeric" | "decimal" => ("f64", "decimal"),
"uuid" => ("String", "uuid"),
"jsonb" | "json" => ("serde_json::Value", "json"),
"timestamp with time zone" | "timestamptz" => {
("chrono::DateTime<chrono::Utc>", "timestamp")
}
"timestamp" | "timestamp without time zone" => ("chrono::NaiveDateTime", "timestamp"),
"date" => ("chrono::NaiveDate", "date"),
"bytea" => ("Vec<u8>", "binary"),
"inet" => ("String", "ip_address"),
"macaddr" => ("String", "mac_address"),
_ => ("String", "column"),
}
}
fn builder_call(col: &ColumnInfo, is_pk: bool, is_unique: bool) -> String {
let name = &col.column_name;
let nullable = col.is_nullable == "YES";
if is_pk && name == "id" {
let dt = col.data_type.as_str();
if matches!(
dt,
"bigint" | "int8" | "bigserial" | "serial8" | "integer" | "int4" | "serial"
) {
return "t.id();".to_string();
}
if dt == "uuid" {
return "t.uuid_id();".to_string();
}
}
if name == "created_at" || name == "updated_at" {
return format!("// timestamps → use t.timestamps(); instead of {name}");
}
if name == "deleted_at" {
return "t.soft_deletes();".to_string();
}
let (_, method) = pg_type_info(&col.data_type, name);
let call = if method == "column" {
format!("t.column(\"{name}\", \"{}\")", col.data_type)
} else if method == "decimal" {
format!("t.decimal(\"{name}\", 18, 6)")
} else if method == "uuid" {
format!("t.column(\"{name}\", \"UUID\")")
} else {
format!("t.{method}(\"{name}\")")
};
let mut mods = call;
if !nullable {
mods += ".not_null()";
}
if is_unique {
mods += ".unique()";
}
mods + ";"
}
pub async fn pull(table: &str, output: Option<&str>) -> anyhow::Result<()> {
use chrono::Local;
dotenvy::dotenv().ok();
let database_url = std::env::var("DATABASE_URL")
.map_err(|_| anyhow::anyhow!("DATABASE_URL not set — add it to .env"))?;
println!(
"{} Connecting to database...",
style("db:pull").green().bold()
);
let pool = PgPool::connect(&database_url)
.await
.map_err(|e| anyhow::anyhow!("Failed to connect: {e}"))?;
let columns: Vec<ColumnInfo> = sqlx::query_as::<_, ColumnInfo>(
r#"
SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = $1
ORDER BY ordinal_position
"#,
)
.bind(table)
.fetch_all(&pool)
.await?;
if columns.is_empty() {
anyhow::bail!("Table '{table}' not found in public schema.");
}
let pk_cols: Vec<(String,)> = sqlx::query_as(
r#"
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = 'public'
AND tc.table_name = $1
"#,
)
.bind(table)
.fetch_all(&pool)
.await?;
let pk_set: std::collections::HashSet<String> = pk_cols.into_iter().map(|(c,)| c).collect();
let unique_cols: Vec<(String,)> = sqlx::query_as(
r#"
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'UNIQUE'
AND tc.table_schema = 'public'
AND tc.table_name = $1
"#,
)
.bind(table)
.fetch_all(&pool)
.await?;
let unique_set: std::collections::HashSet<String> =
unique_cols.into_iter().map(|(c,)| c).collect();
println!(
"{} Found {} column(s) in '{table}'",
style("✓").green(),
columns.len()
);
let pascal = table.trim_end_matches('s').to_upper_camel_case();
let model_code = generate_model(&pascal, table, &columns);
let ts = Local::now().format("%Y_%m_%d_%H%M%S").to_string();
let migration_code = generate_migration(&pascal, table, &ts, &columns, &pk_set, &unique_set);
let migration_struct = format!("Create{}sTable", pascal);
let migration_dir = if std::path::Path::new("database/migrations").exists() {
"database/migrations"
} else {
"migrations"
};
match output {
Some(path) => {
std::fs::write(path, &model_code)?;
println!("{} Model written to {path}", style("✓").green().bold());
}
None => {
let migration_path = format!("{migration_dir}/{ts}_create_{table}_table.rs");
let model_path = format!("src/app/models/{table}.rs");
println!(
"\n{}",
style("─── Migration ──────────────────────────────").dim()
);
println!("{}", &migration_code);
println!(
"{}",
style("─── Model ──────────────────────────────────").dim()
);
println!("{}", &model_code);
println!(
"{}",
style("─── Write files? ───────────────────────────").dim()
);
println!(" Migration → {migration_path}");
println!(" Model → {model_path}");
println!();
println!(
" Run with {} to write both files:",
style("--output").yellow()
);
println!(" rok db:pull {table} --output {model_path}",);
println!();
println!(" Then register the migration:");
println!(" MigrationRunner::new(pool).migration({migration_struct}).run().await?;");
}
}
Ok(())
}
fn generate_migration(
pascal: &str,
table: &str,
timestamp: &str,
columns: &[ColumnInfo],
pk_set: &std::collections::HashSet<String>,
unique_set: &std::collections::HashSet<String>,
) -> String {
let has_timestamps = columns.iter().any(|c| c.column_name == "created_at")
&& columns.iter().any(|c| c.column_name == "updated_at");
let mut col_calls = Vec::new();
let mut emitted_timestamps = false;
for col in columns {
let is_pk = pk_set.contains(&col.column_name);
let is_unique = unique_set.contains(&col.column_name);
let call = builder_call(col, is_pk, is_unique);
if col.column_name == "created_at" || col.column_name == "updated_at" {
if has_timestamps && !emitted_timestamps {
col_calls.push("t.timestamps();".to_string());
emitted_timestamps = true;
}
continue;
}
if call.starts_with("//") {
continue;
}
col_calls.push(call);
}
let body = col_calls
.iter()
.map(|c| format!(" {c}"))
.collect::<Vec<_>>()
.join("\n");
format!(
r#"use rok_orm_migrate::{{Migration, SchemaExecutor}};
use async_trait::async_trait;
pub struct Create{pascal}sTable;
#[async_trait]
impl Migration for Create{pascal}sTable {{
fn name(&self) -> &str {{
"{timestamp}_create_{table}_table"
}}
async fn up(&self, schema: &SchemaExecutor) -> anyhow::Result<()> {{
schema.create("{table}", |t| {{
{body}
}}).await
}}
async fn down(&self, schema: &SchemaExecutor) -> anyhow::Result<()> {{
schema.drop_table_if_exists("{table}").await
}}
}}
"#
)
}
fn generate_model(pascal: &str, table: &str, columns: &[ColumnInfo]) -> String {
let mut uses = vec![
"use rok_orm::Model;".to_string(),
"use serde::{Deserialize, Serialize};".to_string(),
];
let mut needs_chrono = false;
let mut needs_json = false;
for col in columns {
let (rust_t, _) = pg_type_info(&col.data_type, &col.column_name);
if rust_t.contains("chrono") {
needs_chrono = true;
}
if rust_t.contains("Value") {
needs_json = true;
}
}
if needs_chrono {
uses.push("use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc};".to_string());
}
if needs_json {
uses.push("use serde_json::Value;".to_string());
}
let uses_str = uses.join("\n");
let mut fields = String::new();
for col in columns {
let (rust_t, _) = pg_type_info(&col.data_type, &col.column_name);
let rust_t = match rust_t {
"chrono::DateTime<chrono::Utc>" => "DateTime<Utc>",
"chrono::NaiveDate" => "NaiveDate",
"chrono::NaiveDateTime" => "NaiveDateTime",
"serde_json::Value" => "Value",
other => other,
};
let nullable = col.is_nullable == "YES";
let field_type = if nullable {
format!("Option<{rust_t}>")
} else {
rust_t.to_string()
};
fields.push_str(&format!(" pub {}: {field_type},\n", col.column_name));
}
format!(
"{uses_str}\n\n#[derive(Debug, Clone, Model, Serialize, Deserialize, sqlx::FromRow)]\n#[rok_orm(table = \"{table}\")]\npub struct {pascal} {{\n{fields}}}\n"
)
}
pub async fn diff(model_file: &str, table: Option<&str>) -> anyhow::Result<()> {
dotenvy::dotenv().ok();
let database_url =
std::env::var("DATABASE_URL").map_err(|_| anyhow::anyhow!("DATABASE_URL not set"))?;
if !std::path::Path::new(model_file).exists() {
anyhow::bail!("Model file not found: {model_file}");
}
let source = std::fs::read_to_string(model_file)?;
let table_name = table
.map(|t| t.to_string())
.or_else(|| extract_table_name(&source))
.unwrap_or_else(|| {
std::path::Path::new(model_file)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown")
.to_string()
+ "s"
});
let model_fields = parse_struct_fields(&source);
let model_field_names: std::collections::HashSet<String> =
model_fields.iter().map(|(name, _)| name.clone()).collect();
let model_field_types: std::collections::HashMap<String, String> =
model_fields.into_iter().collect();
println!(
"{} Connecting to database...",
style("db:diff").green().bold()
);
let pool = PgPool::connect(&database_url)
.await
.map_err(|e| anyhow::anyhow!("Failed to connect: {e}"))?;
let db_columns: Vec<ColumnInfo> = sqlx::query_as::<_, ColumnInfo>(
r#"
SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = $1
ORDER BY ordinal_position
"#,
)
.bind(&table_name)
.fetch_all(&pool)
.await?;
if db_columns.is_empty() {
anyhow::bail!(
"Table '{table_name}' not found in DB. Run 'rok db:migrate' first or pass --table."
);
}
let db_col_names: std::collections::HashSet<String> =
db_columns.iter().map(|c| c.column_name.clone()).collect();
println!(
"\n{} Schema diff for '{table_name}':",
style("db:diff").bold()
);
let mut has_diff = false;
for name in &model_field_names {
if !db_col_names.contains(name) {
println!(
" {} {name:<30} in model, missing in DB → add a migration",
style("+").green().bold()
);
has_diff = true;
}
}
for col in &db_columns {
if !model_field_names.contains(&col.column_name) {
println!(
" {} {:<30} in DB, missing in model → add field or #[rok_orm(skip)]",
style("-").red().bold(),
col.column_name,
);
has_diff = true;
}
}
for col in &db_columns {
if let Some(rust_type) = model_field_types.get(&col.column_name) {
let pg_cat = pg_type_category(&col.data_type);
let rs_cat = rust_type_category(rust_type);
if pg_cat != "other" && rs_cat != "other" && pg_cat != rs_cat {
let (suggested, _) = pg_type_info(&col.data_type, &col.column_name);
println!(
" {} {:<30} type mismatch: model `{}`, DB `{}` → expected `{suggested}`",
style("~").yellow().bold(),
col.column_name,
rust_type,
col.data_type,
);
has_diff = true;
}
}
}
if !has_diff {
println!(
" {} No drift detected — model and database are in sync.",
style("✓").green()
);
}
Ok(())
}
fn extract_table_name(source: &str) -> Option<String> {
let marker = "table = \"";
let start = source.find(marker)? + marker.len();
let end = source[start..].find('"')? + start;
Some(source[start..end].to_string())
}
fn parse_struct_fields(source: &str) -> Vec<(String, String)> {
let mut in_struct = false;
let mut fields = Vec::new();
for line in source.lines() {
let trimmed = line.trim();
if trimmed.starts_with("pub struct ") {
in_struct = true;
continue;
}
if in_struct {
if trimmed == "}" {
break;
}
if trimmed.starts_with('#') || trimmed.starts_with("//") {
continue;
}
if let Some(rest) = trimmed.strip_prefix("pub ") {
let parts: Vec<&str> = rest.splitn(2, ':').collect();
if parts.len() == 2 {
let name = parts[0].trim().to_string();
let rust_type = parts[1].trim().trim_end_matches(',').to_string();
if !name.is_empty() && !name.starts_with("//") {
fields.push((name, rust_type));
}
}
}
}
}
fields
}
fn pg_type_category(pg_type: &str) -> &'static str {
match pg_type {
"text" | "character varying" | "varchar" | "char" | "character" | "uuid" | "inet"
| "cidr" | "macaddr" => "text",
"bigint" | "int8" | "bigserial" | "serial8" | "integer" | "int" | "int4" | "serial"
| "serial4" | "smallint" | "int2" => "integer",
"boolean" | "bool" => "boolean",
"double precision" | "float8" | "real" | "float4" | "numeric" | "decimal" => "float",
"jsonb" | "json" => "json",
"bytea" => "binary",
"timestamp with time zone"
| "timestamptz"
| "timestamp"
| "timestamp without time zone" => "datetime",
"date" => "date",
_ => "other",
}
}
fn rust_type_category(rust_type: &str) -> &'static str {
let t = rust_type.trim();
let inner = if t.starts_with("Option<") && t.ends_with('>') {
t[7..t.len() - 1].trim()
} else {
t
};
match inner {
"String" | "&str" | "&'static str" => "text",
"i64" | "i32" | "i16" | "i8" | "u64" | "u32" | "u16" | "u8" | "usize" | "isize" => {
"integer"
}
"f64" | "f32" => "float",
"bool" => "boolean",
"serde_json::Value" | "Value" => "json",
"Vec<u8>" => "binary",
s if s.contains("DateTime") || s.contains("NaiveDateTime") => "datetime",
"NaiveDate" | "chrono::NaiveDate" => "date",
_ => "other",
}
}