r2d2_mysql_batis_macros/
lib.rs

1use proc_macro::TokenStream;
2use std::ops::Add;
3use quote::quote;
4use syn::{parse_macro_input, DeriveInput, Fields, Type};
5
6#[proc_macro_attribute]
7pub fn entity_option_mapping(attr: TokenStream, input: TokenStream) -> TokenStream {
8    let attr = attr.to_string();
9    let mut attr_split = attr.split(",");
10    let mut table_name = "";
11    if let Some(table_name_option) = attr_split.next() {
12        table_name = table_name_option.trim();
13    }
14    let mut primary_key = "";
15    if let Some(primary_key_option) = attr_split.next() {
16        primary_key = primary_key_option.trim();
17    }
18    let input_tokens: proc_macro2::TokenStream = input.clone().into();
19    let derive_input = parse_macro_input!(input as DeriveInput);
20    let struct_name = &derive_input.ident;
21    let mut primary_key_name = &derive_input.ident;
22
23    if let syn::Data::Struct(data) = &derive_input.data {
24        if let Fields::Named(named_fields) = &data.fields {
25            let mut assignment = vec![];
26            let mut stmts = vec![];
27            let mut params = vec![];
28            let mut update_sql = vec![];
29            let mut type_names = vec![];
30
31            for field in &named_fields.named {
32                if let Some(ident) = &field.ident {
33                    assignment.push(quote! {#ident: None,});
34                    let field_name = ident.to_string();
35                    if field_name.eq(primary_key) {
36                        primary_key_name = &ident;
37                    }
38                    type_names.push(field_name.clone());
39                    update_sql.push(quote!{
40                        if !self.#ident.eq(&None) {
41                            update_sql = update_sql.add(" `");
42                            update_sql = update_sql.add(#field_name);
43                            update_sql = update_sql.add("` = :");
44                            update_sql = update_sql.add(#field_name);
45                            update_sql = update_sql.add(",");
46                        }
47                    });
48                    let field_type = &field.ty;
49                    if let Type::Path(type_path) = field_type {
50                        // let type_name = &type_path.path.segments.last().unwrap().ident;
51                        // let type_name = type_name.to_string();
52                        let arguments = &type_path.path.segments.last().unwrap().arguments;
53                        if let syn::PathArguments::AngleBracketed(args) = arguments {
54                            if let Some(inner_ty) = args.args.first() {
55                                let inner_ty_str = quote!(#inner_ty).to_string();
56                                params.push(quote! {
57                                    if self.#ident.eq(&None) {
58                                        map.insert(#field_name.as_bytes().to_vec(), r2d2_mysql::mysql::Value::NULL);
59                                    }
60                                });
61                                if inner_ty_str.eq("String") {
62                                    stmts.push(quote! {
63                                        if columns.contains(&&*#field_name) {
64                                            let value = &row[#field_name];
65                                            match value {
66                                                r2d2_mysql::mysql::Value::NULL => {}
67                                                r2d2_mysql::mysql::Value::Date(year, month, day, hour, minutes, seconds, micro_seconds) => {
68                                                    vo.#ident = Some(format!("{}-{}-{} {}:{}:{}.{}", year, month, day, hour, minutes, seconds, micro_seconds));
69                                                }
70                                                r2d2_mysql::mysql::Value::Time(_is_negative, days, hours, minutes, seconds, micro_seconds) => {
71                                                    vo.#ident = Some(format!("{} {}:{}:{}.{}", days, hours, minutes, seconds, micro_seconds));
72                                                }
73                                                _ => {
74                                                    vo.#ident = row.get(#field_name);
75                                                }
76                                            }
77                                        }
78                                    });
79                                    params.push(quote! {
80                                        else {
81                                            map.insert(#field_name.as_bytes().to_vec(), r2d2_mysql::mysql::Value::Bytes(self.#ident.clone().unwrap().as_bytes().to_vec()));
82                                        }
83                                    });
84                                    continue;
85                                } else if inner_ty_str.eq("i8") {
86                                    params.push(quote! {
87                                        else {
88                                            map.insert(#field_name.as_bytes().to_vec(), r2d2_mysql::mysql::Value::Int(self.#ident.unwrap().into()));
89                                        }
90                                    });
91                                } else if inner_ty_str.eq("u8") {
92                                    params.push(quote! {
93                                        else {
94                                            map.insert(#field_name.as_bytes().to_vec(), r2d2_mysql::mysql::Value::UInt(self.#ident.unwrap()));
95                                        }
96                                    });
97                                } else if inner_ty_str.eq("i16") {
98                                    params.push(quote! {
99                                        else {
100                                            map.insert(#field_name.as_bytes().to_vec(), r2d2_mysql::mysql::Value::Int(self.#ident.unwrap()));
101                                        }
102                                    });
103                                } else if inner_ty_str.eq("u16") {
104                                    params.push(quote! {
105                                        else {
106                                            map.insert(#field_name.as_bytes().to_vec(), r2d2_mysql::mysql::Value::UInt(self.#ident.unwrap()));
107                                        }
108                                    });
109                                } else if inner_ty_str.eq("i32") {
110                                    params.push(quote! {
111                                        else {
112                                            map.insert(#field_name.as_bytes().to_vec(), r2d2_mysql::mysql::Value::Int(self.#ident.unwrap()));
113                                        }
114                                    });
115                                } else if inner_ty_str.eq("u32") {
116                                    params.push(quote! {
117                                        else {
118                                            map.insert(#field_name.as_bytes().to_vec(), r2d2_mysql::mysql::Value::UInt(self.#ident.unwrap()));
119                                        }
120                                    });
121                                } else if inner_ty_str.eq("i64") {
122                                    params.push(quote! {
123                                        else {
124                                            map.insert(#field_name.as_bytes().to_vec(), r2d2_mysql::mysql::Value::Int(self.#ident.unwrap()));
125                                        }
126                                    });
127                                } else if inner_ty_str.eq("u64") {
128                                    params.push(quote! {
129                                        else {
130                                            map.insert(#field_name.as_bytes().to_vec(), r2d2_mysql::mysql::Value::UInt(self.#ident.unwrap()));
131                                        }
132                                    });
133                                } else if inner_ty_str.eq("isize") {
134                                    params.push(quote! {
135                                        else {
136                                            map.insert(#field_name.as_bytes().to_vec(), r2d2_mysql::mysql::Value::Int(self.#ident.unwrap()));
137                                        }
138                                    });
139                                } else if inner_ty_str.eq("usize") {
140                                    params.push(quote! {
141                                        else {
142                                            map.insert(#field_name.as_bytes().to_vec(), r2d2_mysql::mysql::Value::UInt(self.#ident.unwrap()));
143                                        }
144                                    });
145                                } else if inner_ty_str.eq("f32") {
146                                    params.push(quote! {
147                                        else {
148                                            map.insert(#field_name.as_bytes().to_vec(), r2d2_mysql::mysql::Value::Float(self.#ident.unwrap()));
149                                        }
150                                    });
151                                } else if inner_ty_str.eq("f64") {
152                                    params.push(quote! {
153                                        else {
154                                            map.insert(#field_name.as_bytes().to_vec(), r2d2_mysql::mysql::Value::Double(self.#ident.unwrap()));
155                                        }
156                                    });
157                                }
158                                stmts.push(quote! {
159                                    if columns.contains(&&*#field_name) {
160                                        let value = &row[#field_name];
161                                        if !r2d2_mysql::mysql::Value::NULL.eq(value) {
162                                            vo.#ident = row.get(#field_name);
163                                        }
164                                    }
165                                });
166                            }
167                        }
168                    }
169                }
170            }
171
172            let mut insert_sql = String::from("INSERT INTO `");
173            insert_sql = insert_sql.add(table_name);
174            insert_sql = insert_sql.add("` (");
175            for type_name in type_names.clone() {
176                if type_name.eq(primary_key) {
177                    continue;
178                }
179                insert_sql = insert_sql.add("`");
180                insert_sql = insert_sql.add(&*type_name);
181                insert_sql = insert_sql.add("`,");
182            }
183            insert_sql.remove(insert_sql.len() - 1);
184            insert_sql = insert_sql.add(") VALUES (");
185            for type_name in type_names {
186                if type_name.eq(primary_key) {
187                    continue;
188                }
189                insert_sql = insert_sql.add(":");
190                insert_sql = insert_sql.add(&*type_name);
191                insert_sql = insert_sql.add(",");
192            }
193            insert_sql.remove(insert_sql.len() - 1);
194            insert_sql = insert_sql.add(")");
195
196            let impl_fn = quote! {
197                use std::ops::Add;
198                #input_tokens
199
200                impl r2d2_mysql_batis::entity::Entity for #struct_name {
201                    fn table_name() -> &'static str {
202                        #table_name
203                    }
204
205                    fn primary_key() -> &'static str {
206                        #primary_key
207                    }
208
209                    fn insert_sql() -> &'static str {
210                        #insert_sql
211                    }
212
213                    fn update_by_id_sql(&self) -> String {
214                        let mut update_sql = String::from("UPDATE `");
215                        update_sql = update_sql.add(#table_name);
216                        update_sql = update_sql.add("` SET");
217                        #(#update_sql)*
218                        update_sql.remove(update_sql.len() - 1);
219                        update_sql = update_sql.add(" WHERE `");
220                        update_sql = update_sql.add(#primary_key);
221                        update_sql = update_sql.add("` = :user_id LIMIT 1");
222                        update_sql
223                    }
224
225                    fn params(&self) -> r2d2_mysql::mysql::Params {
226                        let mut map = std::collections::HashMap::new();
227                        #(#params)*
228                        r2d2_mysql::mysql::Params::Named(map)
229                    }
230                }
231
232                impl r2d2_mysql_batis::service::Service for #struct_name {
233                    fn set_primary_key(&mut self, id: u64) {
234                        self.#primary_key_name = Some(id);
235                    }
236                }
237
238                impl r2d2_mysql::mysql::prelude::FromRow for #struct_name {
239                    fn from_row(row: r2d2_mysql::mysql::Row) -> Self where Self: Sized {
240                        let mut vo = Self {
241                            #(#assignment)*
242                        };
243                        let mut columns = vec![];
244                        for column in row.columns_ref() {
245                            columns.push(std::str::from_utf8(column.name_ref()).unwrap());
246                        }
247                        #(#stmts)*
248                        vo
249                    }
250
251                    fn from_row_opt(row: r2d2_mysql::mysql::Row) -> Result<Self, r2d2_mysql::mysql::FromRowError> where Self: Sized {
252                        let mut vo = Self {
253                            #(#assignment)*
254                        };
255                        let mut columns = vec![];
256                        for column in row.columns_ref() {
257                            columns.push(std::str::from_utf8(column.name_ref()).unwrap());
258                        }
259                        #(#stmts)*
260                        Ok(vo)
261                    }
262                }
263            };
264            return impl_fn.into();
265        }
266    }
267    TokenStream::default()
268}