chia_streamable_macro/
lib.rs

1#![allow(clippy::missing_panics_doc)]
2
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
5use proc_macro_crate::{crate_name, FoundCrate};
6use quote::quote;
7use syn::token::Pub;
8use syn::{
9    parse_macro_input, Data, DeriveInput, Expr, Fields, FieldsNamed, FieldsUnnamed, Index, Lit,
10    Type, Visibility,
11};
12
13#[proc_macro_attribute]
14pub fn streamable(attr: TokenStream, item: TokenStream) -> TokenStream {
15    let found_crate =
16        crate_name("chia-protocol").expect("chia-protocol is present in `Cargo.toml`");
17
18    let chia_protocol = match &found_crate {
19        FoundCrate::Itself => quote!(crate),
20        FoundCrate::Name(name) => {
21            let ident = Ident::new(name, Span::call_site());
22            quote!(#ident)
23        }
24    };
25
26    let is_message = &attr.to_string() == "message";
27    let is_subclass = &attr.to_string() == "subclass";
28    let no_serde = &attr.to_string() == "no_serde";
29    let no_json = &attr.to_string() == "no_json";
30
31    let mut input: DeriveInput = parse_macro_input!(item);
32    let name = input.ident.clone();
33    let name_ref = &name;
34
35    let mut extra_impls = Vec::new();
36
37    if let Data::Struct(data) = &mut input.data {
38        let mut field_names = Vec::new();
39        let mut field_types = Vec::new();
40
41        for (i, field) in data.fields.iter_mut().enumerate() {
42            field.vis = Visibility::Public(Pub::default());
43            field_names.push(Ident::new(
44                &field
45                    .ident
46                    .as_ref()
47                    .map(ToString::to_string)
48                    .unwrap_or(format!("field_{i}")),
49                Span::mixed_site(),
50            ));
51            field_types.push(field.ty.clone());
52        }
53
54        let init_names = field_names.clone();
55
56        let initializer = match &data.fields {
57            Fields::Named(..) => quote!( Self { #( #init_names ),* } ),
58            Fields::Unnamed(..) => quote!( Self( #( #init_names ),* ) ),
59            Fields::Unit => quote!(Self),
60        };
61
62        if field_names.is_empty() {
63            extra_impls.push(quote! {
64                impl Default for #name_ref {
65                    fn default() -> Self {
66                        Self::new()
67                    }
68                }
69            });
70        }
71
72        extra_impls.push(quote! {
73            impl #name_ref {
74                #[allow(clippy::too_many_arguments)]
75                pub fn new( #( #field_names: #field_types ),* ) -> #name_ref {
76                    #initializer
77                }
78            }
79        });
80
81        if is_message {
82            extra_impls.push(quote! {
83                impl #chia_protocol::ChiaProtocolMessage for #name_ref {
84                    fn msg_type() -> #chia_protocol::ProtocolMessageTypes {
85                        #chia_protocol::ProtocolMessageTypes::#name_ref
86                    }
87                }
88            });
89        }
90    } else {
91        panic!("only structs are supported");
92    }
93
94    let main_derives = quote! {
95        #[derive(chia_streamable_macro::Streamable, Hash, Debug, Clone, Eq, PartialEq)]
96    };
97
98    let class_attrs = if is_subclass {
99        quote!(frozen, subclass)
100    } else {
101        quote!(frozen)
102    };
103
104    // If you're calling the macro from `chia-protocol`, enable Python bindings and arbitrary conditionally.
105    // Otherwise, you're calling it from an external crate which doesn't have this infrastructure setup.
106    // In that case, the caller can add these macros manually if they want to.
107    let attrs = if matches!(found_crate, FoundCrate::Itself) {
108        let serde = if is_message || no_serde {
109            TokenStream2::default()
110        } else {
111            quote! {
112                #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
113            }
114        };
115
116        let json_dict = if no_json {
117            TokenStream2::default()
118        } else {
119            quote! {
120                #[cfg_attr(feature = "py-bindings", derive(chia_py_streamable_macro::PyJsonDict))]
121            }
122        };
123
124        quote! {
125            #[cfg_attr(
126                feature = "py-bindings", pyo3::pyclass(#class_attrs), derive(
127                    chia_py_streamable_macro::PyStreamable,
128                    chia_py_streamable_macro::PyGetters
129                )
130            )]
131            #json_dict
132            #main_derives
133            #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
134            #serde
135        }
136    } else {
137        main_derives
138    };
139
140    quote! {
141        #attrs
142        #input
143        #( #extra_impls )*
144    }
145    .into()
146}
147
148#[proc_macro_derive(Streamable)]
149pub fn chia_streamable_macro(input: TokenStream) -> TokenStream {
150    let found_crate = crate_name("chia-traits").expect("chia-traits is present in `Cargo.toml`");
151
152    let crate_name = match found_crate {
153        FoundCrate::Itself => quote!(crate),
154        FoundCrate::Name(name) => {
155            let ident = Ident::new(&name, Span::call_site());
156            quote!(#ident)
157        }
158    };
159
160    let DeriveInput { ident, data, .. } = parse_macro_input!(input);
161
162    let mut fnames = Vec::<Ident>::new();
163    let mut findices = Vec::<Index>::new();
164    let mut ftypes = Vec::<Type>::new();
165    match data {
166        Data::Enum(e) => {
167            let mut names = Vec::<Ident>::new();
168            let mut values = Vec::<u8>::new();
169            for v in &e.variants {
170                names.push(v.ident.clone());
171                let Some((_, expr)) = &v.discriminant else {
172                    panic!("unsupported enum");
173                };
174                let Expr::Lit(l) = expr else {
175                    panic!("unsupported enum (no literal)");
176                };
177                let Lit::Int(i) = &l.lit else {
178                    panic!("unsupported enum (literal is not integer)");
179                };
180                values.push(
181                    i.base10_parse::<u8>()
182                        .expect("unsupported enum (value not u8)"),
183                );
184            }
185            let ret = quote! {
186                impl #crate_name::Streamable for #ident {
187                    fn update_digest(&self, digest: &mut chia_sha2::Sha256) {
188                        <u8 as #crate_name::Streamable>::update_digest(&(*self as u8), digest);
189                    }
190                    fn stream(&self, out: &mut Vec<u8>) -> #crate_name::chia_error::Result<()> {
191                        <u8 as #crate_name::Streamable>::stream(&(*self as u8), out)
192                    }
193                    fn parse<const TRUSTED: bool>(input: &mut std::io::Cursor<&[u8]>) -> #crate_name::chia_error::Result<Self> {
194                        let v = <u8 as #crate_name::Streamable>::parse::<TRUSTED>(input)?;
195                        match &v {
196                            #(#values => Ok(Self::#names),)*
197                            _ => Err(#crate_name::chia_error::Error::InvalidEnum),
198                        }
199                    }
200                }
201            };
202            return ret.into();
203        }
204        Data::Union(_) => {
205            panic!("Streamable does not support Unions");
206        }
207        Data::Struct(s) => match s.fields {
208            Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
209                for (index, f) in unnamed.iter().enumerate() {
210                    findices.push(Index::from(index));
211                    ftypes.push(f.ty.clone());
212                }
213            }
214            Fields::Unit => {}
215            Fields::Named(FieldsNamed { named, .. }) => {
216                for f in &named {
217                    fnames.push(f.ident.as_ref().unwrap().clone());
218                    ftypes.push(f.ty.clone());
219                }
220            }
221        },
222    }
223
224    if !fnames.is_empty() {
225        let ret = quote! {
226            impl #crate_name::Streamable for #ident {
227                fn update_digest(&self, digest: &mut chia_sha2::Sha256) {
228                    #(self.#fnames.update_digest(digest);)*
229                }
230                fn stream(&self, out: &mut Vec<u8>) -> #crate_name::chia_error::Result<()> {
231                    #(self.#fnames.stream(out)?;)*
232                    Ok(())
233                }
234                fn parse<const TRUSTED: bool>(input: &mut std::io::Cursor<&[u8]>) -> #crate_name::chia_error::Result<Self> {
235                    Ok(Self { #( #fnames: <#ftypes as #crate_name::Streamable>::parse::<TRUSTED>(input)?, )* })
236                }
237            }
238        };
239        ret.into()
240    } else if !findices.is_empty() {
241        let ret = quote! {
242            impl #crate_name::Streamable for #ident {
243                fn update_digest(&self, digest: &mut chia_sha2::Sha256) {
244                    #(self.#findices.update_digest(digest);)*
245                }
246                fn stream(&self, out: &mut Vec<u8>) -> #crate_name::chia_error::Result<()> {
247                    #(self.#findices.stream(out)?;)*
248                    Ok(())
249                }
250                fn parse<const TRUSTED: bool>(input: &mut std::io::Cursor<&[u8]>) -> #crate_name::chia_error::Result<Self> {
251                    Ok(Self( #( <#ftypes as #crate_name::Streamable>::parse::<TRUSTED>(input)?, )* ))
252                }
253            }
254        };
255        ret.into()
256    } else {
257        // this is an empty type (Unit)
258        let ret = quote! {
259            impl #crate_name::Streamable for #ident {
260                fn update_digest(&self, _digest: &mut chia_sha2::Sha256) {}
261                fn stream(&self, _out: &mut Vec<u8>) -> #crate_name::chia_error::Result<()> {
262                    Ok(())
263                }
264                fn parse<const TRUSTED: bool>(_input: &mut std::io::Cursor<&[u8]>) -> #crate_name::chia_error::Result<Self> {
265                    Ok(Self{})
266                }
267            }
268        };
269        ret.into()
270    }
271}