Skip to main content

datafusion_ffi/
udtf.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::ffi::c_void;
20use std::sync::Arc;
21
22use datafusion_catalog::{TableFunctionArgs, TableFunctionImpl, TableProvider};
23use datafusion_common::DataFusionError;
24use datafusion_common::error::Result;
25use datafusion_execution::TaskContext;
26use datafusion_proto::logical_plan::from_proto::parse_exprs;
27use datafusion_proto::logical_plan::to_proto::serialize_exprs;
28use datafusion_proto::logical_plan::{
29    DefaultLogicalExtensionCodec, LogicalExtensionCodec,
30};
31use datafusion_proto::protobuf::LogicalExprList;
32use datafusion_session::Session;
33use prost::Message;
34use stabby::vec::Vec as SVec;
35use tokio::runtime::Handle;
36
37use crate::execution::FFI_TaskContextProvider;
38use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
39use crate::session::{FFI_SessionRef, ForeignSession};
40use crate::table_provider::FFI_TableProvider;
41use crate::util::FFI_Result;
42use crate::{df_result, sresult_return};
43
44/// A stable struct for sharing a [`TableFunctionImpl`] across FFI boundaries.
45#[repr(C)]
46#[derive(Debug)]
47pub struct FFI_TableFunction {
48    /// Equivalent to the [`TableFunctionImpl::call`].
49    /// The arguments are Expr passed as protobuf encoded bytes.
50    #[deprecated(
51        since = "53.0.0",
52        note = "See TableFunctionImpl::call deprecation note"
53    )]
54    pub call: unsafe extern "C" fn(
55        udtf: &Self,
56        args: SVec<u8>,
57    ) -> FFI_Result<FFI_TableProvider>,
58
59    /// Equivalent to the [`TableFunctionImpl::call_with_args`].
60    call_with_args: unsafe extern "C" fn(
61        udtf: &Self,
62        args: SVec<u8>,
63        session: FFI_SessionRef,
64    ) -> FFI_Result<FFI_TableProvider>,
65
66    pub logical_codec: FFI_LogicalExtensionCodec,
67
68    /// Used to create a clone on the provider of the udtf. This should
69    /// only need to be called by the receiver of the udtf.
70    pub clone: unsafe extern "C" fn(udtf: &Self) -> Self,
71
72    /// Release the memory of the private data when it is no longer being used.
73    pub release: unsafe extern "C" fn(udtf: &mut Self),
74
75    /// Internal data. This is only to be accessed by the provider of the udtf.
76    /// A [`ForeignTableFunction`] should never attempt to access this data.
77    pub private_data: *mut c_void,
78
79    /// Utility to identify when FFI objects are accessed locally through
80    /// the foreign interface. See [`crate::get_library_marker_id`] and
81    /// the crate's `README.md` for more information.
82    pub library_marker_id: extern "C" fn() -> usize,
83}
84
85unsafe impl Send for FFI_TableFunction {}
86unsafe impl Sync for FFI_TableFunction {}
87
88pub struct TableFunctionPrivateData {
89    udtf: Arc<dyn TableFunctionImpl>,
90    runtime: Option<Handle>,
91}
92
93impl FFI_TableFunction {
94    fn inner(&self) -> &Arc<dyn TableFunctionImpl> {
95        let private_data = self.private_data as *const TableFunctionPrivateData;
96        unsafe { &(*private_data).udtf }
97    }
98
99    fn runtime(&self) -> Option<Handle> {
100        let private_data = self.private_data as *const TableFunctionPrivateData;
101        unsafe { (*private_data).runtime.clone() }
102    }
103}
104
105unsafe extern "C" fn call_fn_wrapper(
106    udtf: &FFI_TableFunction,
107    args: SVec<u8>,
108) -> FFI_Result<FFI_TableProvider> {
109    let runtime = udtf.runtime();
110    let udtf_inner = udtf.inner();
111
112    let ctx: Arc<TaskContext> =
113        sresult_return!((&udtf.logical_codec.task_ctx_provider).try_into());
114    let codec: Arc<dyn LogicalExtensionCodec> = (&udtf.logical_codec).into();
115
116    let proto_filters = sresult_return!(LogicalExprList::decode(args.as_ref()));
117
118    let args = sresult_return!(parse_exprs(
119        proto_filters.expr.iter(),
120        ctx.as_ref(),
121        codec.as_ref()
122    ));
123
124    #[expect(deprecated)]
125    let table_provider = sresult_return!(udtf_inner.call(&args));
126    FFI_Result::Ok(FFI_TableProvider::new_with_ffi_codec(
127        table_provider,
128        false,
129        runtime,
130        udtf.logical_codec.clone(),
131    ))
132}
133
134unsafe extern "C" fn call_with_args_wrapper(
135    udtf: &FFI_TableFunction,
136    args: SVec<u8>,
137    session: FFI_SessionRef,
138) -> FFI_Result<FFI_TableProvider> {
139    let runtime = udtf.runtime();
140    let udtf_inner = udtf.inner();
141
142    let ctx: Arc<TaskContext> =
143        sresult_return!((&udtf.logical_codec.task_ctx_provider).try_into());
144    let codec: Arc<dyn LogicalExtensionCodec> = (&udtf.logical_codec).into();
145
146    let proto_filters = sresult_return!(LogicalExprList::decode(args.as_ref()));
147
148    let args = sresult_return!(parse_exprs(
149        proto_filters.expr.iter(),
150        ctx.as_ref(),
151        codec.as_ref()
152    ));
153
154    let mut foreign_session = None;
155    let session = sresult_return!(
156        session
157            .as_local()
158            .map(Ok::<&(dyn Session + Send + Sync), DataFusionError>)
159            .unwrap_or_else(|| {
160                foreign_session = Some(ForeignSession::try_from(&session)?);
161                Ok(foreign_session.as_ref().unwrap())
162            })
163    );
164    let table_provider = sresult_return!(
165        udtf_inner.call_with_args(TableFunctionArgs::new(&args, session))
166    );
167    FFI_Result::Ok(FFI_TableProvider::new_with_ffi_codec(
168        table_provider,
169        false,
170        runtime,
171        udtf.logical_codec.clone(),
172    ))
173}
174
175unsafe extern "C" fn release_fn_wrapper(udtf: &mut FFI_TableFunction) {
176    unsafe {
177        debug_assert!(!udtf.private_data.is_null());
178        let private_data =
179            Box::from_raw(udtf.private_data as *mut TableFunctionPrivateData);
180        drop(private_data);
181        udtf.private_data = std::ptr::null_mut();
182    }
183}
184
185unsafe extern "C" fn clone_fn_wrapper(udtf: &FFI_TableFunction) -> FFI_TableFunction {
186    let runtime = udtf.runtime();
187    let udtf_inner = udtf.inner();
188
189    FFI_TableFunction::new_with_ffi_codec(
190        Arc::clone(udtf_inner),
191        runtime,
192        udtf.logical_codec.clone(),
193    )
194}
195
196impl Clone for FFI_TableFunction {
197    fn clone(&self) -> Self {
198        unsafe { (self.clone)(self) }
199    }
200}
201
202impl FFI_TableFunction {
203    pub fn new(
204        udtf: Arc<dyn TableFunctionImpl>,
205        runtime: Option<Handle>,
206        task_ctx_provider: impl Into<FFI_TaskContextProvider>,
207        logical_codec: Option<Arc<dyn LogicalExtensionCodec>>,
208    ) -> Self {
209        let task_ctx_provider = task_ctx_provider.into();
210        let logical_codec =
211            logical_codec.unwrap_or_else(|| Arc::new(DefaultLogicalExtensionCodec {}));
212        let logical_codec = FFI_LogicalExtensionCodec::new(
213            logical_codec,
214            runtime.clone(),
215            task_ctx_provider.clone(),
216        );
217
218        Self::new_with_ffi_codec(udtf, runtime, logical_codec)
219    }
220
221    pub fn new_with_ffi_codec(
222        udtf: Arc<dyn TableFunctionImpl>,
223        runtime: Option<Handle>,
224        logical_codec: FFI_LogicalExtensionCodec,
225    ) -> Self {
226        if let Some(udtf) =
227            (Arc::clone(&udtf) as Arc<dyn Any>).downcast_ref::<ForeignTableFunction>()
228        {
229            return udtf.0.clone();
230        }
231
232        let private_data = Box::new(TableFunctionPrivateData { udtf, runtime });
233
234        Self {
235            #[expect(deprecated)]
236            call: call_fn_wrapper,
237            call_with_args: call_with_args_wrapper,
238            logical_codec,
239            clone: clone_fn_wrapper,
240            release: release_fn_wrapper,
241            private_data: Box::into_raw(private_data) as *mut c_void,
242            library_marker_id: crate::get_library_marker_id,
243        }
244    }
245}
246
247impl Drop for FFI_TableFunction {
248    fn drop(&mut self) {
249        unsafe { (self.release)(self) }
250    }
251}
252
253/// This struct is used to access an UDTF provided by a foreign
254/// library across a FFI boundary.
255///
256/// The ForeignTableFunction is to be used by the caller of the UDTF, so it has
257/// no knowledge or access to the private data. All interaction with the UDTF
258/// must occur through the functions defined in FFI_TableFunction.
259#[derive(Debug)]
260pub struct ForeignTableFunction(FFI_TableFunction);
261
262unsafe impl Send for ForeignTableFunction {}
263unsafe impl Sync for ForeignTableFunction {}
264
265impl From<FFI_TableFunction> for Arc<dyn TableFunctionImpl> {
266    fn from(value: FFI_TableFunction) -> Self {
267        if (value.library_marker_id)() == crate::get_library_marker_id() {
268            Arc::clone(value.inner())
269        } else {
270            Arc::new(ForeignTableFunction(value))
271        }
272    }
273}
274
275impl TableFunctionImpl for ForeignTableFunction {
276    fn call_with_args(&self, args: TableFunctionArgs) -> Result<Arc<dyn TableProvider>> {
277        let session = FFI_SessionRef::new(
278            args.session(),
279            self.0.runtime(),
280            self.0.logical_codec.clone(),
281        );
282        let codec: Arc<dyn LogicalExtensionCodec> = (&self.0.logical_codec).into();
283        let expr_list = LogicalExprList {
284            expr: serialize_exprs(args.exprs(), codec.as_ref())?,
285        };
286        let filters_serialized = expr_list.encode_to_vec().into_iter().collect();
287
288        let table_provider =
289            unsafe { (self.0.call_with_args)(&self.0, filters_serialized, session) };
290
291        let table_provider = df_result!(table_provider)?;
292        let table_provider: Arc<dyn TableProvider> = (&table_provider).into();
293
294        Ok(table_provider)
295    }
296
297    fn call(&self, args: &[datafusion_expr::Expr]) -> Result<Arc<dyn TableProvider>> {
298        let codec: Arc<dyn LogicalExtensionCodec> = (&self.0.logical_codec).into();
299        let expr_list = LogicalExprList {
300            expr: serialize_exprs(args, codec.as_ref())?,
301        };
302        let filters_serialized = expr_list.encode_to_vec().into_iter().collect();
303
304        #[expect(deprecated)]
305        let table_provider = unsafe { (self.0.call)(&self.0, filters_serialized) };
306
307        let table_provider = df_result!(table_provider)?;
308        let table_provider: Arc<dyn TableProvider> = (&table_provider).into();
309
310        Ok(table_provider)
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use arrow::array::{
317        ArrayRef, Float64Array, RecordBatch, StringArray, UInt64Array, record_batch,
318    };
319    use arrow::datatypes::{DataType, Field, Schema};
320    use datafusion::catalog::MemTable;
321    use datafusion::common::exec_err;
322    use datafusion::logical_expr::ptr_eq::arc_ptr_eq;
323    use datafusion::prelude::{SessionContext, lit};
324    use datafusion::scalar::ScalarValue;
325    use datafusion_catalog::TableFunctionArgs;
326    use datafusion_execution::TaskContextProvider;
327    use datafusion_expr::Expr;
328
329    use super::*;
330
331    #[derive(Debug)]
332    struct TestUDTF {}
333
334    impl TableFunctionImpl for TestUDTF {
335        fn call_with_args(
336            &self,
337            args: TableFunctionArgs,
338        ) -> Result<Arc<dyn TableProvider>> {
339            let args = args
340                .exprs()
341                .iter()
342                .map(|arg| {
343                    if let Expr::Literal(scalar, _) = arg {
344                        Ok(scalar)
345                    } else {
346                        exec_err!("Expected only literal arguments to table udf")
347                    }
348                })
349                .collect::<Result<Vec<_>>>()?;
350
351            if args.len() < 2 {
352                exec_err!("Expected at least two arguments to table udf")?
353            }
354
355            let ScalarValue::UInt64(Some(num_rows)) = args[0].to_owned() else {
356                exec_err!(
357                    "First argument must be the number of elements to create as u64"
358                )?
359            };
360            let num_rows = num_rows as usize;
361
362            let mut fields = Vec::default();
363            let mut arrays1 = Vec::default();
364            let mut arrays2 = Vec::default();
365
366            let split = num_rows / 3;
367            for (idx, arg) in args[1..].iter().enumerate() {
368                let (field, array) = match arg {
369                    ScalarValue::Utf8(s) => {
370                        let s_vec = vec![s.to_owned(); num_rows];
371                        (
372                            Field::new(format!("field-{idx}"), DataType::Utf8, true),
373                            Arc::new(StringArray::from(s_vec)) as ArrayRef,
374                        )
375                    }
376                    ScalarValue::UInt64(v) => {
377                        let v_vec = vec![v.to_owned(); num_rows];
378                        (
379                            Field::new(format!("field-{idx}"), DataType::UInt64, true),
380                            Arc::new(UInt64Array::from(v_vec)) as ArrayRef,
381                        )
382                    }
383                    ScalarValue::Float64(v) => {
384                        let v_vec = vec![v.to_owned(); num_rows];
385                        (
386                            Field::new(format!("field-{idx}"), DataType::Float64, true),
387                            Arc::new(Float64Array::from(v_vec)) as ArrayRef,
388                        )
389                    }
390                    _ => exec_err!(
391                        "Test case only supports utf8, u64, and f64. Found {}",
392                        arg.data_type()
393                    )?,
394                };
395
396                fields.push(field);
397                arrays1.push(array.slice(0, split));
398                arrays2.push(array.slice(split, num_rows - split));
399            }
400
401            let schema = Arc::new(Schema::new(fields));
402            let batches = vec![
403                RecordBatch::try_new(Arc::clone(&schema), arrays1)?,
404                RecordBatch::try_new(Arc::clone(&schema), arrays2)?,
405            ];
406
407            let table_provider = MemTable::try_new(schema, vec![batches])?;
408
409            Ok(Arc::new(table_provider))
410        }
411    }
412
413    #[tokio::test]
414    async fn test_round_trip_udtf() -> Result<()> {
415        let original_udtf = Arc::new(TestUDTF {}) as Arc<dyn TableFunctionImpl>;
416        let ctx = Arc::new(SessionContext::default());
417        let task_ctx_provider = Arc::clone(&ctx) as Arc<dyn TaskContextProvider>;
418        let task_ctx_provider = FFI_TaskContextProvider::from(&task_ctx_provider);
419
420        let mut local_udtf: FFI_TableFunction = FFI_TableFunction::new(
421            Arc::clone(&original_udtf),
422            None,
423            task_ctx_provider,
424            None,
425        );
426        local_udtf.library_marker_id = crate::mock_foreign_marker_id;
427
428        let foreign_udf: Arc<dyn TableFunctionImpl> = local_udtf.into();
429
430        let table = foreign_udf.call_with_args(TableFunctionArgs::new(
431            &[lit(6_u64), lit("one"), lit(2.0), lit(3_u64)],
432            &ctx.state(),
433        ))?;
434
435        let _ = ctx.register_table("test-table", table)?;
436
437        let returned_batches = ctx.table("test-table").await?.collect().await?;
438
439        assert_eq!(returned_batches.len(), 2);
440        let expected_batch_0 = record_batch!(
441            ("field-0", Utf8, ["one", "one"]),
442            ("field-1", Float64, [2.0, 2.0]),
443            ("field-2", UInt64, [3, 3])
444        )?;
445        assert_eq!(returned_batches[0], expected_batch_0);
446
447        let expected_batch_1 = record_batch!(
448            ("field-0", Utf8, ["one", "one", "one", "one"]),
449            ("field-1", Float64, [2.0, 2.0, 2.0, 2.0]),
450            ("field-2", UInt64, [3, 3, 3, 3])
451        )?;
452        assert_eq!(returned_batches[1], expected_batch_1);
453
454        Ok(())
455    }
456
457    #[test]
458    fn test_ffi_udtf_local_bypass() -> Result<()> {
459        let original_udtf = Arc::new(TestUDTF {}) as Arc<dyn TableFunctionImpl>;
460
461        let ctx = Arc::new(SessionContext::default()) as Arc<dyn TaskContextProvider>;
462        let task_ctx_provider = FFI_TaskContextProvider::from(&ctx);
463        let mut ffi_udtf = FFI_TableFunction::new(
464            Arc::clone(&original_udtf),
465            None,
466            task_ctx_provider,
467            None,
468        );
469
470        // Verify local libraries can be downcast to their original
471        let foreign_udtf: Arc<dyn TableFunctionImpl> = ffi_udtf.clone().into();
472        assert!(arc_ptr_eq(&original_udtf, &foreign_udtf));
473
474        // Verify different library markers generate foreign providers
475        ffi_udtf.library_marker_id = crate::mock_foreign_marker_id;
476        let foreign_udtf: Arc<dyn TableFunctionImpl> = ffi_udtf.into();
477        assert!(!arc_ptr_eq(&original_udtf, &foreign_udtf));
478
479        Ok(())
480    }
481}