use std::ffi::c_void;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use abi_stable::StableAbi;
use abi_stable::std_types::{ROption, RResult, RString, RVec};
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Schema, SchemaRef};
use arrow_schema::{Field, FieldRef};
use datafusion_common::{Result, ffi_err};
use datafusion_expr::function::WindowUDFFieldArgs;
use datafusion_expr::type_coercion::functions::fields_with_udf;
use datafusion_expr::{
LimitEffect, PartitionEvaluator, Signature, WindowUDF, WindowUDFImpl,
};
use datafusion_physical_expr::PhysicalExpr;
use partition_evaluator::FFI_PartitionEvaluator;
use partition_evaluator_args::{
FFI_PartitionEvaluatorArgs, ForeignPartitionEvaluatorArgs,
};
mod partition_evaluator;
mod partition_evaluator_args;
mod range;
use crate::arrow_wrappers::WrappedSchema;
use crate::util::{
FFIResult, rvec_wrapped_to_vec_datatype, rvec_wrapped_to_vec_fieldref,
vec_datatype_to_rvec_wrapped, vec_fieldref_to_rvec_wrapped,
};
use crate::volatility::FFI_Volatility;
use crate::{df_result, rresult, rresult_return};
#[repr(C)]
#[derive(Debug, StableAbi)]
pub struct FFI_WindowUDF {
pub name: RString,
pub aliases: RVec<RString>,
pub volatility: FFI_Volatility,
pub partition_evaluator: unsafe extern "C" fn(
udwf: &Self,
args: FFI_PartitionEvaluatorArgs,
)
-> FFIResult<FFI_PartitionEvaluator>,
pub field: unsafe extern "C" fn(
udwf: &Self,
input_types: RVec<WrappedSchema>,
display_name: RString,
) -> FFIResult<WrappedSchema>,
pub coerce_types: unsafe extern "C" fn(
udf: &Self,
arg_types: RVec<WrappedSchema>,
) -> FFIResult<RVec<WrappedSchema>>,
pub sort_options: ROption<FFI_SortOptions>,
pub clone: unsafe extern "C" fn(udf: &Self) -> Self,
pub release: unsafe extern "C" fn(udf: &mut Self),
pub private_data: *mut c_void,
pub library_marker_id: extern "C" fn() -> usize,
}
unsafe impl Send for FFI_WindowUDF {}
unsafe impl Sync for FFI_WindowUDF {}
pub struct WindowUDFPrivateData {
pub udf: Arc<WindowUDF>,
}
impl FFI_WindowUDF {
unsafe fn inner(&self) -> &Arc<WindowUDF> {
unsafe {
let private_data = self.private_data as *const WindowUDFPrivateData;
&(*private_data).udf
}
}
}
unsafe extern "C" fn partition_evaluator_fn_wrapper(
udwf: &FFI_WindowUDF,
args: FFI_PartitionEvaluatorArgs,
) -> FFIResult<FFI_PartitionEvaluator> {
unsafe {
let inner = udwf.inner();
let args = rresult_return!(ForeignPartitionEvaluatorArgs::try_from(args));
let evaluator =
rresult_return!(inner.partition_evaluator_factory((&args).into()));
RResult::ROk(evaluator.into())
}
}
unsafe extern "C" fn field_fn_wrapper(
udwf: &FFI_WindowUDF,
input_fields: RVec<WrappedSchema>,
display_name: RString,
) -> FFIResult<WrappedSchema> {
unsafe {
let inner = udwf.inner();
let input_fields = rresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields));
let field = rresult_return!(inner.field(WindowUDFFieldArgs::new(
&input_fields,
display_name.as_str()
)));
let schema = Arc::new(Schema::new(vec![field]));
RResult::ROk(WrappedSchema::from(schema))
}
}
unsafe extern "C" fn coerce_types_fn_wrapper(
udwf: &FFI_WindowUDF,
arg_types: RVec<WrappedSchema>,
) -> FFIResult<RVec<WrappedSchema>> {
unsafe {
let inner = udwf.inner();
let arg_fields = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types))
.into_iter()
.map(|dt| Field::new("f", dt, false))
.map(Arc::new)
.collect::<Vec<_>>();
let return_fields = rresult_return!(fields_with_udf(&arg_fields, inner.as_ref()));
let return_types = return_fields
.into_iter()
.map(|f| f.data_type().to_owned())
.collect::<Vec<_>>();
rresult!(vec_datatype_to_rvec_wrapped(&return_types))
}
}
unsafe extern "C" fn release_fn_wrapper(udwf: &mut FFI_WindowUDF) {
unsafe {
debug_assert!(!udwf.private_data.is_null());
let private_data = Box::from_raw(udwf.private_data as *mut WindowUDFPrivateData);
drop(private_data);
udwf.private_data = std::ptr::null_mut();
}
}
unsafe extern "C" fn clone_fn_wrapper(udwf: &FFI_WindowUDF) -> FFI_WindowUDF {
unsafe {
let private_data = Box::new(WindowUDFPrivateData {
udf: Arc::clone(udwf.inner()),
});
FFI_WindowUDF {
name: udwf.name.clone(),
aliases: udwf.aliases.clone(),
volatility: udwf.volatility.clone(),
partition_evaluator: partition_evaluator_fn_wrapper,
sort_options: udwf.sort_options.clone(),
coerce_types: coerce_types_fn_wrapper,
field: field_fn_wrapper,
clone: clone_fn_wrapper,
release: release_fn_wrapper,
private_data: Box::into_raw(private_data) as *mut c_void,
library_marker_id: crate::get_library_marker_id,
}
}
}
impl Clone for FFI_WindowUDF {
fn clone(&self) -> Self {
unsafe { (self.clone)(self) }
}
}
impl From<Arc<WindowUDF>> for FFI_WindowUDF {
fn from(udf: Arc<WindowUDF>) -> Self {
let name = udf.name().into();
let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect();
let volatility = udf.signature().volatility.into();
let sort_options = udf.sort_options().map(|v| (&v).into()).into();
let private_data = Box::new(WindowUDFPrivateData { udf });
Self {
name,
aliases,
volatility,
partition_evaluator: partition_evaluator_fn_wrapper,
sort_options,
coerce_types: coerce_types_fn_wrapper,
field: field_fn_wrapper,
clone: clone_fn_wrapper,
release: release_fn_wrapper,
private_data: Box::into_raw(private_data) as *mut c_void,
library_marker_id: crate::get_library_marker_id,
}
}
}
impl Drop for FFI_WindowUDF {
fn drop(&mut self) {
unsafe { (self.release)(self) }
}
}
#[derive(Debug)]
pub struct ForeignWindowUDF {
name: String,
aliases: Vec<String>,
udf: FFI_WindowUDF,
signature: Signature,
}
unsafe impl Send for ForeignWindowUDF {}
unsafe impl Sync for ForeignWindowUDF {}
impl PartialEq for ForeignWindowUDF {
fn eq(&self, other: &Self) -> bool {
std::ptr::eq(self, other)
}
}
impl Eq for ForeignWindowUDF {}
impl Hash for ForeignWindowUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
std::ptr::hash(self, state)
}
}
impl From<&FFI_WindowUDF> for Arc<dyn WindowUDFImpl> {
fn from(udf: &FFI_WindowUDF) -> Self {
if (udf.library_marker_id)() == crate::get_library_marker_id() {
Arc::clone(unsafe { udf.inner().inner() })
} else {
let name = udf.name.to_owned().into();
let signature = Signature::user_defined((&udf.volatility).into());
let aliases = udf.aliases.iter().map(|s| s.to_string()).collect();
Arc::new(ForeignWindowUDF {
name,
udf: udf.clone(),
aliases,
signature,
})
}
}
}
impl WindowUDFImpl for ForeignWindowUDF {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn aliases(&self) -> &[String] {
&self.aliases
}
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
unsafe {
let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?;
Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
}
}
fn partition_evaluator(
&self,
args: datafusion_expr::function::PartitionEvaluatorArgs,
) -> Result<Box<dyn PartitionEvaluator>> {
let evaluator = unsafe {
let args = FFI_PartitionEvaluatorArgs::try_from(args)?;
(self.udf.partition_evaluator)(&self.udf, args)
};
df_result!(evaluator).map(<Box<dyn PartitionEvaluator>>::from)
}
fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
unsafe {
let input_types = vec_fieldref_to_rvec_wrapped(field_args.input_fields())?;
let schema = df_result!((self.udf.field)(
&self.udf,
input_types,
field_args.name().into()
))?;
let schema: SchemaRef = schema.into();
match schema.fields().is_empty() {
true => ffi_err!(
"Unable to retrieve field in WindowUDF via FFI - schema has no fields"
),
false => Ok(schema.field(0).to_owned().into()),
}
}
}
fn sort_options(&self) -> Option<SortOptions> {
let options: Option<&FFI_SortOptions> = self.udf.sort_options.as_ref().into();
options.map(|s| s.into())
}
fn limit_effect(&self, _args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
LimitEffect::Unknown
}
}
#[repr(C)]
#[derive(Debug, StableAbi, Clone)]
pub struct FFI_SortOptions {
pub descending: bool,
pub nulls_first: bool,
}
impl From<&SortOptions> for FFI_SortOptions {
fn from(value: &SortOptions) -> Self {
Self {
descending: value.descending,
nulls_first: value.nulls_first,
}
}
}
impl From<&FFI_SortOptions> for SortOptions {
fn from(value: &FFI_SortOptions) -> Self {
Self {
descending: value.descending,
nulls_first: value.nulls_first,
}
}
}
#[cfg(test)]
#[cfg(feature = "integration-tests")]
mod tests {
use std::sync::Arc;
use arrow::array::{ArrayRef, create_array};
use datafusion::functions_window::lead_lag::{WindowShift, lag_udwf};
use datafusion::logical_expr::expr::Sort;
use datafusion::logical_expr::{ExprFunctionExt, WindowUDF, WindowUDFImpl, col};
use datafusion::prelude::SessionContext;
use crate::tests::create_record_batch;
use crate::udwf::{FFI_WindowUDF, ForeignWindowUDF};
fn create_test_foreign_udwf(
original_udwf: impl WindowUDFImpl + 'static,
) -> datafusion::common::Result<WindowUDF> {
let original_udwf = Arc::new(WindowUDF::from(original_udwf));
let mut local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into();
local_udwf.library_marker_id = crate::mock_foreign_marker_id;
let foreign_udwf: Arc<dyn WindowUDFImpl> = (&local_udwf).into();
Ok(WindowUDF::new_from_shared_impl(foreign_udwf))
}
#[test]
fn test_round_trip_udwf() -> datafusion::common::Result<()> {
let original_udwf = lag_udwf();
let original_name = original_udwf.name().to_owned();
let mut local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into();
local_udwf.library_marker_id = crate::mock_foreign_marker_id;
let foreign_udwf: Arc<dyn WindowUDFImpl> = (&local_udwf).into();
let foreign_udwf = WindowUDF::new_from_shared_impl(foreign_udwf);
assert_eq!(original_name, foreign_udwf.name());
Ok(())
}
#[tokio::test]
async fn test_lag_udwf() -> datafusion::common::Result<()> {
let udwf = create_test_foreign_udwf(WindowShift::lag())?;
let ctx = SessionContext::default();
let df = ctx.read_batch(create_record_batch(-5, 5))?;
let df = df.select(vec![
col("a"),
udwf.call(vec![col("a")])
.order_by(vec![Sort::new(col("a"), true, true)])
.build()
.unwrap()
.alias("lag_a"),
])?;
df.clone().show().await?;
let result = df.collect().await?;
let expected =
create_array!(Int32, [None, Some(-5), Some(-4), Some(-3), Some(-2)])
as ArrayRef;
assert_eq!(result.len(), 1);
assert_eq!(result[0].column(1), &expected);
Ok(())
}
#[test]
fn test_ffi_udwf_local_bypass() -> datafusion_common::Result<()> {
let original_udwf = Arc::new(WindowUDF::from(WindowShift::lag()));
let mut ffi_udwf = FFI_WindowUDF::from(original_udwf);
let foreign_udwf: Arc<dyn WindowUDFImpl> = (&ffi_udwf).into();
assert!(
foreign_udwf
.as_any()
.downcast_ref::<WindowShift>()
.is_some()
);
ffi_udwf.library_marker_id = crate::mock_foreign_marker_id;
let foreign_udwf: Arc<dyn WindowUDFImpl> = (&ffi_udwf).into();
assert!(
foreign_udwf
.as_any()
.downcast_ref::<ForeignWindowUDF>()
.is_some()
);
Ok(())
}
}