use std::convert::{From, TryFrom};
use std::ffi::CStr;
use std::sync::Arc;
use arrow_array::ffi;
use arrow_array::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
use arrow_array::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream};
use arrow_array::{
RecordBatch, RecordBatchIterator, RecordBatchOptions, RecordBatchReader, StructArray,
make_array,
};
use arrow_data::ArrayData;
use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::ffi::Py_uintptr_t;
use pyo3::import_exception;
use pyo3::prelude::*;
use pyo3::sync::PyOnceLock;
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
import_exception!(pyarrow, ArrowException);
pub type PyArrowException = ArrowException;
const ARROW_ARRAY_STREAM_CAPSULE_NAME: &CStr = c"arrow_array_stream";
const ARROW_SCHEMA_CAPSULE_NAME: &CStr = c"arrow_schema";
const ARROW_ARRAY_CAPSULE_NAME: &CStr = c"arrow_array";
fn to_py_err(err: ArrowError) -> PyErr {
PyArrowException::new_err(err.to_string())
}
pub trait FromPyArrow: Sized {
fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self>;
}
pub trait ToPyArrow {
fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>>;
}
pub trait IntoPyArrow {
fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>>;
}
impl<T: ToPyArrow> IntoPyArrow for T {
fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
self.to_pyarrow(py)
}
}
fn validate_class(expected: &Bound<PyType>, value: &Bound<PyAny>) -> PyResult<()> {
if !value.is_instance(expected)? {
let expected_module = expected.getattr("__module__")?;
let expected_name = expected.getattr("__name__")?;
let found_class = value.get_type();
let found_module = found_class.getattr("__module__")?;
let found_name = found_class.getattr("__name__")?;
return Err(PyTypeError::new_err(format!(
"Expected instance of {expected_module}.{expected_name}, got {found_module}.{found_name}",
)));
}
Ok(())
}
fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
let capsule_name = capsule.name()?;
if capsule_name.is_none() {
return Err(PyValueError::new_err(
"Expected schema PyCapsule to have name set.",
));
}
let capsule_name = unsafe { capsule_name.unwrap().as_cstr().to_str()? };
if capsule_name != name {
return Err(PyValueError::new_err(format!(
"Expected name '{name}' in PyCapsule, instead got '{capsule_name}'",
)));
}
Ok(())
}
fn extract_arrow_c_array_capsules<'py>(
value: &Bound<'py, PyAny>,
) -> PyResult<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>)> {
let tuple = value.call_method0("__arrow_c_array__")?;
if !tuple.is_instance_of::<PyTuple>() {
return Err(PyTypeError::new_err(
"Expected __arrow_c_array__ to return a tuple of (schema, array) capsules.",
));
}
tuple.extract().map_err(|_| {
PyTypeError::new_err(
"Expected __arrow_c_array__ to return a tuple of (schema, array) capsules.",
)
})
}
impl FromPyArrow for DataType {
fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
if value.hasattr("__arrow_c_schema__")? {
let capsule = value.call_method0("__arrow_c_schema__")?.extract()?;
validate_pycapsule(&capsule, "arrow_schema")?;
let schema_ptr = capsule
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
.cast::<FFI_ArrowSchema>();
return unsafe { DataType::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
}
validate_class(data_type_class(value.py())?, value)?;
let mut c_schema = FFI_ArrowSchema::empty();
value.call_method1("_export_to_c", (&raw mut c_schema as Py_uintptr_t,))?;
DataType::try_from(&c_schema).map_err(to_py_err)
}
}
impl ToPyArrow for DataType {
fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
data_type_class(py)?.call_method1("_import_from_c", (&raw const c_schema as Py_uintptr_t,))
}
}
impl FromPyArrow for Field {
fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
if value.hasattr("__arrow_c_schema__")? {
let capsule = value.call_method0("__arrow_c_schema__")?.extract()?;
validate_pycapsule(&capsule, "arrow_schema")?;
let schema_ptr = capsule
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
.cast::<FFI_ArrowSchema>();
return unsafe { Field::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
}
validate_class(field_class(value.py())?, value)?;
let mut c_schema = FFI_ArrowSchema::empty();
value.call_method1("_export_to_c", (&raw mut c_schema as Py_uintptr_t,))?;
Field::try_from(&c_schema).map_err(to_py_err)
}
}
impl ToPyArrow for Field {
fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
field_class(py)?.call_method1("_import_from_c", (&raw const c_schema as Py_uintptr_t,))
}
}
impl FromPyArrow for Schema {
fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
if value.hasattr("__arrow_c_schema__")? {
let capsule = value.call_method0("__arrow_c_schema__")?.extract()?;
validate_pycapsule(&capsule, "arrow_schema")?;
let schema_ptr = capsule
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
.cast::<FFI_ArrowSchema>();
return unsafe { Schema::try_from(schema_ptr.as_ref()) }.map_err(to_py_err);
}
validate_class(schema_class(value.py())?, value)?;
let mut c_schema = FFI_ArrowSchema::empty();
value.call_method1("_export_to_c", (&raw mut c_schema as Py_uintptr_t,))?;
Schema::try_from(&c_schema).map_err(to_py_err)
}
}
impl ToPyArrow for Schema {
fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
schema_class(py)?.call_method1("_import_from_c", (&raw const c_schema as Py_uintptr_t,))
}
}
impl FromPyArrow for ArrayData {
fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
if value.hasattr("__arrow_c_array__")? {
let (schema_capsule, array_capsule) = extract_arrow_c_array_capsules(value)?;
validate_pycapsule(&schema_capsule, "arrow_schema")?;
validate_pycapsule(&array_capsule, "arrow_array")?;
let schema_ptr = schema_capsule
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
.cast::<FFI_ArrowSchema>();
let array = unsafe {
FFI_ArrowArray::from_raw(
array_capsule
.pointer_checked(Some(ARROW_ARRAY_CAPSULE_NAME))?
.cast::<FFI_ArrowArray>()
.as_ptr(),
)
};
return unsafe { ffi::from_ffi(array, schema_ptr.as_ref()) }.map_err(to_py_err);
}
validate_class(array_class(value.py())?, value)?;
let mut array = FFI_ArrowArray::empty();
let mut schema = FFI_ArrowSchema::empty();
value.call_method1(
"_export_to_c",
(
&raw mut array as Py_uintptr_t,
&raw mut schema as Py_uintptr_t,
),
)?;
unsafe { ffi::from_ffi(array, &schema) }.map_err(to_py_err)
}
}
impl ToPyArrow for ArrayData {
fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let array = FFI_ArrowArray::new(self);
let schema = FFI_ArrowSchema::try_from(self.data_type()).map_err(to_py_err)?;
array_class(py)?.call_method1(
"_import_from_c",
(
&raw const array as Py_uintptr_t,
&raw const schema as Py_uintptr_t,
),
)
}
}
impl<T: FromPyArrow> FromPyArrow for Vec<T> {
fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
let list = value.cast::<PyList>()?;
list.iter().map(|x| T::from_pyarrow_bound(&x)).collect()
}
}
impl<T: ToPyArrow> ToPyArrow for Vec<T> {
fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let values = self
.iter()
.map(|v| v.to_pyarrow(py))
.collect::<PyResult<Vec<_>>>()?;
Ok(PyList::new(py, values)?.into_any())
}
}
impl FromPyArrow for RecordBatch {
fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
if value.hasattr("__arrow_c_array__")? {
let (schema_capsule, array_capsule) = extract_arrow_c_array_capsules(value)?;
validate_pycapsule(&schema_capsule, "arrow_schema")?;
validate_pycapsule(&array_capsule, "arrow_array")?;
let schema_ptr = schema_capsule
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
.cast::<FFI_ArrowSchema>();
let array_ptr = array_capsule
.pointer_checked(Some(ARROW_ARRAY_CAPSULE_NAME))?
.cast::<FFI_ArrowArray>();
let ffi_array = unsafe { FFI_ArrowArray::from_raw(array_ptr.as_ptr()) };
let mut array_data =
unsafe { ffi::from_ffi(ffi_array, schema_ptr.as_ref()) }.map_err(to_py_err)?;
if !matches!(array_data.data_type(), DataType::Struct(_)) {
return Err(PyTypeError::new_err(
"Expected Struct type from __arrow_c_array.",
));
}
let options = RecordBatchOptions::default().with_row_count(Some(array_data.len()));
array_data.align_buffers();
let array = StructArray::from(array_data);
let schema =
unsafe { Arc::new(Schema::try_from(schema_ptr.as_ref()).map_err(to_py_err)?) };
let (_fields, columns, nulls) = array.into_parts();
assert_eq!(
nulls.map(|n| n.null_count()).unwrap_or_default(),
0,
"Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
);
return RecordBatch::try_new_with_options(schema, columns, &options).map_err(to_py_err);
}
validate_class(record_batch_class(value.py())?, value)?;
let schema = value.getattr("schema")?;
let schema = Arc::new(Schema::from_pyarrow_bound(&schema)?);
let arrays = value.getattr("columns")?;
let arrays = arrays
.cast::<PyList>()?
.iter()
.map(|a| Ok(make_array(ArrayData::from_pyarrow_bound(&a)?)))
.collect::<PyResult<_>>()?;
let row_count = value
.getattr("num_rows")
.ok()
.and_then(|x| x.extract().ok());
let options = RecordBatchOptions::default().with_row_count(row_count);
let batch =
RecordBatch::try_new_with_options(schema, arrays, &options).map_err(to_py_err)?;
Ok(batch)
}
}
impl ToPyArrow for RecordBatch {
fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let reader = RecordBatchIterator::new(vec![Ok(self.clone())], self.schema());
let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
let py_reader = reader.into_pyarrow(py)?;
py_reader.call_method0("read_next_batch")
}
}
impl FromPyArrow for ArrowArrayStreamReader {
fn from_pyarrow_bound(value: &Bound<PyAny>) -> PyResult<Self> {
if value.hasattr("__arrow_c_stream__")? {
let capsule = value.call_method0("__arrow_c_stream__")?.extract()?;
validate_pycapsule(&capsule, "arrow_array_stream")?;
let stream = unsafe {
FFI_ArrowArrayStream::from_raw(
capsule
.pointer_checked(Some(ARROW_ARRAY_STREAM_CAPSULE_NAME))?
.cast::<FFI_ArrowArrayStream>()
.as_ptr(),
)
};
let stream_reader = ArrowArrayStreamReader::try_new(stream)
.map_err(|err| PyValueError::new_err(err.to_string()))?;
return Ok(stream_reader);
}
validate_class(record_batch_reader_class(value.py())?, value)?;
let mut stream = FFI_ArrowArrayStream::empty();
let args = PyTuple::new(value.py(), [&raw mut stream as Py_uintptr_t])?;
value.call_method1("_export_to_c", args)?;
ArrowArrayStreamReader::try_new(stream)
.map_err(|err| PyValueError::new_err(err.to_string()))
}
}
impl IntoPyArrow for Box<dyn RecordBatchReader + Send> {
fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let stream = FFI_ArrowArrayStream::new(self);
record_batch_reader_class(py)?
.call_method1("_import_from_c", (&raw const stream as Py_uintptr_t,))
}
}
impl IntoPyArrow for ArrowArrayStreamReader {
fn into_pyarrow<'py>(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let boxed: Box<dyn RecordBatchReader + Send> = Box::new(self);
boxed.into_pyarrow(py)
}
}
#[derive(Clone)]
pub struct Table {
record_batches: Vec<RecordBatch>,
schema: SchemaRef,
}
impl Table {
pub fn try_new(
record_batches: Vec<RecordBatch>,
schema: SchemaRef,
) -> Result<Self, ArrowError> {
for record_batch in &record_batches {
if schema != record_batch.schema() {
return Err(ArrowError::SchemaError(format!(
"All record batches must have the same schema. \
Expected schema: {:?}, got schema: {:?}",
schema,
record_batch.schema()
)));
}
}
Ok(Self {
record_batches,
schema,
})
}
pub fn record_batches(&self) -> &[RecordBatch] {
&self.record_batches
}
pub fn schema(&self) -> SchemaRef {
self.schema.clone()
}
pub fn into_inner(self) -> (Vec<RecordBatch>, SchemaRef) {
(self.record_batches, self.schema)
}
}
impl TryFrom<Box<dyn RecordBatchReader>> for Table {
type Error = ArrowError;
fn try_from(value: Box<dyn RecordBatchReader>) -> Result<Self, ArrowError> {
let schema = value.schema();
let batches = value.collect::<Result<Vec<_>, _>>()?;
Self::try_new(batches, schema)
}
}
impl FromPyArrow for Table {
fn from_pyarrow_bound(ob: &Bound<PyAny>) -> PyResult<Self> {
let reader: Box<dyn RecordBatchReader> =
Box::new(ArrowArrayStreamReader::from_pyarrow_bound(ob)?);
Self::try_from(reader).map_err(|err| PyValueError::new_err(err.to_string()))
}
}
impl IntoPyArrow for Table {
fn into_pyarrow(self, py: Python) -> PyResult<Bound<PyAny>> {
let py_batches = PyList::new(py, self.record_batches.into_iter().map(PyArrowType))?;
let py_schema = PyArrowType(Arc::unwrap_or_clone(self.schema));
let kwargs = PyDict::new(py);
kwargs.set_item("schema", py_schema)?;
table_class(py)?.call_method("from_batches", (py_batches,), Some(&kwargs))
}
}
fn array_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
TYPE.import(py, "pyarrow", "Array")
}
fn record_batch_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
TYPE.import(py, "pyarrow", "RecordBatch")
}
fn record_batch_reader_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
TYPE.import(py, "pyarrow", "RecordBatchReader")
}
fn data_type_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
TYPE.import(py, "pyarrow", "DataType")
}
fn field_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
TYPE.import(py, "pyarrow", "Field")
}
fn schema_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
TYPE.import(py, "pyarrow", "Schema")
}
fn table_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
TYPE.import(py, "pyarrow", "Table")
}
#[derive(Debug)]
pub struct PyArrowType<T>(pub T);
impl<T: FromPyArrow> FromPyObject<'_, '_> for PyArrowType<T> {
type Error = PyErr;
fn extract(value: Borrowed<'_, '_, PyAny>) -> PyResult<Self> {
Ok(Self(T::from_pyarrow_bound(&value)?))
}
}
impl<'py, T: IntoPyArrow> IntoPyObject<'py> for PyArrowType<T> {
type Target = PyAny;
type Output = Bound<'py, Self::Target>;
type Error = PyErr;
fn into_pyobject(self, py: Python<'py>) -> PyResult<Self::Output> {
self.0.into_pyarrow(py)
}
}
impl<T> From<T> for PyArrowType<T> {
fn from(s: T) -> Self {
Self(s)
}
}