byteorm_lib/rustgen/
upsert.rs

1use proc_macro2::TokenStream;
2use quote::{format_ident, quote};
3use crate::{Model, Modifier};
4use crate::rustgen::{rust_type_from_schema, to_snake_case};
5
6pub fn generate_upsert_builder(model: &Model) -> TokenStream {
7    let model_name = format_ident!("{}", model.name);
8    let upsert_builder_name = format_ident!("{}Upsert", model.name);
9    let table_name = model.name.to_lowercase();
10
11    let pk_fields: Vec<_> = model.fields.iter()
12        .filter(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)))
13        .collect();
14
15    if pk_fields.is_empty() {
16        return quote! {
17            pub struct #upsert_builder_name;
18
19            impl #upsert_builder_name {
20                pub fn new(_client: Arc<PgClient>) -> Self {
21                    Self
22                }
23            }
24        };
25    }
26
27    let all_fields: Vec<_> = model.fields.iter().collect();
28
29    let where_methods = pk_fields.iter().map(|field| {
30        let method_name = format_ident!("where_{}", to_snake_case(&field.name));
31        let is_nullable = field.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
32        let field_type = rust_type_from_schema(&field.type_name, is_nullable);
33        let field_col = to_snake_case(&field.name);
34
35        quote! {
36            pub fn #method_name(mut self, value: #field_type) -> Self {
37                self.pk_values.insert(#field_col, Box::new(value));
38                self
39            }
40        }
41    });
42
43    let set_methods = all_fields.iter().map(|field| {
44        let method_name = format_ident!("set_{}", to_snake_case(&field.name));
45        let is_nullable = field.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
46        let field_type = rust_type_from_schema(&field.type_name, is_nullable);
47        let field_col = to_snake_case(&field.name);
48
49        quote! {
50            pub fn #method_name(mut self, value: #field_type) -> Self {
51                self.set_values.insert(#field_col, Box::new(value));
52                self
53            }
54        }
55    });
56
57    let field_gets = model.fields.iter().enumerate().map(|(idx, field)| {
58        let field_name = format_ident!("{}", field.name);
59        quote! { #field_name: row.get(#idx) }
60    });
61
62    let pk_col_names: Vec<String> = pk_fields.iter()
63        .map(|f| to_snake_case(&f.name))
64        .collect();
65    let conflict_clause = pk_col_names.join(", ");
66
67    quote! {
68        pub struct #upsert_builder_name {
69            client: Arc<PgClient>,
70            table: String,
71            pk_values: std::collections::HashMap<&'static str, Box<dyn tokio_postgres::types::ToSql + Sync>>,
72            set_values: std::collections::HashMap<&'static str, Box<dyn tokio_postgres::types::ToSql + Sync>>,
73        }
74
75        unsafe impl Send for #upsert_builder_name {}
76
77        impl #upsert_builder_name {
78            pub fn new(client: Arc<PgClient>) -> Self {
79                Self {
80                    client,
81                    table: #table_name.to_string(),
82                    pk_values: std::collections::HashMap::new(),
83                    set_values: std::collections::HashMap::new(),
84                }
85            }
86
87            #(#where_methods)*
88            #(#set_methods)*
89
90            pub async fn execute(self) -> Result<#model_name, Box<dyn std::error::Error + Send + Sync>> {
91                let pk_columns = vec![#(#pk_col_names),*];
92                for pk_col in &pk_columns {
93                    if !self.pk_values.contains_key(pk_col) && !self.set_values.contains_key(pk_col) {
94                        return Err(format!("Missing primary key field: {}", pk_col).into());
95                    }
96                }
97
98                let mut all_values = self.pk_values;
99                for (k, v) in self.set_values {
100                    all_values.insert(k, v);
101                }
102
103                if all_values.is_empty() {
104                    return Err("No fields to upsert".into());
105                }
106
107                let mut columns: Vec<&str> = all_values.keys().copied().collect();
108                columns.sort();
109
110                let columns_str = columns.join(", ");
111                let placeholders: Vec<String> = (1..=columns.len())
112                    .map(|i| format!("${}", i))
113                    .collect();
114                let placeholders_str = placeholders.join(", ");
115
116                let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = vec![];
117                for col in &columns {
118                    params.push(all_values.get(col).unwrap().as_ref());
119                }
120
121                let update_columns: Vec<&str> = columns.iter()
122                    .filter(|col| !pk_columns.iter().any(|pk| pk == *col))
123                    .copied()
124                    .collect();
125
126                let sql = if update_columns.is_empty() {
127                    format!(
128                        "INSERT INTO {} ({}) VALUES ({}) ON CONFLICT ({}) DO NOTHING RETURNING *",
129                        self.table, columns_str, placeholders_str, #conflict_clause
130                    )
131                } else {
132                    let update_clauses: Vec<String> = update_columns.iter()
133                        .map(|col| format!("{} = EXCLUDED.{}", col, col))
134                        .collect();
135
136                    format!(
137                        "INSERT INTO {} ({}) VALUES ({}) ON CONFLICT ({}) DO UPDATE SET {} RETURNING *",
138                        self.table, columns_str, placeholders_str, #conflict_clause, update_clauses.join(", ")
139                    )
140                };
141
142                let row = self.client.query_one(&sql, &params[..]).await?;
143                Ok(#model_name {
144                    #(#field_gets),*
145                })
146            }
147        }
148    }
149}