Skip to main content

scirs2_numpy/
array_subclass.rs

1//! Array subclass support for duck-typed Python objects.
2//!
3//! Provides utilities to accept any Python object that looks like an array:
4//! - NumPy ndarrays (via `__array__` protocol)
5//! - pandas Series (has a `.values` attribute returning an ndarray)
6//! - Any list-like object supporting `__len__` and `__getitem__`
7//!
8//! Also exposes [`SubclassArrayWrapper`], a `#[pyclass]` that wraps a flat
9//! `f64` buffer with shape metadata and looks enough like a NumPy array for
10//! downstream code that uses duck-typing.
11
12use pyo3::prelude::*;
13use pyo3::types::PyAnyMethods;
14
15// ──────────────────────────────────────────────────────────────────────────────
16// Free-standing extraction helpers
17// ──────────────────────────────────────────────────────────────────────────────
18
19/// Extract `f32` values from any Python array-like object.
20///
21/// Attempts the following strategies in order:
22/// 1. `.values` attribute (pandas Series / masked array).
23/// 2. `.__array__()` method (NumPy array protocol).
24/// 3. Direct iteration via `__len__` + `__getitem__`.
25///
26/// # Errors
27/// Returns a [`PyErr`] if none of the strategies succeeds or if an element
28/// cannot be converted to `f32`.
29#[pyfunction]
30pub fn from_array_like_f32(obj: &Bound<'_, PyAny>) -> PyResult<Vec<f32>> {
31    // Strategy 1: .values attribute (pandas Series / masked array).
32    // Guard: only recurse if the attribute is NOT already a plain ndarray
33    // (to avoid infinite recursion on ndarray objects that happen to lack .values).
34    if let Ok(values) = obj.getattr("values") {
35        // If `values` itself has a .values attribute we stop (avoid deep recursion).
36        if values.getattr("values").is_err() {
37            return from_array_like_f32(&values);
38        }
39    }
40
41    // Strategy 2: __array__ protocol.
42    if let Ok(arr) = obj.call_method0("__array__") {
43        // Recursing here is safe: the resulting ndarray has no .values attribute.
44        return from_array_like_f32(&arr);
45    }
46
47    // Strategy 3: direct iteration.
48    let len = obj.len()?;
49    let mut result = Vec::with_capacity(len);
50    for i in 0..len {
51        let item = obj.get_item(i)?;
52        let val: f32 = item.extract()?;
53        result.push(val);
54    }
55    Ok(result)
56}
57
58/// Extract `f64` values from any Python array-like object.
59///
60/// Attempts the following strategies in order:
61/// 1. `.values` attribute (pandas Series / masked array).
62/// 2. `.__array__()` method (NumPy array protocol).
63/// 3. Direct iteration via `__len__` + `__getitem__`.
64///
65/// # Errors
66/// Returns a [`PyErr`] if none of the strategies succeeds or if an element
67/// cannot be converted to `f64`.
68#[pyfunction]
69pub fn from_array_like_f64(obj: &Bound<'_, PyAny>) -> PyResult<Vec<f64>> {
70    // Strategy 1: .values attribute — with depth guard.
71    if let Ok(values) = obj.getattr("values") {
72        if values.getattr("values").is_err() {
73            return from_array_like_f64(&values);
74        }
75    }
76
77    // Strategy 2: __array__ protocol.
78    if let Ok(arr) = obj.call_method0("__array__") {
79        return from_array_like_f64(&arr);
80    }
81
82    // Strategy 3: direct iteration.
83    let len = obj.len()?;
84    let mut result = Vec::with_capacity(len);
85    for i in 0..len {
86        let item = obj.get_item(i)?;
87        let val: f64 = item.extract()?;
88        result.push(val);
89    }
90    Ok(result)
91}
92
93// ──────────────────────────────────────────────────────────────────────────────
94// SubclassArrayWrapper
95// ──────────────────────────────────────────────────────────────────────────────
96
97/// A Python-visible wrapper around a flat `f64` data buffer with shape metadata.
98///
99/// Implements enough of the NumPy duck-typing surface to be accepted by code
100/// that inspects `.shape`, `.dtype`, `.__len__`, and `.__getitem__`.
101#[pyclass(name = "SubclassArrayWrapper")]
102pub struct SubclassArrayWrapper {
103    /// Flat data buffer in C (row-major) order.
104    data: Vec<f64>,
105    /// Logical shape of the array.
106    shape: Vec<usize>,
107    /// NumPy-compatible dtype string (e.g. `"float64"`).
108    dtype: String,
109}
110
111#[pymethods]
112impl SubclassArrayWrapper {
113    /// Construct a new wrapper.
114    ///
115    /// # Arguments
116    /// * `data`  – flat element buffer; length must equal the product of `shape`.
117    /// * `shape` – logical shape; `[]` is interpreted as a 0-d scalar.
118    /// * `dtype` – NumPy-compatible dtype string such as `"float64"`.
119    #[new]
120    #[pyo3(signature = (data, shape, dtype = "float64".to_string()))]
121    pub fn new(data: Vec<f64>, shape: Vec<usize>, dtype: String) -> PyResult<Self> {
122        let n: usize = shape.iter().product::<usize>().max(1);
123        if shape.is_empty() {
124            // 0-d: exactly one element.
125            if data.len() != 1 {
126                return Err(pyo3::exceptions::PyValueError::new_err(
127                    "0-d SubclassArrayWrapper requires exactly one element",
128                ));
129            }
130        } else if data.len() != n {
131            return Err(pyo3::exceptions::PyValueError::new_err(format!(
132                "data length {} does not match shape product {}",
133                data.len(),
134                n
135            )));
136        }
137        Ok(Self { data, shape, dtype })
138    }
139
140    /// Number of elements (flat).
141    pub fn __len__(&self) -> usize {
142        self.data.len()
143    }
144
145    /// Return element at flat index `idx`.
146    ///
147    /// # Errors
148    /// Returns `IndexError` if `idx` is out of bounds.
149    pub fn __getitem__(&self, idx: usize) -> PyResult<f64> {
150        self.data.get(idx).copied().ok_or_else(|| {
151            pyo3::exceptions::PyIndexError::new_err(format!(
152                "index {} out of bounds for array of length {}",
153                idx,
154                self.data.len()
155            ))
156        })
157    }
158
159    /// Return `self` — mirrors pandas `.values` which returns the underlying array.
160    ///
161    /// This makes `SubclassArrayWrapper` itself accepted by [`from_array_like_f64`].
162    pub fn values(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
163        slf
164    }
165
166    /// Logical shape tuple.
167    pub fn shape(&self) -> Vec<usize> {
168        self.shape.clone()
169    }
170
171    /// NumPy-compatible dtype string.
172    pub fn dtype(&self) -> &str {
173        &self.dtype
174    }
175
176    /// Flat copy of the data as a Python list.
177    pub fn to_list(&self) -> Vec<f64> {
178        self.data.clone()
179    }
180}
181
182// ──────────────────────────────────────────────────────────────────────────────
183// Module registration
184// ──────────────────────────────────────────────────────────────────────────────
185
186/// Register array-subclass helpers and [`SubclassArrayWrapper`] into a PyO3 module.
187pub fn register_array_subclass_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
188    m.add_function(wrap_pyfunction!(from_array_like_f32, m)?)?;
189    m.add_function(wrap_pyfunction!(from_array_like_f64, m)?)?;
190    m.add_class::<SubclassArrayWrapper>()?;
191    Ok(())
192}
193
194// ──────────────────────────────────────────────────────────────────────────────
195// Tests
196// ──────────────────────────────────────────────────────────────────────────────
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn array_like_extracts_from_list() {
204        Python::attach(|py| {
205            let list = py
206                .eval(pyo3::ffi::c_str!("[1.0, 2.0, 3.0]"), None, None)
207                .expect("eval failed");
208            let result = from_array_like_f64(&list).expect("extraction failed");
209            assert_eq!(result, vec![1.0, 2.0, 3.0]);
210        });
211    }
212
213    #[test]
214    fn array_like_wrapper_len_correct() {
215        let wrapper =
216            SubclassArrayWrapper::new(vec![1.0, 2.0, 3.0], vec![3], "float64".to_string())
217                .expect("construction failed");
218        assert_eq!(wrapper.__len__(), 3);
219    }
220
221    #[test]
222    fn subclass_wrapper_getitem_correct() {
223        let wrapper =
224            SubclassArrayWrapper::new(vec![10.0, 20.0, 30.0], vec![3], "float64".to_string())
225                .expect("construction failed");
226        assert!((wrapper.__getitem__(1).expect("index valid") - 20.0).abs() < f64::EPSILON);
227    }
228
229    #[test]
230    fn subclass_wrapper_getitem_oob() {
231        let wrapper = SubclassArrayWrapper::new(vec![1.0], vec![1], "float64".to_string())
232            .expect("construction failed");
233        assert!(wrapper.__getitem__(99).is_err());
234    }
235
236    #[test]
237    fn subclass_wrapper_shape_and_dtype() {
238        let wrapper = SubclassArrayWrapper::new(vec![1.0, 2.0], vec![2], "float64".to_string())
239            .expect("construction failed");
240        assert_eq!(wrapper.shape(), vec![2usize]);
241        assert_eq!(wrapper.dtype(), "float64");
242    }
243}