use tera::{Tera, Context};
use indexmap::IndexMap;
use crate::tree::Tree;
use crate::models::{ModelDef, Column};
pub struct ModelGenerator;
impl ModelGenerator {
pub fn generate(tree: &Tree) -> IndexMap<String, String> {
let mut files = IndexMap::new();
let mut tera = Tera::default();
tera.add_raw_template("model", MODEL_TEMPLATE).unwrap();
let mut mod_lines = vec!["pub mod prelude;".to_string()];
let mut prelude_lines = vec![];
for model in tree.all_models() {
let mut ctx = Context::new();
ctx.insert("name", &model.name);
ctx.insert("table", &model.table_name());
ctx.insert("primary_key", &model.primary_key);
ctx.insert("has_timestamps", &model.timestamps);
ctx.insert("has_soft_deletes", &model.has_soft_deletes());
ctx.insert("has_ulid", &model.traits.contains(&"HasUlid".to_string()));
ctx.insert("has_uuid", &model.traits.contains(&"HasUuid".to_string()));
let fields = Self::generate_fields(model);
ctx.insert("fields", &fields);
let derives = Self::generate_derives(model);
ctx.insert("derives", &derives);
let rel_helpers = Self::generate_relationship_helpers(model);
ctx.insert("rel_helpers", &rel_helpers);
let query_methods = Self::generate_query_methods(model);
ctx.insert("query_methods", &query_methods);
let rendered = tera.render("model", &ctx).unwrap();
let path = format!("src/models/{}.rs", to_snake_case(&model.name));
files.insert(path, rendered);
mod_lines.push(format!("pub mod {};", model.name.to_lowercase()));
prelude_lines.push(format!("pub use super::{}::{};", model.name.to_lowercase(), model.name));
}
let mod_content = mod_lines.join("\n");
files.insert("src/models/mod.rs".to_string(), mod_content);
let prelude_content = format!("//! Re-exports for common use\n\n{}\n", prelude_lines.join("\n"));
files.insert("src/models/prelude.rs".to_string(), prelude_content);
files
}
fn generate_fields(model: &ModelDef) -> Vec<String> {
let mut fields = vec![];
fields.push(format!(" pub {}: {},", model.primary_key, Self::pk_type(model)));
for column in model.columns.values() {
if column.name == model.primary_key {
continue;
}
let rust_type = Self::column_to_rust_type(column);
let _nullable = if column.has_modifier("nullable") || column.name == "deleted_at" {
""
} else {
""
};
let actual_type = if column.has_modifier("nullable") || column.name == "deleted_at" {
format!("Option<{}>", rust_type)
} else {
rust_type.to_string()
};
fields.push(format!(" pub {}: {},", column.name, actual_type));
}
if model.timestamps {
if !model.columns.contains_key("created_at") {
fields.push(" pub created_at: chrono::NaiveDateTime,".to_string());
}
if !model.columns.contains_key("updated_at") {
fields.push(" pub updated_at: chrono::NaiveDateTime,".to_string());
}
}
if model.has_soft_deletes() && !model.columns.contains_key("deleted_at") {
fields.push(" pub deleted_at: Option<chrono::NaiveDateTime>,".to_string());
}
fields
}
fn pk_type(model: &ModelDef) -> &'static str {
if model.traits.contains(&"HasUlid".to_string()) {
"String"
} else if model.traits.contains(&"HasUuid".to_string()) {
"uuid::Uuid"
} else {
"i32"
}
}
fn column_to_rust_type(column: &Column) -> &'static str {
match column.data_type.as_str() {
"string" | "char" | "text" | "longtext" | "mediumtext" => "String",
"integer" | "bigInteger" | "unsignedInteger" | "unsignedBigInteger" => "i32",
"smallInteger" | "tinyInteger" => "i16",
"boolean" => "bool",
"float" | "double" => "f64",
"decimal" => "rust_decimal::Decimal",
"date" => "chrono::NaiveDate",
"datetime" | "timestamp" | "datetimeTz" | "timestampTz" => "chrono::NaiveDateTime",
"json" | "jsonb" => "serde_json::Value",
"uuid" => "uuid::Uuid",
"ulid" => "String",
"ipAddress" => "std::net::IpAddr",
"macAddress" => "String",
"enum" | "set" => "String",
"rememberToken" => "String",
_ => "String",
}
}
fn generate_derives(model: &ModelDef) -> Vec<String> {
let mut derives = vec![
"Debug".to_string(),
"Clone".to_string(),
"serde::Serialize".to_string(),
"serde::Deserialize".to_string(),
"sqlx::FromRow".to_string(),
];
derives
}
fn generate_relationship_helpers(model: &ModelDef) -> Vec<String> {
let mut helpers = vec![];
let mut added = std::collections::HashSet::new();
for rel in &model.relationships {
let method_name = to_snake_case(&rel.model);
let table = pluralize(&rel.model.to_lowercase());
match rel.type_.as_str() {
"belongsTo" => {
let fk = format!("{}_id", method_name);
helpers.push(format!(
" pub async fn fetch_{}(&self, pool: &sqlx::PgPool) -> Result<super::{}, sqlx::Error> {{\n sqlx::query_as::<_, super::{}>(\"SELECT * FROM {} WHERE id = $1\")\n .bind(self.{})\n .fetch_one(pool)\n .await\n }}",
method_name, rel.model, rel.model, table, fk
));
added.insert(method_name.clone());
}
"hasMany" => {
helpers.push(format!(
" pub async fn fetch_{}(&self, pool: &sqlx::PgPool) -> Result<Vec<super::{}>, sqlx::Error> {{\n sqlx::query_as::<_, super::{}>(\"SELECT * FROM {} WHERE {}_id = $1\")\n .bind(self.id)\n .fetch_all(pool)\n .await\n }}",
method_name, rel.model, rel.model, table, to_snake_case(&model.name)
));
added.insert(method_name.clone());
}
"hasOne" => {
helpers.push(format!(
" pub async fn fetch_{}(&self, pool: &sqlx::PgPool) -> Result<Option<super::{}>, sqlx::Error> {{\n sqlx::query_as::<_, super::{}>(\"SELECT * FROM {} WHERE {}_id = $1\")\n .bind(self.id)\n .fetch_optional(pool)\n .await\n }}",
method_name, rel.model, rel.model, table, to_snake_case(&model.name)
));
added.insert(method_name.clone());
}
"belongsToMany" => {
let pivot = if rel.model < model.name {
format!("{}_{}", to_snake_case(&rel.model), to_snake_case(&model.name))
} else {
format!("{}_{}", to_snake_case(&model.name), to_snake_case(&rel.model))
};
helpers.push(format!(
" pub async fn fetch_{}(&self, pool: &sqlx::PgPool) -> Result<Vec<super::{}>, sqlx::Error> {{\n sqlx::query_as::<_, super::{}>(\"SELECT {}.* FROM {} INNER JOIN {} ON {}.id = {}.{}_id WHERE {}.{}_id = $1\")\n .bind(self.id)\n .fetch_all(pool)\n .await\n }}",
method_name, rel.model, rel.model, table, table, pivot, table, pivot, to_snake_case(&model.name), pivot, to_snake_case(&model.name)
));
added.insert(method_name.clone());
}
_ => {}
}
}
for col in model.columns.values() {
if col.is_relationship {
if let Some(ref rel_model) = col.relationship_model {
let method_name = col.name.strip_suffix("_id").unwrap_or(&col.name).to_string();
if added.contains(&method_name) { continue; }
helpers.push(format!(
" pub async fn fetch_{}(&self, pool: &sqlx::PgPool) -> Result<super::{}, sqlx::Error> {{\n sqlx::query_as::<_, super::{}>(\"SELECT * FROM {} WHERE id = $1\")\n .bind(self.{})\n .fetch_one(pool)\n .await\n }}",
method_name, rel_model, rel_model, pluralize(&rel_model.to_lowercase()), col.name
));
added.insert(method_name);
}
}
}
helpers
}
fn generate_query_methods(model: &ModelDef) -> Vec<String> {
let table = model.table_name();
let name = &model.name;
let lower = name.to_lowercase();
vec![
format!(
" pub async fn find_by_id(id: i32, pool: &sqlx::PgPool) -> Result<Self, sqlx::Error> {{\n sqlx::query_as::<_, Self>(\"SELECT * FROM {} WHERE id = $1\")\n .bind(id)\n .fetch_one(pool)\n .await\n }}", table
),
format!(
" pub async fn find_all(pool: &sqlx::PgPool) -> Result<Vec<Self>, sqlx::Error> {{\n sqlx::query_as::<_, Self>(\"SELECT * FROM {}\")\n .fetch_all(pool)\n .await\n }}", table
),
format!(
" pub async fn delete(self, pool: &sqlx::PgPool) -> Result<_, sqlx::Error> {{\n sqlx::query(\"DELETE FROM {} WHERE id = $1\")\n .bind(self.{})\n .execute(pool)\n .await\n }}", table, model.primary_key
),
]
}
}
pub fn to_snake_case(s: &str) -> String {
let mut result = String::new();
for (i, c) in s.chars().enumerate() {
if c.is_uppercase() {
if i > 0 {
result.push('_');
}
result.push(c.to_lowercase().next().unwrap());
} else {
result.push(c);
}
}
result
}
fn pluralize(s: &str) -> String {
if s.ends_with('y') && !s.ends_with("ay") && !s.ends_with("ey") && !s.ends_with("oy") && !s.ends_with("uy") {
format!("{}ies", &s[..s.len() - 1])
} else if s.ends_with('s') || s.ends_with('x') || s.ends_with('z') || s.ends_with("ch") || s.ends_with("sh") {
format!("{}es", s)
} else {
format!("{}s", s)
}
}
const MODEL_TEMPLATE: &str = r#"use serde::{Deserialize, Serialize};
use sqlx::FromRow;
/// Represents a single row in the `{{ table }}` table.
#[derive({{ derives | join(sep=", ") }})]
pub struct {{ name }} {
{% for field in fields %}
{{ field }}
{% endfor %}
}
impl {{ name }} {
{% for method in query_methods %}
{{ method }}
{% endfor %}
{% for helper in rel_helpers %}
{{ helper }}
{% endfor %}
}
"#;