rok-cli 0.1.3

Developer CLI for rok-based Axum applications
//! `rok db:pull` and `rok db:diff` — live DB introspection and schema drift detection.

use console::style;
use heck::ToUpperCamelCase;
use sqlx::PgPool;

// ── Column metadata ───────────────────────────────────────────────────────────

#[derive(Debug, sqlx::FromRow)]
struct ColumnInfo {
    column_name: String,
    data_type: String,
    is_nullable: String,
    #[allow(dead_code)]
    column_default: Option<String>,
}

fn pg_type_to_rust(pg_type: &str) -> (&'static str, &'static str) {
    // returns (rust_type, field_attr)
    match pg_type {
        "text" | "character varying" | "varchar" | "char" | "character" => ("String", ""),
        "bigint" | "int8" | "integer" | "int" | "int4" | "smallint" | "int2" => ("i64", ""),
        "boolean" | "bool" => ("bool", ""),
        "double precision" | "float8" | "real" | "float4" => ("f64", ""),
        "numeric" | "decimal" => ("f64", ""),
        "uuid" => ("String", "// uuid"),
        "jsonb" | "json" => ("serde_json::Value", "#[sqlx(json)]"),
        "timestamp with time zone" | "timestamptz" => ("chrono::DateTime<chrono::Utc>", ""),
        "timestamp" | "timestamp without time zone" => ("chrono::NaiveDateTime", ""),
        "date" => ("chrono::NaiveDate", ""),
        "bytea" => ("Vec<u8>", ""),
        "bigserial" | "serial8" => ("i64", "// auto-increment"),
        "serial" | "serial4" => ("i32", "// auto-increment"),
        _ => ("String", "// unknown pg type"),
    }
}

// ── db:pull ───────────────────────────────────────────────────────────────────

pub async fn pull(table: &str, output: Option<&str>) -> anyhow::Result<()> {
    dotenvy::dotenv().ok();
    let database_url = std::env::var("DATABASE_URL")
        .map_err(|_| anyhow::anyhow!("DATABASE_URL not set — add it to .env"))?;

    println!(
        "{} Connecting to database...",
        style("db:pull").green().bold()
    );

    let pool = PgPool::connect(&database_url)
        .await
        .map_err(|e| anyhow::anyhow!("Failed to connect: {e}"))?;

    let columns: Vec<ColumnInfo> = sqlx::query_as::<_, ColumnInfo>(
        r#"
        SELECT
            column_name,
            data_type,
            is_nullable,
            column_default
        FROM information_schema.columns
        WHERE table_schema = 'public'
          AND table_name   = $1
        ORDER BY ordinal_position
        "#,
    )
    .bind(table)
    .fetch_all(&pool)
    .await?;

    if columns.is_empty() {
        anyhow::bail!("Table '{table}' not found in public schema. Check the table name.");
    }

    let pascal = table.trim_end_matches('s').to_upper_camel_case();

    println!(
        "{} Found {} columns in '{table}'",
        style("").green(),
        columns.len()
    );

    let model_code = generate_model_from_columns(&pascal, table, &columns);

    match output {
        Some(path) => {
            std::fs::write(path, &model_code)?;
            println!("{} Model written to {path}", style("").green().bold());
        }
        None => {
            println!("\n{}", style("Generated model:").bold());
            println!("{}", model_code);
        }
    }

    Ok(())
}

fn generate_model_from_columns(pascal: &str, table: &str, columns: &[ColumnInfo]) -> String {
    let mut uses = vec![
        "use rok_orm::Model;".to_string(),
        "use serde::{Deserialize, Serialize};".to_string(),
    ];
    let mut needs_chrono = false;
    let mut needs_json = false;

    for col in columns {
        let (rust_t, _) = pg_type_to_rust(&col.data_type);
        if rust_t.contains("chrono") {
            needs_chrono = true;
        }
        if rust_t.contains("Value") {
            needs_json = true;
        }
    }
    if needs_chrono {
        uses.push("use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc};".to_string());
    }
    if needs_json {
        uses.push("use serde_json::Value;".to_string());
    }

    let uses_str = uses.join("\n");

    let mut fields = String::new();
    for col in columns {
        let (rust_t, attr) = pg_type_to_rust(&col.data_type);
        let rust_t = match rust_t {
            "chrono::DateTime<chrono::Utc>" => "DateTime<Utc>",
            "chrono::NaiveDate" => "NaiveDate",
            "chrono::NaiveDateTime" => "NaiveDateTime",
            "serde_json::Value" => "Value",
            other => other,
        };
        let nullable = col.is_nullable == "YES";
        let field_type = if nullable {
            format!("Option<{}>", rust_t)
        } else {
            rust_t.to_string()
        };

        if !attr.is_empty() {
            fields.push_str(&format!("    {attr}\n"));
        }
        fields.push_str(&format!("    pub {}: {},\n", col.column_name, field_type));
    }

    format!(
        "{uses_str}\n\n#[derive(Debug, Clone, Model, Serialize, Deserialize, sqlx::FromRow)]\n#[rok_orm(table = \"{table}\")]\npub struct {pascal} {{\n{fields}}}\n"
    )
}

