1use std::any::Any;
19use std::ffi::c_void;
20use std::sync::Arc;
21
22use datafusion_catalog::{TableFunctionArgs, TableFunctionImpl, TableProvider};
23use datafusion_common::DataFusionError;
24use datafusion_common::error::Result;
25use datafusion_execution::TaskContext;
26use datafusion_proto::logical_plan::from_proto::parse_exprs;
27use datafusion_proto::logical_plan::to_proto::serialize_exprs;
28use datafusion_proto::logical_plan::{
29 DefaultLogicalExtensionCodec, LogicalExtensionCodec,
30};
31use datafusion_proto::protobuf::LogicalExprList;
32use datafusion_session::Session;
33use prost::Message;
34use stabby::vec::Vec as SVec;
35use tokio::runtime::Handle;
36
37use crate::execution::FFI_TaskContextProvider;
38use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
39use crate::session::{FFI_SessionRef, ForeignSession};
40use crate::table_provider::FFI_TableProvider;
41use crate::util::FFI_Result;
42use crate::{df_result, sresult_return};
43
44#[repr(C)]
46#[derive(Debug)]
47pub struct FFI_TableFunction {
48 #[deprecated(
51 since = "53.0.0",
52 note = "See TableFunctionImpl::call deprecation note"
53 )]
54 pub call: unsafe extern "C" fn(
55 udtf: &Self,
56 args: SVec<u8>,
57 ) -> FFI_Result<FFI_TableProvider>,
58
59 call_with_args: unsafe extern "C" fn(
61 udtf: &Self,
62 args: SVec<u8>,
63 session: FFI_SessionRef,
64 ) -> FFI_Result<FFI_TableProvider>,
65
66 pub logical_codec: FFI_LogicalExtensionCodec,
67
68 pub clone: unsafe extern "C" fn(udtf: &Self) -> Self,
71
72 pub release: unsafe extern "C" fn(udtf: &mut Self),
74
75 pub private_data: *mut c_void,
78
79 pub library_marker_id: extern "C" fn() -> usize,
83}
84
85unsafe impl Send for FFI_TableFunction {}
86unsafe impl Sync for FFI_TableFunction {}
87
88pub struct TableFunctionPrivateData {
89 udtf: Arc<dyn TableFunctionImpl>,
90 runtime: Option<Handle>,
91}
92
93impl FFI_TableFunction {
94 fn inner(&self) -> &Arc<dyn TableFunctionImpl> {
95 let private_data = self.private_data as *const TableFunctionPrivateData;
96 unsafe { &(*private_data).udtf }
97 }
98
99 fn runtime(&self) -> Option<Handle> {
100 let private_data = self.private_data as *const TableFunctionPrivateData;
101 unsafe { (*private_data).runtime.clone() }
102 }
103}
104
105unsafe extern "C" fn call_fn_wrapper(
106 udtf: &FFI_TableFunction,
107 args: SVec<u8>,
108) -> FFI_Result<FFI_TableProvider> {
109 let runtime = udtf.runtime();
110 let udtf_inner = udtf.inner();
111
112 let ctx: Arc<TaskContext> =
113 sresult_return!((&udtf.logical_codec.task_ctx_provider).try_into());
114 let codec: Arc<dyn LogicalExtensionCodec> = (&udtf.logical_codec).into();
115
116 let proto_filters = sresult_return!(LogicalExprList::decode(args.as_ref()));
117
118 let args = sresult_return!(parse_exprs(
119 proto_filters.expr.iter(),
120 ctx.as_ref(),
121 codec.as_ref()
122 ));
123
124 #[expect(deprecated)]
125 let table_provider = sresult_return!(udtf_inner.call(&args));
126 FFI_Result::Ok(FFI_TableProvider::new_with_ffi_codec(
127 table_provider,
128 false,
129 runtime,
130 udtf.logical_codec.clone(),
131 ))
132}
133
134unsafe extern "C" fn call_with_args_wrapper(
135 udtf: &FFI_TableFunction,
136 args: SVec<u8>,
137 session: FFI_SessionRef,
138) -> FFI_Result<FFI_TableProvider> {
139 let runtime = udtf.runtime();
140 let udtf_inner = udtf.inner();
141
142 let ctx: Arc<TaskContext> =
143 sresult_return!((&udtf.logical_codec.task_ctx_provider).try_into());
144 let codec: Arc<dyn LogicalExtensionCodec> = (&udtf.logical_codec).into();
145
146 let proto_filters = sresult_return!(LogicalExprList::decode(args.as_ref()));
147
148 let args = sresult_return!(parse_exprs(
149 proto_filters.expr.iter(),
150 ctx.as_ref(),
151 codec.as_ref()
152 ));
153
154 let mut foreign_session = None;
155 let session = sresult_return!(
156 session
157 .as_local()
158 .map(Ok::<&(dyn Session + Send + Sync), DataFusionError>)
159 .unwrap_or_else(|| {
160 foreign_session = Some(ForeignSession::try_from(&session)?);
161 Ok(foreign_session.as_ref().unwrap())
162 })
163 );
164 let table_provider = sresult_return!(
165 udtf_inner.call_with_args(TableFunctionArgs::new(&args, session))
166 );
167 FFI_Result::Ok(FFI_TableProvider::new_with_ffi_codec(
168 table_provider,
169 false,
170 runtime,
171 udtf.logical_codec.clone(),
172 ))
173}
174
175unsafe extern "C" fn release_fn_wrapper(udtf: &mut FFI_TableFunction) {
176 unsafe {
177 debug_assert!(!udtf.private_data.is_null());
178 let private_data =
179 Box::from_raw(udtf.private_data as *mut TableFunctionPrivateData);
180 drop(private_data);
181 udtf.private_data = std::ptr::null_mut();
182 }
183}
184
185unsafe extern "C" fn clone_fn_wrapper(udtf: &FFI_TableFunction) -> FFI_TableFunction {
186 let runtime = udtf.runtime();
187 let udtf_inner = udtf.inner();
188
189 FFI_TableFunction::new_with_ffi_codec(
190 Arc::clone(udtf_inner),
191 runtime,
192 udtf.logical_codec.clone(),
193 )
194}
195
196impl Clone for FFI_TableFunction {
197 fn clone(&self) -> Self {
198 unsafe { (self.clone)(self) }
199 }
200}
201
202impl FFI_TableFunction {
203 pub fn new(
204 udtf: Arc<dyn TableFunctionImpl>,
205 runtime: Option<Handle>,
206 task_ctx_provider: impl Into<FFI_TaskContextProvider>,
207 logical_codec: Option<Arc<dyn LogicalExtensionCodec>>,
208 ) -> Self {
209 let task_ctx_provider = task_ctx_provider.into();
210 let logical_codec =
211 logical_codec.unwrap_or_else(|| Arc::new(DefaultLogicalExtensionCodec {}));
212 let logical_codec = FFI_LogicalExtensionCodec::new(
213 logical_codec,
214 runtime.clone(),
215 task_ctx_provider.clone(),
216 );
217
218 Self::new_with_ffi_codec(udtf, runtime, logical_codec)
219 }
220
221 pub fn new_with_ffi_codec(
222 udtf: Arc<dyn TableFunctionImpl>,
223 runtime: Option<Handle>,
224 logical_codec: FFI_LogicalExtensionCodec,
225 ) -> Self {
226 if let Some(udtf) =
227 (Arc::clone(&udtf) as Arc<dyn Any>).downcast_ref::<ForeignTableFunction>()
228 {
229 return udtf.0.clone();
230 }
231
232 let private_data = Box::new(TableFunctionPrivateData { udtf, runtime });
233
234 Self {
235 #[expect(deprecated)]
236 call: call_fn_wrapper,
237 call_with_args: call_with_args_wrapper,
238 logical_codec,
239 clone: clone_fn_wrapper,
240 release: release_fn_wrapper,
241 private_data: Box::into_raw(private_data) as *mut c_void,
242 library_marker_id: crate::get_library_marker_id,
243 }
244 }
245}
246
247impl Drop for FFI_TableFunction {
248 fn drop(&mut self) {
249 unsafe { (self.release)(self) }
250 }
251}
252
253#[derive(Debug)]
260pub struct ForeignTableFunction(FFI_TableFunction);
261
262unsafe impl Send for ForeignTableFunction {}
263unsafe impl Sync for ForeignTableFunction {}
264
265impl From<FFI_TableFunction> for Arc<dyn TableFunctionImpl> {
266 fn from(value: FFI_TableFunction) -> Self {
267 if (value.library_marker_id)() == crate::get_library_marker_id() {
268 Arc::clone(value.inner())
269 } else {
270 Arc::new(ForeignTableFunction(value))
271 }
272 }
273}
274
275impl TableFunctionImpl for ForeignTableFunction {
276 fn call_with_args(&self, args: TableFunctionArgs) -> Result<Arc<dyn TableProvider>> {
277 let session = FFI_SessionRef::new(
278 args.session(),
279 self.0.runtime(),
280 self.0.logical_codec.clone(),
281 );
282 let codec: Arc<dyn LogicalExtensionCodec> = (&self.0.logical_codec).into();
283 let expr_list = LogicalExprList {
284 expr: serialize_exprs(args.exprs(), codec.as_ref())?,
285 };
286 let filters_serialized = expr_list.encode_to_vec().into_iter().collect();
287
288 let table_provider =
289 unsafe { (self.0.call_with_args)(&self.0, filters_serialized, session) };
290
291 let table_provider = df_result!(table_provider)?;
292 let table_provider: Arc<dyn TableProvider> = (&table_provider).into();
293
294 Ok(table_provider)
295 }
296
297 fn call(&self, args: &[datafusion_expr::Expr]) -> Result<Arc<dyn TableProvider>> {
298 let codec: Arc<dyn LogicalExtensionCodec> = (&self.0.logical_codec).into();
299 let expr_list = LogicalExprList {
300 expr: serialize_exprs(args, codec.as_ref())?,
301 };
302 let filters_serialized = expr_list.encode_to_vec().into_iter().collect();
303
304 #[expect(deprecated)]
305 let table_provider = unsafe { (self.0.call)(&self.0, filters_serialized) };
306
307 let table_provider = df_result!(table_provider)?;
308 let table_provider: Arc<dyn TableProvider> = (&table_provider).into();
309
310 Ok(table_provider)
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use arrow::array::{
317 ArrayRef, Float64Array, RecordBatch, StringArray, UInt64Array, record_batch,
318 };
319 use arrow::datatypes::{DataType, Field, Schema};
320 use datafusion::catalog::MemTable;
321 use datafusion::common::exec_err;
322 use datafusion::logical_expr::ptr_eq::arc_ptr_eq;
323 use datafusion::prelude::{SessionContext, lit};
324 use datafusion::scalar::ScalarValue;
325 use datafusion_catalog::TableFunctionArgs;
326 use datafusion_execution::TaskContextProvider;
327 use datafusion_expr::Expr;
328
329 use super::*;
330
331 #[derive(Debug)]
332 struct TestUDTF {}
333
334 impl TableFunctionImpl for TestUDTF {
335 fn call_with_args(
336 &self,
337 args: TableFunctionArgs,
338 ) -> Result<Arc<dyn TableProvider>> {
339 let args = args
340 .exprs()
341 .iter()
342 .map(|arg| {
343 if let Expr::Literal(scalar, _) = arg {
344 Ok(scalar)
345 } else {
346 exec_err!("Expected only literal arguments to table udf")
347 }
348 })
349 .collect::<Result<Vec<_>>>()?;
350
351 if args.len() < 2 {
352 exec_err!("Expected at least two arguments to table udf")?
353 }
354
355 let ScalarValue::UInt64(Some(num_rows)) = args[0].to_owned() else {
356 exec_err!(
357 "First argument must be the number of elements to create as u64"
358 )?
359 };
360 let num_rows = num_rows as usize;
361
362 let mut fields = Vec::default();
363 let mut arrays1 = Vec::default();
364 let mut arrays2 = Vec::default();
365
366 let split = num_rows / 3;
367 for (idx, arg) in args[1..].iter().enumerate() {
368 let (field, array) = match arg {
369 ScalarValue::Utf8(s) => {
370 let s_vec = vec![s.to_owned(); num_rows];
371 (
372 Field::new(format!("field-{idx}"), DataType::Utf8, true),
373 Arc::new(StringArray::from(s_vec)) as ArrayRef,
374 )
375 }
376 ScalarValue::UInt64(v) => {
377 let v_vec = vec![v.to_owned(); num_rows];
378 (
379 Field::new(format!("field-{idx}"), DataType::UInt64, true),
380 Arc::new(UInt64Array::from(v_vec)) as ArrayRef,
381 )
382 }
383 ScalarValue::Float64(v) => {
384 let v_vec = vec![v.to_owned(); num_rows];
385 (
386 Field::new(format!("field-{idx}"), DataType::Float64, true),
387 Arc::new(Float64Array::from(v_vec)) as ArrayRef,
388 )
389 }
390 _ => exec_err!(
391 "Test case only supports utf8, u64, and f64. Found {}",
392 arg.data_type()
393 )?,
394 };
395
396 fields.push(field);
397 arrays1.push(array.slice(0, split));
398 arrays2.push(array.slice(split, num_rows - split));
399 }
400
401 let schema = Arc::new(Schema::new(fields));
402 let batches = vec![
403 RecordBatch::try_new(Arc::clone(&schema), arrays1)?,
404 RecordBatch::try_new(Arc::clone(&schema), arrays2)?,
405 ];
406
407 let table_provider = MemTable::try_new(schema, vec![batches])?;
408
409 Ok(Arc::new(table_provider))
410 }
411 }
412
413 #[tokio::test]
414 async fn test_round_trip_udtf() -> Result<()> {
415 let original_udtf = Arc::new(TestUDTF {}) as Arc<dyn TableFunctionImpl>;
416 let ctx = Arc::new(SessionContext::default());
417 let task_ctx_provider = Arc::clone(&ctx) as Arc<dyn TaskContextProvider>;
418 let task_ctx_provider = FFI_TaskContextProvider::from(&task_ctx_provider);
419
420 let mut local_udtf: FFI_TableFunction = FFI_TableFunction::new(
421 Arc::clone(&original_udtf),
422 None,
423 task_ctx_provider,
424 None,
425 );
426 local_udtf.library_marker_id = crate::mock_foreign_marker_id;
427
428 let foreign_udf: Arc<dyn TableFunctionImpl> = local_udtf.into();
429
430 let table = foreign_udf.call_with_args(TableFunctionArgs::new(
431 &[lit(6_u64), lit("one"), lit(2.0), lit(3_u64)],
432 &ctx.state(),
433 ))?;
434
435 let _ = ctx.register_table("test-table", table)?;
436
437 let returned_batches = ctx.table("test-table").await?.collect().await?;
438
439 assert_eq!(returned_batches.len(), 2);
440 let expected_batch_0 = record_batch!(
441 ("field-0", Utf8, ["one", "one"]),
442 ("field-1", Float64, [2.0, 2.0]),
443 ("field-2", UInt64, [3, 3])
444 )?;
445 assert_eq!(returned_batches[0], expected_batch_0);
446
447 let expected_batch_1 = record_batch!(
448 ("field-0", Utf8, ["one", "one", "one", "one"]),
449 ("field-1", Float64, [2.0, 2.0, 2.0, 2.0]),
450 ("field-2", UInt64, [3, 3, 3, 3])
451 )?;
452 assert_eq!(returned_batches[1], expected_batch_1);
453
454 Ok(())
455 }
456
457 #[test]
458 fn test_ffi_udtf_local_bypass() -> Result<()> {
459 let original_udtf = Arc::new(TestUDTF {}) as Arc<dyn TableFunctionImpl>;
460
461 let ctx = Arc::new(SessionContext::default()) as Arc<dyn TaskContextProvider>;
462 let task_ctx_provider = FFI_TaskContextProvider::from(&ctx);
463 let mut ffi_udtf = FFI_TableFunction::new(
464 Arc::clone(&original_udtf),
465 None,
466 task_ctx_provider,
467 None,
468 );
469
470 let foreign_udtf: Arc<dyn TableFunctionImpl> = ffi_udtf.clone().into();
472 assert!(arc_ptr_eq(&original_udtf, &foreign_udtf));
473
474 ffi_udtf.library_marker_id = crate::mock_foreign_marker_id;
476 let foreign_udtf: Arc<dyn TableFunctionImpl> = ffi_udtf.into();
477 assert!(!arc_ptr_eq(&original_udtf, &foreign_udtf));
478
479 Ok(())
480 }
481}