datafusion_ffi/
udtf.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{ffi::c_void, sync::Arc};
19
20use abi_stable::{
21    std_types::{RResult, RString, RVec},
22    StableAbi,
23};
24
25use datafusion::error::Result;
26use datafusion::{
27    catalog::{TableFunctionImpl, TableProvider},
28    prelude::{Expr, SessionContext},
29};
30use datafusion_proto::{
31    logical_plan::{
32        from_proto::parse_exprs, to_proto::serialize_exprs, DefaultLogicalExtensionCodec,
33    },
34    protobuf::LogicalExprList,
35};
36use prost::Message;
37use tokio::runtime::Handle;
38
39use crate::{
40    df_result, rresult_return,
41    table_provider::{FFI_TableProvider, ForeignTableProvider},
42};
43
44/// A stable struct for sharing a [`TableFunctionImpl`] across FFI boundaries.
45#[repr(C)]
46#[derive(Debug, StableAbi)]
47#[allow(non_camel_case_types)]
48pub struct FFI_TableFunction {
49    /// Equivalent to the `call` function of the TableFunctionImpl.
50    /// The arguments are Expr passed as protobuf encoded bytes.
51    pub call: unsafe extern "C" fn(
52        udtf: &Self,
53        args: RVec<u8>,
54    ) -> RResult<FFI_TableProvider, RString>,
55
56    /// Used to create a clone on the provider of the udtf. This should
57    /// only need to be called by the receiver of the udtf.
58    pub clone: unsafe extern "C" fn(udtf: &Self) -> Self,
59
60    /// Release the memory of the private data when it is no longer being used.
61    pub release: unsafe extern "C" fn(udtf: &mut Self),
62
63    /// Internal data. This is only to be accessed by the provider of the udtf.
64    /// A [`ForeignTableFunction`] should never attempt to access this data.
65    pub private_data: *mut c_void,
66}
67
68unsafe impl Send for FFI_TableFunction {}
69unsafe impl Sync for FFI_TableFunction {}
70
71pub struct TableFunctionPrivateData {
72    udtf: Arc<dyn TableFunctionImpl>,
73    runtime: Option<Handle>,
74}
75
76impl FFI_TableFunction {
77    fn inner(&self) -> &Arc<dyn TableFunctionImpl> {
78        let private_data = self.private_data as *const TableFunctionPrivateData;
79        unsafe { &(*private_data).udtf }
80    }
81
82    fn runtime(&self) -> Option<Handle> {
83        let private_data = self.private_data as *const TableFunctionPrivateData;
84        unsafe { (*private_data).runtime.clone() }
85    }
86}
87
88unsafe extern "C" fn call_fn_wrapper(
89    udtf: &FFI_TableFunction,
90    args: RVec<u8>,
91) -> RResult<FFI_TableProvider, RString> {
92    let runtime = udtf.runtime();
93    let udtf = udtf.inner();
94
95    let default_ctx = SessionContext::new();
96    let codec = DefaultLogicalExtensionCodec {};
97
98    let proto_filters = rresult_return!(LogicalExprList::decode(args.as_ref()));
99
100    let args =
101        rresult_return!(parse_exprs(proto_filters.expr.iter(), &default_ctx, &codec));
102
103    let table_provider = rresult_return!(udtf.call(&args));
104    RResult::ROk(FFI_TableProvider::new(table_provider, false, runtime))
105}
106
107unsafe extern "C" fn release_fn_wrapper(udtf: &mut FFI_TableFunction) {
108    let private_data = Box::from_raw(udtf.private_data as *mut TableFunctionPrivateData);
109    drop(private_data);
110}
111
112unsafe extern "C" fn clone_fn_wrapper(udtf: &FFI_TableFunction) -> FFI_TableFunction {
113    let runtime = udtf.runtime();
114    let udtf = udtf.inner();
115
116    FFI_TableFunction::new(Arc::clone(udtf), runtime)
117}
118
119impl Clone for FFI_TableFunction {
120    fn clone(&self) -> Self {
121        unsafe { (self.clone)(self) }
122    }
123}
124
125impl FFI_TableFunction {
126    pub fn new(udtf: Arc<dyn TableFunctionImpl>, runtime: Option<Handle>) -> Self {
127        let private_data = Box::new(TableFunctionPrivateData { udtf, runtime });
128
129        Self {
130            call: call_fn_wrapper,
131            clone: clone_fn_wrapper,
132            release: release_fn_wrapper,
133            private_data: Box::into_raw(private_data) as *mut c_void,
134        }
135    }
136}
137
138impl From<Arc<dyn TableFunctionImpl>> for FFI_TableFunction {
139    fn from(udtf: Arc<dyn TableFunctionImpl>) -> Self {
140        let private_data = Box::new(TableFunctionPrivateData {
141            udtf,
142            runtime: None,
143        });
144
145        Self {
146            call: call_fn_wrapper,
147            clone: clone_fn_wrapper,
148            release: release_fn_wrapper,
149            private_data: Box::into_raw(private_data) as *mut c_void,
150        }
151    }
152}
153
154impl Drop for FFI_TableFunction {
155    fn drop(&mut self) {
156        unsafe { (self.release)(self) }
157    }
158}
159
160/// This struct is used to access an UDTF provided by a foreign
161/// library across a FFI boundary.
162///
163/// The ForeignTableFunction is to be used by the caller of the UDTF, so it has
164/// no knowledge or access to the private data. All interaction with the UDTF
165/// must occur through the functions defined in FFI_TableFunction.
166#[derive(Debug)]
167pub struct ForeignTableFunction(FFI_TableFunction);
168
169unsafe impl Send for ForeignTableFunction {}
170unsafe impl Sync for ForeignTableFunction {}
171
172impl From<FFI_TableFunction> for ForeignTableFunction {
173    fn from(value: FFI_TableFunction) -> Self {
174        Self(value)
175    }
176}
177
178impl TableFunctionImpl for ForeignTableFunction {
179    fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
180        let codec = DefaultLogicalExtensionCodec {};
181        let expr_list = LogicalExprList {
182            expr: serialize_exprs(args, &codec)?,
183        };
184        let filters_serialized = expr_list.encode_to_vec().into();
185
186        let table_provider = unsafe { (self.0.call)(&self.0, filters_serialized) };
187
188        let table_provider = df_result!(table_provider)?;
189        let table_provider: ForeignTableProvider = (&table_provider).into();
190
191        Ok(Arc::new(table_provider))
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use arrow::{
198        array::{
199            record_batch, ArrayRef, Float64Array, RecordBatch, StringArray, UInt64Array,
200        },
201        datatypes::{DataType, Field, Schema},
202    };
203    use datafusion::{
204        catalog::MemTable, common::exec_err, prelude::lit, scalar::ScalarValue,
205    };
206
207    use super::*;
208
209    #[derive(Debug)]
210    struct TestUDTF {}
211
212    impl TableFunctionImpl for TestUDTF {
213        fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
214            let args = args
215                .iter()
216                .map(|arg| {
217                    if let Expr::Literal(scalar, _) = arg {
218                        Ok(scalar)
219                    } else {
220                        exec_err!("Expected only literal arguments to table udf")
221                    }
222                })
223                .collect::<Result<Vec<_>>>()?;
224
225            if args.len() < 2 {
226                exec_err!("Expected at least two arguments to table udf")?
227            }
228
229            let ScalarValue::UInt64(Some(num_rows)) = args[0].to_owned() else {
230                exec_err!(
231                    "First argument must be the number of elements to create as u64"
232                )?
233            };
234            let num_rows = num_rows as usize;
235
236            let mut fields = Vec::default();
237            let mut arrays1 = Vec::default();
238            let mut arrays2 = Vec::default();
239
240            let split = num_rows / 3;
241            for (idx, arg) in args[1..].iter().enumerate() {
242                let (field, array) = match arg {
243                    ScalarValue::Utf8(s) => {
244                        let s_vec = vec![s.to_owned(); num_rows];
245                        (
246                            Field::new(format!("field-{idx}"), DataType::Utf8, true),
247                            Arc::new(StringArray::from(s_vec)) as ArrayRef,
248                        )
249                    }
250                    ScalarValue::UInt64(v) => {
251                        let v_vec = vec![v.to_owned(); num_rows];
252                        (
253                            Field::new(format!("field-{idx}"), DataType::UInt64, true),
254                            Arc::new(UInt64Array::from(v_vec)) as ArrayRef,
255                        )
256                    }
257                    ScalarValue::Float64(v) => {
258                        let v_vec = vec![v.to_owned(); num_rows];
259                        (
260                            Field::new(format!("field-{idx}"), DataType::Float64, true),
261                            Arc::new(Float64Array::from(v_vec)) as ArrayRef,
262                        )
263                    }
264                    _ => exec_err!(
265                        "Test case only supports utf8, u64, and f64. Found {}",
266                        arg.data_type()
267                    )?,
268                };
269
270                fields.push(field);
271                arrays1.push(array.slice(0, split));
272                arrays2.push(array.slice(split, num_rows - split));
273            }
274
275            let schema = Arc::new(Schema::new(fields));
276            let batches = vec![
277                RecordBatch::try_new(Arc::clone(&schema), arrays1)?,
278                RecordBatch::try_new(Arc::clone(&schema), arrays2)?,
279            ];
280
281            let table_provider = MemTable::try_new(schema, vec![batches])?;
282
283            Ok(Arc::new(table_provider))
284        }
285    }
286
287    #[tokio::test]
288    async fn test_round_trip_udtf() -> Result<()> {
289        let original_udtf = Arc::new(TestUDTF {}) as Arc<dyn TableFunctionImpl>;
290
291        let local_udtf: FFI_TableFunction =
292            FFI_TableFunction::new(Arc::clone(&original_udtf), None);
293
294        let foreign_udf: ForeignTableFunction = local_udtf.into();
295
296        let table = foreign_udf.call(&[lit(6_u64), lit("one"), lit(2.0), lit(3_u64)])?;
297
298        let ctx = SessionContext::default();
299        let _ = ctx.register_table("test-table", table)?;
300
301        let returned_batches = ctx.table("test-table").await?.collect().await?;
302
303        assert_eq!(returned_batches.len(), 2);
304        let expected_batch_0 = record_batch!(
305            ("field-0", Utf8, ["one", "one"]),
306            ("field-1", Float64, [2.0, 2.0]),
307            ("field-2", UInt64, [3, 3])
308        )?;
309        assert_eq!(returned_batches[0], expected_batch_0);
310
311        let expected_batch_1 = record_batch!(
312            ("field-0", Utf8, ["one", "one", "one", "one"]),
313            ("field-1", Float64, [2.0, 2.0, 2.0, 2.0]),
314            ("field-2", UInt64, [3, 3, 3, 3])
315        )?;
316        assert_eq!(returned_batches[1], expected_batch_1);
317
318        Ok(())
319    }
320}