samplify_rs/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote};
3use syn::{parse_macro_input, Data, DeriveInput, Fields, Type};
4
5#[proc_macro_derive(Sampleable)]
6pub fn sampleable_derive(input: TokenStream) -> TokenStream {
7    // Parse the input tokens into a syntax tree.
8    let input = parse_macro_input!(input as DeriveInput);
9
10    // Get the name of the struct or enum.
11    let name = input.ident.clone();
12
13    // Match on the data type: struct or enum
14    match input.data {
15        Data::Struct(data_struct) => {
16            // Handle structs
17            expand_struct(name, data_struct)
18        },
19        Data::Enum(data_enum) => {
20            // Handle enums
21            expand_enum(name, data_enum)
22        },
23        _ => {
24            unimplemented!("Sampleable can only be derived for structs and enums");
25        }
26    }
27}
28
29fn expand_struct(name: syn::Ident, data_struct: syn::DataStruct) -> TokenStream {
30    // Extract the fields from the struct.
31    let fields = match data_struct.fields {
32        Fields::Named(fields_named) => fields_named.named,
33        _ => unimplemented!("Sampleable can only be derived for structs with named fields"),
34    };
35
36    // Generate code for each field.
37    let field_samples = fields.iter().map(|field| {
38        let field_name = field.ident.as_ref().unwrap();
39        let field_name_str = field_name.to_string();
40        let field_type = &field.ty;
41
42        let sample_code = generate_sample_code(field_type, &field_name_str, &quote!(config));
43
44        quote! {
45            #field_name: #sample_code
46        }
47    });
48
49    // Generate the sample_with_config method.
50    let expanded = quote! {
51        impl #name {
52            pub fn sample_with_config(config: &serde_json::Map<String, serde_json::Value>) -> Result<Self, String> {
53                use rand::Rng;
54                use rand::seq::SliceRandom;
55
56                Ok(Self {
57                    #(#field_samples),*
58                })
59            }
60        }
61    };
62
63    // Return the generated code.
64    TokenStream::from(expanded)
65}
66
67fn expand_enum(name: syn::Ident, data_enum: syn::DataEnum) -> TokenStream {
68    // Get the variants
69    let variants = data_enum.variants;
70
71    // Generate code to randomly select a variant
72    let variant_names = variants.iter().map(|v| v.ident.clone());
73
74    let variant_sample_cases = variants.iter().map(|variant| {
75        let variant_name = &variant.ident;
76        let variant_name_str = variant_name.to_string();
77
78        match &variant.fields {
79            Fields::Unit => {
80                // Unit variant, no fields
81                quote! {
82                    #variant_name_str => {
83                        #name::#variant_name
84                    }
85                }
86            },
87            Fields::Named(fields_named) => {
88                // Struct variant
89                let field_samples = fields_named.named.iter().map(|field| {
90                    let field_name = &field.ident;
91                    let field_name_str = field_name.as_ref().unwrap().to_string();
92                    let field_type = &field.ty;
93
94                    let sample_code = generate_sample_code(field_type, &field_name_str, &quote!(variant_data));
95
96                    quote! {
97                        #field_name: #sample_code
98                    }
99                });
100
101                quote! {
102                    #variant_name_str => {
103                        if let Some(serde_json::Value::Object(variant_data)) = variant_config.get(#variant_name_str) {
104                            #name::#variant_name {
105                                #(#field_samples),*
106                            }
107                        } else {
108                            return Err(format!("Configuration for variant '{}' is missing or invalid", #variant_name_str));
109                        }
110                    }
111                }
112            },
113            Fields::Unnamed(fields_unnamed) => {
114                // Tuple variant
115                let field_samples = fields_unnamed.unnamed.iter().enumerate().map(|(i, field)| {
116                    let field_name_str = format!("field{}", i);
117                    let field_type = &field.ty;
118
119                    let sample_code = generate_sample_code(field_type, &field_name_str, &quote!(variant_data));
120
121                    quote! {
122                        #sample_code
123                    }
124                });
125
126                quote! {
127                    #variant_name_str => {
128                        if let Some(serde_json::Value::Object(variant_data)) = variant_config.get(#variant_name_str) {
129                            #name::#variant_name(
130                                #(#field_samples),*
131                            )
132                        } else {
133                            return Err(format!("Configuration for variant '{}' is missing or invalid", #variant_name_str));
134                        }
135                    }
136                }
137            },
138        }
139    });
140
141    // Generate the sample_with_config method for the enum
142    let expanded = quote! {
143        impl #name {
144            pub fn sample_with_config(config: &serde_json::Map<String, serde_json::Value>) -> Result<Self, String> {
145                use rand::Rng;
146                use rand::seq::SliceRandom;
147
148                // Get the list of allowed variants from the config
149                let variants: Vec<String> = if let Some(serde_json::Value::Array(variant_array)) = config.get("variants") {
150                    variant_array.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect()
151                } else {
152                    {
153                        let mut vec = Vec::new();
154                        #(
155                            vec.push(String::from(stringify!(#variant_names)));
156                        )*
157                        vec
158                    }
159                };
160
161                if variants.is_empty() {
162                    return Err("No variants specified for enum sampling".to_string());
163                }
164
165                let selected_variant = variants.choose(&mut rand::thread_rng()).unwrap();
166
167                // Get the 'variant_data' from the config
168                let variant_config = if let Some(serde_json::Value::Object(map)) = config.get("variant_data") {
169                    map
170                } else {
171                    &serde_json::Map::new()
172                };
173
174                let result = match selected_variant.as_str() {
175                    #(#variant_sample_cases),*,
176                    _ => return Err(format!("Variant '{}' is not recognized", selected_variant)),
177                };
178
179                Ok(result)
180            }
181        }
182    };
183
184    TokenStream::from(expanded)
185}
186
187// Helper function to generate sample code based on the field type.
188fn generate_sample_code(field_type: &Type, field_name_str: &str, config_var: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
189    if is_option(field_type) {
190        let inner_type = get_inner_type(field_type);
191        let inner_sample_code = generate_sample_code(&inner_type, field_name_str, config_var);
192
193        quote! {
194            {
195                if let Some(config_value) = #config_var.get(#field_name_str) {
196                    if config_value.is_null() {
197                        None
198                    } else {
199                        Some(#inner_sample_code)
200                    }
201                } else {
202                    None
203                }
204            }
205        }
206    } else if is_vec(field_type) {
207        let inner_type = get_inner_type(field_type);
208        let inner_sample_code = generate_sample_code_for_vec_elements(&inner_type, field_name_str, config_var);
209
210        quote! {
211            {
212                #inner_sample_code
213            }
214        }
215    } else if is_box(field_type) {
216        let inner_type = get_inner_type(field_type);
217        let inner_sample_code = generate_sample_code(&inner_type, field_name_str, config_var);
218        
219        quote! {
220            Box::new(#inner_sample_code)
221        }
222    } else if is_primitive(field_type) {
223        generate_primitive_sample_code(field_type, field_name_str, config_var)
224    } else {
225        // Assume it's a nested struct or enum that implements Sampleable.
226        quote! {
227            {
228                if let Some(serde_json::Value::Object(map)) = #config_var.get(#field_name_str) {
229                    <#field_type>::sample_with_config(map)?
230                } else {
231                    return Err(format!("Configuration for '{}' must be an object", #field_name_str));
232                }
233            }
234        }
235    }
236}
237
238fn generate_sample_code_for_vec_elements(element_type: &Type, field_name_str: &str, config_var: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
239    if is_primitive(element_type) {
240        // For Vec of primitive types, pick random elements
241        let element_type_str = match element_type {
242            Type::Path(type_path) => {
243                type_path.path.segments.last().unwrap().ident.to_string()
244            },
245            _ => "".to_string(),
246        };
247        let parse_value = match element_type_str.as_str() {
248            "String" => quote! {
249                v.as_str().map(|s| s.to_string())
250            },
251            "i32" | "i64" | "u32" | "u64" | "usize" | "isize" => quote! {
252                v.as_i64().map(|n| n as #element_type)
253            },
254            "f32" | "f64" => quote! {
255                v.as_f64().map(|n| n as #element_type)
256            },
257            "bool" => quote! {
258                v.as_bool()
259            },
260            _ => quote! {
261                None
262            },
263        };
264
265        quote! {
266            {
267                if let Some(config_value) = #config_var.get(#field_name_str) {
268                    if let serde_json::Value::Array(values_array) = config_value {
269                        let values: Vec<#element_type> = values_array.iter()
270                            .filter_map(|v| #parse_value)
271                            .collect();
272                        if values.is_empty() {
273                            return Err(format!("Values array for field '{}' is empty or contains invalid types", #field_name_str));
274                        }
275                        let mut rng = rand::thread_rng();
276                        let sample_size = rng.gen_range(1..=values.len());
277                        let samples = values.choose_multiple(&mut rng, sample_size)
278                            .cloned()
279                            .collect::<Vec<#element_type>>();
280                        samples
281                    } else {
282                        return Err(format!("Configuration for '{}' must be an array", #field_name_str));
283                    }
284                } else {
285                    Vec::<#element_type>::new()
286                }
287            }
288        }
289    } else {
290        // For Vec of complex types
291        quote! {
292            {
293                if let Some(config_value) = #config_var.get(#field_name_str) {
294                    if let serde_json::Value::Array(array) = config_value {
295                        let mut vec = Vec::new();
296                        for item in array {
297                            if let serde_json::Value::Object(item_config) = item {
298                                vec.push(<#element_type>::sample_with_config(&item_config)?);
299                            } else {
300                                return Err(format!("Each item in '{}' must be an object", #field_name_str));
301                            }
302                        }
303                        vec
304                    } else {
305                        return Err(format!("Configuration for '{}' must be an array", #field_name_str));
306                    }
307                } else {
308                    Vec::<#element_type>::new()
309                }
310            }
311        }
312    }
313}
314
315// Helper functions to identify types.
316
317fn is_option(ty: &Type) -> bool {
318    match ty {
319        Type::Path(type_path) => type_path.path.segments.last().unwrap().ident == "Option",
320        _ => false,
321    }
322}
323
324fn is_vec(ty: &Type) -> bool {
325    match ty {
326        Type::Path(type_path) => type_path.path.segments.last().unwrap().ident == "Vec",
327        _ => false,
328    }
329}
330
331fn get_inner_type(ty: &Type) -> Type {
332    match ty {
333        Type::Path(type_path) => {
334            if let syn::PathArguments::AngleBracketed(args) = &type_path.path.segments.last().unwrap().arguments {
335                if let Some(syn::GenericArgument::Type(inner_type)) = args.args.first() {
336                    inner_type.clone()
337                } else {
338                    panic!("Expected a type argument");
339                }
340            } else {
341                panic!("Expected angle bracketed arguments");
342            }
343        }
344        _ => panic!("Expected a type path"),
345    }
346}
347
348fn is_primitive(ty: &Type) -> bool {
349    match ty {
350        Type::Path(type_path) => {
351            let ident = &type_path.path.segments.last().unwrap().ident;
352            ["f64", "f32", "i32", "i64", "u32", "u64", "usize", "isize", "String", "bool"].contains(&ident.to_string().as_str())
353        }
354        _ => false,
355    }
356}
357
358fn is_box(ty: &Type) -> bool {
359    match ty {
360        Type::Path(type_path) => type_path.path.segments.last().unwrap().ident == "Box",
361        _ => false,
362    }
363}
364
365fn generate_primitive_sample_code(field_type: &Type, field_name_str: &str, config_var: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
366    let type_ident = match field_type {
367        Type::Path(type_path) => &type_path.path.segments.last().unwrap().ident,
368        _ => panic!("Expected a type path"),
369    };
370    let type_ident_str = type_ident.to_string();
371
372    if ["f64", "f32"].contains(&type_ident_str.as_str()) {
373        // Floating-point numbers
374        quote! {
375            {
376                if let Some(config_value) = #config_var.get(#field_name_str) {
377                    if let Some(range_array) = config_value.as_array() {
378                        if range_array.len() == 2 {
379                            if let (Some(start), Some(end)) = (range_array[0].as_f64(), range_array[1].as_f64()) {
380                                rand::thread_rng().gen_range(start..end)
381                            } else {
382                                return Err(format!("Invalid range values for field '{}'", #field_name_str));
383                            }
384                        } else {
385                            return Err(format!("Range array for field '{}' must have exactly two elements", #field_name_str));
386                        }
387                    } else {
388                        return Err(format!("Configuration for field '{}' must be an array", #field_name_str));
389                    }
390                } else {
391                    return Err(format!("Configuration for '{}' is missing", #field_name_str));
392                }
393            }
394        }
395    } else if ["i32", "i64", "u32", "u64", "usize", "isize"].contains(&type_ident_str.as_str()) {
396        // Integer numbers
397        quote! {
398            {
399                if let Some(config_value) = #config_var.get(#field_name_str) {
400                    if let Some(range_array) = config_value.as_array() {
401                        if range_array.len() == 2 {
402                            if let (Some(start), Some(end)) = (range_array[0].as_i64(), range_array[1].as_i64()) {
403                                rand::thread_rng().gen_range(start..end) as #field_type
404                            } else {
405                                return Err(format!("Invalid range values for field '{}'", #field_name_str));
406                            }
407                        } else {
408                            return Err(format!("Range array for field '{}' must have exactly two elements", #field_name_str));
409                        }
410                    } else {
411                        return Err(format!("Configuration for field '{}' must be an array", #field_name_str));
412                    }
413                } else {
414                    return Err(format!("Configuration for '{}' is missing", #field_name_str));
415                }
416            }
417        }
418    } else if type_ident_str == "String" {
419        // Strings
420        quote! {
421            {
422                if let Some(config_value) = #config_var.get(#field_name_str) {
423                    if let Some(values_array) = config_value.as_array() {
424                        let values: Vec<String> = values_array.iter()
425                            .filter_map(|v| v.as_str().map(|s| s.to_string()))
426                            .collect();
427                        if !values.is_empty() {
428                            values.choose(&mut rand::thread_rng()).unwrap().clone()
429                        } else {
430                            return Err(format!("Values array for field '{}' is empty", #field_name_str));
431                        }
432                    } else if let Some(value_str) = config_value.as_str() {
433                        value_str.to_string()
434                    } else {
435                        return Err(format!("Configuration for '{}' must be an array or string", #field_name_str));
436                    }
437                } else {
438                    return Err(format!("Configuration for '{}' is missing", #field_name_str));
439                }
440            }
441        }
442    } else if type_ident_str == "bool" {
443        // Booleans
444        quote! {
445            {
446                if let Some(config_value) = #config_var.get(#field_name_str) {
447                    if let Some(value_bool) = config_value.as_bool() {
448                        value_bool
449                    } else {
450                        return Err(format!("Configuration for '{}' must be a boolean", #field_name_str));
451                    }
452                } else {
453                    return Err(format!("Configuration for '{}' is missing", #field_name_str));
454                }
455            }
456        }
457    } else {
458        // Unsupported primitive type
459        quote! {
460            return Err(format!("Unsupported type for field '{}'", #field_name_str));
461        }
462    }
463}