1use std::{ffi::c_void, sync::Arc};
19
20use abi_stable::{
21 std_types::{RResult, RString, RVec},
22 StableAbi,
23};
24
25use datafusion::error::Result;
26use datafusion::{
27 catalog::{TableFunctionImpl, TableProvider},
28 prelude::{Expr, SessionContext},
29};
30use datafusion_proto::{
31 logical_plan::{
32 from_proto::parse_exprs, to_proto::serialize_exprs, DefaultLogicalExtensionCodec,
33 },
34 protobuf::LogicalExprList,
35};
36use prost::Message;
37use tokio::runtime::Handle;
38
39use crate::{
40 df_result, rresult_return,
41 table_provider::{FFI_TableProvider, ForeignTableProvider},
42};
43
44#[repr(C)]
46#[derive(Debug, StableAbi)]
47#[allow(non_camel_case_types)]
48pub struct FFI_TableFunction {
49 pub call: unsafe extern "C" fn(
52 udtf: &Self,
53 args: RVec<u8>,
54 ) -> RResult<FFI_TableProvider, RString>,
55
56 pub clone: unsafe extern "C" fn(udtf: &Self) -> Self,
59
60 pub release: unsafe extern "C" fn(udtf: &mut Self),
62
63 pub private_data: *mut c_void,
66}
67
68unsafe impl Send for FFI_TableFunction {}
69unsafe impl Sync for FFI_TableFunction {}
70
71pub struct TableFunctionPrivateData {
72 udtf: Arc<dyn TableFunctionImpl>,
73 runtime: Option<Handle>,
74}
75
76impl FFI_TableFunction {
77 fn inner(&self) -> &Arc<dyn TableFunctionImpl> {
78 let private_data = self.private_data as *const TableFunctionPrivateData;
79 unsafe { &(*private_data).udtf }
80 }
81
82 fn runtime(&self) -> Option<Handle> {
83 let private_data = self.private_data as *const TableFunctionPrivateData;
84 unsafe { (*private_data).runtime.clone() }
85 }
86}
87
88unsafe extern "C" fn call_fn_wrapper(
89 udtf: &FFI_TableFunction,
90 args: RVec<u8>,
91) -> RResult<FFI_TableProvider, RString> {
92 let runtime = udtf.runtime();
93 let udtf = udtf.inner();
94
95 let default_ctx = SessionContext::new();
96 let codec = DefaultLogicalExtensionCodec {};
97
98 let proto_filters = rresult_return!(LogicalExprList::decode(args.as_ref()));
99
100 let args =
101 rresult_return!(parse_exprs(proto_filters.expr.iter(), &default_ctx, &codec));
102
103 let table_provider = rresult_return!(udtf.call(&args));
104 RResult::ROk(FFI_TableProvider::new(table_provider, false, runtime))
105}
106
107unsafe extern "C" fn release_fn_wrapper(udtf: &mut FFI_TableFunction) {
108 let private_data = Box::from_raw(udtf.private_data as *mut TableFunctionPrivateData);
109 drop(private_data);
110}
111
112unsafe extern "C" fn clone_fn_wrapper(udtf: &FFI_TableFunction) -> FFI_TableFunction {
113 let runtime = udtf.runtime();
114 let udtf = udtf.inner();
115
116 FFI_TableFunction::new(Arc::clone(udtf), runtime)
117}
118
119impl Clone for FFI_TableFunction {
120 fn clone(&self) -> Self {
121 unsafe { (self.clone)(self) }
122 }
123}
124
125impl FFI_TableFunction {
126 pub fn new(udtf: Arc<dyn TableFunctionImpl>, runtime: Option<Handle>) -> Self {
127 let private_data = Box::new(TableFunctionPrivateData { udtf, runtime });
128
129 Self {
130 call: call_fn_wrapper,
131 clone: clone_fn_wrapper,
132 release: release_fn_wrapper,
133 private_data: Box::into_raw(private_data) as *mut c_void,
134 }
135 }
136}
137
138impl From<Arc<dyn TableFunctionImpl>> for FFI_TableFunction {
139 fn from(udtf: Arc<dyn TableFunctionImpl>) -> Self {
140 let private_data = Box::new(TableFunctionPrivateData {
141 udtf,
142 runtime: None,
143 });
144
145 Self {
146 call: call_fn_wrapper,
147 clone: clone_fn_wrapper,
148 release: release_fn_wrapper,
149 private_data: Box::into_raw(private_data) as *mut c_void,
150 }
151 }
152}
153
154impl Drop for FFI_TableFunction {
155 fn drop(&mut self) {
156 unsafe { (self.release)(self) }
157 }
158}
159
160#[derive(Debug)]
167pub struct ForeignTableFunction(FFI_TableFunction);
168
169unsafe impl Send for ForeignTableFunction {}
170unsafe impl Sync for ForeignTableFunction {}
171
172impl From<FFI_TableFunction> for ForeignTableFunction {
173 fn from(value: FFI_TableFunction) -> Self {
174 Self(value)
175 }
176}
177
178impl TableFunctionImpl for ForeignTableFunction {
179 fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
180 let codec = DefaultLogicalExtensionCodec {};
181 let expr_list = LogicalExprList {
182 expr: serialize_exprs(args, &codec)?,
183 };
184 let filters_serialized = expr_list.encode_to_vec().into();
185
186 let table_provider = unsafe { (self.0.call)(&self.0, filters_serialized) };
187
188 let table_provider = df_result!(table_provider)?;
189 let table_provider: ForeignTableProvider = (&table_provider).into();
190
191 Ok(Arc::new(table_provider))
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use arrow::{
198 array::{
199 record_batch, ArrayRef, Float64Array, RecordBatch, StringArray, UInt64Array,
200 },
201 datatypes::{DataType, Field, Schema},
202 };
203 use datafusion::{
204 catalog::MemTable, common::exec_err, prelude::lit, scalar::ScalarValue,
205 };
206
207 use super::*;
208
209 #[derive(Debug)]
210 struct TestUDTF {}
211
212 impl TableFunctionImpl for TestUDTF {
213 fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
214 let args = args
215 .iter()
216 .map(|arg| {
217 if let Expr::Literal(scalar, _) = arg {
218 Ok(scalar)
219 } else {
220 exec_err!("Expected only literal arguments to table udf")
221 }
222 })
223 .collect::<Result<Vec<_>>>()?;
224
225 if args.len() < 2 {
226 exec_err!("Expected at least two arguments to table udf")?
227 }
228
229 let ScalarValue::UInt64(Some(num_rows)) = args[0].to_owned() else {
230 exec_err!(
231 "First argument must be the number of elements to create as u64"
232 )?
233 };
234 let num_rows = num_rows as usize;
235
236 let mut fields = Vec::default();
237 let mut arrays1 = Vec::default();
238 let mut arrays2 = Vec::default();
239
240 let split = num_rows / 3;
241 for (idx, arg) in args[1..].iter().enumerate() {
242 let (field, array) = match arg {
243 ScalarValue::Utf8(s) => {
244 let s_vec = vec![s.to_owned(); num_rows];
245 (
246 Field::new(format!("field-{idx}"), DataType::Utf8, true),
247 Arc::new(StringArray::from(s_vec)) as ArrayRef,
248 )
249 }
250 ScalarValue::UInt64(v) => {
251 let v_vec = vec![v.to_owned(); num_rows];
252 (
253 Field::new(format!("field-{idx}"), DataType::UInt64, true),
254 Arc::new(UInt64Array::from(v_vec)) as ArrayRef,
255 )
256 }
257 ScalarValue::Float64(v) => {
258 let v_vec = vec![v.to_owned(); num_rows];
259 (
260 Field::new(format!("field-{idx}"), DataType::Float64, true),
261 Arc::new(Float64Array::from(v_vec)) as ArrayRef,
262 )
263 }
264 _ => exec_err!(
265 "Test case only supports utf8, u64, and f64. Found {}",
266 arg.data_type()
267 )?,
268 };
269
270 fields.push(field);
271 arrays1.push(array.slice(0, split));
272 arrays2.push(array.slice(split, num_rows - split));
273 }
274
275 let schema = Arc::new(Schema::new(fields));
276 let batches = vec![
277 RecordBatch::try_new(Arc::clone(&schema), arrays1)?,
278 RecordBatch::try_new(Arc::clone(&schema), arrays2)?,
279 ];
280
281 let table_provider = MemTable::try_new(schema, vec![batches])?;
282
283 Ok(Arc::new(table_provider))
284 }
285 }
286
287 #[tokio::test]
288 async fn test_round_trip_udtf() -> Result<()> {
289 let original_udtf = Arc::new(TestUDTF {}) as Arc<dyn TableFunctionImpl>;
290
291 let local_udtf: FFI_TableFunction =
292 FFI_TableFunction::new(Arc::clone(&original_udtf), None);
293
294 let foreign_udf: ForeignTableFunction = local_udtf.into();
295
296 let table = foreign_udf.call(&[lit(6_u64), lit("one"), lit(2.0), lit(3_u64)])?;
297
298 let ctx = SessionContext::default();
299 let _ = ctx.register_table("test-table", table)?;
300
301 let returned_batches = ctx.table("test-table").await?.collect().await?;
302
303 assert_eq!(returned_batches.len(), 2);
304 let expected_batch_0 = record_batch!(
305 ("field-0", Utf8, ["one", "one"]),
306 ("field-1", Float64, [2.0, 2.0]),
307 ("field-2", UInt64, [3, 3])
308 )?;
309 assert_eq!(returned_batches[0], expected_batch_0);
310
311 let expected_batch_1 = record_batch!(
312 ("field-0", Utf8, ["one", "one", "one", "one"]),
313 ("field-1", Float64, [2.0, 2.0, 2.0, 2.0]),
314 ("field-2", UInt64, [3, 3, 3, 3])
315 )?;
316 assert_eq!(returned_batches[1], expected_batch_1);
317
318 Ok(())
319 }
320}