optirustic_macros/
lib.rs

1use proc_macro::TokenStream;
2
3use quote::quote;
4use syn::parse::Parser;
5use syn::{parse_macro_input, DeriveInput, ItemFn};
6
7/// An attribute macro to repeat a test `n` times until the test passes. The test passes if it does
8/// not panic at least once, it fails if it panics `n` times.
9#[proc_macro_attribute]
10pub fn test_with_retries(attrs: TokenStream, item: TokenStream) -> TokenStream {
11    let input_fn = parse_macro_input!(item as ItemFn);
12    let fn_name = &input_fn.sig.ident;
13    let tries = attrs
14        .to_string()
15        .parse::<u8>()
16        .expect("Attr must be an int");
17
18    let expanded = quote! {
19        #[test]
20        fn #fn_name() {
21            #input_fn
22            for i in 1..=#tries {
23                println!("Attempt #{i}");
24                let result = std::panic::catch_unwind(|| { #fn_name() });
25
26                if result.is_ok() {
27                    println!("Ok");
28                    return;
29                }
30
31                if i == #tries {
32                    std::panic::resume_unwind(result.unwrap_err());
33                }
34            };
35        }
36    };
37    expanded.into()
38}
39
40/// Register new fields on a struct that contains algorithm options. This macro adds:
41///  - the Serialize, Deserialize, Clone traits to the structure to make it serialisable and
42///    de-serialisable.
43///  - add the following fields: stopping_condition (`StoppingCondition`), parallel (`bool`)
44///    and export_history (`Option<ExportHistory>`).
45#[proc_macro_attribute]
46pub fn as_algorithm_args(_attrs: TokenStream, input: TokenStream) -> TokenStream {
47    let mut ast = parse_macro_input!(input as DeriveInput);
48    match &mut ast.data {
49        syn::Data::Struct(ref mut struct_data) => {
50            if let syn::Fields::Named(fields) = &mut struct_data.fields {
51                fields.named.push(
52                    syn::Field::parse_named
53                        .parse2(quote! {
54                            /// The condition to use when to terminate the algorithm.
55                            pub stopping_condition: StoppingCondition
56                        })
57                        .expect("Cannot add `stopping_condition` field"),
58                );
59                fields.named.push(
60                    syn::Field::parse_named
61                        .parse2(quote! {
62                            /// Whether the objective and constraint evaluation in [`Problem::evaluator`] should run
63                            /// using threads. If the evaluation function takes a long time to run and return the updated
64                            /// values, it is advisable to set this to `true`. This defaults to `true`.
65                            pub parallel: Option<bool>
66                        })
67                        .expect("Cannot add `parallel` field"),
68                );
69                fields.named.push(
70                    syn::Field::parse_named
71                        .parse2(quote! {
72                            /// The options to configure the individual's history export. When provided, the algorithm will
73                            /// save objectives, constraints and solutions to a file each time the generation increases by
74                            /// a given step. This is useful to track convergence and inspect an algorithm evolution.
75                            pub export_history: Option<ExportHistory>
76                        })
77                        .expect("Cannot add `export_history` field"),
78                );
79            }
80
81            let expand = quote! {
82                use crate::algorithms::{StoppingCondition, ExportHistory};
83                use serde::{Deserialize, Serialize};
84
85                #[derive(Serialize, Deserialize, Clone)]
86                #ast
87            };
88            expand.into()
89        }
90        _ => unimplemented!("`as_algorithm_args` can only be used on structs"),
91    }
92}
93
94/// This macro adds the following private fields to the struct defining an algorithm:
95/// `problem`, `number_of_individuals`, `population`, `generation`,`stopping_condition`,
96/// `number_of_function_evaluations`, `start_time`, `export_history` and `parallel`.
97///
98/// It also implements the `Display` trait.
99///
100#[proc_macro_attribute]
101pub fn as_algorithm(attrs: TokenStream, input: TokenStream) -> TokenStream {
102    let mut ast = parse_macro_input!(input as DeriveInput);
103    let name = &ast.ident;
104
105    let arg_type = syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated
106        .parse(attrs)
107        .expect("Cannot parse argument type");
108
109    match &mut ast.data {
110        syn::Data::Struct(ref mut struct_data) => {
111            if let syn::Fields::Named(fields) = &mut struct_data.fields {
112                fields.named.push(
113                    syn::Field::parse_named
114                        .parse2(quote! {
115                            /// The problem being solved.
116                            problem: Arc<Problem>
117                        })
118                        .expect("Cannot add `problem` field"),
119                );
120                fields.named.push(
121                    syn::Field::parse_named
122                        .parse2(quote! {
123                            /// The number of individuals to use in the population.
124                            number_of_individuals: usize
125                        })
126                        .expect("Cannot add `number_of_individuals` field"),
127                );
128                fields.named.push(
129                    syn::Field::parse_named
130                        .parse2(quote! {
131                            /// The population with the solutions.
132                            population: Population
133                        })
134                        .expect("Cannot add `population` field"),
135                );
136                fields.named.push(
137                    syn::Field::parse_named
138                        .parse2(quote! {
139                            /// The evolution step.
140                            generation: u32
141                        })
142                        .expect("Cannot add `generation` field"),
143                );
144                fields.named.push(
145                    syn::Field::parse_named
146                        .parse2(quote! {
147                            /// The number of function evaluations.
148                            nfe: u32
149                        })
150                        .expect("Cannot add `nfe` field"),
151                );
152                fields.named.push(
153                    syn::Field::parse_named
154                        .parse2(quote! {
155                             /// The stopping condition.
156                            stopping_condition: StoppingCondition
157                        })
158                        .expect("Cannot add `stopping_condition` field"),
159                );
160                fields.named.push(
161                    syn::Field::parse_named
162                        .parse2(quote! {
163                            /// The algorithm options
164                            args: #arg_type
165                        })
166                        .expect("Cannot add `args` field"),
167                );
168                fields.named.push(
169                    syn::Field::parse_named
170                        .parse2(quote! {
171                            /// The time when the algorithm started.
172                            start_time: Instant
173                        })
174                        .expect("Cannot add `start_time` field"),
175                );
176                fields.named.push(
177                    syn::Field::parse_named
178                        .parse2(quote! {
179                            /// The configuration struct to export the algorithm history.
180                            export_history: Option<ExportHistory>
181                        })
182                        .expect("Cannot add `export_history` field"),
183                );
184                fields.named.push(
185                    syn::Field::parse_named
186                        .parse2(quote! {
187                            /// Whether the evaluation should run using threads
188                            parallel: bool
189                        })
190                        .expect("Cannot add `parallel` field"),
191                );
192            }
193
194            let expand = quote! {
195                use std::time::Instant;
196                use std::sync::Arc;
197                use crate::core::{Problem, Population};
198
199                #ast
200
201                impl Display for #name {
202                    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
203                        f.write_str(self.name().as_str())
204                    }
205                }
206            };
207            expand.into()
208        }
209        _ => unimplemented!("`as_algorithm` can only be used on structs"),
210    }
211}
212
213/// This macro adds common items when the `Algorithm` trait is implemented for a new algorithm
214/// struct. This adds the following items: `Algorithm::name()`, `Algorithm::stopping_condition()`
215/// `Algorithm::start_time()`, `Algorithm::problem()`,  `Algorithm::population()`,
216/// `Algorithm::generation()`, `Algorithm::number_of_function_evaluations()` and `Algorithm::export_history()`.
217///
218#[proc_macro_attribute]
219pub fn impl_algorithm_trait_items(attrs: TokenStream, input: TokenStream) -> TokenStream {
220    let mut ast = parse_macro_input!(input as syn::ItemImpl);
221    let name = if let syn::Type::Path(tp) = &*ast.self_ty {
222        tp.path.clone()
223    } else {
224        unimplemented!("Token not supported")
225    };
226    let arg_type = syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated
227        .parse(attrs)
228        .expect("Cannot parse argument type");
229
230    let mut new_items = vec![
231        syn::parse::<syn::ImplItem>(
232            quote!(
233                fn stopping_condition(&self) -> &StoppingCondition {
234                    &self.stopping_condition
235                }
236            )
237            .into(),
238        )
239        .expect("Failed to parse `name` item"),
240        syn::parse::<syn::ImplItem>(
241            quote!(
242                fn name(&self) -> String {
243                    stringify!(#name).to_string()
244                }
245            )
246            .into(),
247        )
248        .expect("Failed to parse `name` item"),
249        syn::parse::<syn::ImplItem>(
250            quote!(
251                fn start_time(&self) -> &Instant {
252                    &self.start_time
253                }
254            )
255            .into(),
256        )
257        .expect("Failed to parse `start_time` item"),
258        syn::parse::<syn::ImplItem>(
259            quote!(
260                fn problem(&self) -> Arc<Problem> {
261                    self.problem.clone()
262                }
263            )
264            .into(),
265        )
266        .expect("Failed to parse `problem` item"),
267        syn::parse::<syn::ImplItem>(
268            quote!(
269                fn population(&self) -> &Population {
270                    &self.population
271                }
272            )
273            .into(),
274        )
275        .expect("Failed to parse `population` item"),
276        syn::parse::<syn::ImplItem>(
277            quote!(
278                fn export_history(&self) -> Option<&ExportHistory> {
279                    self.export_history.as_ref()
280                }
281            )
282            .into(),
283        )
284        .expect("Failed to parse `export_history` item"),
285        syn::parse::<syn::ImplItem>(
286            quote!(
287                fn generation(&self) -> u32 {
288                    self.generation
289                }
290            )
291            .into(),
292        )
293        .expect("Failed to parse `generation` item"),
294        syn::parse::<syn::ImplItem>(
295            quote!(
296                fn number_of_function_evaluations(&self) -> u32 {
297                    self.nfe
298                }
299            )
300            .into(),
301        )
302        .expect("Failed to parse `number_of_function_evaluations` item"),
303        syn::parse::<syn::ImplItem>(
304            quote!(
305                fn algorithm_options(&self) -> #arg_type {
306                    self.args.clone()
307                }
308            )
309            .into(),
310        )
311        .expect("Failed to parse `algorithm_options` item"),
312    ];
313
314    ast.items.append(&mut new_items);
315    let expand = quote! { #ast };
316    expand.into()
317}