chia_py_streamable_macro/
lib.rs1#![allow(clippy::missing_panics_doc)]
2
3use proc_macro_crate::{FoundCrate, crate_name};
4use proc_macro2::{Ident, Span};
5use quote::quote;
6use syn::{DeriveInput, FieldsNamed, FieldsUnnamed, parse_macro_input};
7
8fn maybe_upper_fields(py_uppercase: bool, fnames: Vec<Ident>) -> Vec<Ident> {
9 if py_uppercase {
10 fnames
11 .into_iter()
12 .map(|f| Ident::new(&f.to_string().to_uppercase(), Span::call_site()))
13 .collect()
14 } else {
15 fnames
16 }
17}
18
19#[proc_macro_derive(PyStreamable, attributes(py_uppercase, py_pickle))]
20pub fn py_streamable_macro(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
21 let found_crate = crate_name("chia-traits").expect("chia-traits is present in `Cargo.toml`");
22
23 let crate_name = match found_crate {
24 FoundCrate::Itself => quote!(crate),
25 FoundCrate::Name(name) => {
26 let ident = Ident::new(&name, Span::call_site());
27 quote!(#ident)
28 }
29 };
30
31 let DeriveInput {
32 ident, data, attrs, ..
33 } = parse_macro_input!(input);
34
35 let mut py_uppercase = false;
36 let mut py_pickle = false;
37 for attr in &attrs {
38 if attr.path().is_ident("py_uppercase") {
39 py_uppercase = true;
40 } else if attr.path().is_ident("py_pickle") {
41 py_pickle = true;
42 }
43 }
44
45 let fields = match data {
46 syn::Data::Struct(s) => s.fields,
47 syn::Data::Enum(_) => {
48 return quote! {
49 impl<'a, 'py> pyo3::conversion::FromPyObject<'a, 'py> for #ident {
50 type Error = pyo3::PyErr;
51
52 fn extract(obj: pyo3::Borrowed<'a, 'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
53 use pyo3::types::PyAnyMethods;
54 let v: u8 = obj.extract()?;
55 <Self as #crate_name::Streamable>::parse::<false>(&mut std::io::Cursor::<&[u8]>::new(&[v])).map_err(|e| e.into())
56 }
57 }
58
59 impl<'py> pyo3::conversion::IntoPyObject<'py> for #ident {
60 type Target = pyo3::PyAny;
61 type Output = pyo3::Bound<'py, Self::Target>;
62 type Error = std::convert::Infallible;
63
64 fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
65 Ok(pyo3::IntoPyObject::into_pyobject(self as u8, py)?
66 .clone()
67 .into_any())
68 }
69 }
70 }
71 .into();
72 }
73 syn::Data::Union(_) => {
74 panic!("Streamable only support struct");
75 }
76 };
77
78 let mut py_protocol = quote! {
79 #[pyo3::pymethods]
80 impl #ident {
81 fn __richcmp__(&self, other: pyo3::PyRef<Self>, op: pyo3::class::basic::CompareOp) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
82 use pyo3::class::basic::CompareOp;
83 use pyo3::IntoPyObjectExt;
84 let py = other.py();
85 match op {
86 CompareOp::Eq => (self == &*other).into_py_any(py),
87 CompareOp::Ne => (self != &*other).into_py_any(py),
88 _ => Ok(py.NotImplemented()),
89 }
90 }
91
92 fn __hash__(&self) -> pyo3::PyResult<isize> {
93 let mut hasher = std::collections::hash_map::DefaultHasher::new();
94 std::hash::Hash::hash(self, &mut hasher);
95 Ok(std::hash::Hasher::finish(&hasher) as isize)
96 }
97 }
98
99 impl #crate_name::ChiaToPython for #ident {
100 fn to_python<'a>(&self, py: pyo3::Python<'a>) -> pyo3::PyResult<pyo3::Bound<'a, pyo3::PyAny>> {
101 Ok(pyo3::Py::new(py, self.clone())?.into_bound(py).into_any())
102 }
103 }
104 };
105
106 let mut fnames = Vec::<Ident>::new();
107 let mut ftypes = Vec::<syn::Type>::new();
108
109 match fields {
110 syn::Fields::Named(FieldsNamed { named, .. }) => {
111 for f in &named {
112 fnames.push(f.ident.as_ref().unwrap().clone());
113 ftypes.push(f.ty.clone());
114 }
115
116 let fnames_maybe_upper = maybe_upper_fields(py_uppercase, fnames.clone());
117
118 py_protocol.extend(quote! {
119 #[pyo3::pymethods]
120 impl #ident {
121 #[allow(too_many_arguments)]
122 #[new]
123 #[pyo3(signature = (#(#fnames_maybe_upper),*))]
124 pub fn py_new ( #(#fnames_maybe_upper : #ftypes),* ) -> Self {
125 Self { #(#fnames: #fnames_maybe_upper),* }
126 }
127 }
128 });
129
130 if py_uppercase {
131 py_protocol.extend(quote! {
132 #[pyo3::pymethods]
133 impl #ident {
134 fn __repr__(&self) -> pyo3::PyResult<String> {
135 Ok(format!(concat!(stringify!(#ident), " {{ ", #(stringify!(#fnames_maybe_upper), ": {:?}, ",)* "}}"), #(self.#fnames,)*))
136 }
137 }
138 });
139 } else {
140 py_protocol.extend(quote! {
141 #[pyo3::pymethods]
142 impl #ident {
143 fn __repr__(&self) -> pyo3::PyResult<String> {
144 Ok(format!("{self:?}"))
145 }
146 }
147 });
148 }
149
150 if !named.is_empty() {
151 py_protocol.extend(quote! {
152 #[pyo3::pymethods]
153 impl #ident {
154 #[pyo3(signature = (**kwargs))]
155 fn replace(&self, kwargs: Option<&pyo3::Bound<pyo3::types::PyDict>>) -> pyo3::PyResult<Self> {
156 let mut ret = self.clone();
157 if let Some(kwargs) = kwargs {
158 use pyo3::prelude::PyDictMethods;
159 let iter = kwargs.iter();
160 for (field, value) in iter {
161 use pyo3::prelude::PyAnyMethods;
162 let field = field.extract::<String>()?;
163 match field.as_str() {
164 #(stringify!(#fnames_maybe_upper) => {
165 ret.#fnames = value.extract()?;
166 }),*
167 _ => { return Err(pyo3::exceptions::PyKeyError::new_err(format!("unknown field {field}"))); }
168 }
169 }
170 }
171 Ok(ret)
172 }
173 }
174 });
175 }
176 }
177 syn::Fields::Unnamed(FieldsUnnamed { .. }) => {
178 py_protocol.extend(quote! {
179 #[pyo3::pymethods]
180 impl #ident {
181 fn __repr__(&self) -> pyo3::PyResult<String> {
182 Ok(format!("{self:?}"))
183 }
184 }
185 });
186 }
187 syn::Fields::Unit => {
188 panic!("PyStreamable does not support the unit type");
189 }
190 }
191
192 py_protocol.extend(quote! {
193 #[pyo3::pymethods]
194 impl #ident {
195 #[classmethod]
196 #[pyo3(signature=(json_dict))]
197 pub fn from_json_dict(cls: &pyo3::Bound<'_, pyo3::types::PyType>, py: pyo3::Python<'_>, json_dict: &pyo3::Bound<pyo3::PyAny>) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
198 use pyo3::prelude::PyAnyMethods;
199 use pyo3::IntoPyObjectExt;
200 use pyo3::Bound;
201 use pyo3::type_object::PyTypeInfo;
202 use std::borrow::Borrow;
203 let rust_obj = Bound::new(py, <Self as #crate_name::from_json_dict::FromJsonDict>::from_json_dict(json_dict)?)?;
204
205 if rust_obj.is_exact_instance(&cls) {
206 rust_obj.into_py_any(py)
207 } else {
208 let rust_py = rust_obj.into_py_any(py)?;
209 let instance = cls.call_method1("from_parent", (rust_py.clone_ref(py),))?;
210 Ok(instance.into_any().unbind())
211 }
212 }
213
214 pub fn to_json_dict(&self, py: pyo3::Python) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
215 #crate_name::to_json_dict::ToJsonDict::to_json_dict(self, py)
216 }
217 }
218 });
219
220 let streamable = quote! {
221 #[pyo3::pymethods]
222 impl #ident {
223 #[classmethod]
224 #[pyo3(name = "from_bytes")]
225 pub fn py_from_bytes(cls: &pyo3::Bound<'_, pyo3::types::PyType>, py: pyo3::Python<'_>, blob: pyo3::buffer::PyBuffer<u8>) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
226 use pyo3::prelude::PyAnyMethods;
227 use pyo3::IntoPyObjectExt;
228 use pyo3::Bound;
229 use pyo3::type_object::PyTypeInfo;
230 use std::borrow::Borrow;
231
232 if !blob.is_c_contiguous() {
233 panic!("from_bytes() must be called with a contiguous buffer");
234 }
235 let slice = unsafe {
236 std::slice::from_raw_parts(blob.buf_ptr() as *const u8, blob.len_bytes())
237 };
238 let rust_obj = Bound::new(py, <Self as #crate_name::Streamable>::from_bytes(slice)?)?;
239
240 if rust_obj.is_exact_instance(&cls) {
241 rust_obj.into_py_any(py)
242 } else {
243 let rust_py = rust_obj.into_py_any(py)?;
244 let instance = cls.call_method1("from_parent", (rust_py.clone_ref(py),))?;
245 Ok(instance.into_any().unbind())
246 }
247 }
248
249 #[classmethod]
250 #[pyo3(name = "from_bytes_unchecked")]
251 pub fn py_from_bytes_unchecked(cls: &pyo3::Bound<'_, pyo3::types::PyType>, py: pyo3::Python<'_>, blob: pyo3::buffer::PyBuffer<u8>) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
252 use pyo3::prelude::PyAnyMethods;
253 use pyo3::IntoPyObjectExt;
254 use pyo3::Bound;
255 use pyo3::type_object::PyTypeInfo;
256 use std::borrow::Borrow;
257 if !blob.is_c_contiguous() {
258 panic!("from_bytes_unchecked() must be called with a contiguous buffer");
259 }
260 let slice = unsafe {
261 std::slice::from_raw_parts(blob.buf_ptr() as *const u8, blob.len_bytes())
262 };
263 let rust_obj = Bound::new(py, <Self as #crate_name::Streamable>::from_bytes_unchecked(slice).map_err(|e| <#crate_name::chia_error::Error as Into<pyo3::PyErr>>::into(e))?)?;
264
265 if rust_obj.is_exact_instance(&cls) {
266 rust_obj.into_py_any(py)
267 } else {
268 let rust_py = rust_obj.into_py_any(py)?;
269 let instance = cls.call_method1("from_parent", (rust_py.clone_ref(py),))?;
270 Ok(instance.into_any().unbind())
271 }
272 }
273
274 #[classmethod]
276 #[pyo3(signature= (blob, trusted=false))]
277 pub fn parse_rust<'p>(cls: &pyo3::Bound<'_, pyo3::types::PyType>, py: pyo3::Python<'_>, blob: pyo3::buffer::PyBuffer<u8>, trusted: bool) -> pyo3::PyResult<(pyo3::Py<pyo3::PyAny>, u32)> {
278 use pyo3::prelude::PyAnyMethods;
279 use pyo3::IntoPyObjectExt;
280 use pyo3::Bound;
281 use pyo3::type_object::PyTypeInfo;
282 use std::borrow::Borrow;
283 if !blob.is_c_contiguous() {
284 panic!("parse_rust() must be called with a contiguous buffer");
285 }
286 let slice = unsafe {
287 std::slice::from_raw_parts(blob.buf_ptr() as *const u8, blob.len_bytes())
288 };
289 let mut input = std::io::Cursor::<&[u8]>::new(slice);
290 let rust_obj = if trusted {
291 <Self as #crate_name::Streamable>::parse::<true>(&mut input).map_err(|e| <#crate_name::chia_error::Error as Into<pyo3::PyErr>>::into(e)).map(|v| (v, input.position() as u32))
292 } else {
293 <Self as #crate_name::Streamable>::parse::<false>(&mut input).map_err(|e| <#crate_name::chia_error::Error as Into<pyo3::PyErr>>::into(e)).map(|v| (v, input.position() as u32))
294 }?;
295
296 let rust_value = rust_obj.0;
300 let position = rust_obj.1;
301 let rust_bound = Bound::new(py, rust_value)?;
302
303 if rust_bound.is_exact_instance(&cls) {
304 Ok((rust_bound.into_py_any(py)?, position))
305 } else {
306 let rust_py = rust_bound.into_py_any(py)?;
307 let instance = cls.call_method1("from_parent", (rust_py.clone_ref(py),))?;
308 Ok((instance.into_any().unbind(), position))
309 }
310 }
311
312 pub fn get_hash<'p>(&self, py: pyo3::Python<'p>) -> pyo3::PyResult<pyo3::Bound<'p, pyo3::types::PyAny>> {
313 use pyo3::IntoPyObjectExt;
314 use pyo3::types::PyModule;
315 use pyo3::prelude::PyAnyMethods;
316 let mut ctx = chia_sha2::Sha256::new();
317 #crate_name::Streamable::update_digest(self, &mut ctx);
318 let bytes_module = PyModule::import(py, "chia_rs.sized_bytes")?;
319 let ty = bytes_module.getattr("bytes32")?;
320 let digest = ctx.finalize().into_py_any(py)?;
321 ty.call1((digest,))
322 }
323 #[pyo3(name = "to_bytes")]
324 pub fn py_to_bytes<'p>(&self, py: pyo3::Python<'p>) -> pyo3::PyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
325 let mut writer = Vec::<u8>::new();
326 #crate_name::Streamable::stream(self, &mut writer).map_err(|e| <#crate_name::chia_error::Error as Into<pyo3::PyErr>>::into(e))?;
327 Ok(pyo3::types::PyBytes::new(py, &writer))
328 }
329
330 pub fn stream_to_bytes<'p>(&self, py: pyo3::Python<'p>) -> pyo3::PyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
331 self.py_to_bytes(py)
332 }
333
334 pub fn __bytes__<'p>(&self, py: pyo3::Python<'p>) -> pyo3::PyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
335 self.py_to_bytes(py)
336 }
337
338 pub fn __deepcopy__<'p>(&self, memo: &pyo3::Bound<pyo3::PyAny>) -> pyo3::PyResult<Self> {
339 Ok(self.clone())
340 }
341
342 pub fn __copy__<'p>(&self) -> pyo3::PyResult<Self> {
343 Ok(self.clone())
344 }
345 }
346 };
347 py_protocol.extend(streamable);
348
349 if py_pickle {
350 let pickle = quote! {
351 #[pyo3::pymethods]
352 impl #ident {
353 pub fn __setstate__(
354 &mut self,
355 state: &pyo3::Bound<pyo3::types::PyBytes>,
356 ) -> pyo3::PyResult<()> {
357 use chia_traits::Streamable;
358 use pyo3::types::PyBytesMethods;
359
360 *self = Self::parse::<true>(&mut std::io::Cursor::new(state.as_bytes()))?;
361
362 Ok(())
363 }
364
365 pub fn __getstate__<'py>(
366 &self,
367 py: pyo3::Python<'py>,
368 ) -> pyo3::PyResult<pyo3::Bound<'py, pyo3::types::PyBytes>> {
369 self.py_to_bytes(py)
370 }
371
372 pub fn __getnewargs__<'py>(&self, py: pyo3::Python<'py>) -> pyo3::PyResult<pyo3::Bound<'py, pyo3::types::PyTuple>> {
373 let mut args = Vec::new();
374 #( args.push(#crate_name::ChiaToPython::to_python(&self.#fnames, py)?); )*
375 pyo3::types::PyTuple::new(py, args)
376 }
377 }
378 };
379 py_protocol.extend(pickle);
380 }
381
382 py_protocol.into()
383}
384
385#[proc_macro_derive(PyJsonDict, attributes(py_uppercase))]
386pub fn py_json_dict_macro(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
387 let found_crate = crate_name("chia-traits").expect("chia-traits is present in `Cargo.toml`");
388
389 let crate_name = match found_crate {
390 FoundCrate::Itself => quote!(crate),
391 FoundCrate::Name(name) => {
392 let ident = Ident::new(&name, Span::call_site());
393 quote!(#ident)
394 }
395 };
396
397 let DeriveInput {
398 ident, data, attrs, ..
399 } = parse_macro_input!(input);
400
401 let mut py_uppercase = false;
402 for attr in &attrs {
403 if attr.path().is_ident("py_uppercase") {
404 py_uppercase = true;
405 }
406 }
407
408 let fields = match data {
409 syn::Data::Struct(s) => s.fields,
410 syn::Data::Enum(_) => {
411 return quote! {
412 impl #crate_name::to_json_dict::ToJsonDict for #ident {
413 fn to_json_dict(&self, py: pyo3::Python) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
414 <u8 as #crate_name::to_json_dict::ToJsonDict>::to_json_dict(&(*self as u8), py)
415 }
416 }
417
418 impl #crate_name::from_json_dict::FromJsonDict for #ident {
419 fn from_json_dict(o: &pyo3::Bound<pyo3::PyAny>) -> pyo3::PyResult<Self> {
420 let v = <u8 as #crate_name::from_json_dict::FromJsonDict>::from_json_dict(o)?;
421 <Self as #crate_name::Streamable>::parse::<false>(&mut std::io::Cursor::<&[u8]>::new(&[v])).map_err(|e| e.into())
422 }
423 }
424 }
425 .into();
426 }
427 syn::Data::Union(_) => {
428 panic!("PyJsonDict only support struct");
429 }
430 };
431
432 let mut py_protocol = quote! {};
433
434 match fields {
435 syn::Fields::Named(FieldsNamed { named, .. }) => {
436 let mut fnames = Vec::<Ident>::new();
437 let mut ftypes = Vec::<syn::Type>::new();
438 for f in &named {
439 fnames.push(f.ident.as_ref().unwrap().clone());
440 ftypes.push(f.ty.clone());
441 }
442
443 let fnames_maybe_upper = maybe_upper_fields(py_uppercase, fnames.clone());
444
445 py_protocol.extend( quote! {
446
447 impl #crate_name::to_json_dict::ToJsonDict for #ident {
448 fn to_json_dict(&self, py: pyo3::Python) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
449 use pyo3::prelude::PyDictMethods;
450 let ret = pyo3::types::PyDict::new(py);
451 #(ret.set_item(stringify!(#fnames_maybe_upper), self.#fnames.to_json_dict(py)?)?);*;
452 Ok(ret.into_any().unbind())
453 }
454 }
455
456 impl #crate_name::from_json_dict::FromJsonDict for #ident {
457 fn from_json_dict(o: &pyo3::Bound<pyo3::PyAny>) -> pyo3::PyResult<Self> {
458 use pyo3::prelude::PyAnyMethods;
459 Ok(Self{
460 #(#fnames: <#ftypes as #crate_name::from_json_dict::FromJsonDict>::from_json_dict(&o.get_item(stringify!(#fnames_maybe_upper))?)?,)*
461 })
462 }
463 }
464 });
465 }
466 syn::Fields::Unnamed(FieldsUnnamed { unnamed, .. }) if unnamed.len() == 1 => {
467 let ftype: syn::Type = unnamed
468 .first()
469 .expect("match arm if requires 1 item")
470 .ty
471 .clone();
472
473 py_protocol.extend( quote! {
474
475 impl #crate_name::to_json_dict::ToJsonDict for #ident {
476 fn to_json_dict(&self, py: pyo3::Python) -> pyo3::PyResult<pyo3::Py<pyo3::PyAny>> {
477 self.0.to_json_dict(py)
478 }
479 }
480
481 impl #crate_name::from_json_dict::FromJsonDict for #ident {
482 fn from_json_dict(o: &pyo3::Bound<pyo3::PyAny>) -> pyo3::PyResult<Self> {
483 Ok(Self(
484 <#ftype as #crate_name::from_json_dict::FromJsonDict>::from_json_dict(&o)?
485 ))
486 }
487 }
488 });
489 }
490 _ => {
491 panic!("PyJsonDict only supports named structs and single field unnamed structs");
492 }
493 }
494
495 py_protocol.into()
496}
497
498#[proc_macro_derive(PyGetters, attributes(py_uppercase))]
499pub fn py_getters_macro(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
500 let DeriveInput {
501 ident, data, attrs, ..
502 } = parse_macro_input!(input);
503
504 let mut py_uppercase = false;
505 for attr in &attrs {
506 if attr.path().is_ident("py_uppercase") {
507 py_uppercase = true;
508 }
509 }
510
511 let syn::Data::Struct(s) = data else {
512 panic!("python binding only support struct");
513 };
514
515 let syn::Fields::Named(FieldsNamed { named, .. }) = s.fields else {
516 panic!("python binding only support struct");
517 };
518
519 let found_crate = crate_name("chia-traits").expect("chia-traits is present in `Cargo.toml`");
520
521 let crate_name = match found_crate {
522 FoundCrate::Itself => quote!(crate),
523 FoundCrate::Name(name) => {
524 let ident = Ident::new(&name, Span::call_site());
525 quote!(#ident)
526 }
527 };
528
529 let mut fnames = Vec::<Ident>::new();
530 let mut ftypes = Vec::<syn::Type>::new();
531 for f in named {
532 fnames.push(f.ident.unwrap());
533 ftypes.push(f.ty);
534 }
535
536 let fnames_maybe_upper = maybe_upper_fields(py_uppercase, fnames.clone());
537
538 let ret = quote! {
539 #[pyo3::pymethods]
540 impl #ident {
541 #(
542 #[getter]
543 fn #fnames_maybe_upper<'a> (&self, py: pyo3::Python<'a>) -> pyo3::PyResult<pyo3::Bound<'a, pyo3::PyAny>> {
544 #crate_name::ChiaToPython::to_python(&self.#fnames, py)
545 }
546 )*
547 }
548 };
549
550 ret.into()
551}