pywrapper_macro/
lib.rs

1//! This crate declares only the proc macro attributes, as a crate defining proc macro attributes
2//! must not contain any other public items.
3use syn::parse_macro_input;
4
5use proc_macro::{self, TokenStream};
6use quote::quote;
7#[macro_use]
8extern crate lazy_static;
9
10/// A proc macro used to expose ciphercore Rust structs as Python objects.
11#[proc_macro_attribute]
12pub fn struct_wrapper(_metadata: TokenStream, input: TokenStream) -> TokenStream {
13    let ast = parse_macro_input!(input as syn::ItemStruct);
14    let expanded = macro_backend::build_struct(&ast.ident, &ast.attrs)
15        .unwrap_or_else(|e| e.to_compile_error());
16    quote!(#ast
17        #expanded
18    )
19    .into()
20}
21
22/// A proc macro used to expose ciphercore Rust enums as Python objects.
23#[proc_macro_attribute]
24pub fn enum_to_struct_wrapper(_metadata: TokenStream, input: TokenStream) -> TokenStream {
25    let ast = parse_macro_input!(input as syn::ItemEnum);
26    let expanded = macro_backend::build_struct(&ast.ident, &ast.attrs)
27        .unwrap_or_else(|e| e.to_compile_error());
28    quote!(#ast
29        #expanded
30    )
31    .into()
32}
33
34/// A proc macro used to expose methods to Python.
35///
36/// This self type must have one of {struct_wrapper, enum_to_struct_wrapper} attributes.
37#[proc_macro_attribute]
38pub fn impl_wrapper(_metadata: TokenStream, input: TokenStream) -> TokenStream {
39    let mut ast = parse_macro_input!(input as syn::ItemImpl);
40    let expanded = macro_backend::build_methods(&mut ast).unwrap_or_else(|e| e.to_compile_error());
41    quote!(#ast
42           #[allow(clippy::needless_question_mark)] #expanded
43    )
44    .into()
45}
46
47/// A proc macro used to expose single function to Python.
48#[proc_macro_attribute]
49pub fn fn_wrapper(_metadata: TokenStream, input: TokenStream) -> TokenStream {
50    let mut ast = parse_macro_input!(input as syn::ItemFn);
51    let expanded = macro_backend::build_fn(&mut ast).unwrap_or_else(|e| e.to_compile_error());
52    quote!(#ast
53           #expanded
54    )
55    .into()
56}
57
58mod macro_backend {
59    use std::collections::HashSet;
60
61    use proc_macro2::{Ident, TokenStream};
62    use quote::{quote, ToTokens};
63    use syn::{punctuated::Punctuated, token::Comma, FnArg, PatType, ReturnType};
64
65    lazy_static! {
66        static ref TYPES_TO_WRAP: HashSet<&'static str> = {
67            HashSet::from_iter(vec![
68                "Node",
69                "Graph",
70                "Context",
71                "ScalarType",
72                "Type",
73                "SliceElement",
74                "TypedValue",
75                "Value",
76                "CustomOperation",
77                "JoinType",
78                "ShardConfig",
79            ])
80        };
81    }
82
83    pub fn build_methods(ast: &mut syn::ItemImpl) -> syn::Result<TokenStream> {
84        impl_methods(&ast.self_ty, &mut ast.items)
85    }
86
87    pub fn build_fn(ast: &mut syn::ItemFn) -> syn::Result<TokenStream> {
88        let token_stream = gen_wrapper_method(&mut ast.sig, None)?;
89        let attrs = gen_attributes(&ast.attrs);
90        let name = ast.sig.ident.to_string();
91        Ok(quote!(#(#attrs)*
92    #[pyo3::pyfunction]
93    #[pyo3(name = #name)]
94    #token_stream))
95    }
96
97    pub fn build_struct(
98        t: &syn::Ident,
99        struct_attrs: &[syn::Attribute],
100    ) -> syn::Result<TokenStream> {
101        let nt = get_wrapper_ident(t);
102        let name = format!("{}", t);
103        let attrs = gen_attributes(struct_attrs);
104        Ok(quote!(
105            #(#attrs)*
106            #[pyo3::pyclass(name = #name)]
107            pub struct #nt {
108                pub inner: #t,
109            }
110        ))
111    }
112
113    fn impl_methods(ty: &syn::Type, impls: &mut [syn::ImplItem]) -> syn::Result<TokenStream> {
114        let mut methods = Vec::new();
115
116        for iimpl in impls.iter_mut() {
117            if let syn::ImplItem::Method(meth) = iimpl {
118                let token_stream = gen_wrapper_method(&mut meth.sig, Some(ty))?;
119                let attrs = gen_attributes(&meth.attrs);
120                methods.push(quote!(#(#attrs)* #token_stream));
121            }
122        }
123        let nt = get_wrapper_type_ident(ty, false);
124        Ok(quote! {
125            #[pyo3::pymethods]
126            impl #nt {
127            #(#methods)*
128            fn __str__(&self) -> String {
129                format!("{}", self.inner)
130            }
131            fn __repr__(&self) -> String {
132                self.__str__()
133            }
134            }
135        })
136    }
137
138    fn in_types_to_wrap(tt: &syn::Ident) -> bool {
139        let ii = format!("{}", tt);
140        TYPES_TO_WRAP.contains(ii.as_str())
141    }
142
143    fn get_wrapper_ident(tt: &syn::Ident) -> syn::Ident {
144        let prefix = if in_types_to_wrap(tt) {
145            "PyBinding"
146        } else {
147            ""
148        };
149        Ident::new(format!("{}{}", prefix, tt).as_str(), tt.span())
150    }
151
152    fn get_wrapper_type_ident(ty: &syn::Type, add_ref: bool) -> TokenStream {
153        match get_last_path_segment_from_type_path(ty) {
154            Some(s) => {
155                let ident = get_wrapper_ident(&s.ident);
156                if add_ref && in_types_to_wrap(&s.ident) {
157                    quote!(&#ident)
158                } else {
159                    ident.to_token_stream()
160                }
161            }
162            None => ty.to_token_stream(),
163        }
164    }
165
166    fn gen_wrapper_method(
167        sig: &mut syn::Signature,
168        class: Option<&syn::Type>,
169    ) -> syn::Result<TokenStream> {
170        let name = &sig.ident;
171        if let Some(ts) = check_in_allowlist(format!("{}", name)) {
172            return Ok(ts);
173        }
174        let input = Input::new(&sig.inputs);
175        let inner_inputs = input.get_inner_inputs();
176        let sig_inputs = input.get_sig_inputs();
177        let output = Output::new(&sig.output, class);
178        let ret = output.get_output();
179        let result = if class.is_some() {
180            if input.has_receiver {
181                output.wrap_result(quote!(self.inner.#name(#inner_inputs)))
182            } else {
183                let ts = class.to_token_stream();
184                output.wrap_result(quote!(#ts::#name(#inner_inputs)))
185            }
186        } else {
187            output.wrap_result(quote!(#name(#inner_inputs)))
188        };
189        let attr_sign = input.gen_attr_signature();
190        let staticmethod = input.mb_gen_staticmethod(class.is_none());
191        let prefix = if class.is_none() { "py_binding_" } else { "" };
192        let result_fn_name = Ident::new(format!("{}{}", prefix, name).as_str(), sig.ident.span());
193        Ok(quote!(#staticmethod #attr_sign pub fn #result_fn_name(#sig_inputs) #ret { #result }))
194    }
195
196    struct Output<'a> {
197        has_result: bool,
198        is_vector: bool,
199        inner_type: Option<&'a Ident>,
200        initial_return: &'a ReturnType,
201    }
202
203    impl<'a> Output<'a> {
204        fn new(output: &'a ReturnType, class: Option<&'a syn::Type>) -> Self {
205            let mut has_result = false;
206            let mut is_vector = false;
207            match &output {
208                ReturnType::Default => Output {
209                    has_result,
210                    is_vector,
211                    inner_type: None,
212                    initial_return: output,
213                },
214                ReturnType::Type(_, t) => {
215                    let s = match get_last_path_segment_from_type_path(t.as_ref()) {
216                        Some(tt) => tt,
217                        None => {
218                            return Output {
219                                has_result,
220                                is_vector,
221                                inner_type: None,
222                                initial_return: output,
223                            };
224                        }
225                    };
226                    let ps = if format!("{}", s.ident) == "Result" {
227                        has_result = true;
228                        get_last_path_segment_from_first_argument(s)
229                    } else {
230                        Some(s)
231                    };
232                    let inner_type = match ps {
233                        Some(p) => {
234                            if format!("{}", p.ident) == "Vec" {
235                                is_vector = true;
236                                &get_last_path_segment_from_first_argument(p).unwrap().ident
237                            } else if format!("{}", p.ident) == "Self" {
238                                match get_last_path_segment_from_type_path(class.unwrap()) {
239                                    Some(s) => &s.ident,
240                                    None => &p.ident,
241                                }
242                            } else {
243                                &p.ident
244                            }
245                        }
246                        None => {
247                            return Output {
248                                has_result,
249                                is_vector,
250                                inner_type: None,
251                                initial_return: output,
252                            };
253                        }
254                    };
255                    Output {
256                        has_result,
257                        is_vector,
258                        inner_type: if in_types_to_wrap(inner_type) {
259                            Some(inner_type)
260                        } else {
261                            None
262                        },
263                        initial_return: output,
264                    }
265                }
266            }
267        }
268        fn get_output(&self) -> TokenStream {
269            match self.inner_type {
270                Some(t) => {
271                    let name = get_wrapper_ident(t);
272                    let mb_vec = if self.is_vector {
273                        quote!(Vec<#name>)
274                    } else {
275                        name.to_token_stream()
276                    };
277                    if self.has_result {
278                        quote!(-> pyo3::PyResult<#mb_vec>)
279                    } else {
280                        quote!(-> #mb_vec)
281                    }
282                }
283                None => self.initial_return.to_token_stream(),
284            }
285        }
286        fn wrap_result(&self, result: TokenStream) -> TokenStream {
287            let return_if = if self.has_result {
288                quote!(#result?)
289            } else {
290                result
291            };
292            let wrapped = if self.is_vector {
293                match self.inner_type {
294                    Some(t) => {
295                        let name = get_wrapper_ident(t);
296                        quote!(#return_if.into_iter().map(|x| #name {inner: x}).collect())
297                    }
298                    None => return_if,
299                }
300            } else {
301                match self.inner_type {
302                    Some(t) => {
303                        let name = get_wrapper_ident(t);
304                        quote!(#name {inner: #return_if})
305                    }
306                    None => return_if,
307                }
308            };
309
310            if self.has_result {
311                quote!(Ok(#wrapped))
312            } else {
313                wrapped
314            }
315        }
316    }
317
318    fn get_last_path_segment_from_first_argument(
319        s: &syn::PathSegment,
320    ) -> Option<&syn::PathSegment> {
321        match &s.arguments {
322            syn::PathArguments::AngleBracketed(args) => match args.args.first().unwrap() {
323                syn::GenericArgument::Type(t) => match get_last_path_segment_from_type_path(t) {
324                    Some(p) => Some(p),
325                    None => None,
326                },
327                _ => None,
328            },
329            _ => None,
330        }
331    }
332
333    struct InputArgument<'a> {
334        is_vector: bool,
335        initial_type: &'a syn::Type,
336        inner_type: Option<&'a Ident>,
337        var_name: TokenStream,
338    }
339
340    fn get_last_path_segment_from_type_path(t: &syn::Type) -> Option<&syn::PathSegment> {
341        match t {
342            syn::Type::Path(p) => match p.path.segments.last() {
343                Some(s) => Some(s),
344                None => None,
345            },
346            _ => None,
347        }
348    }
349
350    impl<'a> InputArgument<'a> {
351        fn new(t: &'a PatType) -> Self {
352            let name = match t.pat.as_ref() {
353                syn::Pat::Ident(i) => &i.ident,
354                _ => unreachable!(),
355            };
356            let mut is_vector = false;
357            let s = match get_last_path_segment_from_type_path(t.ty.as_ref()) {
358                Some(s) => s,
359                None => {
360                    return InputArgument {
361                        is_vector,
362                        initial_type: &t.ty,
363                        inner_type: None,
364                        var_name: name.to_token_stream(),
365                    }
366                }
367            };
368            let inner_type = if format!("{}", s.ident) == "Vec" {
369                is_vector = true;
370                &get_last_path_segment_from_first_argument(s).unwrap().ident
371            } else if format!("{}", s.ident) == "Slice" {
372                is_vector = true;
373                &s.ident
374            } else {
375                &s.ident
376            };
377            InputArgument {
378                is_vector,
379                initial_type: &t.ty,
380                inner_type: if in_types_to_wrap(inner_type) || format!("{}", inner_type) == "Slice"
381                {
382                    Some(inner_type)
383                } else {
384                    None
385                },
386                var_name: name.to_token_stream(),
387            }
388        }
389        fn get_signature(&self) -> TokenStream {
390            match self.inner_type {
391                Some(t) => {
392                    let name = &self.var_name;
393                    // Special case for Slicing.
394                    let nt = if format!("{}", t) == "Slice" {
395                        Ident::new("PyBindingSliceElement", t.span())
396                    } else {
397                        get_wrapper_ident(t)
398                    };
399                    if self.is_vector {
400                        quote!(#name: Vec<pyo3::PyRef<#nt>>)
401                    } else {
402                        quote!(#name: &#nt)
403                    }
404                }
405                None => {
406                    let name = &self.var_name;
407                    let t = self.initial_type;
408                    quote!(#name: #t)
409                }
410            }
411        }
412        fn as_inner_argument(&self) -> TokenStream {
413            match self.inner_type {
414                Some(_) => {
415                    let name = &self.var_name;
416                    if self.is_vector {
417                        quote!(#name.into_iter().map(|x| x.inner.clone()).collect())
418                    } else {
419                        quote!(#name.inner.clone())
420                    }
421                }
422                None => {
423                    let name = &self.var_name;
424                    quote!(#name)
425                }
426            }
427        }
428    }
429
430    struct Input {
431        sig_inputs: Vec<TokenStream>,
432        inner_inputs: Vec<TokenStream>,
433        attr_sig: Vec<String>,
434        has_receiver: bool,
435    }
436
437    impl Input {
438        fn new(inputs: &Punctuated<FnArg, Comma>) -> Self {
439            let mut sig = vec![];
440            let mut inner = vec![];
441            let mut attr_sig = vec![];
442            let mut has_receiver = false;
443            for arg in inputs {
444                match arg {
445                    FnArg::Typed(t) => {
446                        let processed_argument = InputArgument::new(t);
447                        sig.push(processed_argument.get_signature());
448                        inner.push(processed_argument.as_inner_argument());
449                        attr_sig.push(processed_argument.var_name.to_string());
450                    }
451                    FnArg::Receiver(slf) => {
452                        sig.push(slf.into_token_stream());
453                        attr_sig.push("$self".to_string());
454                        has_receiver = true;
455                    }
456                }
457            }
458            attr_sig.push("/".to_string());
459            Input {
460                sig_inputs: sig,
461                inner_inputs: inner,
462                attr_sig,
463                has_receiver,
464            }
465        }
466        fn get_inner_inputs(&self) -> TokenStream {
467            let inputs = &self.inner_inputs;
468            quote!(#(#inputs),*)
469        }
470        fn get_sig_inputs(&self) -> TokenStream {
471            let inputs = &self.sig_inputs;
472            quote!(#(#inputs),*)
473        }
474        fn gen_attr_signature(&self) -> TokenStream {
475            let val = vec!["(", self.attr_sig.join(", ").as_str(), ")"].join(" ");
476            quote!(#[pyo3(text_signature = #val)])
477        }
478        fn mb_gen_staticmethod(&self, ignore: bool) -> TokenStream {
479            if self.has_receiver || ignore {
480                TokenStream::new()
481            } else {
482                quote!(#[staticmethod])
483            }
484        }
485    }
486
487    fn gen_attributes(attrs: &[syn::Attribute]) -> Vec<&syn::Attribute> {
488        let mut result = vec![];
489        let mut stop_adding_docs = false;
490        for attr in attrs {
491            if attr.path.is_ident("cfg") {
492                result.push(attr);
493            }
494            if attr.path.is_ident("doc") && !stop_adding_docs {
495                if format!("{}", attr.tokens).contains("# Example")
496                    || format!("{}", attr.tokens).contains("# Rust crates")
497                {
498                    stop_adding_docs = true;
499                } else {
500                    result.push(attr);
501                }
502            }
503        }
504        result
505    }
506
507    fn check_in_allowlist(name: String) -> Option<TokenStream> {
508        match name.as_str() {
509            "create_named_tuple" => Some(quote!(
510                pub fn create_named_tuple(
511                    &self,
512                    elements: Vec<(String, pyo3::PyRef<PyBindingNode>)>,
513                ) -> pyo3::PyResult<PyBindingNode> {
514                    Ok(PyBindingNode {
515                        inner: self.inner.create_named_tuple(
516                            elements
517                                .into_iter()
518                                .map(|x| (x.0, x.1.inner.clone()))
519                                .collect(),
520                        )?,
521                    })
522                }
523            )),
524            "constant" => Some(quote!(
525                pub fn constant(&self, tv: &PyBindingTypedValue) -> pyo3::PyResult<PyBindingNode> {
526                    Ok(PyBindingNode {
527                        inner: self
528                            .inner
529                            .constant(tv.inner.t.clone(), tv.inner.value.clone())?,
530                    })
531                }
532            )),
533            "get_operation" => Some(quote!(
534                pub fn get_operation(&self) -> pyo3::PyResult<String> {
535                    serde_json::to_string(&self.inner.get_operation())
536                        .map_err(|err| pyo3::exceptions::PyRuntimeError::new_err(err.to_string()))
537                }
538            )),
539            "named_tuple_type" => Some(quote!(
540                pub fn py_binding_named_tuple_type(
541                    v: Vec<(String, pyo3::PyRef<PyBindingType>)>,
542                ) -> PyBindingType {
543                    PyBindingType {
544                        inner: named_tuple_type(
545                            v.into_iter().map(|x| (x.0, x.1.inner.clone())).collect(),
546                        ),
547                    }
548                }
549            )),
550            "get_sub_values" => Some(quote!(
551                fn get_sub_values(&self) -> Option<Vec<PyBindingValue>> {
552                    match self.inner.get_sub_values() {
553                        None => None,
554                        Some(v) => {
555                            Some(v.into_iter().map(|x| PyBindingValue { inner: x }).collect())
556                        }
557                    }
558                }
559            )),
560            _ => None,
561        }
562    }
563}