1use std::ffi::c_void;
19use std::sync::Arc;
20
21use abi_stable::StableAbi;
22use abi_stable::std_types::{RResult, RVec};
23use datafusion_catalog::{TableFunctionImpl, TableProvider};
24use datafusion_common::error::Result;
25use datafusion_execution::TaskContext;
26use datafusion_expr::Expr;
27use datafusion_proto::logical_plan::from_proto::parse_exprs;
28use datafusion_proto::logical_plan::to_proto::serialize_exprs;
29use datafusion_proto::logical_plan::{
30 DefaultLogicalExtensionCodec, LogicalExtensionCodec,
31};
32use datafusion_proto::protobuf::LogicalExprList;
33use prost::Message;
34use tokio::runtime::Handle;
35
36use crate::execution::FFI_TaskContextProvider;
37use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
38use crate::table_provider::FFI_TableProvider;
39use crate::util::FFIResult;
40use crate::{df_result, rresult_return};
41
42#[repr(C)]
44#[derive(Debug, StableAbi)]
45pub struct FFI_TableFunction {
46 pub call:
49 unsafe extern "C" fn(udtf: &Self, args: RVec<u8>) -> FFIResult<FFI_TableProvider>,
50
51 pub logical_codec: FFI_LogicalExtensionCodec,
52
53 pub clone: unsafe extern "C" fn(udtf: &Self) -> Self,
56
57 pub release: unsafe extern "C" fn(udtf: &mut Self),
59
60 pub private_data: *mut c_void,
63
64 pub library_marker_id: extern "C" fn() -> usize,
68}
69
70unsafe impl Send for FFI_TableFunction {}
71unsafe impl Sync for FFI_TableFunction {}
72
73pub struct TableFunctionPrivateData {
74 udtf: Arc<dyn TableFunctionImpl>,
75 runtime: Option<Handle>,
76}
77
78impl FFI_TableFunction {
79 fn inner(&self) -> &Arc<dyn TableFunctionImpl> {
80 let private_data = self.private_data as *const TableFunctionPrivateData;
81 unsafe { &(*private_data).udtf }
82 }
83
84 fn runtime(&self) -> Option<Handle> {
85 let private_data = self.private_data as *const TableFunctionPrivateData;
86 unsafe { (*private_data).runtime.clone() }
87 }
88}
89
90unsafe extern "C" fn call_fn_wrapper(
91 udtf: &FFI_TableFunction,
92 args: RVec<u8>,
93) -> FFIResult<FFI_TableProvider> {
94 let runtime = udtf.runtime();
95 let udtf_inner = udtf.inner();
96
97 let ctx: Arc<TaskContext> =
98 rresult_return!((&udtf.logical_codec.task_ctx_provider).try_into());
99 let codec: Arc<dyn LogicalExtensionCodec> = (&udtf.logical_codec).into();
100
101 let proto_filters = rresult_return!(LogicalExprList::decode(args.as_ref()));
102
103 let args = rresult_return!(parse_exprs(
104 proto_filters.expr.iter(),
105 ctx.as_ref(),
106 codec.as_ref()
107 ));
108
109 let table_provider = rresult_return!(udtf_inner.call(&args));
110 RResult::ROk(FFI_TableProvider::new_with_ffi_codec(
111 table_provider,
112 false,
113 runtime,
114 udtf.logical_codec.clone(),
115 ))
116}
117
118unsafe extern "C" fn release_fn_wrapper(udtf: &mut FFI_TableFunction) {
119 unsafe {
120 debug_assert!(!udtf.private_data.is_null());
121 let private_data =
122 Box::from_raw(udtf.private_data as *mut TableFunctionPrivateData);
123 drop(private_data);
124 udtf.private_data = std::ptr::null_mut();
125 }
126}
127
128unsafe extern "C" fn clone_fn_wrapper(udtf: &FFI_TableFunction) -> FFI_TableFunction {
129 let runtime = udtf.runtime();
130 let udtf_inner = udtf.inner();
131
132 FFI_TableFunction::new_with_ffi_codec(
133 Arc::clone(udtf_inner),
134 runtime,
135 udtf.logical_codec.clone(),
136 )
137}
138
139impl Clone for FFI_TableFunction {
140 fn clone(&self) -> Self {
141 unsafe { (self.clone)(self) }
142 }
143}
144
145impl FFI_TableFunction {
146 pub fn new(
147 udtf: Arc<dyn TableFunctionImpl>,
148 runtime: Option<Handle>,
149 task_ctx_provider: impl Into<FFI_TaskContextProvider>,
150 logical_codec: Option<Arc<dyn LogicalExtensionCodec>>,
151 ) -> Self {
152 let task_ctx_provider = task_ctx_provider.into();
153 let logical_codec =
154 logical_codec.unwrap_or_else(|| Arc::new(DefaultLogicalExtensionCodec {}));
155 let logical_codec = FFI_LogicalExtensionCodec::new(
156 logical_codec,
157 runtime.clone(),
158 task_ctx_provider.clone(),
159 );
160
161 Self::new_with_ffi_codec(udtf, runtime, logical_codec)
162 }
163
164 pub fn new_with_ffi_codec(
165 udtf: Arc<dyn TableFunctionImpl>,
166 runtime: Option<Handle>,
167 logical_codec: FFI_LogicalExtensionCodec,
168 ) -> Self {
169 let private_data = Box::new(TableFunctionPrivateData { udtf, runtime });
170
171 Self {
172 call: call_fn_wrapper,
173 logical_codec,
174 clone: clone_fn_wrapper,
175 release: release_fn_wrapper,
176 private_data: Box::into_raw(private_data) as *mut c_void,
177 library_marker_id: crate::get_library_marker_id,
178 }
179 }
180}
181
182impl Drop for FFI_TableFunction {
183 fn drop(&mut self) {
184 unsafe { (self.release)(self) }
185 }
186}
187
188#[derive(Debug)]
195pub struct ForeignTableFunction(FFI_TableFunction);
196
197unsafe impl Send for ForeignTableFunction {}
198unsafe impl Sync for ForeignTableFunction {}
199
200impl From<FFI_TableFunction> for Arc<dyn TableFunctionImpl> {
201 fn from(value: FFI_TableFunction) -> Self {
202 if (value.library_marker_id)() == crate::get_library_marker_id() {
203 Arc::clone(value.inner())
204 } else {
205 Arc::new(ForeignTableFunction(value))
206 }
207 }
208}
209
210impl TableFunctionImpl for ForeignTableFunction {
211 fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
212 let codec: Arc<dyn LogicalExtensionCodec> = (&self.0.logical_codec).into();
213 let expr_list = LogicalExprList {
214 expr: serialize_exprs(args, codec.as_ref())?,
215 };
216 let filters_serialized = expr_list.encode_to_vec().into();
217
218 let table_provider = unsafe { (self.0.call)(&self.0, filters_serialized) };
219
220 let table_provider = df_result!(table_provider)?;
221 let table_provider: Arc<dyn TableProvider> = (&table_provider).into();
222
223 Ok(table_provider)
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use arrow::array::{
230 ArrayRef, Float64Array, RecordBatch, StringArray, UInt64Array, record_batch,
231 };
232 use arrow::datatypes::{DataType, Field, Schema};
233 use datafusion::catalog::MemTable;
234 use datafusion::common::exec_err;
235 use datafusion::logical_expr::ptr_eq::arc_ptr_eq;
236 use datafusion::prelude::{SessionContext, lit};
237 use datafusion::scalar::ScalarValue;
238 use datafusion_execution::TaskContextProvider;
239
240 use super::*;
241
242 #[derive(Debug)]
243 struct TestUDTF {}
244
245 impl TableFunctionImpl for TestUDTF {
246 fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
247 let args = args
248 .iter()
249 .map(|arg| {
250 if let Expr::Literal(scalar, _) = arg {
251 Ok(scalar)
252 } else {
253 exec_err!("Expected only literal arguments to table udf")
254 }
255 })
256 .collect::<Result<Vec<_>>>()?;
257
258 if args.len() < 2 {
259 exec_err!("Expected at least two arguments to table udf")?
260 }
261
262 let ScalarValue::UInt64(Some(num_rows)) = args[0].to_owned() else {
263 exec_err!(
264 "First argument must be the number of elements to create as u64"
265 )?
266 };
267 let num_rows = num_rows as usize;
268
269 let mut fields = Vec::default();
270 let mut arrays1 = Vec::default();
271 let mut arrays2 = Vec::default();
272
273 let split = num_rows / 3;
274 for (idx, arg) in args[1..].iter().enumerate() {
275 let (field, array) = match arg {
276 ScalarValue::Utf8(s) => {
277 let s_vec = vec![s.to_owned(); num_rows];
278 (
279 Field::new(format!("field-{idx}"), DataType::Utf8, true),
280 Arc::new(StringArray::from(s_vec)) as ArrayRef,
281 )
282 }
283 ScalarValue::UInt64(v) => {
284 let v_vec = vec![v.to_owned(); num_rows];
285 (
286 Field::new(format!("field-{idx}"), DataType::UInt64, true),
287 Arc::new(UInt64Array::from(v_vec)) as ArrayRef,
288 )
289 }
290 ScalarValue::Float64(v) => {
291 let v_vec = vec![v.to_owned(); num_rows];
292 (
293 Field::new(format!("field-{idx}"), DataType::Float64, true),
294 Arc::new(Float64Array::from(v_vec)) as ArrayRef,
295 )
296 }
297 _ => exec_err!(
298 "Test case only supports utf8, u64, and f64. Found {}",
299 arg.data_type()
300 )?,
301 };
302
303 fields.push(field);
304 arrays1.push(array.slice(0, split));
305 arrays2.push(array.slice(split, num_rows - split));
306 }
307
308 let schema = Arc::new(Schema::new(fields));
309 let batches = vec![
310 RecordBatch::try_new(Arc::clone(&schema), arrays1)?,
311 RecordBatch::try_new(Arc::clone(&schema), arrays2)?,
312 ];
313
314 let table_provider = MemTable::try_new(schema, vec![batches])?;
315
316 Ok(Arc::new(table_provider))
317 }
318 }
319
320 #[tokio::test]
321 async fn test_round_trip_udtf() -> Result<()> {
322 let original_udtf = Arc::new(TestUDTF {}) as Arc<dyn TableFunctionImpl>;
323 let ctx = Arc::new(SessionContext::default());
324 let task_ctx_provider = Arc::clone(&ctx) as Arc<dyn TaskContextProvider>;
325 let task_ctx_provider = FFI_TaskContextProvider::from(&task_ctx_provider);
326
327 let mut local_udtf: FFI_TableFunction = FFI_TableFunction::new(
328 Arc::clone(&original_udtf),
329 None,
330 task_ctx_provider,
331 None,
332 );
333 local_udtf.library_marker_id = crate::mock_foreign_marker_id;
334
335 let foreign_udf: Arc<dyn TableFunctionImpl> = local_udtf.into();
336
337 let table = foreign_udf.call(&[lit(6_u64), lit("one"), lit(2.0), lit(3_u64)])?;
338
339 let _ = ctx.register_table("test-table", table)?;
340
341 let returned_batches = ctx.table("test-table").await?.collect().await?;
342
343 assert_eq!(returned_batches.len(), 2);
344 let expected_batch_0 = record_batch!(
345 ("field-0", Utf8, ["one", "one"]),
346 ("field-1", Float64, [2.0, 2.0]),
347 ("field-2", UInt64, [3, 3])
348 )?;
349 assert_eq!(returned_batches[0], expected_batch_0);
350
351 let expected_batch_1 = record_batch!(
352 ("field-0", Utf8, ["one", "one", "one", "one"]),
353 ("field-1", Float64, [2.0, 2.0, 2.0, 2.0]),
354 ("field-2", UInt64, [3, 3, 3, 3])
355 )?;
356 assert_eq!(returned_batches[1], expected_batch_1);
357
358 Ok(())
359 }
360
361 #[test]
362 fn test_ffi_udtf_local_bypass() -> Result<()> {
363 let original_udtf = Arc::new(TestUDTF {}) as Arc<dyn TableFunctionImpl>;
364
365 let ctx = Arc::new(SessionContext::default()) as Arc<dyn TaskContextProvider>;
366 let task_ctx_provider = FFI_TaskContextProvider::from(&ctx);
367 let mut ffi_udtf = FFI_TableFunction::new(
368 Arc::clone(&original_udtf),
369 None,
370 task_ctx_provider,
371 None,
372 );
373
374 let foreign_udtf: Arc<dyn TableFunctionImpl> = ffi_udtf.clone().into();
376 assert!(arc_ptr_eq(&original_udtf, &foreign_udtf));
377
378 ffi_udtf.library_marker_id = crate::mock_foreign_marker_id;
380 let foreign_udtf: Arc<dyn TableFunctionImpl> = ffi_udtf.into();
381 assert!(!arc_ptr_eq(&original_udtf, &foreign_udtf));
382
383 Ok(())
384 }
385}