deserter/
lib.rs

1use std::{collections::HashMap, sync::Mutex};
2
3use once_cell::sync::Lazy;
4use proc_macro::TokenStream;
5use proc_macro2::Span;
6use quote::{quote, ToTokens};
7use syn::{
8    braced,
9    parse::{Parse, ParseStream},
10    parse_macro_input,
11    token::Brace,
12    Expr, Ident, ItemStruct, Token, Type,
13};
14
15enum FieldValue {
16    LoadStruct(StructValue),
17    Expr(Expr),
18}
19
20impl ToTokens for FieldValue {
21    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
22        match self {
23            FieldValue::LoadStruct(struct_value) => struct_value.to_tokens(tokens),
24            FieldValue::Expr(expr) => expr.to_tokens(tokens),
25        }
26    }
27}
28
29struct StructValue {
30    fields: HashMap<Ident, FieldValue>,
31}
32
33impl Parse for StructValue {
34    fn parse(input: ParseStream) -> syn::Result<Self> {
35        let content;
36        braced!(content in input);
37
38        let mut fields = HashMap::new();
39        while !content.is_empty() {
40            let name: Ident = content.parse()?;
41            content.parse::<Token![=]>()?;
42            let value = if content.peek(Brace) {
43                FieldValue::LoadStruct(StructValue::parse(&content)?)
44            } else {
45                FieldValue::Expr(content.parse()?)
46            };
47            fields.insert(name, value);
48
49            if !content.peek(Token![,]) {
50                break;
51            } else {
52                content.parse::<Token![,]>()?;
53            }
54        }
55
56        Ok(StructValue { fields })
57    }
58}
59
60impl ToTokens for StructValue {
61    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
62        let fields_tokens = self.fields.iter().map(|(field_name, field_value)| {
63            let field_tokens = match field_value {
64                FieldValue::LoadStruct(struct_value) => {
65                    let structs = STRUCTS
66                        .try_lock()
67                        .expect("Could not get struct cache while untokenizing struct value")
68                        .clone();
69                    let current_struct_name = CURRENT_STRUCT_NAME
70                        .try_lock()
71                        .expect("Could not get current struct name while untokenizing struct value")
72                        .clone();
73                    let current_struct_fields = structs.get(&current_struct_name.clone()).unwrap_or_else(|| {
74                        panic!(
75                            "The type of the field \"{}\" has not been registered with `#[loadable]`. Available types are:\n- {}",
76                            field_name,
77                            structs.keys().map(|element| element.to_owned()).collect::<Vec<_>>().join("\n- ")
78                        )
79                    });
80
81                    let mut struct_name: Option<Ident> = None;
82                    for (other_field_name, field_type) in current_struct_fields {
83                        if other_field_name == &field_name.to_string() {
84                            struct_name = Some(Ident::new(field_type, Span::call_site()));
85                        }
86                    }
87
88                    let Some(struct_name) = struct_name else {
89                        panic!(
90                            "The type of the field \"{}\" has not been registered with `#[loadable]`. Available types are:\n- {}",
91                            field_name,
92                            structs.keys().map(|element| element.to_owned()).collect::<Vec<_>>().join("\n- ")
93                        );
94                    };
95
96                    let mut current_struct_name = CURRENT_STRUCT_NAME
97                        .try_lock()
98                        .expect("Could not get current struct name while untokenizing struct value");
99                    *current_struct_name = struct_name.to_string();
100                    drop(current_struct_name);
101
102                    quote! { #struct_name #struct_value }
103                }
104                FieldValue::Expr(expression) => expression.into_token_stream(),
105            };
106            quote! { #field_name: #field_tokens }
107        });
108
109        tokens.extend(quote! { {
110                #(#fields_tokens),*
111            }
112        });
113    }
114}
115
116struct Wrapper {
117    struct_name: Ident,
118    value: StructValue,
119}
120
121impl Parse for Wrapper {
122    fn parse(input: ParseStream) -> syn::Result<Self> {
123        let struct_name = input.parse::<Ident>()?;
124        let mut current_name = CURRENT_STRUCT_NAME
125            .try_lock()
126            .expect("Cannot get current struct name while attempting to parse outer wrapper struct");
127        *current_name = struct_name.to_string();
128        let value = StructValue::parse(input)?;
129        Ok(Self { struct_name, value })
130    }
131}
132
133impl ToTokens for Wrapper {
134    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
135        let struct_name = &self.struct_name;
136        let value = &self.value;
137        tokens.extend(quote! {
138            #struct_name #value
139        });
140    }
141}
142
143#[proc_macro]
144pub fn load(input: TokenStream) -> TokenStream {
145    let wrapper = parse_macro_input!(input as Wrapper);
146    quote! {
147        #wrapper
148    }
149    .into()
150}
151
152#[proc_macro_attribute]
153pub fn loadable(_attribute: TokenStream, input: TokenStream) -> TokenStream {
154    cache_struct(syn::parse::<ItemStruct>(input.clone()).unwrap());
155    input
156}
157
158static STRUCTS: Lazy<Mutex<HashMap<String, Vec<(String, String)>>>> = Lazy::new(|| Mutex::new(HashMap::new()));
159
160static CURRENT_STRUCT_NAME: Mutex<String> = Mutex::new(String::new());
161
162/// Store a trait definition for future reference.
163fn cache_struct(item: syn::ItemStruct) {
164    STRUCTS
165        .try_lock()
166        .expect("Cannot get cached structs while attempting to cache a struct")
167        .insert(
168            item.ident.to_string(),
169            item.fields
170                .iter()
171                .filter_map(|field| {
172                    let Type::Path(field_type) = &field.ty else {
173                        return None;
174                    };
175
176                    Some((
177                        field.ident.as_ref().unwrap().to_string(),
178                        field_type.path.segments.last().unwrap().ident.to_string(),
179                    ))
180                })
181                .collect(),
182        );
183}