// ── db:diff ───────────────────────────────────────────────────────────────────

pub async fn diff(model_file: &str, table: Option<&str>) -> anyhow::Result<()> {
    dotenvy::dotenv().ok();
    let database_url =
        std::env::var("DATABASE_URL").map_err(|_| anyhow::anyhow!("DATABASE_URL not set"))?;

    if !std::path::Path::new(model_file).exists() {
        anyhow::bail!("Model file not found: {model_file}");
    }

    let source = std::fs::read_to_string(model_file)?;

    // Extract table name from #[rok_orm(table = "...")] or infer from file name
    let table_name = table
        .map(|t| t.to_string())
        .or_else(|| extract_table_name(&source))
        .unwrap_or_else(|| {
            std::path::Path::new(model_file)
                .file_stem()
                .and_then(|s| s.to_str())
                .unwrap_or("unknown")
                .to_string()
                + "s"
        });

    // Parse struct fields from source (simple text scan, no AST)
    let model_fields = parse_struct_fields(&source);

    println!(
        "{} Connecting to database...",
        style("db:diff").green().bold()
    );
    let pool = PgPool::connect(&database_url)
        .await
        .map_err(|e| anyhow::anyhow!("Failed to connect: {e}"))?;

    let db_columns: Vec<ColumnInfo> = sqlx::query_as::<_, ColumnInfo>(
        r#"
        SELECT column_name, data_type, is_nullable, column_default
        FROM information_schema.columns
        WHERE table_schema = 'public' AND table_name = $1
        ORDER BY ordinal_position
        "#,
    )
    .bind(&table_name)
    .fetch_all(&pool)
    .await?;

    let db_col_names: std::collections::HashSet<String> =
        db_columns.iter().map(|c| c.column_name.clone()).collect();

    let model_field_names: std::collections::HashSet<String> =
        model_fields.iter().map(|f| f.to_string()).collect();

    println!(
        "\n{} Schema diff for '{table_name}':",
        style("db:diff").bold()
    );

    let mut has_diff = false;

    // Fields in model but not in DB
    for name in &model_field_names {
        if !db_col_names.contains(name) {
            println!(
                "  {} {name} — in model, missing in DB (add a migration)",
                style("+").green()
            );
            has_diff = true;
        }
    }

    // Columns in DB but not in model
    for col in &db_columns {
        if !model_field_names.contains(&col.column_name) {
            println!(
                "  {} {} — in DB, missing in model (add field or #[rok_orm(skip)])",
                style("-").red(),
                col.column_name
            );
            has_diff = true;
        }
    }

    if !has_diff {
        println!(
            "  {} No drift detected — model and DB are in sync.",
            style("").green()
        );
    }

    Ok(())
}

fn extract_table_name(source: &str) -> Option<String> {
    // Look for: table = "table_name"
    let marker = "table = \"";
    let start = source.find(marker)? + marker.len();
    let end = source[start..].find('"')? + start;
    Some(source[start..end].to_string())
}

fn parse_struct_fields(source: &str) -> Vec<String> {
    // Simple text scan for `pub field_name:` lines inside the struct
    let mut in_struct = false;
    let mut fields = Vec::new();

    for line in source.lines() {
        let trimmed = line.trim();
        if trimmed.starts_with("pub struct ") {
            in_struct = true;
            continue;
        }
        if in_struct {
            if trimmed == "}" {
                break;
            }
            if let Some(rest) = trimmed.strip_prefix("pub ") {
                if let Some(name) = rest.split(':').next() {
                    let name = name.trim().to_string();
                    if !name.is_empty() && !name.starts_with("//") {
                        fields.push(name);
                    }
                }
            }
        }
    }
    fields
}