1use arrow_schema::DataType;
58use std::ffi::CStr;
59use std::ptr::addr_of;
60use std::{
61 ffi::CString,
62 os::raw::{c_char, c_int, c_void},
63 sync::Arc,
64};
65
66use arrow_data::ffi::FFI_ArrowArray;
67use arrow_schema::{ArrowError, Schema, SchemaRef, ffi::FFI_ArrowSchema};
68
69use crate::RecordBatchOptions;
70use crate::array::Array;
71use crate::array::StructArray;
72use crate::ffi::from_ffi_and_data_type;
73use crate::record_batch::{RecordBatch, RecordBatchReader};
74
75type Result<T> = std::result::Result<T, ArrowError>;
76
77const ENOMEM: i32 = 12;
78const EIO: i32 = 5;
79const EINVAL: i32 = 22;
80const ENOSYS: i32 = 78;
81
82#[repr(C)]
86#[derive(Debug)]
87#[allow(non_camel_case_types)]
88pub struct FFI_ArrowArrayStream {
89 pub get_schema:
91 Option<unsafe extern "C" fn(arg1: *mut Self, out: *mut FFI_ArrowSchema) -> c_int>,
92 pub get_next: Option<unsafe extern "C" fn(arg1: *mut Self, out: *mut FFI_ArrowArray) -> c_int>,
94 pub get_last_error: Option<unsafe extern "C" fn(arg1: *mut Self) -> *const c_char>,
96 pub release: Option<unsafe extern "C" fn(arg1: *mut Self)>,
98 pub private_data: *mut c_void,
100}
101
102unsafe impl Send for FFI_ArrowArrayStream {}
103
104unsafe extern "C" fn release_stream(stream: *mut FFI_ArrowArrayStream) {
106 if stream.is_null() {
107 return;
108 }
109 let stream = unsafe { &mut *stream };
110
111 stream.get_schema = None;
112 stream.get_next = None;
113 stream.get_last_error = None;
114
115 let private_data = unsafe { Box::from_raw(stream.private_data as *mut StreamPrivateData) };
116 drop(private_data);
117
118 stream.release = None;
119}
120
121struct StreamPrivateData {
122 batch_reader: Box<dyn RecordBatchReader + Send>,
123 last_error: Option<CString>,
124}
125
126unsafe extern "C" fn get_schema(
128 stream: *mut FFI_ArrowArrayStream,
129 schema: *mut FFI_ArrowSchema,
130) -> c_int {
131 ExportedArrayStream { stream }.get_schema(schema)
132}
133
134unsafe extern "C" fn get_next(
136 stream: *mut FFI_ArrowArrayStream,
137 array: *mut FFI_ArrowArray,
138) -> c_int {
139 ExportedArrayStream { stream }.get_next(array)
140}
141
142unsafe extern "C" fn get_last_error(stream: *mut FFI_ArrowArrayStream) -> *const c_char {
144 let mut ffi_stream = ExportedArrayStream { stream };
145 match ffi_stream.get_last_error() {
148 Some(err_string) => err_string.as_ptr(),
149 None => std::ptr::null(),
150 }
151}
152
153impl Drop for FFI_ArrowArrayStream {
154 fn drop(&mut self) {
155 match self.release {
156 None => (),
157 Some(release) => unsafe { release(self) },
158 };
159 }
160}
161
162impl FFI_ArrowArrayStream {
163 pub fn new(batch_reader: Box<dyn RecordBatchReader + Send>) -> Self {
165 let private_data = Box::new(StreamPrivateData {
166 batch_reader,
167 last_error: None,
168 });
169
170 Self {
171 get_schema: Some(get_schema),
172 get_next: Some(get_next),
173 get_last_error: Some(get_last_error),
174 release: Some(release_stream),
175 private_data: Box::into_raw(private_data) as *mut c_void,
176 }
177 }
178
179 pub unsafe fn from_raw(raw_stream: *mut FFI_ArrowArrayStream) -> Self {
192 unsafe { std::ptr::replace(raw_stream, Self::empty()) }
193 }
194
195 pub fn empty() -> Self {
197 Self {
198 get_schema: None,
199 get_next: None,
200 get_last_error: None,
201 release: None,
202 private_data: std::ptr::null_mut(),
203 }
204 }
205}
206
207struct ExportedArrayStream {
208 stream: *mut FFI_ArrowArrayStream,
209}
210
211impl ExportedArrayStream {
212 fn get_private_data(&mut self) -> &mut StreamPrivateData {
213 unsafe { &mut *((*self.stream).private_data as *mut StreamPrivateData) }
214 }
215
216 pub fn get_schema(&mut self, out: *mut FFI_ArrowSchema) -> i32 {
217 let private_data = self.get_private_data();
218 let reader = &private_data.batch_reader;
219
220 let schema = FFI_ArrowSchema::try_from(reader.schema().as_ref());
221
222 match schema {
223 Ok(schema) => {
224 unsafe { std::ptr::copy(addr_of!(schema), out, 1) };
225 std::mem::forget(schema);
226 0
227 }
228 Err(ref err) => {
229 private_data.last_error = Some(
230 CString::new(err.to_string()).expect("Error string has a null byte in it."),
231 );
232 get_error_code(err)
233 }
234 }
235 }
236
237 pub fn get_next(&mut self, out: *mut FFI_ArrowArray) -> i32 {
238 let private_data = self.get_private_data();
239 let reader = &mut private_data.batch_reader;
240
241 match reader.next() {
242 None => {
243 unsafe { std::ptr::write(out, FFI_ArrowArray::empty()) }
245 0
246 }
247 Some(next_batch) => {
248 if let Ok(batch) = next_batch {
249 let struct_array = StructArray::from(batch);
250 let array = FFI_ArrowArray::new(&struct_array.to_data());
251
252 unsafe { std::ptr::write_unaligned(out, array) };
253 0
254 } else {
255 let err = &next_batch.unwrap_err();
256 private_data.last_error = Some(
257 CString::new(err.to_string()).expect("Error string has a null byte in it."),
258 );
259 get_error_code(err)
260 }
261 }
262 }
263 }
264
265 pub fn get_last_error(&mut self) -> Option<&CString> {
266 self.get_private_data().last_error.as_ref()
267 }
268}
269
270fn get_error_code(err: &ArrowError) -> i32 {
271 match err {
272 ArrowError::NotYetImplemented(_) => ENOSYS,
273 ArrowError::MemoryError(_) => ENOMEM,
274 ArrowError::IoError(_, _) => EIO,
275 _ => EINVAL,
276 }
277}
278
279#[derive(Debug)]
285pub struct ArrowArrayStreamReader {
286 stream: FFI_ArrowArrayStream,
287 schema: SchemaRef,
288}
289
290fn get_stream_schema(stream_ptr: *mut FFI_ArrowArrayStream) -> Result<SchemaRef> {
293 let mut schema = FFI_ArrowSchema::empty();
294
295 let ret_code = unsafe { (*stream_ptr).get_schema.unwrap()(stream_ptr, &mut schema) };
296
297 if ret_code == 0 {
298 let schema = Schema::try_from(&schema)?;
299 Ok(Arc::new(schema))
300 } else {
301 Err(ArrowError::CDataInterface(format!(
302 "Cannot get schema from input stream. Error code: {ret_code:?}"
303 )))
304 }
305}
306
307impl ArrowArrayStreamReader {
308 #[allow(dead_code)]
311 pub fn try_new(mut stream: FFI_ArrowArrayStream) -> Result<Self> {
312 if stream.release.is_none() {
313 return Err(ArrowError::CDataInterface(
314 "input stream is already released".to_string(),
315 ));
316 }
317
318 let schema = get_stream_schema(&mut stream)?;
319
320 Ok(Self { stream, schema })
321 }
322
323 pub unsafe fn from_raw(raw_stream: *mut FFI_ArrowArrayStream) -> Result<Self> {
334 Self::try_new(unsafe { FFI_ArrowArrayStream::from_raw(raw_stream) })
335 }
336
337 fn get_stream_last_error(&mut self) -> Option<String> {
339 let get_last_error = self.stream.get_last_error?;
340
341 let error_str = unsafe { get_last_error(&mut self.stream) };
342 if error_str.is_null() {
343 return None;
344 }
345
346 let error_str = unsafe { CStr::from_ptr(error_str) };
347 Some(error_str.to_string_lossy().to_string())
348 }
349}
350
351impl Iterator for ArrowArrayStreamReader {
352 type Item = Result<RecordBatch>;
353
354 fn next(&mut self) -> Option<Self::Item> {
355 let mut array = FFI_ArrowArray::empty();
356
357 let ret_code = unsafe { self.stream.get_next.unwrap()(&mut self.stream, &mut array) };
358
359 if ret_code == 0 {
360 if array.is_released() {
362 return None;
363 }
364
365 let result = unsafe {
366 from_ffi_and_data_type(array, DataType::Struct(self.schema().fields().clone()))
367 };
368 Some(result.and_then(|data| {
369 let len = data.len();
370 RecordBatch::try_new_with_options(
371 self.schema.clone(),
372 StructArray::from(data).into_parts().1,
373 &RecordBatchOptions::new().with_row_count(Some(len)),
374 )
375 }))
376 } else {
377 let last_error = self.get_stream_last_error();
378 let err = ArrowError::CDataInterface(last_error.unwrap());
379 Some(Err(err))
380 }
381 }
382}
383
384impl RecordBatchReader for ArrowArrayStreamReader {
385 fn schema(&self) -> SchemaRef {
386 self.schema.clone()
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use std::collections::HashMap;
394
395 use arrow_schema::Field;
396
397 use crate::array::Int32Array;
398 use crate::ffi::from_ffi;
399
400 struct TestRecordBatchReader {
401 schema: SchemaRef,
402 iter: Box<dyn Iterator<Item = Result<RecordBatch>> + Send>,
403 }
404
405 impl TestRecordBatchReader {
406 pub fn new(
407 schema: SchemaRef,
408 iter: Box<dyn Iterator<Item = Result<RecordBatch>> + Send>,
409 ) -> Box<TestRecordBatchReader> {
410 Box::new(TestRecordBatchReader { schema, iter })
411 }
412 }
413
414 impl Iterator for TestRecordBatchReader {
415 type Item = Result<RecordBatch>;
416
417 fn next(&mut self) -> Option<Self::Item> {
418 self.iter.next()
419 }
420 }
421
422 impl RecordBatchReader for TestRecordBatchReader {
423 fn schema(&self) -> SchemaRef {
424 self.schema.clone()
425 }
426 }
427
428 fn _test_round_trip_export(batch: RecordBatch, schema: Arc<Schema>) -> Result<()> {
429 let iter = Box::new(vec![batch.clone(), batch.clone()].into_iter().map(Ok)) as _;
430
431 let reader = TestRecordBatchReader::new(schema.clone(), iter);
432
433 let mut ffi_stream = FFI_ArrowArrayStream::new(reader);
435
436 let mut ffi_schema = FFI_ArrowSchema::empty();
438 let ret_code = unsafe { get_schema(&mut ffi_stream, &mut ffi_schema) };
439 assert_eq!(ret_code, 0);
440
441 let exported_schema = Schema::try_from(&ffi_schema).unwrap();
442 assert_eq!(&exported_schema, schema.as_ref());
443
444 let mut produced_batches = vec![];
446 loop {
447 let mut ffi_array = FFI_ArrowArray::empty();
448 let ret_code = unsafe { get_next(&mut ffi_stream, &mut ffi_array) };
449 assert_eq!(ret_code, 0);
450
451 if ffi_array.is_released() {
453 break;
454 }
455
456 let array = unsafe { from_ffi(ffi_array, &ffi_schema) }.unwrap();
457 let len = array.len();
458
459 let record_batch = RecordBatch::try_new_with_options(
460 SchemaRef::from(exported_schema.clone()),
461 StructArray::from(array).into_parts().1,
462 &RecordBatchOptions::new().with_row_count(Some(len)),
463 )
464 .unwrap();
465 produced_batches.push(record_batch);
466 }
467
468 assert_eq!(produced_batches, vec![batch.clone(), batch]);
469
470 Ok(())
471 }
472
473 fn _test_round_trip_import(batch: RecordBatch, schema: Arc<Schema>) -> Result<()> {
474 let iter = Box::new(vec![batch.clone(), batch.clone()].into_iter().map(Ok)) as _;
475
476 let reader = TestRecordBatchReader::new(schema.clone(), iter);
477
478 let stream = FFI_ArrowArrayStream::new(reader);
480 let stream_reader = ArrowArrayStreamReader::try_new(stream).unwrap();
481
482 let imported_schema = stream_reader.schema();
483 assert_eq!(imported_schema, schema);
484
485 let mut produced_batches = vec![];
486 for batch in stream_reader {
487 produced_batches.push(batch.unwrap());
488 }
489
490 assert_eq!(produced_batches, vec![batch.clone(), batch]);
491
492 Ok(())
493 }
494
495 #[test]
496 fn test_stream_round_trip() {
497 let array = Int32Array::from(vec![Some(2), None, Some(1), None]);
498 let array: Arc<dyn Array> = Arc::new(array);
499 let metadata = HashMap::from([("foo".to_owned(), "bar".to_owned())]);
500
501 let schema = Arc::new(Schema::new_with_metadata(
502 vec![
503 Field::new("a", array.data_type().clone(), true).with_metadata(metadata.clone()),
504 Field::new("b", array.data_type().clone(), true).with_metadata(metadata.clone()),
505 Field::new("c", array.data_type().clone(), true).with_metadata(metadata.clone()),
506 ],
507 metadata,
508 ));
509 let batch = RecordBatch::try_new(schema.clone(), vec![array.clone(), array.clone(), array])
510 .unwrap();
511
512 _test_round_trip_export(batch.clone(), schema.clone()).unwrap();
513 _test_round_trip_import(batch, schema).unwrap();
514 }
515
516 #[test]
517 fn test_stream_round_trip_no_columns() {
518 let metadata = HashMap::from([("foo".to_owned(), "bar".to_owned())]);
519
520 let schema = Arc::new(Schema::new_with_metadata(Vec::<Field>::new(), metadata));
521 let batch = RecordBatch::try_new_with_options(
522 schema.clone(),
523 Vec::<Arc<dyn Array>>::new(),
524 &RecordBatchOptions::new().with_row_count(Some(10)),
525 )
526 .unwrap();
527
528 _test_round_trip_export(batch.clone(), schema.clone()).unwrap();
529 _test_round_trip_import(batch, schema).unwrap();
530 }
531
532 #[test]
533 fn test_error_import() -> Result<()> {
534 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)]));
535
536 let iter = Box::new(vec![Err(ArrowError::MemoryError("".to_string()))].into_iter());
537
538 let reader = TestRecordBatchReader::new(schema.clone(), iter);
539
540 let stream = FFI_ArrowArrayStream::new(reader);
542 let stream_reader = ArrowArrayStreamReader::try_new(stream).unwrap();
543
544 let imported_schema = stream_reader.schema();
545 assert_eq!(imported_schema, schema);
546
547 let mut produced_batches = vec![];
548 for batch in stream_reader {
549 produced_batches.push(batch);
550 }
551
552 assert_eq!(produced_batches.len(), 1);
554 assert!(produced_batches[0].is_err());
555
556 Ok(())
557 }
558}