use std::ffi::c_void;
use std::task::Poll;
use abi_stable::StableAbi;
use abi_stable::std_types::{ROption, RResult};
use arrow::array::{Array, RecordBatch, StructArray, make_array};
use arrow::ffi::{from_ffi, to_ffi};
use async_ffi::{ContextExt, FfiContext, FfiPoll};
use datafusion_common::{DataFusionError, Result, ffi_datafusion_err, ffi_err};
use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream};
use futures::{Stream, TryStreamExt};
use tokio::runtime::Handle;
use crate::arrow_wrappers::{WrappedArray, WrappedSchema};
use crate::rresult;
use crate::util::FFIResult;
#[repr(C)]
#[derive(Debug, StableAbi)]
pub struct FFI_RecordBatchStream {
pub poll_next: unsafe extern "C" fn(
stream: &Self,
cx: &mut FfiContext,
) -> FfiPoll<ROption<FFIResult<WrappedArray>>>,
pub schema: unsafe extern "C" fn(stream: &Self) -> WrappedSchema,
pub release: unsafe extern "C" fn(arg: &mut Self),
pub private_data: *mut c_void,
}
pub struct RecordBatchStreamPrivateData {
pub rbs: SendableRecordBatchStream,
pub runtime: Option<Handle>,
}
impl From<SendableRecordBatchStream> for FFI_RecordBatchStream {
fn from(stream: SendableRecordBatchStream) -> Self {
Self::new(stream, None)
}
}
impl FFI_RecordBatchStream {
pub fn new(stream: SendableRecordBatchStream, runtime: Option<Handle>) -> Self {
let private_data = Box::into_raw(Box::new(RecordBatchStreamPrivateData {
rbs: stream,
runtime,
})) as *mut c_void;
FFI_RecordBatchStream {
poll_next: poll_next_fn_wrapper,
schema: schema_fn_wrapper,
release: release_fn_wrapper,
private_data,
}
}
}
unsafe impl Send for FFI_RecordBatchStream {}
unsafe extern "C" fn schema_fn_wrapper(stream: &FFI_RecordBatchStream) -> WrappedSchema {
unsafe {
let private_data = stream.private_data as *const RecordBatchStreamPrivateData;
let stream = &(*private_data).rbs;
(*stream).schema().into()
}
}
unsafe extern "C" fn release_fn_wrapper(provider: &mut FFI_RecordBatchStream) {
unsafe {
debug_assert!(!provider.private_data.is_null());
let private_data =
Box::from_raw(provider.private_data as *mut RecordBatchStreamPrivateData);
drop(private_data);
provider.private_data = std::ptr::null_mut();
}
}
pub(crate) fn record_batch_to_wrapped_array(
record_batch: RecordBatch,
) -> FFIResult<WrappedArray> {
let schema = WrappedSchema::from(record_batch.schema());
let struct_array = StructArray::from(record_batch);
rresult!(
to_ffi(&struct_array.to_data())
.map(|(array, _schema)| WrappedArray { array, schema })
)
}
fn maybe_record_batch_to_wrapped_stream(
record_batch: Option<Result<RecordBatch>>,
) -> ROption<FFIResult<WrappedArray>> {
match record_batch {
Some(Ok(record_batch)) => {
ROption::RSome(record_batch_to_wrapped_array(record_batch))
}
Some(Err(e)) => ROption::RSome(RResult::RErr(e.to_string().into())),
None => ROption::RNone,
}
}
unsafe extern "C" fn poll_next_fn_wrapper(
stream: &FFI_RecordBatchStream,
cx: &mut FfiContext,
) -> FfiPoll<ROption<FFIResult<WrappedArray>>> {
unsafe {
let private_data = stream.private_data as *mut RecordBatchStreamPrivateData;
let stream = &mut (*private_data).rbs;
let _guard = (*private_data).runtime.as_ref().map(|rt| rt.enter());
let poll_result = cx.with_context(|std_cx| {
(*stream)
.try_poll_next_unpin(std_cx)
.map(maybe_record_batch_to_wrapped_stream)
});
poll_result.into()
}
}
impl RecordBatchStream for FFI_RecordBatchStream {
fn schema(&self) -> arrow::datatypes::SchemaRef {
let wrapped_schema = unsafe { (self.schema)(self) };
wrapped_schema.into()
}
}
pub(crate) fn wrapped_array_to_record_batch(array: WrappedArray) -> Result<RecordBatch> {
let array_data =
unsafe { from_ffi(array.array, &array.schema.0).map_err(DataFusionError::from)? };
let schema: arrow::datatypes::SchemaRef = array.schema.into();
let array = make_array(array_data);
let struct_array = array
.as_any()
.downcast_ref::<StructArray>()
.ok_or_else(|| ffi_datafusion_err!(
"Unexpected array type during record batch collection in FFI_RecordBatchStream - expected StructArray"
))?;
let rb: RecordBatch = struct_array.into();
rb.with_schema(schema).map_err(Into::into)
}
fn maybe_wrapped_array_to_record_batch(
array: ROption<FFIResult<WrappedArray>>,
) -> Option<Result<RecordBatch>> {
match array {
ROption::RSome(RResult::ROk(wrapped_array)) => {
Some(wrapped_array_to_record_batch(wrapped_array))
}
ROption::RSome(RResult::RErr(e)) => Some(ffi_err!("{e}")),
ROption::RNone => None,
}
}
impl Stream for FFI_RecordBatchStream {
type Item = Result<RecordBatch>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let poll_result =
unsafe { cx.with_ffi_context(|ffi_cx| (self.poll_next)(&self, ffi_cx)) };
match poll_result {
FfiPoll::Ready(array) => {
Poll::Ready(maybe_wrapped_array_to_record_batch(array))
}
FfiPoll::Pending => Poll::Pending,
FfiPoll::Panicked => Poll::Ready(Some(ffi_err!(
"Panic occurred during poll_next on FFI_RecordBatchStream"
))),
}
}
}
impl Drop for FFI_RecordBatchStream {
fn drop(&mut self) {
unsafe { (self.release)(self) }
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion::common::record_batch;
use datafusion::error::Result;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::test_util::bounded_stream;
use futures::StreamExt;
use super::{
FFI_RecordBatchStream, record_batch_to_wrapped_array,
wrapped_array_to_record_batch,
};
use crate::df_result;
#[tokio::test]
async fn test_round_trip_record_batch_stream() -> Result<()> {
let record_batch = record_batch!(
("a", Int32, vec![1, 2, 3]),
("b", Float64, vec![Some(4.0), None, Some(5.0)])
)?;
let original_rbs = bounded_stream(record_batch.clone(), 1);
let ffi_rbs: FFI_RecordBatchStream = original_rbs.into();
let mut ffi_rbs: SendableRecordBatchStream = Box::pin(ffi_rbs);
let schema = ffi_rbs.schema();
assert_eq!(
schema,
Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Float64, true)
]))
);
let batch = ffi_rbs.next().await;
assert!(batch.is_some());
assert!(batch.as_ref().unwrap().is_ok());
assert_eq!(batch.unwrap().unwrap(), record_batch);
let no_batch = ffi_rbs.next().await;
assert!(no_batch.is_none());
Ok(())
}
#[test]
fn round_trip_record_batch_with_metadata() -> Result<()> {
let rb = record_batch!(
("a", Int32, vec![1, 2, 3]),
("b", Float64, vec![Some(4.0), None, Some(5.0)])
)?;
let schema = rb
.schema()
.as_ref()
.clone()
.with_metadata([("some_key".to_owned(), "some_value".to_owned())].into())
.into();
let rb = rb.with_schema(schema)?;
let ffi_rb = df_result!(record_batch_to_wrapped_array(rb.clone()))?;
let round_trip_rb = wrapped_array_to_record_batch(ffi_rb)?;
assert_eq!(rb, round_trip_rb);
Ok(())
}
}