rust-db-blueprint 0.1.0

A Rust code generator — reads YAML draft files and generates Axum + SQLx models, migrations, handlers, routes, requests, tests, and seeds
Documentation
use tera::{Tera, Context};
use indexmap::IndexMap;

use crate::tree::Tree;
use crate::models::{ModelDef, Column};

pub struct ModelGenerator;

impl ModelGenerator {
    pub fn generate(tree: &Tree) -> IndexMap<String, String> {
        let mut files = IndexMap::new();
        let mut tera = Tera::default();

        tera.add_raw_template("model", MODEL_TEMPLATE).unwrap();

        // Generate mod.rs for models directory
        let mut mod_lines = vec!["pub mod prelude;".to_string()];
        let mut prelude_lines = vec![];

        for model in tree.all_models() {
            let mut ctx = Context::new();
            ctx.insert("name", &model.name);
            ctx.insert("table", &model.table_name());
            ctx.insert("primary_key", &model.primary_key);
            ctx.insert("has_timestamps", &model.timestamps);
            ctx.insert("has_soft_deletes", &model.has_soft_deletes());
            ctx.insert("has_ulid", &model.traits.contains(&"HasUlid".to_string()));
            ctx.insert("has_uuid", &model.traits.contains(&"HasUuid".to_string()));

            // Generate struct fields
            let fields = Self::generate_fields(model);
            ctx.insert("fields", &fields);

            // Generate derive macros
            let derives = Self::generate_derives(model);
            ctx.insert("derives", &derives);

            // Generate relationship helpers
            let rel_helpers = Self::generate_relationship_helpers(model);
            ctx.insert("rel_helpers", &rel_helpers);

            // Generate query methods (CRUD)
            let query_methods = Self::generate_query_methods(model);
            ctx.insert("query_methods", &query_methods);

            let rendered = tera.render("model", &ctx).unwrap();
            let path = format!("src/models/{}.rs", to_snake_case(&model.name));
            files.insert(path, rendered);

            mod_lines.push(format!("pub mod {};", model.name.to_lowercase()));
            prelude_lines.push(format!("pub use super::{}::{};", model.name.to_lowercase(), model.name));
        }

        // Generate mod.rs and prelude.rs
        let mod_content = mod_lines.join("\n");
        files.insert("src/models/mod.rs".to_string(), mod_content);

        let prelude_content = format!("//! Re-exports for common use\n\n{}\n", prelude_lines.join("\n"));
        files.insert("src/models/prelude.rs".to_string(), prelude_content);

        files
    }

    fn generate_fields(model: &ModelDef) -> Vec<String> {
        let mut fields = vec![];

        // Primary key
        fields.push(format!("    pub {}: {},", model.primary_key, Self::pk_type(model)));

        for column in model.columns.values() {
            if column.name == model.primary_key {
                continue;
            }
            let rust_type = Self::column_to_rust_type(column);
            let _nullable = if column.has_modifier("nullable") || column.name == "deleted_at" {
                ""
            } else {
                ""
            };
            let actual_type = if column.has_modifier("nullable") || column.name == "deleted_at" {
                format!("Option<{}>", rust_type)
            } else {
                rust_type.to_string()
            };
            fields.push(format!("    pub {}: {},", column.name, actual_type));
        }

        if model.timestamps {
            if !model.columns.contains_key("created_at") {
                fields.push("    pub created_at: chrono::NaiveDateTime,".to_string());
            }
            if !model.columns.contains_key("updated_at") {
                fields.push("    pub updated_at: chrono::NaiveDateTime,".to_string());
            }
        }

        if model.has_soft_deletes() && !model.columns.contains_key("deleted_at") {
            fields.push("    pub deleted_at: Option<chrono::NaiveDateTime>,".to_string());
        }

        fields
    }

