use std::any::Any;
use std::ffi::c_void;
use std::sync::Arc;
use abi_stable::StableAbi;
use abi_stable::std_types::{ROption, RResult, RVec};
use arrow::datatypes::SchemaRef;
use async_ffi::{FfiFuture, FutureExt};
use async_trait::async_trait;
use datafusion_catalog::{Session, TableProvider};
use datafusion_common::error::{DataFusionError, Result};
use datafusion_execution::TaskContext;
use datafusion_expr::dml::InsertOp;
use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType};
use datafusion_physical_plan::ExecutionPlan;
use datafusion_proto::logical_plan::from_proto::parse_exprs;
use datafusion_proto::logical_plan::to_proto::serialize_exprs;
use datafusion_proto::logical_plan::{
DefaultLogicalExtensionCodec, LogicalExtensionCodec,
};
use datafusion_proto::protobuf::LogicalExprList;
use prost::Message;
use tokio::runtime::Handle;
use super::execution_plan::FFI_ExecutionPlan;
use super::insert_op::FFI_InsertOp;
use crate::arrow_wrappers::WrappedSchema;
use crate::execution::FFI_TaskContextProvider;
use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
use crate::session::{FFI_SessionRef, ForeignSession};
use crate::table_source::{FFI_TableProviderFilterPushDown, FFI_TableType};
use crate::util::FFIResult;
use crate::{df_result, rresult_return};
#[repr(C)]
#[derive(Debug, StableAbi)]
pub struct FFI_TableProvider {
schema: unsafe extern "C" fn(provider: &Self) -> WrappedSchema,
scan: unsafe extern "C" fn(
provider: &Self,
session: FFI_SessionRef,
projections: RVec<usize>,
filters_serialized: RVec<u8>,
limit: ROption<usize>,
) -> FfiFuture<FFIResult<FFI_ExecutionPlan>>,
table_type: unsafe extern "C" fn(provider: &Self) -> FFI_TableType,
supports_filters_pushdown: Option<
unsafe extern "C" fn(
provider: &FFI_TableProvider,
filters_serialized: RVec<u8>,
) -> FFIResult<RVec<FFI_TableProviderFilterPushDown>>,
>,
insert_into: unsafe extern "C" fn(
provider: &Self,
session: FFI_SessionRef,
input: &FFI_ExecutionPlan,
insert_op: FFI_InsertOp,
) -> FfiFuture<FFIResult<FFI_ExecutionPlan>>,
pub logical_codec: FFI_LogicalExtensionCodec,
clone: unsafe extern "C" fn(plan: &Self) -> Self,
release: unsafe extern "C" fn(arg: &mut Self),
pub version: unsafe extern "C" fn() -> u64,
private_data: *mut c_void,
pub library_marker_id: extern "C" fn() -> usize,
}
unsafe impl Send for FFI_TableProvider {}
unsafe impl Sync for FFI_TableProvider {}
struct ProviderPrivateData {
provider: Arc<dyn TableProvider + Send>,
runtime: Option<Handle>,
}
impl FFI_TableProvider {
fn inner(&self) -> &Arc<dyn TableProvider + Send> {
let private_data = self.private_data as *const ProviderPrivateData;
unsafe { &(*private_data).provider }
}
fn runtime(&self) -> &Option<Handle> {
let private_data = self.private_data as *const ProviderPrivateData;
unsafe { &(*private_data).runtime }
}
}
unsafe extern "C" fn schema_fn_wrapper(provider: &FFI_TableProvider) -> WrappedSchema {
provider.inner().schema().into()
}
unsafe extern "C" fn table_type_fn_wrapper(
provider: &FFI_TableProvider,
) -> FFI_TableType {
provider.inner().table_type().into()
}
fn supports_filters_pushdown_internal(
provider: &Arc<dyn TableProvider + Send>,
filters_serialized: &[u8],
task_ctx: &Arc<TaskContext>,
codec: &dyn LogicalExtensionCodec,
) -> Result<RVec<FFI_TableProviderFilterPushDown>> {
let filters = match filters_serialized.is_empty() {
true => vec![],
false => {
let proto_filters = LogicalExprList::decode(filters_serialized)
.map_err(|e| DataFusionError::Plan(e.to_string()))?;
parse_exprs(proto_filters.expr.iter(), task_ctx.as_ref(), codec)?
}
};
let filters_borrowed: Vec<&Expr> = filters.iter().collect();
let results: RVec<_> = provider
.supports_filters_pushdown(&filters_borrowed)?
.iter()
.map(|v| v.into())
.collect();
Ok(results)
}
unsafe extern "C" fn supports_filters_pushdown_fn_wrapper(
provider: &FFI_TableProvider,
filters_serialized: RVec<u8>,
) -> FFIResult<RVec<FFI_TableProviderFilterPushDown>> {
let logical_codec: Arc<dyn LogicalExtensionCodec> = (&provider.logical_codec).into();
let task_ctx = rresult_return!(<Arc<TaskContext>>::try_from(
&provider.logical_codec.task_ctx_provider
));
supports_filters_pushdown_internal(
provider.inner(),
&filters_serialized,
&task_ctx,
logical_codec.as_ref(),
)
.map_err(|e| e.to_string().into())
.into()
}
unsafe extern "C" fn scan_fn_wrapper(
provider: &FFI_TableProvider,
session: FFI_SessionRef,
projections: RVec<usize>,
filters_serialized: RVec<u8>,
limit: ROption<usize>,
) -> FfiFuture<FFIResult<FFI_ExecutionPlan>> {
let task_ctx: Result<Arc<TaskContext>, DataFusionError> =
(&provider.logical_codec.task_ctx_provider).try_into();
let runtime = provider.runtime().clone();
let logical_codec: Arc<dyn LogicalExtensionCodec> = (&provider.logical_codec).into();
let internal_provider = Arc::clone(provider.inner());
async move {
let mut foreign_session = None;
let session = rresult_return!(
session
.as_local()
.map(Ok::<&(dyn Session + Send + Sync), DataFusionError>)
.unwrap_or_else(|| {
foreign_session = Some(ForeignSession::try_from(&session)?);
Ok(foreign_session.as_ref().unwrap())
})
);
let task_ctx = rresult_return!(task_ctx);
let filters = match filters_serialized.is_empty() {
true => vec![],
false => {
let proto_filters =
rresult_return!(LogicalExprList::decode(filters_serialized.as_ref()));
rresult_return!(parse_exprs(
proto_filters.expr.iter(),
task_ctx.as_ref(),
logical_codec.as_ref(),
))
}
};
let projections: Vec<_> = projections.into_iter().collect();
let plan = rresult_return!(
internal_provider
.scan(session, Some(&projections), &filters, limit.into())
.await
);
RResult::ROk(FFI_ExecutionPlan::new(plan, runtime.clone()))
}
.into_ffi()
}
unsafe extern "C" fn insert_into_fn_wrapper(
provider: &FFI_TableProvider,
session: FFI_SessionRef,
input: &FFI_ExecutionPlan,
insert_op: FFI_InsertOp,
) -> FfiFuture<FFIResult<FFI_ExecutionPlan>> {
let runtime = provider.runtime().clone();
let internal_provider = Arc::clone(provider.inner());
let input = input.clone();
async move {
let mut foreign_session = None;
let session = rresult_return!(
session
.as_local()
.map(Ok::<&(dyn Session + Send + Sync), DataFusionError>)
.unwrap_or_else(|| {
foreign_session = Some(ForeignSession::try_from(&session)?);
Ok(foreign_session.as_ref().unwrap())
})
);
let input = rresult_return!(<Arc<dyn ExecutionPlan>>::try_from(&input));
let insert_op = InsertOp::from(insert_op);
let plan = rresult_return!(
internal_provider
.insert_into(session, input, insert_op)
.await
);
RResult::ROk(FFI_ExecutionPlan::new(plan, runtime.clone()))
}
.into_ffi()
}
unsafe extern "C" fn release_fn_wrapper(provider: &mut FFI_TableProvider) {
unsafe {
debug_assert!(!provider.private_data.is_null());
let private_data =
Box::from_raw(provider.private_data as *mut ProviderPrivateData);
drop(private_data);
provider.private_data = std::ptr::null_mut();
}
}
unsafe extern "C" fn clone_fn_wrapper(provider: &FFI_TableProvider) -> FFI_TableProvider {
let runtime = provider.runtime().clone();
let old_provider = Arc::clone(provider.inner());
let private_data = Box::into_raw(Box::new(ProviderPrivateData {
provider: old_provider,
runtime,
})) as *mut c_void;
FFI_TableProvider {
schema: schema_fn_wrapper,
scan: scan_fn_wrapper,
table_type: table_type_fn_wrapper,
supports_filters_pushdown: provider.supports_filters_pushdown,
insert_into: provider.insert_into,
logical_codec: provider.logical_codec.clone(),
clone: clone_fn_wrapper,
release: release_fn_wrapper,
version: super::version,
private_data,
library_marker_id: crate::get_library_marker_id,
}
}
impl Drop for FFI_TableProvider {
fn drop(&mut self) {
unsafe { (self.release)(self) }
}
}
impl FFI_TableProvider {
pub fn new(
provider: Arc<dyn TableProvider + Send>,
can_support_pushdown_filters: bool,
runtime: Option<Handle>,
task_ctx_provider: impl Into<FFI_TaskContextProvider>,
logical_codec: Option<Arc<dyn LogicalExtensionCodec>>,
) -> Self {
let task_ctx_provider = task_ctx_provider.into();
let logical_codec =
logical_codec.unwrap_or_else(|| Arc::new(DefaultLogicalExtensionCodec {}));
let logical_codec = FFI_LogicalExtensionCodec::new(
logical_codec,
runtime.clone(),
task_ctx_provider.clone(),
);
Self::new_with_ffi_codec(
provider,
can_support_pushdown_filters,
runtime,
logical_codec,
)
}
pub fn new_with_ffi_codec(
provider: Arc<dyn TableProvider + Send>,
can_support_pushdown_filters: bool,
runtime: Option<Handle>,
logical_codec: FFI_LogicalExtensionCodec,
) -> Self {
if let Some(provider) = provider.as_any().downcast_ref::<ForeignTableProvider>() {
return provider.0.clone();
}
let private_data = Box::new(ProviderPrivateData { provider, runtime });
Self {
schema: schema_fn_wrapper,
scan: scan_fn_wrapper,
table_type: table_type_fn_wrapper,
supports_filters_pushdown: match can_support_pushdown_filters {
true => Some(supports_filters_pushdown_fn_wrapper),
false => None,
},
insert_into: insert_into_fn_wrapper,
logical_codec,
clone: clone_fn_wrapper,
release: release_fn_wrapper,
version: super::version,
private_data: Box::into_raw(private_data) as *mut c_void,
library_marker_id: crate::get_library_marker_id,
}
}
}
#[derive(Debug)]
pub struct ForeignTableProvider(pub FFI_TableProvider);
unsafe impl Send for ForeignTableProvider {}
unsafe impl Sync for ForeignTableProvider {}
impl From<&FFI_TableProvider> for Arc<dyn TableProvider> {
fn from(provider: &FFI_TableProvider) -> Self {
if (provider.library_marker_id)() == crate::get_library_marker_id() {
Arc::clone(provider.inner()) as Arc<dyn TableProvider>
} else {
Arc::new(ForeignTableProvider(provider.clone()))
}
}
}
impl Clone for FFI_TableProvider {
fn clone(&self) -> Self {
unsafe { (self.clone)(self) }
}
}
#[async_trait]
impl TableProvider for ForeignTableProvider {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
let wrapped_schema = unsafe { (self.0.schema)(&self.0) };
wrapped_schema.into()
}
fn table_type(&self) -> TableType {
unsafe { (self.0.table_type)(&self.0).into() }
}
async fn scan(
&self,
session: &dyn Session,
projection: Option<&Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let session = FFI_SessionRef::new(session, None, self.0.logical_codec.clone());
let projections: Option<RVec<usize>> =
projection.map(|p| p.iter().map(|v| v.to_owned()).collect());
let codec: Arc<dyn LogicalExtensionCodec> = (&self.0.logical_codec).into();
let filter_list = LogicalExprList {
expr: serialize_exprs(filters, codec.as_ref())?,
};
let filters_serialized = filter_list.encode_to_vec().into();
let plan = unsafe {
let maybe_plan = (self.0.scan)(
&self.0,
session,
projections.unwrap_or_default(),
filters_serialized,
limit.into(),
)
.await;
<Arc<dyn ExecutionPlan>>::try_from(&df_result!(maybe_plan)?)?
};
Ok(plan)
}
fn supports_filters_pushdown(
&self,
filters: &[&Expr],
) -> Result<Vec<TableProviderFilterPushDown>> {
unsafe {
let pushdown_fn = match self.0.supports_filters_pushdown {
Some(func) => func,
None => {
return Ok(vec![
TableProviderFilterPushDown::Unsupported;
filters.len()
]);
}
};
let codec: Arc<dyn LogicalExtensionCodec> = (&self.0.logical_codec).into();
let expr_list = LogicalExprList {
expr: serialize_exprs(
filters.iter().map(|f| f.to_owned()),
codec.as_ref(),
)?,
};
let serialized_filters = expr_list.encode_to_vec();
let pushdowns = df_result!(pushdown_fn(&self.0, serialized_filters.into()))?;
Ok(pushdowns.iter().map(|v| v.into()).collect())
}
}
async fn insert_into(
&self,
session: &dyn Session,
input: Arc<dyn ExecutionPlan>,
insert_op: InsertOp,
) -> Result<Arc<dyn ExecutionPlan>> {
let session = FFI_SessionRef::new(session, None, self.0.logical_codec.clone());
let rc = Handle::try_current().ok();
let input = FFI_ExecutionPlan::new(input, rc);
let insert_op: FFI_InsertOp = insert_op.into();
let plan = unsafe {
let maybe_plan =
(self.0.insert_into)(&self.0, session, &input, insert_op).await;
<Arc<dyn ExecutionPlan>>::try_from(&df_result!(maybe_plan)?)?
};
Ok(plan)
}
}
#[cfg(test)]
mod tests {
use arrow::datatypes::Schema;
use datafusion::prelude::{SessionContext, col, lit};
use datafusion_execution::TaskContextProvider;
use super::*;
fn create_test_table_provider() -> Result<Arc<dyn TableProvider>> {
use arrow::datatypes::Field;
use datafusion::arrow::array::Float32Array;
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::MemTable;
let schema =
Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)]));
let batch1 = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))],
)?;
let batch2 = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(Float32Array::from(vec![64.0]))],
)?;
Ok(Arc::new(MemTable::try_new(
schema,
vec![vec![batch1], vec![batch2]],
)?))
}
#[tokio::test]
async fn test_round_trip_ffi_table_provider_scan() -> Result<()> {
let provider = create_test_table_provider()?;
let ctx = Arc::new(SessionContext::new());
let task_ctx_provider = Arc::clone(&ctx) as Arc<dyn TaskContextProvider>;
let task_ctx_provider = FFI_TaskContextProvider::from(&task_ctx_provider);
let mut ffi_provider =
FFI_TableProvider::new(provider, true, None, task_ctx_provider, None);
ffi_provider.library_marker_id = crate::mock_foreign_marker_id;
let foreign_table_provider: Arc<dyn TableProvider> = (&ffi_provider).into();
ctx.register_table("t", foreign_table_provider)?;
let df = ctx.table("t").await?;
df.select(vec![col("a")])?
.filter(col("a").gt(lit(3.0)))?
.show()
.await?;
Ok(())
}
#[tokio::test]
async fn test_round_trip_ffi_table_provider_insert_into() -> Result<()> {
let provider = create_test_table_provider()?;
let ctx = Arc::new(SessionContext::new());
let task_ctx_provider = Arc::clone(&ctx) as Arc<dyn TaskContextProvider>;
let task_ctx_provider = FFI_TaskContextProvider::from(&task_ctx_provider);
let mut ffi_provider =
FFI_TableProvider::new(provider, true, None, task_ctx_provider, None);
ffi_provider.library_marker_id = crate::mock_foreign_marker_id;
let foreign_table_provider: Arc<dyn TableProvider> = (&ffi_provider).into();
ctx.register_table("t", foreign_table_provider)?;
let result = ctx
.sql("INSERT INTO t VALUES (128.0);")
.await?
.collect()
.await?;
assert!(result.len() == 1 && result[0].num_rows() == 1);
ctx.table("t")
.await?
.select(vec![col("a")])?
.filter(col("a").gt(lit(3.0)))?
.show()
.await?;
Ok(())
}
#[tokio::test]
async fn test_aggregation() -> Result<()> {
use arrow::datatypes::Field;
use datafusion::arrow::array::Float32Array;
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::common::assert_batches_eq;
use datafusion::datasource::MemTable;
let schema =
Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)]));
let batch1 = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))],
)?;
let ctx = Arc::new(SessionContext::new());
let task_ctx_provider = Arc::clone(&ctx) as Arc<dyn TaskContextProvider>;
let task_ctx_provider = FFI_TaskContextProvider::from(&task_ctx_provider);
let provider = Arc::new(MemTable::try_new(schema, vec![vec![batch1]])?);
let ffi_provider =
FFI_TableProvider::new(provider, true, None, task_ctx_provider, None);
let foreign_table_provider: Arc<dyn TableProvider> = (&ffi_provider).into();
ctx.register_table("t", foreign_table_provider)?;
let result = ctx
.sql("SELECT COUNT(*) as cnt FROM t")
.await?
.collect()
.await?;
#[rustfmt::skip]
let expected = [
"+-----+",
"| cnt |",
"+-----+",
"| 3 |",
"+-----+"
];
assert_batches_eq!(expected, &result);
Ok(())
}
#[test]
fn test_ffi_table_provider_local_bypass() -> Result<()> {
let table_provider = create_test_table_provider()?;
let ctx = Arc::new(SessionContext::new()) as Arc<dyn TaskContextProvider>;
let task_ctx_provider = FFI_TaskContextProvider::from(&ctx);
let mut ffi_table =
FFI_TableProvider::new(table_provider, false, None, task_ctx_provider, None);
let foreign_table: Arc<dyn TableProvider> = (&ffi_table).into();
assert!(
foreign_table
.as_any()
.downcast_ref::<datafusion::datasource::MemTable>()
.is_some()
);
ffi_table.library_marker_id = crate::mock_foreign_marker_id;
let foreign_table: Arc<dyn TableProvider> = (&ffi_table).into();
assert!(
foreign_table
.as_any()
.downcast_ref::<ForeignTableProvider>()
.is_some()
);
Ok(())
}
}