1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
use std::str::FromStr;
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use quote::{quote, ToTokens};
use syn::{parse_macro_input, ItemStruct};
#[proc_macro_attribute]
pub fn hyperparameters(_header: TokenStream, input_struct: TokenStream) -> TokenStream {
let orig_struct = parse_macro_input!(input_struct as ItemStruct);
let field_names = orig_struct.fields.iter().map(|field| field.ident.as_ref().expect("Tuple structs are not allowed."));
let field_names1 = field_names.clone();
let field_names3 = field_names.clone();
let field_names4 = field_names.clone();
let field_names5 = field_names.clone();
let num_fields = field_names4.count();
let field_types = orig_struct.fields.iter().map(|field| &field.ty);
let field_types1 = orig_struct.fields.iter().map(|field| {
let type_string = field.ty.to_token_stream().to_string();
assert!(type_string.len() > 5, "Only Vec hyperparams are supported currently");
assert_eq!(type_string[..5], *"Vec <", "Only Vec hyperparams are supported currently");
let type_string = proc_macro2::TokenStream::from_str(&type_string[5..type_string.len()-1].to_owned().trim().to_string().replace('"', "")).unwrap();
quote!{#type_string}
});
let indexes = 0..num_fields;
let orig_struct_ident = orig_struct.ident.clone();
let permutation_struct_ident = Ident::new(&format!("{}Permutations", orig_struct_ident), Span::call_site());
TokenStream::from(quote!{
#orig_struct
impl <'a>#orig_struct_ident {
pub fn permutations(&'a self) -> #permutation_struct_ident <'a> {
#permutation_struct_ident {
#(#field_names : &self.#field_names),*,
indexes: [0; #num_fields],
lens: [#(self.#field_names1.len()),*],
first: true,
}
}
}
pub struct #permutation_struct_ident<'a> {
#(
#field_names3: &'a #field_types
),*,
indexes: [usize; #num_fields],
lens: [usize; #num_fields],
first: bool,
}
impl <'a>Iterator for #permutation_struct_ident<'a> {
type Item = (
#(#field_types1),*
);
fn next(&mut self) -> Option<Self::Item> {
if !self.first {
for (index, i) in self.indexes.iter_mut().enumerate().rev() {
if *i < self.lens[index] - 1 {
*i += 1;
break
} else {
if index == 0 {
self.indexes = [0; #num_fields];
return None;
}
*i = 0;
}
}
} else {
self.first = false;
}
Some(
(#(
self . #field_names5 [ self.indexes[ #indexes ]]
),*)
)
}
}
})
}