rssql_macro/
lib.rs

1use proc_macro::TokenStream;
2
3use quote::{quote, ToTokens};
4// no need to import a specific crate for TokenStream
5use syn::{Expr, ExprLit, Lit};
6use syn::{Meta, Path, punctuated::Punctuated, token::Comma};
7use syn::Data::Struct;
8use syn::DataStruct;
9use syn::Fields::Named;
10use syn::FieldsNamed;
11
12use crate::utils::{
13    parse_table_name,
14    extract_type_from_option,
15};
16
17mod utils;
18
19// use proc_macro2::{Span, TokenTree};
20// use proc_macro2::TokenTree::Group;
21// use proc_macro2::Ident;
22// use syn::{parse_macro_input, DeriveInput};
23// use syn::token::Token;
24
25
26#[proc_macro_derive(ORM, attributes(rssql))]
27pub fn show_streams(tokens: TokenStream) -> TokenStream {
28    // println!("attr: \"{}\"", attr.to_string());
29    println!("item: \"{}\"", tokens.to_string());
30    // let t: proc_macro2::TokenStream = tokens.clone().into();
31    let ast: syn::DeriveInput = syn::parse(tokens).unwrap();
32    // let struct_name = ast.ident;
33    let table_name = parse_table_name(&ast.attrs);
34    let struct_name = ast.ident;
35
36    // println!("attrs: {:?}", ast.attrs);
37    // println!("data: {:?}", ast.data);
38
39    let fields = match ast.data {
40        Struct(DataStruct { fields: Named(FieldsNamed { ref named, .. }), .. }) => named,
41        _ => unimplemented!()
42    };
43
44    let builder_types = fields.iter().map(|f| {
45        let mn = f.clone().ident.unwrap().to_string();
46        let ty = &f.ty.to_token_stream().to_string();
47        quote! {
48            #mn => #ty
49        }
50    });
51
52    let builder_fields_mapping = fields.iter().map(|f| f.clone().ident.unwrap().to_string());
53
54    let builder_row_func = fields.iter().map(|f| {
55        let mn = f.clone().ident.unwrap().to_string();
56        let field_name = format!("{}.{}", &table_name, &mn);
57        let ty = &f.ty;
58        let ty = match extract_type_from_option(ty) {
59            Some(value) => value,
60            None => ty
61        };
62        let type_name = ty.to_token_stream().to_string();
63        return match type_name.as_str() {
64            "String" => {
65                quote! {
66                    map.insert(#mn.to_string(), row.get::<&str, &str>(#field_name).into())
67                }
68            }
69            "NaiveDateTime" => {
70                quote! {
71                    map.insert(#mn.to_string(), row.get::<#ty, &str>(#field_name).unwrap().to_string().into())
72                }
73            }
74            _ => {
75                quote! {
76                    map.insert(#mn.to_string(), row.get::<#ty, &str>(#field_name).into())
77                }
78            }
79        };
80    });
81
82    // for pushing elements to TokenRow which used in bulk insert
83    let builder_insert_rows = fields.iter().map(|f| {
84        let field = f.clone().ident.unwrap();
85        return quote! {
86            row.push(item.#field.into_sql())
87        };
88    });
89
90    // for insert one
91    let builder_insert_fields = fields.iter()
92        .map(|f| { f.clone().ident.unwrap().to_string() })
93        .reduce(|cur: String, next: String| format!("{},{}", cur, &next)).unwrap();
94    let mut fields_count = 0;
95    let builder_insert_params = fields.iter()
96        .map(|_| {
97            fields_count += 1;
98            return format!("@p{}", fields_count);
99        })
100        .reduce(|cur: String, next: String| format!("{},{}", cur, &next)).unwrap();
101    let builder_insert_data = fields.iter().map(|f|
102        f.clone().ident.unwrap()
103    )
104        // .filter(|x| { *x.to_string() != "id".to_string() })
105        .map(|f| return quote! {&self.#f});
106
107    // for update one
108    fields_count = 0;
109    let builder_update_fields = fields.iter()
110        .map(|f| {
111            fields_count += 1;
112            return format!(" {} = @p{}", f.clone().ident.unwrap().to_string(), fields_count);
113        })
114        .reduce(|cur: String, next: String| format!("{},{}", cur, &next)).unwrap();
115    let builder_update_data = builder_insert_data.clone();
116
117
118    #[cfg(feature = "polars")]
119        let builder_new_vecs = fields.iter().map(|f| {
120        let field = f.clone().ident.unwrap();
121        let ty = &f.ty;
122        quote! {
123            let mut #field : Vec<#ty> = vec![]
124        }
125    });
126
127    #[cfg(feature = "polars")]
128        let builder_insert_to_df = fields.iter().map(|f| {
129        let field = f.clone().ident.unwrap();
130        quote! {
131            #field.push(Phant_Name1.#field)
132        }
133    });
134
135    #[cfg(feature = "polars")]
136        let builder_df = fields.iter().map(|f| {
137        let field = f.clone().ident.unwrap();
138        let mn = field.to_string();
139        quote! {
140            #mn => #field
141        }
142    });
143
144    // for getting vectors of self struct
145    let builder_row_to_self_func = fields.iter().map(|f| {
146        let mn = f.clone().ident.unwrap();
147        let field_name = format!("{}.{}", &table_name, &mn.to_string());
148        let ty = &f.ty;
149        return match extract_type_from_option(ty) {
150            Some(value) => {
151                let type_name = value.to_token_stream().to_string();
152                match type_name.as_str() {
153                    "String" => {
154                        quote! {
155                            #mn: row.get::<&str, &str>(#field_name).map(|i| i.to_string())
156                        }
157                    }
158                    _ => {
159                        quote! {
160                            #mn: row.get::<#value, &str>(#field_name)
161                        }
162                    }
163                }
164            }
165            None => {
166                let type_name = ty.to_token_stream().to_string();
167                match type_name.as_str() {
168                    "String" => {
169                        quote! {
170                            #mn: row.get::<&str, &str>(#field_name).unwrap().to_string()
171                        }
172                    }
173                    _ => {
174                        quote! {
175                            #mn: row.get::<#ty, &str>(#field_name).unwrap()
176                        }
177                    }
178                }
179            }
180        };
181    });
182
183    let mut result = quote! {
184    };
185
186    let mut relations: Vec<String> = vec![];
187    let mut tables: Vec<String> = vec![];
188    let mut primary_key = None;
189    for field in fields.iter() {
190        for attr in field.attrs.iter() {
191            if let Some(ident) = attr.path().get_ident() {
192                if ident == "rssql" {
193                    if let Ok(list) = attr.parse_args_with(Punctuated::<Meta, Comma>::parse_terminated) {
194                        for meta in list.iter() {
195                            if let Meta::Path(path) = meta {
196                                let Path { ref segments, .. } = path;
197                                for rssql_tags in segments.iter() {
198                                    if rssql_tags.ident == "primary_key" {
199                                        primary_key = Some(field.clone());
200                                    }
201                                }
202                            }
203
204                            if let Meta::NameValue(named_v) = meta {
205                                let Path { ref segments, .. } = &named_v.path;
206                                for rssql_tags in segments.iter() {
207                                    if rssql_tags.ident == "foreign_key" {
208                                        if let Expr::Lit(ExprLit { lit, .. }) = &named_v.value {
209                                            if let Lit::Str(v) = lit {
210                                                let field_name = field.ident.as_ref().unwrap().to_string();
211                                                relations.push(format!("{}.{} = {}", &table_name, field_name, v.value()));
212                                                tables.push(v.value()[..v.value().rfind('.').unwrap()].to_string());
213                                            }
214                                        }
215                                        // if let Expr::Path(p_v) = &named_v.value {
216                                        //     dbg!(&p_v);
217                                        //     for seg in p_v.path.segments.iter() {
218                                        //         let i = &seg.ident;
219                                        //         result.extend(quote! {
220                                        //                 impl #struct_name {
221                                        //                     fn #i() -> String{
222                                        //                         "asdf".to_string()
223                                        //                     }
224                                        //                 }
225                                        //             })
226                                        //     }
227                                        // }
228                                        // if let Expr::Lit{, ..} = &named_v.value {
229                                        //     dbg!(&p_v);
230                                        //     for seg in p_v.path.segments.iter() {
231                                        //         let i = &seg.ident;
232                                        //         result.extend(quote! {
233                                        //                 impl #struct_name {
234                                        //                     fn #i() -> String{
235                                        //                         "asdf".to_string()
236                                        //                     }
237                                        //                 }
238                                        //             })
239                                        //     }
240                                        // }
241                                    }
242                                }
243                            }
244                        }
245                    }
246                }
247            }
248        }
249    }
250
251    let builder_fields = relations.iter().zip(tables.iter()).map(|(rel, tb)| {
252        quote! { #tb => {
253            concat!(" ", #tb, " ON ", #rel)
254            // format!("JOIN {} ON {}", #tb, #rel)
255        }}
256    });
257
258    let pk = if let Some(f) = primary_key {
259        let field_name = f.ident.as_ref().unwrap().to_string();
260        let mn = f.ident.unwrap();
261        quote! {
262            impl #struct_name {
263                fn primary_key(&self) -> (&'static str, ColumnData) {
264                    (#field_name, self.#mn.to_sql())
265                }
266            }
267        }
268    } else {
269        quote! {
270            impl #struct_name {
271                fn primary_key(&self) -> (&'static str, ColumnData) {
272                    unimplemented!("Primary key not set");
273                }
274            }
275        }
276    };
277    result.extend(pk);
278
279    result.extend(quote! {
280        #[async_trait(?Send)]
281        impl RssqlMarker for #struct_name {
282            fn table_name() -> &'static str {
283                #table_name
284            }
285
286            fn fields() -> Vec<&'static str> {
287                vec![#(#builder_fields_mapping,)*]
288            }
289
290            fn row_to_json(row:&Row) -> Map<String, Value> {
291                let mut map = Map::new();
292                #(#builder_row_func;)*
293                map
294            }
295
296            fn row_to_struct(row:&Row) -> Self {
297                Self{
298                    #(#builder_row_to_self_func,)*
299                }
300            }
301
302            async fn insert_many(iter: impl IntoIterator<Item = #struct_name> , conn: &mut Client<Compat<TcpStream>>) -> RssqlResult<u64>
303            // where I:  impl Iterator<Item = #struct_name>
304            {
305                let mut req = conn.bulk_insert(#table_name).await?;
306                for item in iter{
307                    let mut row = TokenRow::new();
308                    #(#builder_insert_rows;)*
309                    req.send(row).await?;
310                }
311                let res = req.finalize().await?;
312                Ok(res.total())
313            }
314
315            async fn insert(self, conn: &mut Client<Compat<TcpStream>>) -> RssqlResult<()> {
316                let sql = format!("INSERT INTO {} ({}) values({})", #table_name, #builder_insert_fields, #builder_insert_params);
317                conn.execute(sql, &[#(#builder_insert_data,)*]).await?;
318                Ok(())
319            }
320
321            async fn delete(self, conn: &mut Client<Compat<TcpStream>>) -> RssqlResult<()> {
322                let (pk, dt) = self.primary_key();
323                QueryBuilder::<#struct_name>::delete(&dt, #table_name, pk, conn).await?;
324                Ok(())
325            }
326
327            async fn update(&self, conn: &mut Client<Compat<TcpStream>>) -> RssqlResult<()> {
328                let (pk, dt) = self.primary_key();
329                let sql = format!("UPDATE {} SET {} WHERE {} {}", #table_name, #builder_update_fields, pk, QueryBuilder::<#struct_name>::process_pk_condition(&dt));
330                conn.execute(sql, &[#(#builder_update_data,)*]).await?;
331                Ok(())
332            }
333
334        }
335        impl #struct_name {
336
337            fn relationship(input: &str) -> &'static str {
338                match input {
339                    #(#builder_fields,)*
340                    _ =>  unimplemented!("relationship not found"),
341                }
342            }
343
344            fn column_type(input: &str) -> &'static str{
345                match input {
346                    #(#builder_types,)*
347                    _ =>  unimplemented!("column_type not found"),
348                }
349            }
350
351            pub fn query() -> QueryBuilder<#struct_name> {
352                QueryBuilder::<#struct_name>::new(
353                    (#table_name, #struct_name::fields()),
354                    #struct_name::relationship)
355            }
356
357        }
358    });
359
360
361
362    #[cfg(feature = "polars")]
363    result.extend(quote! {
364        impl PolarsHelper for #struct_name {
365            fn dataframe(vec: Vec<Self>) -> PolarsResult<DataFrame> {
366                #(#builder_new_vecs;)*
367                #[allow(non_snake_case)]
368                for Phant_Name1 in vec {
369                    #(#builder_insert_to_df;)*
370                }
371                df!(
372                    #(#builder_df,)*
373                )
374            }
375        }
376    });
377
378    result.into()
379}
380