Skip to main content

pyo3_polars_derive/
lib.rs

1mod attr;
2mod keywords;
3
4use proc_macro::TokenStream;
5use quote::quote;
6use syn::{parse_macro_input, FnArg};
7
8fn quote_get_kwargs() -> proc_macro2::TokenStream {
9    quote!(
10    let kwargs = std::slice::from_raw_parts(kwargs_ptr, kwargs_len);
11
12    let kwargs = match pyo3_polars::derive::_parse_kwargs(kwargs)  {
13        Ok(value) => value,
14        Err(err) => {
15            let err = polars_error::polars_err!(InvalidOperation: "could not parse kwargs: '{}'\n\nCheck: registration of kwargs in the plugin.", err);
16            pyo3_polars::derive::_update_last_error(err);
17            return;
18        }
19    };
20
21    )
22}
23
24fn quote_call_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {
25    let kwargs = quote_get_kwargs();
26    quote!(
27            // parse the kwargs and assign to `let kwargs`
28            #kwargs
29
30            // define the function
31            #ast
32
33            // call the function
34        let result: polars_error::PolarsResult<polars_core::prelude::Series> = #fn_name(&inputs, kwargs);
35
36    )
37}
38
39fn quote_call_context(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {
40    quote!(
41            let context = *context;
42
43            // define the function
44            #ast
45
46            // call the function
47        let result: polars_error::PolarsResult<polars_core::prelude::Series> = #fn_name(&inputs, context);
48    )
49}
50
51fn quote_call_context_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {
52    quote!(
53            let context = *context;
54
55            let kwargs = std::slice::from_raw_parts(kwargs_ptr, kwargs_len);
56
57            let kwargs = match pyo3_polars::derive::_parse_kwargs(kwargs)  {
58                    Ok(value) => value,
59                    Err(err) => {
60                        pyo3_polars::derive::_update_last_error(err);
61                        return;
62                    }
63            };
64
65            // define the function
66            #ast
67
68            // call the function
69        let result: polars_error::PolarsResult<polars_core::prelude::Series> = #fn_name(&inputs, context, kwargs);
70    )
71}
72
73fn quote_call_no_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {
74    quote!(
75            // define the function
76            #ast
77            // call the function
78            let result: polars_error::PolarsResult<polars_core::prelude::Series> = #fn_name(&inputs);
79    )
80}
81
82fn quote_process_results() -> proc_macro2::TokenStream {
83    quote!(match result {
84        Ok(out) => {
85            // Update return value.
86            *return_value = polars_ffi::version_0::export_series(&out);
87        },
88        Err(err) => {
89            // Set latest error, but leave return value in empty state.
90            pyo3_polars::derive::_update_last_error(err);
91        },
92    })
93}
94
95fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream {
96    // count how often the user define a kwargs argument.
97    let args = ast
98        .sig
99        .inputs
100        .iter()
101        .skip(1)
102        .map(|fn_arg| {
103            if let FnArg::Typed(pat) = fn_arg {
104                if let syn::Pat::Ident(pat) = pat.pat.as_ref() {
105                    pat.ident.to_string()
106                } else {
107                    panic!("expected an argument")
108                }
109            } else {
110                panic!("expected a type argument")
111            }
112        })
113        .collect::<Vec<_>>();
114
115    let fn_name = &ast.sig.ident;
116
117    // Get the tokenstream of the call logic.
118    let quote_call = match args.len() {
119        0 => quote_call_no_kwargs(&ast, fn_name),
120        1 => match args[0].as_str() {
121            "kwargs" => quote_call_kwargs(&ast, fn_name),
122            "context" => quote_call_context(&ast, fn_name),
123            a => panic!("didn't expect argument {a}"),
124        },
125        2 => match (args[0].as_str(), args[1].as_str()) {
126            ("context", "kwargs") => quote_call_context_kwargs(&ast, fn_name),
127            ("kwargs", "context") => panic!("'kwargs', 'context' order should be reversed"),
128            (a, b) => panic!("didn't expect arguments {a}, {b}"),
129        },
130        _ => panic!("didn't expect so many arguments"),
131    };
132
133    let quote_process_result = quote_process_results();
134    let fn_name = get_expression_function_name(fn_name);
135
136    quote!(
137        use ::pyo3_polars::export::*;
138
139        // create the outer public function
140        #[no_mangle]
141        pub unsafe extern "C" fn #fn_name (
142            e: *mut polars_ffi::version_0::SeriesExport,
143            input_len: usize,
144            kwargs_ptr: *const u8,
145            kwargs_len: usize,
146            return_value: *mut polars_ffi::version_0::SeriesExport,
147            context: *mut polars_ffi::version_0::CallerContext
148        )  {
149            let panic_result = std::panic::catch_unwind(move || {
150                let inputs = polars_ffi::version_0::import_series_buffer(e, input_len).unwrap();
151
152                #quote_call
153
154                #quote_process_result
155            });
156
157            if panic_result.is_err() {
158                // Set latest to panic;
159                ::pyo3_polars::derive::_set_panic();
160            }
161        }
162    )
163}
164
165fn get_field_function_name(fn_name: &syn::Ident) -> syn::Ident {
166    syn::Ident::new(&format!("_polars_plugin_field_{fn_name}"), fn_name.span())
167}
168
169fn get_expression_function_name(fn_name: &syn::Ident) -> syn::Ident {
170    syn::Ident::new(&format!("_polars_plugin_{fn_name}"), fn_name.span())
171}
172
173fn quote_get_inputs() -> proc_macro2::TokenStream {
174    quote!(
175             let inputs = std::slice::from_raw_parts(field, len);
176             let inputs = inputs.iter().map(|field| {
177                 let field = polars_arrow::ffi::import_field_from_c(field).unwrap();
178                 let out = polars_core::prelude::Field::from(&field);
179                 out
180             }).collect::<Vec<_>>();
181    )
182}
183
184fn create_field_function(
185    fn_name: &syn::Ident,
186    dtype_fn_name: &syn::Ident,
187    kwargs: bool,
188) -> proc_macro2::TokenStream {
189    let map_field_name = get_field_function_name(fn_name);
190    let inputs = quote_get_inputs();
191
192    let call_fn = if kwargs {
193        let kwargs = quote_get_kwargs();
194        quote! (
195            #kwargs
196            let result = #dtype_fn_name(&inputs, kwargs);
197        )
198    } else {
199        quote!(
200            let result = #dtype_fn_name(&inputs);
201        )
202    };
203
204    quote! (
205        #[no_mangle]
206        pub unsafe extern "C" fn #map_field_name(
207            field: *mut polars_arrow::ffi::ArrowSchema,
208            len: usize,
209            return_value: *mut polars_arrow::ffi::ArrowSchema,
210            kwargs_ptr: *const u8,
211            kwargs_len: usize,
212        ) {
213            let panic_result = std::panic::catch_unwind(move || {
214                #inputs;
215
216                #call_fn;
217
218                match result {
219                    Ok(out) => {
220                        let out = polars_arrow::ffi::export_field_to_c(&out.to_arrow(polars_core::datatypes::CompatLevel::newest()));
221                        *return_value = out;
222                    },
223                    Err(err) => {
224                        // Set latest error, but leave return value in empty state.
225                        pyo3_polars::derive::_update_last_error(err);
226                    }
227                }
228            });
229
230            if panic_result.is_err() {
231                // Set latest to panic;
232                pyo3_polars::derive::_set_panic();
233            }
234        }
235    )
236}
237
238fn create_field_function_from_with_dtype(
239    fn_name: &syn::Ident,
240    dtype: syn::Ident,
241) -> proc_macro2::TokenStream {
242    let map_field_name = get_field_function_name(fn_name);
243    let inputs = quote_get_inputs();
244
245    quote! (
246        #[no_mangle]
247        pub unsafe extern "C" fn #map_field_name(
248            field: *mut polars_arrow::ffi::ArrowSchema,
249            len: usize,
250            return_value: *mut polars_arrow::ffi::ArrowSchema
251        ) {
252            #inputs
253
254            let mapper = polars_plan::prelude::FieldsMapper::new(&inputs);
255            let dtype = polars_core::datatypes::DataType::#dtype;
256            let out = mapper.with_dtype(dtype).unwrap();
257            let out = polars_arrow::ffi::export_field_to_c(&out.to_arrow(polars_core::datatypes::CompatLevel::newest()));
258            *return_value = out;
259        }
260    )
261}
262
263#[proc_macro_attribute]
264pub fn polars_expr(attr: TokenStream, input: TokenStream) -> TokenStream {
265    let ast = parse_macro_input!(input as syn::ItemFn);
266
267    let options = parse_macro_input!(attr as attr::ExprsFunctionOptions);
268    let expanded_field_fn = if let Some(fn_name) = options.output_type_fn {
269        create_field_function(&ast.sig.ident, &fn_name, false)
270    } else if let Some(fn_name) = options.output_type_fn_kwargs {
271        create_field_function(&ast.sig.ident, &fn_name, true)
272    } else if let Some(dtype) = options.output_dtype {
273        create_field_function_from_with_dtype(&ast.sig.ident, dtype)
274    } else {
275        panic!("didn't understand polars_expr attribute")
276    };
277
278    let expanded_expr = create_expression_function(ast);
279    let expanded = quote!(
280        #expanded_field_fn
281
282        #expanded_expr
283    );
284    TokenStream::from(expanded)
285}