byteorm_lib/rustgen/
upsert.rs1use proc_macro2::TokenStream;
2use quote::{format_ident, quote};
3use crate::{Model, Modifier};
4use crate::rustgen::{rust_type_from_schema, to_snake_case};
5
6pub fn generate_upsert_builder(model: &Model) -> TokenStream {
7 let model_name = format_ident!("{}", model.name);
8 let upsert_builder_name = format_ident!("{}Upsert", model.name);
9 let table_name = model.name.to_lowercase();
10
11 let pk_fields: Vec<_> = model.fields.iter()
12 .filter(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)))
13 .collect();
14
15 if pk_fields.is_empty() {
16 return quote! {
17 pub struct #upsert_builder_name;
18
19 impl #upsert_builder_name {
20 pub fn new(_client: Arc<PgClient>) -> Self {
21 Self
22 }
23 }
24 };
25 }
26
27 let all_fields: Vec<_> = model.fields.iter().collect();
28
29 let where_methods = pk_fields.iter().map(|field| {
30 let method_name = format_ident!("where_{}", to_snake_case(&field.name));
31 let is_nullable = field.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
32 let field_type = rust_type_from_schema(&field.type_name, is_nullable);
33 let field_col = to_snake_case(&field.name);
34
35 quote! {
36 pub fn #method_name(mut self, value: #field_type) -> Self {
37 self.pk_values.insert(#field_col, Box::new(value));
38 self
39 }
40 }
41 });
42
43 let set_methods = all_fields.iter().map(|field| {
44 let method_name = format_ident!("set_{}", to_snake_case(&field.name));
45 let is_nullable = field.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
46 let field_type = rust_type_from_schema(&field.type_name, is_nullable);
47 let field_col = to_snake_case(&field.name);
48
49 quote! {
50 pub fn #method_name(mut self, value: #field_type) -> Self {
51 self.set_values.insert(#field_col, Box::new(value));
52 self
53 }
54 }
55 });
56
57 let field_gets = model.fields.iter().enumerate().map(|(idx, field)| {
58 let field_name = format_ident!("{}", field.name);
59 quote! { #field_name: row.get(#idx) }
60 });
61
62 let pk_col_names: Vec<String> = pk_fields.iter()
63 .map(|f| to_snake_case(&f.name))
64 .collect();
65 let conflict_clause = pk_col_names.join(", ");
66
67 quote! {
68 pub struct #upsert_builder_name {
69 client: Arc<PgClient>,
70 table: String,
71 pk_values: std::collections::HashMap<&'static str, Box<dyn tokio_postgres::types::ToSql + Sync>>,
72 set_values: std::collections::HashMap<&'static str, Box<dyn tokio_postgres::types::ToSql + Sync>>,
73 }
74
75 unsafe impl Send for #upsert_builder_name {}
76
77 impl #upsert_builder_name {
78 pub fn new(client: Arc<PgClient>) -> Self {
79 Self {
80 client,
81 table: #table_name.to_string(),
82 pk_values: std::collections::HashMap::new(),
83 set_values: std::collections::HashMap::new(),
84 }
85 }
86
87 #(#where_methods)*
88 #(#set_methods)*
89
90 pub async fn execute(self) -> Result<#model_name, Box<dyn std::error::Error + Send + Sync>> {
91 let pk_columns = vec![#(#pk_col_names),*];
92 for pk_col in &pk_columns {
93 if !self.pk_values.contains_key(pk_col) && !self.set_values.contains_key(pk_col) {
94 return Err(format!("Missing primary key field: {}", pk_col).into());
95 }
96 }
97
98 let mut all_values = self.pk_values;
99 for (k, v) in self.set_values {
100 all_values.insert(k, v);
101 }
102
103 if all_values.is_empty() {
104 return Err("No fields to upsert".into());
105 }
106
107 let mut columns: Vec<&str> = all_values.keys().copied().collect();
108 columns.sort();
109
110 let columns_str = columns.join(", ");
111 let placeholders: Vec<String> = (1..=columns.len())
112 .map(|i| format!("${}", i))
113 .collect();
114 let placeholders_str = placeholders.join(", ");
115
116 let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = vec![];
117 for col in &columns {
118 params.push(all_values.get(col).unwrap().as_ref());
119 }
120
121 let update_columns: Vec<&str> = columns.iter()
122 .filter(|col| !pk_columns.iter().any(|pk| pk == *col))
123 .copied()
124 .collect();
125
126 let sql = if update_columns.is_empty() {
127 format!(
128 "INSERT INTO {} ({}) VALUES ({}) ON CONFLICT ({}) DO NOTHING RETURNING *",
129 self.table, columns_str, placeholders_str, #conflict_clause
130 )
131 } else {
132 let update_clauses: Vec<String> = update_columns.iter()
133 .map(|col| format!("{} = EXCLUDED.{}", col, col))
134 .collect();
135
136 format!(
137 "INSERT INTO {} ({}) VALUES ({}) ON CONFLICT ({}) DO UPDATE SET {} RETURNING *",
138 self.table, columns_str, placeholders_str, #conflict_clause, update_clauses.join(", ")
139 )
140 };
141
142 let row = self.client.query_one(&sql, ¶ms[..]).await?;
143 Ok(#model_name {
144 #(#field_gets),*
145 })
146 }
147 }
148 }
149}