1use syn::parse_macro_input;
4
5use proc_macro::{self, TokenStream};
6use quote::quote;
7#[macro_use]
8extern crate lazy_static;
9
10#[proc_macro_attribute]
12pub fn struct_wrapper(_metadata: TokenStream, input: TokenStream) -> TokenStream {
13 let ast = parse_macro_input!(input as syn::ItemStruct);
14 let expanded = macro_backend::build_struct(&ast.ident, &ast.attrs)
15 .unwrap_or_else(|e| e.to_compile_error());
16 quote!(#ast
17 #expanded
18 )
19 .into()
20}
21
22#[proc_macro_attribute]
24pub fn enum_to_struct_wrapper(_metadata: TokenStream, input: TokenStream) -> TokenStream {
25 let ast = parse_macro_input!(input as syn::ItemEnum);
26 let expanded = macro_backend::build_struct(&ast.ident, &ast.attrs)
27 .unwrap_or_else(|e| e.to_compile_error());
28 quote!(#ast
29 #expanded
30 )
31 .into()
32}
33
34#[proc_macro_attribute]
38pub fn impl_wrapper(_metadata: TokenStream, input: TokenStream) -> TokenStream {
39 let mut ast = parse_macro_input!(input as syn::ItemImpl);
40 let expanded = macro_backend::build_methods(&mut ast).unwrap_or_else(|e| e.to_compile_error());
41 quote!(#ast
42 #[allow(clippy::needless_question_mark)] #expanded
43 )
44 .into()
45}
46
47#[proc_macro_attribute]
49pub fn fn_wrapper(_metadata: TokenStream, input: TokenStream) -> TokenStream {
50 let mut ast = parse_macro_input!(input as syn::ItemFn);
51 let expanded = macro_backend::build_fn(&mut ast).unwrap_or_else(|e| e.to_compile_error());
52 quote!(#ast
53 #expanded
54 )
55 .into()
56}
57
58mod macro_backend {
59 use std::collections::HashSet;
60
61 use proc_macro2::{Ident, TokenStream};
62 use quote::{quote, ToTokens};
63 use syn::{punctuated::Punctuated, token::Comma, FnArg, PatType, ReturnType};
64
65 lazy_static! {
66 static ref TYPES_TO_WRAP: HashSet<&'static str> = {
67 HashSet::from_iter(vec![
68 "Node",
69 "Graph",
70 "Context",
71 "ScalarType",
72 "Type",
73 "SliceElement",
74 "TypedValue",
75 "Value",
76 "CustomOperation",
77 "JoinType",
78 "ShardConfig",
79 ])
80 };
81 }
82
83 pub fn build_methods(ast: &mut syn::ItemImpl) -> syn::Result<TokenStream> {
84 impl_methods(&ast.self_ty, &mut ast.items)
85 }
86
87 pub fn build_fn(ast: &mut syn::ItemFn) -> syn::Result<TokenStream> {
88 let token_stream = gen_wrapper_method(&mut ast.sig, None)?;
89 let attrs = gen_attributes(&ast.attrs);
90 let name = ast.sig.ident.to_string();
91 Ok(quote!(#(#attrs)*
92 #[pyo3::pyfunction]
93 #[pyo3(name = #name)]
94 #token_stream))
95 }
96
97 pub fn build_struct(
98 t: &syn::Ident,
99 struct_attrs: &[syn::Attribute],
100 ) -> syn::Result<TokenStream> {
101 let nt = get_wrapper_ident(t);
102 let name = format!("{}", t);
103 let attrs = gen_attributes(struct_attrs);
104 Ok(quote!(
105 #(#attrs)*
106 #[pyo3::pyclass(name = #name)]
107 pub struct #nt {
108 pub inner: #t,
109 }
110 ))
111 }
112
113 fn impl_methods(ty: &syn::Type, impls: &mut [syn::ImplItem]) -> syn::Result<TokenStream> {
114 let mut methods = Vec::new();
115
116 for iimpl in impls.iter_mut() {
117 if let syn::ImplItem::Method(meth) = iimpl {
118 let token_stream = gen_wrapper_method(&mut meth.sig, Some(ty))?;
119 let attrs = gen_attributes(&meth.attrs);
120 methods.push(quote!(#(#attrs)* #token_stream));
121 }
122 }
123 let nt = get_wrapper_type_ident(ty, false);
124 Ok(quote! {
125 #[pyo3::pymethods]
126 impl #nt {
127 #(#methods)*
128 fn __str__(&self) -> String {
129 format!("{}", self.inner)
130 }
131 fn __repr__(&self) -> String {
132 self.__str__()
133 }
134 }
135 })
136 }
137
138 fn in_types_to_wrap(tt: &syn::Ident) -> bool {
139 let ii = format!("{}", tt);
140 TYPES_TO_WRAP.contains(ii.as_str())
141 }
142
143 fn get_wrapper_ident(tt: &syn::Ident) -> syn::Ident {
144 let prefix = if in_types_to_wrap(tt) {
145 "PyBinding"
146 } else {
147 ""
148 };
149 Ident::new(format!("{}{}", prefix, tt).as_str(), tt.span())
150 }
151
152 fn get_wrapper_type_ident(ty: &syn::Type, add_ref: bool) -> TokenStream {
153 match get_last_path_segment_from_type_path(ty) {
154 Some(s) => {
155 let ident = get_wrapper_ident(&s.ident);
156 if add_ref && in_types_to_wrap(&s.ident) {
157 quote!(&#ident)
158 } else {
159 ident.to_token_stream()
160 }
161 }
162 None => ty.to_token_stream(),
163 }
164 }
165
166 fn gen_wrapper_method(
167 sig: &mut syn::Signature,
168 class: Option<&syn::Type>,
169 ) -> syn::Result<TokenStream> {
170 let name = &sig.ident;
171 if let Some(ts) = check_in_allowlist(format!("{}", name)) {
172 return Ok(ts);
173 }
174 let input = Input::new(&sig.inputs);
175 let inner_inputs = input.get_inner_inputs();
176 let sig_inputs = input.get_sig_inputs();
177 let output = Output::new(&sig.output, class);
178 let ret = output.get_output();
179 let result = if class.is_some() {
180 if input.has_receiver {
181 output.wrap_result(quote!(self.inner.#name(#inner_inputs)))
182 } else {
183 let ts = class.to_token_stream();
184 output.wrap_result(quote!(#ts::#name(#inner_inputs)))
185 }
186 } else {
187 output.wrap_result(quote!(#name(#inner_inputs)))
188 };
189 let attr_sign = input.gen_attr_signature();
190 let staticmethod = input.mb_gen_staticmethod(class.is_none());
191 let prefix = if class.is_none() { "py_binding_" } else { "" };
192 let result_fn_name = Ident::new(format!("{}{}", prefix, name).as_str(), sig.ident.span());
193 Ok(quote!(#staticmethod #attr_sign pub fn #result_fn_name(#sig_inputs) #ret { #result }))
194 }
195
196 struct Output<'a> {
197 has_result: bool,
198 is_vector: bool,
199 inner_type: Option<&'a Ident>,
200 initial_return: &'a ReturnType,
201 }
202
203 impl<'a> Output<'a> {
204 fn new(output: &'a ReturnType, class: Option<&'a syn::Type>) -> Self {
205 let mut has_result = false;
206 let mut is_vector = false;
207 match &output {
208 ReturnType::Default => Output {
209 has_result,
210 is_vector,
211 inner_type: None,
212 initial_return: output,
213 },
214 ReturnType::Type(_, t) => {
215 let s = match get_last_path_segment_from_type_path(t.as_ref()) {
216 Some(tt) => tt,
217 None => {
218 return Output {
219 has_result,
220 is_vector,
221 inner_type: None,
222 initial_return: output,
223 };
224 }
225 };
226 let ps = if format!("{}", s.ident) == "Result" {
227 has_result = true;
228 get_last_path_segment_from_first_argument(s)
229 } else {
230 Some(s)
231 };
232 let inner_type = match ps {
233 Some(p) => {
234 if format!("{}", p.ident) == "Vec" {
235 is_vector = true;
236 &get_last_path_segment_from_first_argument(p).unwrap().ident
237 } else if format!("{}", p.ident) == "Self" {
238 match get_last_path_segment_from_type_path(class.unwrap()) {
239 Some(s) => &s.ident,
240 None => &p.ident,
241 }
242 } else {
243 &p.ident
244 }
245 }
246 None => {
247 return Output {
248 has_result,
249 is_vector,
250 inner_type: None,
251 initial_return: output,
252 };
253 }
254 };
255 Output {
256 has_result,
257 is_vector,
258 inner_type: if in_types_to_wrap(inner_type) {
259 Some(inner_type)
260 } else {
261 None
262 },
263 initial_return: output,
264 }
265 }
266 }
267 }
268 fn get_output(&self) -> TokenStream {
269 match self.inner_type {
270 Some(t) => {
271 let name = get_wrapper_ident(t);
272 let mb_vec = if self.is_vector {
273 quote!(Vec<#name>)
274 } else {
275 name.to_token_stream()
276 };
277 if self.has_result {
278 quote!(-> pyo3::PyResult<#mb_vec>)
279 } else {
280 quote!(-> #mb_vec)
281 }
282 }
283 None => self.initial_return.to_token_stream(),
284 }
285 }
286 fn wrap_result(&self, result: TokenStream) -> TokenStream {
287 let return_if = if self.has_result {
288 quote!(#result?)
289 } else {
290 result
291 };
292 let wrapped = if self.is_vector {
293 match self.inner_type {
294 Some(t) => {
295 let name = get_wrapper_ident(t);
296 quote!(#return_if.into_iter().map(|x| #name {inner: x}).collect())
297 }
298 None => return_if,
299 }
300 } else {
301 match self.inner_type {
302 Some(t) => {
303 let name = get_wrapper_ident(t);
304 quote!(#name {inner: #return_if})
305 }
306 None => return_if,
307 }
308 };
309
310 if self.has_result {
311 quote!(Ok(#wrapped))
312 } else {
313 wrapped
314 }
315 }
316 }
317
318 fn get_last_path_segment_from_first_argument(
319 s: &syn::PathSegment,
320 ) -> Option<&syn::PathSegment> {
321 match &s.arguments {
322 syn::PathArguments::AngleBracketed(args) => match args.args.first().unwrap() {
323 syn::GenericArgument::Type(t) => match get_last_path_segment_from_type_path(t) {
324 Some(p) => Some(p),
325 None => None,
326 },
327 _ => None,
328 },
329 _ => None,
330 }
331 }
332
333 struct InputArgument<'a> {
334 is_vector: bool,
335 initial_type: &'a syn::Type,
336 inner_type: Option<&'a Ident>,
337 var_name: TokenStream,
338 }
339
340 fn get_last_path_segment_from_type_path(t: &syn::Type) -> Option<&syn::PathSegment> {
341 match t {
342 syn::Type::Path(p) => match p.path.segments.last() {
343 Some(s) => Some(s),
344 None => None,
345 },
346 _ => None,
347 }
348 }
349
350 impl<'a> InputArgument<'a> {
351 fn new(t: &'a PatType) -> Self {
352 let name = match t.pat.as_ref() {
353 syn::Pat::Ident(i) => &i.ident,
354 _ => unreachable!(),
355 };
356 let mut is_vector = false;
357 let s = match get_last_path_segment_from_type_path(t.ty.as_ref()) {
358 Some(s) => s,
359 None => {
360 return InputArgument {
361 is_vector,
362 initial_type: &t.ty,
363 inner_type: None,
364 var_name: name.to_token_stream(),
365 }
366 }
367 };
368 let inner_type = if format!("{}", s.ident) == "Vec" {
369 is_vector = true;
370 &get_last_path_segment_from_first_argument(s).unwrap().ident
371 } else if format!("{}", s.ident) == "Slice" {
372 is_vector = true;
373 &s.ident
374 } else {
375 &s.ident
376 };
377 InputArgument {
378 is_vector,
379 initial_type: &t.ty,
380 inner_type: if in_types_to_wrap(inner_type) || format!("{}", inner_type) == "Slice"
381 {
382 Some(inner_type)
383 } else {
384 None
385 },
386 var_name: name.to_token_stream(),
387 }
388 }
389 fn get_signature(&self) -> TokenStream {
390 match self.inner_type {
391 Some(t) => {
392 let name = &self.var_name;
393 let nt = if format!("{}", t) == "Slice" {
395 Ident::new("PyBindingSliceElement", t.span())
396 } else {
397 get_wrapper_ident(t)
398 };
399 if self.is_vector {
400 quote!(#name: Vec<pyo3::PyRef<#nt>>)
401 } else {
402 quote!(#name: &#nt)
403 }
404 }
405 None => {
406 let name = &self.var_name;
407 let t = self.initial_type;
408 quote!(#name: #t)
409 }
410 }
411 }
412 fn as_inner_argument(&self) -> TokenStream {
413 match self.inner_type {
414 Some(_) => {
415 let name = &self.var_name;
416 if self.is_vector {
417 quote!(#name.into_iter().map(|x| x.inner.clone()).collect())
418 } else {
419 quote!(#name.inner.clone())
420 }
421 }
422 None => {
423 let name = &self.var_name;
424 quote!(#name)
425 }
426 }
427 }
428 }
429
430 struct Input {
431 sig_inputs: Vec<TokenStream>,
432 inner_inputs: Vec<TokenStream>,
433 attr_sig: Vec<String>,
434 has_receiver: bool,
435 }
436
437 impl Input {
438 fn new(inputs: &Punctuated<FnArg, Comma>) -> Self {
439 let mut sig = vec![];
440 let mut inner = vec![];
441 let mut attr_sig = vec![];
442 let mut has_receiver = false;
443 for arg in inputs {
444 match arg {
445 FnArg::Typed(t) => {
446 let processed_argument = InputArgument::new(t);
447 sig.push(processed_argument.get_signature());
448 inner.push(processed_argument.as_inner_argument());
449 attr_sig.push(processed_argument.var_name.to_string());
450 }
451 FnArg::Receiver(slf) => {
452 sig.push(slf.into_token_stream());
453 attr_sig.push("$self".to_string());
454 has_receiver = true;
455 }
456 }
457 }
458 attr_sig.push("/".to_string());
459 Input {
460 sig_inputs: sig,
461 inner_inputs: inner,
462 attr_sig,
463 has_receiver,
464 }
465 }
466 fn get_inner_inputs(&self) -> TokenStream {
467 let inputs = &self.inner_inputs;
468 quote!(#(#inputs),*)
469 }
470 fn get_sig_inputs(&self) -> TokenStream {
471 let inputs = &self.sig_inputs;
472 quote!(#(#inputs),*)
473 }
474 fn gen_attr_signature(&self) -> TokenStream {
475 let val = vec!["(", self.attr_sig.join(", ").as_str(), ")"].join(" ");
476 quote!(#[pyo3(text_signature = #val)])
477 }
478 fn mb_gen_staticmethod(&self, ignore: bool) -> TokenStream {
479 if self.has_receiver || ignore {
480 TokenStream::new()
481 } else {
482 quote!(#[staticmethod])
483 }
484 }
485 }
486
487 fn gen_attributes(attrs: &[syn::Attribute]) -> Vec<&syn::Attribute> {
488 let mut result = vec![];
489 let mut stop_adding_docs = false;
490 for attr in attrs {
491 if attr.path.is_ident("cfg") {
492 result.push(attr);
493 }
494 if attr.path.is_ident("doc") && !stop_adding_docs {
495 if format!("{}", attr.tokens).contains("# Example")
496 || format!("{}", attr.tokens).contains("# Rust crates")
497 {
498 stop_adding_docs = true;
499 } else {
500 result.push(attr);
501 }
502 }
503 }
504 result
505 }
506
507 fn check_in_allowlist(name: String) -> Option<TokenStream> {
508 match name.as_str() {
509 "create_named_tuple" => Some(quote!(
510 pub fn create_named_tuple(
511 &self,
512 elements: Vec<(String, pyo3::PyRef<PyBindingNode>)>,
513 ) -> pyo3::PyResult<PyBindingNode> {
514 Ok(PyBindingNode {
515 inner: self.inner.create_named_tuple(
516 elements
517 .into_iter()
518 .map(|x| (x.0, x.1.inner.clone()))
519 .collect(),
520 )?,
521 })
522 }
523 )),
524 "constant" => Some(quote!(
525 pub fn constant(&self, tv: &PyBindingTypedValue) -> pyo3::PyResult<PyBindingNode> {
526 Ok(PyBindingNode {
527 inner: self
528 .inner
529 .constant(tv.inner.t.clone(), tv.inner.value.clone())?,
530 })
531 }
532 )),
533 "get_operation" => Some(quote!(
534 pub fn get_operation(&self) -> pyo3::PyResult<String> {
535 serde_json::to_string(&self.inner.get_operation())
536 .map_err(|err| pyo3::exceptions::PyRuntimeError::new_err(err.to_string()))
537 }
538 )),
539 "named_tuple_type" => Some(quote!(
540 pub fn py_binding_named_tuple_type(
541 v: Vec<(String, pyo3::PyRef<PyBindingType>)>,
542 ) -> PyBindingType {
543 PyBindingType {
544 inner: named_tuple_type(
545 v.into_iter().map(|x| (x.0, x.1.inner.clone())).collect(),
546 ),
547 }
548 }
549 )),
550 "get_sub_values" => Some(quote!(
551 fn get_sub_values(&self) -> Option<Vec<PyBindingValue>> {
552 match self.inner.get_sub_values() {
553 None => None,
554 Some(v) => {
555 Some(v.into_iter().map(|x| PyBindingValue { inner: x }).collect())
556 }
557 }
558 }
559 )),
560 _ => None,
561 }
562 }
563}