datafusion_ffi/
record_batch_stream.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{ffi::c_void, task::Poll};
19
20use abi_stable::{
21    std_types::{ROption, RResult, RString},
22    StableAbi,
23};
24use arrow::array::{Array, RecordBatch};
25use arrow::{
26    array::{make_array, StructArray},
27    ffi::{from_ffi, to_ffi},
28};
29use async_ffi::{ContextExt, FfiContext, FfiPoll};
30use datafusion::error::Result;
31use datafusion::{
32    error::DataFusionError,
33    execution::{RecordBatchStream, SendableRecordBatchStream},
34};
35use datafusion_common::{exec_datafusion_err, exec_err};
36use futures::{Stream, TryStreamExt};
37use tokio::runtime::Handle;
38
39use crate::{
40    arrow_wrappers::{WrappedArray, WrappedSchema},
41    rresult,
42};
43
44/// A stable struct for sharing [`RecordBatchStream`] across FFI boundaries.
45/// We use the async-ffi crate for handling async calls across libraries.
46#[repr(C)]
47#[derive(Debug, StableAbi)]
48#[allow(non_camel_case_types)]
49pub struct FFI_RecordBatchStream {
50    /// This mirrors the `poll_next` of [`RecordBatchStream`] but does so
51    /// in a FFI safe manner.
52    pub poll_next:
53        unsafe extern "C" fn(
54            stream: &Self,
55            cx: &mut FfiContext,
56        ) -> FfiPoll<ROption<RResult<WrappedArray, RString>>>,
57
58    /// Return the schema of the record batch
59    pub schema: unsafe extern "C" fn(stream: &Self) -> WrappedSchema,
60
61    /// Release the memory of the private data when it is no longer being used.
62    pub release: unsafe extern "C" fn(arg: &mut Self),
63
64    /// Internal data. This is only to be accessed by the provider of the plan.
65    /// The foreign library should never attempt to access this data.
66    pub private_data: *mut c_void,
67}
68
69pub struct RecordBatchStreamPrivateData {
70    pub rbs: SendableRecordBatchStream,
71    pub runtime: Option<Handle>,
72}
73
74impl From<SendableRecordBatchStream> for FFI_RecordBatchStream {
75    fn from(stream: SendableRecordBatchStream) -> Self {
76        Self::new(stream, None)
77    }
78}
79
80impl FFI_RecordBatchStream {
81    pub fn new(stream: SendableRecordBatchStream, runtime: Option<Handle>) -> Self {
82        let private_data = Box::into_raw(Box::new(RecordBatchStreamPrivateData {
83            rbs: stream,
84            runtime,
85        })) as *mut c_void;
86        FFI_RecordBatchStream {
87            poll_next: poll_next_fn_wrapper,
88            schema: schema_fn_wrapper,
89            release: release_fn_wrapper,
90            private_data,
91        }
92    }
93}
94
95unsafe impl Send for FFI_RecordBatchStream {}
96
97unsafe extern "C" fn schema_fn_wrapper(stream: &FFI_RecordBatchStream) -> WrappedSchema {
98    let private_data = stream.private_data as *const RecordBatchStreamPrivateData;
99    let stream = &(*private_data).rbs;
100
101    (*stream).schema().into()
102}
103
104unsafe extern "C" fn release_fn_wrapper(provider: &mut FFI_RecordBatchStream) {
105    let private_data =
106        Box::from_raw(provider.private_data as *mut RecordBatchStreamPrivateData);
107    drop(private_data);
108}
109
110fn record_batch_to_wrapped_array(
111    record_batch: RecordBatch,
112) -> RResult<WrappedArray, RString> {
113    let struct_array = StructArray::from(record_batch);
114    rresult!(
115        to_ffi(&struct_array.to_data()).map(|(array, schema)| WrappedArray {
116            array,
117            schema: WrappedSchema(schema)
118        })
119    )
120}
121
122// probably want to use pub unsafe fn from_ffi(array: FFI_ArrowArray, schema: &FFI_ArrowSchema) -> Result<ArrayData> {
123fn maybe_record_batch_to_wrapped_stream(
124    record_batch: Option<Result<RecordBatch>>,
125) -> ROption<RResult<WrappedArray, RString>> {
126    match record_batch {
127        Some(Ok(record_batch)) => {
128            ROption::RSome(record_batch_to_wrapped_array(record_batch))
129        }
130        Some(Err(e)) => ROption::RSome(RResult::RErr(e.to_string().into())),
131        None => ROption::RNone,
132    }
133}
134
135unsafe extern "C" fn poll_next_fn_wrapper(
136    stream: &FFI_RecordBatchStream,
137    cx: &mut FfiContext,
138) -> FfiPoll<ROption<RResult<WrappedArray, RString>>> {
139    let private_data = stream.private_data as *mut RecordBatchStreamPrivateData;
140    let stream = &mut (*private_data).rbs;
141
142    let _guard = (*private_data).runtime.as_ref().map(|rt| rt.enter());
143
144    let poll_result = cx.with_context(|std_cx| {
145        (*stream)
146            .try_poll_next_unpin(std_cx)
147            .map(maybe_record_batch_to_wrapped_stream)
148    });
149
150    poll_result.into()
151}
152
153impl RecordBatchStream for FFI_RecordBatchStream {
154    fn schema(&self) -> arrow::datatypes::SchemaRef {
155        let wrapped_schema = unsafe { (self.schema)(self) };
156        wrapped_schema.into()
157    }
158}
159
160fn wrapped_array_to_record_batch(array: WrappedArray) -> Result<RecordBatch> {
161    let array_data =
162        unsafe { from_ffi(array.array, &array.schema.0).map_err(DataFusionError::from)? };
163    let array = make_array(array_data);
164    let struct_array = array
165        .as_any()
166        .downcast_ref::<StructArray>()
167        .ok_or_else(|| exec_datafusion_err!(
168        "Unexpected array type during record batch collection in FFI_RecordBatchStream - expected StructArray"
169    ))?;
170
171    Ok(struct_array.into())
172}
173
174fn maybe_wrapped_array_to_record_batch(
175    array: ROption<RResult<WrappedArray, RString>>,
176) -> Option<Result<RecordBatch>> {
177    match array {
178        ROption::RSome(RResult::ROk(wrapped_array)) => {
179            Some(wrapped_array_to_record_batch(wrapped_array))
180        }
181        ROption::RSome(RResult::RErr(e)) => Some(exec_err!("FFI error: {e}")),
182        ROption::RNone => None,
183    }
184}
185
186impl Stream for FFI_RecordBatchStream {
187    type Item = Result<RecordBatch>;
188
189    fn poll_next(
190        self: std::pin::Pin<&mut Self>,
191        cx: &mut std::task::Context<'_>,
192    ) -> Poll<Option<Self::Item>> {
193        let poll_result =
194            unsafe { cx.with_ffi_context(|ffi_cx| (self.poll_next)(&self, ffi_cx)) };
195
196        match poll_result {
197            FfiPoll::Ready(array) => {
198                Poll::Ready(maybe_wrapped_array_to_record_batch(array))
199            }
200            FfiPoll::Pending => Poll::Pending,
201            FfiPoll::Panicked => Poll::Ready(Some(exec_err!(
202                "Panic occurred during poll_next on FFI_RecordBatchStream"
203            ))),
204        }
205    }
206}
207
208impl Drop for FFI_RecordBatchStream {
209    fn drop(&mut self) {
210        unsafe { (self.release)(self) }
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use std::sync::Arc;
217
218    use arrow::datatypes::{DataType, Field, Schema};
219    use datafusion::{
220        common::record_batch, error::Result, execution::SendableRecordBatchStream,
221        test_util::bounded_stream,
222    };
223
224    use super::FFI_RecordBatchStream;
225    use futures::StreamExt;
226
227    #[tokio::test]
228    async fn test_round_trip_record_batch_stream() -> Result<()> {
229        let record_batch = record_batch!(
230            ("a", Int32, vec![1, 2, 3]),
231            ("b", Float64, vec![Some(4.0), None, Some(5.0)])
232        )?;
233        let original_rbs = bounded_stream(record_batch.clone(), 1);
234
235        let ffi_rbs: FFI_RecordBatchStream = original_rbs.into();
236        let mut ffi_rbs: SendableRecordBatchStream = Box::pin(ffi_rbs);
237
238        let schema = ffi_rbs.schema();
239        assert_eq!(
240            schema,
241            Arc::new(Schema::new(vec![
242                Field::new("a", DataType::Int32, true),
243                Field::new("b", DataType::Float64, true)
244            ]))
245        );
246
247        let batch = ffi_rbs.next().await;
248        assert!(batch.is_some());
249        assert!(batch.as_ref().unwrap().is_ok());
250        assert_eq!(batch.unwrap().unwrap(), record_batch);
251
252        // There should only be one batch
253        let no_batch = ffi_rbs.next().await;
254        assert!(no_batch.is_none());
255
256        Ok(())
257    }
258}