byteorm 0.1.0

A lightweight ORM for Rust
Documentation
use quote::{quote, format_ident};
use proc_macro2::TokenStream;
use crate::{Schema, Model, Field, Modifier};

pub fn generate_rust_code(schema: &Schema) -> String {
    let structs_and_impls = schema.models.iter().map(|model| {
        generate_model_with_query_builder(model)
    });

    let code = quote! {
        use serde::{Deserialize, Serialize};
        use chrono::{DateTime, Utc};
        use tokio_postgres::Client;

        fn calculate_json_diff(before: &serde_json::Value, after: &serde_json::Value) -> serde_json::Value {
            let mut diff = serde_json::Map::new();

            if let (Some(before_obj), Some(after_obj)) = (before.as_object(), after.as_object()) {
                for (key, after_val) in after_obj {
                    if let Some(before_val) = before_obj.get(key) {
                        if before_val != after_val {
                            diff.insert(
                                key.clone(),
                                serde_json::json!({
                                    "from": before_val,
                                    "to": after_val
                                })
                            );
                        }
                    } else {
                        diff.insert(key.clone(), serde_json::json!({ "added": after_val }));
                    }
                }

                for (key, before_val) in before_obj {
                    if !after_obj.contains_key(key) {
                        diff.insert(key.clone(), serde_json::json!({ "removed": before_val }));
                    }
                }
            }

            serde_json::Value::Object(diff)
        }

        #(#structs_and_impls)*
    };

    let file: syn::File = syn::parse2(code).unwrap();
    prettyplease::unparse(&file)
}

fn generate_model_with_query_builder(model: &Model) -> TokenStream {
    let model_struct = generate_model_struct(model);
    let query_builder_struct = generate_query_builder_struct(model);
    let query_builder_impl = generate_query_builder_impl(model);
    let model_impl = generate_model_impl(model);

    quote! {
        #model_struct

        #query_builder_struct

        #model_impl

        #query_builder_impl
    }
}

fn generate_model_struct(model: &Model) -> TokenStream {
    let name = format_ident!("{}", model.name);
    let fields = model.fields.iter().map(|field| {
        let field_name = format_ident!("{}", field.name);
        let field_type = rust_type_from_schema(&field.type_name);

        quote! {
            pub #field_name: #field_type
        }
    });

    quote! {
        #[derive(Debug, Clone, Serialize, Deserialize)]
        pub struct #name {
            #(#fields),*
        }
    }
}

fn generate_model_impl(model: &Model) -> TokenStream {
    let model_name = format_ident!("{}", model.name);
    let builder_name = format_ident!("{}Query", model.name);

    let pk_field = model.fields.iter()
        .find(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)));

    let find_by_id_impl = if let Some(pk) = pk_field {
        let pk_name = format_ident!("{}", to_snake_case(&pk.name));
        let pk_type = rust_type_from_schema(&pk.type_name);

        quote! {
            pub fn find_by_id(id: #pk_type) -> String {
                format!("SELECT * FROM {} WHERE {} = {}",
                    stringify!(#pk_name).replace("_", ""),
                    stringify!(#pk_name),
                    id
                )
            }
        }
    } else {
        quote! {}
    };

    quote! {
        impl #model_name {
            pub fn query() -> #builder_name {
                #builder_name::new()
            }

            #find_by_id_impl
        }
    }
}

fn generate_query_builder_struct(model: &Model) -> TokenStream {
    let builder_name = format_ident!("{}Query", model.name);

    quote! {
        pub struct #builder_name {
            table: String,
            conditions: Vec<String>,
            limit: Option<usize>,
            offset: Option<usize>,
            order_by: Vec<(String, String)>,
        }

        impl Clone for #builder_name {
            fn clone(&self) -> Self {
                Self {
                    table: self.table.clone(),
                    conditions: self.conditions.clone(),
                    limit: self.limit,
                    offset: self.offset,
                    order_by: self.order_by.clone(),
                }
            }
        }
    }
}

