1use proc_macro::TokenStream;
2
3use quote::quote;
4use syn::parse::Parser;
5use syn::{parse_macro_input, DeriveInput, ItemFn};
6
7#[proc_macro_attribute]
10pub fn test_with_retries(attrs: TokenStream, item: TokenStream) -> TokenStream {
11 let input_fn = parse_macro_input!(item as ItemFn);
12 let fn_name = &input_fn.sig.ident;
13 let tries = attrs
14 .to_string()
15 .parse::<u8>()
16 .expect("Attr must be an int");
17
18 let expanded = quote! {
19 #[test]
20 fn #fn_name() {
21 #input_fn
22 for i in 1..=#tries {
23 println!("Attempt #{i}");
24 let result = std::panic::catch_unwind(|| { #fn_name() });
25
26 if result.is_ok() {
27 println!("Ok");
28 return;
29 }
30
31 if i == #tries {
32 std::panic::resume_unwind(result.unwrap_err());
33 }
34 };
35 }
36 };
37 expanded.into()
38}
39
40#[proc_macro_attribute]
46pub fn as_algorithm_args(_attrs: TokenStream, input: TokenStream) -> TokenStream {
47 let mut ast = parse_macro_input!(input as DeriveInput);
48 match &mut ast.data {
49 syn::Data::Struct(ref mut struct_data) => {
50 if let syn::Fields::Named(fields) = &mut struct_data.fields {
51 fields.named.push(
52 syn::Field::parse_named
53 .parse2(quote! {
54 pub stopping_condition: StoppingCondition
56 })
57 .expect("Cannot add `stopping_condition` field"),
58 );
59 fields.named.push(
60 syn::Field::parse_named
61 .parse2(quote! {
62 pub parallel: Option<bool>
66 })
67 .expect("Cannot add `parallel` field"),
68 );
69 fields.named.push(
70 syn::Field::parse_named
71 .parse2(quote! {
72 pub export_history: Option<ExportHistory>
76 })
77 .expect("Cannot add `export_history` field"),
78 );
79 }
80
81 let expand = quote! {
82 use crate::algorithms::{StoppingCondition, ExportHistory};
83 use serde::{Deserialize, Serialize};
84
85 #[derive(Serialize, Deserialize, Clone)]
86 #ast
87 };
88 expand.into()
89 }
90 _ => unimplemented!("`as_algorithm_args` can only be used on structs"),
91 }
92}
93
94#[proc_macro_attribute]
101pub fn as_algorithm(attrs: TokenStream, input: TokenStream) -> TokenStream {
102 let mut ast = parse_macro_input!(input as DeriveInput);
103 let name = &ast.ident;
104
105 let arg_type = syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated
106 .parse(attrs)
107 .expect("Cannot parse argument type");
108
109 match &mut ast.data {
110 syn::Data::Struct(ref mut struct_data) => {
111 if let syn::Fields::Named(fields) = &mut struct_data.fields {
112 fields.named.push(
113 syn::Field::parse_named
114 .parse2(quote! {
115 problem: Arc<Problem>
117 })
118 .expect("Cannot add `problem` field"),
119 );
120 fields.named.push(
121 syn::Field::parse_named
122 .parse2(quote! {
123 number_of_individuals: usize
125 })
126 .expect("Cannot add `number_of_individuals` field"),
127 );
128 fields.named.push(
129 syn::Field::parse_named
130 .parse2(quote! {
131 population: Population
133 })
134 .expect("Cannot add `population` field"),
135 );
136 fields.named.push(
137 syn::Field::parse_named
138 .parse2(quote! {
139 generation: u32
141 })
142 .expect("Cannot add `generation` field"),
143 );
144 fields.named.push(
145 syn::Field::parse_named
146 .parse2(quote! {
147 nfe: u32
149 })
150 .expect("Cannot add `nfe` field"),
151 );
152 fields.named.push(
153 syn::Field::parse_named
154 .parse2(quote! {
155 stopping_condition: StoppingCondition
157 })
158 .expect("Cannot add `stopping_condition` field"),
159 );
160 fields.named.push(
161 syn::Field::parse_named
162 .parse2(quote! {
163 args: #arg_type
165 })
166 .expect("Cannot add `args` field"),
167 );
168 fields.named.push(
169 syn::Field::parse_named
170 .parse2(quote! {
171 start_time: Instant
173 })
174 .expect("Cannot add `start_time` field"),
175 );
176 fields.named.push(
177 syn::Field::parse_named
178 .parse2(quote! {
179 export_history: Option<ExportHistory>
181 })
182 .expect("Cannot add `export_history` field"),
183 );
184 fields.named.push(
185 syn::Field::parse_named
186 .parse2(quote! {
187 parallel: bool
189 })
190 .expect("Cannot add `parallel` field"),
191 );
192 }
193
194 let expand = quote! {
195 use std::time::Instant;
196 use std::sync::Arc;
197 use crate::core::{Problem, Population};
198
199 #ast
200
201 impl Display for #name {
202 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
203 f.write_str(self.name().as_str())
204 }
205 }
206 };
207 expand.into()
208 }
209 _ => unimplemented!("`as_algorithm` can only be used on structs"),
210 }
211}
212
213#[proc_macro_attribute]
219pub fn impl_algorithm_trait_items(attrs: TokenStream, input: TokenStream) -> TokenStream {
220 let mut ast = parse_macro_input!(input as syn::ItemImpl);
221 let name = if let syn::Type::Path(tp) = &*ast.self_ty {
222 tp.path.clone()
223 } else {
224 unimplemented!("Token not supported")
225 };
226 let arg_type = syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated
227 .parse(attrs)
228 .expect("Cannot parse argument type");
229
230 let mut new_items = vec![
231 syn::parse::<syn::ImplItem>(
232 quote!(
233 fn stopping_condition(&self) -> &StoppingCondition {
234 &self.stopping_condition
235 }
236 )
237 .into(),
238 )
239 .expect("Failed to parse `name` item"),
240 syn::parse::<syn::ImplItem>(
241 quote!(
242 fn name(&self) -> String {
243 stringify!(#name).to_string()
244 }
245 )
246 .into(),
247 )
248 .expect("Failed to parse `name` item"),
249 syn::parse::<syn::ImplItem>(
250 quote!(
251 fn start_time(&self) -> &Instant {
252 &self.start_time
253 }
254 )
255 .into(),
256 )
257 .expect("Failed to parse `start_time` item"),
258 syn::parse::<syn::ImplItem>(
259 quote!(
260 fn problem(&self) -> Arc<Problem> {
261 self.problem.clone()
262 }
263 )
264 .into(),
265 )
266 .expect("Failed to parse `problem` item"),
267 syn::parse::<syn::ImplItem>(
268 quote!(
269 fn population(&self) -> &Population {
270 &self.population
271 }
272 )
273 .into(),
274 )
275 .expect("Failed to parse `population` item"),
276 syn::parse::<syn::ImplItem>(
277 quote!(
278 fn export_history(&self) -> Option<&ExportHistory> {
279 self.export_history.as_ref()
280 }
281 )
282 .into(),
283 )
284 .expect("Failed to parse `export_history` item"),
285 syn::parse::<syn::ImplItem>(
286 quote!(
287 fn generation(&self) -> u32 {
288 self.generation
289 }
290 )
291 .into(),
292 )
293 .expect("Failed to parse `generation` item"),
294 syn::parse::<syn::ImplItem>(
295 quote!(
296 fn number_of_function_evaluations(&self) -> u32 {
297 self.nfe
298 }
299 )
300 .into(),
301 )
302 .expect("Failed to parse `number_of_function_evaluations` item"),
303 syn::parse::<syn::ImplItem>(
304 quote!(
305 fn algorithm_options(&self) -> #arg_type {
306 self.args.clone()
307 }
308 )
309 .into(),
310 )
311 .expect("Failed to parse `algorithm_options` item"),
312 ];
313
314 ast.items.append(&mut new_items);
315 let expand = quote! { #ast };
316 expand.into()
317}