byteorm_lib/rustgen/
update.rs1use proc_macro2::TokenStream;
2use quote::{format_ident, quote};
3use crate::{Model, Modifier};
4use crate::rustgen::{rust_type_from_schema, to_snake_case};
5
6fn is_numeric_type(ty: &str) -> bool {
7 matches!(ty, "BigInt" | "Int" | "Serial" | "Float" | "Real")
8}
9
10pub fn generate_update_builder(model: &Model) -> TokenStream {
11 let model_name = format_ident!("{}", model.name);
12 let update_builder_name = format_ident!("{}Update", model.name);
13 let table_name = model.name.to_lowercase();
14
15 let where_methods = model.fields.iter().map(|field| {
16 let method_name = format_ident!("where_{}", to_snake_case(&field.name));
17 let is_nullable = field.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
18 let field_type = rust_type_from_schema(&field.type_name, is_nullable);
19 let field_col = to_snake_case(&field.name);
20
21 quote! {
22 pub fn #method_name(mut self, value: #field_type) -> Self {
23 self.where_args.push(Box::new(value));
24 self.where_fragments.push((#field_col, self.where_args.len()));
25 self
26 }
27 }
28 });
29
30 let set_methods = model.fields.iter().map(|field| {
31 let method_name = format_ident!("set_{}", to_snake_case(&field.name));
32 let is_nullable = field.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
33 let field_type = rust_type_from_schema(&field.type_name, is_nullable);
34 let field_col = to_snake_case(&field.name);
35
36 quote! {
37 pub fn #method_name(mut self, value: #field_type) -> Self {
38 self.set_args.push(Box::new(value));
39 self.set_fragments.push(#field_col);
40 self
41 }
42 }
43 });
44
45 let inc_methods = model.fields.iter()
46 .filter(|f| is_numeric_type(&f.type_name))
47 .map(|field| {
48 let field_col = to_snake_case(&field.name);
49 let inc_method = format_ident!("inc_{}", field_col);
50 let dec_method = format_ident!("dec_{}", field_col);
51 let mul_method = format_ident!("mul_{}", field_col);
52 let div_method = format_ident!("div_{}", field_col);
53 quote! {
54 pub fn #inc_method(mut self, amount: i64) -> Self {
55 self.inc_ops.push((#field_col, "inc", amount));
56 self
57 }
58 pub fn #dec_method(mut self, amount: i64) -> Self {
59 self.inc_ops.push((#field_col, "dec", amount));
60 self
61 }
62 pub fn #mul_method(mut self, factor: i64) -> Self {
63 self.inc_ops.push((#field_col, "mul", factor));
64 self
65 }
66 pub fn #div_method(mut self, divisor: i64) -> Self {
67 self.inc_ops.push((#field_col, "div", divisor));
68 self
69 }
70 }
71 });
72
73 let field_gets = model.fields.iter().enumerate().map(|(idx, field)| {
74 let field_name = format_ident!("{}", field.name);
75 quote! { #field_name: row.get(#idx) }
76 });
77
78 quote! {
79 pub struct #update_builder_name {
80 client: Arc<PgClient>,
81 table: String,
82 where_fragments: Vec<(&'static str, usize)>,
83 where_args: Vec<Box<dyn tokio_postgres::types::ToSql + Sync>>,
84 set_fragments: Vec<&'static str>,
85 set_args: Vec<Box<dyn tokio_postgres::types::ToSql + Sync>>,
86 inc_ops: Vec<(&'static str, &'static str, i64)>,
87 }
88
89 unsafe impl Send for #update_builder_name {}
90
91 impl #update_builder_name {
92 pub fn new(client: Arc<PgClient>) -> Self {
93 Self {
94 client,
95 table: #table_name.to_string(),
96 where_fragments: vec![],
97 where_args: vec![],
98 set_fragments: vec![],
99 set_args: vec![],
100 inc_ops: vec![],
101 }
102 }
103
104 #(#where_methods)*
105 #(#set_methods)*
106 #(#inc_methods)*
107
108 pub async fn execute(self) -> Result<#model_name, Box<dyn std::error::Error + Send + Sync>> {
109 if self.set_fragments.is_empty() && self.inc_ops.is_empty() {
110 return Err("No fields to update".into());
111 }
112
113 let mut sql = format!("UPDATE {} SET ", self.table);
114 let mut set_clauses: Vec<String> = vec![];
115 let mut param_idx = 1;
116
117 for (i, col) in self.set_fragments.iter().enumerate() {
118 set_clauses.push(format!("{} = ${}", col, param_idx));
119 param_idx += 1;
120 }
121
122 for (field, op, _) in &self.inc_ops {
123 let clause = match *op {
124 "inc" => format!("{} = {} + ${}", field, field, param_idx),
125 "dec" => format!("{} = {} - ${}", field, field, param_idx),
126 "mul" => format!("{} = {} * ${}", field, field, param_idx),
127 "div" => format!("{} = {} / ${}", field, field, param_idx),
128 _ => continue,
129 };
130 set_clauses.push(clause);
131 param_idx += 1;
132 }
133 sql.push_str(&set_clauses.join(", "));
134
135 let mut all_params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
136 self.set_args.iter().map(|a| a.as_ref()).collect();
137 for (_, _, val) in &self.inc_ops {
138 all_params.push(val);
139 }
140
141 if !self.where_fragments.is_empty() {
142 let where_clauses: Vec<String> = self.where_fragments.iter()
143 .enumerate()
144 .map(|(i, &(col, _))| format!("{} = ${}", col, self.set_args.len() + self.inc_ops.len() + i + 1))
145 .collect();
146 sql.push_str(" WHERE ");
147 sql.push_str(&where_clauses.join(" AND "));
148
149 for arg in &self.where_args {
150 all_params.push(arg.as_ref());
151 }
152 }
153
154 sql.push_str(" RETURNING *");
155
156 let row = self.client.query_one(&sql, &all_params[..]).await?;
157 Ok(#model_name {
158 #(#field_gets),*
159 })
160 }
161 }
162 }
163}