datafusion_ffi/
record_batch_stream.rs1use std::{ffi::c_void, task::Poll};
19
20use abi_stable::{
21 std_types::{ROption, RResult, RString},
22 StableAbi,
23};
24use arrow::array::{Array, RecordBatch};
25use arrow::{
26 array::{make_array, StructArray},
27 ffi::{from_ffi, to_ffi},
28};
29use async_ffi::{ContextExt, FfiContext, FfiPoll};
30use datafusion::error::Result;
31use datafusion::{
32 error::DataFusionError,
33 execution::{RecordBatchStream, SendableRecordBatchStream},
34};
35use datafusion_common::{exec_datafusion_err, exec_err};
36use futures::{Stream, TryStreamExt};
37use tokio::runtime::Handle;
38
39use crate::{
40 arrow_wrappers::{WrappedArray, WrappedSchema},
41 rresult,
42};
43
44#[repr(C)]
47#[derive(Debug, StableAbi)]
48#[allow(non_camel_case_types)]
49pub struct FFI_RecordBatchStream {
50 pub poll_next:
53 unsafe extern "C" fn(
54 stream: &Self,
55 cx: &mut FfiContext,
56 ) -> FfiPoll<ROption<RResult<WrappedArray, RString>>>,
57
58 pub schema: unsafe extern "C" fn(stream: &Self) -> WrappedSchema,
60
61 pub release: unsafe extern "C" fn(arg: &mut Self),
63
64 pub private_data: *mut c_void,
67}
68
69pub struct RecordBatchStreamPrivateData {
70 pub rbs: SendableRecordBatchStream,
71 pub runtime: Option<Handle>,
72}
73
74impl From<SendableRecordBatchStream> for FFI_RecordBatchStream {
75 fn from(stream: SendableRecordBatchStream) -> Self {
76 Self::new(stream, None)
77 }
78}
79
80impl FFI_RecordBatchStream {
81 pub fn new(stream: SendableRecordBatchStream, runtime: Option<Handle>) -> Self {
82 let private_data = Box::into_raw(Box::new(RecordBatchStreamPrivateData {
83 rbs: stream,
84 runtime,
85 })) as *mut c_void;
86 FFI_RecordBatchStream {
87 poll_next: poll_next_fn_wrapper,
88 schema: schema_fn_wrapper,
89 release: release_fn_wrapper,
90 private_data,
91 }
92 }
93}
94
95unsafe impl Send for FFI_RecordBatchStream {}
96
97unsafe extern "C" fn schema_fn_wrapper(stream: &FFI_RecordBatchStream) -> WrappedSchema {
98 let private_data = stream.private_data as *const RecordBatchStreamPrivateData;
99 let stream = &(*private_data).rbs;
100
101 (*stream).schema().into()
102}
103
104unsafe extern "C" fn release_fn_wrapper(provider: &mut FFI_RecordBatchStream) {
105 let private_data =
106 Box::from_raw(provider.private_data as *mut RecordBatchStreamPrivateData);
107 drop(private_data);
108}
109
110fn record_batch_to_wrapped_array(
111 record_batch: RecordBatch,
112) -> RResult<WrappedArray, RString> {
113 let struct_array = StructArray::from(record_batch);
114 rresult!(
115 to_ffi(&struct_array.to_data()).map(|(array, schema)| WrappedArray {
116 array,
117 schema: WrappedSchema(schema)
118 })
119 )
120}
121
122fn maybe_record_batch_to_wrapped_stream(
124 record_batch: Option<Result<RecordBatch>>,
125) -> ROption<RResult<WrappedArray, RString>> {
126 match record_batch {
127 Some(Ok(record_batch)) => {
128 ROption::RSome(record_batch_to_wrapped_array(record_batch))
129 }
130 Some(Err(e)) => ROption::RSome(RResult::RErr(e.to_string().into())),
131 None => ROption::RNone,
132 }
133}
134
135unsafe extern "C" fn poll_next_fn_wrapper(
136 stream: &FFI_RecordBatchStream,
137 cx: &mut FfiContext,
138) -> FfiPoll<ROption<RResult<WrappedArray, RString>>> {
139 let private_data = stream.private_data as *mut RecordBatchStreamPrivateData;
140 let stream = &mut (*private_data).rbs;
141
142 let _guard = (*private_data).runtime.as_ref().map(|rt| rt.enter());
143
144 let poll_result = cx.with_context(|std_cx| {
145 (*stream)
146 .try_poll_next_unpin(std_cx)
147 .map(maybe_record_batch_to_wrapped_stream)
148 });
149
150 poll_result.into()
151}
152
153impl RecordBatchStream for FFI_RecordBatchStream {
154 fn schema(&self) -> arrow::datatypes::SchemaRef {
155 let wrapped_schema = unsafe { (self.schema)(self) };
156 wrapped_schema.into()
157 }
158}
159
160fn wrapped_array_to_record_batch(array: WrappedArray) -> Result<RecordBatch> {
161 let array_data =
162 unsafe { from_ffi(array.array, &array.schema.0).map_err(DataFusionError::from)? };
163 let array = make_array(array_data);
164 let struct_array = array
165 .as_any()
166 .downcast_ref::<StructArray>()
167 .ok_or_else(|| exec_datafusion_err!(
168 "Unexpected array type during record batch collection in FFI_RecordBatchStream - expected StructArray"
169 ))?;
170
171 Ok(struct_array.into())
172}
173
174fn maybe_wrapped_array_to_record_batch(
175 array: ROption<RResult<WrappedArray, RString>>,
176) -> Option<Result<RecordBatch>> {
177 match array {
178 ROption::RSome(RResult::ROk(wrapped_array)) => {
179 Some(wrapped_array_to_record_batch(wrapped_array))
180 }
181 ROption::RSome(RResult::RErr(e)) => Some(exec_err!("FFI error: {e}")),
182 ROption::RNone => None,
183 }
184}
185
186impl Stream for FFI_RecordBatchStream {
187 type Item = Result<RecordBatch>;
188
189 fn poll_next(
190 self: std::pin::Pin<&mut Self>,
191 cx: &mut std::task::Context<'_>,
192 ) -> Poll<Option<Self::Item>> {
193 let poll_result =
194 unsafe { cx.with_ffi_context(|ffi_cx| (self.poll_next)(&self, ffi_cx)) };
195
196 match poll_result {
197 FfiPoll::Ready(array) => {
198 Poll::Ready(maybe_wrapped_array_to_record_batch(array))
199 }
200 FfiPoll::Pending => Poll::Pending,
201 FfiPoll::Panicked => Poll::Ready(Some(exec_err!(
202 "Panic occurred during poll_next on FFI_RecordBatchStream"
203 ))),
204 }
205 }
206}
207
208impl Drop for FFI_RecordBatchStream {
209 fn drop(&mut self) {
210 unsafe { (self.release)(self) }
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use std::sync::Arc;
217
218 use arrow::datatypes::{DataType, Field, Schema};
219 use datafusion::{
220 common::record_batch, error::Result, execution::SendableRecordBatchStream,
221 test_util::bounded_stream,
222 };
223
224 use super::FFI_RecordBatchStream;
225 use futures::StreamExt;
226
227 #[tokio::test]
228 async fn test_round_trip_record_batch_stream() -> Result<()> {
229 let record_batch = record_batch!(
230 ("a", Int32, vec![1, 2, 3]),
231 ("b", Float64, vec![Some(4.0), None, Some(5.0)])
232 )?;
233 let original_rbs = bounded_stream(record_batch.clone(), 1);
234
235 let ffi_rbs: FFI_RecordBatchStream = original_rbs.into();
236 let mut ffi_rbs: SendableRecordBatchStream = Box::pin(ffi_rbs);
237
238 let schema = ffi_rbs.schema();
239 assert_eq!(
240 schema,
241 Arc::new(Schema::new(vec![
242 Field::new("a", DataType::Int32, true),
243 Field::new("b", DataType::Float64, true)
244 ]))
245 );
246
247 let batch = ffi_rbs.next().await;
248 assert!(batch.is_some());
249 assert!(batch.as_ref().unwrap().is_ok());
250 assert_eq!(batch.unwrap().unwrap(), record_batch);
251
252 let no_batch = ffi_rbs.next().await;
254 assert!(no_batch.is_none());
255
256 Ok(())
257 }
258}