fn generate_query_builder_impl(model: &Model) -> TokenStream {
    let builder_name = format_ident!("{}Query", model.name);
    let model_name = format_ident!("{}", model.name);
    let table_name = model.name.to_lowercase();

    let where_methods = model.fields.iter().map(|field| {
        let method_name = format_ident!("where_{}", to_snake_case(&field.name));
        let field_name = to_snake_case(&field.name);
        let field_type = rust_type_from_schema(&field.type_name);

        quote! {
            pub fn #method_name(mut self, value: #field_type) -> Self {
                self.conditions.push(format!("{} = {:?}", #field_name, value));
                self
            }
        }
    });

    let field_gets = model.fields.iter().enumerate().map(|(idx, field)| {
        let field_name = format_ident!("{}", field.name);

        quote! {
            #field_name: row.get(#idx)
        }
    });

    let has_audit = model.fields.iter().any(|f| f.get_audit_model().is_some());

    let update_method = if has_audit {
        generate_update_with_audit(model)
    } else {
        quote! {}
    };

    quote! {
        impl #builder_name {
            pub fn new() -> Self {
                Self {
                    table: #table_name.to_string(),
                    conditions: vec![],
                    limit: None,
                    offset: None,
                    order_by: vec![],
                }
            }

            #(#where_methods)*

            pub fn limit(mut self, limit: usize) -> Self {
                self.limit = Some(limit);
                self
            }

            pub fn offset(mut self, offset: usize) -> Self {
                self.offset = Some(offset);
                self
            }

            pub fn order_by(mut self, column: &str, direction: &str) -> Self {
                self.order_by.push((column.to_string(), direction.to_string()));
                self
            }

            pub async fn select(&self, client: &Client)
                -> Result<Vec<#model_name>, Box<dyn std::error::Error>>
            {
                let sql = self.build_select();

                let rows = client.query(&sql, &[]).await?;
                let mut results = Vec::new();

                for row in rows {
                    results.push(#model_name {
                        #(#field_gets),*
                    });
                }

                Ok(results)
            }

            pub async fn first(&self, client: &Client)
                -> Result<Option<#model_name>, Box<dyn std::error::Error>>
            {
                let mut query = self.clone();
                query.limit = Some(1);
                let results = query.select(client).await?;
                Ok(results.into_iter().next())
            }

            pub async fn count(&self, client: &Client)
                -> Result<i64, Box<dyn std::error::Error>>
            {
                let sql = format!("SELECT COUNT(*) FROM {}{}",
                    self.table,
                    if self.conditions.is_empty() {
                        String::new()
                    } else {
                        format!(" WHERE {}", self.conditions.join(" AND "))
                    }
                );

                let row = client.query_one(&sql, &[]).await?;
                Ok(row.get(0))
            }

            #update_method

            fn build_select(&self) -> String {
                let mut sql = format!("SELECT * FROM {}", self.table);

                if !self.conditions.is_empty() {
                    sql.push_str(" WHERE ");
                    sql.push_str(&self.conditions.join(" AND "));
                }

                if !self.order_by.is_empty() {
                    sql.push_str(" ORDER BY ");
                    let order_clauses: Vec<String> = self.order_by.iter()
                        .map(|(col, dir)| format!("{} {}", col, dir))
                        .collect();
                    sql.push_str(&order_clauses.join(", "));
                }

                if let Some(limit) = self.limit {
                    sql.push_str(&format!(" LIMIT {}", limit));
                }

                if let Some(offset) = self.offset {
                    sql.push_str(&format!(" OFFSET {}", offset));
                }

                sql
            }
        }
    }
}

fn generate_update_with_audit(model: &Model) -> TokenStream {
    let table_name = model.name.to_lowercase();

    let audit_field = model.fields.iter()
        .find(|f| f.get_audit_model().is_some())
        .expect("Called generate_update_with_audit without audit field");

    let audit_model_name = audit_field.get_audit_model().unwrap();
    let audit_table = audit_model_name.to_lowercase();
    let audited_field_name = &audit_field.name;

    let pk_field = model.fields.iter()
        .find(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)))
        .expect("Model must have primary key for audit");

    let pk_column = to_snake_case(&pk_field.name);

    quote! {
        pub async fn update(
            &self,
            client: &Client,
            new_value: serde_json::Value,
            who: i64,
        ) -> Result<(), Box<dyn std::error::Error>> {
            let transaction = client.transaction().await?;

            if self.conditions.is_empty() {
                return Err("No WHERE condition specified for update".into());
            }

            let pk_value: i64 = self.conditions[0]
                .split('=')
                .nth(1)
                .and_then(|s| s.trim().parse().ok())
                .ok_or("Failed to parse primary key value")?;

            let before_sql = format!(
                "SELECT {} FROM {} WHERE {} = $1",
                #audited_field_name,
                #table_name,
                #pk_column
            );

            let before_row = transaction.query_one(&before_sql, &[&pk_value]).await?;
            let before: serde_json::Value = before_row.get(0);

            let update_sql = format!(
                "UPDATE {} SET {} = $1, updated_at = now() WHERE {} = $2",
                #table_name,
                #audited_field_name,
                #pk_column
            );

            transaction.execute(&update_sql, &[&new_value, &pk_value]).await?;

            let diff = calculate_json_diff(&before, &new_value);

            let audit_sql = format!(
                "INSERT INTO {} ({}, who, changed_at, before, after, diff) VALUES ($1, $2, now(), $3, $4, $5)",
                #audit_table,
                #pk_column
            );

            transaction.execute(
                &audit_sql,
                &[&pk_value, &who, &before, &new_value, &diff]
            ).await?;

            transaction.commit().await?;

            Ok(())
        }
    }
}


fn rust_type_from_schema(type_name: &str) -> TokenStream {
    match type_name {
        "BigInt" => quote! { i64 },
        "Int" => quote! { i32 },
        "String" => quote! { String },
        "JsonB" => quote! { serde_json::Value },
        "TimestamptZ" => quote! { DateTime<Utc> },
        "Boolean" => quote! { bool },
        "Float" => quote! { f64 },
        _ => quote! { String },
    }
}

fn to_snake_case(s: &str) -> String {
    let mut result = String::new();
    for (i, ch) in s.chars().enumerate() {
        if ch.is_uppercase() && i > 0 {
            result.push('_');
        }
        result.push(ch.to_lowercase().next().unwrap_or(ch));
    }
    result
}