datafusion_ffi/
record_batch_stream.rs1use std::ffi::c_void;
19use std::task::Poll;
20
21use arrow::array::{Array, RecordBatch, StructArray, make_array};
22use arrow::ffi::{from_ffi, to_ffi};
23use async_ffi::{ContextExt, FfiContext, FfiPoll};
24use datafusion_common::{DataFusionError, Result, ffi_datafusion_err, ffi_err};
25use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream};
26use futures::{Stream, TryStreamExt};
27
28use tokio::runtime::Handle;
29
30use crate::arrow_wrappers::{WrappedArray, WrappedSchema};
31use crate::sresult;
32use crate::util::{FFI_Option, FFI_Result};
33
34#[repr(C)]
37#[derive(Debug)]
38pub struct FFI_RecordBatchStream {
39 pub poll_next: unsafe extern "C" fn(
42 stream: &Self,
43 cx: &mut FfiContext,
44 )
45 -> FfiPoll<FFI_Option<FFI_Result<WrappedArray>>>,
46
47 pub schema: unsafe extern "C" fn(stream: &Self) -> WrappedSchema,
49
50 pub release: unsafe extern "C" fn(arg: &mut Self),
52
53 pub private_data: *mut c_void,
56}
57
58pub struct RecordBatchStreamPrivateData {
59 pub rbs: SendableRecordBatchStream,
60 pub runtime: Option<Handle>,
61}
62
63impl From<SendableRecordBatchStream> for FFI_RecordBatchStream {
64 fn from(stream: SendableRecordBatchStream) -> Self {
65 Self::new(stream, None)
66 }
67}
68
69impl FFI_RecordBatchStream {
70 pub fn new(stream: SendableRecordBatchStream, runtime: Option<Handle>) -> Self {
71 let private_data = Box::into_raw(Box::new(RecordBatchStreamPrivateData {
72 rbs: stream,
73 runtime,
74 })) as *mut c_void;
75 FFI_RecordBatchStream {
76 poll_next: poll_next_fn_wrapper,
77 schema: schema_fn_wrapper,
78 release: release_fn_wrapper,
79 private_data,
80 }
81 }
82}
83
84unsafe impl Send for FFI_RecordBatchStream {}
85
86unsafe extern "C" fn schema_fn_wrapper(stream: &FFI_RecordBatchStream) -> WrappedSchema {
87 unsafe {
88 let private_data = stream.private_data as *const RecordBatchStreamPrivateData;
89 let stream = &(*private_data).rbs;
90
91 (*stream).schema().into()
92 }
93}
94
95unsafe extern "C" fn release_fn_wrapper(provider: &mut FFI_RecordBatchStream) {
96 unsafe {
97 debug_assert!(!provider.private_data.is_null());
98 let private_data =
99 Box::from_raw(provider.private_data as *mut RecordBatchStreamPrivateData);
100 drop(private_data);
101 provider.private_data = std::ptr::null_mut();
102 }
103}
104
105pub(crate) fn record_batch_to_wrapped_array(
106 record_batch: RecordBatch,
107) -> FFI_Result<WrappedArray> {
108 let schema = WrappedSchema::from(record_batch.schema());
109 let struct_array = StructArray::from(record_batch);
110 sresult!(
111 to_ffi(&struct_array.to_data())
112 .map(|(array, _schema)| WrappedArray { array, schema })
113 )
114}
115
116fn maybe_record_batch_to_wrapped_stream(
118 record_batch: Option<Result<RecordBatch>>,
119) -> FFI_Option<FFI_Result<WrappedArray>> {
120 match record_batch {
121 Some(Ok(record_batch)) => {
122 FFI_Option::Some(record_batch_to_wrapped_array(record_batch))
123 }
124 Some(Err(e)) => FFI_Option::Some(FFI_Result::Err(e.to_string().into())),
125 None => FFI_Option::None,
126 }
127}
128
129unsafe extern "C" fn poll_next_fn_wrapper(
130 stream: &FFI_RecordBatchStream,
131 cx: &mut FfiContext,
132) -> FfiPoll<FFI_Option<FFI_Result<WrappedArray>>> {
133 unsafe {
134 let private_data = stream.private_data as *mut RecordBatchStreamPrivateData;
135 let stream = &mut (*private_data).rbs;
136
137 let _guard = (*private_data).runtime.as_ref().map(|rt| rt.enter());
138
139 let poll_result = cx.with_context(|std_cx| {
140 (*stream)
141 .try_poll_next_unpin(std_cx)
142 .map(maybe_record_batch_to_wrapped_stream)
143 });
144
145 poll_result.into()
146 }
147}
148
149impl RecordBatchStream for FFI_RecordBatchStream {
150 fn schema(&self) -> arrow::datatypes::SchemaRef {
151 let wrapped_schema = unsafe { (self.schema)(self) };
152 wrapped_schema.into()
153 }
154}
155
156pub(crate) fn wrapped_array_to_record_batch(array: WrappedArray) -> Result<RecordBatch> {
157 let array_data =
158 unsafe { from_ffi(array.array, &array.schema.0).map_err(DataFusionError::from)? };
159 let schema: arrow::datatypes::SchemaRef = array.schema.into();
160 let array = make_array(array_data);
161 let struct_array = array
162 .as_any()
163 .downcast_ref::<StructArray>()
164 .ok_or_else(|| ffi_datafusion_err!(
165 "Unexpected array type during record batch collection in FFI_RecordBatchStream - expected StructArray"
166 ))?;
167
168 let rb: RecordBatch = struct_array.into();
169
170 rb.with_schema(schema).map_err(Into::into)
171}
172
173fn maybe_wrapped_array_to_record_batch(
174 array: FFI_Option<FFI_Result<WrappedArray>>,
175) -> Option<Result<RecordBatch>> {
176 let array: Option<FFI_Result<WrappedArray>> = array.into();
177 match array {
178 Some(result) => {
179 let result: std::result::Result<WrappedArray, _> = result.into();
180 match result {
181 Ok(wrapped_array) => Some(wrapped_array_to_record_batch(wrapped_array)),
182 Err(e) => Some(ffi_err!("{e}")),
183 }
184 }
185 None => None,
186 }
187}
188
189impl Stream for FFI_RecordBatchStream {
190 type Item = Result<RecordBatch>;
191
192 fn poll_next(
193 self: std::pin::Pin<&mut Self>,
194 cx: &mut std::task::Context<'_>,
195 ) -> Poll<Option<Self::Item>> {
196 let poll_result =
197 unsafe { cx.with_ffi_context(|ffi_cx| (self.poll_next)(&self, ffi_cx)) };
198
199 match poll_result {
200 FfiPoll::Ready(array) => {
201 Poll::Ready(maybe_wrapped_array_to_record_batch(array))
202 }
203 FfiPoll::Pending => Poll::Pending,
204 FfiPoll::Panicked => Poll::Ready(Some(ffi_err!(
205 "Panic occurred during poll_next on FFI_RecordBatchStream"
206 ))),
207 }
208 }
209}
210
211impl Drop for FFI_RecordBatchStream {
212 fn drop(&mut self) {
213 unsafe { (self.release)(self) }
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use std::sync::Arc;
220
221 use arrow::datatypes::{DataType, Field, Schema};
222 use datafusion::common::record_batch;
223 use datafusion::error::Result;
224 use datafusion::execution::SendableRecordBatchStream;
225 use datafusion::test_util::bounded_stream;
226 use futures::StreamExt;
227
228 use super::{
229 FFI_RecordBatchStream, record_batch_to_wrapped_array,
230 wrapped_array_to_record_batch,
231 };
232 use crate::df_result;
233
234 #[tokio::test]
235 async fn test_round_trip_record_batch_stream() -> Result<()> {
236 let record_batch = record_batch!(
237 ("a", Int32, vec![1, 2, 3]),
238 ("b", Float64, vec![Some(4.0), None, Some(5.0)])
239 )?;
240 let original_rbs = bounded_stream(record_batch.clone(), 1);
241
242 let ffi_rbs: FFI_RecordBatchStream = original_rbs.into();
243 let mut ffi_rbs: SendableRecordBatchStream = Box::pin(ffi_rbs);
244
245 let schema = ffi_rbs.schema();
246 assert_eq!(
247 schema,
248 Arc::new(Schema::new(vec![
249 Field::new("a", DataType::Int32, true),
250 Field::new("b", DataType::Float64, true)
251 ]))
252 );
253
254 let batch = ffi_rbs.next().await;
255 assert!(batch.is_some());
256 assert!(batch.as_ref().unwrap().is_ok());
257 assert_eq!(batch.unwrap().unwrap(), record_batch);
258
259 let no_batch = ffi_rbs.next().await;
261 assert!(no_batch.is_none());
262
263 Ok(())
264 }
265
266 #[test]
267 fn round_trip_record_batch_with_metadata() -> Result<()> {
268 let rb = record_batch!(
269 ("a", Int32, vec![1, 2, 3]),
270 ("b", Float64, vec![Some(4.0), None, Some(5.0)])
271 )?;
272
273 let schema = rb
274 .schema()
275 .as_ref()
276 .clone()
277 .with_metadata([("some_key".to_owned(), "some_value".to_owned())].into())
278 .into();
279
280 let rb = rb.with_schema(schema)?;
281
282 let ffi_rb = df_result!(record_batch_to_wrapped_array(rb.clone()))?;
283
284 let round_trip_rb = wrapped_array_to_record_batch(ffi_rb)?;
285
286 assert_eq!(rb, round_trip_rb);
287 Ok(())
288 }
289}