chia_streamable_macro/
lib.rs1#![allow(clippy::missing_panics_doc)]
2
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
5use proc_macro_crate::{crate_name, FoundCrate};
6use quote::quote;
7use syn::token::Pub;
8use syn::{
9 parse_macro_input, Data, DeriveInput, Expr, Fields, FieldsNamed, FieldsUnnamed, Index, Lit,
10 Type, Visibility,
11};
12
13#[proc_macro_attribute]
14pub fn streamable(attr: TokenStream, item: TokenStream) -> TokenStream {
15 let found_crate =
16 crate_name("chia-protocol").expect("chia-protocol is present in `Cargo.toml`");
17
18 let chia_protocol = match &found_crate {
19 FoundCrate::Itself => quote!(crate),
20 FoundCrate::Name(name) => {
21 let ident = Ident::new(name, Span::call_site());
22 quote!(#ident)
23 }
24 };
25
26 let is_message = &attr.to_string() == "message";
27 let is_subclass = &attr.to_string() == "subclass";
28 let no_serde = &attr.to_string() == "no_serde";
29 let no_json = &attr.to_string() == "no_json";
30 let no_streamable = &attr.to_string() == "no_streamable";
31
32 let mut input: DeriveInput = parse_macro_input!(item);
33 let name = input.ident.clone();
34 let name_ref = &name;
35
36 let mut extra_impls = Vec::new();
37
38 if let Data::Struct(data) = &mut input.data {
39 let mut field_names = Vec::new();
40 let mut field_types = Vec::new();
41
42 for (i, field) in data.fields.iter_mut().enumerate() {
43 field.vis = Visibility::Public(Pub::default());
44 field_names.push(Ident::new(
45 &field
46 .ident
47 .as_ref()
48 .map(ToString::to_string)
49 .unwrap_or(format!("field_{i}")),
50 Span::mixed_site(),
51 ));
52 field_types.push(field.ty.clone());
53 }
54
55 let init_names = field_names.clone();
56
57 let initializer = match &data.fields {
58 Fields::Named(..) => quote!( Self { #( #init_names ),* } ),
59 Fields::Unnamed(..) => quote!( Self( #( #init_names ),* ) ),
60 Fields::Unit => quote!(Self),
61 };
62
63 if field_names.is_empty() {
64 extra_impls.push(quote! {
65 impl Default for #name_ref {
66 fn default() -> Self {
67 Self::new()
68 }
69 }
70 });
71 }
72
73 extra_impls.push(quote! {
74 impl #name_ref {
75 #[allow(clippy::too_many_arguments)]
76 pub fn new( #( #field_names: #field_types ),* ) -> #name_ref {
77 #initializer
78 }
79 }
80 });
81
82 if is_message {
83 extra_impls.push(quote! {
84 impl #chia_protocol::ChiaProtocolMessage for #name_ref {
85 fn msg_type() -> #chia_protocol::ProtocolMessageTypes {
86 #chia_protocol::ProtocolMessageTypes::#name_ref
87 }
88 }
89 });
90 }
91 } else {
92 panic!("only structs are supported");
93 }
94
95 let main_derives = if no_streamable {
96 quote! {
97 #[derive(Hash, Debug, Clone, Eq, PartialEq)]
98 }
99 } else {
100 quote! {
101 #[derive(chia_streamable_macro::Streamable, Hash, Debug, Clone, Eq, PartialEq)]
102 }
103 };
104
105 let class_attrs = if is_subclass {
106 quote!(frozen, subclass)
107 } else {
108 quote!(frozen)
109 };
110
111 let attrs = if matches!(found_crate, FoundCrate::Itself) {
115 let serde = if is_message || no_serde {
116 TokenStream2::default()
117 } else {
118 quote! {
119 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
120 }
121 };
122
123 let json_dict = if no_json {
124 TokenStream2::default()
125 } else {
126 quote! {
127 #[cfg_attr(feature = "py-bindings", derive(chia_py_streamable_macro::PyJsonDict))]
128 }
129 };
130
131 quote! {
132 #[cfg_attr(
133 feature = "py-bindings", pyo3::pyclass(#class_attrs), derive(
134 chia_py_streamable_macro::PyStreamable,
135 chia_py_streamable_macro::PyGetters
136 )
137 )]
138 #json_dict
139 #main_derives
140 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
141 #serde
142 }
143 } else {
144 main_derives
145 };
146
147 quote! {
148 #attrs
149 #input
150 #( #extra_impls )*
151 }
152 .into()
153}
154
155#[proc_macro_derive(Streamable)]
156pub fn chia_streamable_macro(input: TokenStream) -> TokenStream {
157 let found_crate = crate_name("chia-traits").expect("chia-traits is present in `Cargo.toml`");
158
159 let crate_name = match found_crate {
160 FoundCrate::Itself => quote!(crate),
161 FoundCrate::Name(name) => {
162 let ident = Ident::new(&name, Span::call_site());
163 quote!(#ident)
164 }
165 };
166
167 let DeriveInput { ident, data, .. } = parse_macro_input!(input);
168
169 let mut fnames = Vec::<Ident>::new();
170 let mut findices = Vec::<Index>::new();
171 let mut ftypes = Vec::<Type>::new();
172 match data {
173 Data::Enum(e) => {
174 let mut names = Vec::<Ident>::new();
175 let mut values = Vec::<u8>::new();
176 for v in &e.variants {
177 names.push(v.ident.clone());
178 let Some((_, expr)) = &v.discriminant else {
179 panic!("unsupported enum");
180 };
181 let Expr::Lit(l) = expr else {
182 panic!("unsupported enum (no literal)");
183 };
184 let Lit::Int(i) = &l.lit else {
185 panic!("unsupported enum (literal is not integer)");
186 };
187 values.push(
188 i.base10_parse::<u8>()
189 .expect("unsupported enum (value not u8)"),
190 );
191 }
192 let ret = quote! {
193 impl #crate_name::Streamable for #ident {
194 fn update_digest(&self, digest: &mut chia_sha2::Sha256) {
195 <u8 as #crate_name::Streamable>::update_digest(&(*self as u8), digest);
196 }
197 fn stream(&self, out: &mut Vec<u8>) -> #crate_name::chia_error::Result<()> {
198 <u8 as #crate_name::Streamable>::stream(&(*self as u8), out)
199 }
200 fn parse<const TRUSTED: bool>(input: &mut std::io::Cursor<&[u8]>) -> #crate_name::chia_error::Result<Self> {
201 let v = <u8 as #crate_name::Streamable>::parse::<TRUSTED>(input)?;
202 match &v {
203 #(#values => Ok(Self::#names),)*
204 _ => Err(#crate_name::chia_error::Error::InvalidEnum),
205 }
206 }
207 }
208 };
209 return ret.into();
210 }
211 Data::Union(_) => {
212 panic!("Streamable does not support Unions");
213 }
214 Data::Struct(s) => match s.fields {
215 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
216 for (index, f) in unnamed.iter().enumerate() {
217 findices.push(Index::from(index));
218 ftypes.push(f.ty.clone());
219 }
220 }
221 Fields::Unit => {}
222 Fields::Named(FieldsNamed { named, .. }) => {
223 for f in &named {
224 fnames.push(f.ident.as_ref().unwrap().clone());
225 ftypes.push(f.ty.clone());
226 }
227 }
228 },
229 }
230
231 if !fnames.is_empty() {
232 let ret = quote! {
233 impl #crate_name::Streamable for #ident {
234 fn update_digest(&self, digest: &mut chia_sha2::Sha256) {
235 #(self.#fnames.update_digest(digest);)*
236 }
237 fn stream(&self, out: &mut Vec<u8>) -> #crate_name::chia_error::Result<()> {
238 #(self.#fnames.stream(out)?;)*
239 Ok(())
240 }
241 fn parse<const TRUSTED: bool>(input: &mut std::io::Cursor<&[u8]>) -> #crate_name::chia_error::Result<Self> {
242 Ok(Self { #( #fnames: <#ftypes as #crate_name::Streamable>::parse::<TRUSTED>(input)?, )* })
243 }
244 }
245 };
246 ret.into()
247 } else if !findices.is_empty() {
248 let ret = quote! {
249 impl #crate_name::Streamable for #ident {
250 fn update_digest(&self, digest: &mut chia_sha2::Sha256) {
251 #(self.#findices.update_digest(digest);)*
252 }
253 fn stream(&self, out: &mut Vec<u8>) -> #crate_name::chia_error::Result<()> {
254 #(self.#findices.stream(out)?;)*
255 Ok(())
256 }
257 fn parse<const TRUSTED: bool>(input: &mut std::io::Cursor<&[u8]>) -> #crate_name::chia_error::Result<Self> {
258 Ok(Self( #( <#ftypes as #crate_name::Streamable>::parse::<TRUSTED>(input)?, )* ))
259 }
260 }
261 };
262 ret.into()
263 } else {
264 let ret = quote! {
266 impl #crate_name::Streamable for #ident {
267 fn update_digest(&self, _digest: &mut chia_sha2::Sha256) {}
268 fn stream(&self, _out: &mut Vec<u8>) -> #crate_name::chia_error::Result<()> {
269 Ok(())
270 }
271 fn parse<const TRUSTED: bool>(_input: &mut std::io::Cursor<&[u8]>) -> #crate_name::chia_error::Result<Self> {
272 Ok(Self{})
273 }
274 }
275 };
276 ret.into()
277 }
278}