    fn pk_type(model: &ModelDef) -> &'static str {
        if model.traits.contains(&"HasUlid".to_string()) {
            "String"
        } else if model.traits.contains(&"HasUuid".to_string()) {
            "uuid::Uuid"
        } else {
            "i32"
        }
    }

    fn column_to_rust_type(column: &Column) -> &'static str {
        match column.data_type.as_str() {
            "string" | "char" | "text" | "longtext" | "mediumtext" => "String",
            "integer" | "bigInteger" | "unsignedInteger" | "unsignedBigInteger" => "i32",
            "smallInteger" | "tinyInteger" => "i16",
            "boolean" => "bool",
            "float" | "double" => "f64",
            "decimal" => "rust_decimal::Decimal",
            "date" => "chrono::NaiveDate",
            "datetime" | "timestamp" | "datetimeTz" | "timestampTz" => "chrono::NaiveDateTime",
            "json" | "jsonb" => "serde_json::Value",
            "uuid" => "uuid::Uuid",
            "ulid" => "String",
            "ipAddress" => "std::net::IpAddr",
            "macAddress" => "String",
            "enum" | "set" => "String",
            "rememberToken" => "String",
            _ => "String",
        }
    }

    fn generate_derives(model: &ModelDef) -> Vec<String> {
        let mut derives = vec![
            "Debug".to_string(),
            "Clone".to_string(),
            "serde::Serialize".to_string(),
            "serde::Deserialize".to_string(),
            "sqlx::FromRow".to_string(),
        ];
        derives
    }

    fn generate_relationship_helpers(model: &ModelDef) -> Vec<String> {
        let mut helpers = vec![];
        let mut added = std::collections::HashSet::new();

        // From explicit relationships
        for rel in &model.relationships {
            let method_name = to_snake_case(&rel.model);
            let table = pluralize(&rel.model.to_lowercase());
            match rel.type_.as_str() {
                "belongsTo" => {
                    let fk = format!("{}_id", method_name);
                    helpers.push(format!(
                        "    pub async fn fetch_{}(&self, pool: &sqlx::PgPool) -> Result<super::{}, sqlx::Error> {{\n        sqlx::query_as::<_, super::{}>(\"SELECT * FROM {} WHERE id = $1\")\n            .bind(self.{})\n            .fetch_one(pool)\n            .await\n    }}",
                        method_name, rel.model, rel.model, table, fk
                    ));
                    added.insert(method_name.clone());
                }
                "hasMany" => {
                    helpers.push(format!(
                        "    pub async fn fetch_{}(&self, pool: &sqlx::PgPool) -> Result<Vec<super::{}>, sqlx::Error> {{\n        sqlx::query_as::<_, super::{}>(\"SELECT * FROM {} WHERE {}_id = $1\")\n            .bind(self.id)\n            .fetch_all(pool)\n            .await\n    }}",
                        method_name, rel.model, rel.model, table, to_snake_case(&model.name)
                    ));
                    added.insert(method_name.clone());
                }
                "hasOne" => {
                    helpers.push(format!(
                        "    pub async fn fetch_{}(&self, pool: &sqlx::PgPool) -> Result<Option<super::{}>, sqlx::Error> {{\n        sqlx::query_as::<_, super::{}>(\"SELECT * FROM {} WHERE {}_id = $1\")\n            .bind(self.id)\n            .fetch_optional(pool)\n            .await\n    }}",
                        method_name, rel.model, rel.model, table, to_snake_case(&model.name)
                    ));
                    added.insert(method_name.clone());
                }
                "belongsToMany" => {
                    let pivot = if rel.model < model.name {
                        format!("{}_{}", to_snake_case(&rel.model), to_snake_case(&model.name))
                    } else {
                        format!("{}_{}", to_snake_case(&model.name), to_snake_case(&rel.model))
                    };
                    helpers.push(format!(
                        "    pub async fn fetch_{}(&self, pool: &sqlx::PgPool) -> Result<Vec<super::{}>, sqlx::Error> {{\n        sqlx::query_as::<_, super::{}>(\"SELECT {}.* FROM {} INNER JOIN {} ON {}.id = {}.{}_id WHERE {}.{}_id = $1\")\n            .bind(self.id)\n            .fetch_all(pool)\n            .await\n    }}",
                        method_name, rel.model, rel.model, table, table, pivot, table, pivot, to_snake_case(&model.name), pivot, to_snake_case(&model.name)
                    ));
                    added.insert(method_name.clone());
                }
                _ => {}
            }
        }

        // From column _id detection (belongsTo)
        for col in model.columns.values() {
            if col.is_relationship {
                if let Some(ref rel_model) = col.relationship_model {
                    let method_name = col.name.strip_suffix("_id").unwrap_or(&col.name).to_string();
                    if added.contains(&method_name) { continue; }
                    helpers.push(format!(
                        "    pub async fn fetch_{}(&self, pool: &sqlx::PgPool) -> Result<super::{}, sqlx::Error> {{\n        sqlx::query_as::<_, super::{}>(\"SELECT * FROM {} WHERE id = $1\")\n            .bind(self.{})\n            .fetch_one(pool)\n            .await\n    }}",
                        method_name, rel_model, rel_model, pluralize(&rel_model.to_lowercase()), col.name
                    ));
                    added.insert(method_name);
                }
            }
        }
        helpers
    }

    fn generate_query_methods(model: &ModelDef) -> Vec<String> {
        let table = model.table_name();
        let name = &model.name;
        let lower = name.to_lowercase();

        vec![
            format!(
                "    pub async fn find_by_id(id: i32, pool: &sqlx::PgPool) -> Result<Self, sqlx::Error> {{\n        sqlx::query_as::<_, Self>(\"SELECT * FROM {} WHERE id = $1\")\n            .bind(id)\n            .fetch_one(pool)\n            .await\n    }}", table
            ),
            format!(
                "    pub async fn find_all(pool: &sqlx::PgPool) -> Result<Vec<Self>, sqlx::Error> {{\n        sqlx::query_as::<_, Self>(\"SELECT * FROM {}\")\n            .fetch_all(pool)\n            .await\n    }}", table
            ),
            format!(
                "    pub async fn delete(self, pool: &sqlx::PgPool) -> Result<_, sqlx::Error> {{\n        sqlx::query(\"DELETE FROM {} WHERE id = $1\")\n            .bind(self.{})\n            .execute(pool)\n            .await\n    }}", table, model.primary_key
            ),
        ]
    }
}

