mysql_connector_macros/
lib.rs

1extern crate proc_macro;
2
3mod parse;
4
5use {
6    parse::{parse_attr, parse_fields, NamedField, TypeComplexity},
7    proc_macro::TokenStream,
8    proc_macro2::Span,
9    quote::{format_ident, quote},
10    std::fmt,
11    syn::{parse_macro_input, DeriveInput, Ident, LitStr, Type},
12};
13
14struct Error(Option<syn::Error>);
15
16impl Error {
17    pub fn empty() -> Self {
18        Self(None)
19    }
20
21    pub fn add<T: fmt::Display>(&mut self, span: Span, message: T) {
22        let error = syn::Error::new(span, message);
23        match &mut self.0 {
24            Some(e) => e.combine(error),
25            None => self.0 = Some(error),
26        }
27    }
28
29    pub fn add_err(&mut self, error: syn::Error) {
30        match &mut self.0 {
31            Some(e) => e.combine(error),
32            None => self.0 = Some(error),
33        }
34    }
35
36    pub fn error(&mut self) -> Option<syn::Error> {
37        self.0.take()
38    }
39}
40
41#[proc_macro_derive(ModelData, attributes(mysql_connector))]
42pub fn derive_model_data(input: TokenStream) -> TokenStream {
43    let mut error = Error::empty();
44    let input = parse_macro_input!(input as DeriveInput);
45
46    let (attr_span, attrs, _) = parse_attr(&mut error, input.ident.span(), &input.attrs);
47    if let Some(span) = attr_span {
48        if !attrs.contains_key("table") {
49            error.add(span, "table needed (#[mysql_connector(table = \"...\")]");
50        }
51    }
52
53    if let Some(error) = error.error() {
54        return error.into_compile_error().into();
55    }
56
57    let ident = &input.ident;
58    let table = attrs.get("table").unwrap();
59    let table_with_point = table.to_owned() + ".";
60
61    quote! {
62        impl mysql_connector::model::ModelData for #ident {
63            const TABLE: &'static str = #table;
64            const TABLE_WITH_POINT: &'static str = #table_with_point;
65        }
66    }
67    .into()
68}
69
70#[proc_macro_derive(FromQueryResult)]
71pub fn derive_from_query_result(input: TokenStream) -> TokenStream {
72    let mut error = Error::empty();
73    let input = parse_macro_input!(input as DeriveInput);
74
75    let (_, _, types) = parse_attr(&mut error, input.ident.span(), &input.attrs);
76    let fields = parse_fields(&mut error, input.ident.span(), &input.data, &types);
77
78    if let Some(error) = error.error() {
79        return error.into_compile_error().into();
80    }
81
82    let ident = &input.ident;
83    let visibility = &input.vis;
84    let mapping_ident = format_ident!("{ident}Mapping");
85
86    let simple_field_names: &Vec<&Ident> = &fields
87        .iter()
88        .filter(TypeComplexity::simple_ref)
89        .map(|x| &x.ident)
90        .collect();
91    let mut struct_field_names = Vec::new();
92    let mut set_struct_fields = proc_macro2::TokenStream::new();
93    for field in &fields {
94        if let TypeComplexity::Struct(r#struct) = &field.complexity {
95            let field_ident = &field.ident;
96            let struct_path = &r#struct.path;
97            let mapping_names = r#struct
98                .fields
99                .iter()
100                .map(|x| format_ident!("{}_{}", field.ident, x.1));
101            struct_field_names.extend(mapping_names.clone());
102            let struct_names = r#struct.fields.iter().map(|x| &x.0);
103            set_struct_fields = quote! {
104                #set_struct_fields
105                #field_ident: #struct_path {
106                    #(#struct_names: row[mapping.#mapping_names.ok_or(mysql_connector::error::ParseError::MissingField(
107                        concat!(stringify!(#ident), ".", stringify!(#mapping_names))
108                    ))?].take().try_into()?,)*
109                },
110            }
111        }
112    }
113    let struct_field_names = &struct_field_names;
114
115    let complex_field_names: &Vec<&Ident> = &fields
116        .iter()
117        .filter(TypeComplexity::complex_ref)
118        .map(|x: &parse::NamedField| &x.ident)
119        .collect();
120    let complex_field_types: &Vec<&Type> = &fields
121        .iter()
122        .filter(TypeComplexity::complex_ref)
123        .map(|x| &x.ty)
124        .collect();
125
126    let set_mapping = {
127        let mut set_child_mapping = proc_macro2::TokenStream::new();
128
129        for (
130            i,
131            NamedField {
132                complexity: _,
133                //vis: _,
134                ident,
135                ty: _,
136            },
137        ) in fields
138            .iter()
139            .filter(TypeComplexity::complex_ref)
140            .enumerate()
141        {
142            let name = ident.to_string();
143            let name_with_point = name.clone() + ".";
144            let len = name_with_point.as_bytes().len();
145            let maybe_else = if i == 0 { None } else { Some(quote!(else)) };
146
147            set_child_mapping = quote! {
148                #set_child_mapping
149                #maybe_else if table == #name {
150                    self.#ident.set_mapping(column, "", index);
151                } else if table.starts_with(#name_with_point) {
152                    self.#ident.set_mapping(column, &table[#len..], index);
153                }
154            };
155        }
156
157        let set_own_mapping = quote! {
158            *match column.org_name() {
159                #(stringify!(#simple_field_names) => &mut self.#simple_field_names,)*
160                #(stringify!(#struct_field_names) => &mut self.#struct_field_names,)*
161                _ => return,
162            } = Some(index);
163        };
164
165        if !fields.iter().any(TypeComplexity::complex) {
166            set_own_mapping
167        } else {
168            quote! {
169                #set_child_mapping
170                else {
171                    #set_own_mapping
172                }
173            }
174        }
175    };
176
177    quote! {
178        const _: () = {
179            #[derive(Default)]
180            #visibility struct #mapping_ident {
181                #(#simple_field_names: Option<usize>,)*
182                #(#struct_field_names: Option<usize>,)*
183                #(#complex_field_names: <#complex_field_types as mysql_connector::model::FromQueryResult>::Mapping,)*
184            }
185
186            impl mysql_connector::model::FromQueryResultMapping<#ident> for #mapping_ident {
187                fn set_mapping_inner(&mut self, column: &mysql_connector::types::Column, table: &str, index: usize) {
188                    #set_mapping
189                }
190            }
191
192            impl mysql_connector::model::FromQueryResult for #ident {
193                type Mapping = #mapping_ident;
194
195                fn from_mapping_and_row(mapping: &Self::Mapping, row: &mut std::vec::Vec<mysql_connector::types::Value>) -> std::result::Result<Self, mysql_connector::error::ParseError> {
196                    Ok(Self {
197                        #(#simple_field_names: row[mapping.#simple_field_names.ok_or(mysql_connector::error::ParseError::MissingField(
198                            concat!(stringify!(#ident), ".", stringify!(#simple_field_names))
199                        ))?].take().try_into()?,)*
200                        #set_struct_fields
201                        #(#complex_field_names: <#complex_field_types>::from_mapping_and_row(&mapping.#complex_field_names, row)?,)*
202                    })
203                }
204            }
205        };
206    }.into()
207}
208
209#[proc_macro_derive(ActiveModel)]
210pub fn derive_active_model(input: TokenStream) -> TokenStream {
211    let mut error = Error::empty();
212    let input = parse_macro_input!(input as DeriveInput);
213
214    let (attr_span, attrs, types) = parse_attr(&mut error, input.ident.span(), &input.attrs);
215    let fields = parse_fields(&mut error, input.ident.span(), &input.data, &types);
216
217    let primary = match attr_span {
218        Some(span) => match attrs.get("primary") {
219            Some(primary) => match attrs.get("auto_increment") {
220                Some(ai) => Some((format_ident!("{primary}"), ai == "true")),
221                None => {
222                    error.add(
223                        span,
224                        "auto_increment needed (#[mysql_connector(auto_increment = \"...\")]",
225                    );
226                    None
227                }
228            },
229            None => None,
230        },
231        None => None,
232    };
233
234    if let Some(error) = error.error() {
235        return error.into_compile_error().into();
236    }
237
238    let mut insert_struct_fields = proc_macro2::TokenStream::new();
239    for field in &fields {
240        if let TypeComplexity::Struct(r#struct) = &field.complexity {
241            let ident = &field.ident;
242            let idents = r#struct.fields.iter().map(|(x, _)| x);
243            let names = r#struct
244                .fields
245                .iter()
246                .map(|(_, x)| format_ident!("{ident}_{x}"));
247            insert_struct_fields = quote! {
248                #insert_struct_fields
249                match self.#ident {
250                    mysql_connector::model::ActiveValue::Unset =>(),
251                    mysql_connector::model::ActiveValue::Set(value) => {
252                        #(values.push(mysql_connector::model::NamedValue(stringify!(#names), value.#idents.try_into().map_err(Into::<mysql_connector::error::SerializeError>::into)?));)*
253                    }
254                }
255            };
256        }
257    }
258
259    let simple_field_names: &Vec<&Ident> = &fields
260        .iter()
261        .filter(TypeComplexity::simple_ref)
262        .map(|x| &x.ident)
263        .collect();
264    let (simple_field_names_without_primary, set_primary) = primary
265        .as_ref()
266        .and_then(|(primary, auto_increment)| {
267            if *auto_increment {
268                let field_names = simple_field_names
269                    .iter()
270                    .filter(|x| **x != primary)
271                    .copied()
272                    .collect();
273                let set_primary = quote! {
274                    #primary: mysql_connector::model::ActiveValue::Unset,
275                };
276                Some((field_names, set_primary))
277            } else {
278                None
279            }
280        })
281        .unwrap_or_else(|| (simple_field_names.clone(), proc_macro2::TokenStream::new()));
282    let get_primary = match primary {
283        Some((primary, _)) => quote! {
284            match self.#primary {
285                mysql_connector::model::ActiveValue::Set(x) => Some(x.into()),
286                mysql_connector::model::ActiveValue::Unset => None,
287            }
288        },
289        None => quote! {None},
290    };
291
292    let simple_field_types: &Vec<&Type> = &fields
293        .iter()
294        .filter(TypeComplexity::simple_ref)
295        .map(|x| &x.ty)
296        .collect();
297    let struct_field_names: &Vec<&Ident> = &fields
298        .iter()
299        .filter(TypeComplexity::struct_ref)
300        .map(|x| &x.ident)
301        .collect();
302    let struct_field_types: &Vec<&Type> = &fields
303        .iter()
304        .filter(TypeComplexity::struct_ref)
305        .map(|x| &x.ty)
306        .collect();
307    let complex_field_names: &Vec<&Ident> = &fields
308        .iter()
309        .filter(TypeComplexity::complex_ref)
310        .map(|x| &x.ident)
311        .collect();
312    let complex_field_types: &Vec<&Type> = &fields
313        .iter()
314        .filter(TypeComplexity::complex_ref)
315        .map(|x| &x.ty)
316        .collect();
317
318    let ident = &input.ident;
319    let model_ident = format_ident!("{ident}ActiveModel");
320
321    quote! {
322        const _: () = {
323            #[derive(Debug, Default)]
324            pub struct #model_ident {
325                #(pub #simple_field_names: mysql_connector::model::ActiveValue<#simple_field_types>,)*
326                #(pub #struct_field_names: mysql_connector::model::ActiveValue<#struct_field_types>,)*
327                #(pub #complex_field_names: mysql_connector::model::ActiveReference<#complex_field_types>,)*
328            }
329
330            impl mysql_connector::model::ActiveModel<#ident> for #model_ident {
331                async fn into_values(self, conn: &mut mysql_connector::Connection) -> Result<Vec<mysql_connector::model::NamedValue>, mysql_connector::error::Error> {
332                    let mut values = Vec::new();
333                    #(self.#simple_field_names.insert_named_value(&mut values, stringify!(#simple_field_names))?;)*
334                    #insert_struct_fields
335                    #(self.#complex_field_names.insert_named_value(&mut values, stringify!(#complex_field_names), conn).await?;)*
336                    Ok(values)
337                }
338
339                fn primary(&self) -> Option<mysql_connector::types::Value> {
340                    #get_primary
341                }
342            }
343
344            impl mysql_connector::model::HasActiveModel for #ident {
345                type ActiveModel = #model_ident;
346
347                fn into_active_model(self) -> Self::ActiveModel {
348                    #model_ident {
349                        #set_primary
350                        #(#simple_field_names_without_primary: mysql_connector::model::ActiveValue::Set(self.#simple_field_names_without_primary),)*
351                        #(#struct_field_names: mysql_connector::model::ActiveValue::Set(self.#struct_field_names),)*
352                        #(#complex_field_names: mysql_connector::model::ActiveReference::Insert(<#complex_field_types as mysql_connector::model::HasActiveModel>::into_active_model(self.#complex_field_names)),)*
353                    }
354                }
355            }
356        };
357    }.into()
358}
359
360#[proc_macro_derive(IntoQuery)]
361pub fn derive_into_query(input: TokenStream) -> TokenStream {
362    let mut error = Error::empty();
363    let input = parse_macro_input!(input as DeriveInput);
364
365    let (_, _, types) = parse_attr(&mut error, input.ident.span(), &input.attrs);
366    let fields = parse_fields(&mut error, input.ident.span(), &input.data, &types);
367
368    if let Some(error) = error.error() {
369        return error.into_compile_error().into();
370    }
371
372    let mut simple_field_names: Vec<LitStr> = fields
373        .iter()
374        .filter(TypeComplexity::simple_ref)
375        .map(|x| LitStr::new(&x.ident.to_string(), x.ident.span()))
376        .collect();
377    for field in &fields {
378        if let TypeComplexity::Struct(r#struct) = &field.complexity {
379            let mapping_names = r#struct
380                .fields
381                .iter()
382                .map(|x| LitStr::new(&format!("{}_{}", field.ident, x.1), field.ident.span()));
383            simple_field_names.extend(mapping_names);
384        }
385    }
386    let complex_field_names: &Vec<LitStr> = &fields
387        .iter()
388        .filter(TypeComplexity::complex_ref)
389        .map(|x| LitStr::new(&x.ident.to_string(), x.ident.span()))
390        .collect();
391    let complex_field_types: &Vec<&Type> = &fields
392        .iter()
393        .filter(TypeComplexity::complex_ref)
394        .map(|x| &x.ty)
395        .collect();
396
397    let ident = &input.ident;
398
399    quote! {
400        impl mysql_connector::model::IntoQuery for #ident {
401            const COLUMNS: &'static [mysql_connector::model::QueryColumn] = &[
402                #(mysql_connector::model::QueryColumn::Column(#simple_field_names),)*
403                #(mysql_connector::model::QueryColumn::Reference(mysql_connector::model::QueryColumnReference {
404                    column: #complex_field_names,
405                    table: <#complex_field_types as mysql_connector::model::ModelData>::TABLE,
406                    key: <#complex_field_types as mysql_connector::model::Model>::PRIMARY,
407                    columns: <#complex_field_types as mysql_connector::model::IntoQuery>::COLUMNS,
408                }),)*
409            ];
410        }
411    }.into()
412}
413
414#[proc_macro_derive(Model)]
415pub fn derive_model(input: TokenStream) -> TokenStream {
416    let mut error = Error::empty();
417    let input = parse_macro_input!(input as DeriveInput);
418
419    let (attr_span, attrs, types) = parse_attr(&mut error, input.ident.span(), &input.attrs);
420    let fields = parse_fields(&mut error, input.ident.span(), &input.data, &types);
421
422    let mut primary_type = None;
423    let mut auto_increment = false;
424    if let Some(span) = attr_span {
425        match attrs.get("primary") {
426            Some(primary) => match fields.iter().find(|field| field.ident == primary) {
427                Some(field) => primary_type = Some(&field.ty),
428                None => error.add(span, "primary not found in struct"),
429            },
430            None => error.add(
431                span,
432                "primary needed (#[mysql_connector(primary = \"...\")]",
433            ),
434        }
435        match attrs.get("auto_increment") {
436            Some(ai) => auto_increment = ai == "true",
437            None => error.add(
438                span,
439                "auto_increment needed (#[mysql_connector(auto_increment = \"...\")]",
440            ),
441        }
442    }
443
444    if let Some(error) = error.error() {
445        return error.into_compile_error().into();
446    }
447
448    let primary = attrs.get("primary").unwrap();
449    let primary_type = primary_type.unwrap();
450    let primary_ident = Ident::new(primary, Span::call_site());
451    let ident = &input.ident;
452
453    quote! {
454        impl mysql_connector::model::Model for #ident {
455            const PRIMARY: &'static str = #primary;
456            const AUTO_INCREMENT: bool = #auto_increment;
457
458            type Primary = #primary_type;
459
460            fn primary(&self) -> Self::Primary {
461                self.#primary_ident
462            }
463        }
464    }
465    .into()
466}