1use abi_stable::{
19 std_types::{ROption, RResult, RString, RVec},
20 StableAbi,
21};
22use arrow::datatypes::Schema;
23use arrow::{
24 compute::SortOptions,
25 datatypes::{DataType, SchemaRef},
26};
27use arrow_schema::{Field, FieldRef};
28use datafusion::logical_expr::LimitEffect;
29use datafusion::physical_expr::PhysicalExpr;
30use datafusion::{
31 error::DataFusionError,
32 logical_expr::{
33 function::WindowUDFFieldArgs, type_coercion::functions::fields_with_window_udf,
34 PartitionEvaluator,
35 },
36};
37use datafusion::{
38 error::Result,
39 logical_expr::{Signature, WindowUDF, WindowUDFImpl},
40};
41use datafusion_common::exec_err;
42use partition_evaluator::{FFI_PartitionEvaluator, ForeignPartitionEvaluator};
43use partition_evaluator_args::{
44 FFI_PartitionEvaluatorArgs, ForeignPartitionEvaluatorArgs,
45};
46use std::hash::{Hash, Hasher};
47use std::{ffi::c_void, sync::Arc};
48
49mod partition_evaluator;
50mod partition_evaluator_args;
51mod range;
52
53use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped};
54use crate::{
55 arrow_wrappers::WrappedSchema,
56 df_result, rresult, rresult_return,
57 util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped},
58 volatility::FFI_Volatility,
59};
60
61#[repr(C)]
63#[derive(Debug, StableAbi)]
64#[allow(non_camel_case_types)]
65pub struct FFI_WindowUDF {
66 pub name: RString,
68
69 pub aliases: RVec<RString>,
71
72 pub volatility: FFI_Volatility,
74
75 pub partition_evaluator:
76 unsafe extern "C" fn(
77 udwf: &Self,
78 args: FFI_PartitionEvaluatorArgs,
79 ) -> RResult<FFI_PartitionEvaluator, RString>,
80
81 pub field: unsafe extern "C" fn(
82 udwf: &Self,
83 input_types: RVec<WrappedSchema>,
84 display_name: RString,
85 ) -> RResult<WrappedSchema, RString>,
86
87 pub coerce_types: unsafe extern "C" fn(
92 udf: &Self,
93 arg_types: RVec<WrappedSchema>,
94 ) -> RResult<RVec<WrappedSchema>, RString>,
95
96 pub sort_options: ROption<FFI_SortOptions>,
97
98 pub clone: unsafe extern "C" fn(udf: &Self) -> Self,
101
102 pub release: unsafe extern "C" fn(udf: &mut Self),
104
105 pub private_data: *mut c_void,
108}
109
110unsafe impl Send for FFI_WindowUDF {}
111unsafe impl Sync for FFI_WindowUDF {}
112
113pub struct WindowUDFPrivateData {
114 pub udf: Arc<WindowUDF>,
115}
116
117impl FFI_WindowUDF {
118 unsafe fn inner(&self) -> &Arc<WindowUDF> {
119 let private_data = self.private_data as *const WindowUDFPrivateData;
120 &(*private_data).udf
121 }
122}
123
124unsafe extern "C" fn partition_evaluator_fn_wrapper(
125 udwf: &FFI_WindowUDF,
126 args: FFI_PartitionEvaluatorArgs,
127) -> RResult<FFI_PartitionEvaluator, RString> {
128 let inner = udwf.inner();
129
130 let args = rresult_return!(ForeignPartitionEvaluatorArgs::try_from(args));
131
132 let evaluator = rresult_return!(inner.partition_evaluator_factory((&args).into()));
133
134 RResult::ROk(evaluator.into())
135}
136
137unsafe extern "C" fn field_fn_wrapper(
138 udwf: &FFI_WindowUDF,
139 input_fields: RVec<WrappedSchema>,
140 display_name: RString,
141) -> RResult<WrappedSchema, RString> {
142 let inner = udwf.inner();
143
144 let input_fields = rresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields));
145
146 let field = rresult_return!(inner.field(WindowUDFFieldArgs::new(
147 &input_fields,
148 display_name.as_str()
149 )));
150
151 let schema = Arc::new(Schema::new(vec![field]));
152
153 RResult::ROk(WrappedSchema::from(schema))
154}
155
156unsafe extern "C" fn coerce_types_fn_wrapper(
157 udwf: &FFI_WindowUDF,
158 arg_types: RVec<WrappedSchema>,
159) -> RResult<RVec<WrappedSchema>, RString> {
160 let inner = udwf.inner();
161
162 let arg_fields = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types))
163 .into_iter()
164 .map(|dt| Field::new("f", dt, false))
165 .map(Arc::new)
166 .collect::<Vec<_>>();
167
168 let return_fields = rresult_return!(fields_with_window_udf(&arg_fields, inner));
169 let return_types = return_fields
170 .into_iter()
171 .map(|f| f.data_type().to_owned())
172 .collect::<Vec<_>>();
173
174 rresult!(vec_datatype_to_rvec_wrapped(&return_types))
175}
176
177unsafe extern "C" fn release_fn_wrapper(udwf: &mut FFI_WindowUDF) {
178 let private_data = Box::from_raw(udwf.private_data as *mut WindowUDFPrivateData);
179 drop(private_data);
180}
181
182unsafe extern "C" fn clone_fn_wrapper(udwf: &FFI_WindowUDF) -> FFI_WindowUDF {
183 let private_data = Box::new(WindowUDFPrivateData {
190 udf: Arc::clone(udwf.inner()),
191 });
192
193 FFI_WindowUDF {
194 name: udwf.name.clone(),
195 aliases: udwf.aliases.clone(),
196 volatility: udwf.volatility.clone(),
197 partition_evaluator: partition_evaluator_fn_wrapper,
198 sort_options: udwf.sort_options.clone(),
199 coerce_types: coerce_types_fn_wrapper,
200 field: field_fn_wrapper,
201 clone: clone_fn_wrapper,
202 release: release_fn_wrapper,
203 private_data: Box::into_raw(private_data) as *mut c_void,
204 }
205}
206
207impl Clone for FFI_WindowUDF {
208 fn clone(&self) -> Self {
209 unsafe { (self.clone)(self) }
210 }
211}
212
213impl From<Arc<WindowUDF>> for FFI_WindowUDF {
214 fn from(udf: Arc<WindowUDF>) -> Self {
215 let name = udf.name().into();
216 let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect();
217 let volatility = udf.signature().volatility.into();
218 let sort_options = udf.sort_options().map(|v| (&v).into()).into();
219
220 let private_data = Box::new(WindowUDFPrivateData { udf });
221
222 Self {
223 name,
224 aliases,
225 volatility,
226 partition_evaluator: partition_evaluator_fn_wrapper,
227 sort_options,
228 coerce_types: coerce_types_fn_wrapper,
229 field: field_fn_wrapper,
230 clone: clone_fn_wrapper,
231 release: release_fn_wrapper,
232 private_data: Box::into_raw(private_data) as *mut c_void,
233 }
234 }
235}
236
237impl Drop for FFI_WindowUDF {
238 fn drop(&mut self) {
239 unsafe { (self.release)(self) }
240 }
241}
242
243#[derive(Debug)]
250pub struct ForeignWindowUDF {
251 name: String,
252 aliases: Vec<String>,
253 udf: FFI_WindowUDF,
254 signature: Signature,
255}
256
257unsafe impl Send for ForeignWindowUDF {}
258unsafe impl Sync for ForeignWindowUDF {}
259
260impl PartialEq for ForeignWindowUDF {
261 fn eq(&self, other: &Self) -> bool {
262 std::ptr::eq(self, other)
264 }
265}
266impl Eq for ForeignWindowUDF {}
267impl Hash for ForeignWindowUDF {
268 fn hash<H: Hasher>(&self, state: &mut H) {
269 std::ptr::hash(self, state)
270 }
271}
272
273impl TryFrom<&FFI_WindowUDF> for ForeignWindowUDF {
274 type Error = DataFusionError;
275
276 fn try_from(udf: &FFI_WindowUDF) -> Result<Self, Self::Error> {
277 let name = udf.name.to_owned().into();
278 let signature = Signature::user_defined((&udf.volatility).into());
279
280 let aliases = udf.aliases.iter().map(|s| s.to_string()).collect();
281
282 Ok(Self {
283 name,
284 udf: udf.clone(),
285 aliases,
286 signature,
287 })
288 }
289}
290
291impl WindowUDFImpl for ForeignWindowUDF {
292 fn as_any(&self) -> &dyn std::any::Any {
293 self
294 }
295
296 fn name(&self) -> &str {
297 &self.name
298 }
299
300 fn signature(&self) -> &Signature {
301 &self.signature
302 }
303
304 fn aliases(&self) -> &[String] {
305 &self.aliases
306 }
307
308 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
309 unsafe {
310 let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
311 let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?;
312 Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
313 }
314 }
315
316 fn partition_evaluator(
317 &self,
318 args: datafusion::logical_expr::function::PartitionEvaluatorArgs,
319 ) -> Result<Box<dyn PartitionEvaluator>> {
320 let evaluator = unsafe {
321 let args = FFI_PartitionEvaluatorArgs::try_from(args)?;
322 (self.udf.partition_evaluator)(&self.udf, args)
323 };
324
325 df_result!(evaluator).map(|evaluator| {
326 Box::new(ForeignPartitionEvaluator::from(evaluator))
327 as Box<dyn PartitionEvaluator>
328 })
329 }
330
331 fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
332 unsafe {
333 let input_types = vec_fieldref_to_rvec_wrapped(field_args.input_fields())?;
334 let schema = df_result!((self.udf.field)(
335 &self.udf,
336 input_types,
337 field_args.name().into()
338 ))?;
339 let schema: SchemaRef = schema.into();
340
341 match schema.fields().is_empty() {
342 true => exec_err!(
343 "Unable to retrieve field in WindowUDF via FFI - schema has no fields"
344 ),
345 false => Ok(schema.field(0).to_owned().into()),
346 }
347 }
348 }
349
350 fn sort_options(&self) -> Option<SortOptions> {
351 let options: Option<&FFI_SortOptions> = self.udf.sort_options.as_ref().into();
352 options.map(|s| s.into())
353 }
354
355 fn limit_effect(&self, _args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
356 LimitEffect::Unknown
357 }
358}
359
360#[repr(C)]
361#[derive(Debug, StableAbi, Clone)]
362#[allow(non_camel_case_types)]
363pub struct FFI_SortOptions {
364 pub descending: bool,
365 pub nulls_first: bool,
366}
367
368impl From<&SortOptions> for FFI_SortOptions {
369 fn from(value: &SortOptions) -> Self {
370 Self {
371 descending: value.descending,
372 nulls_first: value.nulls_first,
373 }
374 }
375}
376
377impl From<&FFI_SortOptions> for SortOptions {
378 fn from(value: &FFI_SortOptions) -> Self {
379 Self {
380 descending: value.descending,
381 nulls_first: value.nulls_first,
382 }
383 }
384}
385
386#[cfg(test)]
387#[cfg(feature = "integration-tests")]
388mod tests {
389 use crate::tests::create_record_batch;
390 use crate::udwf::{FFI_WindowUDF, ForeignWindowUDF};
391 use arrow::array::{create_array, ArrayRef};
392 use datafusion::functions_window::lead_lag::{lag_udwf, WindowShift};
393 use datafusion::logical_expr::expr::Sort;
394 use datafusion::logical_expr::{col, ExprFunctionExt, WindowUDF, WindowUDFImpl};
395 use datafusion::prelude::SessionContext;
396 use std::sync::Arc;
397
398 fn create_test_foreign_udwf(
399 original_udwf: impl WindowUDFImpl + 'static,
400 ) -> datafusion::common::Result<WindowUDF> {
401 let original_udwf = Arc::new(WindowUDF::from(original_udwf));
402
403 let local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into();
404
405 let foreign_udwf: ForeignWindowUDF = (&local_udwf).try_into()?;
406 Ok(foreign_udwf.into())
407 }
408
409 #[test]
410 fn test_round_trip_udwf() -> datafusion::common::Result<()> {
411 let original_udwf = lag_udwf();
412 let original_name = original_udwf.name().to_owned();
413
414 let local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into();
416
417 let foreign_udwf: ForeignWindowUDF = (&local_udwf).try_into()?;
419 let foreign_udwf: WindowUDF = foreign_udwf.into();
420
421 assert_eq!(original_name, foreign_udwf.name());
422 Ok(())
423 }
424
425 #[tokio::test]
426 async fn test_lag_udwf() -> datafusion::common::Result<()> {
427 let udwf = create_test_foreign_udwf(WindowShift::lag())?;
428
429 let ctx = SessionContext::default();
430 let df = ctx.read_batch(create_record_batch(-5, 5))?;
431
432 let df = df.select(vec![
433 col("a"),
434 udwf.call(vec![col("a")])
435 .order_by(vec![Sort::new(col("a"), true, true)])
436 .build()
437 .unwrap()
438 .alias("lag_a"),
439 ])?;
440
441 df.clone().show().await?;
442
443 let result = df.collect().await?;
444 let expected =
445 create_array!(Int32, [None, Some(-5), Some(-4), Some(-3), Some(-2)])
446 as ArrayRef;
447
448 assert_eq!(result.len(), 1);
449 assert_eq!(result[0].column(1), &expected);
450
451 Ok(())
452 }
453}