1extern crate proc_macro;
2use std::collections::HashSet;
3
4use proc_macro::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
5use rand::{TryRngCore, rngs::OsRng};
6
7fn compile_error(msg: &str) -> TokenStream {
8 let mut out = TokenStream::new();
9
10 out.extend([TokenTree::Ident(Ident::new(
12 "compile_error",
13 Span::call_site(),
14 ))]);
15
16 out.extend([TokenTree::Punct(Punct::new('!', Spacing::Alone))]);
18
19 let mut error_str = TokenStream::new();
20 error_str.extend([Literal::string(msg)]);
21
22 let group = Group::new(Delimiter::Parenthesis, error_str);
23 out.extend([TokenTree::Group(group)]);
24
25 out.extend([Punct::new(';', Spacing::Alone)]);
26
27 out
28}
29
30fn is_enum(token_trees: &[TokenTree]) -> bool {
31 match &token_trees[0] {
32 TokenTree::Ident(ident) => {
33 if ident.to_string() == "enum" {
34 true
35 } else {
36 match &token_trees[1] {
37 TokenTree::Ident(ident) => ident.to_string() == "enum",
38 _ => false
39 }
40 }
41 },
42 _ => false
43 }
44}
45
46#[allow(non_camel_case_types)] #[derive(Clone, Copy)]
48enum IntegralType {
49 u32,
50 i32,
51 u64,
52 i64,
53}
54
55impl IntegralType {
56 fn gen_random(&self) -> Integral {
57 match self {
58 IntegralType::u32 => Integral::u32(OsRng.try_next_u32().unwrap()),
59 IntegralType::i32 => Integral::i32(i32::from_ne_bytes(
60 OsRng.try_next_u32().unwrap().to_ne_bytes(),
61 )),
62 IntegralType::u64 => Integral::u64(OsRng.try_next_u64().unwrap()),
63 IntegralType::i64 => Integral::i64(i64::from_ne_bytes(
64 OsRng.try_next_u64().unwrap().to_ne_bytes(),
65 )),
66 }
67 }
68
69 fn gen_repr_annotation(&self) -> TokenStream {
70 let mut tree = TokenStream::new();
71
72 let mut repr_content = TokenStream::new();
74 repr_content.extend([TokenTree::Ident(Ident::new("repr", Span::call_site()))]);
75
76 let mut type_content = TokenStream::new();
78 let type_name = match self {
79 IntegralType::u32 => "u32",
80 IntegralType::i32 => "i32",
81 IntegralType::u64 => "u64",
82 IntegralType::i64 => "i64",
83 };
84 type_content.extend([TokenTree::Ident(Ident::new(type_name, Span::call_site()))]);
85
86 repr_content.extend([TokenTree::Group(Group::new(
88 Delimiter::Parenthesis,
89 type_content,
90 ))]);
91
92 tree.extend([
94 TokenTree::Punct(Punct::new('#', Spacing::Alone)),
95 TokenTree::Group(Group::new(Delimiter::Bracket, repr_content)),
96 ]);
97
98 tree
99 }
100}
101
102#[allow(non_camel_case_types)] #[derive(PartialEq, Eq, Hash, Clone, Copy)]
104enum Integral {
105 u32(u32),
106 i32(i32),
107 u64(u64),
108 i64(i64),
109}
110
111impl Integral {
112 fn to_literal(self) -> Literal {
113 match self {
114 Integral::u32(i) => Literal::u32_suffixed(i),
115 Integral::i32(i) => Literal::i32_suffixed(i),
116 Integral::u64(i) => Literal::u64_suffixed(i),
117 Integral::i64(i) => Literal::i64_suffixed(i),
118 }
119 }
120}
121
122fn parse_attrs(attrs: TokenStream) -> Result<IntegralType, &'static str> {
123 if attrs.is_empty() {
124 return Err(
125 "this macro must be provided with a representation of the form `u32`, `i32`, `u64`, `i64`",
126 );
127 }
128 let token_trees: Vec<TokenTree> = attrs.into_iter().collect();
129 if token_trees.len() != 1 {
130 return Err("this macro must be provided only one of `u32`, `i32`, `u64`, `i64`");
131 }
132
133 match &token_trees[0] {
134 TokenTree::Ident(ident) => match ident.to_string().as_ref() {
135 "u32" => Ok(IntegralType::u32),
136 "i32" => Ok(IntegralType::i32),
137 "u64" => Ok(IntegralType::u64),
138 "i64" => Ok(IntegralType::i64),
139 _ => Err("this macro must be provided one of `u32`, `i32`, `u64`, `i64`"),
140 },
141 _ => Err("this macro must be provided only `u32`, `i32`, `u64`, `i64`"),
142 }
143}
144
145fn generate_unique_repr(
146 generated_reprs: &mut HashSet<Integral>,
147 integral_type: IntegralType,
148) -> Integral {
149 loop {
151 let integral = integral_type.gen_random();
152 if generated_reprs.insert(integral) {
153 return integral;
154 }
155 }
156}
157
158fn transform_token_tree(
159 token_stream: TokenStream,
160 generated_reprs: &mut HashSet<Integral>,
161 integral_type: IntegralType,
162) -> TokenStream {
163 let mut result_token_stream = TokenStream::new();
164
165 let mut last_token_tree: Option<TokenTree> = None;
166
167 for child_token_tree in token_stream {
168 if let TokenTree::Group(ref group) = child_token_tree {
169 result_token_stream.extend([TokenTree::Group(Group::new(group.delimiter(), transform_token_tree(
171 group.stream(),
172 generated_reprs,
173 integral_type,
174 )))]);
175 continue;
176 }
177
178 if let TokenTree::Punct(ref punct) = child_token_tree
179 && punct.as_char() == ','
180 && last_token_tree.is_some_and(|last_tree| {
181 matches!(last_tree, TokenTree::Ident(_) | TokenTree::Group(_))
182 })
183 {
184 result_token_stream.extend([
186 TokenTree::Punct(Punct::new('=', Spacing::Alone)),
187 TokenTree::Literal(
188 generate_unique_repr(generated_reprs, integral_type).to_literal(),
189 ),
190 ]);
191 }
192 result_token_stream.extend([child_token_tree.clone()]);
193 last_token_tree = Some(child_token_tree.clone());
194 }
195 result_token_stream
196}
197
198#[proc_macro_attribute]
199pub fn randomize_repr(attrs: TokenStream, item: TokenStream) -> TokenStream {
209 let original_token_trees: Vec<TokenTree> = item.clone().into_iter().collect();
210
211 if !is_enum(&original_token_trees) {
212 return compile_error("this macro must be called on an enum");
213 }
214
215 let attr_result = parse_attrs(attrs);
216
217 if let Err(err) = attr_result {
218 return compile_error(err);
219 }
220
221 let integral_type = attr_result.unwrap();
222
223 let mut result_token_stream = TokenStream::new();
224
225 result_token_stream.extend(integral_type.gen_repr_annotation());
226
227 let mut generated_reprs: HashSet<Integral> = HashSet::new();
228
229 result_token_stream.extend([transform_token_tree(item, &mut generated_reprs, integral_type)]);
230
231 result_token_stream
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn always_correct_types() {
240 assert!(matches!(IntegralType::u32.gen_random(), Integral::u32(_)));
241 assert!(matches!(IntegralType::u64.gen_random(), Integral::u64(_)));
242 assert!(matches!(IntegralType::i32.gen_random(), Integral::i32(_)));
243 assert!(matches!(IntegralType::i64.gen_random(), Integral::i64(_)));
244 }
245}