Skip to main content

rand_repr/
lib.rs

1extern crate proc_macro;
2use std::collections::HashSet;
3
4use proc_macro::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
5use rand::{TryRngCore, rngs::OsRng};
6
7fn compile_error(msg: &str) -> TokenStream {
8    let mut out = TokenStream::new();
9
10    // compile_error
11    out.extend([TokenTree::Ident(Ident::new(
12        "compile_error",
13        Span::call_site(),
14    ))]);
15
16    // !
17    out.extend([TokenTree::Punct(Punct::new('!', Spacing::Alone))]);
18
19    let mut error_str = TokenStream::new();
20    error_str.extend([Literal::string(msg)]);
21
22    let group = Group::new(Delimiter::Parenthesis, error_str);
23    out.extend([TokenTree::Group(group)]);
24
25    out.extend([Punct::new(';', Spacing::Alone)]);
26
27    out
28}
29
30fn is_enum(token_trees: &[TokenTree]) -> bool {
31    match &token_trees[0] {
32        TokenTree::Ident(ident) => {
33            if ident.to_string() == "enum" {
34                true
35            } else {
36                match &token_trees[1] {
37                    TokenTree::Ident(ident) => ident.to_string() == "enum",
38                    _ => false
39                }
40            }
41        },
42        _ => false
43    }
44}
45
46#[allow(non_camel_case_types)] // keep enum variants consistent with types
47#[derive(Clone, Copy)]
48enum IntegralType {
49    u32,
50    i32,
51    u64,
52    i64,
53}
54
55impl IntegralType {
56    fn gen_random(&self) -> Integral {
57        match self {
58            IntegralType::u32 => Integral::u32(OsRng.try_next_u32().unwrap()),
59            IntegralType::i32 => Integral::i32(i32::from_ne_bytes(
60                OsRng.try_next_u32().unwrap().to_ne_bytes(),
61            )),
62            IntegralType::u64 => Integral::u64(OsRng.try_next_u64().unwrap()),
63            IntegralType::i64 => Integral::i64(i64::from_ne_bytes(
64                OsRng.try_next_u64().unwrap().to_ne_bytes(),
65            )),
66        }
67    }
68
69    fn gen_repr_annotation(&self) -> TokenStream {
70        let mut tree = TokenStream::new();
71
72        // Create the repr(...) content
73        let mut repr_content = TokenStream::new();
74        repr_content.extend([TokenTree::Ident(Ident::new("repr", Span::call_site()))]);
75
76        // Create the type content inside parentheses
77        let mut type_content = TokenStream::new();
78        let type_name = match self {
79            IntegralType::u32 => "u32",
80            IntegralType::i32 => "i32",
81            IntegralType::u64 => "u64",
82            IntegralType::i64 => "i64",
83        };
84        type_content.extend([TokenTree::Ident(Ident::new(type_name, Span::call_site()))]);
85
86        // repr(type)
87        repr_content.extend([TokenTree::Group(Group::new(
88            Delimiter::Parenthesis,
89            type_content,
90        ))]);
91
92        // Build #[repr(...)]
93        tree.extend([
94            TokenTree::Punct(Punct::new('#', Spacing::Alone)),
95            TokenTree::Group(Group::new(Delimiter::Bracket, repr_content)),
96        ]);
97
98        tree
99    }
100}
101
102#[allow(non_camel_case_types)] // keep enum variants consistent with types
103#[derive(PartialEq, Eq, Hash, Clone, Copy)]
104enum Integral {
105    u32(u32),
106    i32(i32),
107    u64(u64),
108    i64(i64),
109}
110
111impl Integral {
112    fn to_literal(self) -> Literal {
113        match self {
114            Integral::u32(i) => Literal::u32_suffixed(i),
115            Integral::i32(i) => Literal::i32_suffixed(i),
116            Integral::u64(i) => Literal::u64_suffixed(i),
117            Integral::i64(i) => Literal::i64_suffixed(i),
118        }
119    }
120}
121
122fn parse_attrs(attrs: TokenStream) -> Result<IntegralType, &'static str> {
123    if attrs.is_empty() {
124        return Err(
125            "this macro must be provided with a representation of the form `u32`, `i32`, `u64`, `i64`",
126        );
127    }
128    let token_trees: Vec<TokenTree> = attrs.into_iter().collect();
129    if token_trees.len() != 1 {
130        return Err("this macro must be provided only one of `u32`, `i32`, `u64`, `i64`");
131    }
132
133    match &token_trees[0] {
134        TokenTree::Ident(ident) => match ident.to_string().as_ref() {
135            "u32" => Ok(IntegralType::u32),
136            "i32" => Ok(IntegralType::i32),
137            "u64" => Ok(IntegralType::u64),
138            "i64" => Ok(IntegralType::i64),
139            _ => Err("this macro must be provided one of `u32`, `i32`, `u64`, `i64`"),
140        },
141        _ => Err("this macro must be provided only `u32`, `i32`, `u64`, `i64`"),
142    }
143}
144
145fn generate_unique_repr(
146    generated_reprs: &mut HashSet<Integral>,
147    integral_type: IntegralType,
148) -> Integral {
149    // TODO: Replace this with something like a block cipher to generate a random permutation deterministically instead of relying on retries
150    loop {
151        let integral = integral_type.gen_random();
152        if generated_reprs.insert(integral) {
153            return integral;
154        }
155    }
156}
157
158fn transform_token_tree(
159    token_stream: TokenStream,
160    generated_reprs: &mut HashSet<Integral>,
161    integral_type: IntegralType,
162) -> TokenStream {
163    let mut result_token_stream = TokenStream::new();
164
165    let mut last_token_tree: Option<TokenTree> = None;
166
167    for child_token_tree in token_stream {
168        if let TokenTree::Group(ref group) = child_token_tree {
169            // recurse
170            result_token_stream.extend([TokenTree::Group(Group::new(group.delimiter(), transform_token_tree(
171                group.stream(),
172                generated_reprs,
173                integral_type,
174            )))]);
175            continue;
176        }
177
178        if let TokenTree::Punct(ref punct) = child_token_tree
179            && punct.as_char() == ','
180            && last_token_tree.is_some_and(|last_tree| {
181                matches!(last_tree, TokenTree::Ident(_) | TokenTree::Group(_))
182            })
183        {
184            // insert an = num here
185            result_token_stream.extend([
186                TokenTree::Punct(Punct::new('=', Spacing::Alone)),
187                TokenTree::Literal(
188                    generate_unique_repr(generated_reprs, integral_type).to_literal(),
189                ),
190            ]);
191        }
192        result_token_stream.extend([child_token_tree.clone()]);
193        last_token_tree = Some(child_token_tree.clone());
194    }
195    result_token_stream
196}
197
198#[proc_macro_attribute]
199/// Randomizes the representation of a macro. Must be called with a representation to use
200/// ```
201/// #[randomize_repr(u32)]
202/// enum Status {
203///     NoLogin,
204///     LoggedIn,
205///     SuperUser
206/// }
207/// ```
208pub fn randomize_repr(attrs: TokenStream, item: TokenStream) -> TokenStream {
209    let original_token_trees: Vec<TokenTree> = item.clone().into_iter().collect();
210
211    if !is_enum(&original_token_trees) {
212        return compile_error("this macro must be called on an enum");
213    }
214
215    let attr_result = parse_attrs(attrs);
216
217    if let Err(err) = attr_result {
218        return compile_error(err);
219    }
220
221    let integral_type = attr_result.unwrap();
222
223    let mut result_token_stream = TokenStream::new();
224
225    result_token_stream.extend(integral_type.gen_repr_annotation());
226
227    let mut generated_reprs: HashSet<Integral> = HashSet::new();
228
229    result_token_stream.extend([transform_token_tree(item, &mut generated_reprs, integral_type)]);
230
231    result_token_stream
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn always_correct_types() {
240        assert!(matches!(IntegralType::u32.gen_random(), Integral::u32(_)));
241        assert!(matches!(IntegralType::u64.gen_random(), Integral::u64(_)));
242        assert!(matches!(IntegralType::i32.gen_random(), Integral::i32(_)));
243        assert!(matches!(IntegralType::i64.gen_random(), Integral::i64(_)));
244    }
245}