datafusion_ffi/
record_batch_stream.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{ffi::c_void, task::Poll};
19
20use abi_stable::{
21    std_types::{ROption, RResult, RString},
22    StableAbi,
23};
24use arrow::array::{Array, RecordBatch};
25use arrow::{
26    array::{make_array, StructArray},
27    ffi::{from_ffi, to_ffi},
28};
29use async_ffi::{ContextExt, FfiContext, FfiPoll};
30use datafusion::error::Result;
31use datafusion::{
32    error::DataFusionError,
33    execution::{RecordBatchStream, SendableRecordBatchStream},
34};
35use futures::{Stream, TryStreamExt};
36use tokio::runtime::Handle;
37
38use crate::{
39    arrow_wrappers::{WrappedArray, WrappedSchema},
40    rresult,
41};
42
43/// A stable struct for sharing [`RecordBatchStream`] across FFI boundaries.
44/// We use the async-ffi crate for handling async calls across libraries.
45#[repr(C)]
46#[derive(Debug, StableAbi)]
47#[allow(non_camel_case_types)]
48pub struct FFI_RecordBatchStream {
49    /// This mirrors the `poll_next` of [`RecordBatchStream`] but does so
50    /// in a FFI safe manner.
51    pub poll_next:
52        unsafe extern "C" fn(
53            stream: &Self,
54            cx: &mut FfiContext,
55        ) -> FfiPoll<ROption<RResult<WrappedArray, RString>>>,
56
57    /// Return the schema of the record batch
58    pub schema: unsafe extern "C" fn(stream: &Self) -> WrappedSchema,
59
60    /// Internal data. This is only to be accessed by the provider of the plan.
61    /// The foreign library should never attempt to access this data.
62    pub private_data: *mut c_void,
63}
64
65pub struct RecordBatchStreamPrivateData {
66    pub rbs: SendableRecordBatchStream,
67    pub runtime: Option<Handle>,
68}
69
70impl From<SendableRecordBatchStream> for FFI_RecordBatchStream {
71    fn from(stream: SendableRecordBatchStream) -> Self {
72        Self::new(stream, None)
73    }
74}
75
76impl FFI_RecordBatchStream {
77    pub fn new(stream: SendableRecordBatchStream, runtime: Option<Handle>) -> Self {
78        let private_data = Box::into_raw(Box::new(RecordBatchStreamPrivateData {
79            rbs: stream,
80            runtime,
81        })) as *mut c_void;
82        FFI_RecordBatchStream {
83            poll_next: poll_next_fn_wrapper,
84            schema: schema_fn_wrapper,
85            private_data,
86        }
87    }
88}
89
90unsafe impl Send for FFI_RecordBatchStream {}
91
92unsafe extern "C" fn schema_fn_wrapper(stream: &FFI_RecordBatchStream) -> WrappedSchema {
93    let private_data = stream.private_data as *const RecordBatchStreamPrivateData;
94    let stream = &(*private_data).rbs;
95
96    (*stream).schema().into()
97}
98
99fn record_batch_to_wrapped_array(
100    record_batch: RecordBatch,
101) -> RResult<WrappedArray, RString> {
102    let struct_array = StructArray::from(record_batch);
103    rresult!(
104        to_ffi(&struct_array.to_data()).map(|(array, schema)| WrappedArray {
105            array,
106            schema: WrappedSchema(schema)
107        })
108    )
109}
110
111// probably want to use pub unsafe fn from_ffi(array: FFI_ArrowArray, schema: &FFI_ArrowSchema) -> Result<ArrayData> {
112fn maybe_record_batch_to_wrapped_stream(
113    record_batch: Option<Result<RecordBatch>>,
114) -> ROption<RResult<WrappedArray, RString>> {
115    match record_batch {
116        Some(Ok(record_batch)) => {
117            ROption::RSome(record_batch_to_wrapped_array(record_batch))
118        }
119        Some(Err(e)) => ROption::RSome(RResult::RErr(e.to_string().into())),
120        None => ROption::RNone,
121    }
122}
123
124unsafe extern "C" fn poll_next_fn_wrapper(
125    stream: &FFI_RecordBatchStream,
126    cx: &mut FfiContext,
127) -> FfiPoll<ROption<RResult<WrappedArray, RString>>> {
128    let private_data = stream.private_data as *mut RecordBatchStreamPrivateData;
129    let stream = &mut (*private_data).rbs;
130
131    let _guard = (*private_data).runtime.as_ref().map(|rt| rt.enter());
132
133    let poll_result = cx.with_context(|std_cx| {
134        (*stream)
135            .try_poll_next_unpin(std_cx)
136            .map(maybe_record_batch_to_wrapped_stream)
137    });
138
139    poll_result.into()
140}
141
142impl RecordBatchStream for FFI_RecordBatchStream {
143    fn schema(&self) -> arrow::datatypes::SchemaRef {
144        let wrapped_schema = unsafe { (self.schema)(self) };
145        wrapped_schema.into()
146    }
147}
148
149fn wrapped_array_to_record_batch(array: WrappedArray) -> Result<RecordBatch> {
150    let array_data =
151        unsafe { from_ffi(array.array, &array.schema.0).map_err(DataFusionError::from)? };
152    let array = make_array(array_data);
153    let struct_array = array
154        .as_any()
155        .downcast_ref::<StructArray>()
156        .ok_or(DataFusionError::Execution(
157        "Unexpected array type during record batch collection in FFI_RecordBatchStream"
158            .to_string(),
159    ))?;
160
161    Ok(struct_array.into())
162}
163
164fn maybe_wrapped_array_to_record_batch(
165    array: ROption<RResult<WrappedArray, RString>>,
166) -> Option<Result<RecordBatch>> {
167    match array {
168        ROption::RSome(RResult::ROk(wrapped_array)) => {
169            Some(wrapped_array_to_record_batch(wrapped_array))
170        }
171        ROption::RSome(RResult::RErr(e)) => {
172            Some(Err(DataFusionError::Execution(e.to_string())))
173        }
174        ROption::RNone => None,
175    }
176}
177
178impl Stream for FFI_RecordBatchStream {
179    type Item = Result<RecordBatch>;
180
181    fn poll_next(
182        self: std::pin::Pin<&mut Self>,
183        cx: &mut std::task::Context<'_>,
184    ) -> Poll<Option<Self::Item>> {
185        let poll_result =
186            unsafe { cx.with_ffi_context(|ffi_cx| (self.poll_next)(&self, ffi_cx)) };
187
188        match poll_result {
189            FfiPoll::Ready(array) => {
190                Poll::Ready(maybe_wrapped_array_to_record_batch(array))
191            }
192            FfiPoll::Pending => Poll::Pending,
193            FfiPoll::Panicked => Poll::Ready(Some(Err(DataFusionError::Execution(
194                "Error occurred during poll_next on FFI_RecordBatchStream".to_string(),
195            )))),
196        }
197    }
198}