use crate::python::common::{conversions, to_py_err, validation};
use crate::tensor::device::Device;
use crate::tensor::operations::zero_copy::TensorIterOps;
use crate::tensor::Tensor;
use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1, ToPyArray};
use pyo3::exceptions::*;
use pyo3::prelude::*;
#[pyclass]
#[derive(Clone)]
pub struct PyTensor {
pub(crate) tensor: Tensor<f32>,
}
#[pymethods]
impl PyTensor {
#[new]
pub fn new(data: Vec<f32>, shape: Vec<usize>) -> PyResult<Self> {
use crate::python::common::validation::validate_dimensions;
validate_dimensions(&shape)?;
let tensor = Tensor::from_vec(data, shape);
Ok(PyTensor { tensor })
}
#[staticmethod]
pub fn from_numpy(array: PyReadonlyArray1<f32>) -> PyResult<Self> {
use crate::python::common::conversions::pyarray_to_vec;
let data = pyarray_to_vec(array);
let shape = vec![data.len()];
crate::python::common::validation::validate_dimensions(&shape)?;
let tensor = Tensor::from_vec(data, shape);
Ok(PyTensor { tensor })
}
#[staticmethod]
pub fn zeros(shape: Vec<usize>) -> PyResult<Self> {
use crate::python::common::validation::validate_dimensions;
validate_dimensions(&shape)?;
let tensor = Tensor::zeros(&shape);
Ok(PyTensor { tensor })
}
#[staticmethod]
pub fn ones(shape: Vec<usize>) -> PyResult<Self> {
use crate::python::common::validation::validate_dimensions;
validate_dimensions(&shape)?;
let tensor = Tensor::ones(&shape);
Ok(PyTensor { tensor })
}
#[staticmethod]
pub fn randn(shape: Vec<usize>) -> PyResult<Self> {
use crate::python::common::validation::validate_dimensions;
validate_dimensions(&shape)?;
let tensor = Tensor::randn(&shape);
Ok(PyTensor { tensor })
}
#[staticmethod]
pub fn arange(start: f32, end: f32, step: f32) -> PyResult<Self> {
if step <= 0.0 {
return Err(PyValueError::new_err("Step must be positive"));
}
if start >= end {
return Err(PyValueError::new_err("Start must be less than end"));
}
let size = ((end - start) / step).ceil() as usize;
let data: Vec<f32> = (0..size).map(|i| start + i as f32 * step).collect();
let shape = vec![size];
let tensor = Tensor::from_vec(data, shape);
Ok(PyTensor { tensor })
}
pub fn shape(&self) -> Vec<usize> {
self.tensor.shape().to_vec()
}
pub fn data(&self) -> Vec<f32> {
self.tensor.iter().cloned().collect()
}
pub fn to_vec(&self) -> PyResult<Vec<f32>> {
Ok(self.data())
}
pub fn numpy<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f32>> {
use crate::python::common::conversions::vec_to_pyarray;
let data = self.data();
vec_to_pyarray(data, py)
}
pub fn ndim(&self) -> usize {
self.tensor.shape().len()
}
pub fn numel(&self) -> usize {
self.tensor.shape().iter().product()
}
pub fn reshape(&self, shape: Vec<usize>) -> PyResult<Self> {
use crate::python::common::{to_py_err, validation::validate_dimensions};
validate_dimensions(&shape)?;
let current_elements = self.numel();
let new_elements: usize = shape.iter().product();
if current_elements != new_elements {
return Err(PyValueError::new_err(format!(
"Cannot reshape tensor with {} elements to shape with {} elements",
current_elements, new_elements
)));
}
match self.tensor.reshape(&shape) {
Ok(tensor) => Ok(PyTensor { tensor }),
Err(e) => Err(to_py_err(e)),
}
}
pub fn transpose(&self) -> PyResult<Self> {
use crate::python::common::to_py_err;
match self.tensor.transpose() {
Ok(tensor) => Ok(PyTensor { tensor }),
Err(e) => Err(to_py_err(e)),
}
}
pub fn __add__(&self, other: &PyTensor) -> PyResult<Self> {
if self.shape() != other.shape() {
return Err(PyValueError::new_err(
"Tensor shapes must match for addition",
));
}
let result_tensor = &self.tensor + &other.tensor;
Ok(PyTensor {
tensor: result_tensor,
})
}
pub fn __sub__(&self, other: &PyTensor) -> PyResult<Self> {
if self.shape() != other.shape() {
return Err(PyValueError::new_err(
"Tensor shapes must match for subtraction",
));
}
let result_tensor = &self.tensor - &other.tensor;
Ok(PyTensor {
tensor: result_tensor,
})
}
pub fn __mul__(&self, other: &PyTensor) -> PyResult<Self> {
if self.shape() != other.shape() {
return Err(PyValueError::new_err(
"Tensor shapes must match for multiplication",
));
}
let result_tensor = &self.tensor * &other.tensor;
Ok(PyTensor {
tensor: result_tensor,
})
}
pub fn __matmul__(&self, other: &PyTensor) -> PyResult<Self> {
let self_shape = self.shape();
let other_shape = other.shape();
if self_shape.len() < 2 || other_shape.len() < 2 {
return Err(PyValueError::new_err(
"Matrix multiplication requires at least 2D tensors",
));
}
let self_cols = self_shape[self_shape.len() - 1];
let other_rows = other_shape[other_shape.len() - 2];
if self_cols != other_rows {
return Err(PyValueError::new_err(format!(
"Matrix multiplication shape mismatch: {} vs {}",
self_cols, other_rows
)));
}
let result_tensor = &self.tensor * &other.tensor;
Ok(PyTensor {
tensor: result_tensor,
})
}
pub fn dot(&self, other: &PyTensor) -> PyResult<Self> {
self.__matmul__(other)
}
pub fn sum(&self) -> f32 {
self.tensor.iter().sum()
}
pub fn mean(&self) -> f32 {
let sum: f32 = self.tensor.iter().sum();
sum / self.numel() as f32
}
pub fn max(&self) -> f32 {
self.tensor.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b))
}
pub fn min(&self) -> f32 {
self.tensor.iter().fold(f32::INFINITY, |a, &b| a.min(b))
}
pub fn svd(&self, compute_uv: Option<bool>) -> PyResult<(PyTensor, PyTensor, PyTensor)> {
use crate::python::common::to_py_err;
let _compute_uv = compute_uv.unwrap_or(true);
match self.tensor.svd() {
Ok((u, s, vt)) => Ok((
PyTensor { tensor: u },
PyTensor { tensor: s },
PyTensor { tensor: vt },
)),
Err(e) => Err(to_py_err(e)),
}
}
pub fn qr(&self) -> PyResult<(PyTensor, PyTensor)> {
use crate::python::common::to_py_err;
match self.tensor.qr() {
Ok((q, r)) => Ok((PyTensor { tensor: q }, PyTensor { tensor: r })),
Err(e) => Err(to_py_err(e)),
}
}
pub fn eig(&self) -> PyResult<(PyTensor, PyTensor)> {
use crate::python::common::to_py_err;
match self.tensor.eigh() {
Ok((eigenvalues, eigenvectors)) => Ok((
PyTensor {
tensor: eigenvalues,
},
PyTensor {
tensor: eigenvectors,
},
)),
Err(e) => Err(to_py_err(e)),
}
}
pub fn det(&self) -> PyResult<f32> {
use crate::python::common::to_py_err;
match self.tensor.det() {
Ok(det) => Ok(det),
Err(e) => Err(to_py_err(e)),
}
}
pub fn inverse(&self) -> PyResult<PyTensor> {
use crate::python::common::to_py_err;
match self.tensor.inverse() {
Ok(tensor) => Ok(PyTensor { tensor }),
Err(e) => Err(to_py_err(e)),
}
}
pub fn norm(&self, ord: Option<String>) -> PyResult<f32> {
let _ord = ord.unwrap_or_else(|| "fro".to_string());
let norm_value = self.tensor.norm();
Ok(norm_value)
}
pub fn __repr__(&self) -> String {
format!("PyTensor(shape={:?}, data={:?})", self.shape(), {
let data = self.data();
if data.len() <= 10 {
format!("{:?}", data)
} else {
format!("{:?}...", &data[..10])
}
})
}
pub fn __str__(&self) -> String {
self.__repr__()
}
pub fn __copy__(&self) -> Self {
self.clone()
}
pub fn __deepcopy__(&self, _memo: &Bound<'_, pyo3::types::PyDict>) -> Self {
self.clone()
}
}
#[pyclass]
pub struct PyDevice {
pub(crate) device: Device,
}
#[pymethods]
impl PyDevice {
#[staticmethod]
pub fn cpu() -> Self {
PyDevice {
device: Device::Cpu,
}
}
#[staticmethod]
pub fn cuda(index: Option<usize>) -> PyResult<Self> {
let index = index.unwrap_or(0);
let device = Device::Cuda(index);
Ok(PyDevice { device })
}
#[staticmethod]
pub fn metal() -> PyResult<Self> {
let device = Device::Mps;
Ok(PyDevice { device })
}
pub fn is_available(&self) -> bool {
match self.device {
Device::Cpu => true,
Device::Cuda(_) => false, Device::Mps => cfg!(target_os = "macos"),
Device::Wasm => cfg!(target_arch = "wasm32"),
}
}
pub fn type_(&self) -> String {
match self.device {
Device::Cpu => "cpu".to_string(),
Device::Cuda(_) => "cuda".to_string(),
Device::Mps => "mps".to_string(),
Device::Wasm => "wasm".to_string(),
}
}
pub fn __repr__(&self) -> String {
match self.device {
Device::Cpu => "device(type='cpu')".to_string(),
Device::Cuda(index) => format!("device(type='cuda', index={})", index),
Device::Mps => "device(type='mps')".to_string(),
Device::Wasm => "device(type='wasm')".to_string(),
}
}
}
#[pyfunction]
pub fn tensor(data: Vec<f32>, shape: Option<Vec<usize>>) -> PyResult<PyTensor> {
let shape = shape.unwrap_or_else(|| vec![data.len()]);
PyTensor::new(data, shape)
}
#[pyfunction]
pub fn eye(n: usize) -> PyResult<PyTensor> {
let mut data = vec![0.0; n * n];
for i in 0..n {
data[i * n + i] = 1.0;
}
PyTensor::new(data, vec![n, n])
}