use std::ptr::addr_of;
use std::{
convert::TryFrom,
ffi::CString,
os::raw::{c_char, c_int, c_void},
sync::Arc,
};
use crate::array::Array;
use crate::array::StructArray;
use crate::datatypes::{Schema, SchemaRef};
use crate::error::ArrowError;
use crate::error::Result;
use crate::ffi::*;
use crate::record_batch::{RecordBatch, RecordBatchReader};
const ENOMEM: i32 = 12;
const EIO: i32 = 5;
const EINVAL: i32 = 22;
const ENOSYS: i32 = 78;
#[repr(C)]
#[derive(Debug)]
pub struct FFI_ArrowArrayStream {
pub get_schema: Option<
unsafe extern "C" fn(
arg1: *mut FFI_ArrowArrayStream,
out: *mut FFI_ArrowSchema,
) -> c_int,
>,
pub get_next: Option<
unsafe extern "C" fn(
arg1: *mut FFI_ArrowArrayStream,
out: *mut FFI_ArrowArray,
) -> c_int,
>,
pub get_last_error:
Option<unsafe extern "C" fn(arg1: *mut FFI_ArrowArrayStream) -> *const c_char>,
pub release: Option<unsafe extern "C" fn(arg1: *mut FFI_ArrowArrayStream)>,
pub private_data: *mut c_void,
}
unsafe extern "C" fn release_stream(stream: *mut FFI_ArrowArrayStream) {
if stream.is_null() {
return;
}
let stream = &mut *stream;
stream.get_schema = None;
stream.get_next = None;
stream.get_last_error = None;
let private_data = Box::from_raw(stream.private_data as *mut StreamPrivateData);
drop(private_data);
stream.release = None;
}
struct StreamPrivateData {
batch_reader: Box<dyn RecordBatchReader>,
last_error: String,
}
unsafe extern "C" fn get_schema(
stream: *mut FFI_ArrowArrayStream,
schema: *mut FFI_ArrowSchema,
) -> c_int {
ExportedArrayStream { stream }.get_schema(schema)
}
unsafe extern "C" fn get_next(
stream: *mut FFI_ArrowArrayStream,
array: *mut FFI_ArrowArray,
) -> c_int {
ExportedArrayStream { stream }.get_next(array)
}
unsafe extern "C" fn get_last_error(stream: *mut FFI_ArrowArrayStream) -> *const c_char {
let mut ffi_stream = ExportedArrayStream { stream };
let last_error = ffi_stream.get_last_error();
CString::new(last_error.as_str()).unwrap().into_raw()
}
impl Drop for FFI_ArrowArrayStream {
fn drop(&mut self) {
match self.release {
None => (),
Some(release) => unsafe { release(self) },
};
}
}
impl FFI_ArrowArrayStream {
pub fn new(batch_reader: Box<dyn RecordBatchReader>) -> Self {
let private_data = Box::new(StreamPrivateData {
batch_reader,
last_error: String::new(),
});
Self {
get_schema: Some(get_schema),
get_next: Some(get_next),
get_last_error: Some(get_last_error),
release: Some(release_stream),
private_data: Box::into_raw(private_data) as *mut c_void,
}
}
pub fn empty() -> Self {
Self {
get_schema: None,
get_next: None,
get_last_error: None,
release: None,
private_data: std::ptr::null_mut(),
}
}
}
struct ExportedArrayStream {
stream: *mut FFI_ArrowArrayStream,
}
impl ExportedArrayStream {
fn get_private_data(&mut self) -> &mut StreamPrivateData {
unsafe { &mut *((*self.stream).private_data as *mut StreamPrivateData) }
}
pub fn get_schema(&mut self, out: *mut FFI_ArrowSchema) -> i32 {
let mut private_data = self.get_private_data();
let reader = &private_data.batch_reader;
let schema = FFI_ArrowSchema::try_from(reader.schema().as_ref());
match schema {
Ok(schema) => {
unsafe { std::ptr::copy(addr_of!(schema), out, 1) };
std::mem::forget(schema);
0
}
Err(ref err) => {
private_data.last_error = err.to_string();
get_error_code(err)
}
}
}
pub fn get_next(&mut self, out: *mut FFI_ArrowArray) -> i32 {
let mut private_data = self.get_private_data();
let reader = &mut private_data.batch_reader;
match reader.next() {
None => {
unsafe { std::ptr::write(out, FFI_ArrowArray::empty()) }
0
}
Some(next_batch) => {
if let Ok(batch) = next_batch {
let struct_array = StructArray::from(batch);
let array = FFI_ArrowArray::new(&struct_array.to_data());
unsafe { std::ptr::copy(addr_of!(array), out, 1) };
std::mem::forget(array);
0
} else {
let err = &next_batch.unwrap_err();
private_data.last_error = err.to_string();
get_error_code(err)
}
}
}
}
pub fn get_last_error(&mut self) -> &String {
&self.get_private_data().last_error
}
}
fn get_error_code(err: &ArrowError) -> i32 {
match err {
ArrowError::NotYetImplemented(_) => ENOSYS,
ArrowError::MemoryError(_) => ENOMEM,
ArrowError::IoError(_) => EIO,
_ => EINVAL,
}
}
#[derive(Debug, Clone)]
pub struct ArrowArrayStreamReader {
stream: Arc<FFI_ArrowArrayStream>,
schema: SchemaRef,
}
fn get_stream_schema(stream_ptr: *mut FFI_ArrowArrayStream) -> Result<SchemaRef> {
let empty_schema = Arc::new(FFI_ArrowSchema::empty());
let schema_ptr = Arc::into_raw(empty_schema) as *mut FFI_ArrowSchema;
let ret_code = unsafe { (*stream_ptr).get_schema.unwrap()(stream_ptr, schema_ptr) };
let ffi_schema = unsafe { Arc::from_raw(schema_ptr) };
if ret_code == 0 {
let schema = Schema::try_from(ffi_schema.as_ref()).unwrap();
Ok(Arc::new(schema))
} else {
Err(ArrowError::CDataInterface(format!(
"Cannot get schema from input stream. Error code: {ret_code:?}"
)))
}
}
impl ArrowArrayStreamReader {
#[allow(dead_code)]
pub fn try_new(stream: FFI_ArrowArrayStream) -> Result<Self> {
if stream.release.is_none() {
return Err(ArrowError::CDataInterface(
"input stream is already released".to_string(),
));
}
let stream_ptr = Arc::into_raw(Arc::new(stream)) as *mut FFI_ArrowArrayStream;
let schema = get_stream_schema(stream_ptr)?;
Ok(Self {
stream: unsafe { Arc::from_raw(stream_ptr) },
schema,
})
}
pub unsafe fn from_raw(raw_stream: *mut FFI_ArrowArrayStream) -> Result<Self> {
let stream_data = std::ptr::replace(raw_stream, FFI_ArrowArrayStream::empty());
Self::try_new(stream_data)
}
fn get_stream_last_error(&self) -> Option<String> {
self.stream.get_last_error?;
let stream_ptr = Arc::as_ptr(&self.stream) as *mut FFI_ArrowArrayStream;
let error_str = unsafe {
let c_str = self.stream.get_last_error.unwrap()(stream_ptr) as *mut c_char;
CString::from_raw(c_str).into_string()
};
if let Err(err) = error_str {
Some(err.to_string())
} else {
Some(error_str.unwrap())
}
}
}
impl Iterator for ArrowArrayStreamReader {
type Item = Result<RecordBatch>;
fn next(&mut self) -> Option<Self::Item> {
let stream_ptr = Arc::as_ptr(&self.stream) as *mut FFI_ArrowArrayStream;
let empty_array = Arc::new(FFI_ArrowArray::empty());
let array_ptr = Arc::into_raw(empty_array) as *mut FFI_ArrowArray;
let ret_code = unsafe { self.stream.get_next.unwrap()(stream_ptr, array_ptr) };
if ret_code == 0 {
let ffi_array = unsafe { Arc::from_raw(array_ptr) };
if ffi_array.is_released() {
return None;
}
let schema_ref = self.schema();
let schema = FFI_ArrowSchema::try_from(schema_ref.as_ref()).ok()?;
let data = ArrowArray {
array: ffi_array,
schema: Arc::new(schema),
}
.to_data()
.ok()?;
let record_batch = RecordBatch::from(StructArray::from(data));
Some(Ok(record_batch))
} else {
unsafe { Arc::from_raw(array_ptr) };
let last_error = self.get_stream_last_error();
let err = ArrowError::CDataInterface(last_error.unwrap());
Some(Err(err))
}
}
}
impl RecordBatchReader for ArrowArrayStreamReader {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
pub unsafe fn export_reader_into_raw(
reader: Box<dyn RecordBatchReader>,
out_stream: *mut FFI_ArrowArrayStream,
) {
let stream = FFI_ArrowArrayStream::new(reader);
std::ptr::write_unaligned(out_stream, stream);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::array::Int32Array;
use crate::datatypes::{Field, Schema};
struct TestRecordBatchReader {
schema: SchemaRef,
iter: Box<dyn Iterator<Item = Result<RecordBatch>>>,
}
impl TestRecordBatchReader {
pub fn new(
schema: SchemaRef,
iter: Box<dyn Iterator<Item = Result<RecordBatch>>>,
) -> Box<TestRecordBatchReader> {
Box::new(TestRecordBatchReader { schema, iter })
}
}
impl Iterator for TestRecordBatchReader {
type Item = Result<RecordBatch>;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}
impl RecordBatchReader for TestRecordBatchReader {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
fn _test_round_trip_export(arrays: Vec<Arc<dyn Array>>) -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", arrays[0].data_type().clone(), true),
Field::new("b", arrays[1].data_type().clone(), true),
Field::new("c", arrays[2].data_type().clone(), true),
]));
let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
let iter = Box::new(vec![batch.clone(), batch.clone()].into_iter().map(Ok)) as _;
let reader = TestRecordBatchReader::new(schema.clone(), iter);
let stream = Arc::new(FFI_ArrowArrayStream::empty());
let stream_ptr = Arc::into_raw(stream) as *mut FFI_ArrowArrayStream;
unsafe { export_reader_into_raw(reader, stream_ptr) };
let empty_schema = Arc::new(FFI_ArrowSchema::empty());
let schema_ptr = Arc::into_raw(empty_schema) as *mut FFI_ArrowSchema;
let ret_code = unsafe { get_schema(stream_ptr, schema_ptr) };
assert_eq!(ret_code, 0);
let ffi_schema = unsafe { Arc::from_raw(schema_ptr) };
let exported_schema = Schema::try_from(ffi_schema.as_ref()).unwrap();
assert_eq!(&exported_schema, schema.as_ref());
let mut produced_batches = vec![];
loop {
let empty_array = Arc::new(FFI_ArrowArray::empty());
let array_ptr = Arc::into_raw(empty_array.clone()) as *mut FFI_ArrowArray;
let ret_code = unsafe { get_next(stream_ptr, array_ptr) };
assert_eq!(ret_code, 0);
let ffi_array = unsafe { Arc::from_raw(array_ptr) };
if ffi_array.is_released() {
break;
}
let array = ArrowArray {
array: ffi_array,
schema: ffi_schema.clone(),
}
.to_data()
.unwrap();
let record_batch = RecordBatch::from(StructArray::from(array));
produced_batches.push(record_batch);
}
assert_eq!(produced_batches, vec![batch.clone(), batch]);
unsafe { Arc::from_raw(stream_ptr) };
Ok(())
}
fn _test_round_trip_import(arrays: Vec<Arc<dyn Array>>) -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", arrays[0].data_type().clone(), true),
Field::new("b", arrays[1].data_type().clone(), true),
Field::new("c", arrays[2].data_type().clone(), true),
]));
let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
let iter = Box::new(vec![batch.clone(), batch.clone()].into_iter().map(Ok)) as _;
let reader = TestRecordBatchReader::new(schema.clone(), iter);
let stream = Arc::new(FFI_ArrowArrayStream::new(reader));
let stream_ptr = Arc::into_raw(stream) as *mut FFI_ArrowArrayStream;
let stream_reader =
unsafe { ArrowArrayStreamReader::from_raw(stream_ptr).unwrap() };
let imported_schema = stream_reader.schema();
assert_eq!(imported_schema, schema);
let mut produced_batches = vec![];
for batch in stream_reader {
produced_batches.push(batch.unwrap());
}
assert_eq!(produced_batches, vec![batch.clone(), batch]);
unsafe { Arc::from_raw(stream_ptr) };
Ok(())
}
#[test]
fn test_stream_round_trip_export() -> Result<()> {
let array = Int32Array::from(vec![Some(2), None, Some(1), None]);
let array: Arc<dyn Array> = Arc::new(array);
_test_round_trip_export(vec![array.clone(), array.clone(), array])
}
#[test]
fn test_stream_round_trip_import() -> Result<()> {
let array = Int32Array::from(vec![Some(2), None, Some(1), None]);
let array: Arc<dyn Array> = Arc::new(array);
_test_round_trip_import(vec![array.clone(), array.clone(), array])
}
}