chia_streamable_macro/
lib.rs

1#![allow(clippy::missing_panics_doc)]
2
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
5use proc_macro_crate::{crate_name, FoundCrate};
6use quote::quote;
7use syn::token::Pub;
8use syn::{
9    parse_macro_input, Data, DeriveInput, Expr, Fields, FieldsNamed, FieldsUnnamed, Index, Lit,
10    Type, Visibility,
11};
12
13#[proc_macro_attribute]
14pub fn streamable(attr: TokenStream, item: TokenStream) -> TokenStream {
15    let found_crate =
16        crate_name("chia-protocol").expect("chia-protocol is present in `Cargo.toml`");
17
18    let chia_protocol = match &found_crate {
19        FoundCrate::Itself => quote!(crate),
20        FoundCrate::Name(name) => {
21            let ident = Ident::new(name, Span::call_site());
22            quote!(#ident)
23        }
24    };
25
26    let is_message = &attr.to_string() == "message";
27    let is_subclass = &attr.to_string() == "subclass";
28    let no_serde = &attr.to_string() == "no_serde";
29    let no_json = &attr.to_string() == "no_json";
30    let no_streamable = &attr.to_string() == "no_streamable";
31
32    let mut input: DeriveInput = parse_macro_input!(item);
33    let name = input.ident.clone();
34    let name_ref = &name;
35
36    let mut extra_impls = Vec::new();
37
38    if let Data::Struct(data) = &mut input.data {
39        let mut field_names = Vec::new();
40        let mut field_types = Vec::new();
41
42        for (i, field) in data.fields.iter_mut().enumerate() {
43            field.vis = Visibility::Public(Pub::default());
44            field_names.push(Ident::new(
45                &field
46                    .ident
47                    .as_ref()
48                    .map(ToString::to_string)
49                    .unwrap_or(format!("field_{i}")),
50                Span::mixed_site(),
51            ));
52            field_types.push(field.ty.clone());
53        }
54
55        let init_names = field_names.clone();
56
57        let initializer = match &data.fields {
58            Fields::Named(..) => quote!( Self { #( #init_names ),* } ),
59            Fields::Unnamed(..) => quote!( Self( #( #init_names ),* ) ),
60            Fields::Unit => quote!(Self),
61        };
62
63        if field_names.is_empty() {
64            extra_impls.push(quote! {
65                impl Default for #name_ref {
66                    fn default() -> Self {
67                        Self::new()
68                    }
69                }
70            });
71        }
72
73        extra_impls.push(quote! {
74            impl #name_ref {
75                #[allow(clippy::too_many_arguments)]
76                pub fn new( #( #field_names: #field_types ),* ) -> #name_ref {
77                    #initializer
78                }
79            }
80        });
81
82        if is_message {
83            extra_impls.push(quote! {
84                impl #chia_protocol::ChiaProtocolMessage for #name_ref {
85                    fn msg_type() -> #chia_protocol::ProtocolMessageTypes {
86                        #chia_protocol::ProtocolMessageTypes::#name_ref
87                    }
88                }
89            });
90        }
91    } else {
92        panic!("only structs are supported");
93    }
94
95    let main_derives = if no_streamable {
96        quote! {
97            #[derive(Hash, Debug, Clone, Eq, PartialEq)]
98        }
99    } else {
100        quote! {
101            #[derive(chia_streamable_macro::Streamable, Hash, Debug, Clone, Eq, PartialEq)]
102        }
103    };
104
105    let class_attrs = if is_subclass {
106        quote!(frozen, subclass)
107    } else {
108        quote!(frozen)
109    };
110
111    // If you're calling the macro from `chia-protocol`, enable Python bindings and arbitrary conditionally.
112    // Otherwise, you're calling it from an external crate which doesn't have this infrastructure setup.
113    // In that case, the caller can add these macros manually if they want to.
114    let attrs = if matches!(found_crate, FoundCrate::Itself) {
115        let serde = if is_message || no_serde {
116            TokenStream2::default()
117        } else {
118            quote! {
119                #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
120            }
121        };
122
123        let json_dict = if no_json {
124            TokenStream2::default()
125        } else {
126            quote! {
127                #[cfg_attr(feature = "py-bindings", derive(chia_py_streamable_macro::PyJsonDict))]
128            }
129        };
130
131        quote! {
132            #[cfg_attr(
133                feature = "py-bindings", pyo3::pyclass(#class_attrs), derive(
134                    chia_py_streamable_macro::PyStreamable,
135                    chia_py_streamable_macro::PyGetters
136                )
137            )]
138            #json_dict
139            #main_derives
140            #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
141            #serde
142        }
143    } else {
144        main_derives
145    };
146
147    quote! {
148        #attrs
149        #input
150        #( #extra_impls )*
151    }
152    .into()
153}
154
155#[proc_macro_derive(Streamable)]
156pub fn chia_streamable_macro(input: TokenStream) -> TokenStream {
157    let found_crate = crate_name("chia-traits").expect("chia-traits is present in `Cargo.toml`");
158
159    let crate_name = match found_crate {
160        FoundCrate::Itself => quote!(crate),
161        FoundCrate::Name(name) => {
162            let ident = Ident::new(&name, Span::call_site());
163            quote!(#ident)
164        }
165    };
166
167    let DeriveInput { ident, data, .. } = parse_macro_input!(input);
168
169    let mut fnames = Vec::<Ident>::new();
170    let mut findices = Vec::<Index>::new();
171    let mut ftypes = Vec::<Type>::new();
172    match data {
173        Data::Enum(e) => {
174            let mut names = Vec::<Ident>::new();
175            let mut values = Vec::<u8>::new();
176            for v in &e.variants {
177                names.push(v.ident.clone());
178                let Some((_, expr)) = &v.discriminant else {
179                    panic!("unsupported enum");
180                };
181                let Expr::Lit(l) = expr else {
182                    panic!("unsupported enum (no literal)");
183                };
184                let Lit::Int(i) = &l.lit else {
185                    panic!("unsupported enum (literal is not integer)");
186                };
187                values.push(
188                    i.base10_parse::<u8>()
189                        .expect("unsupported enum (value not u8)"),
190                );
191            }
192            let ret = quote! {
193                impl #crate_name::Streamable for #ident {
194                    fn update_digest(&self, digest: &mut chia_sha2::Sha256) {
195                        <u8 as #crate_name::Streamable>::update_digest(&(*self as u8), digest);
196                    }
197                    fn stream(&self, out: &mut Vec<u8>) -> #crate_name::chia_error::Result<()> {
198                        <u8 as #crate_name::Streamable>::stream(&(*self as u8), out)
199                    }
200                    fn parse<const TRUSTED: bool>(input: &mut std::io::Cursor<&[u8]>) -> #crate_name::chia_error::Result<Self> {
201                        let v = <u8 as #crate_name::Streamable>::parse::<TRUSTED>(input)?;
202                        match &v {
203                            #(#values => Ok(Self::#names),)*
204                            _ => Err(#crate_name::chia_error::Error::InvalidEnum),
205                        }
206                    }
207                }
208            };
209            return ret.into();
210        }
211        Data::Union(_) => {
212            panic!("Streamable does not support Unions");
213        }
214        Data::Struct(s) => match s.fields {
215            Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
216                for (index, f) in unnamed.iter().enumerate() {
217                    findices.push(Index::from(index));
218                    ftypes.push(f.ty.clone());
219                }
220            }
221            Fields::Unit => {}
222            Fields::Named(FieldsNamed { named, .. }) => {
223                for f in &named {
224                    fnames.push(f.ident.as_ref().unwrap().clone());
225                    ftypes.push(f.ty.clone());
226                }
227            }
228        },
229    }
230
231    if !fnames.is_empty() {
232        let ret = quote! {
233            impl #crate_name::Streamable for #ident {
234                fn update_digest(&self, digest: &mut chia_sha2::Sha256) {
235                    #(self.#fnames.update_digest(digest);)*
236                }
237                fn stream(&self, out: &mut Vec<u8>) -> #crate_name::chia_error::Result<()> {
238                    #(self.#fnames.stream(out)?;)*
239                    Ok(())
240                }
241                fn parse<const TRUSTED: bool>(input: &mut std::io::Cursor<&[u8]>) -> #crate_name::chia_error::Result<Self> {
242                    Ok(Self { #( #fnames: <#ftypes as #crate_name::Streamable>::parse::<TRUSTED>(input)?, )* })
243                }
244            }
245        };
246        ret.into()
247    } else if !findices.is_empty() {
248        let ret = quote! {
249            impl #crate_name::Streamable for #ident {
250                fn update_digest(&self, digest: &mut chia_sha2::Sha256) {
251                    #(self.#findices.update_digest(digest);)*
252                }
253                fn stream(&self, out: &mut Vec<u8>) -> #crate_name::chia_error::Result<()> {
254                    #(self.#findices.stream(out)?;)*
255                    Ok(())
256                }
257                fn parse<const TRUSTED: bool>(input: &mut std::io::Cursor<&[u8]>) -> #crate_name::chia_error::Result<Self> {
258                    Ok(Self( #( <#ftypes as #crate_name::Streamable>::parse::<TRUSTED>(input)?, )* ))
259                }
260            }
261        };
262        ret.into()
263    } else {
264        // this is an empty type (Unit)
265        let ret = quote! {
266            impl #crate_name::Streamable for #ident {
267                fn update_digest(&self, _digest: &mut chia_sha2::Sha256) {}
268                fn stream(&self, _out: &mut Vec<u8>) -> #crate_name::chia_error::Result<()> {
269                    Ok(())
270                }
271                fn parse<const TRUSTED: bool>(_input: &mut std::io::Cursor<&[u8]>) -> #crate_name::chia_error::Result<Self> {
272                    Ok(Self{})
273                }
274            }
275        };
276        ret.into()
277    }
278}