reflected_proc/
lib.rs

1use std::str::FromStr;
2
3use proc_macro::TokenStream;
4use quote::{ToTokens, quote};
5use syn::{
6    __private::{Span, TokenStream2},
7    Attribute, Data, DeriveInput, Fields, FieldsNamed, GenericArgument, Ident, Meta, NestedMeta,
8    PathArguments, Type, parse_macro_input,
9};
10
11use crate::field::Field;
12
13mod field;
14
15/// Data must also derive `Default`
16#[proc_macro_derive(Reflected)]
17pub fn reflected(stream: TokenStream) -> TokenStream {
18    let mut stream = parse_macro_input!(stream as DeriveInput);
19
20    let Data::Struct(data) = &mut stream.data else {
21        panic!("`db_entity` macro has to be used with structs")
22    };
23
24    let Fields::Named(struct_fields) = &mut data.fields else {
25        panic!()
26    };
27
28    let (rename, fields) = parse_fields(struct_fields);
29
30    let name = stream.ident.clone();
31
32    let name_string = if let Some(rename) = rename {
33        TokenStream2::from_str(&format!("\"{rename}\""))
34    } else {
35        TokenStream2::from_str(&format!("\"{name}\""))
36    }
37    .unwrap();
38
39    let fields_struct_name = Ident::new(&format!("{name}Fields"), Span::call_site());
40
41    let fields_struct = fields_struct(&name, &fields);
42    let fields_const_var = fields_const_var(&name, &fields);
43    let fields_reflect = fields_reflect(&name, &fields);
44    let simple_fields_reflect = simple_fields_reflect(&name, &fields);
45    let get_value = fields_get_value(&fields);
46    let set_value = fields_set_value(&fields);
47    let sqlx_bind = fields_sqlx_bind(&fields);
48
49    quote! {
50        #[derive(Debug)]
51        pub struct #fields_struct_name {
52            #fields_struct
53        }
54
55        impl #name {
56            pub const FIELDS: #fields_struct_name = #fields_struct_name {
57                #fields_const_var
58            };
59        }
60
61        impl reflected::Reflected for #name {
62            fn type_name() -> &'static str {
63                #name_string
64            }
65
66            fn fields() -> &'static [reflected::Field<Self>] {
67                &[
68                    #fields_reflect
69                ]
70            }
71
72            fn simple_fields() -> &'static [reflected::Field<Self>] {
73                &[
74                    #simple_fields_reflect
75                ]
76            }
77
78            fn get_value(&self, field: reflected::Field<Self>) -> String {
79                use std::borrow::Borrow;
80                use reflected::ToReflectedString;
81                let field = field.borrow();
82
83                if field.is_custom() {
84                    panic!("get_value method is not supported for custom types: {field:?}");
85                }
86
87                match field.name {
88                    #get_value
89                    _ => unreachable!("Invalid field name in get_value: {}", field.name),
90                }
91            }
92
93            fn set_value(&mut self, field: reflected::Field<Self>, value: Option<&str>) {
94                use reflected::ToReflectedVal;
95                use std::borrow::Borrow;
96                let field = field.borrow();
97                match field.name {
98                    #set_value
99                    _ => unreachable!("Invalid field name in set_value: {}", field.name),
100                }
101            }
102
103            fn bind_to_sqlx_query<'q, O>(self, query: sqlx::query::QueryAs<'q, sqlx::Postgres, O, <sqlx::Postgres as sqlx::Database>::Arguments<'q>>,)
104                ->  sqlx::query::QueryAs<'q, sqlx::Postgres, O, <sqlx::Postgres as sqlx::Database>::Arguments<'q>> {
105                let mut query = query;
106                #sqlx_bind
107                query
108            }
109        }
110    }
111    .into()
112}
113
114fn fields_const_var(type_name: &Ident, fields: &Vec<Field>) -> TokenStream2 {
115    let mut res = quote!();
116
117    let type_name = TokenStream2::from_str(&format!("\"{type_name}\"")).unwrap();
118
119    for field in fields {
120        let name = &field.name;
121
122        let field_type = field.field_type();
123
124        let field_type_name = field.type_as_string();
125        let name_string = field.name_as_string();
126
127        let optional = field.optional;
128
129        let tp = if optional {
130            quote! {
131                tp: reflected::Type::#field_type.to_optional()
132            }
133        } else {
134            quote! {
135                tp: reflected::Type::#field_type
136            }
137        };
138
139        res = quote! {
140            #res
141            #name: reflected::Field {
142                name: #name_string,
143                #tp,
144                type_name: #field_type_name,
145                parent_name: #type_name,
146                optional: #optional,
147                _p: std::marker::PhantomData,
148            },
149        }
150    }
151
152    res
153}
154
155fn fields_struct(type_name: &Ident, fields: &Vec<Field>) -> TokenStream2 {
156    let mut res = quote!();
157
158    for field in fields {
159        let name = &field.name;
160        res = quote! {
161            #res
162            pub #name: reflected::Field<#type_name>,
163        }
164    }
165
166    quote! {
167        #res
168    }
169}
170
171fn fields_reflect(name: &Ident, fields: &Vec<Field>) -> TokenStream2 {
172    let mut res = quote!();
173
174    for field in fields {
175        let field_name = &field.name;
176        res = quote! {
177            #res
178            #name::FIELDS.#field_name,
179        }
180    }
181
182    res
183}
184
185fn simple_fields_reflect(name: &Ident, fields: &Vec<Field>) -> TokenStream2 {
186    let mut res = quote!();
187
188    for field in fields {
189        if !field.is_simple() {
190            continue;
191        }
192        let field_name = &field.name;
193        res = quote! {
194            #res
195            #name::FIELDS.#field_name,
196        }
197    }
198
199    res
200}
201
202fn fields_get_value(fields: &Vec<Field>) -> TokenStream2 {
203    let mut res = quote!();
204
205    for field in fields {
206        if field.custom() {
207            continue;
208        }
209
210        let field_name = &field.name;
211        let name_string = field.name_as_string();
212
213        if field.is_bool() {
214            if field.optional {
215                res = quote! {
216                    #res
217                    #name_string => self.#field_name.map(|a| if a { "1" } else { "0" }.to_string()).unwrap_or("NULL".to_string()),
218                }
219            } else {
220                res = quote! {
221                    #res
222                    #name_string => if self.#field_name { "1" } else { "0" }.to_string(),
223                }
224            }
225        } else if field.optional || field.is_float() {
226            res = quote! {
227                #res
228                #name_string => self.#field_name.to_reflected_string(),
229            }
230        } else {
231            res = quote! {
232                #res
233                #name_string => self.#field_name.to_string(),
234            }
235        }
236    }
237
238    res
239}
240
241fn fields_set_value(fields: &Vec<Field>) -> TokenStream2 {
242    let mut res = quote!();
243
244    for field in fields {
245        if field.custom() {
246            continue;
247        }
248
249        let field_name = &field.name;
250        let name_string = field.name_as_string();
251
252        if field.is_bool() {
253            if field.optional {
254                res = quote! {
255                    #res
256                    #name_string =>  {
257                        self.#field_name = value.map(|a| match a {
258                            "0" => false,
259                            "1" => true,
260                            _ => unreachable!("Invalid value in bool: {value:?}")
261                        })
262                    },
263                }
264            } else {
265                res = quote! {
266                    #res
267                    #name_string =>  {
268                        self.#field_name = match value.unwrap() {
269                            "0" => false,
270                            "1" => true,
271                            _ => unreachable!("Invalid value in bool: {value:?}")
272                        }
273                    },
274                }
275            }
276        } else if field.is_date() {
277            res = quote! {
278                #res
279                #name_string => self.#field_name = chrono::NaiveDateTime::parse_from_str(&value.unwrap(), "%Y-%m-%d %H:%M:%S%.9f").unwrap(),
280            }
281        } else if field.optional {
282            res = quote! {
283                #res
284                #name_string => self.#field_name = value.map(|a| a.to_reflected_val()
285                    .expect(&format!("Failed to convert to: {} from: {}", #name_string, a))),
286            }
287        } else {
288            res = quote! {
289                #res
290                #name_string => self.#field_name = value.unwrap().to_reflected_val()
291                .expect(&format!("Failed to convert to: {} from: {}", #name_string, value.unwrap())),
292            }
293        }
294    }
295
296    res
297}
298
299fn fields_sqlx_bind(fields: &Vec<Field>) -> TokenStream2 {
300    let mut res = quote!();
301
302    for field in fields {
303        let field_name = &field.name;
304
305        if field.custom() || field.is_date() {
306            continue;
307        }
308
309        if field.tp == "Decimal" || field.tp == "usize" {
310            continue;
311        }
312
313        res = quote! {
314            #res
315            query = query.bind(self.#field_name);
316        };
317    }
318
319    res
320}
321
322fn parse_fields(fields: &FieldsNamed) -> (Option<String>, Vec<Field>) {
323    let mut rename: Option<String> = None;
324
325    let fields: Vec<Field> = fields
326        .named
327        .iter()
328        .map(|field| {
329            let name = field.ident.as_ref().unwrap().clone();
330            let mut optional = false;
331
332            let Type::Path(path) = &field.ty else {
333                unreachable!("invalid parse_fields")
334            };
335
336            let mut tp = path.path.segments.first().unwrap().ident.clone();
337
338            if tp == "Option" {
339                optional = true;
340                let args = &path.path.segments.first().unwrap().arguments;
341                if let PathArguments::AngleBracketed(args) = args {
342                    if let GenericArgument::Type(generic_tp) = args.args.first().unwrap() {
343                        let ident = generic_tp.to_token_stream().to_string();
344                        let ident = Ident::new(&ident, Span::call_site());
345                        tp = ident;
346                    } else {
347                        unreachable!()
348                    }
349                } else {
350                    unreachable!()
351                }
352            }
353
354            let _attrs: Vec<String> = field
355                .attrs
356                .iter()
357                .map(|a| {
358                    let name = get_attribute_name(a);
359                    if name == "name" {
360                        rename = get_attribute_value(a).expect("name attribute should have value").into();
361                    }
362                    name
363                })
364                .collect();
365
366            Field { name, tp, optional }
367        })
368        .collect();
369
370    (rename, fields)
371}
372
373fn get_attribute_name(attribute: &Attribute) -> String {
374    attribute.path.segments.first().unwrap().ident.to_string()
375}
376
377fn get_attribute_value(attribute: &Attribute) -> Option<String> {
378    if let Ok(Meta::List(meta_list)) = attribute.parse_meta() {
379        if let NestedMeta::Meta(Meta::Path(path)) = &meta_list.nested[0] {
380            return Some(path.segments.last()?.ident.to_string());
381        }
382    }
383    None
384}