1extern crate proc_macro;
2
3mod parse;
4
5use {
6 parse::{parse_attr, parse_fields, NamedField, TypeComplexity},
7 proc_macro::TokenStream,
8 proc_macro2::Span,
9 quote::{format_ident, quote},
10 std::fmt,
11 syn::{parse_macro_input, DeriveInput, Ident, LitStr, Type},
12};
13
14struct Error(Option<syn::Error>);
15
16impl Error {
17 pub fn empty() -> Self {
18 Self(None)
19 }
20
21 pub fn add<T: fmt::Display>(&mut self, span: Span, message: T) {
22 let error = syn::Error::new(span, message);
23 match &mut self.0 {
24 Some(e) => e.combine(error),
25 None => self.0 = Some(error),
26 }
27 }
28
29 pub fn add_err(&mut self, error: syn::Error) {
30 match &mut self.0 {
31 Some(e) => e.combine(error),
32 None => self.0 = Some(error),
33 }
34 }
35
36 pub fn error(&mut self) -> Option<syn::Error> {
37 self.0.take()
38 }
39}
40
41#[proc_macro_derive(ModelData, attributes(mysql_connector))]
42pub fn derive_model_data(input: TokenStream) -> TokenStream {
43 let mut error = Error::empty();
44 let input = parse_macro_input!(input as DeriveInput);
45
46 let (attr_span, attrs, _) = parse_attr(&mut error, input.ident.span(), &input.attrs);
47 if let Some(span) = attr_span {
48 if !attrs.contains_key("table") {
49 error.add(span, "table needed (#[mysql_connector(table = \"...\")]");
50 }
51 }
52
53 if let Some(error) = error.error() {
54 return error.into_compile_error().into();
55 }
56
57 let ident = &input.ident;
58 let table = attrs.get("table").unwrap();
59 let table_with_point = table.to_owned() + ".";
60
61 quote! {
62 impl mysql_connector::model::ModelData for #ident {
63 const TABLE: &'static str = #table;
64 const TABLE_WITH_POINT: &'static str = #table_with_point;
65 }
66 }
67 .into()
68}
69
70#[proc_macro_derive(FromQueryResult)]
71pub fn derive_from_query_result(input: TokenStream) -> TokenStream {
72 let mut error = Error::empty();
73 let input = parse_macro_input!(input as DeriveInput);
74
75 let (_, _, types) = parse_attr(&mut error, input.ident.span(), &input.attrs);
76 let fields = parse_fields(&mut error, input.ident.span(), &input.data, &types);
77
78 if let Some(error) = error.error() {
79 return error.into_compile_error().into();
80 }
81
82 let ident = &input.ident;
83 let visibility = &input.vis;
84 let mapping_ident = format_ident!("{ident}Mapping");
85
86 let simple_field_names: &Vec<&Ident> = &fields
87 .iter()
88 .filter(TypeComplexity::simple_ref)
89 .map(|x| &x.ident)
90 .collect();
91 let mut struct_field_names = Vec::new();
92 let mut set_struct_fields = proc_macro2::TokenStream::new();
93 for field in &fields {
94 if let TypeComplexity::Struct(r#struct) = &field.complexity {
95 let field_ident = &field.ident;
96 let struct_path = &r#struct.path;
97 let mapping_names = r#struct
98 .fields
99 .iter()
100 .map(|x| format_ident!("{}_{}", field.ident, x.1));
101 struct_field_names.extend(mapping_names.clone());
102 let struct_names = r#struct.fields.iter().map(|x| &x.0);
103 set_struct_fields = quote! {
104 #set_struct_fields
105 #field_ident: #struct_path {
106 #(#struct_names: row[mapping.#mapping_names.ok_or(mysql_connector::error::ParseError::MissingField(
107 concat!(stringify!(#ident), ".", stringify!(#mapping_names))
108 ))?].take().try_into()?,)*
109 },
110 }
111 }
112 }
113 let struct_field_names = &struct_field_names;
114
115 let complex_field_names: &Vec<&Ident> = &fields
116 .iter()
117 .filter(TypeComplexity::complex_ref)
118 .map(|x: &parse::NamedField| &x.ident)
119 .collect();
120 let complex_field_types: &Vec<&Type> = &fields
121 .iter()
122 .filter(TypeComplexity::complex_ref)
123 .map(|x| &x.ty)
124 .collect();
125
126 let set_mapping = {
127 let mut set_child_mapping = proc_macro2::TokenStream::new();
128
129 for (
130 i,
131 NamedField {
132 complexity: _,
133 ident,
135 ty: _,
136 },
137 ) in fields
138 .iter()
139 .filter(TypeComplexity::complex_ref)
140 .enumerate()
141 {
142 let name = ident.to_string();
143 let name_with_point = name.clone() + ".";
144 let len = name_with_point.as_bytes().len();
145 let maybe_else = if i == 0 { None } else { Some(quote!(else)) };
146
147 set_child_mapping = quote! {
148 #set_child_mapping
149 #maybe_else if table == #name {
150 self.#ident.set_mapping(column, "", index);
151 } else if table.starts_with(#name_with_point) {
152 self.#ident.set_mapping(column, &table[#len..], index);
153 }
154 };
155 }
156
157 let set_own_mapping = quote! {
158 *match column.org_name() {
159 #(stringify!(#simple_field_names) => &mut self.#simple_field_names,)*
160 #(stringify!(#struct_field_names) => &mut self.#struct_field_names,)*
161 _ => return,
162 } = Some(index);
163 };
164
165 if !fields.iter().any(TypeComplexity::complex) {
166 set_own_mapping
167 } else {
168 quote! {
169 #set_child_mapping
170 else {
171 #set_own_mapping
172 }
173 }
174 }
175 };
176
177 quote! {
178 const _: () = {
179 #[derive(Default)]
180 #visibility struct #mapping_ident {
181 #(#simple_field_names: Option<usize>,)*
182 #(#struct_field_names: Option<usize>,)*
183 #(#complex_field_names: <#complex_field_types as mysql_connector::model::FromQueryResult>::Mapping,)*
184 }
185
186 impl mysql_connector::model::FromQueryResultMapping<#ident> for #mapping_ident {
187 fn set_mapping_inner(&mut self, column: &mysql_connector::types::Column, table: &str, index: usize) {
188 #set_mapping
189 }
190 }
191
192 impl mysql_connector::model::FromQueryResult for #ident {
193 type Mapping = #mapping_ident;
194
195 fn from_mapping_and_row(mapping: &Self::Mapping, row: &mut std::vec::Vec<mysql_connector::types::Value>) -> std::result::Result<Self, mysql_connector::error::ParseError> {
196 Ok(Self {
197 #(#simple_field_names: row[mapping.#simple_field_names.ok_or(mysql_connector::error::ParseError::MissingField(
198 concat!(stringify!(#ident), ".", stringify!(#simple_field_names))
199 ))?].take().try_into()?,)*
200 #set_struct_fields
201 #(#complex_field_names: <#complex_field_types>::from_mapping_and_row(&mapping.#complex_field_names, row)?,)*
202 })
203 }
204 }
205 };
206 }.into()
207}
208
209#[proc_macro_derive(ActiveModel)]
210pub fn derive_active_model(input: TokenStream) -> TokenStream {
211 let mut error = Error::empty();
212 let input = parse_macro_input!(input as DeriveInput);
213
214 let (attr_span, attrs, types) = parse_attr(&mut error, input.ident.span(), &input.attrs);
215 let fields = parse_fields(&mut error, input.ident.span(), &input.data, &types);
216
217 let primary = match attr_span {
218 Some(span) => match attrs.get("primary") {
219 Some(primary) => match attrs.get("auto_increment") {
220 Some(ai) => Some((format_ident!("{primary}"), ai == "true")),
221 None => {
222 error.add(
223 span,
224 "auto_increment needed (#[mysql_connector(auto_increment = \"...\")]",
225 );
226 None
227 }
228 },
229 None => None,
230 },
231 None => None,
232 };
233
234 if let Some(error) = error.error() {
235 return error.into_compile_error().into();
236 }
237
238 let mut insert_struct_fields = proc_macro2::TokenStream::new();
239 for field in &fields {
240 if let TypeComplexity::Struct(r#struct) = &field.complexity {
241 let ident = &field.ident;
242 let idents = r#struct.fields.iter().map(|(x, _)| x);
243 let names = r#struct
244 .fields
245 .iter()
246 .map(|(_, x)| format_ident!("{ident}_{x}"));
247 insert_struct_fields = quote! {
248 #insert_struct_fields
249 match self.#ident {
250 mysql_connector::model::ActiveValue::Unset =>(),
251 mysql_connector::model::ActiveValue::Set(value) => {
252 #(values.push(mysql_connector::model::NamedValue(stringify!(#names), value.#idents.try_into().map_err(Into::<mysql_connector::error::SerializeError>::into)?));)*
253 }
254 }
255 };
256 }
257 }
258
259 let simple_field_names: &Vec<&Ident> = &fields
260 .iter()
261 .filter(TypeComplexity::simple_ref)
262 .map(|x| &x.ident)
263 .collect();
264 let (simple_field_names_without_primary, set_primary) = primary
265 .as_ref()
266 .and_then(|(primary, auto_increment)| {
267 if *auto_increment {
268 let field_names = simple_field_names
269 .iter()
270 .filter(|x| **x != primary)
271 .copied()
272 .collect();
273 let set_primary = quote! {
274 #primary: mysql_connector::model::ActiveValue::Unset,
275 };
276 Some((field_names, set_primary))
277 } else {
278 None
279 }
280 })
281 .unwrap_or_else(|| (simple_field_names.clone(), proc_macro2::TokenStream::new()));
282 let get_primary = match primary {
283 Some((primary, _)) => quote! {
284 match self.#primary {
285 mysql_connector::model::ActiveValue::Set(x) => Some(x.into()),
286 mysql_connector::model::ActiveValue::Unset => None,
287 }
288 },
289 None => quote! {None},
290 };
291
292 let simple_field_types: &Vec<&Type> = &fields
293 .iter()
294 .filter(TypeComplexity::simple_ref)
295 .map(|x| &x.ty)
296 .collect();
297 let struct_field_names: &Vec<&Ident> = &fields
298 .iter()
299 .filter(TypeComplexity::struct_ref)
300 .map(|x| &x.ident)
301 .collect();
302 let struct_field_types: &Vec<&Type> = &fields
303 .iter()
304 .filter(TypeComplexity::struct_ref)
305 .map(|x| &x.ty)
306 .collect();
307 let complex_field_names: &Vec<&Ident> = &fields
308 .iter()
309 .filter(TypeComplexity::complex_ref)
310 .map(|x| &x.ident)
311 .collect();
312 let complex_field_types: &Vec<&Type> = &fields
313 .iter()
314 .filter(TypeComplexity::complex_ref)
315 .map(|x| &x.ty)
316 .collect();
317
318 let ident = &input.ident;
319 let model_ident = format_ident!("{ident}ActiveModel");
320
321 quote! {
322 const _: () = {
323 #[derive(Debug, Default)]
324 pub struct #model_ident {
325 #(pub #simple_field_names: mysql_connector::model::ActiveValue<#simple_field_types>,)*
326 #(pub #struct_field_names: mysql_connector::model::ActiveValue<#struct_field_types>,)*
327 #(pub #complex_field_names: mysql_connector::model::ActiveReference<#complex_field_types>,)*
328 }
329
330 impl mysql_connector::model::ActiveModel<#ident> for #model_ident {
331 async fn into_values(self, conn: &mut mysql_connector::Connection) -> Result<Vec<mysql_connector::model::NamedValue>, mysql_connector::error::Error> {
332 let mut values = Vec::new();
333 #(self.#simple_field_names.insert_named_value(&mut values, stringify!(#simple_field_names))?;)*
334 #insert_struct_fields
335 #(self.#complex_field_names.insert_named_value(&mut values, stringify!(#complex_field_names), conn).await?;)*
336 Ok(values)
337 }
338
339 fn primary(&self) -> Option<mysql_connector::types::Value> {
340 #get_primary
341 }
342 }
343
344 impl mysql_connector::model::HasActiveModel for #ident {
345 type ActiveModel = #model_ident;
346
347 fn into_active_model(self) -> Self::ActiveModel {
348 #model_ident {
349 #set_primary
350 #(#simple_field_names_without_primary: mysql_connector::model::ActiveValue::Set(self.#simple_field_names_without_primary),)*
351 #(#struct_field_names: mysql_connector::model::ActiveValue::Set(self.#struct_field_names),)*
352 #(#complex_field_names: mysql_connector::model::ActiveReference::Insert(<#complex_field_types as mysql_connector::model::HasActiveModel>::into_active_model(self.#complex_field_names)),)*
353 }
354 }
355 }
356 };
357 }.into()
358}
359
360#[proc_macro_derive(IntoQuery)]
361pub fn derive_into_query(input: TokenStream) -> TokenStream {
362 let mut error = Error::empty();
363 let input = parse_macro_input!(input as DeriveInput);
364
365 let (_, _, types) = parse_attr(&mut error, input.ident.span(), &input.attrs);
366 let fields = parse_fields(&mut error, input.ident.span(), &input.data, &types);
367
368 if let Some(error) = error.error() {
369 return error.into_compile_error().into();
370 }
371
372 let mut simple_field_names: Vec<LitStr> = fields
373 .iter()
374 .filter(TypeComplexity::simple_ref)
375 .map(|x| LitStr::new(&x.ident.to_string(), x.ident.span()))
376 .collect();
377 for field in &fields {
378 if let TypeComplexity::Struct(r#struct) = &field.complexity {
379 let mapping_names = r#struct
380 .fields
381 .iter()
382 .map(|x| LitStr::new(&format!("{}_{}", field.ident, x.1), field.ident.span()));
383 simple_field_names.extend(mapping_names);
384 }
385 }
386 let complex_field_names: &Vec<LitStr> = &fields
387 .iter()
388 .filter(TypeComplexity::complex_ref)
389 .map(|x| LitStr::new(&x.ident.to_string(), x.ident.span()))
390 .collect();
391 let complex_field_types: &Vec<&Type> = &fields
392 .iter()
393 .filter(TypeComplexity::complex_ref)
394 .map(|x| &x.ty)
395 .collect();
396
397 let ident = &input.ident;
398
399 quote! {
400 impl mysql_connector::model::IntoQuery for #ident {
401 const COLUMNS: &'static [mysql_connector::model::QueryColumn] = &[
402 #(mysql_connector::model::QueryColumn::Column(#simple_field_names),)*
403 #(mysql_connector::model::QueryColumn::Reference(mysql_connector::model::QueryColumnReference {
404 column: #complex_field_names,
405 table: <#complex_field_types as mysql_connector::model::ModelData>::TABLE,
406 key: <#complex_field_types as mysql_connector::model::Model>::PRIMARY,
407 columns: <#complex_field_types as mysql_connector::model::IntoQuery>::COLUMNS,
408 }),)*
409 ];
410 }
411 }.into()
412}
413
414#[proc_macro_derive(Model)]
415pub fn derive_model(input: TokenStream) -> TokenStream {
416 let mut error = Error::empty();
417 let input = parse_macro_input!(input as DeriveInput);
418
419 let (attr_span, attrs, types) = parse_attr(&mut error, input.ident.span(), &input.attrs);
420 let fields = parse_fields(&mut error, input.ident.span(), &input.data, &types);
421
422 let mut primary_type = None;
423 let mut auto_increment = false;
424 if let Some(span) = attr_span {
425 match attrs.get("primary") {
426 Some(primary) => match fields.iter().find(|field| field.ident == primary) {
427 Some(field) => primary_type = Some(&field.ty),
428 None => error.add(span, "primary not found in struct"),
429 },
430 None => error.add(
431 span,
432 "primary needed (#[mysql_connector(primary = \"...\")]",
433 ),
434 }
435 match attrs.get("auto_increment") {
436 Some(ai) => auto_increment = ai == "true",
437 None => error.add(
438 span,
439 "auto_increment needed (#[mysql_connector(auto_increment = \"...\")]",
440 ),
441 }
442 }
443
444 if let Some(error) = error.error() {
445 return error.into_compile_error().into();
446 }
447
448 let primary = attrs.get("primary").unwrap();
449 let primary_type = primary_type.unwrap();
450 let primary_ident = Ident::new(primary, Span::call_site());
451 let ident = &input.ident;
452
453 quote! {
454 impl mysql_connector::model::Model for #ident {
455 const PRIMARY: &'static str = #primary;
456 const AUTO_INCREMENT: bool = #auto_increment;
457
458 type Primary = #primary_type;
459
460 fn primary(&self) -> Self::Primary {
461 self.#primary_ident
462 }
463 }
464 }
465 .into()
466}