byteorm 0.1.6

A lightweight ORM for Rust
Documentation
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use crate::{Model, Modifier};
use crate::rustgen::{generate_query_builder_struct, generate_update_builder, generate_upsert_builder, rust_type_from_schema, to_snake_case};

pub fn generate_model_with_query_builder(model: &Model) -> TokenStream {
    let model_struct = generate_model_struct(model);
    let query_builder_struct = generate_query_builder_struct(model);
    let update_builder = generate_update_builder(model);
    let upsert_builder = generate_upsert_builder(model);
    let model_impl = generate_model_impl(model);

    quote! {
        #model_struct
        #query_builder_struct
        #update_builder
        #upsert_builder
        #model_impl
    }
}

pub fn generate_model_struct(model: &Model) -> TokenStream {
    let name = format_ident!("{}", model.name);
    let fields = model.fields.iter().map(|field| {
        let field_name = format_ident!("{}", field.name);
        let is_nullable = field.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
        let field_type = rust_type_from_schema(&field.type_name, is_nullable);

        quote! {
            pub #field_name: #field_type
        }
    });

    quote! {
        #[derive(Debug, Clone, Serialize, Deserialize)]
        pub struct #name {
            #(#fields),*
        }
    }
}

fn generate_model_impl(model: &Model) -> TokenStream {
    let model_name = format_ident!("{}", model.name);
    let builder_name = format_ident!("{}Query", model_name);

    let pk_fields: Vec<_> = model.fields.iter()
        .filter(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)))
        .collect();

    let field_gets = model.fields.iter().enumerate().map(|(idx, field)| {
        let field_name = format_ident!("{}", field.name);
        quote! { #field_name: row.get(#idx) }
    });

    let find_by_id_impl = if !pk_fields.is_empty() {
        if pk_fields.len() == 1 {
            let pk = &pk_fields[0];
            let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
            let pk_type = rust_type_from_schema(&pk.type_name, is_nullable);
            let pk_name = to_snake_case(&pk.name);

            quote! {
                pub async fn find_by_id(client: &PgClient, id: #pk_type)
                    -> Result<Option<#model_name>, Box<dyn std::error::Error + Send + Sync>>
                {
                    let sql = format!("SELECT * FROM {} WHERE {} = $1",
                        stringify!(#model_name).to_lowercase(), #pk_name);
                    let row_opt = client.query_opt(&sql, &[&id]).await?;
                    Ok(row_opt.map(|row| #model_name {
                        #(#field_gets),*
                    }))
                }
            }
        } else {
            let pk_params = pk_fields.iter().map(|pk| {
                let param_name = format_ident!("{}", to_snake_case(&pk.name));
                let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
                let pk_type = rust_type_from_schema(&pk.type_name, is_nullable);
                quote! { #param_name: #pk_type }
            });

            let pk_conditions = pk_fields.iter().enumerate().map(|(i, pk)| {
                let pk_col = to_snake_case(&pk.name);
                let param_num = i + 1;
                format!("{} = ${}", pk_col, param_num)
            });
            let where_clause = pk_conditions.collect::<Vec<_>>().join(" AND ");

            let pk_args = pk_fields.iter().map(|pk| {
                let param_name = format_ident!("{}", to_snake_case(&pk.name));
                quote! { &#param_name }
            });

            quote! {
                pub async fn find_by_composite_pk(client: &PgClient, #(#pk_params),*)
                    -> Result<Option<#model_name>, Box<dyn std::error::Error + Send + Sync>>
                {
                    let sql = format!("SELECT * FROM {} WHERE {}",
                        stringify!(#model_name).to_lowercase(), #where_clause);
                    let row_opt = client.query_opt(&sql, &[#(#pk_args),*]).await?;
                    Ok(row_opt.map(|row| #model_name {
                        #(#field_gets),*
                    }))
                }
            }
        }
    } else {
        quote! {}
    };

    quote! {
        impl #model_name {
            pub fn query() -> #builder_name {
                #builder_name::new()
            }
            #find_by_id_impl
        }
    }
}