use proc_macro::TokenStream;
use quote::quote;
use syn::parse::Parser;
use syn::{parse_macro_input, DeriveInput, ItemFn};
#[proc_macro_attribute]
pub fn test_with_retries(attrs: TokenStream, item: TokenStream) -> TokenStream {
let input_fn = parse_macro_input!(item as ItemFn);
let fn_name = &input_fn.sig.ident;
let tries = attrs
.to_string()
.parse::<u8>()
.expect("Attr must be an int");
let expanded = quote! {
#[test]
fn #fn_name() {
#input_fn
for i in 1..=#tries {
println!("Attempt #{i}");
let result = std::panic::catch_unwind(|| { #fn_name() });
if result.is_ok() {
println!("Ok");
return;
}
if i == #tries {
std::panic::resume_unwind(result.unwrap_err());
}
};
}
};
expanded.into()
}
#[proc_macro_attribute]
pub fn as_algorithm_args(_attrs: TokenStream, input: TokenStream) -> TokenStream {
let mut ast = parse_macro_input!(input as DeriveInput);
match &mut ast.data {
syn::Data::Struct(ref mut struct_data) => {
if let syn::Fields::Named(fields) = &mut struct_data.fields {
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
pub stopping_condition: StoppingCondition
})
.expect("Cannot add `stopping_condition` field"),
);
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
pub parallel: Option<bool>
})
.expect("Cannot add `parallel` field"),
);
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
pub export_history: Option<ExportHistory>
})
.expect("Cannot add `export_history` field"),
);
}
let expand = quote! {
use crate::algorithms::{StoppingCondition, ExportHistory};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Clone)]
#ast
};
expand.into()
}
_ => unimplemented!("`as_algorithm_args` can only be used on structs"),
}
}
#[proc_macro_attribute]
pub fn as_algorithm(attrs: TokenStream, input: TokenStream) -> TokenStream {
let mut ast = parse_macro_input!(input as DeriveInput);
let name = &ast.ident;
let arg_type = syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated
.parse(attrs)
.expect("Cannot parse argument type");
match &mut ast.data {
syn::Data::Struct(ref mut struct_data) => {
if let syn::Fields::Named(fields) = &mut struct_data.fields {
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
problem: Arc<Problem>
})
.expect("Cannot add `problem` field"),
);
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
number_of_individuals: usize
})
.expect("Cannot add `number_of_individuals` field"),
);
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
population: Population
})
.expect("Cannot add `population` field"),
);
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
generation: u32
})
.expect("Cannot add `generation` field"),
);
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
nfe: u32
})
.expect("Cannot add `nfe` field"),
);
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
stopping_condition: StoppingCondition
})
.expect("Cannot add `stopping_condition` field"),
);
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
args: #arg_type
})
.expect("Cannot add `args` field"),
);
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
start_time: Instant
})
.expect("Cannot add `start_time` field"),
);
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
export_history: Option<ExportHistory>
})
.expect("Cannot add `export_history` field"),
);
fields.named.push(
syn::Field::parse_named
.parse2(quote! {
parallel: bool
})
.expect("Cannot add `parallel` field"),
);
}
let expand = quote! {
use std::time::Instant;
use std::sync::Arc;
use crate::core::{Problem, Population};
#ast
impl Display for #name {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str(self.name().as_str())
}
}
};
expand.into()
}
_ => unimplemented!("`as_algorithm` can only be used on structs"),
}
}
#[proc_macro_attribute]
pub fn impl_algorithm_trait_items(attrs: TokenStream, input: TokenStream) -> TokenStream {
let mut ast = parse_macro_input!(input as syn::ItemImpl);
let name = if let syn::Type::Path(tp) = &*ast.self_ty {
tp.path.clone()
} else {
unimplemented!("Token not supported")
};
let arg_type = syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated
.parse(attrs)
.expect("Cannot parse argument type");
let mut new_items = vec![
syn::parse::<syn::ImplItem>(
quote!(
fn stopping_condition(&self) -> &StoppingCondition {
&self.stopping_condition
}
)
.into(),
)
.expect("Failed to parse `name` item"),
syn::parse::<syn::ImplItem>(
quote!(
fn name(&self) -> String {
stringify!(#name).to_string()
}
)
.into(),
)
.expect("Failed to parse `name` item"),
syn::parse::<syn::ImplItem>(
quote!(
fn start_time(&self) -> &Instant {
&self.start_time
}
)
.into(),
)
.expect("Failed to parse `start_time` item"),
syn::parse::<syn::ImplItem>(
quote!(
fn problem(&self) -> Arc<Problem> {
self.problem.clone()
}
)
.into(),
)
.expect("Failed to parse `problem` item"),
syn::parse::<syn::ImplItem>(
quote!(
fn population(&self) -> &Population {
&self.population
}
)
.into(),
)
.expect("Failed to parse `population` item"),
syn::parse::<syn::ImplItem>(
quote!(
fn export_history(&self) -> Option<&ExportHistory> {
self.export_history.as_ref()
}
)
.into(),
)
.expect("Failed to parse `export_history` item"),
syn::parse::<syn::ImplItem>(
quote!(
fn generation(&self) -> u32 {
self.generation
}
)
.into(),
)
.expect("Failed to parse `generation` item"),
syn::parse::<syn::ImplItem>(
quote!(
fn number_of_function_evaluations(&self) -> u32 {
self.nfe
}
)
.into(),
)
.expect("Failed to parse `number_of_function_evaluations` item"),
syn::parse::<syn::ImplItem>(
quote!(
fn algorithm_options(&self) -> #arg_type {
self.args.clone()
}
)
.into(),
)
.expect("Failed to parse `algorithm_options` item"),
];
ast.items.append(&mut new_items);
let expand = quote! { #ast };
expand.into()
}