use pyo3::prelude::*;
use pyo3::types::PyAnyMethods;
#[pyfunction]
pub fn from_array_like_f32(obj: &Bound<'_, PyAny>) -> PyResult<Vec<f32>> {
if let Ok(values) = obj.getattr("values") {
if values.getattr("values").is_err() {
return from_array_like_f32(&values);
}
}
if let Ok(arr) = obj.call_method0("__array__") {
return from_array_like_f32(&arr);
}
let len = obj.len()?;
let mut result = Vec::with_capacity(len);
for i in 0..len {
let item = obj.get_item(i)?;
let val: f32 = item.extract()?;
result.push(val);
}
Ok(result)
}
#[pyfunction]
pub fn from_array_like_f64(obj: &Bound<'_, PyAny>) -> PyResult<Vec<f64>> {
if let Ok(values) = obj.getattr("values") {
if values.getattr("values").is_err() {
return from_array_like_f64(&values);
}
}
if let Ok(arr) = obj.call_method0("__array__") {
return from_array_like_f64(&arr);
}
let len = obj.len()?;
let mut result = Vec::with_capacity(len);
for i in 0..len {
let item = obj.get_item(i)?;
let val: f64 = item.extract()?;
result.push(val);
}
Ok(result)
}
#[pyclass(name = "SubclassArrayWrapper")]
pub struct SubclassArrayWrapper {
data: Vec<f64>,
shape: Vec<usize>,
dtype: String,
}
#[pymethods]
impl SubclassArrayWrapper {
#[new]
#[pyo3(signature = (data, shape, dtype = "float64".to_string()))]
pub fn new(data: Vec<f64>, shape: Vec<usize>, dtype: String) -> PyResult<Self> {
let n: usize = shape.iter().product::<usize>().max(1);
if shape.is_empty() {
if data.len() != 1 {
return Err(pyo3::exceptions::PyValueError::new_err(
"0-d SubclassArrayWrapper requires exactly one element",
));
}
} else if data.len() != n {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"data length {} does not match shape product {}",
data.len(),
n
)));
}
Ok(Self { data, shape, dtype })
}
pub fn __len__(&self) -> usize {
self.data.len()
}
pub fn __getitem__(&self, idx: usize) -> PyResult<f64> {
self.data.get(idx).copied().ok_or_else(|| {
pyo3::exceptions::PyIndexError::new_err(format!(
"index {} out of bounds for array of length {}",
idx,
self.data.len()
))
})
}
pub fn values(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
pub fn shape(&self) -> Vec<usize> {
self.shape.clone()
}
pub fn dtype(&self) -> &str {
&self.dtype
}
pub fn to_list(&self) -> Vec<f64> {
self.data.clone()
}
}
pub fn register_array_subclass_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(from_array_like_f32, m)?)?;
m.add_function(wrap_pyfunction!(from_array_like_f64, m)?)?;
m.add_class::<SubclassArrayWrapper>()?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn array_like_extracts_from_list() {
Python::attach(|py| {
let list = py
.eval(pyo3::ffi::c_str!("[1.0, 2.0, 3.0]"), None, None)
.expect("eval failed");
let result = from_array_like_f64(&list).expect("extraction failed");
assert_eq!(result, vec![1.0, 2.0, 3.0]);
});
}
#[test]
fn array_like_wrapper_len_correct() {
let wrapper =
SubclassArrayWrapper::new(vec![1.0, 2.0, 3.0], vec![3], "float64".to_string())
.expect("construction failed");
assert_eq!(wrapper.__len__(), 3);
}
#[test]
fn subclass_wrapper_getitem_correct() {
let wrapper =
SubclassArrayWrapper::new(vec![10.0, 20.0, 30.0], vec![3], "float64".to_string())
.expect("construction failed");
assert!((wrapper.__getitem__(1).expect("index valid") - 20.0).abs() < f64::EPSILON);
}
#[test]
fn subclass_wrapper_getitem_oob() {
let wrapper = SubclassArrayWrapper::new(vec![1.0], vec![1], "float64".to_string())
.expect("construction failed");
assert!(wrapper.__getitem__(99).is_err());
}
#[test]
fn subclass_wrapper_shape_and_dtype() {
let wrapper = SubclassArrayWrapper::new(vec![1.0, 2.0], vec![2], "float64".to_string())
.expect("construction failed");
assert_eq!(wrapper.shape(), vec![2usize]);
assert_eq!(wrapper.dtype(), "float64");
}
}