samplify_rs/
lib.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Data, DeriveInput, Fields};

#[proc_macro_derive(Sampleable)]
pub fn sampleable_derive(input: TokenStream) -> TokenStream {
    // Parse the input tokens into a syntax tree
    let input = parse_macro_input!(input as DeriveInput);

    // Get the struct name
    let struct_name = input.ident.clone();

    // Match only on structs with named fields
    let fields = if let Data::Struct(data_struct) = input.data {
        if let Fields::Named(fields_named) = data_struct.fields {
            fields_named.named
        } else {
            unimplemented!("Sampleable can only be derived for structs with named fields");
        }
    } else {
        unimplemented!("Sampleable can only be derived for structs");
    };

    // Generate sample code for each field
    let mut sample_fields = Vec::new();

    for field in fields.iter() {
        let field_name = &field.ident;
        let field_name_str = field_name.as_ref().unwrap().to_string();
        let field_type = &field.ty;

        // Generate sample code based on the field type
        let sample_code = match field_type {
            syn::Type::Path(type_path) => {
                let type_ident = &type_path.path.segments.last().unwrap().ident;
                let type_ident_str = type_ident.to_string();

                if ["f64", "f32"].contains(&type_ident_str.as_str()) {
                    // For floating-point numbers
                    quote! {
                        {
                            if let Some(config_value) = config.get(#field_name_str) {
                                if let Some(range_array) = config_value.as_array() {
                                    if range_array.len() == 2 {
                                        if let (Some(start), Some(end)) = (range_array[0].as_f64(), range_array[1].as_f64()) {
                                            rand::thread_rng().gen_range(start..end)
                                        } else {
                                            return Err(format!("Invalid range values for field '{}'", #field_name_str));
                                        }
                                    } else {
                                        return Err(format!("Range array for field '{}' must have exactly two elements", #field_name_str));
                                    }
                                } else {
                                    return Err(format!("Configuration for field '{}' must be an array", #field_name_str));
                                }
                            } else {
                                return Err(format!("Configuration for field '{}' is missing", #field_name_str));
                            }
                        }
                    }
                } else if ["i32", "i64", "usize", "u32", "u64"].contains(&type_ident_str.as_str()) {
                    // For integer numbers
                    quote! {
                        {
                            if let Some(config_value) = config.get(#field_name_str) {
                                if let Some(range_array) = config_value.as_array() {
                                    if range_array.len() == 2 {
                                        if let (Some(start), Some(end)) = (range_array[0].as_i64(), range_array[1].as_i64()) {
                                            rand::thread_rng().gen_range(start..end) as #field_type
                                        } else {
                                            return Err(format!("Invalid range values for field '{}'", #field_name_str));
                                        }
                                    } else {
                                        return Err(format!("Range array for field '{}' must have exactly two elements", #field_name_str));
                                    }
                                } else {
                                    return Err(format!("Configuration for field '{}' must be an array", #field_name_str));
                                }
                            } else {
                                return Err(format!("Configuration for field '{}' is missing", #field_name_str));
                            }
                        }
                    }
                } else if type_ident_str == "String" {
                    // For strings
                    quote! {
                        {
                            if let Some(config_value) = config.get(#field_name_str) {
                                if let Some(values_array) = config_value.as_array() {
                                    let values: Vec<String> = values_array.iter()
                                        .filter_map(|v| v.as_str().map(|s| s.to_string()))
                                        .collect();
                                    if !values.is_empty() {
                                        values.choose(&mut rand::thread_rng()).unwrap().clone()
                                    } else {
                                        return Err(format!("Values array for field '{}' is empty", #field_name_str));
                                    }
                                } else {
                                    return Err(format!("Configuration for field '{}' must be an array", #field_name_str));
                                }
                            } else {
                                return Err(format!("Configuration for field '{}' is missing", #field_name_str));
                            }
                        }
                    }
                } else {
                    // Unsupported types default to Default::default()
                    quote! {
                        Default::default()
                    }
                }
            }
            _ => {
                // Unsupported types default to Default::default()
                quote! {
                    Default::default()
                }
            }
        };

        sample_fields.push(quote! {
            #field_name: #sample_code,
        });
    }

    // Generate the sample_with_config method
    let sample_method = quote! {
        impl #struct_name {
            pub fn sample_with_config(config: &serde_json::Map<String, serde_json::Value>) -> Result<Self, String> {
                use rand::Rng;
                use rand::seq::SliceRandom;

                Ok(Self {
                    #(#sample_fields)*
                })
            }
        }
    };

    // Return the generated code
    TokenStream::from(sample_method)
}