pub fn to_snake_case(s: &str) -> String {
    let mut result = String::new();
    for (i, c) in s.chars().enumerate() {
        if c.is_uppercase() {
            if i > 0 {
                result.push('_');
            }
            result.push(c.to_lowercase().next().unwrap());
        } else {
            result.push(c);
        }
    }
    result
}

fn pluralize(s: &str) -> String {
    if s.ends_with('y') && !s.ends_with("ay") && !s.ends_with("ey") && !s.ends_with("oy") && !s.ends_with("uy") {
        format!("{}ies", &s[..s.len() - 1])
    } else if s.ends_with('s') || s.ends_with('x') || s.ends_with('z') || s.ends_with("ch") || s.ends_with("sh") {
        format!("{}es", s)
    } else {
        format!("{}s", s)
    }
}

const MODEL_TEMPLATE: &str = r#"use serde::{Deserialize, Serialize};
use sqlx::FromRow;

/// Represents a single row in the `{{ table }}` table.
#[derive({{ derives | join(sep=", ") }})]
pub struct {{ name }} {
{% for field in fields %}
{{ field }}
{% endfor %}
}

impl {{ name }} {
{% for method in query_methods %}
{{ method }}

{% endfor %}
{% for helper in rel_helpers %}
{{ helper }}

{% endfor %}
}
"#;