1use abi_stable::{
19 std_types::{ROption, RResult, RString, RVec},
20 StableAbi,
21};
22use arrow::datatypes::Schema;
23use arrow::{
24 compute::SortOptions,
25 datatypes::{DataType, SchemaRef},
26};
27use arrow_schema::{Field, FieldRef};
28use datafusion::{
29 error::DataFusionError,
30 logical_expr::{
31 function::WindowUDFFieldArgs, type_coercion::functions::fields_with_window_udf,
32 PartitionEvaluator,
33 },
34};
35use datafusion::{
36 error::Result,
37 logical_expr::{Signature, WindowUDF, WindowUDFImpl},
38};
39use partition_evaluator::{FFI_PartitionEvaluator, ForeignPartitionEvaluator};
40use partition_evaluator_args::{
41 FFI_PartitionEvaluatorArgs, ForeignPartitionEvaluatorArgs,
42};
43use std::hash::{Hash, Hasher};
44use std::{ffi::c_void, sync::Arc};
45
46mod partition_evaluator;
47mod partition_evaluator_args;
48mod range;
49
50use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped};
51use crate::{
52 arrow_wrappers::WrappedSchema,
53 df_result, rresult, rresult_return,
54 util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped},
55 volatility::FFI_Volatility,
56};
57
58#[repr(C)]
60#[derive(Debug, StableAbi)]
61#[allow(non_camel_case_types)]
62pub struct FFI_WindowUDF {
63 pub name: RString,
65
66 pub aliases: RVec<RString>,
68
69 pub volatility: FFI_Volatility,
71
72 pub partition_evaluator:
73 unsafe extern "C" fn(
74 udwf: &Self,
75 args: FFI_PartitionEvaluatorArgs,
76 ) -> RResult<FFI_PartitionEvaluator, RString>,
77
78 pub field: unsafe extern "C" fn(
79 udwf: &Self,
80 input_types: RVec<WrappedSchema>,
81 display_name: RString,
82 ) -> RResult<WrappedSchema, RString>,
83
84 pub coerce_types: unsafe extern "C" fn(
89 udf: &Self,
90 arg_types: RVec<WrappedSchema>,
91 ) -> RResult<RVec<WrappedSchema>, RString>,
92
93 pub sort_options: ROption<FFI_SortOptions>,
94
95 pub clone: unsafe extern "C" fn(udf: &Self) -> Self,
98
99 pub release: unsafe extern "C" fn(udf: &mut Self),
101
102 pub private_data: *mut c_void,
105}
106
107unsafe impl Send for FFI_WindowUDF {}
108unsafe impl Sync for FFI_WindowUDF {}
109
110pub struct WindowUDFPrivateData {
111 pub udf: Arc<WindowUDF>,
112}
113
114impl FFI_WindowUDF {
115 unsafe fn inner(&self) -> &Arc<WindowUDF> {
116 let private_data = self.private_data as *const WindowUDFPrivateData;
117 &(*private_data).udf
118 }
119}
120
121unsafe extern "C" fn partition_evaluator_fn_wrapper(
122 udwf: &FFI_WindowUDF,
123 args: FFI_PartitionEvaluatorArgs,
124) -> RResult<FFI_PartitionEvaluator, RString> {
125 let inner = udwf.inner();
126
127 let args = rresult_return!(ForeignPartitionEvaluatorArgs::try_from(args));
128
129 let evaluator = rresult_return!(inner.partition_evaluator_factory((&args).into()));
130
131 RResult::ROk(evaluator.into())
132}
133
134unsafe extern "C" fn field_fn_wrapper(
135 udwf: &FFI_WindowUDF,
136 input_fields: RVec<WrappedSchema>,
137 display_name: RString,
138) -> RResult<WrappedSchema, RString> {
139 let inner = udwf.inner();
140
141 let input_fields = rresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields));
142
143 let field = rresult_return!(inner.field(WindowUDFFieldArgs::new(
144 &input_fields,
145 display_name.as_str()
146 )));
147
148 let schema = Arc::new(Schema::new(vec![field]));
149
150 RResult::ROk(WrappedSchema::from(schema))
151}
152
153unsafe extern "C" fn coerce_types_fn_wrapper(
154 udwf: &FFI_WindowUDF,
155 arg_types: RVec<WrappedSchema>,
156) -> RResult<RVec<WrappedSchema>, RString> {
157 let inner = udwf.inner();
158
159 let arg_fields = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types))
160 .into_iter()
161 .map(|dt| Field::new("f", dt, false))
162 .map(Arc::new)
163 .collect::<Vec<_>>();
164
165 let return_fields = rresult_return!(fields_with_window_udf(&arg_fields, inner));
166 let return_types = return_fields
167 .into_iter()
168 .map(|f| f.data_type().to_owned())
169 .collect::<Vec<_>>();
170
171 rresult!(vec_datatype_to_rvec_wrapped(&return_types))
172}
173
174unsafe extern "C" fn release_fn_wrapper(udwf: &mut FFI_WindowUDF) {
175 let private_data = Box::from_raw(udwf.private_data as *mut WindowUDFPrivateData);
176 drop(private_data);
177}
178
179unsafe extern "C" fn clone_fn_wrapper(udwf: &FFI_WindowUDF) -> FFI_WindowUDF {
180 let private_data = Box::new(WindowUDFPrivateData {
187 udf: Arc::clone(udwf.inner()),
188 });
189
190 FFI_WindowUDF {
191 name: udwf.name.clone(),
192 aliases: udwf.aliases.clone(),
193 volatility: udwf.volatility.clone(),
194 partition_evaluator: partition_evaluator_fn_wrapper,
195 sort_options: udwf.sort_options.clone(),
196 coerce_types: coerce_types_fn_wrapper,
197 field: field_fn_wrapper,
198 clone: clone_fn_wrapper,
199 release: release_fn_wrapper,
200 private_data: Box::into_raw(private_data) as *mut c_void,
201 }
202}
203
204impl Clone for FFI_WindowUDF {
205 fn clone(&self) -> Self {
206 unsafe { (self.clone)(self) }
207 }
208}
209
210impl From<Arc<WindowUDF>> for FFI_WindowUDF {
211 fn from(udf: Arc<WindowUDF>) -> Self {
212 let name = udf.name().into();
213 let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect();
214 let volatility = udf.signature().volatility.into();
215 let sort_options = udf.sort_options().map(|v| (&v).into()).into();
216
217 let private_data = Box::new(WindowUDFPrivateData { udf });
218
219 Self {
220 name,
221 aliases,
222 volatility,
223 partition_evaluator: partition_evaluator_fn_wrapper,
224 sort_options,
225 coerce_types: coerce_types_fn_wrapper,
226 field: field_fn_wrapper,
227 clone: clone_fn_wrapper,
228 release: release_fn_wrapper,
229 private_data: Box::into_raw(private_data) as *mut c_void,
230 }
231 }
232}
233
234impl Drop for FFI_WindowUDF {
235 fn drop(&mut self) {
236 unsafe { (self.release)(self) }
237 }
238}
239
240#[derive(Debug)]
247pub struct ForeignWindowUDF {
248 name: String,
249 aliases: Vec<String>,
250 udf: FFI_WindowUDF,
251 signature: Signature,
252}
253
254unsafe impl Send for ForeignWindowUDF {}
255unsafe impl Sync for ForeignWindowUDF {}
256
257impl PartialEq for ForeignWindowUDF {
258 fn eq(&self, other: &Self) -> bool {
259 std::ptr::eq(self, other)
261 }
262}
263impl Eq for ForeignWindowUDF {}
264impl Hash for ForeignWindowUDF {
265 fn hash<H: Hasher>(&self, state: &mut H) {
266 std::ptr::hash(self, state)
267 }
268}
269
270impl TryFrom<&FFI_WindowUDF> for ForeignWindowUDF {
271 type Error = DataFusionError;
272
273 fn try_from(udf: &FFI_WindowUDF) -> Result<Self, Self::Error> {
274 let name = udf.name.to_owned().into();
275 let signature = Signature::user_defined((&udf.volatility).into());
276
277 let aliases = udf.aliases.iter().map(|s| s.to_string()).collect();
278
279 Ok(Self {
280 name,
281 udf: udf.clone(),
282 aliases,
283 signature,
284 })
285 }
286}
287
288impl WindowUDFImpl for ForeignWindowUDF {
289 fn as_any(&self) -> &dyn std::any::Any {
290 self
291 }
292
293 fn name(&self) -> &str {
294 &self.name
295 }
296
297 fn signature(&self) -> &Signature {
298 &self.signature
299 }
300
301 fn aliases(&self) -> &[String] {
302 &self.aliases
303 }
304
305 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
306 unsafe {
307 let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
308 let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?;
309 Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
310 }
311 }
312
313 fn partition_evaluator(
314 &self,
315 args: datafusion::logical_expr::function::PartitionEvaluatorArgs,
316 ) -> Result<Box<dyn PartitionEvaluator>> {
317 let evaluator = unsafe {
318 let args = FFI_PartitionEvaluatorArgs::try_from(args)?;
319 (self.udf.partition_evaluator)(&self.udf, args)
320 };
321
322 df_result!(evaluator).map(|evaluator| {
323 Box::new(ForeignPartitionEvaluator::from(evaluator))
324 as Box<dyn PartitionEvaluator>
325 })
326 }
327
328 fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
329 unsafe {
330 let input_types = vec_fieldref_to_rvec_wrapped(field_args.input_fields())?;
331 let schema = df_result!((self.udf.field)(
332 &self.udf,
333 input_types,
334 field_args.name().into()
335 ))?;
336 let schema: SchemaRef = schema.into();
337
338 match schema.fields().is_empty() {
339 true => Err(DataFusionError::Execution(
340 "Unable to retrieve field in WindowUDF via FFI".to_string(),
341 )),
342 false => Ok(schema.field(0).to_owned().into()),
343 }
344 }
345 }
346
347 fn sort_options(&self) -> Option<SortOptions> {
348 let options: Option<&FFI_SortOptions> = self.udf.sort_options.as_ref().into();
349 options.map(|s| s.into())
350 }
351}
352
353#[repr(C)]
354#[derive(Debug, StableAbi, Clone)]
355#[allow(non_camel_case_types)]
356pub struct FFI_SortOptions {
357 pub descending: bool,
358 pub nulls_first: bool,
359}
360
361impl From<&SortOptions> for FFI_SortOptions {
362 fn from(value: &SortOptions) -> Self {
363 Self {
364 descending: value.descending,
365 nulls_first: value.nulls_first,
366 }
367 }
368}
369
370impl From<&FFI_SortOptions> for SortOptions {
371 fn from(value: &FFI_SortOptions) -> Self {
372 Self {
373 descending: value.descending,
374 nulls_first: value.nulls_first,
375 }
376 }
377}
378
379#[cfg(test)]
380#[cfg(feature = "integration-tests")]
381mod tests {
382 use crate::tests::create_record_batch;
383 use crate::udwf::{FFI_WindowUDF, ForeignWindowUDF};
384 use arrow::array::{create_array, ArrayRef};
385 use datafusion::functions_window::lead_lag::{lag_udwf, WindowShift};
386 use datafusion::logical_expr::expr::Sort;
387 use datafusion::logical_expr::{col, ExprFunctionExt, WindowUDF, WindowUDFImpl};
388 use datafusion::prelude::SessionContext;
389 use std::sync::Arc;
390
391 fn create_test_foreign_udwf(
392 original_udwf: impl WindowUDFImpl + 'static,
393 ) -> datafusion::common::Result<WindowUDF> {
394 let original_udwf = Arc::new(WindowUDF::from(original_udwf));
395
396 let local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into();
397
398 let foreign_udwf: ForeignWindowUDF = (&local_udwf).try_into()?;
399 Ok(foreign_udwf.into())
400 }
401
402 #[test]
403 fn test_round_trip_udwf() -> datafusion::common::Result<()> {
404 let original_udwf = lag_udwf();
405 let original_name = original_udwf.name().to_owned();
406
407 let local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into();
409
410 let foreign_udwf: ForeignWindowUDF = (&local_udwf).try_into()?;
412 let foreign_udwf: WindowUDF = foreign_udwf.into();
413
414 assert_eq!(original_name, foreign_udwf.name());
415 Ok(())
416 }
417
418 #[tokio::test]
419 async fn test_lag_udwf() -> datafusion::common::Result<()> {
420 let udwf = create_test_foreign_udwf(WindowShift::lag())?;
421
422 let ctx = SessionContext::default();
423 let df = ctx.read_batch(create_record_batch(-5, 5))?;
424
425 let df = df.select(vec![
426 col("a"),
427 udwf.call(vec![col("a")])
428 .order_by(vec![Sort::new(col("a"), true, true)])
429 .build()
430 .unwrap()
431 .alias("lag_a"),
432 ])?;
433
434 df.clone().show().await?;
435
436 let result = df.collect().await?;
437 let expected =
438 create_array!(Int32, [None, Some(-5), Some(-4), Some(-3), Some(-2)])
439 as ArrayRef;
440
441 assert_eq!(result.len(), 1);
442 assert_eq!(result[0].column(1), &expected);
443
444 Ok(())
445 }
446}