chia_py_streamable_macro/
lib.rs

1#![allow(clippy::missing_panics_doc)]
2
3use proc_macro_crate::{FoundCrate, crate_name};
4use proc_macro2::{Ident, Span};
5use quote::quote;
6use syn::{DeriveInput, FieldsNamed, FieldsUnnamed, parse_macro_input};
7
8fn maybe_upper_fields(py_uppercase: bool, fnames: Vec<Ident>) -> Vec<Ident> {
9    if py_uppercase {
10        fnames
11            .into_iter()
12            .map(|f| Ident::new(&f.to_string().to_uppercase(), Span::call_site()))
13            .collect()
14    } else {
15        fnames
16    }
17}
18
19#[proc_macro_derive(PyStreamable, attributes(py_uppercase, py_pickle))]
20pub fn py_streamable_macro(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
21    let found_crate = crate_name("chia-traits").expect("chia-traits is present in `Cargo.toml`");
22
23    let crate_name = match found_crate {
24        FoundCrate::Itself => quote!(crate),
25        FoundCrate::Name(name) => {
26            let ident = Ident::new(&name, Span::call_site());
27            quote!(#ident)
28        }
29    };
30
31    let DeriveInput {
32        ident, data, attrs, ..
33    } = parse_macro_input!(input);
34
35    let mut py_uppercase = false;
36    let mut py_pickle = false;
37    for attr in &attrs {
38        if attr.path().is_ident("py_uppercase") {
39            py_uppercase = true;
40        } else if attr.path().is_ident("py_pickle") {
41            py_pickle = true;
42        }
43    }
44
45    let fields = match data {
46        syn::Data::Struct(s) => s.fields,
47        syn::Data::Enum(_) => {
48            return quote! {
49                impl<'a, 'py> pyo3::conversion::FromPyObject<'a, 'py> for #ident {
50                    type Error = pyo3::PyErr;
51
52                    fn extract(obj: pyo3::Borrowed<'a, 'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
53                        use pyo3::types::PyAnyMethods;
54                        let v: u8 = obj.extract()?;
55                        <Self as #crate_name::Streamable>::parse::<false>(&mut std::io::Cursor::<&[u8]>::new(&[v])).map_err(|e| e.into())
56                    }
57                }
58
59                impl<'py> pyo3::conversion::IntoPyObject<'py> for #ident {
60                    type Target = pyo3::PyAny;
61                    type Output = pyo3::Bound<'py, Self::Target>;
62                    type Error = std::convert::Infallible;
63
64                    fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
65                        Ok(pyo3::IntoPyObject::into_pyobject(self as u8, py)?
66                            .clone()
67                            .into_any())
68                    }
69                }
70            }
71            .into();
72        }
73        syn::Data::Union(_) => {
74            panic!("Streamable only support struct");
75        }
76    };
77
78    let mut py_protocol = quote! {
79        #[pyo3::pymethods]
80        impl #ident {
81            fn __richcmp__(&self, other: pyo3::PyRef<Self>, op: pyo3::class::basic::CompareOp) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
82                use pyo3::class::basic::CompareOp;
83                use pyo3::IntoPyObjectExt;
84                let py = other.py();
85                match op {
86                    CompareOp::Eq => (self == &*other).into_py_any(py),
87                    CompareOp::Ne => (self != &*other).into_py_any(py),
88                    _ => Ok(py.NotImplemented()),
89                }
90            }
91
92            fn __hash__(&self) -> pyo3::PyResult<isize> {
93                let mut hasher = std::collections::hash_map::DefaultHasher::new();
94                std::hash::Hash::hash(self, &mut hasher);
95                Ok(std::hash::Hasher::finish(&hasher) as isize)
96            }
97        }
98
99        impl #crate_name::ChiaToPython for #ident {
100            fn to_python<'a>(&self, py: pyo3::Python<'a>) -> pyo3::PyResult<pyo3::Bound<'a, pyo3::PyAny>> {
101                Ok(pyo3::Py::new(py, self.clone())?.into_bound(py).into_any())
102            }
103        }
104    };
105
106    let mut fnames = Vec::<Ident>::new();
107    let mut ftypes = Vec::<syn::Type>::new();
108
109    match fields {
110        syn::Fields::Named(FieldsNamed { named, .. }) => {
111            for f in &named {
112                fnames.push(f.ident.as_ref().unwrap().clone());
113                ftypes.push(f.ty.clone());
114            }
115
116            let fnames_maybe_upper = maybe_upper_fields(py_uppercase, fnames.clone());
117
118            py_protocol.extend(quote! {
119                #[pyo3::pymethods]
120                impl #ident {
121                    #[allow(too_many_arguments)]
122                    #[new]
123                    #[pyo3(signature = (#(#fnames_maybe_upper),*))]
124                    pub fn py_new ( #(#fnames_maybe_upper : #ftypes),* ) -> Self {
125                        Self { #(#fnames: #fnames_maybe_upper),* }
126                    }
127                }
128            });
129
130            if py_uppercase {
131                py_protocol.extend(quote! {
132                    #[pyo3::pymethods]
133                    impl #ident {
134                        fn __repr__(&self) -> pyo3::PyResult<String> {
135                            Ok(format!(concat!(stringify!(#ident), " {{ ", #(stringify!(#fnames_maybe_upper), ": {:?}, ",)* "}}"), #(self.#fnames,)*))
136                        }
137                    }
138                });
139            } else {
140                py_protocol.extend(quote! {
141                    #[pyo3::pymethods]
142                    impl #ident {
143                        fn __repr__(&self) -> pyo3::PyResult<String> {
144                            Ok(format!("{self:?}"))
145                        }
146                    }
147                });
148            }
149
150            if !named.is_empty() {
151                py_protocol.extend(quote! {
152                    #[pyo3::pymethods]
153                    impl #ident {
154                        #[pyo3(signature = (**kwargs))]
155                        fn replace(&self, kwargs: Option<&pyo3::Bound<pyo3::types::PyDict>>) -> pyo3::PyResult<Self> {
156                            let mut ret = self.clone();
157                            if let Some(kwargs) = kwargs {
158                                use pyo3::prelude::PyDictMethods;
159                                let iter = kwargs.iter();
160                                for (field, value) in iter {
161                                    use pyo3::prelude::PyAnyMethods;
162                                    let field = field.extract::<String>()?;
163                                    match field.as_str() {
164                                        #(stringify!(#fnames_maybe_upper) => {
165                                            ret.#fnames = value.extract()?;
166                                        }),*
167                                        _ => { return Err(pyo3::exceptions::PyKeyError::new_err(format!("unknown field {field}"))); }
168                                    }
169                                }
170                            }
171                            Ok(ret)
172                        }
173                    }
174                });
175            }
176        }
177        syn::Fields::Unnamed(FieldsUnnamed { .. }) => {
178            py_protocol.extend(quote! {
179                #[pyo3::pymethods]
180                impl #ident {
181                    fn __repr__(&self) -> pyo3::PyResult<String> {
182                        Ok(format!("{self:?}"))
183                    }
184                }
185            });
186        }
187        syn::Fields::Unit => {
188            panic!("PyStreamable does not support the unit type");
189        }
190    }
191
192    py_protocol.extend(quote! {
193        #[pyo3::pymethods]
194        impl #ident {
195            #[classmethod]
196            #[pyo3(signature=(json_dict))]
197            pub fn from_json_dict(cls: &pyo3::Bound<'_, pyo3::types::PyType>, py: pyo3::Python<'_>, json_dict: &pyo3::Bound<pyo3::PyAny>) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
198                use pyo3::prelude::PyAnyMethods;
199                use pyo3::IntoPyObjectExt;
200                use pyo3::Bound;
201                use pyo3::type_object::PyTypeInfo;
202                use std::borrow::Borrow;
203                let rust_obj = Bound::new(py, <Self as #crate_name::from_json_dict::FromJsonDict>::from_json_dict(json_dict)?)?;
204
205                if rust_obj.is_exact_instance(&cls) {
206                    rust_obj.into_py_any(py)
207                } else {
208                    let rust_py = rust_obj.into_py_any(py)?;
209                    let instance = cls.call_method1("from_parent", (rust_py.clone_ref(py),))?;
210                    Ok(instance.into_any().unbind())
211                }
212            }
213
214            pub fn to_json_dict(&self, py: pyo3::Python) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
215                #crate_name::to_json_dict::ToJsonDict::to_json_dict(self, py)
216            }
217        }
218    });
219
220    let streamable = quote! {
221        #[pyo3::pymethods]
222        impl #ident {
223            #[classmethod]
224            #[pyo3(name = "from_bytes")]
225            pub fn py_from_bytes(cls: &pyo3::Bound<'_, pyo3::types::PyType>, py: pyo3::Python<'_>, blob: pyo3::buffer::PyBuffer<u8>) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
226                use pyo3::prelude::PyAnyMethods;
227                use pyo3::IntoPyObjectExt;
228                use pyo3::Bound;
229                use pyo3::type_object::PyTypeInfo;
230                use std::borrow::Borrow;
231
232                if !blob.is_c_contiguous() {
233                    panic!("from_bytes() must be called with a contiguous buffer");
234                }
235                let slice = unsafe {
236                    std::slice::from_raw_parts(blob.buf_ptr() as *const u8, blob.len_bytes())
237                };
238                let rust_obj = Bound::new(py, <Self as #crate_name::Streamable>::from_bytes(slice)?)?;
239
240                if rust_obj.is_exact_instance(&cls) {
241                    rust_obj.into_py_any(py)
242                } else {
243                    let rust_py = rust_obj.into_py_any(py)?;
244                    let instance = cls.call_method1("from_parent", (rust_py.clone_ref(py),))?;
245                    Ok(instance.into_any().unbind())
246                }
247            }
248
249            #[classmethod]
250            #[pyo3(name = "from_bytes_unchecked")]
251            pub fn py_from_bytes_unchecked(cls: &pyo3::Bound<'_, pyo3::types::PyType>, py: pyo3::Python<'_>, blob: pyo3::buffer::PyBuffer<u8>) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
252                use pyo3::prelude::PyAnyMethods;
253                use pyo3::IntoPyObjectExt;
254                use pyo3::Bound;
255                use pyo3::type_object::PyTypeInfo;
256                use std::borrow::Borrow;
257                if !blob.is_c_contiguous() {
258                    panic!("from_bytes_unchecked() must be called with a contiguous buffer");
259                }
260                let slice = unsafe {
261                    std::slice::from_raw_parts(blob.buf_ptr() as *const u8, blob.len_bytes())
262                };
263                let rust_obj = Bound::new(py, <Self as #crate_name::Streamable>::from_bytes_unchecked(slice).map_err(|e| <#crate_name::chia_error::Error as Into<pyo3::PyErr>>::into(e))?)?;
264
265                if rust_obj.is_exact_instance(&cls) {
266                    rust_obj.into_py_any(py)
267                } else {
268                    let rust_py = rust_obj.into_py_any(py)?;
269                    let instance = cls.call_method1("from_parent", (rust_py.clone_ref(py),))?;
270                    Ok(instance.into_any().unbind())
271                }
272            }
273
274            // returns the type as well as the number of bytes read from the buffer
275            #[classmethod]
276            #[pyo3(signature= (blob, trusted=false))]
277            pub fn parse_rust<'p>(cls: &pyo3::Bound<'_, pyo3::types::PyType>, py: pyo3::Python<'_>, blob: pyo3::buffer::PyBuffer<u8>, trusted: bool) -> pyo3::PyResult<(pyo3::Py<pyo3::PyAny>, u32)> {
278                use pyo3::prelude::PyAnyMethods;
279                use pyo3::IntoPyObjectExt;
280                use pyo3::Bound;
281                use pyo3::type_object::PyTypeInfo;
282                use std::borrow::Borrow;
283                if !blob.is_c_contiguous() {
284                    panic!("parse_rust() must be called with a contiguous buffer");
285                }
286                let slice = unsafe {
287                    std::slice::from_raw_parts(blob.buf_ptr() as *const u8, blob.len_bytes())
288                };
289                let mut input = std::io::Cursor::<&[u8]>::new(slice);
290                let rust_obj = if trusted {
291                    <Self as #crate_name::Streamable>::parse::<true>(&mut input).map_err(|e| <#crate_name::chia_error::Error as Into<pyo3::PyErr>>::into(e)).map(|v| (v, input.position() as u32))
292                } else {
293                    <Self as #crate_name::Streamable>::parse::<false>(&mut input).map_err(|e| <#crate_name::chia_error::Error as Into<pyo3::PyErr>>::into(e)).map(|v| (v, input.position() as u32))
294                }?;
295
296                // Check if python class is different from rust class (in case of child classes)
297                // if so call the python class's conversion code
298
299                let rust_value = rust_obj.0;
300                let position = rust_obj.1;
301                let rust_bound = Bound::new(py, rust_value)?;
302
303                if rust_bound.is_exact_instance(&cls) {
304                    Ok((rust_bound.into_py_any(py)?, position))
305                } else {
306                    let rust_py = rust_bound.into_py_any(py)?;
307                    let instance = cls.call_method1("from_parent", (rust_py.clone_ref(py),))?;
308                    Ok((instance.into_any().unbind(), position))
309                }
310            }
311
312            pub fn get_hash<'p>(&self, py: pyo3::Python<'p>) -> pyo3::PyResult<pyo3::Bound<'p, pyo3::types::PyAny>> {
313                use pyo3::IntoPyObjectExt;
314                use pyo3::types::PyModule;
315                use pyo3::prelude::PyAnyMethods;
316                let mut ctx = chia_sha2::Sha256::new();
317                #crate_name::Streamable::update_digest(self, &mut ctx);
318                let bytes_module = PyModule::import(py, "chia_rs.sized_bytes")?;
319                let ty = bytes_module.getattr("bytes32")?;
320                let digest = ctx.finalize().into_py_any(py)?;
321                ty.call1((digest,))
322            }
323            #[pyo3(name = "to_bytes")]
324            pub fn py_to_bytes<'p>(&self, py: pyo3::Python<'p>) -> pyo3::PyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
325                let mut writer = Vec::<u8>::new();
326                #crate_name::Streamable::stream(self, &mut writer).map_err(|e| <#crate_name::chia_error::Error as Into<pyo3::PyErr>>::into(e))?;
327                Ok(pyo3::types::PyBytes::new(py, &writer))
328            }
329
330            pub fn stream_to_bytes<'p>(&self, py: pyo3::Python<'p>) -> pyo3::PyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
331                self.py_to_bytes(py)
332            }
333
334            pub fn __bytes__<'p>(&self, py: pyo3::Python<'p>) -> pyo3::PyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
335                self.py_to_bytes(py)
336            }
337
338            pub fn __deepcopy__<'p>(&self, memo: &pyo3::Bound<pyo3::PyAny>) -> pyo3::PyResult<Self> {
339                Ok(self.clone())
340            }
341
342            pub fn __copy__<'p>(&self) -> pyo3::PyResult<Self> {
343                Ok(self.clone())
344            }
345        }
346    };
347    py_protocol.extend(streamable);
348
349    if py_pickle {
350        let pickle = quote! {
351            #[pyo3::pymethods]
352            impl #ident {
353                pub fn __setstate__(
354                    &mut self,
355                    state: &pyo3::Bound<pyo3::types::PyBytes>,
356                ) -> pyo3::PyResult<()> {
357                    use chia_traits::Streamable;
358                    use pyo3::types::PyBytesMethods;
359
360                    *self = Self::parse::<true>(&mut std::io::Cursor::new(state.as_bytes()))?;
361
362                    Ok(())
363                }
364
365                pub fn __getstate__<'py>(
366                    &self,
367                    py: pyo3::Python<'py>,
368                ) -> pyo3::PyResult<pyo3::Bound<'py, pyo3::types::PyBytes>> {
369                    self.py_to_bytes(py)
370                }
371
372                pub fn __getnewargs__<'py>(&self, py: pyo3::Python<'py>) -> pyo3::PyResult<pyo3::Bound<'py, pyo3::types::PyTuple>> {
373                    let mut args = Vec::new();
374                    #( args.push(#crate_name::ChiaToPython::to_python(&self.#fnames, py)?); )*
375                    pyo3::types::PyTuple::new(py, args)
376                }
377            }
378        };
379        py_protocol.extend(pickle);
380    }
381
382    py_protocol.into()
383}
384
385#[proc_macro_derive(PyJsonDict, attributes(py_uppercase))]
386pub fn py_json_dict_macro(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
387    let found_crate = crate_name("chia-traits").expect("chia-traits is present in `Cargo.toml`");
388
389    let crate_name = match found_crate {
390        FoundCrate::Itself => quote!(crate),
391        FoundCrate::Name(name) => {
392            let ident = Ident::new(&name, Span::call_site());
393            quote!(#ident)
394        }
395    };
396
397    let DeriveInput {
398        ident, data, attrs, ..
399    } = parse_macro_input!(input);
400
401    let mut py_uppercase = false;
402    for attr in &attrs {
403        if attr.path().is_ident("py_uppercase") {
404            py_uppercase = true;
405        }
406    }
407
408    let fields = match data {
409        syn::Data::Struct(s) => s.fields,
410        syn::Data::Enum(_) => {
411            return quote! {
412                impl #crate_name::to_json_dict::ToJsonDict for #ident {
413                    fn to_json_dict(&self, py: pyo3::Python) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
414                        <u8 as #crate_name::to_json_dict::ToJsonDict>::to_json_dict(&(*self as u8), py)
415                    }
416                }
417
418                impl #crate_name::from_json_dict::FromJsonDict for #ident {
419                    fn from_json_dict(o: &pyo3::Bound<pyo3::PyAny>) -> pyo3::PyResult<Self> {
420                        let v = <u8 as #crate_name::from_json_dict::FromJsonDict>::from_json_dict(o)?;
421                        <Self as #crate_name::Streamable>::parse::<false>(&mut std::io::Cursor::<&[u8]>::new(&[v])).map_err(|e| e.into())
422                    }
423                }
424            }
425            .into();
426        }
427        syn::Data::Union(_) => {
428            panic!("PyJsonDict only support struct");
429        }
430    };
431
432    let mut py_protocol = quote! {};
433
434    match fields {
435        syn::Fields::Named(FieldsNamed { named, .. }) => {
436            let mut fnames = Vec::<Ident>::new();
437            let mut ftypes = Vec::<syn::Type>::new();
438            for f in &named {
439                fnames.push(f.ident.as_ref().unwrap().clone());
440                ftypes.push(f.ty.clone());
441            }
442
443            let fnames_maybe_upper = maybe_upper_fields(py_uppercase, fnames.clone());
444
445            py_protocol.extend( quote! {
446
447                impl #crate_name::to_json_dict::ToJsonDict for #ident {
448                    fn to_json_dict(&self, py: pyo3::Python) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
449                        use pyo3::prelude::PyDictMethods;
450                        let ret = pyo3::types::PyDict::new(py);
451                        #(ret.set_item(stringify!(#fnames_maybe_upper), self.#fnames.to_json_dict(py)?)?);*;
452                        Ok(ret.into_any().unbind())
453                    }
454                }
455
456                impl #crate_name::from_json_dict::FromJsonDict for #ident {
457                    fn from_json_dict(o: &pyo3::Bound<pyo3::PyAny>) -> pyo3::PyResult<Self> {
458                        use pyo3::prelude::PyAnyMethods;
459                        Ok(Self{
460                            #(#fnames: <#ftypes as #crate_name::from_json_dict::FromJsonDict>::from_json_dict(&o.get_item(stringify!(#fnames_maybe_upper))?)?,)*
461                        })
462                    }
463                }
464            });
465        }
466        syn::Fields::Unnamed(FieldsUnnamed { unnamed, .. }) if unnamed.len() == 1 => {
467            let ftype: syn::Type = unnamed
468                .first()
469                .expect("match arm if requires 1 item")
470                .ty
471                .clone();
472
473            py_protocol.extend( quote! {
474
475                impl #crate_name::to_json_dict::ToJsonDict for #ident {
476                    fn to_json_dict(&self, py: pyo3::Python) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
477                        self.0.to_json_dict(py)
478                    }
479                }
480
481                impl #crate_name::from_json_dict::FromJsonDict for #ident {
482                    fn from_json_dict(o: &pyo3::Bound<pyo3::PyAny>) -> pyo3::PyResult<Self> {
483                        Ok(Self(
484                            <#ftype as #crate_name::from_json_dict::FromJsonDict>::from_json_dict(&o)?
485                        ))
486                    }
487                }
488            });
489        }
490        _ => {
491            panic!("PyJsonDict only supports named structs and single field unnamed structs");
492        }
493    }
494
495    py_protocol.into()
496}
497
498#[proc_macro_derive(PyGetters, attributes(py_uppercase))]
499pub fn py_getters_macro(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
500    let DeriveInput {
501        ident, data, attrs, ..
502    } = parse_macro_input!(input);
503
504    let mut py_uppercase = false;
505    for attr in &attrs {
506        if attr.path().is_ident("py_uppercase") {
507            py_uppercase = true;
508        }
509    }
510
511    let syn::Data::Struct(s) = data else {
512        panic!("python binding only support struct");
513    };
514
515    let syn::Fields::Named(FieldsNamed { named, .. }) = s.fields else {
516        panic!("python binding only support struct");
517    };
518
519    let found_crate = crate_name("chia-traits").expect("chia-traits is present in `Cargo.toml`");
520
521    let crate_name = match found_crate {
522        FoundCrate::Itself => quote!(crate),
523        FoundCrate::Name(name) => {
524            let ident = Ident::new(&name, Span::call_site());
525            quote!(#ident)
526        }
527    };
528
529    let mut fnames = Vec::<Ident>::new();
530    let mut ftypes = Vec::<syn::Type>::new();
531    for f in named {
532        fnames.push(f.ident.unwrap());
533        ftypes.push(f.ty);
534    }
535
536    let fnames_maybe_upper = maybe_upper_fields(py_uppercase, fnames.clone());
537
538    let ret = quote! {
539        #[pyo3::pymethods]
540        impl #ident {
541            #(
542            #[getter]
543            fn #fnames_maybe_upper<'a> (&self, py: pyo3::Python<'a>) -> pyo3::PyResult<pyo3::Bound<'a, pyo3::PyAny>> {
544                #crate_name::ChiaToPython::to_python(&self.#fnames, py)
545            }
546            )*
547        }
548    };
549
550    ret.into()
551}