1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
use std::str::FromStr;

use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use quote::{quote, ToTokens};
use syn::{parse_macro_input, ItemStruct};

#[proc_macro_attribute]
pub fn hyperparameters(_header: TokenStream, input_struct: TokenStream) -> TokenStream {
    let orig_struct = parse_macro_input!(input_struct as ItemStruct);
    let field_names = orig_struct.fields.iter().map(|field| field.ident.as_ref().expect("Tuple structs are not allowed."));
    let field_names1 = field_names.clone();
    let field_names3 = field_names.clone();
    let field_names4 = field_names.clone();
    let field_names5 = field_names.clone();
    let num_fields = field_names4.count();
    let field_types = orig_struct.fields.iter().map(|field| &field.ty);
    let field_types1 = orig_struct.fields.iter().map(|field| {
        let type_string = field.ty.to_token_stream().to_string();
        assert!(type_string.len() > 5, "Only Vec hyperparams are supported currently");
        assert_eq!(type_string[..5], *"Vec <", "Only Vec hyperparams are supported currently");
        let type_string = proc_macro2::TokenStream::from_str(&type_string[5..type_string.len()-1].to_owned().trim().to_string().replace('"', "")).unwrap();
        quote!{#type_string}
    });


    let indexes = 0..num_fields;
    let orig_struct_ident = orig_struct.ident.clone();
    let permutation_struct_ident = Ident::new(&format!("{}Permutations", orig_struct_ident), Span::call_site());

    TokenStream::from(quote!{
        // Original struct
        #orig_struct

        // Impl for original struct
        impl <'a>#orig_struct_ident {
            pub fn permutations(&'a self) -> #permutation_struct_ident <'a> {
                #permutation_struct_ident {
                    #(#field_names : &self.#field_names),*,

                    indexes: [0; #num_fields],
                    lens: [#(self.#field_names1.len()),*],
                    first: true,
                }
            }
        }

        // Permutation struct
        pub struct #permutation_struct_ident<'a> {
            #(
                #field_names3: &'a #field_types
            ),*,

            indexes: [usize; #num_fields],
            lens: [usize; #num_fields],
            first: bool,
        }

        // Iterator for permutation struct
        impl <'a>Iterator for #permutation_struct_ident<'a> {
            type Item = (
                #(#field_types1),*
            );

            fn next(&mut self) -> Option<Self::Item> {
                if !self.first {
                    // Iterate indexes
                    for (index, i) in self.indexes.iter_mut().enumerate().rev() {
                        if *i < self.lens[index] - 1 {
                            *i += 1;
                            break
                        } else {
                            if index == 0 {
                                self.indexes = [0; #num_fields];
                                return None;
                            }
                            *i = 0;
                        }
                    }
                } else {
                    self.first = false;
                }
                
                // Return values
                Some(
                    (#(
                        self . #field_names5 [ self.indexes[ #indexes ]]
                    ),*)
                )
            }
        }
    })
}