datafusion_ffi/
record_batch_stream.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::ffi::c_void;
19use std::task::Poll;
20
21use abi_stable::StableAbi;
22use abi_stable::std_types::{ROption, RResult};
23use arrow::array::{Array, RecordBatch, StructArray, make_array};
24use arrow::ffi::{from_ffi, to_ffi};
25use async_ffi::{ContextExt, FfiContext, FfiPoll};
26use datafusion_common::{DataFusionError, Result, ffi_datafusion_err, ffi_err};
27use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream};
28use futures::{Stream, TryStreamExt};
29use tokio::runtime::Handle;
30
31use crate::arrow_wrappers::{WrappedArray, WrappedSchema};
32use crate::rresult;
33use crate::util::FFIResult;
34
35/// A stable struct for sharing [`RecordBatchStream`] across FFI boundaries.
36/// We use the async-ffi crate for handling async calls across libraries.
37#[repr(C)]
38#[derive(Debug, StableAbi)]
39pub struct FFI_RecordBatchStream {
40    /// This mirrors the `poll_next` of [`RecordBatchStream`] but does so
41    /// in a FFI safe manner.
42    pub poll_next: unsafe extern "C" fn(
43        stream: &Self,
44        cx: &mut FfiContext,
45    ) -> FfiPoll<ROption<FFIResult<WrappedArray>>>,
46
47    /// Return the schema of the record batch
48    pub schema: unsafe extern "C" fn(stream: &Self) -> WrappedSchema,
49
50    /// Release the memory of the private data when it is no longer being used.
51    pub release: unsafe extern "C" fn(arg: &mut Self),
52
53    /// Internal data. This is only to be accessed by the provider of the plan.
54    /// The foreign library should never attempt to access this data.
55    pub private_data: *mut c_void,
56}
57
58pub struct RecordBatchStreamPrivateData {
59    pub rbs: SendableRecordBatchStream,
60    pub runtime: Option<Handle>,
61}
62
63impl From<SendableRecordBatchStream> for FFI_RecordBatchStream {
64    fn from(stream: SendableRecordBatchStream) -> Self {
65        Self::new(stream, None)
66    }
67}
68
69impl FFI_RecordBatchStream {
70    pub fn new(stream: SendableRecordBatchStream, runtime: Option<Handle>) -> Self {
71        let private_data = Box::into_raw(Box::new(RecordBatchStreamPrivateData {
72            rbs: stream,
73            runtime,
74        })) as *mut c_void;
75        FFI_RecordBatchStream {
76            poll_next: poll_next_fn_wrapper,
77            schema: schema_fn_wrapper,
78            release: release_fn_wrapper,
79            private_data,
80        }
81    }
82}
83
84unsafe impl Send for FFI_RecordBatchStream {}
85
86unsafe extern "C" fn schema_fn_wrapper(stream: &FFI_RecordBatchStream) -> WrappedSchema {
87    unsafe {
88        let private_data = stream.private_data as *const RecordBatchStreamPrivateData;
89        let stream = &(*private_data).rbs;
90
91        (*stream).schema().into()
92    }
93}
94
95unsafe extern "C" fn release_fn_wrapper(provider: &mut FFI_RecordBatchStream) {
96    unsafe {
97        debug_assert!(!provider.private_data.is_null());
98        let private_data =
99            Box::from_raw(provider.private_data as *mut RecordBatchStreamPrivateData);
100        drop(private_data);
101        provider.private_data = std::ptr::null_mut();
102    }
103}
104
105pub(crate) fn record_batch_to_wrapped_array(
106    record_batch: RecordBatch,
107) -> FFIResult<WrappedArray> {
108    let schema = WrappedSchema::from(record_batch.schema());
109    let struct_array = StructArray::from(record_batch);
110    rresult!(
111        to_ffi(&struct_array.to_data())
112            .map(|(array, _schema)| WrappedArray { array, schema })
113    )
114}
115
116// probably want to use pub unsafe fn from_ffi(array: FFI_ArrowArray, schema: &FFI_ArrowSchema) -> Result<ArrayData> {
117fn maybe_record_batch_to_wrapped_stream(
118    record_batch: Option<Result<RecordBatch>>,
119) -> ROption<FFIResult<WrappedArray>> {
120    match record_batch {
121        Some(Ok(record_batch)) => {
122            ROption::RSome(record_batch_to_wrapped_array(record_batch))
123        }
124        Some(Err(e)) => ROption::RSome(RResult::RErr(e.to_string().into())),
125        None => ROption::RNone,
126    }
127}
128
129unsafe extern "C" fn poll_next_fn_wrapper(
130    stream: &FFI_RecordBatchStream,
131    cx: &mut FfiContext,
132) -> FfiPoll<ROption<FFIResult<WrappedArray>>> {
133    unsafe {
134        let private_data = stream.private_data as *mut RecordBatchStreamPrivateData;
135        let stream = &mut (*private_data).rbs;
136
137        let _guard = (*private_data).runtime.as_ref().map(|rt| rt.enter());
138
139        let poll_result = cx.with_context(|std_cx| {
140            (*stream)
141                .try_poll_next_unpin(std_cx)
142                .map(maybe_record_batch_to_wrapped_stream)
143        });
144
145        poll_result.into()
146    }
147}
148
149impl RecordBatchStream for FFI_RecordBatchStream {
150    fn schema(&self) -> arrow::datatypes::SchemaRef {
151        let wrapped_schema = unsafe { (self.schema)(self) };
152        wrapped_schema.into()
153    }
154}
155
156pub(crate) fn wrapped_array_to_record_batch(array: WrappedArray) -> Result<RecordBatch> {
157    let array_data =
158        unsafe { from_ffi(array.array, &array.schema.0).map_err(DataFusionError::from)? };
159    let schema: arrow::datatypes::SchemaRef = array.schema.into();
160    let array = make_array(array_data);
161    let struct_array = array
162        .as_any()
163        .downcast_ref::<StructArray>()
164        .ok_or_else(|| ffi_datafusion_err!(
165        "Unexpected array type during record batch collection in FFI_RecordBatchStream - expected StructArray"
166    ))?;
167
168    let rb: RecordBatch = struct_array.into();
169
170    rb.with_schema(schema).map_err(Into::into)
171}
172
173fn maybe_wrapped_array_to_record_batch(
174    array: ROption<FFIResult<WrappedArray>>,
175) -> Option<Result<RecordBatch>> {
176    match array {
177        ROption::RSome(RResult::ROk(wrapped_array)) => {
178            Some(wrapped_array_to_record_batch(wrapped_array))
179        }
180        ROption::RSome(RResult::RErr(e)) => Some(ffi_err!("{e}")),
181        ROption::RNone => None,
182    }
183}
184
185impl Stream for FFI_RecordBatchStream {
186    type Item = Result<RecordBatch>;
187
188    fn poll_next(
189        self: std::pin::Pin<&mut Self>,
190        cx: &mut std::task::Context<'_>,
191    ) -> Poll<Option<Self::Item>> {
192        let poll_result =
193            unsafe { cx.with_ffi_context(|ffi_cx| (self.poll_next)(&self, ffi_cx)) };
194
195        match poll_result {
196            FfiPoll::Ready(array) => {
197                Poll::Ready(maybe_wrapped_array_to_record_batch(array))
198            }
199            FfiPoll::Pending => Poll::Pending,
200            FfiPoll::Panicked => Poll::Ready(Some(ffi_err!(
201                "Panic occurred during poll_next on FFI_RecordBatchStream"
202            ))),
203        }
204    }
205}
206
207impl Drop for FFI_RecordBatchStream {
208    fn drop(&mut self) {
209        unsafe { (self.release)(self) }
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use std::sync::Arc;
216
217    use arrow::datatypes::{DataType, Field, Schema};
218    use datafusion::common::record_batch;
219    use datafusion::error::Result;
220    use datafusion::execution::SendableRecordBatchStream;
221    use datafusion::test_util::bounded_stream;
222    use futures::StreamExt;
223
224    use super::{
225        FFI_RecordBatchStream, record_batch_to_wrapped_array,
226        wrapped_array_to_record_batch,
227    };
228    use crate::df_result;
229
230    #[tokio::test]
231    async fn test_round_trip_record_batch_stream() -> Result<()> {
232        let record_batch = record_batch!(
233            ("a", Int32, vec![1, 2, 3]),
234            ("b", Float64, vec![Some(4.0), None, Some(5.0)])
235        )?;
236        let original_rbs = bounded_stream(record_batch.clone(), 1);
237
238        let ffi_rbs: FFI_RecordBatchStream = original_rbs.into();
239        let mut ffi_rbs: SendableRecordBatchStream = Box::pin(ffi_rbs);
240
241        let schema = ffi_rbs.schema();
242        assert_eq!(
243            schema,
244            Arc::new(Schema::new(vec![
245                Field::new("a", DataType::Int32, true),
246                Field::new("b", DataType::Float64, true)
247            ]))
248        );
249
250        let batch = ffi_rbs.next().await;
251        assert!(batch.is_some());
252        assert!(batch.as_ref().unwrap().is_ok());
253        assert_eq!(batch.unwrap().unwrap(), record_batch);
254
255        // There should only be one batch
256        let no_batch = ffi_rbs.next().await;
257        assert!(no_batch.is_none());
258
259        Ok(())
260    }
261
262    #[test]
263    fn round_trip_record_batch_with_metadata() -> Result<()> {
264        let rb = record_batch!(
265            ("a", Int32, vec![1, 2, 3]),
266            ("b", Float64, vec![Some(4.0), None, Some(5.0)])
267        )?;
268
269        let schema = rb
270            .schema()
271            .as_ref()
272            .clone()
273            .with_metadata([("some_key".to_owned(), "some_value".to_owned())].into())
274            .into();
275
276        let rb = rb.with_schema(schema)?;
277
278        let ffi_rb = df_result!(record_batch_to_wrapped_array(rb.clone()))?;
279
280        let round_trip_rb = wrapped_array_to_record_batch(ffi_rb)?;
281
282        assert_eq!(rb, round_trip_rb);
283        Ok(())
284    }
285}