byteorm 0.1.6

A lightweight ORM for Rust
Documentation
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use crate::{Model, Modifier};
use crate::rustgen::{rust_type_from_schema, to_snake_case};

pub fn generate_upsert_builder(model: &Model) -> TokenStream {
    let model_name = format_ident!("{}", model.name);
    let upsert_builder_name = format_ident!("{}Upsert", model.name);
    let table_name = model.name.to_lowercase();

    let pk_fields: Vec<_> = model.fields.iter()
        .filter(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)))
        .collect();

    if pk_fields.is_empty() {
        return quote! {
            pub struct #upsert_builder_name;

            impl #upsert_builder_name {
                pub fn new(_client: Arc<PgClient>) -> Self {
                    Self
                }
            }
        };
    }

    let all_fields: Vec<_> = model.fields.iter().collect();

    let where_methods = pk_fields.iter().map(|field| {
        let method_name = format_ident!("where_{}", to_snake_case(&field.name));
        let is_nullable = field.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
        let field_type = rust_type_from_schema(&field.type_name, is_nullable);
        let field_col = to_snake_case(&field.name);

        quote! {
            pub fn #method_name(mut self, value: #field_type) -> Self {
                self.pk_values.insert(#field_col, Box::new(value));
                self
            }
        }
    });

    let set_methods = all_fields.iter().map(|field| {
        let method_name = format_ident!("set_{}", to_snake_case(&field.name));
        let is_nullable = field.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
        let field_type = rust_type_from_schema(&field.type_name, is_nullable);
        let field_col = to_snake_case(&field.name);

        quote! {
            pub fn #method_name(mut self, value: #field_type) -> Self {
                self.set_values.insert(#field_col, Box::new(value));
                self
            }
        }
    });

    let field_gets = model.fields.iter().enumerate().map(|(idx, field)| {
        let field_name = format_ident!("{}", field.name);
        quote! { #field_name: row.get(#idx) }
    });

    let pk_col_names: Vec<String> = pk_fields.iter()
        .map(|f| to_snake_case(&f.name))
        .collect();
    let conflict_clause = pk_col_names.join(", ");

    quote! {
        pub struct #upsert_builder_name {
            client: Arc<PgClient>,
            table: String,
            pk_values: std::collections::HashMap<&'static str, Box<dyn tokio_postgres::types::ToSql + Sync>>,
            set_values: std::collections::HashMap<&'static str, Box<dyn tokio_postgres::types::ToSql + Sync>>,
        }

        unsafe impl Send for #upsert_builder_name {}

        impl #upsert_builder_name {
            pub fn new(client: Arc<PgClient>) -> Self {
                Self {
                    client,
                    table: #table_name.to_string(),
                    pk_values: std::collections::HashMap::new(),
                    set_values: std::collections::HashMap::new(),
                }
            }

            #(#where_methods)*
            #(#set_methods)*

            pub async fn execute(self) -> Result<#model_name, Box<dyn std::error::Error + Send + Sync>> {
                let pk_columns = vec![#(#pk_col_names),*];
                for pk_col in &pk_columns {
                    if !self.pk_values.contains_key(pk_col) && !self.set_values.contains_key(pk_col) {
                        return Err(format!("Missing primary key field: {}", pk_col).into());
                    }
                }

                let mut all_values = self.pk_values;
                for (k, v) in self.set_values {
                    all_values.insert(k, v);
                }

                if all_values.is_empty() {
                    return Err("No fields to upsert".into());
                }

                let mut columns: Vec<&str> = all_values.keys().copied().collect();
                columns.sort();

                let columns_str = columns.join(", ");
                let placeholders: Vec<String> = (1..=columns.len())
                    .map(|i| format!("${}", i))
                    .collect();
                let placeholders_str = placeholders.join(", ");

                let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = vec![];
                for col in &columns {
                    params.push(all_values.get(col).unwrap().as_ref());
                }

                let update_columns: Vec<&str> = columns.iter()
                    .filter(|col| !pk_columns.iter().any(|pk| pk == *col))
                    .copied()
                    .collect();

                let sql = if update_columns.is_empty() {
                    format!(
                        "INSERT INTO {} ({}) VALUES ({}) ON CONFLICT ({}) DO NOTHING RETURNING *",
                        self.table, columns_str, placeholders_str, #conflict_clause
                    )
                } else {
                    let update_clauses: Vec<String> = update_columns.iter()
                        .map(|col| format!("{} = EXCLUDED.{}", col, col))
                        .collect();

                    format!(
                        "INSERT INTO {} ({}) VALUES ({}) ON CONFLICT ({}) DO UPDATE SET {} RETURNING *",
                        self.table, columns_str, placeholders_str, #conflict_clause, update_clauses.join(", ")
                    )
                };

                let row = self.client.query_one(&sql, &params[..]).await?;
                Ok(#model_name {
                    #(#field_gets),*
                })
            }
        }
    }
}