csta_derive/
lib.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{quote, quote_spanned};
3use syn::spanned::Spanned;
4use syn::*;
5
6#[proc_macro_derive(Randomizable, attributes(csta))]
7pub fn derive_randomizable(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
8    let input = parse_macro_input!(input as DeriveInput);
9    let name = input.ident;
10
11    let generics = add_trait_bounds(input.generics);
12    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
13    match input.data {
14        Data::Struct(data) => match data.fields {
15            Fields::Named(fields) => {
16                let (let_quotes, field_quotes) = parse_fields_named(&fields);
17                quote! {
18                    impl #impl_generics csta::Randomizable for #name #ty_generics #where_clause {
19                        #[allow(unused)]
20                        fn sample<R: rand::Rng + ?Sized>(rng: &mut R) -> Self {
21                            #( #let_quotes; )*
22                            Self {
23                                #( #field_quotes, )*
24                            }
25                        }
26                    }
27                }
28            }
29            Fields::Unnamed(fields) => {
30                let random_fields = parse_fields_unnamed(&fields);
31                quote! {
32                    impl #impl_generics csta::Randomizable for #name #ty_generics #where_clause {
33                        fn sample<R: rand::Rng + ?Sized>(rng: &mut R) -> Self {
34                            Self(
35                                #( #random_fields, )*
36                            )
37                        }
38                    }
39                }
40            }
41            Fields::Unit => {
42                quote! {
43                    impl #impl_generics csta::Randomizable for #name #ty_generics #where_clause {
44                        fn sample<R: rand::Rng + ?Sized>(rng: &mut R) -> Self {
45                            Self
46                        }
47                    }
48                }
49            }
50        },
51        // todo: add weighted probabilities
52        Data::Enum(data) => {
53            // see if they have weighted probabilities
54            if data.variants.iter().any(enum_has_attribute) {
55                // at least one have weighted probabilities.
56                // if one have it, then all must have it as well.
57                assert!(
58                    data.variants.iter().all(enum_has_attribute),
59                    "If one variant has the weight attribute, all should.\nHint: add #[csta(weight = 0.1)] to ALL variants"
60                );
61                // I need, total_prob = SUM(weight)
62                // if total_prob == 0.0, return default or first.
63                // prob = weight / total_prob
64                // r = rng()
65                // if r < prob1 {1} else if r - prob1 < prob2 {2}
66
67                let probabilities = data.variants.iter().map(|variant| {
68                    let enum_attributes = get_parsed_enum_attributes(variant);
69                    #[allow(clippy::infallible_destructuring_match)]
70                    let weight = match &enum_attributes[0] {
71                        CstaEnumAttributes::Weighted(float) => float,
72                    };
73                    
74                    quote_spanned! {variant.span()=>
75                        #weight
76                    }
77                });
78
79                let builders = data.variants.iter().map(|variant| {
80                    let iden = &variant.ident;
81                    match &variant.fields {
82                        Fields::Named(fields) => {
83                            let (let_quotes, field_quotes) = parse_fields_named(fields);
84                            quote_spanned! {variant.span()=>
85                                {
86                                    #( #let_quotes; )*
87                                    #name::#iden { #( #field_quotes, )* }
88                                }
89                            }
90                        }
91                        Fields::Unnamed(fields) => {
92                            let random_fields = parse_fields_unnamed(fields);
93                            quote_spanned! {variant.span()=>
94                                #name::#iden( #( #random_fields, )* )
95                            }
96                        }
97                        Fields::Unit => {
98                            quote_spanned! {variant.span()=>
99                                #name::#iden
100                            }
101                        }
102                    }
103                }).collect::<Vec<_>>();
104
105                let default = &builders[0];
106                let probabilities: Vec<_> = probabilities.into_iter().zip(data.variants.iter()).scan(quote!(0.0_f64), |state, (prob, variant)| {
107                    let tmp = quote_spanned! {variant.span()=>
108                        #state + #prob
109                    };
110                    *state = tmp;
111                    Some(state.clone())
112                }).collect();
113
114                let prob_sum = probabilities.last().unwrap();
115
116                let if_builder_chain = probabilities.iter().zip(builders.iter()).map(|(prob, builder)| {
117                    quote_spanned! {prob.span()=>
118                        if r < #prob {
119                            return #builder;
120                        }
121                    }
122                });
123
124                quote! {
125                    impl #impl_generics csta::Randomizable for #name #ty_generics #where_clause {
126                        #[allow(unused)]
127                        fn sample<R: rand::Rng + ?Sized>(rng: &mut R) -> Self {
128                            let total_probability = #prob_sum;
129                            if total_probability == 0.0 {
130                                return #default;
131                            }
132
133                            let mut r: f64 = rng.random::<f64>() * total_probability;
134                            #( #if_builder_chain )*
135
136                            #default
137                        }
138                    }
139                }
140            } else {
141                // if no one have weighted, just use the N-dice approach.
142                let num = data.variants.len();
143                let random_variants = data.variants.iter().enumerate().map(|(i, variant)| {
144                    let index = Index::from(i);
145                    let iden = &variant.ident;
146
147                    match &variant.fields {
148                        Fields::Named(fields) => {
149                            let (let_quotes, field_quotes) = parse_fields_named(fields);
150                            quote_spanned! {variant.span()=>
151                                #index => {
152                                    #( #let_quotes; )*
153                                    #name::#iden { #( #field_quotes, )* }
154                                }
155                            }
156                        }
157                        Fields::Unnamed(fields) => {
158                            let random_fields = parse_fields_unnamed(fields);
159                            quote_spanned! {variant.span()=>
160                                #index => #name::#iden( #( #random_fields, )* )
161                            }
162                        }
163                        Fields::Unit => {
164                            quote_spanned! {variant.span()=>
165                                #index => #name::#iden
166                            }
167                        }
168                    }
169                });
170                quote! {
171                    impl #impl_generics csta::Randomizable for #name #ty_generics #where_clause {
172                        #[allow(unused)]
173                        fn sample<R: rand::Rng + ?Sized>(rng: &mut R) -> Self {
174                            let num = rng.random_range(0..#num);
175                            match num {
176                                #( #random_variants, )*
177                                _ => unreachable!("Number not in range of enum"),
178                            }
179                        }
180                    }
181                }
182            }
183        }
184        Data::Union(_) => unimplemented!(),
185    }
186    .into()
187}
188
189fn add_trait_bounds(mut generics: Generics) -> Generics {
190    for param in &mut generics.params {
191        if let GenericParam::Type(ref mut type_param) = *param {
192            type_param.bounds.push(parse_quote!(csta::Randomizable));
193        }
194    }
195    generics
196}
197
198enum CstaEnumAttributes {
199    Weighted(LitFloat),
200}
201
202fn enum_has_attribute(variant: &Variant) -> bool {
203    let mut csta_attributes = Vec::new();
204    parse_enum_attributes(&variant.attrs, &mut csta_attributes);
205    !csta_attributes.is_empty()
206}
207
208fn get_parsed_enum_attributes(variant: &Variant) -> Vec<CstaEnumAttributes> {
209    let mut csta_attributes = Vec::new();
210    parse_enum_attributes(&variant.attrs, &mut csta_attributes);
211    csta_attributes
212}
213
214fn parse_enum_attributes(
215    attributes: &Vec<Attribute>,
216    csta_attributes: &mut Vec<CstaEnumAttributes>,
217) {
218    for attr in attributes {
219        if attr.path().is_ident("csta") {
220            attr.parse_nested_meta(|meta| {
221                if meta.path.is_ident("weight") {
222                    if let Ok(value) = meta.value() {
223                        let expr: Expr = value.parse()?;
224                        if let Expr::Lit(lit) = expr {
225                            if let Lit::Float(float) = lit.lit {
226                                csta_attributes.push(CstaEnumAttributes::Weighted(float));
227                            } else {
228                                return Err(Error::new(attr.span(), "Expected a float number"));
229                            }
230                        } else {
231                            return Err(Error::new(attr.span(), "Expected a float number"));
232                        }
233                    } else {
234                        return Err(Error::new(attr.span(), "Expected a float number"));
235                    }
236                }
237                Ok(())
238            })
239            .expect("Failed to parse attribute");
240        }
241    }
242}
243
244fn parse_attribute(attr: &Attribute, csta_attribute: &mut CstaAttributes) {
245    if attr.path().is_ident("csta") {
246        attr.parse_nested_meta(|meta| {
247            if meta.path.is_ident("range") {
248                let content;
249                parenthesized!(content in meta.input);
250                let range: Expr = content.parse()?;
251                if let Expr::Range(range) = range {
252                    // Check that the range has start and end
253                    if range.start.is_none() || range.end.is_none() {
254                        return Err(Error::new(
255                            range.span(),
256                            "Expected range with start and end (either a..b or a..=b)",
257                        ));
258                    }
259                    *csta_attribute = CstaAttributes::Range(range);
260                } else {
261                    return Err(Error::new(
262                        range.span(),
263                        "Expected range (either a..b or a..=b)",
264                    ));
265                }
266            }
267            if meta.path.is_ident("len") {
268                let content;
269                parenthesized!(content in meta.input);
270                let expr: Expr = content.parse()?;
271                *csta_attribute = CstaAttributes::Len(expr);
272            }
273            if meta.path.is_ident("after") {
274                let content;
275                parenthesized!(content in meta.input);
276                let expr: Expr = content.parse()?;
277                *csta_attribute = CstaAttributes::After(expr);
278            }
279            if meta.path.is_ident("default") {
280                if let Ok(value) = meta.value() {
281                    let iden: TokenStream = value.parse()?;
282                    *csta_attribute = CstaAttributes::DefaultWith(iden);
283                } else {
284                    *csta_attribute = CstaAttributes::Default;
285                }
286            }
287            if meta.path.is_ident("mul") {
288                let value = meta.value()?;
289                csta_attribute.add_mul(Mul(value.parse()?));
290            }
291            if meta.path.is_ident("div") {
292                let value = meta.value()?;
293                csta_attribute.add_div(Div(value.parse()?));
294            }
295            if meta.path.is_ident("add") {
296                let value = meta.value()?;
297                csta_attribute.add_add(Add(value.parse()?));
298            }
299            if meta.path.is_ident("sub") {
300                let value = meta.value()?;
301                csta_attribute.add_sub(Sub(value.parse()?));
302            }
303            Ok(())
304        })
305        .expect("Failed to parse attribute");
306    }
307}
308
309enum CstaAttributes {
310    UseRandomizable,
311    Range(ExprRange),
312    Len(Expr), // used in Vec<T>
313    // TODO: Probability(TokenStream), // used in Option<T>
314    After(Expr), // for manipulations after being created via randomizable
315    Default,
316    DefaultWith(TokenStream),
317    Operation(Option<Mul>, Option<Div>, Option<Add>, Option<Sub>),
318}
319
320impl CstaAttributes {
321    pub fn set_op(&mut self) {
322        if !matches!(self, CstaAttributes::Operation(_, _, _, _)) {
323            *self = CstaAttributes::Operation(None, None, None, None);
324        }
325    }
326
327    pub fn add_mul(&mut self, value: Mul) {
328        self.set_op();
329        if let CstaAttributes::Operation(mul, _, _, _) = self {
330            *mul = Some(value);
331        }
332    }
333
334    pub fn add_div(&mut self, value: Div) {
335        self.set_op();
336        if let CstaAttributes::Operation(_, div, _, _) = self {
337            *div = Some(value);
338        }
339    }
340
341    pub fn add_add(&mut self, value: Add) {
342        self.set_op();
343        if let CstaAttributes::Operation(_, _, add, _) = self {
344            *add = Some(value);
345        }
346    }
347
348    pub fn add_sub(&mut self, value: Sub) {
349        self.set_op();
350        if let CstaAttributes::Operation(_, _, _, sub) = self {
351            *sub = Some(value);
352        }
353    }
354}
355
356struct Mul(TokenStream);
357struct Div(TokenStream);
358struct Add(TokenStream);
359struct Sub(TokenStream);
360
361// the different thing between named and unnamed is that named fields should be able to be used on other attributes
362// like #[csta(len(w*h))], with w, h being other fields.
363// in unnamed fields is imposible to do that, so its a different bussiness
364/// returns (let_quotes, fields_quotes), in that order
365fn parse_fields_named(fields: &FieldsNamed) -> (Vec<TokenStream>, Vec<TokenStream>) {
366    let mut early_let_quotes = Vec::new();
367    let mut later_let_quotes = Vec::new();
368    let mut last_let_quotes = Vec::new();
369    let mut fields_quotes = Vec::new();
370
371    for field in &fields.named {
372        let mut attribute = CstaAttributes::UseRandomizable; // w/o attributes, use randomizable as default
373        field
374            .attrs
375            .iter()
376            .for_each(|attr| parse_attribute(attr, &mut attribute));
377        let ident = &field.ident;
378        let field_type = &field.ty;
379        let value = apply_attributes(field_type, field.span(), &attribute);
380        match attribute {
381            CstaAttributes::Default => {
382                // Default::default() will get the earlier priority.
383                early_let_quotes.push(quote_spanned! {field.span()=>
384                    let #ident: #field_type = #value
385                });
386            }
387            CstaAttributes::DefaultWith(_) => {
388                // These will get the second priority, so that they can use default fields
389                later_let_quotes.push(quote_spanned! {field.span()=>
390                    let #ident: #field_type = #value
391                });
392            }
393            CstaAttributes::After(expr) => {
394                // after only works on named for now.
395                // it creates a let = T::sample, and then a let = expr;
396                later_let_quotes.push(quote_spanned! {field.span()=>
397                    let #ident: #field_type = <#field_type as ::csta::Randomizable>::sample(rng)
398                });
399                last_let_quotes.push(quote_spanned! {field.span()=>
400                    let #ident: #field_type = #expr
401                });
402            }
403            _ => {
404                // These are last prio, maybe they are in order so their prio is in written order
405                last_let_quotes.push(quote_spanned! {fields.span()=>
406                    let #ident: #field_type = #value
407                });
408            }
409        }
410        // because everything is a let w = #value, Self { w } is used.
411        fields_quotes.push(quote_spanned! {fields.span()=>
412            #ident
413        });
414    }
415    // now we merge let quotes and return everything in correct order.
416    early_let_quotes.append(&mut later_let_quotes);
417    early_let_quotes.append(&mut last_let_quotes);
418    (early_let_quotes, fields_quotes)
419}
420
421fn parse_fields_unnamed(fields: &FieldsUnnamed) -> impl Iterator<Item = TokenStream> + '_ {
422    fields.unnamed.iter().map(|field| {
423        let mut modifier = CstaAttributes::UseRandomizable;
424        field
425            .attrs
426            .iter()
427            .for_each(|attr| parse_attribute(attr, &mut modifier));
428
429        let field_type = &field.ty;
430        apply_attributes(field_type, field.span(), &modifier)
431    })
432}
433
434fn apply_attributes(field_type: &Type, span: Span, modifier: &CstaAttributes) -> TokenStream {
435    match modifier {
436        CstaAttributes::UseRandomizable => quote_spanned! {span=>
437            <#field_type as ::csta::Randomizable>::sample(rng)
438        },
439        CstaAttributes::Range(range) => quote_spanned! {span=>
440            rng.random_range(#range)
441        },
442        CstaAttributes::Default => quote_spanned! {span=>
443            Default::default()
444        },
445        CstaAttributes::DefaultWith(iden) => quote_spanned! {span=>
446            #iden
447        },
448        CstaAttributes::Len(len) => {
449            // if is_vec(field_type) {
450            //     let generics = extract_vec_inner(field_type);
451            //     if let Some(inner_type) = generics {
452            //         apply_modifier(inner_type, span, modifier)
453            //     } else {
454            //         panic!("Vec needs to have generics (Vec<T>)");
455            //     }
456            // } else {
457            //     quote_spanned! (span=>)
458            // }
459            let generics = extract_vec_inner(field_type);
460            if let Some(inner_type) = generics {
461                quote_spanned! {span=>
462                    (0..#len).map(|_| <#inner_type as ::csta::Randomizable>::sample(rng)).collect()
463                }
464            } else {
465                quote_spanned! (span=>)
466            }
467        }
468        CstaAttributes::After(expr) => quote_spanned! {span=>
469            #expr
470        },
471        CstaAttributes::Operation(mul, div, add, sub) => {
472            let mut field = quote_spanned! {span=>
473                #field_type::sample(rng)
474            };
475            if let Some(Mul(mul)) = mul {
476                field.extend(quote_spanned! {span=>
477                    * #mul
478                });
479            }
480            if let Some(Div(div)) = div {
481                field.extend(quote_spanned! {span=>
482                    / #div
483                });
484            }
485            if let Some(Add(add)) = add {
486                field.extend(quote_spanned! {span=>
487                    + #add
488                });
489            }
490            if let Some(Sub(sub)) = sub {
491                field.extend(quote_spanned! {span=>
492                    - #sub
493                });
494            }
495            field
496        }
497    }
498}
499
500fn extract_vec_inner(ty: &Type) -> Option<&Type> {
501    if let Type::Path(type_path) = ty
502        && let Some(last_segment) = type_path.path.segments.last()
503        && last_segment.ident == "Vec"
504        && let PathArguments::AngleBracketed(ref generic_args) = last_segment.arguments
505        && let Some(GenericArgument::Type(inner_ty)) = generic_args.args.first()
506    {
507        return Some(inner_ty);
508    }
509    None
510}