1use proc_macro2::{Span, TokenStream};
2use quote::{quote, quote_spanned};
3use syn::spanned::Spanned;
4use syn::*;
5
6#[proc_macro_derive(Randomizable, attributes(csta))]
7pub fn derive_randomizable(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
8 let input = parse_macro_input!(input as DeriveInput);
9 let name = input.ident;
10
11 let generics = add_trait_bounds(input.generics);
12 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
13 match input.data {
14 Data::Struct(data) => match data.fields {
15 Fields::Named(fields) => {
16 let (let_quotes, field_quotes) = parse_fields_named(&fields);
17 quote! {
18 impl #impl_generics csta::Randomizable for #name #ty_generics #where_clause {
19 #[allow(unused)]
20 fn sample<R: rand::Rng + ?Sized>(rng: &mut R) -> Self {
21 #( #let_quotes; )*
22 Self {
23 #( #field_quotes, )*
24 }
25 }
26 }
27 }
28 }
29 Fields::Unnamed(fields) => {
30 let random_fields = parse_fields_unnamed(&fields);
31 quote! {
32 impl #impl_generics csta::Randomizable for #name #ty_generics #where_clause {
33 fn sample<R: rand::Rng + ?Sized>(rng: &mut R) -> Self {
34 Self(
35 #( #random_fields, )*
36 )
37 }
38 }
39 }
40 }
41 Fields::Unit => {
42 quote! {
43 impl #impl_generics csta::Randomizable for #name #ty_generics #where_clause {
44 fn sample<R: rand::Rng + ?Sized>(rng: &mut R) -> Self {
45 Self
46 }
47 }
48 }
49 }
50 },
51 Data::Enum(data) => {
53 if data.variants.iter().any(enum_has_attribute) {
55 assert!(
58 data.variants.iter().all(enum_has_attribute),
59 "If one variant has the weight attribute, all should.\nHint: add #[csta(weight = 0.1)] to ALL variants"
60 );
61 let probabilities = data.variants.iter().map(|variant| {
68 let enum_attributes = get_parsed_enum_attributes(variant);
69 #[allow(clippy::infallible_destructuring_match)]
70 let weight = match &enum_attributes[0] {
71 CstaEnumAttributes::Weighted(float) => float,
72 };
73
74 quote_spanned! {variant.span()=>
75 #weight
76 }
77 });
78
79 let builders = data.variants.iter().map(|variant| {
80 let iden = &variant.ident;
81 match &variant.fields {
82 Fields::Named(fields) => {
83 let (let_quotes, field_quotes) = parse_fields_named(fields);
84 quote_spanned! {variant.span()=>
85 {
86 #( #let_quotes; )*
87 #name::#iden { #( #field_quotes, )* }
88 }
89 }
90 }
91 Fields::Unnamed(fields) => {
92 let random_fields = parse_fields_unnamed(fields);
93 quote_spanned! {variant.span()=>
94 #name::#iden( #( #random_fields, )* )
95 }
96 }
97 Fields::Unit => {
98 quote_spanned! {variant.span()=>
99 #name::#iden
100 }
101 }
102 }
103 }).collect::<Vec<_>>();
104
105 let default = &builders[0];
106 let probabilities: Vec<_> = probabilities.into_iter().zip(data.variants.iter()).scan(quote!(0.0_f64), |state, (prob, variant)| {
107 let tmp = quote_spanned! {variant.span()=>
108 #state + #prob
109 };
110 *state = tmp;
111 Some(state.clone())
112 }).collect();
113
114 let prob_sum = probabilities.last().unwrap();
115
116 let if_builder_chain = probabilities.iter().zip(builders.iter()).map(|(prob, builder)| {
117 quote_spanned! {prob.span()=>
118 if r < #prob {
119 return #builder;
120 }
121 }
122 });
123
124 quote! {
125 impl #impl_generics csta::Randomizable for #name #ty_generics #where_clause {
126 #[allow(unused)]
127 fn sample<R: rand::Rng + ?Sized>(rng: &mut R) -> Self {
128 let total_probability = #prob_sum;
129 if total_probability == 0.0 {
130 return #default;
131 }
132
133 let mut r: f64 = rng.random::<f64>() * total_probability;
134 #( #if_builder_chain )*
135
136 #default
137 }
138 }
139 }
140 } else {
141 let num = data.variants.len();
143 let random_variants = data.variants.iter().enumerate().map(|(i, variant)| {
144 let index = Index::from(i);
145 let iden = &variant.ident;
146
147 match &variant.fields {
148 Fields::Named(fields) => {
149 let (let_quotes, field_quotes) = parse_fields_named(fields);
150 quote_spanned! {variant.span()=>
151 #index => {
152 #( #let_quotes; )*
153 #name::#iden { #( #field_quotes, )* }
154 }
155 }
156 }
157 Fields::Unnamed(fields) => {
158 let random_fields = parse_fields_unnamed(fields);
159 quote_spanned! {variant.span()=>
160 #index => #name::#iden( #( #random_fields, )* )
161 }
162 }
163 Fields::Unit => {
164 quote_spanned! {variant.span()=>
165 #index => #name::#iden
166 }
167 }
168 }
169 });
170 quote! {
171 impl #impl_generics csta::Randomizable for #name #ty_generics #where_clause {
172 #[allow(unused)]
173 fn sample<R: rand::Rng + ?Sized>(rng: &mut R) -> Self {
174 let num = rng.random_range(0..#num);
175 match num {
176 #( #random_variants, )*
177 _ => unreachable!("Number not in range of enum"),
178 }
179 }
180 }
181 }
182 }
183 }
184 Data::Union(_) => unimplemented!(),
185 }
186 .into()
187}
188
189fn add_trait_bounds(mut generics: Generics) -> Generics {
190 for param in &mut generics.params {
191 if let GenericParam::Type(ref mut type_param) = *param {
192 type_param.bounds.push(parse_quote!(csta::Randomizable));
193 }
194 }
195 generics
196}
197
198enum CstaEnumAttributes {
199 Weighted(LitFloat),
200}
201
202fn enum_has_attribute(variant: &Variant) -> bool {
203 let mut csta_attributes = Vec::new();
204 parse_enum_attributes(&variant.attrs, &mut csta_attributes);
205 !csta_attributes.is_empty()
206}
207
208fn get_parsed_enum_attributes(variant: &Variant) -> Vec<CstaEnumAttributes> {
209 let mut csta_attributes = Vec::new();
210 parse_enum_attributes(&variant.attrs, &mut csta_attributes);
211 csta_attributes
212}
213
214fn parse_enum_attributes(
215 attributes: &Vec<Attribute>,
216 csta_attributes: &mut Vec<CstaEnumAttributes>,
217) {
218 for attr in attributes {
219 if attr.path().is_ident("csta") {
220 attr.parse_nested_meta(|meta| {
221 if meta.path.is_ident("weight") {
222 if let Ok(value) = meta.value() {
223 let expr: Expr = value.parse()?;
224 if let Expr::Lit(lit) = expr {
225 if let Lit::Float(float) = lit.lit {
226 csta_attributes.push(CstaEnumAttributes::Weighted(float));
227 } else {
228 return Err(Error::new(attr.span(), "Expected a float number"));
229 }
230 } else {
231 return Err(Error::new(attr.span(), "Expected a float number"));
232 }
233 } else {
234 return Err(Error::new(attr.span(), "Expected a float number"));
235 }
236 }
237 Ok(())
238 })
239 .expect("Failed to parse attribute");
240 }
241 }
242}
243
244fn parse_attribute(attr: &Attribute, csta_attribute: &mut CstaAttributes) {
245 if attr.path().is_ident("csta") {
246 attr.parse_nested_meta(|meta| {
247 if meta.path.is_ident("range") {
248 let content;
249 parenthesized!(content in meta.input);
250 let range: Expr = content.parse()?;
251 if let Expr::Range(range) = range {
252 if range.start.is_none() || range.end.is_none() {
254 return Err(Error::new(
255 range.span(),
256 "Expected range with start and end (either a..b or a..=b)",
257 ));
258 }
259 *csta_attribute = CstaAttributes::Range(range);
260 } else {
261 return Err(Error::new(
262 range.span(),
263 "Expected range (either a..b or a..=b)",
264 ));
265 }
266 }
267 if meta.path.is_ident("len") {
268 let content;
269 parenthesized!(content in meta.input);
270 let expr: Expr = content.parse()?;
271 *csta_attribute = CstaAttributes::Len(expr);
272 }
273 if meta.path.is_ident("after") {
274 let content;
275 parenthesized!(content in meta.input);
276 let expr: Expr = content.parse()?;
277 *csta_attribute = CstaAttributes::After(expr);
278 }
279 if meta.path.is_ident("default") {
280 if let Ok(value) = meta.value() {
281 let iden: TokenStream = value.parse()?;
282 *csta_attribute = CstaAttributes::DefaultWith(iden);
283 } else {
284 *csta_attribute = CstaAttributes::Default;
285 }
286 }
287 if meta.path.is_ident("mul") {
288 let value = meta.value()?;
289 csta_attribute.add_mul(Mul(value.parse()?));
290 }
291 if meta.path.is_ident("div") {
292 let value = meta.value()?;
293 csta_attribute.add_div(Div(value.parse()?));
294 }
295 if meta.path.is_ident("add") {
296 let value = meta.value()?;
297 csta_attribute.add_add(Add(value.parse()?));
298 }
299 if meta.path.is_ident("sub") {
300 let value = meta.value()?;
301 csta_attribute.add_sub(Sub(value.parse()?));
302 }
303 Ok(())
304 })
305 .expect("Failed to parse attribute");
306 }
307}
308
309enum CstaAttributes {
310 UseRandomizable,
311 Range(ExprRange),
312 Len(Expr), After(Expr), Default,
316 DefaultWith(TokenStream),
317 Operation(Option<Mul>, Option<Div>, Option<Add>, Option<Sub>),
318}
319
320impl CstaAttributes {
321 pub fn set_op(&mut self) {
322 if !matches!(self, CstaAttributes::Operation(_, _, _, _)) {
323 *self = CstaAttributes::Operation(None, None, None, None);
324 }
325 }
326
327 pub fn add_mul(&mut self, value: Mul) {
328 self.set_op();
329 if let CstaAttributes::Operation(mul, _, _, _) = self {
330 *mul = Some(value);
331 }
332 }
333
334 pub fn add_div(&mut self, value: Div) {
335 self.set_op();
336 if let CstaAttributes::Operation(_, div, _, _) = self {
337 *div = Some(value);
338 }
339 }
340
341 pub fn add_add(&mut self, value: Add) {
342 self.set_op();
343 if let CstaAttributes::Operation(_, _, add, _) = self {
344 *add = Some(value);
345 }
346 }
347
348 pub fn add_sub(&mut self, value: Sub) {
349 self.set_op();
350 if let CstaAttributes::Operation(_, _, _, sub) = self {
351 *sub = Some(value);
352 }
353 }
354}
355
356struct Mul(TokenStream);
357struct Div(TokenStream);
358struct Add(TokenStream);
359struct Sub(TokenStream);
360
361fn parse_fields_named(fields: &FieldsNamed) -> (Vec<TokenStream>, Vec<TokenStream>) {
366 let mut early_let_quotes = Vec::new();
367 let mut later_let_quotes = Vec::new();
368 let mut last_let_quotes = Vec::new();
369 let mut fields_quotes = Vec::new();
370
371 for field in &fields.named {
372 let mut attribute = CstaAttributes::UseRandomizable; field
374 .attrs
375 .iter()
376 .for_each(|attr| parse_attribute(attr, &mut attribute));
377 let ident = &field.ident;
378 let field_type = &field.ty;
379 let value = apply_attributes(field_type, field.span(), &attribute);
380 match attribute {
381 CstaAttributes::Default => {
382 early_let_quotes.push(quote_spanned! {field.span()=>
384 let #ident: #field_type = #value
385 });
386 }
387 CstaAttributes::DefaultWith(_) => {
388 later_let_quotes.push(quote_spanned! {field.span()=>
390 let #ident: #field_type = #value
391 });
392 }
393 CstaAttributes::After(expr) => {
394 later_let_quotes.push(quote_spanned! {field.span()=>
397 let #ident: #field_type = <#field_type as ::csta::Randomizable>::sample(rng)
398 });
399 last_let_quotes.push(quote_spanned! {field.span()=>
400 let #ident: #field_type = #expr
401 });
402 }
403 _ => {
404 last_let_quotes.push(quote_spanned! {fields.span()=>
406 let #ident: #field_type = #value
407 });
408 }
409 }
410 fields_quotes.push(quote_spanned! {fields.span()=>
412 #ident
413 });
414 }
415 early_let_quotes.append(&mut later_let_quotes);
417 early_let_quotes.append(&mut last_let_quotes);
418 (early_let_quotes, fields_quotes)
419}
420
421fn parse_fields_unnamed(fields: &FieldsUnnamed) -> impl Iterator<Item = TokenStream> + '_ {
422 fields.unnamed.iter().map(|field| {
423 let mut modifier = CstaAttributes::UseRandomizable;
424 field
425 .attrs
426 .iter()
427 .for_each(|attr| parse_attribute(attr, &mut modifier));
428
429 let field_type = &field.ty;
430 apply_attributes(field_type, field.span(), &modifier)
431 })
432}
433
434fn apply_attributes(field_type: &Type, span: Span, modifier: &CstaAttributes) -> TokenStream {
435 match modifier {
436 CstaAttributes::UseRandomizable => quote_spanned! {span=>
437 <#field_type as ::csta::Randomizable>::sample(rng)
438 },
439 CstaAttributes::Range(range) => quote_spanned! {span=>
440 rng.random_range(#range)
441 },
442 CstaAttributes::Default => quote_spanned! {span=>
443 Default::default()
444 },
445 CstaAttributes::DefaultWith(iden) => quote_spanned! {span=>
446 #iden
447 },
448 CstaAttributes::Len(len) => {
449 let generics = extract_vec_inner(field_type);
460 if let Some(inner_type) = generics {
461 quote_spanned! {span=>
462 (0..#len).map(|_| <#inner_type as ::csta::Randomizable>::sample(rng)).collect()
463 }
464 } else {
465 quote_spanned! (span=>)
466 }
467 }
468 CstaAttributes::After(expr) => quote_spanned! {span=>
469 #expr
470 },
471 CstaAttributes::Operation(mul, div, add, sub) => {
472 let mut field = quote_spanned! {span=>
473 #field_type::sample(rng)
474 };
475 if let Some(Mul(mul)) = mul {
476 field.extend(quote_spanned! {span=>
477 * #mul
478 });
479 }
480 if let Some(Div(div)) = div {
481 field.extend(quote_spanned! {span=>
482 / #div
483 });
484 }
485 if let Some(Add(add)) = add {
486 field.extend(quote_spanned! {span=>
487 + #add
488 });
489 }
490 if let Some(Sub(sub)) = sub {
491 field.extend(quote_spanned! {span=>
492 - #sub
493 });
494 }
495 field
496 }
497 }
498}
499
500fn extract_vec_inner(ty: &Type) -> Option<&Type> {
501 if let Type::Path(type_path) = ty
502 && let Some(last_segment) = type_path.path.segments.last()
503 && last_segment.ident == "Vec"
504 && let PathArguments::AngleBracketed(ref generic_args) = last_segment.arguments
505 && let Some(GenericArgument::Type(inner_ty)) = generic_args.args.first()
506 {
507 return Some(inner_ty);
508 }
509 None
510}