use std::ffi::CStr;
use std::os::raw;
use std::os::raw::c_int;
use std::ptr::NonNull;
use std::sync::Arc;
use arrow_array::builder::BooleanBuilder;
use arrow_array::{
ArrayRef, FixedSizeListArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use arrow_buffer::{Buffer, ScalarBuffer};
use arrow_schema::Field;
use pyo3::buffer::{ElementType, PyBuffer};
use pyo3::exceptions::PyValueError;
use pyo3::ffi;
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use crate::error::{PyArrowError, PyArrowResult};
use crate::PyArray;
#[pyclass(module = "arro3.core._core", name = "Buffer", subclass, frozen)]
pub struct PyArrowBuffer(Buffer);
impl AsRef<Buffer> for PyArrowBuffer {
fn as_ref(&self) -> &Buffer {
&self.0
}
}
impl AsRef<[u8]> for PyArrowBuffer {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
impl PyArrowBuffer {
pub fn new(buffer: Buffer) -> Self {
Self(buffer)
}
pub fn into_inner(self) -> Buffer {
self.0
}
}
#[pymethods]
impl PyArrowBuffer {
#[new]
fn py_new(buf: PyArrowBuffer) -> Self {
buf
}
fn to_bytes<'py>(&'py self, py: Python<'py>) -> Bound<'py, PyBytes> {
PyBytes::new(py, &self.0)
}
fn __len__(&self) -> usize {
self.0.len()
}
unsafe fn __getbuffer__(
slf: PyRef<Self>,
view: *mut ffi::Py_buffer,
flags: c_int,
) -> PyResult<()> {
let bytes = slf.0.as_slice();
let ret = ffi::PyBuffer_FillInfo(
view,
slf.as_ptr() as *mut _,
bytes.as_ptr() as *mut _,
bytes.len().try_into().unwrap(),
1, flags,
);
if ret == -1 {
return Err(PyErr::fetch(slf.py()));
}
Ok(())
}
unsafe fn __releasebuffer__(&self, _view: *mut ffi::Py_buffer) {}
}
impl<'py> FromPyObject<'_, 'py> for PyArrowBuffer {
type Error = PyErr;
fn extract(obj: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
let buffer = obj.extract::<AnyBufferProtocol>()?;
if !matches!(buffer, AnyBufferProtocol::UInt8(_)) {
return Err(PyValueError::new_err("Expected u8 buffer protocol object"));
}
Ok(Self(buffer.into_arrow_buffer()?))
}
}
#[allow(missing_docs)]
#[derive(Debug)]
pub enum AnyBufferProtocol {
UInt8(PyBuffer<u8>),
UInt16(PyBuffer<u16>),
UInt32(PyBuffer<u32>),
UInt64(PyBuffer<u64>),
Int8(PyBuffer<i8>),
Int16(PyBuffer<i16>),
Int32(PyBuffer<i32>),
Int64(PyBuffer<i64>),
Float32(PyBuffer<f32>),
Float64(PyBuffer<f64>),
}
impl<'py> FromPyObject<'_, 'py> for AnyBufferProtocol {
type Error = PyErr;
fn extract(obj: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
if let Ok(buf) = obj.extract() {
Ok(Self::UInt8(buf))
} else if let Ok(buf) = obj.extract() {
Ok(Self::UInt16(buf))
} else if let Ok(buf) = obj.extract() {
Ok(Self::UInt32(buf))
} else if let Ok(buf) = obj.extract() {
Ok(Self::UInt64(buf))
} else if let Ok(buf) = obj.extract() {
Ok(Self::Int8(buf))
} else if let Ok(buf) = obj.extract() {
Ok(Self::Int16(buf))
} else if let Ok(buf) = obj.extract() {
Ok(Self::Int32(buf))
} else if let Ok(buf) = obj.extract() {
Ok(Self::Int64(buf))
} else if let Ok(buf) = obj.extract() {
Ok(Self::Float32(buf))
} else if let Ok(buf) = obj.extract() {
Ok(Self::Float64(buf))
} else {
Err(PyValueError::new_err("Not a buffer protocol object"))
}
}
}
impl AnyBufferProtocol {
fn buf_ptr(&self) -> PyResult<*mut raw::c_void> {
let out = match self {
Self::UInt8(buf) => buf.buf_ptr(),
Self::UInt16(buf) => buf.buf_ptr(),
Self::UInt32(buf) => buf.buf_ptr(),
Self::UInt64(buf) => buf.buf_ptr(),
Self::Int8(buf) => buf.buf_ptr(),
Self::Int16(buf) => buf.buf_ptr(),
Self::Int32(buf) => buf.buf_ptr(),
Self::Int64(buf) => buf.buf_ptr(),
Self::Float32(buf) => buf.buf_ptr(),
Self::Float64(buf) => buf.buf_ptr(),
};
Ok(out)
}
#[allow(dead_code)]
fn dimensions(&self) -> PyResult<usize> {
let out = match self {
Self::UInt8(buf) => buf.dimensions(),
Self::UInt16(buf) => buf.dimensions(),
Self::UInt32(buf) => buf.dimensions(),
Self::UInt64(buf) => buf.dimensions(),
Self::Int8(buf) => buf.dimensions(),
Self::Int16(buf) => buf.dimensions(),
Self::Int32(buf) => buf.dimensions(),
Self::Int64(buf) => buf.dimensions(),
Self::Float32(buf) => buf.dimensions(),
Self::Float64(buf) => buf.dimensions(),
};
Ok(out)
}
fn format(&self) -> PyResult<&CStr> {
let out = match self {
Self::UInt8(buf) => buf.format(),
Self::UInt16(buf) => buf.format(),
Self::UInt32(buf) => buf.format(),
Self::UInt64(buf) => buf.format(),
Self::Int8(buf) => buf.format(),
Self::Int16(buf) => buf.format(),
Self::Int32(buf) => buf.format(),
Self::Int64(buf) => buf.format(),
Self::Float32(buf) => buf.format(),
Self::Float64(buf) => buf.format(),
};
Ok(out)
}
pub fn into_arrow_array(self) -> PyArrowResult<ArrayRef> {
self.validate_buffer()?;
let shape = self.shape()?.to_vec();
if shape.len() == 1 {
self.into_arrow_values()
} else {
assert!(shape.len() > 1, "shape cannot be 0");
let mut values = self.into_arrow_values()?;
for size in shape[1..].iter().rev() {
let field = Arc::new(Field::new("item", values.data_type().clone(), false));
let x = FixedSizeListArray::new(field, (*size).try_into().unwrap(), values, None);
values = Arc::new(x);
}
Ok(values)
}
}
fn into_arrow_values(self) -> PyArrowResult<ArrayRef> {
let len = self.item_count()?;
let len_bytes = self.len_bytes()?;
let ptr = NonNull::new(self.buf_ptr()? as _)
.ok_or(PyValueError::new_err("Expected buffer ptr to be non null"))?;
let element_type = ElementType::from_format(self.format()?);
match self {
Self::UInt8(buf) => match element_type {
ElementType::Bool => {
let slice = NonNull::slice_from_raw_parts(ptr, len);
let slice = unsafe { slice.as_ref() };
let mut builder = BooleanBuilder::with_capacity(len);
for val in slice {
builder.append_value(*val > 0);
}
Ok(Arc::new(builder.finish()))
}
ElementType::UnsignedInteger { bytes } => {
if bytes != 1 {
return Err(PyValueError::new_err(format!(
"Expected 1 byte element type, got {}",
bytes
))
.into());
}
let owner = Arc::new(buf);
let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
Ok(Arc::new(UInt8Array::new(
ScalarBuffer::new(buffer, 0, len),
None,
)))
}
_ => Err(PyValueError::new_err(format!(
"Unexpected element type {:?}",
element_type
))
.into()),
},
Self::UInt16(buf) => {
let owner = Arc::new(buf);
let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
Ok(Arc::new(UInt16Array::new(
ScalarBuffer::new(buffer, 0, len),
None,
)))
}
Self::UInt32(buf) => {
let owner = Arc::new(buf);
let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
Ok(Arc::new(UInt32Array::new(
ScalarBuffer::new(buffer, 0, len),
None,
)))
}
Self::UInt64(buf) => {
let owner = Arc::new(buf);
let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
Ok(Arc::new(UInt64Array::new(
ScalarBuffer::new(buffer, 0, len),
None,
)))
}
Self::Int8(buf) => {
let owner = Arc::new(buf);
let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
Ok(Arc::new(Int8Array::new(
ScalarBuffer::new(buffer, 0, len),
None,
)))
}
Self::Int16(buf) => {
let owner = Arc::new(buf);
let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
Ok(Arc::new(Int16Array::new(
ScalarBuffer::new(buffer, 0, len),
None,
)))
}
Self::Int32(buf) => {
let owner = Arc::new(buf);
let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
Ok(Arc::new(Int32Array::new(
ScalarBuffer::new(buffer, 0, len),
None,
)))
}
Self::Int64(buf) => {
let owner = Arc::new(buf);
let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
Ok(Arc::new(Int64Array::new(
ScalarBuffer::new(buffer, 0, len),
None,
)))
}
Self::Float32(buf) => {
let owner = Arc::new(buf);
let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
Ok(Arc::new(Float32Array::new(
ScalarBuffer::new(buffer, 0, len),
None,
)))
}
Self::Float64(buf) => {
let owner = Arc::new(buf);
let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
Ok(Arc::new(Float64Array::new(
ScalarBuffer::new(buffer, 0, len),
None,
)))
}
}
}
pub fn into_arrow_buffer(self) -> PyArrowResult<Buffer> {
let len_bytes = self.len_bytes()?;
let ptr = NonNull::new(self.buf_ptr()? as _)
.ok_or(PyValueError::new_err("Expected buffer ptr to be non null"))?;
let buffer = match self {
Self::UInt8(buf) => {
let owner = Arc::new(buf);
unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
}
Self::UInt16(buf) => {
let owner = Arc::new(buf);
unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
}
Self::UInt32(buf) => {
let owner = Arc::new(buf);
unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
}
Self::UInt64(buf) => {
let owner = Arc::new(buf);
unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
}
Self::Int8(buf) => {
let owner = Arc::new(buf);
unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
}
Self::Int16(buf) => {
let owner = Arc::new(buf);
unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
}
Self::Int32(buf) => {
let owner = Arc::new(buf);
unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
}
Self::Int64(buf) => {
let owner = Arc::new(buf);
unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
}
Self::Float32(buf) => {
let owner = Arc::new(buf);
unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
}
Self::Float64(buf) => {
let owner = Arc::new(buf);
unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
}
};
Ok(buffer)
}
fn item_count(&self) -> PyResult<usize> {
let out = match self {
Self::UInt8(buf) => buf.item_count(),
Self::UInt16(buf) => buf.item_count(),
Self::UInt32(buf) => buf.item_count(),
Self::UInt64(buf) => buf.item_count(),
Self::Int8(buf) => buf.item_count(),
Self::Int16(buf) => buf.item_count(),
Self::Int32(buf) => buf.item_count(),
Self::Int64(buf) => buf.item_count(),
Self::Float32(buf) => buf.item_count(),
Self::Float64(buf) => buf.item_count(),
};
Ok(out)
}
fn is_c_contiguous(&self) -> PyResult<bool> {
let out = match self {
Self::UInt8(buf) => buf.is_c_contiguous(),
Self::UInt16(buf) => buf.is_c_contiguous(),
Self::UInt32(buf) => buf.is_c_contiguous(),
Self::UInt64(buf) => buf.is_c_contiguous(),
Self::Int8(buf) => buf.is_c_contiguous(),
Self::Int16(buf) => buf.is_c_contiguous(),
Self::Int32(buf) => buf.is_c_contiguous(),
Self::Int64(buf) => buf.is_c_contiguous(),
Self::Float32(buf) => buf.is_c_contiguous(),
Self::Float64(buf) => buf.is_c_contiguous(),
};
Ok(out)
}
fn len_bytes(&self) -> PyResult<usize> {
let out = match self {
Self::UInt8(buf) => buf.len_bytes(),
Self::UInt16(buf) => buf.len_bytes(),
Self::UInt32(buf) => buf.len_bytes(),
Self::UInt64(buf) => buf.len_bytes(),
Self::Int8(buf) => buf.len_bytes(),
Self::Int16(buf) => buf.len_bytes(),
Self::Int32(buf) => buf.len_bytes(),
Self::Int64(buf) => buf.len_bytes(),
Self::Float32(buf) => buf.len_bytes(),
Self::Float64(buf) => buf.len_bytes(),
};
Ok(out)
}
fn shape(&self) -> PyResult<&[usize]> {
let out = match self {
Self::UInt8(buf) => buf.shape(),
Self::UInt16(buf) => buf.shape(),
Self::UInt32(buf) => buf.shape(),
Self::UInt64(buf) => buf.shape(),
Self::Int8(buf) => buf.shape(),
Self::Int16(buf) => buf.shape(),
Self::Int32(buf) => buf.shape(),
Self::Int64(buf) => buf.shape(),
Self::Float32(buf) => buf.shape(),
Self::Float64(buf) => buf.shape(),
};
Ok(out)
}
fn validate_buffer(&self) -> PyArrowResult<()> {
if !self.is_c_contiguous()? {
return Err(PyValueError::new_err("Buffer is not C contiguous").into());
}
if self.shape()?.contains(&0) {
return Err(
PyValueError::new_err("0-length dimension not currently supported.").into(),
);
}
Ok(())
}
}
impl TryFrom<AnyBufferProtocol> for PyArray {
type Error = PyArrowError;
fn try_from(value: AnyBufferProtocol) -> Result<Self, Self::Error> {
let array = value.into_arrow_array()?;
Ok(Self::from_array_ref(array))
}
}