byteorm_lib/rustgen/
update.rs

1use proc_macro2::TokenStream;
2use quote::{format_ident, quote};
3use crate::{Model, Modifier};
4use crate::rustgen::{rust_type_from_schema, to_snake_case};
5
6fn is_numeric_type(ty: &str) -> bool {
7    matches!(ty, "BigInt" | "Int" | "Serial" | "Float" | "Real")
8}
9
10pub fn generate_update_builder(model: &Model) -> TokenStream {
11    let model_name = format_ident!("{}", model.name);
12    let update_builder_name = format_ident!("{}Update", model.name);
13    let table_name = model.name.to_lowercase();
14
15    let where_methods = model.fields.iter().map(|field| {
16        let method_name = format_ident!("where_{}", to_snake_case(&field.name));
17        let is_nullable = field.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
18        let field_type = rust_type_from_schema(&field.type_name, is_nullable);
19        let field_col = to_snake_case(&field.name);
20
21        quote! {
22            pub fn #method_name(mut self, value: #field_type) -> Self {
23                self.where_args.push(Box::new(value));
24                self.where_fragments.push((#field_col, self.where_args.len()));
25                self
26            }
27        }
28    });
29
30    let set_methods = model.fields.iter().map(|field| {
31        let method_name = format_ident!("set_{}", to_snake_case(&field.name));
32        let is_nullable = field.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
33        let field_type = rust_type_from_schema(&field.type_name, is_nullable);
34        let field_col = to_snake_case(&field.name);
35
36        quote! {
37            pub fn #method_name(mut self, value: #field_type) -> Self {
38                self.set_args.push(Box::new(value));
39                self.set_fragments.push(#field_col);
40                self
41            }
42        }
43    });
44
45    let inc_methods = model.fields.iter()
46        .filter(|f| is_numeric_type(&f.type_name))
47        .map(|field| {
48            let field_col = to_snake_case(&field.name);
49            let inc_method = format_ident!("inc_{}", field_col);
50            let dec_method = format_ident!("dec_{}", field_col);
51            let mul_method = format_ident!("mul_{}", field_col);
52            let div_method = format_ident!("div_{}", field_col);
53            quote! {
54                pub fn #inc_method(mut self, amount: i64) -> Self {
55                    self.inc_ops.push((#field_col, "inc", amount));
56                    self
57                }
58                pub fn #dec_method(mut self, amount: i64) -> Self {
59                    self.inc_ops.push((#field_col, "dec", amount));
60                    self
61                }
62                pub fn #mul_method(mut self, factor: i64) -> Self {
63                    self.inc_ops.push((#field_col, "mul", factor));
64                    self
65                }
66                pub fn #div_method(mut self, divisor: i64) -> Self {
67                    self.inc_ops.push((#field_col, "div", divisor));
68                    self
69                }
70            }
71        });
72
73    let field_gets = model.fields.iter().enumerate().map(|(idx, field)| {
74        let field_name = format_ident!("{}", field.name);
75        quote! { #field_name: row.get(#idx) }
76    });
77
78    quote! {
79        pub struct #update_builder_name {
80            client: Arc<PgClient>,
81            table: String,
82            where_fragments: Vec<(&'static str, usize)>,
83            where_args: Vec<Box<dyn tokio_postgres::types::ToSql + Sync>>,
84            set_fragments: Vec<&'static str>,
85            set_args: Vec<Box<dyn tokio_postgres::types::ToSql + Sync>>,
86            inc_ops: Vec<(&'static str, &'static str, i64)>,
87        }
88
89        unsafe impl Send for #update_builder_name {}
90
91        impl #update_builder_name {
92            pub fn new(client: Arc<PgClient>) -> Self {
93                Self {
94                    client,
95                    table: #table_name.to_string(),
96                    where_fragments: vec![],
97                    where_args: vec![],
98                    set_fragments: vec![],
99                    set_args: vec![],
100                    inc_ops: vec![],
101                }
102            }
103
104            #(#where_methods)*
105            #(#set_methods)*
106            #(#inc_methods)*
107
108            pub async fn execute(self) -> Result<#model_name, Box<dyn std::error::Error + Send + Sync>> {
109                if self.set_fragments.is_empty() && self.inc_ops.is_empty() {
110                    return Err("No fields to update".into());
111                }
112
113                let mut sql = format!("UPDATE {} SET ", self.table);
114                let mut set_clauses: Vec<String> = vec![];
115                let mut param_idx = 1;
116
117                for (i, col) in self.set_fragments.iter().enumerate() {
118                    set_clauses.push(format!("{} = ${}", col, param_idx));
119                    param_idx += 1;
120                }
121
122                for (field, op, _) in &self.inc_ops {
123                    let clause = match *op {
124                        "inc" => format!("{} = {} + ${}", field, field, param_idx),
125                        "dec" => format!("{} = {} - ${}", field, field, param_idx),
126                        "mul" => format!("{} = {} * ${}", field, field, param_idx),
127                        "div" => format!("{} = {} / ${}", field, field, param_idx),
128                        _ => continue,
129                    };
130                    set_clauses.push(clause);
131                    param_idx += 1;
132                }
133                sql.push_str(&set_clauses.join(", "));
134
135                let mut all_params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
136                    self.set_args.iter().map(|a| a.as_ref()).collect();
137                for (_, _, val) in &self.inc_ops {
138                    all_params.push(val);
139                }
140
141                if !self.where_fragments.is_empty() {
142                    let where_clauses: Vec<String> = self.where_fragments.iter()
143                        .enumerate()
144                        .map(|(i, &(col, _))| format!("{} = ${}", col, self.set_args.len() + self.inc_ops.len() + i + 1))
145                        .collect();
146                    sql.push_str(" WHERE ");
147                    sql.push_str(&where_clauses.join(" AND "));
148
149                    for arg in &self.where_args {
150                        all_params.push(arg.as_ref());
151                    }
152                }
153
154                sql.push_str(" RETURNING *");
155
156                let row = self.client.query_one(&sql, &all_params[..]).await?;
157                Ok(#model_name {
158                    #(#field_gets),*
159                })
160            }
161        }
162    }
163}