cpy_binder/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{
6    bracketed, parse::Parse, parse::ParseStream, parse_macro_input, punctuated::Punctuated,
7    Attribute, Ident, ItemFn, Lit, Meta, Result, Token,
8};
9
10/// Macro used to export enums
11///
12/// Example
13/// ```no_run
14/// #[cpy_enum]
15/// #[comment = "Material types"]
16/// enum Material {
17///     Plastic,
18///     Rubber,
19/// }
20/// ```
21#[proc_macro_attribute]
22pub fn cpy_enum(_attributes: TokenStream, item: TokenStream) -> TokenStream {
23    let input = parse_macro_input!(item as syn::ItemEnum);
24    let name = &input.ident;
25
26    let comment = get_comment(&input.attrs);
27
28    let variants: Vec<_> = input.variants.iter().map(|v| &v.ident).collect();
29    let expanded = quote! {
30        #[doc = #comment]
31        #[derive(Clone, Debug)]
32        #[repr(C)]
33        #[cfg_attr(feature = "python", pyo3::prelude::pyclass)]
34        pub enum #name {
35            #(#variants),*
36        }
37    };
38
39    expanded.into()
40}
41
42/// Macro used to export structures
43///
44/// Example
45/// ```no_run
46/// #[cpy_struct]
47/// #[comment = "2D Size"]
48/// struct Size2D {
49///     width: f64,
50///     height: f64,
51/// }
52///
53/// #[cpy_struct]
54/// #[comment = "Tire structure"]
55/// struct Tire {
56///     material: Material,
57///     pressure: f64,
58///     size: Size2D,
59/// }
60/// ```
61#[proc_macro_attribute]
62pub fn cpy_struct(_attributes: TokenStream, item: TokenStream) -> TokenStream {
63    let input = parse_macro_input!(item as syn::ItemStruct);
64    let name = &input.ident;
65
66    let comment = get_comment(&input.attrs);
67
68    let fields: Vec<_> = input
69        .fields
70        .iter()
71        .map(|f| {
72            let fname = &f.ident;
73            let ftype = &f.ty;
74            quote! { #fname: #ftype }
75        })
76        .collect();
77
78    let expanded = quote! {
79        #[doc = #comment]
80        #[derive(Clone, Debug)]
81        #[repr(C)]
82        #[cfg_attr(feature = "python", pyo3::prelude::pyclass(get_all, set_all))]
83        pub struct #name {
84            #(#fields),*
85        }
86    };
87
88    expanded.into()
89}
90
91/// Macro used to export functions for both C/C++ and Python
92///
93/// Example
94/// ```no_run
95/// #[cpy_fn] // You can also use `#[comment = "Something"]` to document both languages at once
96/// #[comment_c = "@brief Calculates the aspect ratio of a wheel based on its height and width.\n
97///     @param height Height of the wheel.\n
98///     @param width Width of the wheel.\n
99///     @return float Aspect ratio of the wheel.\n"]
100/// #[comment_py = "Calculates the aspect ratio of a wheel based on its height and width.\n
101///     Args:\n
102///         height (float): Height of the wheel.\n
103///         width (float): Width of the wheel.\n
104///     Returns:\n
105///         float: Aspect ratio of the wheel.\n"]
106/// fn wheel_size_aspect(height: f32, width: f32) -> f32 {
107///     (height / width) * 100.0
108/// }
109/// ```
110#[proc_macro_attribute]
111pub fn cpy_fn(_attributes: TokenStream, item: TokenStream) -> TokenStream {
112    let mut input = parse_macro_input!(item as ItemFn);
113
114    let (comment_c, comment_py) = get_comments(&input.attrs);
115    input.attrs.retain(|attr| {
116        !attr.path.is_ident("comment")
117            && !attr.path.is_ident("comment_c")
118            && !attr.path.is_ident("comment_py")
119    });
120
121    let fn_name = &input.sig.ident;
122    let inputs = &input.sig.inputs;
123    let output = &input.sig.output;
124    let block = &input.block;
125
126    let expanded = quote! {
127        #[cfg_attr(not(feature = "python"), doc = #comment_c)]
128        #[cfg_attr(feature = "python", doc = #comment_py)]
129        #[no_mangle]
130        #[cfg_attr(feature = "python", pyo3::prelude::pyfunction)]
131        pub extern "C" fn #fn_name(#inputs) #output #block
132    };
133
134    expanded.into()
135}
136
137/// Macro used to export exclusive C++ functions
138///
139/// Example
140/// ```no_run
141/// #[cpy_fn_c]
142/// #[comment = "Format size of wheels for C ABI"]
143/// fn format_size_of_wheels_c(sizes: *const u8, length: usize) {
144///     let values = unsafe {
145///         assert!(!sizes.is_null());
146///         std::slice::from_raw_parts(sizes, length)
147///     };
148///     println!("Wheel sizes: {:?}", values);
149/// }
150/// ```
151#[proc_macro_attribute]
152pub fn cpy_fn_c(_attributes: TokenStream, item: TokenStream) -> TokenStream {
153    let mut input = parse_macro_input!(item as ItemFn);
154
155    let comment = get_comment(&input.attrs);
156    input.attrs.retain(|attr| !attr.path.is_ident("comment"));
157
158    let mut fn_name = input.sig.ident;
159    if fn_name.to_string().ends_with("_c") {
160        fn_name = format_ident!("{}", &fn_name.to_string().trim_end_matches("_c"));
161    }
162
163    let inputs = &input.sig.inputs;
164    let output = &input.sig.output;
165    let block = &input.block;
166
167    let expanded = quote! {
168        #[doc = #comment]
169        #[no_mangle]
170        #[cfg(not(feature = "python"))]
171        pub extern "C" fn #fn_name(#inputs) #output #block
172    };
173
174    expanded.into()
175}
176
177/// Macro used to export exclusive python functions
178///
179/// Example
180/// ```no_run
181/// #[cpy_fn_py]
182/// #[comment = "Format size of wheels for Python"]
183/// fn format_size_of_wheels_py(sizes: Vec<u8>) {
184///     println!("Wheel sizes: {:?}", sizes);
185/// }
186/// ```
187#[proc_macro_attribute]
188pub fn cpy_fn_py(_attributes: TokenStream, item: TokenStream) -> TokenStream {
189    let mut input = parse_macro_input!(item as ItemFn);
190
191    let comment = get_comment(&input.attrs);
192    input.attrs.retain(|attr| !attr.path.is_ident("comment"));
193
194    let mut fn_name = input.sig.ident;
195    if fn_name.to_string().ends_with("_py") {
196        fn_name = format_ident!("{}", &fn_name.to_string().trim_end_matches("_py"));
197    }
198    let inputs = &input.sig.inputs;
199    let output = &input.sig.output;
200    let block = &input.block;
201
202    let expanded = quote! {
203        #[doc = #comment]
204        #[cfg(feature = "python")]
205        #[pyo3::prelude::pyfunction]
206        pub fn #fn_name(#inputs) #output #block
207    };
208
209    expanded.into()
210}
211
212fn get_comment(attributes: &[Attribute]) -> String {
213    for attribute in attributes {
214        if let Ok(Meta::NameValue(meta_name_value)) = attribute.parse_meta() {
215            if meta_name_value.path.is_ident("comment") {
216                if let Lit::Str(lit_str) = meta_name_value.lit {
217                    return lit_str.value();
218                }
219            }
220        }
221    }
222    "No documentation".to_string()
223}
224
225fn get_comments(attributes: &[Attribute]) -> (Option<String>, Option<String>) {
226    let mut comment_c: Option<String> = None;
227    let mut comment_py: Option<String> = None;
228    let mut comment: Option<String> = None;
229
230    for attribute in attributes {
231        if let Ok(Meta::NameValue(meta_name_value)) = attribute.parse_meta() {
232            if let Some(ident) = meta_name_value.path.get_ident() {
233                match ident.to_string().as_str() {
234                    "comment_c" => {
235                        if let Lit::Str(lit_str) = meta_name_value.lit {
236                            comment_c = Some(lit_str.value());
237                        }
238                    }
239                    "comment_py" => {
240                        if let Lit::Str(lit_str) = meta_name_value.lit {
241                            comment_py = Some(lit_str.value());
242                        }
243                    }
244                    "comment" => {
245                        if let Lit::Str(lit_str) = meta_name_value.lit {
246                            comment = Some(lit_str.value());
247                        }
248                    }
249                    _ => {}
250                }
251            }
252        }
253    }
254
255    if let Some(documentation) = comment {
256        comment_c = comment_c.or(Some(documentation.to_string()));
257        comment_py = comment_py.or(Some(documentation.to_string()));
258    } else {
259        comment_c = comment_c.or(Some("No documentation".to_string()));
260        comment_py = comment_py.or(Some("No documentation".to_string()));
261    }
262
263    (comment_c, comment_py)
264}
265
266struct CpyModuleInput {
267    name: Ident,
268    types: Punctuated<Ident, Token![,]>,
269    functions: Punctuated<Ident, Token![,]>,
270}
271
272impl Parse for CpyModuleInput {
273    fn parse(input: ParseStream) -> Result<Self> {
274        // Name
275        let _: Ident = input.parse()?;
276        input.parse::<Token![=]>()?;
277        let name: Ident = input.parse()?;
278        input.parse::<Token![,]>()?;
279
280        // Types
281        let _: Ident = input.parse()?;
282        input.parse::<Token![=]>()?;
283        let types_content;
284        bracketed!(types_content in input);
285        let types: Punctuated<Ident, Token![,]> = types_content.parse_terminated(Ident::parse)?;
286        input.parse::<Token![,]>()?;
287
288        // Functions
289        let _: Ident = input.parse()?;
290        input.parse::<Token![=]>()?;
291        let functions_content;
292        bracketed!(functions_content in input);
293        let functions: Punctuated<Ident, Token![,]> =
294            functions_content.parse_terminated(Ident::parse)?;
295
296        Ok(CpyModuleInput {
297            name,
298            types,
299            functions,
300        })
301    }
302}
303
304/// Macro used to export the python module
305///
306/// Example
307/// ```no_run
308/// cpy_module!(
309///     name = example, // Module name
310///     types = [Material, Size2D, Tire], // Structures and Enums to be exported
311///     functions = [ // Functions to be accessed from python
312///         create_random_tire,
313///         format_wheel_identifier,
314///         format_size_of_wheels,
315///         func_with_no_return,
316///         wheel_size_aspect
317///     ]
318/// );
319/// ```
320#[proc_macro]
321pub fn cpy_module(input: TokenStream) -> TokenStream {
322    let input = parse_macro_input!(input as CpyModuleInput);
323
324    let type_additions: Vec<_> = input
325        .types
326        .iter()
327        .map(|item| {
328            quote! {
329                m.add_class::<#item>()?;
330            }
331        })
332        .collect();
333
334    let function_additions: Vec<_> = input
335        .functions
336        .iter()
337        .map(|item| {
338            quote! {
339                m.add_function(pyo3::wrap_pyfunction!(#item, m)?)?;
340            }
341        })
342        .collect();
343
344    let module_name = &input.name;
345
346    let expanded = quote! {
347        #[cfg(feature = "python")]
348        #[pyo3::pymodule]
349        fn #module_name(py: pyo3::prelude::Python, m: &pyo3::prelude::PyModule) -> pyo3::prelude::PyResult<()> {
350            #(#type_additions)*
351            #(#function_additions)*
352            Ok(())
353        }
354    };
355
356    expanded.into()
357}