use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyTuple};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ArrayProtocolError {
#[error("unsupported dtype: {0}")]
UnsupportedDtype(String),
#[error("invalid typestr: {0}")]
InvalidTypestr(String),
#[error("python error: {0}")]
PythonError(String),
}
impl From<PyErr> for ArrayProtocolError {
fn from(e: PyErr) -> Self {
Self::PythonError(e.to_string())
}
}
impl From<ArrayProtocolError> for PyErr {
fn from(e: ArrayProtocolError) -> Self {
pyo3::exceptions::PyValueError::new_err(e.to_string())
}
}
pub fn parse_typestr(typestr: &str) -> Result<(char, usize), ArrayProtocolError> {
if typestr.len() < 3 {
return Err(ArrayProtocolError::InvalidTypestr(format!(
"too short: {typestr:?}"
)));
}
let mut chars = typestr.chars();
let endian = chars
.next()
.ok_or_else(|| ArrayProtocolError::InvalidTypestr(format!("empty typestr: {typestr:?}")))?;
if !matches!(endian, '<' | '>' | '=' | '|') {
return Err(ArrayProtocolError::InvalidTypestr(format!(
"unknown endianness character {endian:?} in {typestr:?}"
)));
}
let kind = chars.next().ok_or_else(|| {
ArrayProtocolError::InvalidTypestr(format!("missing kind in {typestr:?}"))
})?;
if !kind.is_ascii_alphabetic() {
return Err(ArrayProtocolError::InvalidTypestr(format!(
"invalid kind character {kind:?} in {typestr:?}"
)));
}
let size_str: String = chars.collect();
let byte_count = size_str.parse::<usize>().map_err(|_| {
ArrayProtocolError::InvalidTypestr(format!(
"invalid byte count {size_str:?} in {typestr:?}"
))
})?;
if byte_count == 0 {
return Err(ArrayProtocolError::InvalidTypestr(format!(
"byte count must be > 0 in {typestr:?}"
)));
}
Ok((kind, byte_count))
}
pub trait ArrayProtocol {
fn array_interface(&self) -> ArrayInterfaceDict;
fn dtype_str(&self) -> &'static str;
fn shape(&self) -> Vec<usize>;
fn strides(&self) -> Vec<usize>;
fn data_ptr(&self) -> *const u8;
fn nbytes(&self) -> usize;
}
pub struct ArrayInterfaceDict {
pub shape: Vec<usize>,
pub typestr: String,
pub data_ptr: usize,
pub readonly: bool,
pub strides: Option<Vec<usize>>,
pub version: u8,
}
impl ArrayInterfaceDict {
pub fn to_py_dict<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
let dict = PyDict::new(py);
let shape_tuple = PyTuple::new(py, self.shape.iter().copied())?;
dict.set_item("shape", shape_tuple)?;
dict.set_item("typestr", &self.typestr)?;
let data_tuple = PyTuple::new(py, [self.data_ptr, self.readonly as usize])?;
dict.set_item("data", data_tuple)?;
dict.set_item("version", self.version)?;
if let Some(ref strides) = self.strides {
let strides_tuple = PyTuple::new(py, strides.iter().copied())?;
dict.set_item("strides", strides_tuple)?;
}
Ok(dict)
}
}
#[pyclass(name = "NdArrayWrapper")]
pub struct NdArrayWrapper {
data: Vec<f64>,
shape: Vec<usize>,
strides: Vec<usize>,
dtype: String,
}
#[pymethods]
impl NdArrayWrapper {
#[new]
pub fn new(data: Vec<f64>, shape: Vec<usize>) -> PyResult<Self> {
let n: usize = shape.iter().product();
if data.len() != n {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"data length {} does not match shape product {}",
data.len(),
n
)));
}
let strides = compute_c_strides_bytes(&shape, std::mem::size_of::<f64>());
Ok(Self {
data,
shape,
strides,
dtype: "<f8".to_owned(),
})
}
#[pyo3(name = "__array__")]
pub fn array_method(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
let np = py.import("numpy").map_err(|e| {
pyo3::exceptions::PyImportError::new_err(format!("numpy not available: {e}"))
})?;
let flat_list = PyList::new(py, &self.data)?;
let kwargs = PyDict::new(py);
kwargs.set_item("dtype", "f8")?;
let arr = np.call_method("array", (flat_list,), Some(&kwargs))?;
let shape_tuple = PyTuple::new(py, self.shape.iter().copied())?;
let reshaped = arr.call_method1("reshape", (shape_tuple,))?;
Ok(reshaped.unbind())
}
#[getter]
pub fn array_interface(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
let desc = ArrayInterfaceDict {
shape: self.shape.clone(),
typestr: self.dtype.clone(),
data_ptr: self.data.as_ptr() as usize,
readonly: true,
strides: Some(self.strides.clone()),
version: 3,
};
let dict = desc.to_py_dict(py)?;
Ok(dict.into_any().unbind())
}
pub fn shape_tuple(&self, py: Python<'_>) -> Py<PyAny> {
PyTuple::new(py, self.shape.iter().copied())
.map(|t| t.into_any().unbind())
.unwrap_or_else(|_| py.None())
}
pub fn dtype_str(&self) -> &str {
&self.dtype
}
pub fn data(&self) -> Vec<f64> {
self.data.clone()
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
}
impl ArrayProtocol for NdArrayWrapper {
fn array_interface(&self) -> ArrayInterfaceDict {
ArrayInterfaceDict {
shape: self.shape.clone(),
typestr: self.dtype.clone(),
data_ptr: self.data.as_ptr() as usize,
readonly: true,
strides: Some(self.strides.clone()),
version: 3,
}
}
fn dtype_str(&self) -> &'static str {
"<f8"
}
fn shape(&self) -> Vec<usize> {
self.shape.clone()
}
fn strides(&self) -> Vec<usize> {
self.strides.clone()
}
fn data_ptr(&self) -> *const u8 {
self.data.as_ptr() as *const u8
}
fn nbytes(&self) -> usize {
self.data.len() * std::mem::size_of::<f64>()
}
}
pub fn register_array_protocol_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<NdArrayWrapper>()?;
Ok(())
}
fn compute_c_strides_bytes(shape: &[usize], elem_size: usize) -> Vec<usize> {
let n = shape.len();
if n == 0 {
return Vec::new();
}
let mut strides = vec![elem_size; n];
for i in (0..n - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_typestr_f64_le() {
let (kind, bytes) = parse_typestr("<f8").expect("parse_typestr failed");
assert_eq!(kind, 'f');
assert_eq!(bytes, 8);
}
#[test]
fn test_parse_typestr_i32_be() {
let (kind, bytes) = parse_typestr(">i4").expect("parse_typestr failed");
assert_eq!(kind, 'i');
assert_eq!(bytes, 4);
}
#[test]
fn test_parse_typestr_u16_native() {
let (kind, bytes) = parse_typestr("=u2").expect("parse_typestr failed");
assert_eq!(kind, 'u');
assert_eq!(bytes, 2);
}
#[test]
fn test_parse_typestr_bool_noendian() {
let (kind, bytes) = parse_typestr("|b1").expect("parse_typestr failed");
assert_eq!(kind, 'b');
assert_eq!(bytes, 1);
}
#[test]
fn test_parse_typestr_error_too_short() {
assert!(parse_typestr("<f").is_err());
assert!(parse_typestr("").is_err());
assert!(parse_typestr("<").is_err());
}
#[test]
fn test_parse_typestr_error_bad_endian() {
assert!(parse_typestr("?f8").is_err());
}
#[test]
fn test_parse_typestr_error_zero_bytes() {
assert!(parse_typestr("<f0").is_err());
}
#[test]
fn test_array_interface_dict_version() {
let data = vec![1.0_f64, 2.0, 3.0, 4.0];
let wrapper = NdArrayWrapper::new(data, vec![2, 2]).expect("NdArrayWrapper::new failed");
let iface = ArrayProtocol::array_interface(&wrapper);
assert_eq!(iface.version, 3, "version must be 3");
}
#[test]
fn test_array_interface_dict_shape() {
let data = vec![1.0_f64; 6];
let wrapper = NdArrayWrapper::new(data, vec![2, 3]).expect("NdArrayWrapper::new failed");
let iface = ArrayProtocol::array_interface(&wrapper);
assert_eq!(iface.shape, vec![2, 3]);
}
#[test]
fn test_array_interface_dict_typestr() {
let data = vec![0.0_f64; 4];
let wrapper = NdArrayWrapper::new(data, vec![4]).expect("NdArrayWrapper::new failed");
let iface = ArrayProtocol::array_interface(&wrapper);
assert_eq!(iface.typestr, "<f8");
}
#[test]
fn test_array_interface_dict_data_ptr_nonzero() {
let data = vec![1.0_f64, 2.0, 3.0];
let wrapper = NdArrayWrapper::new(data, vec![3]).expect("NdArrayWrapper::new failed");
let iface = ArrayProtocol::array_interface(&wrapper);
assert_ne!(iface.data_ptr, 0, "data pointer must be non-null");
}
#[test]
fn test_ndarray_wrapper_shape_mismatch() {
let result = NdArrayWrapper::new(vec![1.0; 4], vec![2, 3]);
assert!(result.is_err());
}
#[test]
fn test_ndarray_wrapper_scalar() {
let wrapper = NdArrayWrapper::new(vec![42.0], vec![1]).expect("scalar failed");
assert_eq!(wrapper.ndim(), 1);
assert_eq!(wrapper.data(), vec![42.0]);
}
#[test]
fn test_ndarray_wrapper_strides_c_order() {
let data = vec![0.0_f64; 12];
let wrapper = NdArrayWrapper::new(data, vec![3, 4]).expect("NdArrayWrapper::new failed");
let strides = ArrayProtocol::strides(&wrapper);
assert_eq!(strides, vec![32, 8]);
}
#[test]
fn test_array_interface_py_dict_keys() {
Python::attach(|py| {
let data = vec![1.0_f64, 2.0, 3.0, 4.0];
let wrapper =
NdArrayWrapper::new(data, vec![2, 2]).expect("NdArrayWrapper::new failed");
let iface = ArrayProtocol::array_interface(&wrapper);
let dict = iface.to_py_dict(py).expect("to_py_dict failed");
assert!(dict
.get_item("shape")
.expect("shape lookup failed")
.is_some());
assert!(dict
.get_item("typestr")
.expect("typestr lookup failed")
.is_some());
assert!(dict.get_item("data").expect("data lookup failed").is_some());
assert!(dict
.get_item("version")
.expect("version lookup failed")
.is_some());
});
}
#[test]
fn test_array_interface_py_dict_shape_values() {
Python::attach(|py| {
let data = vec![0.0_f64; 6];
let wrapper =
NdArrayWrapper::new(data, vec![2, 3]).expect("NdArrayWrapper::new failed");
let iface = ArrayProtocol::array_interface(&wrapper);
let dict = iface.to_py_dict(py).expect("to_py_dict failed");
let shape_obj = dict
.get_item("shape")
.expect("shape lookup failed")
.expect("shape missing");
let shape_tuple = shape_obj.cast::<PyTuple>().expect("shape is not a tuple");
assert_eq!(shape_tuple.len(), 2);
let v0: usize = shape_tuple
.get_item(0)
.expect("item 0")
.extract()
.expect("extract[0]");
let v1: usize = shape_tuple
.get_item(1)
.expect("item 1")
.extract()
.expect("extract[1]");
assert_eq!(v0, 2);
assert_eq!(v1, 3);
});
}
}