datafusion_ffi/
record_batch_stream.rs1use std::ffi::c_void;
19use std::task::Poll;
20
21use abi_stable::StableAbi;
22use abi_stable::std_types::{ROption, RResult};
23use arrow::array::{Array, RecordBatch, StructArray, make_array};
24use arrow::ffi::{from_ffi, to_ffi};
25use async_ffi::{ContextExt, FfiContext, FfiPoll};
26use datafusion_common::{DataFusionError, Result, ffi_datafusion_err, ffi_err};
27use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream};
28use futures::{Stream, TryStreamExt};
29use tokio::runtime::Handle;
30
31use crate::arrow_wrappers::{WrappedArray, WrappedSchema};
32use crate::rresult;
33use crate::util::FFIResult;
34
35#[repr(C)]
38#[derive(Debug, StableAbi)]
39pub struct FFI_RecordBatchStream {
40 pub poll_next: unsafe extern "C" fn(
43 stream: &Self,
44 cx: &mut FfiContext,
45 ) -> FfiPoll<ROption<FFIResult<WrappedArray>>>,
46
47 pub schema: unsafe extern "C" fn(stream: &Self) -> WrappedSchema,
49
50 pub release: unsafe extern "C" fn(arg: &mut Self),
52
53 pub private_data: *mut c_void,
56}
57
58pub struct RecordBatchStreamPrivateData {
59 pub rbs: SendableRecordBatchStream,
60 pub runtime: Option<Handle>,
61}
62
63impl From<SendableRecordBatchStream> for FFI_RecordBatchStream {
64 fn from(stream: SendableRecordBatchStream) -> Self {
65 Self::new(stream, None)
66 }
67}
68
69impl FFI_RecordBatchStream {
70 pub fn new(stream: SendableRecordBatchStream, runtime: Option<Handle>) -> Self {
71 let private_data = Box::into_raw(Box::new(RecordBatchStreamPrivateData {
72 rbs: stream,
73 runtime,
74 })) as *mut c_void;
75 FFI_RecordBatchStream {
76 poll_next: poll_next_fn_wrapper,
77 schema: schema_fn_wrapper,
78 release: release_fn_wrapper,
79 private_data,
80 }
81 }
82}
83
84unsafe impl Send for FFI_RecordBatchStream {}
85
86unsafe extern "C" fn schema_fn_wrapper(stream: &FFI_RecordBatchStream) -> WrappedSchema {
87 unsafe {
88 let private_data = stream.private_data as *const RecordBatchStreamPrivateData;
89 let stream = &(*private_data).rbs;
90
91 (*stream).schema().into()
92 }
93}
94
95unsafe extern "C" fn release_fn_wrapper(provider: &mut FFI_RecordBatchStream) {
96 unsafe {
97 debug_assert!(!provider.private_data.is_null());
98 let private_data =
99 Box::from_raw(provider.private_data as *mut RecordBatchStreamPrivateData);
100 drop(private_data);
101 provider.private_data = std::ptr::null_mut();
102 }
103}
104
105pub(crate) fn record_batch_to_wrapped_array(
106 record_batch: RecordBatch,
107) -> FFIResult<WrappedArray> {
108 let schema = WrappedSchema::from(record_batch.schema());
109 let struct_array = StructArray::from(record_batch);
110 rresult!(
111 to_ffi(&struct_array.to_data())
112 .map(|(array, _schema)| WrappedArray { array, schema })
113 )
114}
115
116fn maybe_record_batch_to_wrapped_stream(
118 record_batch: Option<Result<RecordBatch>>,
119) -> ROption<FFIResult<WrappedArray>> {
120 match record_batch {
121 Some(Ok(record_batch)) => {
122 ROption::RSome(record_batch_to_wrapped_array(record_batch))
123 }
124 Some(Err(e)) => ROption::RSome(RResult::RErr(e.to_string().into())),
125 None => ROption::RNone,
126 }
127}
128
129unsafe extern "C" fn poll_next_fn_wrapper(
130 stream: &FFI_RecordBatchStream,
131 cx: &mut FfiContext,
132) -> FfiPoll<ROption<FFIResult<WrappedArray>>> {
133 unsafe {
134 let private_data = stream.private_data as *mut RecordBatchStreamPrivateData;
135 let stream = &mut (*private_data).rbs;
136
137 let _guard = (*private_data).runtime.as_ref().map(|rt| rt.enter());
138
139 let poll_result = cx.with_context(|std_cx| {
140 (*stream)
141 .try_poll_next_unpin(std_cx)
142 .map(maybe_record_batch_to_wrapped_stream)
143 });
144
145 poll_result.into()
146 }
147}
148
149impl RecordBatchStream for FFI_RecordBatchStream {
150 fn schema(&self) -> arrow::datatypes::SchemaRef {
151 let wrapped_schema = unsafe { (self.schema)(self) };
152 wrapped_schema.into()
153 }
154}
155
156pub(crate) fn wrapped_array_to_record_batch(array: WrappedArray) -> Result<RecordBatch> {
157 let array_data =
158 unsafe { from_ffi(array.array, &array.schema.0).map_err(DataFusionError::from)? };
159 let schema: arrow::datatypes::SchemaRef = array.schema.into();
160 let array = make_array(array_data);
161 let struct_array = array
162 .as_any()
163 .downcast_ref::<StructArray>()
164 .ok_or_else(|| ffi_datafusion_err!(
165 "Unexpected array type during record batch collection in FFI_RecordBatchStream - expected StructArray"
166 ))?;
167
168 let rb: RecordBatch = struct_array.into();
169
170 rb.with_schema(schema).map_err(Into::into)
171}
172
173fn maybe_wrapped_array_to_record_batch(
174 array: ROption<FFIResult<WrappedArray>>,
175) -> Option<Result<RecordBatch>> {
176 match array {
177 ROption::RSome(RResult::ROk(wrapped_array)) => {
178 Some(wrapped_array_to_record_batch(wrapped_array))
179 }
180 ROption::RSome(RResult::RErr(e)) => Some(ffi_err!("{e}")),
181 ROption::RNone => None,
182 }
183}
184
185impl Stream for FFI_RecordBatchStream {
186 type Item = Result<RecordBatch>;
187
188 fn poll_next(
189 self: std::pin::Pin<&mut Self>,
190 cx: &mut std::task::Context<'_>,
191 ) -> Poll<Option<Self::Item>> {
192 let poll_result =
193 unsafe { cx.with_ffi_context(|ffi_cx| (self.poll_next)(&self, ffi_cx)) };
194
195 match poll_result {
196 FfiPoll::Ready(array) => {
197 Poll::Ready(maybe_wrapped_array_to_record_batch(array))
198 }
199 FfiPoll::Pending => Poll::Pending,
200 FfiPoll::Panicked => Poll::Ready(Some(ffi_err!(
201 "Panic occurred during poll_next on FFI_RecordBatchStream"
202 ))),
203 }
204 }
205}
206
207impl Drop for FFI_RecordBatchStream {
208 fn drop(&mut self) {
209 unsafe { (self.release)(self) }
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use std::sync::Arc;
216
217 use arrow::datatypes::{DataType, Field, Schema};
218 use datafusion::common::record_batch;
219 use datafusion::error::Result;
220 use datafusion::execution::SendableRecordBatchStream;
221 use datafusion::test_util::bounded_stream;
222 use futures::StreamExt;
223
224 use super::{
225 FFI_RecordBatchStream, record_batch_to_wrapped_array,
226 wrapped_array_to_record_batch,
227 };
228 use crate::df_result;
229
230 #[tokio::test]
231 async fn test_round_trip_record_batch_stream() -> Result<()> {
232 let record_batch = record_batch!(
233 ("a", Int32, vec![1, 2, 3]),
234 ("b", Float64, vec![Some(4.0), None, Some(5.0)])
235 )?;
236 let original_rbs = bounded_stream(record_batch.clone(), 1);
237
238 let ffi_rbs: FFI_RecordBatchStream = original_rbs.into();
239 let mut ffi_rbs: SendableRecordBatchStream = Box::pin(ffi_rbs);
240
241 let schema = ffi_rbs.schema();
242 assert_eq!(
243 schema,
244 Arc::new(Schema::new(vec![
245 Field::new("a", DataType::Int32, true),
246 Field::new("b", DataType::Float64, true)
247 ]))
248 );
249
250 let batch = ffi_rbs.next().await;
251 assert!(batch.is_some());
252 assert!(batch.as_ref().unwrap().is_ok());
253 assert_eq!(batch.unwrap().unwrap(), record_batch);
254
255 let no_batch = ffi_rbs.next().await;
257 assert!(no_batch.is_none());
258
259 Ok(())
260 }
261
262 #[test]
263 fn round_trip_record_batch_with_metadata() -> Result<()> {
264 let rb = record_batch!(
265 ("a", Int32, vec![1, 2, 3]),
266 ("b", Float64, vec![Some(4.0), None, Some(5.0)])
267 )?;
268
269 let schema = rb
270 .schema()
271 .as_ref()
272 .clone()
273 .with_metadata([("some_key".to_owned(), "some_value".to_owned())].into())
274 .into();
275
276 let rb = rb.with_schema(schema)?;
277
278 let ffi_rb = df_result!(record_batch_to_wrapped_array(rb.clone()))?;
279
280 let round_trip_rb = wrapped_array_to_record_batch(ffi_rb)?;
281
282 assert_eq!(rb, round_trip_rb);
283 Ok(())
284 }
285}