pyo3_polars_derive/
lib.rs1mod attr;
2mod keywords;
3
4use proc_macro::TokenStream;
5use quote::quote;
6use syn::{parse_macro_input, FnArg};
7
8fn quote_get_kwargs() -> proc_macro2::TokenStream {
9 quote!(
10 let kwargs = std::slice::from_raw_parts(kwargs_ptr, kwargs_len);
11
12 let kwargs = match pyo3_polars::derive::_parse_kwargs(kwargs) {
13 Ok(value) => value,
14 Err(err) => {
15 let err = polars_error::polars_err!(InvalidOperation: "could not parse kwargs: '{}'\n\nCheck: registration of kwargs in the plugin.", err);
16 pyo3_polars::derive::_update_last_error(err);
17 return;
18 }
19 };
20
21 )
22}
23
24fn quote_call_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {
25 let kwargs = quote_get_kwargs();
26 quote!(
27 #kwargs
29
30 #ast
32
33 let result: polars_error::PolarsResult<polars_core::prelude::Series> = #fn_name(&inputs, kwargs);
35
36 )
37}
38
39fn quote_call_context(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {
40 quote!(
41 let context = *context;
42
43 #ast
45
46 let result: polars_error::PolarsResult<polars_core::prelude::Series> = #fn_name(&inputs, context);
48 )
49}
50
51fn quote_call_context_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {
52 quote!(
53 let context = *context;
54
55 let kwargs = std::slice::from_raw_parts(kwargs_ptr, kwargs_len);
56
57 let kwargs = match pyo3_polars::derive::_parse_kwargs(kwargs) {
58 Ok(value) => value,
59 Err(err) => {
60 pyo3_polars::derive::_update_last_error(err);
61 return;
62 }
63 };
64
65 #ast
67
68 let result: polars_error::PolarsResult<polars_core::prelude::Series> = #fn_name(&inputs, context, kwargs);
70 )
71}
72
73fn quote_call_no_kwargs(ast: &syn::ItemFn, fn_name: &syn::Ident) -> proc_macro2::TokenStream {
74 quote!(
75 #ast
77 let result: polars_error::PolarsResult<polars_core::prelude::Series> = #fn_name(&inputs);
79 )
80}
81
82fn quote_process_results() -> proc_macro2::TokenStream {
83 quote!(match result {
84 Ok(out) => {
85 *return_value = polars_ffi::version_0::export_series(&out);
87 },
88 Err(err) => {
89 pyo3_polars::derive::_update_last_error(err);
91 },
92 })
93}
94
95fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream {
96 let args = ast
98 .sig
99 .inputs
100 .iter()
101 .skip(1)
102 .map(|fn_arg| {
103 if let FnArg::Typed(pat) = fn_arg {
104 if let syn::Pat::Ident(pat) = pat.pat.as_ref() {
105 pat.ident.to_string()
106 } else {
107 panic!("expected an argument")
108 }
109 } else {
110 panic!("expected a type argument")
111 }
112 })
113 .collect::<Vec<_>>();
114
115 let fn_name = &ast.sig.ident;
116
117 let quote_call = match args.len() {
119 0 => quote_call_no_kwargs(&ast, fn_name),
120 1 => match args[0].as_str() {
121 "kwargs" => quote_call_kwargs(&ast, fn_name),
122 "context" => quote_call_context(&ast, fn_name),
123 a => panic!("didn't expect argument {a}"),
124 },
125 2 => match (args[0].as_str(), args[1].as_str()) {
126 ("context", "kwargs") => quote_call_context_kwargs(&ast, fn_name),
127 ("kwargs", "context") => panic!("'kwargs', 'context' order should be reversed"),
128 (a, b) => panic!("didn't expect arguments {a}, {b}"),
129 },
130 _ => panic!("didn't expect so many arguments"),
131 };
132
133 let quote_process_result = quote_process_results();
134 let fn_name = get_expression_function_name(fn_name);
135
136 quote!(
137 use ::pyo3_polars::export::*;
138
139 #[no_mangle]
141 pub unsafe extern "C" fn #fn_name (
142 e: *mut polars_ffi::version_0::SeriesExport,
143 input_len: usize,
144 kwargs_ptr: *const u8,
145 kwargs_len: usize,
146 return_value: *mut polars_ffi::version_0::SeriesExport,
147 context: *mut polars_ffi::version_0::CallerContext
148 ) {
149 let panic_result = std::panic::catch_unwind(move || {
150 let inputs = polars_ffi::version_0::import_series_buffer(e, input_len).unwrap();
151
152 #quote_call
153
154 #quote_process_result
155 });
156
157 if panic_result.is_err() {
158 ::pyo3_polars::derive::_set_panic();
160 }
161 }
162 )
163}
164
165fn get_field_function_name(fn_name: &syn::Ident) -> syn::Ident {
166 syn::Ident::new(&format!("_polars_plugin_field_{fn_name}"), fn_name.span())
167}
168
169fn get_expression_function_name(fn_name: &syn::Ident) -> syn::Ident {
170 syn::Ident::new(&format!("_polars_plugin_{fn_name}"), fn_name.span())
171}
172
173fn quote_get_inputs() -> proc_macro2::TokenStream {
174 quote!(
175 let inputs = std::slice::from_raw_parts(field, len);
176 let inputs = inputs.iter().map(|field| {
177 let field = polars_arrow::ffi::import_field_from_c(field).unwrap();
178 let out = polars_core::prelude::Field::from(&field);
179 out
180 }).collect::<Vec<_>>();
181 )
182}
183
184fn create_field_function(
185 fn_name: &syn::Ident,
186 dtype_fn_name: &syn::Ident,
187 kwargs: bool,
188) -> proc_macro2::TokenStream {
189 let map_field_name = get_field_function_name(fn_name);
190 let inputs = quote_get_inputs();
191
192 let call_fn = if kwargs {
193 let kwargs = quote_get_kwargs();
194 quote! (
195 #kwargs
196 let result = #dtype_fn_name(&inputs, kwargs);
197 )
198 } else {
199 quote!(
200 let result = #dtype_fn_name(&inputs);
201 )
202 };
203
204 quote! (
205 #[no_mangle]
206 pub unsafe extern "C" fn #map_field_name(
207 field: *mut polars_arrow::ffi::ArrowSchema,
208 len: usize,
209 return_value: *mut polars_arrow::ffi::ArrowSchema,
210 kwargs_ptr: *const u8,
211 kwargs_len: usize,
212 ) {
213 let panic_result = std::panic::catch_unwind(move || {
214 #inputs;
215
216 #call_fn;
217
218 match result {
219 Ok(out) => {
220 let out = polars_arrow::ffi::export_field_to_c(&out.to_arrow(polars_core::datatypes::CompatLevel::newest()));
221 *return_value = out;
222 },
223 Err(err) => {
224 pyo3_polars::derive::_update_last_error(err);
226 }
227 }
228 });
229
230 if panic_result.is_err() {
231 pyo3_polars::derive::_set_panic();
233 }
234 }
235 )
236}
237
238fn create_field_function_from_with_dtype(
239 fn_name: &syn::Ident,
240 dtype: syn::Ident,
241) -> proc_macro2::TokenStream {
242 let map_field_name = get_field_function_name(fn_name);
243 let inputs = quote_get_inputs();
244
245 quote! (
246 #[no_mangle]
247 pub unsafe extern "C" fn #map_field_name(
248 field: *mut polars_arrow::ffi::ArrowSchema,
249 len: usize,
250 return_value: *mut polars_arrow::ffi::ArrowSchema
251 ) {
252 #inputs
253
254 let mapper = polars_plan::prelude::FieldsMapper::new(&inputs);
255 let dtype = polars_core::datatypes::DataType::#dtype;
256 let out = mapper.with_dtype(dtype).unwrap();
257 let out = polars_arrow::ffi::export_field_to_c(&out.to_arrow(polars_core::datatypes::CompatLevel::newest()));
258 *return_value = out;
259 }
260 )
261}
262
263#[proc_macro_attribute]
264pub fn polars_expr(attr: TokenStream, input: TokenStream) -> TokenStream {
265 let ast = parse_macro_input!(input as syn::ItemFn);
266
267 let options = parse_macro_input!(attr as attr::ExprsFunctionOptions);
268 let expanded_field_fn = if let Some(fn_name) = options.output_type_fn {
269 create_field_function(&ast.sig.ident, &fn_name, false)
270 } else if let Some(fn_name) = options.output_type_fn_kwargs {
271 create_field_function(&ast.sig.ident, &fn_name, true)
272 } else if let Some(dtype) = options.output_dtype {
273 create_field_function_from_with_dtype(&ast.sig.ident, dtype)
274 } else {
275 panic!("didn't understand polars_expr attribute")
276 };
277
278 let expanded_expr = create_expression_function(ast);
279 let expanded = quote!(
280 #expanded_field_fn
281
282 #expanded_expr
283 );
284 TokenStream::from(expanded)
285}