Skip to main content

datafusion_ffi/
record_batch_stream.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::ffi::c_void;
19use std::task::Poll;
20
21use arrow::array::{Array, RecordBatch, StructArray, make_array};
22use arrow::ffi::{from_ffi, to_ffi};
23use async_ffi::{ContextExt, FfiContext, FfiPoll};
24use datafusion_common::{DataFusionError, Result, ffi_datafusion_err, ffi_err};
25use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream};
26use futures::{Stream, TryStreamExt};
27
28use tokio::runtime::Handle;
29
30use crate::arrow_wrappers::{WrappedArray, WrappedSchema};
31use crate::sresult;
32use crate::util::{FFI_Option, FFI_Result};
33
34/// A stable struct for sharing [`RecordBatchStream`] across FFI boundaries.
35/// We use the async-ffi crate for handling async calls across libraries.
36#[repr(C)]
37#[derive(Debug)]
38pub struct FFI_RecordBatchStream {
39    /// This mirrors the `poll_next` of [`RecordBatchStream`] but does so
40    /// in a FFI safe manner.
41    pub poll_next: unsafe extern "C" fn(
42        stream: &Self,
43        cx: &mut FfiContext,
44    )
45        -> FfiPoll<FFI_Option<FFI_Result<WrappedArray>>>,
46
47    /// Return the schema of the record batch
48    pub schema: unsafe extern "C" fn(stream: &Self) -> WrappedSchema,
49
50    /// Release the memory of the private data when it is no longer being used.
51    pub release: unsafe extern "C" fn(arg: &mut Self),
52
53    /// Internal data. This is only to be accessed by the provider of the plan.
54    /// The foreign library should never attempt to access this data.
55    pub private_data: *mut c_void,
56}
57
58pub struct RecordBatchStreamPrivateData {
59    pub rbs: SendableRecordBatchStream,
60    pub runtime: Option<Handle>,
61}
62
63impl From<SendableRecordBatchStream> for FFI_RecordBatchStream {
64    fn from(stream: SendableRecordBatchStream) -> Self {
65        Self::new(stream, None)
66    }
67}
68
69impl FFI_RecordBatchStream {
70    pub fn new(stream: SendableRecordBatchStream, runtime: Option<Handle>) -> Self {
71        let private_data = Box::into_raw(Box::new(RecordBatchStreamPrivateData {
72            rbs: stream,
73            runtime,
74        })) as *mut c_void;
75        FFI_RecordBatchStream {
76            poll_next: poll_next_fn_wrapper,
77            schema: schema_fn_wrapper,
78            release: release_fn_wrapper,
79            private_data,
80        }
81    }
82}
83
84unsafe impl Send for FFI_RecordBatchStream {}
85
86unsafe extern "C" fn schema_fn_wrapper(stream: &FFI_RecordBatchStream) -> WrappedSchema {
87    unsafe {
88        let private_data = stream.private_data as *const RecordBatchStreamPrivateData;
89        let stream = &(*private_data).rbs;
90
91        (*stream).schema().into()
92    }
93}
94
95unsafe extern "C" fn release_fn_wrapper(provider: &mut FFI_RecordBatchStream) {
96    unsafe {
97        debug_assert!(!provider.private_data.is_null());
98        let private_data =
99            Box::from_raw(provider.private_data as *mut RecordBatchStreamPrivateData);
100        drop(private_data);
101        provider.private_data = std::ptr::null_mut();
102    }
103}
104
105pub(crate) fn record_batch_to_wrapped_array(
106    record_batch: RecordBatch,
107) -> FFI_Result<WrappedArray> {
108    let schema = WrappedSchema::from(record_batch.schema());
109    let struct_array = StructArray::from(record_batch);
110    sresult!(
111        to_ffi(&struct_array.to_data())
112            .map(|(array, _schema)| WrappedArray { array, schema })
113    )
114}
115
116// probably want to use pub unsafe fn from_ffi(array: FFI_ArrowArray, schema: &FFI_ArrowSchema) -> Result<ArrayData> {
117fn maybe_record_batch_to_wrapped_stream(
118    record_batch: Option<Result<RecordBatch>>,
119) -> FFI_Option<FFI_Result<WrappedArray>> {
120    match record_batch {
121        Some(Ok(record_batch)) => {
122            FFI_Option::Some(record_batch_to_wrapped_array(record_batch))
123        }
124        Some(Err(e)) => FFI_Option::Some(FFI_Result::Err(e.to_string().into())),
125        None => FFI_Option::None,
126    }
127}
128
129unsafe extern "C" fn poll_next_fn_wrapper(
130    stream: &FFI_RecordBatchStream,
131    cx: &mut FfiContext,
132) -> FfiPoll<FFI_Option<FFI_Result<WrappedArray>>> {
133    unsafe {
134        let private_data = stream.private_data as *mut RecordBatchStreamPrivateData;
135        let stream = &mut (*private_data).rbs;
136
137        let _guard = (*private_data).runtime.as_ref().map(|rt| rt.enter());
138
139        let poll_result = cx.with_context(|std_cx| {
140            (*stream)
141                .try_poll_next_unpin(std_cx)
142                .map(maybe_record_batch_to_wrapped_stream)
143        });
144
145        poll_result.into()
146    }
147}
148
149impl RecordBatchStream for FFI_RecordBatchStream {
150    fn schema(&self) -> arrow::datatypes::SchemaRef {
151        let wrapped_schema = unsafe { (self.schema)(self) };
152        wrapped_schema.into()
153    }
154}
155
156pub(crate) fn wrapped_array_to_record_batch(array: WrappedArray) -> Result<RecordBatch> {
157    let array_data =
158        unsafe { from_ffi(array.array, &array.schema.0).map_err(DataFusionError::from)? };
159    let schema: arrow::datatypes::SchemaRef = array.schema.into();
160    let array = make_array(array_data);
161    let struct_array = array
162        .as_any()
163        .downcast_ref::<StructArray>()
164        .ok_or_else(|| ffi_datafusion_err!(
165        "Unexpected array type during record batch collection in FFI_RecordBatchStream - expected StructArray"
166    ))?;
167
168    let rb: RecordBatch = struct_array.into();
169
170    rb.with_schema(schema).map_err(Into::into)
171}
172
173fn maybe_wrapped_array_to_record_batch(
174    array: FFI_Option<FFI_Result<WrappedArray>>,
175) -> Option<Result<RecordBatch>> {
176    let array: Option<FFI_Result<WrappedArray>> = array.into();
177    match array {
178        Some(result) => {
179            let result: std::result::Result<WrappedArray, _> = result.into();
180            match result {
181                Ok(wrapped_array) => Some(wrapped_array_to_record_batch(wrapped_array)),
182                Err(e) => Some(ffi_err!("{e}")),
183            }
184        }
185        None => None,
186    }
187}
188
189impl Stream for FFI_RecordBatchStream {
190    type Item = Result<RecordBatch>;
191
192    fn poll_next(
193        self: std::pin::Pin<&mut Self>,
194        cx: &mut std::task::Context<'_>,
195    ) -> Poll<Option<Self::Item>> {
196        let poll_result =
197            unsafe { cx.with_ffi_context(|ffi_cx| (self.poll_next)(&self, ffi_cx)) };
198
199        match poll_result {
200            FfiPoll::Ready(array) => {
201                Poll::Ready(maybe_wrapped_array_to_record_batch(array))
202            }
203            FfiPoll::Pending => Poll::Pending,
204            FfiPoll::Panicked => Poll::Ready(Some(ffi_err!(
205                "Panic occurred during poll_next on FFI_RecordBatchStream"
206            ))),
207        }
208    }
209}
210
211impl Drop for FFI_RecordBatchStream {
212    fn drop(&mut self) {
213        unsafe { (self.release)(self) }
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use std::sync::Arc;
220
221    use arrow::datatypes::{DataType, Field, Schema};
222    use datafusion::common::record_batch;
223    use datafusion::error::Result;
224    use datafusion::execution::SendableRecordBatchStream;
225    use datafusion::test_util::bounded_stream;
226    use futures::StreamExt;
227
228    use super::{
229        FFI_RecordBatchStream, record_batch_to_wrapped_array,
230        wrapped_array_to_record_batch,
231    };
232    use crate::df_result;
233
234    #[tokio::test]
235    async fn test_round_trip_record_batch_stream() -> Result<()> {
236        let record_batch = record_batch!(
237            ("a", Int32, vec![1, 2, 3]),
238            ("b", Float64, vec![Some(4.0), None, Some(5.0)])
239        )?;
240        let original_rbs = bounded_stream(record_batch.clone(), 1);
241
242        let ffi_rbs: FFI_RecordBatchStream = original_rbs.into();
243        let mut ffi_rbs: SendableRecordBatchStream = Box::pin(ffi_rbs);
244
245        let schema = ffi_rbs.schema();
246        assert_eq!(
247            schema,
248            Arc::new(Schema::new(vec![
249                Field::new("a", DataType::Int32, true),
250                Field::new("b", DataType::Float64, true)
251            ]))
252        );
253
254        let batch = ffi_rbs.next().await;
255        assert!(batch.is_some());
256        assert!(batch.as_ref().unwrap().is_ok());
257        assert_eq!(batch.unwrap().unwrap(), record_batch);
258
259        // There should only be one batch
260        let no_batch = ffi_rbs.next().await;
261        assert!(no_batch.is_none());
262
263        Ok(())
264    }
265
266    #[test]
267    fn round_trip_record_batch_with_metadata() -> Result<()> {
268        let rb = record_batch!(
269            ("a", Int32, vec![1, 2, 3]),
270            ("b", Float64, vec![Some(4.0), None, Some(5.0)])
271        )?;
272
273        let schema = rb
274            .schema()
275            .as_ref()
276            .clone()
277            .with_metadata([("some_key".to_owned(), "some_value".to_owned())].into())
278            .into();
279
280        let rb = rb.with_schema(schema)?;
281
282        let ffi_rb = df_result!(record_batch_to_wrapped_array(rb.clone()))?;
283
284        let round_trip_rb = wrapped_array_to_record_batch(ffi_rb)?;
285
286        assert_eq!(rb, round_trip_rb);
287        Ok(())
288    }
289}