datafusion_ffi/
record_batch_stream.rs1use std::{ffi::c_void, task::Poll};
19
20use abi_stable::{
21 std_types::{ROption, RResult, RString},
22 StableAbi,
23};
24use arrow::array::{Array, RecordBatch};
25use arrow::{
26 array::{make_array, StructArray},
27 ffi::{from_ffi, to_ffi},
28};
29use async_ffi::{ContextExt, FfiContext, FfiPoll};
30use datafusion::error::Result;
31use datafusion::{
32 error::DataFusionError,
33 execution::{RecordBatchStream, SendableRecordBatchStream},
34};
35use futures::{Stream, TryStreamExt};
36use tokio::runtime::Handle;
37
38use crate::{
39 arrow_wrappers::{WrappedArray, WrappedSchema},
40 rresult,
41};
42
43#[repr(C)]
46#[derive(Debug, StableAbi)]
47#[allow(non_camel_case_types)]
48pub struct FFI_RecordBatchStream {
49 pub poll_next:
52 unsafe extern "C" fn(
53 stream: &Self,
54 cx: &mut FfiContext,
55 ) -> FfiPoll<ROption<RResult<WrappedArray, RString>>>,
56
57 pub schema: unsafe extern "C" fn(stream: &Self) -> WrappedSchema,
59
60 pub private_data: *mut c_void,
63}
64
65pub struct RecordBatchStreamPrivateData {
66 pub rbs: SendableRecordBatchStream,
67 pub runtime: Option<Handle>,
68}
69
70impl From<SendableRecordBatchStream> for FFI_RecordBatchStream {
71 fn from(stream: SendableRecordBatchStream) -> Self {
72 Self::new(stream, None)
73 }
74}
75
76impl FFI_RecordBatchStream {
77 pub fn new(stream: SendableRecordBatchStream, runtime: Option<Handle>) -> Self {
78 let private_data = Box::into_raw(Box::new(RecordBatchStreamPrivateData {
79 rbs: stream,
80 runtime,
81 })) as *mut c_void;
82 FFI_RecordBatchStream {
83 poll_next: poll_next_fn_wrapper,
84 schema: schema_fn_wrapper,
85 private_data,
86 }
87 }
88}
89
90unsafe impl Send for FFI_RecordBatchStream {}
91
92unsafe extern "C" fn schema_fn_wrapper(stream: &FFI_RecordBatchStream) -> WrappedSchema {
93 let private_data = stream.private_data as *const RecordBatchStreamPrivateData;
94 let stream = &(*private_data).rbs;
95
96 (*stream).schema().into()
97}
98
99fn record_batch_to_wrapped_array(
100 record_batch: RecordBatch,
101) -> RResult<WrappedArray, RString> {
102 let struct_array = StructArray::from(record_batch);
103 rresult!(
104 to_ffi(&struct_array.to_data()).map(|(array, schema)| WrappedArray {
105 array,
106 schema: WrappedSchema(schema)
107 })
108 )
109}
110
111fn maybe_record_batch_to_wrapped_stream(
113 record_batch: Option<Result<RecordBatch>>,
114) -> ROption<RResult<WrappedArray, RString>> {
115 match record_batch {
116 Some(Ok(record_batch)) => {
117 ROption::RSome(record_batch_to_wrapped_array(record_batch))
118 }
119 Some(Err(e)) => ROption::RSome(RResult::RErr(e.to_string().into())),
120 None => ROption::RNone,
121 }
122}
123
124unsafe extern "C" fn poll_next_fn_wrapper(
125 stream: &FFI_RecordBatchStream,
126 cx: &mut FfiContext,
127) -> FfiPoll<ROption<RResult<WrappedArray, RString>>> {
128 let private_data = stream.private_data as *mut RecordBatchStreamPrivateData;
129 let stream = &mut (*private_data).rbs;
130
131 let _guard = (*private_data).runtime.as_ref().map(|rt| rt.enter());
132
133 let poll_result = cx.with_context(|std_cx| {
134 (*stream)
135 .try_poll_next_unpin(std_cx)
136 .map(maybe_record_batch_to_wrapped_stream)
137 });
138
139 poll_result.into()
140}
141
142impl RecordBatchStream for FFI_RecordBatchStream {
143 fn schema(&self) -> arrow::datatypes::SchemaRef {
144 let wrapped_schema = unsafe { (self.schema)(self) };
145 wrapped_schema.into()
146 }
147}
148
149fn wrapped_array_to_record_batch(array: WrappedArray) -> Result<RecordBatch> {
150 let array_data =
151 unsafe { from_ffi(array.array, &array.schema.0).map_err(DataFusionError::from)? };
152 let array = make_array(array_data);
153 let struct_array = array
154 .as_any()
155 .downcast_ref::<StructArray>()
156 .ok_or(DataFusionError::Execution(
157 "Unexpected array type during record batch collection in FFI_RecordBatchStream"
158 .to_string(),
159 ))?;
160
161 Ok(struct_array.into())
162}
163
164fn maybe_wrapped_array_to_record_batch(
165 array: ROption<RResult<WrappedArray, RString>>,
166) -> Option<Result<RecordBatch>> {
167 match array {
168 ROption::RSome(RResult::ROk(wrapped_array)) => {
169 Some(wrapped_array_to_record_batch(wrapped_array))
170 }
171 ROption::RSome(RResult::RErr(e)) => {
172 Some(Err(DataFusionError::Execution(e.to_string())))
173 }
174 ROption::RNone => None,
175 }
176}
177
178impl Stream for FFI_RecordBatchStream {
179 type Item = Result<RecordBatch>;
180
181 fn poll_next(
182 self: std::pin::Pin<&mut Self>,
183 cx: &mut std::task::Context<'_>,
184 ) -> Poll<Option<Self::Item>> {
185 let poll_result =
186 unsafe { cx.with_ffi_context(|ffi_cx| (self.poll_next)(&self, ffi_cx)) };
187
188 match poll_result {
189 FfiPoll::Ready(array) => {
190 Poll::Ready(maybe_wrapped_array_to_record_batch(array))
191 }
192 FfiPoll::Pending => Poll::Pending,
193 FfiPoll::Panicked => Poll::Ready(Some(Err(DataFusionError::Execution(
194 "Error occurred during poll_next on FFI_RecordBatchStream".to_string(),
195 )))),
196 }
197 }
198}