use std::ffi::c_void;
use std::sync::Arc;
use abi_stable::StableAbi;
use abi_stable::std_types::{RResult, RVec};
use datafusion_catalog::{TableFunctionImpl, TableProvider};
use datafusion_common::error::Result;
use datafusion_execution::TaskContext;
use datafusion_expr::Expr;
use datafusion_proto::logical_plan::from_proto::parse_exprs;
use datafusion_proto::logical_plan::to_proto::serialize_exprs;
use datafusion_proto::logical_plan::{
DefaultLogicalExtensionCodec, LogicalExtensionCodec,
};
use datafusion_proto::protobuf::LogicalExprList;
use prost::Message;
use tokio::runtime::Handle;
use crate::execution::FFI_TaskContextProvider;
use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
use crate::table_provider::FFI_TableProvider;
use crate::util::FFIResult;
use crate::{df_result, rresult_return};
#[repr(C)]
#[derive(Debug, StableAbi)]
pub struct FFI_TableFunction {
pub call:
unsafe extern "C" fn(udtf: &Self, args: RVec<u8>) -> FFIResult<FFI_TableProvider>,
pub logical_codec: FFI_LogicalExtensionCodec,
pub clone: unsafe extern "C" fn(udtf: &Self) -> Self,
pub release: unsafe extern "C" fn(udtf: &mut Self),
pub private_data: *mut c_void,
pub library_marker_id: extern "C" fn() -> usize,
}
unsafe impl Send for FFI_TableFunction {}
unsafe impl Sync for FFI_TableFunction {}
pub struct TableFunctionPrivateData {
udtf: Arc<dyn TableFunctionImpl>,
runtime: Option<Handle>,
}
impl FFI_TableFunction {
fn inner(&self) -> &Arc<dyn TableFunctionImpl> {
let private_data = self.private_data as *const TableFunctionPrivateData;
unsafe { &(*private_data).udtf }
}
fn runtime(&self) -> Option<Handle> {
let private_data = self.private_data as *const TableFunctionPrivateData;
unsafe { (*private_data).runtime.clone() }
}
}
unsafe extern "C" fn call_fn_wrapper(
udtf: &FFI_TableFunction,
args: RVec<u8>,
) -> FFIResult<FFI_TableProvider> {
let runtime = udtf.runtime();
let udtf_inner = udtf.inner();
let ctx: Arc<TaskContext> =
rresult_return!((&udtf.logical_codec.task_ctx_provider).try_into());
let codec: Arc<dyn LogicalExtensionCodec> = (&udtf.logical_codec).into();
let proto_filters = rresult_return!(LogicalExprList::decode(args.as_ref()));
let args = rresult_return!(parse_exprs(
proto_filters.expr.iter(),
ctx.as_ref(),
codec.as_ref()
));
let table_provider = rresult_return!(udtf_inner.call(&args));
RResult::ROk(FFI_TableProvider::new_with_ffi_codec(
table_provider,
false,
runtime,
udtf.logical_codec.clone(),
))
}
unsafe extern "C" fn release_fn_wrapper(udtf: &mut FFI_TableFunction) {
unsafe {
debug_assert!(!udtf.private_data.is_null());
let private_data =
Box::from_raw(udtf.private_data as *mut TableFunctionPrivateData);
drop(private_data);
udtf.private_data = std::ptr::null_mut();
}
}
unsafe extern "C" fn clone_fn_wrapper(udtf: &FFI_TableFunction) -> FFI_TableFunction {
let runtime = udtf.runtime();
let udtf_inner = udtf.inner();
FFI_TableFunction::new_with_ffi_codec(
Arc::clone(udtf_inner),
runtime,
udtf.logical_codec.clone(),
)
}
impl Clone for FFI_TableFunction {
fn clone(&self) -> Self {
unsafe { (self.clone)(self) }
}
}
impl FFI_TableFunction {
pub fn new(
udtf: Arc<dyn TableFunctionImpl>,
runtime: Option<Handle>,
task_ctx_provider: impl Into<FFI_TaskContextProvider>,
logical_codec: Option<Arc<dyn LogicalExtensionCodec>>,
) -> Self {
let task_ctx_provider = task_ctx_provider.into();
let logical_codec =
logical_codec.unwrap_or_else(|| Arc::new(DefaultLogicalExtensionCodec {}));
let logical_codec = FFI_LogicalExtensionCodec::new(
logical_codec,
runtime.clone(),
task_ctx_provider.clone(),
);
Self::new_with_ffi_codec(udtf, runtime, logical_codec)
}
pub fn new_with_ffi_codec(
udtf: Arc<dyn TableFunctionImpl>,
runtime: Option<Handle>,
logical_codec: FFI_LogicalExtensionCodec,
) -> Self {
let private_data = Box::new(TableFunctionPrivateData { udtf, runtime });
Self {
call: call_fn_wrapper,
logical_codec,
clone: clone_fn_wrapper,
release: release_fn_wrapper,
private_data: Box::into_raw(private_data) as *mut c_void,
library_marker_id: crate::get_library_marker_id,
}
}
}
impl Drop for FFI_TableFunction {
fn drop(&mut self) {
unsafe { (self.release)(self) }
}
}
#[derive(Debug)]
pub struct ForeignTableFunction(FFI_TableFunction);
unsafe impl Send for ForeignTableFunction {}
unsafe impl Sync for ForeignTableFunction {}
impl From<FFI_TableFunction> for Arc<dyn TableFunctionImpl> {
fn from(value: FFI_TableFunction) -> Self {
if (value.library_marker_id)() == crate::get_library_marker_id() {
Arc::clone(value.inner())
} else {
Arc::new(ForeignTableFunction(value))
}
}
}
impl TableFunctionImpl for ForeignTableFunction {
fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
let codec: Arc<dyn LogicalExtensionCodec> = (&self.0.logical_codec).into();
let expr_list = LogicalExprList {
expr: serialize_exprs(args, codec.as_ref())?,
};
let filters_serialized = expr_list.encode_to_vec().into();
let table_provider = unsafe { (self.0.call)(&self.0, filters_serialized) };
let table_provider = df_result!(table_provider)?;
let table_provider: Arc<dyn TableProvider> = (&table_provider).into();
Ok(table_provider)
}
}
#[cfg(test)]
mod tests {
use arrow::array::{
ArrayRef, Float64Array, RecordBatch, StringArray, UInt64Array, record_batch,
};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion::catalog::MemTable;
use datafusion::common::exec_err;
use datafusion::logical_expr::ptr_eq::arc_ptr_eq;
use datafusion::prelude::{SessionContext, lit};
use datafusion::scalar::ScalarValue;
use datafusion_execution::TaskContextProvider;
use super::*;
#[derive(Debug)]
struct TestUDTF {}
impl TableFunctionImpl for TestUDTF {
fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
let args = args
.iter()
.map(|arg| {
if let Expr::Literal(scalar, _) = arg {
Ok(scalar)
} else {
exec_err!("Expected only literal arguments to table udf")
}
})
.collect::<Result<Vec<_>>>()?;
if args.len() < 2 {
exec_err!("Expected at least two arguments to table udf")?
}
let ScalarValue::UInt64(Some(num_rows)) = args[0].to_owned() else {
exec_err!(
"First argument must be the number of elements to create as u64"
)?
};
let num_rows = num_rows as usize;
let mut fields = Vec::default();
let mut arrays1 = Vec::default();
let mut arrays2 = Vec::default();
let split = num_rows / 3;
for (idx, arg) in args[1..].iter().enumerate() {
let (field, array) = match arg {
ScalarValue::Utf8(s) => {
let s_vec = vec![s.to_owned(); num_rows];
(
Field::new(format!("field-{idx}"), DataType::Utf8, true),
Arc::new(StringArray::from(s_vec)) as ArrayRef,
)
}
ScalarValue::UInt64(v) => {
let v_vec = vec![v.to_owned(); num_rows];
(
Field::new(format!("field-{idx}"), DataType::UInt64, true),
Arc::new(UInt64Array::from(v_vec)) as ArrayRef,
)
}
ScalarValue::Float64(v) => {
let v_vec = vec![v.to_owned(); num_rows];
(
Field::new(format!("field-{idx}"), DataType::Float64, true),
Arc::new(Float64Array::from(v_vec)) as ArrayRef,
)
}
_ => exec_err!(
"Test case only supports utf8, u64, and f64. Found {}",
arg.data_type()
)?,
};
fields.push(field);
arrays1.push(array.slice(0, split));
arrays2.push(array.slice(split, num_rows - split));
}
let schema = Arc::new(Schema::new(fields));
let batches = vec![
RecordBatch::try_new(Arc::clone(&schema), arrays1)?,
RecordBatch::try_new(Arc::clone(&schema), arrays2)?,
];
let table_provider = MemTable::try_new(schema, vec![batches])?;
Ok(Arc::new(table_provider))
}
}
#[tokio::test]
async fn test_round_trip_udtf() -> Result<()> {
let original_udtf = Arc::new(TestUDTF {}) as Arc<dyn TableFunctionImpl>;
let ctx = Arc::new(SessionContext::default());
let task_ctx_provider = Arc::clone(&ctx) as Arc<dyn TaskContextProvider>;
let task_ctx_provider = FFI_TaskContextProvider::from(&task_ctx_provider);
let mut local_udtf: FFI_TableFunction = FFI_TableFunction::new(
Arc::clone(&original_udtf),
None,
task_ctx_provider,
None,
);
local_udtf.library_marker_id = crate::mock_foreign_marker_id;
let foreign_udf: Arc<dyn TableFunctionImpl> = local_udtf.into();
let table = foreign_udf.call(&[lit(6_u64), lit("one"), lit(2.0), lit(3_u64)])?;
let _ = ctx.register_table("test-table", table)?;
let returned_batches = ctx.table("test-table").await?.collect().await?;
assert_eq!(returned_batches.len(), 2);
let expected_batch_0 = record_batch!(
("field-0", Utf8, ["one", "one"]),
("field-1", Float64, [2.0, 2.0]),
("field-2", UInt64, [3, 3])
)?;
assert_eq!(returned_batches[0], expected_batch_0);
let expected_batch_1 = record_batch!(
("field-0", Utf8, ["one", "one", "one", "one"]),
("field-1", Float64, [2.0, 2.0, 2.0, 2.0]),
("field-2", UInt64, [3, 3, 3, 3])
)?;
assert_eq!(returned_batches[1], expected_batch_1);
Ok(())
}
#[test]
fn test_ffi_udtf_local_bypass() -> Result<()> {
let original_udtf = Arc::new(TestUDTF {}) as Arc<dyn TableFunctionImpl>;
let ctx = Arc::new(SessionContext::default()) as Arc<dyn TaskContextProvider>;
let task_ctx_provider = FFI_TaskContextProvider::from(&ctx);
let mut ffi_udtf = FFI_TableFunction::new(
Arc::clone(&original_udtf),
None,
task_ctx_provider,
None,
);
let foreign_udtf: Arc<dyn TableFunctionImpl> = ffi_udtf.clone().into();
assert!(arc_ptr_eq(&original_udtf, &foreign_udtf));
ffi_udtf.library_marker_id = crate::mock_foreign_marker_id;
let foreign_udtf: Arc<dyn TableFunctionImpl> = ffi_udtf.into();
assert!(!arc_ptr_eq(&original_udtf, &foreign_udtf));
Ok(())
}
}