use crate::array::Array;
use crate::math;
use crate::NumRs2Error;
use pyo3::exceptions::{PyIndexError, PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyList, PyTuple};
use scirs2_numpy::{
PyArrayDyn, PyArrayMethods, PyReadonlyArrayDyn, PyUntypedArrayMethods, ToPyArray,
};
#[pyclass(name = "Array")]
#[derive(Clone)]
pub struct PyArray {
pub(crate) inner: Array<f64>,
}
#[pymethods]
impl PyArray {
#[new]
fn new(data: &Bound<'_, PyAny>) -> PyResult<Self> {
if let Ok(np_arr) = data.extract::<PyReadonlyArrayDyn<f64>>() {
let shape = np_arr.shape().to_vec();
let data_vec: Vec<f64> = np_arr
.as_slice()
.map_err(|_| {
PyValueError::new_err("Cannot convert NumPy array to contiguous slice")
})?
.to_vec();
let array = Array::from_vec(data_vec).reshape(&shape);
return Ok(PyArray { inner: array });
}
if let Ok(list) = data.downcast::<PyList>() {
let vec: Vec<f64> = list.extract()?;
return Ok(PyArray {
inner: Array::from_vec(vec),
});
}
if let Ok(tuple) = data.downcast::<PyTuple>() {
let vec: Vec<f64> = tuple.extract()?;
return Ok(PyArray {
inner: Array::from_vec(vec),
});
}
Err(PyTypeError::new_err("Expected list, tuple, or NumPy array"))
}
#[getter]
fn shape(&self) -> Vec<usize> {
self.inner.shape()
}
#[getter]
fn ndim(&self) -> usize {
self.inner.ndim()
}
#[getter]
fn size(&self) -> usize {
self.inner.size()
}
#[getter]
fn dtype(&self) -> &str {
"float64"
}
fn reshape(&self, shape: Vec<usize>) -> PyResult<Self> {
Ok(PyArray {
inner: self.inner.clone().reshape(&shape),
})
}
fn transpose(&self) -> PyResult<Self> {
Ok(PyArray {
inner: self.inner.transpose(),
})
}
fn flatten(&self) -> PyResult<Self> {
let size = self.inner.size();
Ok(PyArray {
inner: self.inner.clone().reshape(&[size]),
})
}
fn squeeze(&self) -> PyResult<Self> {
let shape: Vec<usize> = self
.inner
.shape()
.iter()
.copied()
.filter(|&s| s != 1)
.collect();
if shape.is_empty() {
Ok(PyArray {
inner: self.inner.clone().reshape(&[1]),
})
} else {
Ok(PyArray {
inner: self.inner.clone().reshape(&shape),
})
}
}
pub fn tolist(&self) -> Vec<f64> {
self.inner.to_vec()
}
fn to_numpy<'py>(&self, py: Python<'py>) -> Bound<'py, PyArrayDyn<f64>> {
let vec = self.inner.to_vec();
let shape: Vec<usize> = self.inner.shape();
let arr = vec.to_pyarray(py);
arr.reshape(shape)
.expect("reshape to array shape should not fail since shape is valid")
}
fn copy(&self) -> Self {
self.clone()
}
fn sum(&self) -> f64 {
self.inner.to_vec().iter().sum()
}
pub fn mean(&self) -> f64 {
let vec = self.inner.to_vec();
if vec.is_empty() {
0.0
} else {
vec.iter().sum::<f64>() / vec.len() as f64
}
}
fn min(&self) -> PyResult<f64> {
self.inner
.to_vec()
.iter()
.copied()
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.ok_or_else(|| PyValueError::new_err("Array is empty"))
}
fn max(&self) -> PyResult<f64> {
self.inner
.to_vec()
.iter()
.copied()
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.ok_or_else(|| PyValueError::new_err("Array is empty"))
}
fn __add__(&self, other: &PyArray) -> PyResult<Self> {
Ok(PyArray {
inner: &self.inner + &other.inner,
})
}
fn __sub__(&self, other: &PyArray) -> PyResult<Self> {
Ok(PyArray {
inner: &self.inner - &other.inner,
})
}
fn __mul__(&self, other: &PyArray) -> PyResult<Self> {
Ok(PyArray {
inner: &self.inner * &other.inner,
})
}
fn __truediv__(&self, other: &PyArray) -> PyResult<Self> {
Ok(PyArray {
inner: &self.inner / &other.inner,
})
}
fn __neg__(&self) -> PyResult<Self> {
Ok(PyArray {
inner: -self.inner.clone(),
})
}
fn __repr__(&self) -> String {
format!(
"Array(shape={:?}, dtype={}, size={})",
self.shape(),
self.dtype(),
self.size()
)
}
fn __str__(&self) -> String {
self.__repr__()
}
fn __len__(&self) -> usize {
self.inner.shape().first().copied().unwrap_or(0)
}
}
#[pyfunction]
fn array(data: &Bound<'_, PyAny>) -> PyResult<PyArray> {
PyArray::new(data)
}
#[pyfunction]
fn zeros(shape: Vec<usize>) -> PyArray {
PyArray {
inner: Array::zeros(&shape),
}
}
#[pyfunction]
fn ones(shape: Vec<usize>) -> PyArray {
PyArray {
inner: Array::ones(&shape),
}
}
#[pyfunction]
fn eye(n: usize, m: Option<usize>, k: Option<isize>) -> PyArray {
let m = m.unwrap_or(n);
let k = k.unwrap_or(0);
PyArray {
inner: Array::eye(n, m, k),
}
}
#[pyfunction]
fn identity(n: usize) -> PyArray {
PyArray {
inner: Array::eye(n, n, 0),
}
}
#[pyfunction]
fn linspace(start: f64, stop: f64, num: usize, endpoint: Option<bool>) -> PyArray {
let endpoint = endpoint.unwrap_or(true);
if endpoint {
PyArray {
inner: math::linspace(start, stop, num),
}
} else {
let step = (stop - start) / num as f64;
let values: Vec<f64> = (0..num).map(|i| start + i as f64 * step).collect();
PyArray {
inner: Array::from_vec(values),
}
}
}
#[pyfunction]
fn arange(start: f64, stop: f64, step: Option<f64>) -> PyArray {
let step = step.unwrap_or(1.0);
PyArray {
inner: math::arange(start, stop, step),
}
}
#[pyfunction]
fn full(shape: Vec<usize>, fill_value: f64) -> PyArray {
let size: usize = shape.iter().product();
let data = vec![fill_value; size];
PyArray {
inner: Array::from_vec(data).reshape(&shape),
}
}
#[pyfunction]
fn zeros_like(a: &PyArray) -> PyArray {
zeros(a.shape())
}
#[pyfunction]
fn ones_like(a: &PyArray) -> PyArray {
ones(a.shape())
}
#[pyfunction]
fn concatenate(arrays: Vec<PyArray>, axis: Option<usize>) -> PyResult<PyArray> {
let _axis = axis.unwrap_or(0);
let mut result = Vec::new();
for arr in arrays {
result.extend(arr.tolist());
}
Ok(PyArray {
inner: Array::from_vec(result),
})
}
pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyArray>()?;
m.add_function(wrap_pyfunction!(array, m)?)?;
m.add_function(wrap_pyfunction!(zeros, m)?)?;
m.add_function(wrap_pyfunction!(ones, m)?)?;
m.add_function(wrap_pyfunction!(eye, m)?)?;
m.add_function(wrap_pyfunction!(identity, m)?)?;
m.add_function(wrap_pyfunction!(linspace, m)?)?;
m.add_function(wrap_pyfunction!(arange, m)?)?;
m.add_function(wrap_pyfunction!(full, m)?)?;
m.add_function(wrap_pyfunction!(zeros_like, m)?)?;
m.add_function(wrap_pyfunction!(ones_like, m)?)?;
m.add_function(wrap_pyfunction!(concatenate, m)?)?;
Ok(())
}