rizzle_macros/
lib.rs

1use std::collections::HashSet;
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::{quote, quote_spanned, ToTokens};
6use syn::{
7    parse::Parse, parse_macro_input, spanned::Spanned, Data, DeriveInput, Expr, ExprAssign,
8    ExprLit, ExprPath, Field, Ident, Lit, LitStr, PathSegment, Result, Type, TypePath,
9};
10
11#[proc_macro_derive(Table, attributes(rizzle))]
12pub fn table(s: TokenStream) -> TokenStream {
13    let input = parse_macro_input!(s as DeriveInput);
14    match table_macro(input) {
15        Ok(s) => s.to_token_stream().into(),
16        Err(e) => e.to_compile_error().into(),
17    }
18}
19
20enum Rel {
21    One(LitStr),
22    Many(LitStr),
23}
24
25#[derive(Default)]
26struct RizzleAttr {
27    table_name: Option<LitStr>,
28    primary_key: bool,
29    not_null: bool,
30    default_value: Option<LitStr>,
31    columns: Option<LitStr>,
32    references: Option<LitStr>,
33    from: Option<LitStr>,
34    to: Option<LitStr>,
35    rel: Option<Rel>,
36}
37
38impl Parse for RizzleAttr {
39    fn parse(input: syn::parse::ParseStream) -> Result<Self> {
40        let mut rizzle_attr = RizzleAttr::default();
41        let args_parsed =
42            syn::punctuated::Punctuated::<Expr, syn::Token![,]>::parse_terminated(input)?;
43        for expr in args_parsed.iter() {
44            match expr {
45                Expr::Assign(ExprAssign { left, right, .. }) => match (&**left, &**right) {
46                    (Expr::Path(ExprPath { path, .. }), Expr::Lit(ExprLit { lit, .. })) => {
47                        if let (Some(PathSegment { ident, .. }), Lit::Str(lit_str)) =
48                            (path.segments.last(), lit)
49                        {
50                            match ident.to_string().as_ref() {
51                                "table" => {
52                                    rizzle_attr.table_name = Some(lit_str.clone());
53                                }
54                                "r#default" => {
55                                    rizzle_attr.default_value = Some(lit_str.clone());
56                                }
57                                "columns" => {
58                                    rizzle_attr.columns = Some(lit_str.clone());
59                                }
60                                "references" => {
61                                    rizzle_attr.references = Some(lit_str.clone());
62                                }
63                                "many" => {
64                                    rizzle_attr.rel = Some(Rel::Many(lit_str.clone()));
65                                }
66                                "from" => {
67                                    rizzle_attr.from = Some(lit_str.clone());
68                                }
69                                "to" => {
70                                    rizzle_attr.to = Some(lit_str.clone());
71                                }
72                                "one" => {
73                                    rizzle_attr.rel = Some(Rel::One(lit_str.clone()));
74                                }
75                                _ => unimplemented!(),
76                            }
77                        }
78                    }
79                    _ => unimplemented!(),
80                },
81                Expr::Path(path) => match path.path.segments.len() {
82                    1 => match path
83                        .path
84                        .segments
85                        .first()
86                        .unwrap()
87                        .ident
88                        .to_string()
89                        .as_ref()
90                    {
91                        "not_null" => rizzle_attr.not_null = true,
92                        "primary_key" => rizzle_attr.primary_key = true,
93                        _ => {}
94                    },
95                    _ => {}
96                },
97                _ => {}
98            }
99        }
100
101        Ok(rizzle_attr)
102    }
103}
104
105struct RizzleField {
106    ident_name: String,
107    ident: Ident,
108    field: Field,
109    attrs: Vec<RizzleAttr>,
110    type_string: String,
111}
112
113fn table_macro(input: DeriveInput) -> Result<TokenStream2> {
114    let table_str = input
115        .attrs
116        .iter()
117        .filter_map(|attr| attr.parse_args::<RizzleAttr>().ok())
118        .last()
119        .expect("define #![rizzle(table = \"your table name here\")] on struct")
120        .table_name
121        .unwrap();
122    let struct_name = input.ident;
123    let table_name = table_str.value();
124    let rizzle_fields = match input.data {
125        syn::Data::Struct(ref data) => data
126            .fields
127            .iter()
128            .map(|field| {
129                let ident = field
130                    .ident
131                    .as_ref()
132                    .expect("Struct fields should have names");
133                RizzleField {
134                    ident: ident.clone(),
135                    ident_name: ident.to_string(),
136                    field: field.clone(),
137                    attrs: field
138                        .attrs
139                        .iter()
140                        .filter_map(|attr| attr.parse_args::<RizzleAttr>().ok())
141                        .collect::<Vec<_>>(),
142                    type_string: string_type_from_field(&field),
143                }
144            })
145            .collect::<Vec<_>>(),
146        _ => unimplemented!(),
147    };
148    let columns = columns(&table_name, &rizzle_fields);
149    let attrs = struct_attrs(&table_name, &rizzle_fields);
150    let indexes = indexes(&table_name, &rizzle_fields);
151    let references = references(&table_name, &rizzle_fields);
152
153    Ok(quote! {
154        impl Table for #struct_name {
155            fn new() -> Self {
156                Self { #(#attrs,)* }
157            }
158
159            fn name(&self) -> String {
160                String::from(#table_str)
161            }
162
163            fn columns(&self) -> Vec<Column> {
164                vec![#(#columns,)*]
165            }
166
167            fn indexes(&self) -> Vec<Index> {
168                vec![#(#indexes,)*]
169            }
170
171            fn references(&self) -> Vec<Reference> {
172                vec![#(#references,)*]
173            }
174
175            fn create_sql(&self) -> String {
176                let columns_sql = self.columns()
177                    .iter()
178                    .map(|c| c.definition_sql())
179                    .collect::<Vec<_>>()
180                    .join(", ");
181                format!("create table {} ({})", self.name(), columns_sql)
182            }
183        }
184    })
185}
186
187fn references(table_name: &String, fields: &Vec<RizzleField>) -> Vec<TokenStream2> {
188    fields
189        .iter()
190        .filter(|field| match field.type_string.as_str() {
191            "Real" | "Integer" | "Text" | "Blob" | "Many" => true,
192            _ => false,
193        })
194        .filter(|field| match field.attrs.last() {
195            Some(attr) => attr.references.is_some(),
196            None => false,
197        })
198        .map(|field| {
199            let RizzleAttr { references, .. } = field.attrs.last().unwrap();
200            let many = field.type_string == "Many";
201            quote! {
202                Reference {
203                    table: #table_name.to_owned(),
204                    clause: #references.to_owned(),
205                    many: #many,
206                    ..Default::default()
207                }
208            }
209        })
210        .collect()
211}
212
213fn string_type_from_field(field: &Field) -> String {
214    match &field.ty {
215        syn::Type::Path(TypePath { path, .. }) => match path.segments.last() {
216            Some(PathSegment { ident, .. }) => ident.to_string(),
217            None => unimplemented!(),
218        },
219        _ => unimplemented!(),
220    }
221}
222
223fn struct_attrs(table_name: &String, fields: &Vec<RizzleField>) -> Vec<TokenStream2> {
224    fields
225        .iter()
226        .map(|f| {
227            let ident = &f
228                .field
229                .ident
230                .as_ref()
231                .expect("Struct fields should have names");
232            let value = format!("{}.{}", table_name, ident.to_string());
233            quote! {
234                #ident: #value
235            }
236        })
237        .collect::<Vec<_>>()
238}
239
240fn data_type(string_type: &String) -> TokenStream2 {
241    match string_type.as_str() {
242        "Real" => quote! { sqlite::DataType::Real },
243        "Integer" => quote! { sqlite::DataType::Integer },
244        "Text" => quote! { sqlite::DataType::Text },
245        _ => quote! { sqlite::DataType::Blob },
246    }
247}
248
249fn columns(table_name: &String, fields: &Vec<RizzleField>) -> Vec<TokenStream2> {
250    fields
251        .iter()
252        .filter(|field| match field.type_string.as_ref() {
253            "Real" | "Integer" | "Text" | "Blob" => true,
254            _ => false,
255        })
256        .map(|field| {
257            let ty = data_type(&field.type_string);
258            let ident = &field.ident_name;
259            if let Some(RizzleAttr {
260                primary_key,
261                not_null,
262                default_value,
263                references,
264                ..
265            }) = field.attrs.last()
266            {
267                let default_value = match default_value {
268                    Some(default) => quote! { Some(#default.to_owned()) },
269                    None => quote! { None },
270                };
271                let references = match references {
272                    Some(references) => quote! { Some(#references.to_owned()) },
273                    None => quote! { None },
274                };
275                quote! {
276                    Column {
277                        table_name: #table_name.to_string(),
278                        name: #ident.to_string(),
279                        data_type: #ty,
280                        primary_key: #primary_key,
281                        not_null: #not_null,
282                        default_value: #default_value,
283                        references: #references,
284                        ..Default::default()
285                    }
286                }
287            } else {
288                quote! {
289                    Column {
290                        table_name: #table_name.to_string(),
291                        name: #ident.to_string(),
292                        data_type: #ty,
293                        ..Default::default()
294                    }
295                }
296            }
297        })
298        .collect::<Vec<_>>()
299}
300
301fn indexes(table_name: &String, fields: &Vec<RizzleField>) -> Vec<TokenStream2> {
302    let column_names = fields
303        .iter()
304        .filter(|f| match f.type_string.as_ref() {
305            "Text" | "Integer" | "Real" | "Blob" => true,
306            _ => false,
307        })
308        .map(|f| f.ident_name.clone())
309        .collect::<HashSet<_>>();
310    fields
311        .iter()
312        .filter(|f| match f.type_string.as_ref() {
313            "Index" | "UniqueIndex" => true,
314            _ => false,
315        })
316        .filter(|f| match f.attrs.last() {
317            Some(attr) => attr.columns.is_some(),
318            None => false,
319        })
320        .map(|f| {
321            let name = &f.ident_name;
322            let attr = f.attrs.last().unwrap();
323            let RizzleAttr { columns, .. } = attr;
324            let attr_column_names = match columns {
325                Some(lit_str) => lit_str.value(),
326                None => String::default(),
327            };
328            let names = attr_column_names
329                .split(",")
330                .map(|x| x.to_owned())
331                .collect::<HashSet<_>>();
332            let diff = &names.difference(&column_names).collect::<Vec<_>>();
333            let column_names_list = column_names
334                .iter()
335                .map(|x| format!("- {}", x))
336                .collect::<Vec<_>>()
337                .join("\n");
338            let compiler_error = format!(
339                "index {:?} on {:?} in table {:?} which only declares \n{}",
340                name, diff, table_name, column_names_list
341            );
342            if diff.len() != 0 {
343                quote_spanned! {
344                    columns.span() => compile_error!(#compiler_error)
345                }
346            } else {
347                let column_names = match columns {
348                    Some(lit_str) => quote! { #lit_str.to_string() },
349                    None => quote! { "".to_string() },
350                };
351                let ty = match f.type_string.as_ref() {
352                    "Index" => quote! { sqlite::IndexType::Plain },
353                    "UniqueIndex" => quote! { sqlite::IndexType::Unique },
354                    _ => unimplemented!(),
355                };
356                quote! {
357                    Index {
358                        table_name: #table_name.to_string(),
359                        name: #name.to_string(),
360                        index_type: #ty,
361                        column_names: #column_names
362                    }
363                }
364            }
365        })
366        .collect::<Vec<_>>()
367}
368
369#[proc_macro_derive(Select, attributes(rizzle))]
370pub fn select(s: TokenStream) -> TokenStream {
371    let input = parse_macro_input!(s as DeriveInput);
372    match select_macro(input) {
373        Ok(s) => s.to_token_stream().into(),
374        Err(e) => e.to_compile_error().into(),
375    }
376}
377
378fn select_macro(input: DeriveInput) -> Result<TokenStream2> {
379    let struct_name = input.ident;
380    let rizzle_fields = rizzle_fields(&input.data);
381    let column_names = rizzle_fields
382        .iter()
383        .filter(|RizzleField { attrs, .. }| attrs.is_empty())
384        .map(|RizzleField { ident, .. }| ident.to_string())
385        .collect::<Vec<_>>()
386        .join(", ");
387    let fk_column_streams = rizzle_fields
388        .iter()
389        .filter(|RizzleField { attrs, .. }| !attrs.is_empty())
390        .map(|RizzleField { attrs, field,  .. }| {
391            let ty = last_segment_from_type(&field.ty).expect("");
392            let attr = attrs.last().unwrap();
393            let ty_name = &ty.to_string();
394            match &attr.rel {
395                Some(Rel::One(table_name)) => {
396                    quote! {
397                        #ty::column_names_sql().split(",").map(|col| format!("{}.{} as '{}_{}'", #table_name, col.trim(), #ty_name, col.trim())).collect::<Vec<_>>().join(", ")
398                    }
399                },
400                Some(Rel::Many(_)) => todo!(),
401                None => todo!(),
402            }
403        })
404        .collect::<Vec<_>>();
405    let sets = &rizzle_fields
406        .iter()
407        .map(
408            |RizzleField {
409                 ident,
410                 attrs,
411                 field,
412                 ..
413             }| {
414                if let (Some(ty_ident), Some(attr)) =
415                    (last_segment_from_type(&field.ty), attrs.last())
416                {
417                    match &attr.rel {
418                        Some(Rel::One(_)) => quote! {
419                            #ident: #ty_ident::from_row(row)?
420                        },
421                        Some(Rel::Many(_)) => todo!(),
422                        _ => todo!(),
423                    }
424                } else {
425                    let lit_str = ident.to_string();
426                    let struct_name_string = struct_name.to_string();
427                    let fk_name = format!("{}_{}", struct_name_string, lit_str);
428                    quote! {
429                        #ident: match row.try_get(#lit_str) {
430                            Ok(val) => val,
431                            Err(_) => row.try_get(#fk_name)?
432                        }
433                    }
434                }
435            },
436        )
437        .collect::<Vec<_>>();
438    Ok(quote! {
439        impl Select for #struct_name {
440            fn column_names_sql() -> String {
441                let prefixed_vec: Vec<String> = vec![#(#fk_column_streams,)*];
442                let prefixed: String = prefixed_vec.join(", ");
443                if prefixed.is_empty() {
444                    format!("{}", #column_names)
445                } else {
446                    format!("{}, {}", #column_names, prefixed)
447                }
448            }
449
450            fn columns_sql(&self) -> String {
451                Self::column_names_sql()
452            }
453        }
454
455        impl<'r> FromRow<'r, sqlite::SqliteRow> for #struct_name {
456            fn from_row(row: &'r sqlite::SqliteRow) -> Result<Self, SqlxError> {
457                Ok(#struct_name {
458                    #(#sets,)*
459                })
460            }
461        }
462    })
463}
464
465#[proc_macro_derive(Insert, attributes(rizzle))]
466pub fn insert(s: TokenStream) -> TokenStream {
467    let input = parse_macro_input!(s as DeriveInput);
468    match insert_macro(input) {
469        Ok(s) => s.to_token_stream().into(),
470        Err(e) => e.to_compile_error().into(),
471    }
472}
473
474fn insert_macro(input: DeriveInput) -> Result<TokenStream2> {
475    let struct_name = input.ident;
476    let fields = match input.data {
477        syn::Data::Struct(ref data) => data
478            .fields
479            .iter()
480            .filter(|field| field.attrs.is_empty())
481            .map(|field| {
482                let ident = field
483                    .ident
484                    .as_ref()
485                    .expect("Struct fields should have names");
486                let ty = &field.ty;
487                (ident, ty)
488            })
489            .collect::<Vec<_>>(),
490        _ => unimplemented!(),
491    };
492    let column_names = fields
493        .iter()
494        .map(|(ident, _)| ident.to_string())
495        .collect::<Vec<_>>()
496        .join(", ");
497    let sql_placeholders = &fields.iter().map(|_| "?").collect::<Vec<_>>();
498    let placeholders = &fields.iter().map(|_| "{}").collect::<Vec<_>>().join(", ");
499    let data_values = &fields
500        .iter()
501        .map(|(ident, _)| quote! { self.#ident.clone().into() })
502        .collect::<Vec<_>>();
503    Ok(quote! {
504        impl Insert for #struct_name {
505            fn insert_values(&self) -> Vec<DataValue> {
506                vec![
507                    #(#data_values,)*
508                ]
509            }
510
511            fn insert_sql(&self) -> String {
512                let values_sql = format!(#placeholders, #(#sql_placeholders,)*);
513                format!("({}) values ({})", #column_names, values_sql)
514            }
515        }
516    })
517}
518
519#[proc_macro_derive(New, attributes(rizzle))]
520pub fn new(s: TokenStream) -> TokenStream {
521    let input = parse_macro_input!(s as DeriveInput);
522    match new_macro(input) {
523        Ok(s) => s.to_token_stream().into(),
524        Err(e) => e.to_compile_error().into(),
525    }
526}
527
528fn new_macro(input: DeriveInput) -> Result<TokenStream2> {
529    let fields = match input.data {
530        syn::Data::Struct(ref data) => data
531            .fields
532            .iter()
533            .map(|field| {
534                let ident = field
535                    .ident
536                    .as_ref()
537                    .expect("Struct fields should have names");
538                let ty = &field.ty;
539                (ident, ty)
540            })
541            .collect::<Vec<_>>(),
542        _ => unimplemented!(),
543    };
544    let attrs = &fields
545        .iter()
546        .map(|(ident, ty)| {
547            quote! {
548                #ident: #ty::default()
549            }
550        })
551        .collect::<Vec<_>>();
552    let struct_name = input.ident;
553    Ok(quote! {
554        impl New for #struct_name {
555           fn new() -> Self {
556               Self { #(#attrs,)* }
557           }
558        }
559    })
560}
561
562#[proc_macro_derive(Update, attributes(rizzle))]
563pub fn update(s: TokenStream) -> TokenStream {
564    let input = parse_macro_input!(s as DeriveInput);
565    match update_macro(input) {
566        Ok(s) => s.to_token_stream().into(),
567        Err(e) => e.to_compile_error().into(),
568    }
569}
570
571fn update_macro(input: DeriveInput) -> Result<TokenStream2> {
572    let struct_name = input.ident;
573    let fields = match input.data {
574        syn::Data::Struct(ref data) => data
575            .fields
576            .iter()
577            .filter(|field| field.attrs.is_empty())
578            .map(|field| {
579                let ident = field
580                    .ident
581                    .as_ref()
582                    .expect("Struct fields should have names");
583                let ty = &field.ty;
584                (ident, ty)
585            })
586            .collect::<Vec<_>>(),
587        _ => unimplemented!(),
588    };
589    let placeholders = &fields
590        .iter()
591        .map(|(ident, _)| format!("{} = ?", ident))
592        .collect::<Vec<_>>()
593        .join(", ");
594    let data_values = &fields
595        .iter()
596        .map(|(ident, _)| quote! { self.#ident.clone().into() })
597        .collect::<Vec<_>>();
598    Ok(quote! {
599        impl Update for #struct_name {
600            fn update_values(&self) -> Vec<DataValue> {
601                vec![
602                    #(#data_values,)*
603                ]
604            }
605
606            fn update_sql(&self) -> String {
607                format!("set {}", #placeholders)
608            }
609        }
610    })
611}
612
613#[proc_macro_derive(Row, attributes(rizzle))]
614pub fn row(s: TokenStream) -> TokenStream {
615    let input = parse_macro_input!(s as DeriveInput);
616    match row_macro(input) {
617        Ok(s) => s.to_token_stream().into(),
618        Err(e) => e.to_compile_error().into(),
619    }
620}
621
622fn row_macro(input: DeriveInput) -> Result<TokenStream2> {
623    let insert_token_stream = insert_macro(input.clone())?;
624    let update_token_stream = update_macro(input.clone())?;
625    let select_token_stream = select_macro(input.clone())?;
626    Ok(quote! {
627        #insert_token_stream
628        #update_token_stream
629        #select_token_stream
630    })
631}
632
633fn rizzle_fields(data: &Data) -> Vec<RizzleField> {
634    match data {
635        syn::Data::Struct(ref data) => data
636            .fields
637            .iter()
638            .map(|field| {
639                let ident = field
640                    .ident
641                    .as_ref()
642                    .expect("Struct fields should have names");
643                RizzleField {
644                    ident: ident.clone(),
645                    ident_name: ident.to_string(),
646                    field: field.clone(),
647                    attrs: field
648                        .attrs
649                        .iter()
650                        .filter_map(|attr| attr.parse_args::<RizzleAttr>().ok())
651                        .collect::<Vec<_>>(),
652                    type_string: string_type_from_field(&field),
653                }
654            })
655            .collect::<Vec<_>>(),
656        _ => unimplemented!(),
657    }
658}
659
660fn last_segment_from_type(ty: &Type) -> Option<Ident> {
661    match ty {
662        Type::Path(TypePath { path, .. }) => Some(path.segments.last()?.ident.clone()),
663        _ => None,
664    }
665}
666
667#[proc_macro_derive(RizzleSchema, attributes(rizzle))]
668pub fn rizzle_schema(s: TokenStream) -> TokenStream {
669    let input = parse_macro_input!(s as DeriveInput);
670    match rizzle_schema_macro(input) {
671        Ok(s) => s.to_token_stream().into(),
672        Err(e) => e.to_compile_error().into(),
673    }
674}
675
676fn rizzle_schema_macro(input: DeriveInput) -> Result<TokenStream2> {
677    let struct_name = input.ident;
678    let struct_fields = rizzle_fields(&input.data);
679    let new_fields = struct_fields
680        .iter()
681        .map(|field| {
682            let ident = field
683                .field
684                .ident
685                .as_ref()
686                .expect("Struct fields should have ");
687            let ty = &field.field.ty;
688            quote! { #ident: #ty::new() }
689        })
690        .collect::<Vec<_>>();
691    let tables = struct_fields
692        .iter()
693        .map(|field| {
694            let ident = &field.ident;
695            quote! { &self.#ident }
696        })
697        .collect::<Vec<_>>();
698    Ok(quote! {
699        impl RizzleSchema for #struct_name {
700            fn new() -> Self {
701                Self { #(#new_fields,)* }
702            }
703
704            fn sql(&self) -> String {
705                "".to_owned()
706            }
707
708            fn tables<'a>(&'a self) -> Vec<&'a dyn Table> {
709                vec![#(#tables,)*]
710            }
711        }
712    })
713}