1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
use quote::ToTokens;
extern crate proc_macro;
#[proc_macro_attribute]
pub fn pyclass_for_prost_struct(
_args: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let input = proc_macro2::TokenStream::from(input);
let output = pyclass_for_prost_struct_impl(input);
proc_macro::TokenStream::from(output)
}
fn pyclass_for_prost_struct_impl(input: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
if let Ok(mut struct_) = syn::parse2::<syn::ItemStruct>(input.clone()) {
struct_
.attrs
.push(syn::parse_quote! {#[::pyo3::prelude::pyclass]});
if let syn::Fields::Named(fields_named) = &mut struct_.fields {
for field in fields_named.named.iter_mut() {
let is_oneof = field.attrs.iter().any(|attr| {
if attr.path.is_ident("prost") {
if let Ok(syn::Meta::List(list)) = attr.parse_meta() {
return list.nested.iter().any(|nested_meta| {
if let syn::NestedMeta::Meta(meta) = nested_meta {
if let syn::Meta::NameValue(nv) = meta {
if nv.path.is_ident("oneof") {
return true;
}
}
}
false
});
}
}
false
});
if !is_oneof {
field.attrs.push(syn::parse_quote! {
#[pyo3(get, set)]
});
}
}
}
let struct_name = &struct_.ident;
let impl_ = quote::quote! {
#[::pyo3::prelude::pymethods]
impl #struct_name {
#[new]
pub fn new() -> Self {
Self::default()
}
#[staticmethod]
#[pyo3(name = "decode")]
pub fn decode_py(bytes: &::pyo3::types::PyBytes) -> ::pyo3::PyResult<Self> {
let bytes: &[u8] = ::pyo3::FromPyObject::extract(bytes)?;
<Self as ::prost::Message>::decode(bytes).map_err(|e| {
::pyo3::exceptions::PyRuntimeError::new_err(format!("{}", e))
})
}
pub fn decode_merge(slf: &::pyo3::pycell::PyCell<#struct_name>, py: ::pyo3::Python, bytes: &::pyo3::types::PyBytes) -> ::pyo3::PyResult<()> {
let bytes: &[u8] = ::pyo3::FromPyObject::extract(bytes)?;
{
let mut obj_mut = slf.borrow_mut();
<Self as ::prost::Message>::merge(::core::ops::DerefMut::deref_mut(&mut obj_mut), bytes).map_err(|e| {
::pyo3::exceptions::PyRuntimeError::new_err(format!("{}", e))
})?;
}
Ok(())
}
#[pyo3(name = "encode")]
pub fn encode_py<'a>(&self, py: ::pyo3::Python<'a>) -> ::pyo3::PyResult<&'a ::pyo3::types::PyBytes> {
Ok(::pyo3::types::PyBytes::new_with(py, ::prost::Message::encoded_len(self), |mut py_buf: &mut [u8]| {
::prost::Message::encode(self, &mut py_buf).map_err(|e| {
::pyo3::exceptions::PyRuntimeError::new_err(format!("{}", e))
})?;
Ok(())
})?)
}
pub fn clear(&mut self) {
*self = Default::default();
}
}
#[::pyo3::prelude::pyproto]
impl ::pyo3::class::basic::PyObjectProtocol for #struct_name {
fn __repr__(&self) -> ::pyo3::PyResult<String> {
Ok(format!("{:?}", self))
}
fn __str__(&self) -> ::pyo3::PyResult<String> {
Ok(format!("{:#?}", self))
}
}
};
struct_
.into_token_stream()
.into_iter()
.chain(impl_.into_iter())
.collect()
} else {
input
}
}
#[cfg(test)]
mod tests {
#[test]
fn it_works() {
use std::str::FromStr;
let ts = proc_macro2::TokenStream::from_str(
"#[derive(Clone, PartialEq, ::prost::Message)]\npub struct MarginUpdate { a: i32, pub b:String,}",
)
.unwrap();
println!("{}", super::pyclass_for_prost_struct_impl(ts));
}
}