chia_streamable_macro/
lib.rs1#![allow(clippy::missing_panics_doc)]
2
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
5use proc_macro_crate::{crate_name, FoundCrate};
6use quote::quote;
7use syn::token::Pub;
8use syn::{
9 parse_macro_input, Data, DeriveInput, Expr, Fields, FieldsNamed, FieldsUnnamed, Index, Lit,
10 Type, Visibility,
11};
12
13#[proc_macro_attribute]
14pub fn streamable(attr: TokenStream, item: TokenStream) -> TokenStream {
15 let found_crate =
16 crate_name("chia-protocol").expect("chia-protocol is present in `Cargo.toml`");
17
18 let chia_protocol = match &found_crate {
19 FoundCrate::Itself => quote!(crate),
20 FoundCrate::Name(name) => {
21 let ident = Ident::new(name, Span::call_site());
22 quote!(#ident)
23 }
24 };
25
26 let is_message = &attr.to_string() == "message";
27 let is_subclass = &attr.to_string() == "subclass";
28 let no_serde = &attr.to_string() == "no_serde";
29 let no_json = &attr.to_string() == "no_json";
30
31 let mut input: DeriveInput = parse_macro_input!(item);
32 let name = input.ident.clone();
33 let name_ref = &name;
34
35 let mut extra_impls = Vec::new();
36
37 if let Data::Struct(data) = &mut input.data {
38 let mut field_names = Vec::new();
39 let mut field_types = Vec::new();
40
41 for (i, field) in data.fields.iter_mut().enumerate() {
42 field.vis = Visibility::Public(Pub::default());
43 field_names.push(Ident::new(
44 &field
45 .ident
46 .as_ref()
47 .map(ToString::to_string)
48 .unwrap_or(format!("field_{i}")),
49 Span::mixed_site(),
50 ));
51 field_types.push(field.ty.clone());
52 }
53
54 let init_names = field_names.clone();
55
56 let initializer = match &data.fields {
57 Fields::Named(..) => quote!( Self { #( #init_names ),* } ),
58 Fields::Unnamed(..) => quote!( Self( #( #init_names ),* ) ),
59 Fields::Unit => quote!(Self),
60 };
61
62 if field_names.is_empty() {
63 extra_impls.push(quote! {
64 impl Default for #name_ref {
65 fn default() -> Self {
66 Self::new()
67 }
68 }
69 });
70 }
71
72 extra_impls.push(quote! {
73 impl #name_ref {
74 #[allow(clippy::too_many_arguments)]
75 pub fn new( #( #field_names: #field_types ),* ) -> #name_ref {
76 #initializer
77 }
78 }
79 });
80
81 if is_message {
82 extra_impls.push(quote! {
83 impl #chia_protocol::ChiaProtocolMessage for #name_ref {
84 fn msg_type() -> #chia_protocol::ProtocolMessageTypes {
85 #chia_protocol::ProtocolMessageTypes::#name_ref
86 }
87 }
88 });
89 }
90 } else {
91 panic!("only structs are supported");
92 }
93
94 let main_derives = quote! {
95 #[derive(chia_streamable_macro::Streamable, Hash, Debug, Clone, Eq, PartialEq)]
96 };
97
98 let class_attrs = if is_subclass {
99 quote!(frozen, subclass)
100 } else {
101 quote!(frozen)
102 };
103
104 let attrs = if matches!(found_crate, FoundCrate::Itself) {
108 let serde = if is_message || no_serde {
109 TokenStream2::default()
110 } else {
111 quote! {
112 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
113 }
114 };
115
116 let json_dict = if no_json {
117 TokenStream2::default()
118 } else {
119 quote! {
120 #[cfg_attr(feature = "py-bindings", derive(chia_py_streamable_macro::PyJsonDict))]
121 }
122 };
123
124 quote! {
125 #[cfg_attr(
126 feature = "py-bindings", pyo3::pyclass(#class_attrs), derive(
127 chia_py_streamable_macro::PyStreamable,
128 chia_py_streamable_macro::PyGetters
129 )
130 )]
131 #json_dict
132 #main_derives
133 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
134 #serde
135 }
136 } else {
137 main_derives
138 };
139
140 quote! {
141 #attrs
142 #input
143 #( #extra_impls )*
144 }
145 .into()
146}
147
148#[proc_macro_derive(Streamable)]
149pub fn chia_streamable_macro(input: TokenStream) -> TokenStream {
150 let found_crate = crate_name("chia-traits").expect("chia-traits is present in `Cargo.toml`");
151
152 let crate_name = match found_crate {
153 FoundCrate::Itself => quote!(crate),
154 FoundCrate::Name(name) => {
155 let ident = Ident::new(&name, Span::call_site());
156 quote!(#ident)
157 }
158 };
159
160 let DeriveInput { ident, data, .. } = parse_macro_input!(input);
161
162 let mut fnames = Vec::<Ident>::new();
163 let mut findices = Vec::<Index>::new();
164 let mut ftypes = Vec::<Type>::new();
165 match data {
166 Data::Enum(e) => {
167 let mut names = Vec::<Ident>::new();
168 let mut values = Vec::<u8>::new();
169 for v in &e.variants {
170 names.push(v.ident.clone());
171 let Some((_, expr)) = &v.discriminant else {
172 panic!("unsupported enum");
173 };
174 let Expr::Lit(l) = expr else {
175 panic!("unsupported enum (no literal)");
176 };
177 let Lit::Int(i) = &l.lit else {
178 panic!("unsupported enum (literal is not integer)");
179 };
180 values.push(
181 i.base10_parse::<u8>()
182 .expect("unsupported enum (value not u8)"),
183 );
184 }
185 let ret = quote! {
186 impl #crate_name::Streamable for #ident {
187 fn update_digest(&self, digest: &mut chia_sha2::Sha256) {
188 <u8 as #crate_name::Streamable>::update_digest(&(*self as u8), digest);
189 }
190 fn stream(&self, out: &mut Vec<u8>) -> #crate_name::chia_error::Result<()> {
191 <u8 as #crate_name::Streamable>::stream(&(*self as u8), out)
192 }
193 fn parse<const TRUSTED: bool>(input: &mut std::io::Cursor<&[u8]>) -> #crate_name::chia_error::Result<Self> {
194 let v = <u8 as #crate_name::Streamable>::parse::<TRUSTED>(input)?;
195 match &v {
196 #(#values => Ok(Self::#names),)*
197 _ => Err(#crate_name::chia_error::Error::InvalidEnum),
198 }
199 }
200 }
201 };
202 return ret.into();
203 }
204 Data::Union(_) => {
205 panic!("Streamable does not support Unions");
206 }
207 Data::Struct(s) => match s.fields {
208 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
209 for (index, f) in unnamed.iter().enumerate() {
210 findices.push(Index::from(index));
211 ftypes.push(f.ty.clone());
212 }
213 }
214 Fields::Unit => {}
215 Fields::Named(FieldsNamed { named, .. }) => {
216 for f in &named {
217 fnames.push(f.ident.as_ref().unwrap().clone());
218 ftypes.push(f.ty.clone());
219 }
220 }
221 },
222 }
223
224 if !fnames.is_empty() {
225 let ret = quote! {
226 impl #crate_name::Streamable for #ident {
227 fn update_digest(&self, digest: &mut chia_sha2::Sha256) {
228 #(self.#fnames.update_digest(digest);)*
229 }
230 fn stream(&self, out: &mut Vec<u8>) -> #crate_name::chia_error::Result<()> {
231 #(self.#fnames.stream(out)?;)*
232 Ok(())
233 }
234 fn parse<const TRUSTED: bool>(input: &mut std::io::Cursor<&[u8]>) -> #crate_name::chia_error::Result<Self> {
235 Ok(Self { #( #fnames: <#ftypes as #crate_name::Streamable>::parse::<TRUSTED>(input)?, )* })
236 }
237 }
238 };
239 ret.into()
240 } else if !findices.is_empty() {
241 let ret = quote! {
242 impl #crate_name::Streamable for #ident {
243 fn update_digest(&self, digest: &mut chia_sha2::Sha256) {
244 #(self.#findices.update_digest(digest);)*
245 }
246 fn stream(&self, out: &mut Vec<u8>) -> #crate_name::chia_error::Result<()> {
247 #(self.#findices.stream(out)?;)*
248 Ok(())
249 }
250 fn parse<const TRUSTED: bool>(input: &mut std::io::Cursor<&[u8]>) -> #crate_name::chia_error::Result<Self> {
251 Ok(Self( #( <#ftypes as #crate_name::Streamable>::parse::<TRUSTED>(input)?, )* ))
252 }
253 }
254 };
255 ret.into()
256 } else {
257 let ret = quote! {
259 impl #crate_name::Streamable for #ident {
260 fn update_digest(&self, _digest: &mut chia_sha2::Sha256) {}
261 fn stream(&self, _out: &mut Vec<u8>) -> #crate_name::chia_error::Result<()> {
262 Ok(())
263 }
264 fn parse<const TRUSTED: bool>(_input: &mut std::io::Cursor<&[u8]>) -> #crate_name::chia_error::Result<Self> {
265 Ok(Self{})
266 }
267 }
268 };
269 ret.into()
270 }
271}