Skip to main content

azalea_registry_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    Attribute, Ident, LitStr, Token, braced,
5    parse::{Parse, ParseStream, Result},
6    parse_macro_input,
7    punctuated::Punctuated,
8};
9
10struct Registry {
11    name: Ident,
12    items: Vec<RegistryItem>,
13    attrs: Vec<Attribute>,
14}
15
16struct RegistryItem {
17    attrs: Vec<Attribute>,
18    name: Ident,
19    id: String,
20}
21
22impl Parse for RegistryItem {
23    // Air => "minecraft:air"
24    fn parse(input: ParseStream) -> Result<Self> {
25        // parse annotations like #[default]
26        let attrs = input.call(Attribute::parse_outer).unwrap_or_default();
27
28        let name = input.parse()?;
29        input.parse::<Token![=>]>()?;
30        let id = input.parse::<LitStr>()?.value();
31        Ok(RegistryItem { attrs, name, id })
32    }
33}
34
35impl Parse for Registry {
36    fn parse(input: ParseStream) -> Result<Self> {
37        // enum BlockKind {
38        //     Air => "minecraft:air",
39        //     Stone => "minecraft:stone"
40        // }
41
42        // this also includes docs
43        let attrs = input.call(Attribute::parse_outer).unwrap_or_default();
44
45        input.parse::<Token![enum]>()?;
46        let name = input.parse()?;
47        let content;
48        braced!(content in input);
49        let items: Punctuated<RegistryItem, _> =
50            content.parse_terminated(RegistryItem::parse, Token![,])?;
51
52        Ok(Registry {
53            name,
54            items: items.into_iter().collect(),
55            attrs,
56        })
57    }
58}
59
60#[proc_macro]
61pub fn registry(input: TokenStream) -> TokenStream {
62    let input = parse_macro_input!(input as Registry);
63    let name = input.name;
64    let mut generated = quote! {};
65
66    // enum BlockKind {
67    //     Air = 0,
68    //     Stone,
69    // }
70    let mut enum_items = quote! {};
71    for (i, item) in input.items.iter().enumerate() {
72        let attrs = &item.attrs;
73        let name = &item.name;
74        let protocol_id = i as u32;
75        enum_items.extend(quote! {
76            #(#attrs)*
77            #name = #protocol_id,
78        });
79    }
80    let attributes = input.attrs;
81    generated.extend(quote! {
82        #(#attributes)*
83        #[derive(azalea_buf::AzBuf, Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
84        #[repr(u32)]
85        pub enum #name {
86            #enum_items
87        }
88    });
89
90    let max_id = input.items.len() as u32 - 1;
91
92    let doc_0 = format!("Transmutes a u32 to a {name}.");
93    let doc_1 = format!("The `id` should be at most {max_id}.");
94
95    generated.extend(quote! {
96        impl #name {
97            #[doc = #doc_0]
98            ///
99            /// # Safety
100            #[doc = #doc_1]
101            #[inline]
102            pub unsafe fn from_u32_unchecked(id: u32) -> Self {
103                std::mem::transmute::<u32, #name>(id)
104            }
105
106            #[inline]
107            pub fn is_valid_id(id: u32) -> bool {
108                id <= #max_id
109            }
110        }
111        impl crate::Registry for #name {
112            fn from_u32(value: u32) -> Option<Self> {
113                if Self::is_valid_id(value) {
114                    Some(unsafe { Self::from_u32_unchecked(value) })
115                } else {
116                    None
117                }
118            }
119            fn to_u32(&self) -> u32 {
120                *self as u32
121            }
122        }
123    });
124
125    let doc_0 = format!("Safely transmutes a u32 to a {name}.");
126
127    generated.extend(quote! {
128        impl TryFrom<u32> for #name {
129            type Error = ();
130
131            #[doc = #doc_0]
132            fn try_from(id: u32) -> Result<Self, Self::Error> {
133                if let Some(value) = crate::Registry::from_u32(id) {
134                    Ok(value)
135                } else {
136                    Err(())
137                }
138            }
139        }
140    });
141
142    // Display that uses registry ids
143    let mut display_items = quote! {};
144    let mut from_str_items = quote! {};
145    for item in &input.items {
146        let name = &item.name;
147        let id = &item.id;
148        display_items.extend(quote! {
149            Self::#name => write!(f, concat!("minecraft:", #id)),
150        });
151        from_str_items.extend(quote! {
152            #id => Ok(Self::#name),
153        });
154    }
155    generated.extend(quote! {
156        impl std::fmt::Display for #name {
157            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158                match self {
159                    #display_items
160                }
161            }
162        }
163        impl<'a> TryFrom<&'a crate::Identifier> for #name {
164            type Error = ();
165            fn try_from(ident: &'a crate::Identifier) -> Result<Self, Self::Error> {
166                if ident.namespace() != "minecraft" { return Err(()) }
167                match ident.path() {
168                    #from_str_items
169                    _ => return Err(()),
170                }
171            }
172        }
173        impl std::str::FromStr for #name {
174            type Err = ();
175
176            fn from_str(s: &str) -> Result<Self, Self::Err> {
177                Self::try_from(&crate::Identifier::new(s))
178            }
179        }
180
181        #[cfg(feature = "serde")]
182        impl serde::Serialize for #name {
183            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
184            where
185                S: serde::Serializer,
186            {
187                serializer.serialize_str(&self.to_string())
188            }
189        }
190        #[cfg(feature = "serde")]
191        impl<'de> serde::Deserialize<'de> for #name {
192            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
193            where
194                D: serde::Deserializer<'de>,
195            {
196                let s = String::deserialize(deserializer)?;
197                s.parse().map_err(|_| serde::de::Error::custom(
198                    format!("{s:?} is not a valid {name}", s = s, name = stringify!(#name))
199                ))
200            }
201        }
202
203        impl simdnbt::FromNbtTag for #name {
204            fn from_nbt_tag(tag: simdnbt::borrow::NbtTag) -> Option<Self> {
205                let v = tag.string()?;
206                std::str::FromStr::from_str(&v.to_str()).ok()
207            }
208        }
209        impl simdnbt::ToNbtTag for #name {
210            fn to_nbt_tag(self) -> simdnbt::owned::NbtTag {
211                simdnbt::owned::NbtTag::String(self.to_string().into())
212            }
213        }
214    });
215
216    generated.into()
217}