1use std::{any::Any, ffi::c_void, sync::Arc};
19
20use abi_stable::{
21 std_types::{ROption, RResult, RString, RVec},
22 StableAbi,
23};
24use arrow::datatypes::SchemaRef;
25use async_ffi::{FfiFuture, FutureExt};
26use async_trait::async_trait;
27use datafusion::{
28 catalog::{Session, TableProvider},
29 datasource::TableType,
30 error::DataFusionError,
31 execution::{session_state::SessionStateBuilder, TaskContext},
32 logical_expr::{logical_plan::dml::InsertOp, TableProviderFilterPushDown},
33 physical_plan::ExecutionPlan,
34 prelude::{Expr, SessionContext},
35};
36use datafusion_proto::{
37 logical_plan::{
38 from_proto::parse_exprs, to_proto::serialize_exprs, DefaultLogicalExtensionCodec,
39 },
40 protobuf::LogicalExprList,
41};
42use prost::Message;
43use tokio::runtime::Handle;
44
45use crate::{
46 arrow_wrappers::WrappedSchema,
47 df_result, rresult_return,
48 session_config::ForeignSessionConfig,
49 table_source::{FFI_TableProviderFilterPushDown, FFI_TableType},
50};
51
52use super::{
53 execution_plan::{FFI_ExecutionPlan, ForeignExecutionPlan},
54 insert_op::FFI_InsertOp,
55 session_config::FFI_SessionConfig,
56};
57use datafusion::error::Result;
58
59#[repr(C)]
99#[derive(Debug, StableAbi)]
100#[allow(non_camel_case_types)]
101pub struct FFI_TableProvider {
102 pub schema: unsafe extern "C" fn(provider: &Self) -> WrappedSchema,
104
105 pub scan: unsafe extern "C" fn(
117 provider: &Self,
118 session_config: &FFI_SessionConfig,
119 projections: RVec<usize>,
120 filters_serialized: RVec<u8>,
121 limit: ROption<usize>,
122 ) -> FfiFuture<RResult<FFI_ExecutionPlan, RString>>,
123
124 pub table_type: unsafe extern "C" fn(provider: &Self) -> FFI_TableType,
126
127 pub supports_filters_pushdown: Option<
131 unsafe extern "C" fn(
132 provider: &FFI_TableProvider,
133 filters_serialized: RVec<u8>,
134 )
135 -> RResult<RVec<FFI_TableProviderFilterPushDown>, RString>,
136 >,
137
138 pub insert_into:
139 unsafe extern "C" fn(
140 provider: &Self,
141 session_config: &FFI_SessionConfig,
142 input: &FFI_ExecutionPlan,
143 insert_op: FFI_InsertOp,
144 ) -> FfiFuture<RResult<FFI_ExecutionPlan, RString>>,
145
146 pub clone: unsafe extern "C" fn(plan: &Self) -> Self,
149
150 pub release: unsafe extern "C" fn(arg: &mut Self),
152
153 pub version: unsafe extern "C" fn() -> u64,
155
156 pub private_data: *mut c_void,
159}
160
161unsafe impl Send for FFI_TableProvider {}
162unsafe impl Sync for FFI_TableProvider {}
163
164struct ProviderPrivateData {
165 provider: Arc<dyn TableProvider + Send>,
166 runtime: Option<Handle>,
167}
168
169unsafe extern "C" fn schema_fn_wrapper(provider: &FFI_TableProvider) -> WrappedSchema {
170 let private_data = provider.private_data as *const ProviderPrivateData;
171 let provider = &(*private_data).provider;
172
173 provider.schema().into()
174}
175
176unsafe extern "C" fn table_type_fn_wrapper(
177 provider: &FFI_TableProvider,
178) -> FFI_TableType {
179 let private_data = provider.private_data as *const ProviderPrivateData;
180 let provider = &(*private_data).provider;
181
182 provider.table_type().into()
183}
184
185fn supports_filters_pushdown_internal(
186 provider: &Arc<dyn TableProvider + Send>,
187 filters_serialized: &[u8],
188) -> Result<RVec<FFI_TableProviderFilterPushDown>> {
189 let default_ctx = SessionContext::new();
190 let codec = DefaultLogicalExtensionCodec {};
191
192 let filters = match filters_serialized.is_empty() {
193 true => vec![],
194 false => {
195 let proto_filters = LogicalExprList::decode(filters_serialized)
196 .map_err(|e| DataFusionError::Plan(e.to_string()))?;
197
198 parse_exprs(proto_filters.expr.iter(), &default_ctx, &codec)?
199 }
200 };
201 let filters_borrowed: Vec<&Expr> = filters.iter().collect();
202
203 let results: RVec<_> = provider
204 .supports_filters_pushdown(&filters_borrowed)?
205 .iter()
206 .map(|v| v.into())
207 .collect();
208
209 Ok(results)
210}
211
212unsafe extern "C" fn supports_filters_pushdown_fn_wrapper(
213 provider: &FFI_TableProvider,
214 filters_serialized: RVec<u8>,
215) -> RResult<RVec<FFI_TableProviderFilterPushDown>, RString> {
216 let private_data = provider.private_data as *const ProviderPrivateData;
217 let provider = &(*private_data).provider;
218
219 supports_filters_pushdown_internal(provider, &filters_serialized)
220 .map_err(|e| e.to_string().into())
221 .into()
222}
223
224unsafe extern "C" fn scan_fn_wrapper(
225 provider: &FFI_TableProvider,
226 session_config: &FFI_SessionConfig,
227 projections: RVec<usize>,
228 filters_serialized: RVec<u8>,
229 limit: ROption<usize>,
230) -> FfiFuture<RResult<FFI_ExecutionPlan, RString>> {
231 let private_data = provider.private_data as *mut ProviderPrivateData;
232 let internal_provider = &(*private_data).provider;
233 let session_config = session_config.clone();
234 let runtime = &(*private_data).runtime;
235
236 async move {
237 let config = rresult_return!(ForeignSessionConfig::try_from(&session_config));
238 let session = SessionStateBuilder::new()
239 .with_default_features()
240 .with_config(config.0)
241 .build();
242 let ctx = SessionContext::new_with_state(session);
243
244 let filters = match filters_serialized.is_empty() {
245 true => vec![],
246 false => {
247 let default_ctx = SessionContext::new();
248 let codec = DefaultLogicalExtensionCodec {};
249
250 let proto_filters =
251 rresult_return!(LogicalExprList::decode(filters_serialized.as_ref()));
252
253 rresult_return!(parse_exprs(
254 proto_filters.expr.iter(),
255 &default_ctx,
256 &codec
257 ))
258 }
259 };
260
261 let projections: Vec<_> = projections.into_iter().collect();
262
263 let plan = rresult_return!(
264 internal_provider
265 .scan(&ctx.state(), Some(&projections), &filters, limit.into())
266 .await
267 );
268
269 RResult::ROk(FFI_ExecutionPlan::new(
270 plan,
271 ctx.task_ctx(),
272 runtime.clone(),
273 ))
274 }
275 .into_ffi()
276}
277
278unsafe extern "C" fn insert_into_fn_wrapper(
279 provider: &FFI_TableProvider,
280 session_config: &FFI_SessionConfig,
281 input: &FFI_ExecutionPlan,
282 insert_op: FFI_InsertOp,
283) -> FfiFuture<RResult<FFI_ExecutionPlan, RString>> {
284 let private_data = provider.private_data as *mut ProviderPrivateData;
285 let internal_provider = &(*private_data).provider;
286 let session_config = session_config.clone();
287 let input = input.clone();
288 let runtime = &(*private_data).runtime;
289
290 async move {
291 let config = rresult_return!(ForeignSessionConfig::try_from(&session_config));
292 let session = SessionStateBuilder::new()
293 .with_default_features()
294 .with_config(config.0)
295 .build();
296 let ctx = SessionContext::new_with_state(session);
297
298 let input = rresult_return!(ForeignExecutionPlan::try_from(&input).map(Arc::new));
299
300 let insert_op = InsertOp::from(insert_op);
301
302 let plan = rresult_return!(
303 internal_provider
304 .insert_into(&ctx.state(), input, insert_op)
305 .await
306 );
307
308 RResult::ROk(FFI_ExecutionPlan::new(
309 plan,
310 ctx.task_ctx(),
311 runtime.clone(),
312 ))
313 }
314 .into_ffi()
315}
316
317unsafe extern "C" fn release_fn_wrapper(provider: &mut FFI_TableProvider) {
318 let private_data = Box::from_raw(provider.private_data as *mut ProviderPrivateData);
319 drop(private_data);
320}
321
322unsafe extern "C" fn clone_fn_wrapper(provider: &FFI_TableProvider) -> FFI_TableProvider {
323 let old_private_data = provider.private_data as *const ProviderPrivateData;
324 let runtime = (*old_private_data).runtime.clone();
325
326 let private_data = Box::into_raw(Box::new(ProviderPrivateData {
327 provider: Arc::clone(&(*old_private_data).provider),
328 runtime,
329 })) as *mut c_void;
330
331 FFI_TableProvider {
332 schema: schema_fn_wrapper,
333 scan: scan_fn_wrapper,
334 table_type: table_type_fn_wrapper,
335 supports_filters_pushdown: provider.supports_filters_pushdown,
336 insert_into: provider.insert_into,
337 clone: clone_fn_wrapper,
338 release: release_fn_wrapper,
339 version: super::version,
340 private_data,
341 }
342}
343
344impl Drop for FFI_TableProvider {
345 fn drop(&mut self) {
346 unsafe { (self.release)(self) }
347 }
348}
349
350impl FFI_TableProvider {
351 pub fn new(
353 provider: Arc<dyn TableProvider + Send>,
354 can_support_pushdown_filters: bool,
355 runtime: Option<Handle>,
356 ) -> Self {
357 let private_data = Box::new(ProviderPrivateData { provider, runtime });
358
359 Self {
360 schema: schema_fn_wrapper,
361 scan: scan_fn_wrapper,
362 table_type: table_type_fn_wrapper,
363 supports_filters_pushdown: match can_support_pushdown_filters {
364 true => Some(supports_filters_pushdown_fn_wrapper),
365 false => None,
366 },
367 insert_into: insert_into_fn_wrapper,
368 clone: clone_fn_wrapper,
369 release: release_fn_wrapper,
370 version: super::version,
371 private_data: Box::into_raw(private_data) as *mut c_void,
372 }
373 }
374}
375
376#[derive(Debug)]
381pub struct ForeignTableProvider(pub FFI_TableProvider);
382
383unsafe impl Send for ForeignTableProvider {}
384unsafe impl Sync for ForeignTableProvider {}
385
386impl From<&FFI_TableProvider> for ForeignTableProvider {
387 fn from(provider: &FFI_TableProvider) -> Self {
388 Self(provider.clone())
389 }
390}
391
392impl Clone for FFI_TableProvider {
393 fn clone(&self) -> Self {
394 unsafe { (self.clone)(self) }
395 }
396}
397
398#[async_trait]
399impl TableProvider for ForeignTableProvider {
400 fn as_any(&self) -> &dyn Any {
401 self
402 }
403
404 fn schema(&self) -> SchemaRef {
405 let wrapped_schema = unsafe { (self.0.schema)(&self.0) };
406 wrapped_schema.into()
407 }
408
409 fn table_type(&self) -> TableType {
410 unsafe { (self.0.table_type)(&self.0).into() }
411 }
412
413 async fn scan(
414 &self,
415 session: &dyn Session,
416 projection: Option<&Vec<usize>>,
417 filters: &[Expr],
418 limit: Option<usize>,
419 ) -> Result<Arc<dyn ExecutionPlan>> {
420 let session_config: FFI_SessionConfig = session.config().into();
421
422 let projections: Option<RVec<usize>> =
423 projection.map(|p| p.iter().map(|v| v.to_owned()).collect());
424
425 let codec = DefaultLogicalExtensionCodec {};
426 let filter_list = LogicalExprList {
427 expr: serialize_exprs(filters, &codec)?,
428 };
429 let filters_serialized = filter_list.encode_to_vec().into();
430
431 let plan = unsafe {
432 let maybe_plan = (self.0.scan)(
433 &self.0,
434 &session_config,
435 projections.unwrap_or_default(),
436 filters_serialized,
437 limit.into(),
438 )
439 .await;
440
441 ForeignExecutionPlan::try_from(&df_result!(maybe_plan)?)?
442 };
443
444 Ok(Arc::new(plan))
445 }
446
447 fn supports_filters_pushdown(
450 &self,
451 filters: &[&Expr],
452 ) -> Result<Vec<TableProviderFilterPushDown>> {
453 unsafe {
454 let pushdown_fn = match self.0.supports_filters_pushdown {
455 Some(func) => func,
456 None => {
457 return Ok(vec![
458 TableProviderFilterPushDown::Unsupported;
459 filters.len()
460 ])
461 }
462 };
463
464 let codec = DefaultLogicalExtensionCodec {};
465
466 let expr_list = LogicalExprList {
467 expr: serialize_exprs(filters.iter().map(|f| f.to_owned()), &codec)?,
468 };
469 let serialized_filters = expr_list.encode_to_vec();
470
471 let pushdowns = df_result!(pushdown_fn(&self.0, serialized_filters.into()))?;
472
473 Ok(pushdowns.iter().map(|v| v.into()).collect())
474 }
475 }
476
477 async fn insert_into(
478 &self,
479 session: &dyn Session,
480 input: Arc<dyn ExecutionPlan>,
481 insert_op: InsertOp,
482 ) -> Result<Arc<dyn ExecutionPlan>> {
483 let session_config: FFI_SessionConfig = session.config().into();
484
485 let rc = Handle::try_current().ok();
486 let input =
487 FFI_ExecutionPlan::new(input, Arc::new(TaskContext::from(session)), rc);
488 let insert_op: FFI_InsertOp = insert_op.into();
489
490 let plan = unsafe {
491 let maybe_plan =
492 (self.0.insert_into)(&self.0, &session_config, &input, insert_op).await;
493
494 ForeignExecutionPlan::try_from(&df_result!(maybe_plan)?)?
495 };
496
497 Ok(Arc::new(plan))
498 }
499}
500
501#[cfg(test)]
502mod tests {
503 use arrow::datatypes::Schema;
504 use datafusion::prelude::{col, lit};
505
506 use super::*;
507
508 #[tokio::test]
509 async fn test_round_trip_ffi_table_provider_scan() -> Result<()> {
510 use arrow::datatypes::Field;
511 use datafusion::arrow::{
512 array::Float32Array, datatypes::DataType, record_batch::RecordBatch,
513 };
514 use datafusion::datasource::MemTable;
515
516 let schema =
517 Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)]));
518
519 let batch1 = RecordBatch::try_new(
521 Arc::clone(&schema),
522 vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))],
523 )?;
524 let batch2 = RecordBatch::try_new(
525 Arc::clone(&schema),
526 vec![Arc::new(Float32Array::from(vec![64.0]))],
527 )?;
528
529 let ctx = SessionContext::new();
530
531 let provider =
532 Arc::new(MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?);
533
534 let ffi_provider = FFI_TableProvider::new(provider, true, None);
535
536 let foreign_table_provider: ForeignTableProvider = (&ffi_provider).into();
537
538 ctx.register_table("t", Arc::new(foreign_table_provider))?;
539
540 let df = ctx.table("t").await?;
541
542 df.select(vec![col("a")])?
543 .filter(col("a").gt(lit(3.0)))?
544 .show()
545 .await?;
546
547 Ok(())
548 }
549
550 #[tokio::test]
551 async fn test_round_trip_ffi_table_provider_insert_into() -> Result<()> {
552 use arrow::datatypes::Field;
553 use datafusion::arrow::{
554 array::Float32Array, datatypes::DataType, record_batch::RecordBatch,
555 };
556 use datafusion::datasource::MemTable;
557
558 let schema =
559 Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)]));
560
561 let batch1 = RecordBatch::try_new(
563 Arc::clone(&schema),
564 vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))],
565 )?;
566 let batch2 = RecordBatch::try_new(
567 Arc::clone(&schema),
568 vec![Arc::new(Float32Array::from(vec![64.0]))],
569 )?;
570
571 let ctx = SessionContext::new();
572
573 let provider =
574 Arc::new(MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?);
575
576 let ffi_provider = FFI_TableProvider::new(provider, true, None);
577
578 let foreign_table_provider: ForeignTableProvider = (&ffi_provider).into();
579
580 ctx.register_table("t", Arc::new(foreign_table_provider))?;
581
582 let result = ctx
583 .sql("INSERT INTO t VALUES (128.0);")
584 .await?
585 .collect()
586 .await?;
587
588 assert!(result.len() == 1 && result[0].num_rows() == 1);
589
590 ctx.table("t")
591 .await?
592 .select(vec![col("a")])?
593 .filter(col("a").gt(lit(3.0)))?
594 .show()
595 .await?;
596
597 Ok(())
598 }
599
600 #[tokio::test]
601 async fn test_aggregation() -> Result<()> {
602 use arrow::datatypes::Field;
603 use datafusion::arrow::{
604 array::Float32Array, datatypes::DataType, record_batch::RecordBatch,
605 };
606 use datafusion::common::assert_batches_eq;
607 use datafusion::datasource::MemTable;
608
609 let schema =
610 Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)]));
611
612 let batch1 = RecordBatch::try_new(
614 Arc::clone(&schema),
615 vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))],
616 )?;
617
618 let ctx = SessionContext::new();
619
620 let provider = Arc::new(MemTable::try_new(schema, vec![vec![batch1]])?);
621
622 let ffi_provider = FFI_TableProvider::new(provider, true, None);
623
624 let foreign_table_provider: ForeignTableProvider = (&ffi_provider).into();
625
626 ctx.register_table("t", Arc::new(foreign_table_provider))?;
627
628 let result = ctx
629 .sql("SELECT COUNT(*) as cnt FROM t")
630 .await?
631 .collect()
632 .await?;
633 #[rustfmt::skip]
634 let expected = [
635 "+-----+",
636 "| cnt |",
637 "+-----+",
638 "| 3 |",
639 "+-----+"
640 ];
641 assert_batches_eq!(expected, &result);
642 Ok(())
643 }
644}