use console::style;
use heck::ToUpperCamelCase;
use sqlx::PgPool;
#[derive(Debug, sqlx::FromRow)]
struct ColumnInfo {
column_name: String,
data_type: String,
is_nullable: String,
#[allow(dead_code)]
column_default: Option<String>,
}
fn pg_type_to_rust(pg_type: &str) -> (&'static str, &'static str) {
match pg_type {
"text" | "character varying" | "varchar" | "char" | "character" => ("String", ""),
"bigint" | "int8" | "integer" | "int" | "int4" | "smallint" | "int2" => ("i64", ""),
"boolean" | "bool" => ("bool", ""),
"double precision" | "float8" | "real" | "float4" => ("f64", ""),
"numeric" | "decimal" => ("f64", ""),
"uuid" => ("String", "// uuid"),
"jsonb" | "json" => ("serde_json::Value", "#[sqlx(json)]"),
"timestamp with time zone" | "timestamptz" => ("chrono::DateTime<chrono::Utc>", ""),
"timestamp" | "timestamp without time zone" => ("chrono::NaiveDateTime", ""),
"date" => ("chrono::NaiveDate", ""),
"bytea" => ("Vec<u8>", ""),
"bigserial" | "serial8" => ("i64", "// auto-increment"),
"serial" | "serial4" => ("i32", "// auto-increment"),
_ => ("String", "// unknown pg type"),
}
}
pub async fn pull(table: &str, output: Option<&str>) -> anyhow::Result<()> {
dotenvy::dotenv().ok();
let database_url = std::env::var("DATABASE_URL")
.map_err(|_| anyhow::anyhow!("DATABASE_URL not set — add it to .env"))?;
println!(
"{} Connecting to database...",
style("db:pull").green().bold()
);
let pool = PgPool::connect(&database_url)
.await
.map_err(|e| anyhow::anyhow!("Failed to connect: {e}"))?;
let columns: Vec<ColumnInfo> = sqlx::query_as::<_, ColumnInfo>(
r#"
SELECT
column_name,
data_type,
is_nullable,
column_default
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = $1
ORDER BY ordinal_position
"#,
)
.bind(table)
.fetch_all(&pool)
.await?;
if columns.is_empty() {
anyhow::bail!("Table '{table}' not found in public schema. Check the table name.");
}
let pascal = table.trim_end_matches('s').to_upper_camel_case();
println!(
"{} Found {} columns in '{table}'",
style("✓").green(),
columns.len()
);
let model_code = generate_model_from_columns(&pascal, table, &columns);
match output {
Some(path) => {
std::fs::write(path, &model_code)?;
println!("{} Model written to {path}", style("✓").green().bold());
}
None => {
println!("\n{}", style("Generated model:").bold());
println!("{}", model_code);
}
}
Ok(())
}
fn generate_model_from_columns(pascal: &str, table: &str, columns: &[ColumnInfo]) -> String {
let mut uses = vec![
"use rok_orm::Model;".to_string(),
"use serde::{Deserialize, Serialize};".to_string(),
];
let mut needs_chrono = false;
let mut needs_json = false;
for col in columns {
let (rust_t, _) = pg_type_to_rust(&col.data_type);
if rust_t.contains("chrono") {
needs_chrono = true;
}
if rust_t.contains("Value") {
needs_json = true;
}
}
if needs_chrono {
uses.push("use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc};".to_string());
}
if needs_json {
uses.push("use serde_json::Value;".to_string());
}
let uses_str = uses.join("\n");
let mut fields = String::new();
for col in columns {
let (rust_t, attr) = pg_type_to_rust(&col.data_type);
let rust_t = match rust_t {
"chrono::DateTime<chrono::Utc>" => "DateTime<Utc>",
"chrono::NaiveDate" => "NaiveDate",
"chrono::NaiveDateTime" => "NaiveDateTime",
"serde_json::Value" => "Value",
other => other,
};
let nullable = col.is_nullable == "YES";
let field_type = if nullable {
format!("Option<{}>", rust_t)
} else {
rust_t.to_string()
};
if !attr.is_empty() {
fields.push_str(&format!(" {attr}\n"));
}
fields.push_str(&format!(" pub {}: {},\n", col.column_name, field_type));
}
format!(
"{uses_str}\n\n#[derive(Debug, Clone, Model, Serialize, Deserialize, sqlx::FromRow)]\n#[rok_orm(table = \"{table}\")]\npub struct {pascal} {{\n{fields}}}\n"
)
}
pub async fn diff(model_file: &str, table: Option<&str>) -> anyhow::Result<()> {
dotenvy::dotenv().ok();
let database_url =
std::env::var("DATABASE_URL").map_err(|_| anyhow::anyhow!("DATABASE_URL not set"))?;
if !std::path::Path::new(model_file).exists() {
anyhow::bail!("Model file not found: {model_file}");
}
let source = std::fs::read_to_string(model_file)?;
let table_name = table
.map(|t| t.to_string())
.or_else(|| extract_table_name(&source))
.unwrap_or_else(|| {
std::path::Path::new(model_file)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown")
.to_string()
+ "s"
});
let model_fields = parse_struct_fields(&source);
println!(
"{} Connecting to database...",
style("db:diff").green().bold()
);
let pool = PgPool::connect(&database_url)
.await
.map_err(|e| anyhow::anyhow!("Failed to connect: {e}"))?;
let db_columns: Vec<ColumnInfo> = sqlx::query_as::<_, ColumnInfo>(
r#"
SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = $1
ORDER BY ordinal_position
"#,
)
.bind(&table_name)
.fetch_all(&pool)
.await?;
let db_col_names: std::collections::HashSet<String> =
db_columns.iter().map(|c| c.column_name.clone()).collect();
let model_field_names: std::collections::HashSet<String> =
model_fields.iter().map(|f| f.to_string()).collect();
println!(
"\n{} Schema diff for '{table_name}':",
style("db:diff").bold()
);
let mut has_diff = false;
for name in &model_field_names {
if !db_col_names.contains(name) {
println!(
" {} {name} — in model, missing in DB (add a migration)",
style("+").green()
);
has_diff = true;
}
}
for col in &db_columns {
if !model_field_names.contains(&col.column_name) {
println!(
" {} {} — in DB, missing in model (add field or #[rok_orm(skip)])",
style("-").red(),
col.column_name
);
has_diff = true;
}
}
if !has_diff {
println!(
" {} No drift detected — model and DB are in sync.",
style("✓").green()
);
}
Ok(())
}
fn extract_table_name(source: &str) -> Option<String> {
let marker = "table = \"";
let start = source.find(marker)? + marker.len();
let end = source[start..].find('"')? + start;
Some(source[start..end].to_string())
}
fn parse_struct_fields(source: &str) -> Vec<String> {
let mut in_struct = false;
let mut fields = Vec::new();
for line in source.lines() {
let trimmed = line.trim();
if trimmed.starts_with("pub struct ") {
in_struct = true;
continue;
}
if in_struct {
if trimmed == "}" {
break;
}
if let Some(rest) = trimmed.strip_prefix("pub ") {
if let Some(name) = rest.split(':').next() {
let name = name.trim().to_string();
if !name.is_empty() && !name.starts_with("//") {
fields.push(name);
}
}
}
}
}
fields
}