1use std::str::FromStr;
2
3use proc_macro::TokenStream;
4use quote::{ToTokens, quote};
5use syn::{
6 __private::{Span, TokenStream2},
7 Attribute, Data, DeriveInput, Fields, FieldsNamed, GenericArgument, Ident, Meta, NestedMeta,
8 PathArguments, Type, parse_macro_input,
9};
10
11use crate::field::Field;
12
13mod field;
14
15#[proc_macro_derive(Reflected)]
17pub fn reflected(stream: TokenStream) -> TokenStream {
18 let mut stream = parse_macro_input!(stream as DeriveInput);
19
20 let Data::Struct(data) = &mut stream.data else {
21 panic!("`db_entity` macro has to be used with structs")
22 };
23
24 let Fields::Named(struct_fields) = &mut data.fields else {
25 panic!()
26 };
27
28 let (rename, fields) = parse_fields(struct_fields);
29
30 let name = stream.ident.clone();
31
32 let name_string = if let Some(rename) = rename {
33 TokenStream2::from_str(&format!("\"{rename}\""))
34 } else {
35 TokenStream2::from_str(&format!("\"{name}\""))
36 }
37 .unwrap();
38
39 let fields_struct_name = Ident::new(&format!("{name}Fields"), Span::call_site());
40
41 let fields_struct = fields_struct(&name, &fields);
42 let fields_const_var = fields_const_var(&name, &fields);
43 let fields_reflect = fields_reflect(&name, &fields);
44 let simple_fields_reflect = simple_fields_reflect(&name, &fields);
45 let get_value = fields_get_value(&fields);
46 let set_value = fields_set_value(&fields);
47 let sqlx_bind = fields_sqlx_bind(&fields);
48
49 quote! {
50 #[derive(Debug)]
51 pub struct #fields_struct_name {
52 #fields_struct
53 }
54
55 impl #name {
56 pub const FIELDS: #fields_struct_name = #fields_struct_name {
57 #fields_const_var
58 };
59 }
60
61 impl reflected::Reflected for #name {
62 fn type_name() -> &'static str {
63 #name_string
64 }
65
66 fn fields() -> &'static [reflected::Field<Self>] {
67 &[
68 #fields_reflect
69 ]
70 }
71
72 fn simple_fields() -> &'static [reflected::Field<Self>] {
73 &[
74 #simple_fields_reflect
75 ]
76 }
77
78 fn get_value(&self, field: reflected::Field<Self>) -> String {
79 use std::borrow::Borrow;
80 use reflected::ToReflectedString;
81 let field = field.borrow();
82
83 if field.is_custom() {
84 panic!("get_value method is not supported for custom types: {field:?}");
85 }
86
87 match field.name {
88 #get_value
89 _ => unreachable!("Invalid field name in get_value: {}", field.name),
90 }
91 }
92
93 fn set_value(&mut self, field: reflected::Field<Self>, value: Option<&str>) {
94 use reflected::ToReflectedVal;
95 use std::borrow::Borrow;
96 let field = field.borrow();
97 match field.name {
98 #set_value
99 _ => unreachable!("Invalid field name in set_value: {}", field.name),
100 }
101 }
102
103 fn bind_to_sqlx_query<'q, O>(self, query: sqlx::query::QueryAs<'q, sqlx::Postgres, O, <sqlx::Postgres as sqlx::Database>::Arguments<'q>>,)
104 -> sqlx::query::QueryAs<'q, sqlx::Postgres, O, <sqlx::Postgres as sqlx::Database>::Arguments<'q>> {
105 let mut query = query;
106 #sqlx_bind
107 query
108 }
109 }
110 }
111 .into()
112}
113
114fn fields_const_var(type_name: &Ident, fields: &Vec<Field>) -> TokenStream2 {
115 let mut res = quote!();
116
117 let type_name = TokenStream2::from_str(&format!("\"{type_name}\"")).unwrap();
118
119 for field in fields {
120 let name = &field.name;
121
122 let field_type = field.field_type();
123
124 let field_type_name = field.type_as_string();
125 let name_string = field.name_as_string();
126
127 let optional = field.optional;
128
129 let tp = if optional {
130 quote! {
131 tp: reflected::Type::#field_type.to_optional()
132 }
133 } else {
134 quote! {
135 tp: reflected::Type::#field_type
136 }
137 };
138
139 res = quote! {
140 #res
141 #name: reflected::Field {
142 name: #name_string,
143 #tp,
144 type_name: #field_type_name,
145 parent_name: #type_name,
146 optional: #optional,
147 _p: std::marker::PhantomData,
148 },
149 }
150 }
151
152 res
153}
154
155fn fields_struct(type_name: &Ident, fields: &Vec<Field>) -> TokenStream2 {
156 let mut res = quote!();
157
158 for field in fields {
159 let name = &field.name;
160 res = quote! {
161 #res
162 pub #name: reflected::Field<#type_name>,
163 }
164 }
165
166 quote! {
167 #res
168 }
169}
170
171fn fields_reflect(name: &Ident, fields: &Vec<Field>) -> TokenStream2 {
172 let mut res = quote!();
173
174 for field in fields {
175 let field_name = &field.name;
176 res = quote! {
177 #res
178 #name::FIELDS.#field_name,
179 }
180 }
181
182 res
183}
184
185fn simple_fields_reflect(name: &Ident, fields: &Vec<Field>) -> TokenStream2 {
186 let mut res = quote!();
187
188 for field in fields {
189 if !field.is_simple() {
190 continue;
191 }
192 let field_name = &field.name;
193 res = quote! {
194 #res
195 #name::FIELDS.#field_name,
196 }
197 }
198
199 res
200}
201
202fn fields_get_value(fields: &Vec<Field>) -> TokenStream2 {
203 let mut res = quote!();
204
205 for field in fields {
206 if field.custom() {
207 continue;
208 }
209
210 let field_name = &field.name;
211 let name_string = field.name_as_string();
212
213 if field.is_bool() {
214 if field.optional {
215 res = quote! {
216 #res
217 #name_string => self.#field_name.map(|a| if a { "1" } else { "0" }.to_string()).unwrap_or("NULL".to_string()),
218 }
219 } else {
220 res = quote! {
221 #res
222 #name_string => if self.#field_name { "1" } else { "0" }.to_string(),
223 }
224 }
225 } else if field.optional || field.is_float() {
226 res = quote! {
227 #res
228 #name_string => self.#field_name.to_reflected_string(),
229 }
230 } else {
231 res = quote! {
232 #res
233 #name_string => self.#field_name.to_string(),
234 }
235 }
236 }
237
238 res
239}
240
241fn fields_set_value(fields: &Vec<Field>) -> TokenStream2 {
242 let mut res = quote!();
243
244 for field in fields {
245 if field.custom() {
246 continue;
247 }
248
249 let field_name = &field.name;
250 let name_string = field.name_as_string();
251
252 if field.is_bool() {
253 if field.optional {
254 res = quote! {
255 #res
256 #name_string => {
257 self.#field_name = value.map(|a| match a {
258 "0" => false,
259 "1" => true,
260 _ => unreachable!("Invalid value in bool: {value:?}")
261 })
262 },
263 }
264 } else {
265 res = quote! {
266 #res
267 #name_string => {
268 self.#field_name = match value.unwrap() {
269 "0" => false,
270 "1" => true,
271 _ => unreachable!("Invalid value in bool: {value:?}")
272 }
273 },
274 }
275 }
276 } else if field.is_date() {
277 res = quote! {
278 #res
279 #name_string => self.#field_name = chrono::NaiveDateTime::parse_from_str(&value.unwrap(), "%Y-%m-%d %H:%M:%S%.9f").unwrap(),
280 }
281 } else if field.optional {
282 res = quote! {
283 #res
284 #name_string => self.#field_name = value.map(|a| a.to_reflected_val()
285 .expect(&format!("Failed to convert to: {} from: {}", #name_string, a))),
286 }
287 } else {
288 res = quote! {
289 #res
290 #name_string => self.#field_name = value.unwrap().to_reflected_val()
291 .expect(&format!("Failed to convert to: {} from: {}", #name_string, value.unwrap())),
292 }
293 }
294 }
295
296 res
297}
298
299fn fields_sqlx_bind(fields: &Vec<Field>) -> TokenStream2 {
300 let mut res = quote!();
301
302 for field in fields {
303 let field_name = &field.name;
304
305 if field.custom() || field.is_date() {
306 continue;
307 }
308
309 if field.tp == "Decimal" || field.tp == "usize" {
310 continue;
311 }
312
313 res = quote! {
314 #res
315 query = query.bind(self.#field_name);
316 };
317 }
318
319 res
320}
321
322fn parse_fields(fields: &FieldsNamed) -> (Option<String>, Vec<Field>) {
323 let mut rename: Option<String> = None;
324
325 let fields: Vec<Field> = fields
326 .named
327 .iter()
328 .map(|field| {
329 let name = field.ident.as_ref().unwrap().clone();
330 let mut optional = false;
331
332 let Type::Path(path) = &field.ty else {
333 unreachable!("invalid parse_fields")
334 };
335
336 let mut tp = path.path.segments.first().unwrap().ident.clone();
337
338 if tp == "Option" {
339 optional = true;
340 let args = &path.path.segments.first().unwrap().arguments;
341 if let PathArguments::AngleBracketed(args) = args {
342 if let GenericArgument::Type(generic_tp) = args.args.first().unwrap() {
343 let ident = generic_tp.to_token_stream().to_string();
344 let ident = Ident::new(&ident, Span::call_site());
345 tp = ident;
346 } else {
347 unreachable!()
348 }
349 } else {
350 unreachable!()
351 }
352 }
353
354 let _attrs: Vec<String> = field
355 .attrs
356 .iter()
357 .map(|a| {
358 let name = get_attribute_name(a);
359 if name == "name" {
360 rename = get_attribute_value(a).expect("name attribute should have value").into();
361 }
362 name
363 })
364 .collect();
365
366 Field { name, tp, optional }
367 })
368 .collect();
369
370 (rename, fields)
371}
372
373fn get_attribute_name(attribute: &Attribute) -> String {
374 attribute.path.segments.first().unwrap().ident.to_string()
375}
376
377fn get_attribute_value(attribute: &Attribute) -> Option<String> {
378 if let Ok(Meta::List(meta_list)) = attribute.parse_meta() {
379 if let NestedMeta::Meta(Meta::Path(path)) = &meta_list.nested[0] {
380 return Some(path.segments.last()?.ident.to_string());
381 }
382 }
383 None
384}