use super::core::{CoreMLBuffer, DeviceState, Index2D, Index3D, MetalBuffer};
use crate::common::RusTorchResult;
use ndarray::{Array, IxDyn};
use std::ops::{Add, Div, Index, IndexMut, Mul, Neg, Sub};
use std::sync::Arc;
#[derive(Debug)]
pub struct F64Tensor {
pub data: Array<f64, IxDyn>,
pub metal_buffer: Option<Arc<MetalBuffer>>,
pub coreml_buffer: Option<Arc<CoreMLBuffer>>,
pub device_state: DeviceState,
pub requires_grad: bool,
shape: Vec<usize>,
}
impl Clone for F64Tensor {
fn clone(&self) -> Self {
F64Tensor {
data: self.data.clone(),
metal_buffer: self.metal_buffer.clone(),
coreml_buffer: self.coreml_buffer.clone(),
device_state: self.device_state.clone(),
requires_grad: self.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Add<F64Tensor> for F64Tensor {
type Output = F64Tensor;
fn add(self, other: F64Tensor) -> F64Tensor {
F64Tensor {
data: &self.data + &other.data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Add<&F64Tensor> for F64Tensor {
type Output = F64Tensor;
fn add(self, other: &F64Tensor) -> F64Tensor {
F64Tensor {
data: &self.data + &other.data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Add for &F64Tensor {
type Output = F64Tensor;
fn add(self, other: &F64Tensor) -> F64Tensor {
F64Tensor {
data: &self.data + &other.data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Add<f64> for F64Tensor {
type Output = F64Tensor;
fn add(self, scalar: f64) -> F64Tensor {
F64Tensor {
data: &self.data + scalar,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Add<f64> for &F64Tensor {
type Output = F64Tensor;
fn add(self, scalar: f64) -> F64Tensor {
F64Tensor {
data: &self.data + scalar,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Sub<F64Tensor> for F64Tensor {
type Output = F64Tensor;
fn sub(self, other: F64Tensor) -> F64Tensor {
F64Tensor {
data: &self.data - &other.data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Sub<&F64Tensor> for F64Tensor {
type Output = F64Tensor;
fn sub(self, other: &F64Tensor) -> F64Tensor {
F64Tensor {
data: &self.data - &other.data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Sub for &F64Tensor {
type Output = F64Tensor;
fn sub(self, other: &F64Tensor) -> F64Tensor {
F64Tensor {
data: &self.data - &other.data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Sub<f64> for F64Tensor {
type Output = F64Tensor;
fn sub(self, scalar: f64) -> F64Tensor {
F64Tensor {
data: &self.data - scalar,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Sub<f64> for &F64Tensor {
type Output = F64Tensor;
fn sub(self, scalar: f64) -> F64Tensor {
F64Tensor {
data: &self.data - scalar,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Mul<F64Tensor> for F64Tensor {
type Output = F64Tensor;
fn mul(self, other: F64Tensor) -> F64Tensor {
F64Tensor {
data: &self.data * &other.data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Mul<&F64Tensor> for F64Tensor {
type Output = F64Tensor;
fn mul(self, other: &F64Tensor) -> F64Tensor {
F64Tensor {
data: &self.data * &other.data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Mul for &F64Tensor {
type Output = F64Tensor;
fn mul(self, other: &F64Tensor) -> F64Tensor {
F64Tensor {
data: &self.data * &other.data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Mul<f64> for F64Tensor {
type Output = F64Tensor;
fn mul(self, scalar: f64) -> F64Tensor {
F64Tensor {
data: &self.data * scalar,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Mul<f64> for &F64Tensor {
type Output = F64Tensor;
fn mul(self, scalar: f64) -> F64Tensor {
F64Tensor {
data: &self.data * scalar,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Div<F64Tensor> for F64Tensor {
type Output = F64Tensor;
fn div(self, other: F64Tensor) -> F64Tensor {
F64Tensor {
data: &self.data / &other.data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Div<&F64Tensor> for F64Tensor {
type Output = F64Tensor;
fn div(self, other: &F64Tensor) -> F64Tensor {
F64Tensor {
data: &self.data / &other.data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Div for &F64Tensor {
type Output = F64Tensor;
fn div(self, other: &F64Tensor) -> F64Tensor {
F64Tensor {
data: &self.data / &other.data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad || other.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Div<f64> for F64Tensor {
type Output = F64Tensor;
fn div(self, scalar: f64) -> F64Tensor {
F64Tensor {
data: &self.data / scalar,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Div<f64> for &F64Tensor {
type Output = F64Tensor;
fn div(self, scalar: f64) -> F64Tensor {
F64Tensor {
data: &self.data / scalar,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Neg for F64Tensor {
type Output = F64Tensor;
fn neg(self) -> F64Tensor {
F64Tensor {
data: -&self.data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: self.shape.clone(),
}
}
}
impl Neg for &F64Tensor {
type Output = F64Tensor;
fn neg(self) -> F64Tensor {
F64Tensor {
data: -&self.data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: self.requires_grad,
shape: self.shape.clone(),
}
}
}
impl F64Tensor {
pub fn new(data: Array<f64, IxDyn>) -> Self {
let shape = data.shape().to_vec();
F64Tensor {
data,
metal_buffer: None,
coreml_buffer: None,
device_state: DeviceState::CPU,
requires_grad: false,
shape,
}
}
pub fn zeros(shape: &[usize]) -> RusTorchResult<Self> {
let data = Array::zeros(shape);
Ok(F64Tensor::new(data))
}
pub fn ones(shape: &[usize]) -> RusTorchResult<Self> {
let data = Array::ones(shape);
Ok(F64Tensor::new(data))
}
pub fn randn(shape: &[usize]) -> RusTorchResult<Self> {
use ndarray_rand::rand_distr::StandardNormal;
use ndarray_rand::RandomExt;
let data = Array::random(shape, StandardNormal);
Ok(F64Tensor::new(data))
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn numel(&self) -> usize {
self.shape.iter().product()
}
pub fn dtype(&self) -> &'static str {
"f64"
}
pub fn requires_grad_(mut self, requires_grad: bool) -> Self {
self.requires_grad = requires_grad;
self
}
pub fn reshape(&self, new_shape: &[usize]) -> RusTorchResult<Self> {
let new_data = self.data.clone().into_shape_with_order(new_shape)?;
let mut result = F64Tensor::new(new_data);
result.requires_grad = self.requires_grad;
Ok(result)
}
pub fn transpose(&self) -> RusTorchResult<Self> {
let transposed = self.data.t().to_owned();
let mut shape = self.shape.clone();
shape.reverse();
let mut result = F64Tensor::new(transposed);
result.shape = shape;
result.requires_grad = self.requires_grad;
Ok(result)
}
pub fn matmul(&self, other: &F64Tensor) -> RusTorchResult<Self> {
use ndarray::linalg::general_mat_mul;
let (m, k) = (self.shape[0], self.shape[1]);
let (k2, n) = (other.shape[0], other.shape[1]);
if k != k2 {
return Err(crate::error::RusTorchError::tensor_op(format!(
"Cannot multiply matrices with shapes {:?} and {:?}",
self.shape, other.shape
)));
}
let mut result_data = Array::zeros((m, n));
general_mat_mul(
1.0,
&self.data.view().into_dimensionality()?,
&other.data.view().into_dimensionality()?,
0.0,
&mut result_data.view_mut(),
);
let result_dyn = result_data.into_dyn();
let mut result = F64Tensor::new(result_dyn);
result.requires_grad = self.requires_grad || other.requires_grad;
Ok(result)
}
pub fn sum(&self) -> f64 {
self.data.sum()
}
pub fn mean(&self) -> f64 {
self.data.mean().unwrap_or(0.0)
}
pub fn std(&self) -> f64 {
let mean = self.mean();
let variance = self.data.mapv(|x| (x - mean).powi(2)).mean().unwrap_or(0.0);
variance.sqrt()
}
pub fn max(&self) -> f64 {
self.data.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x))
}
pub fn min(&self) -> f64 {
self.data.fold(f64::INFINITY, |acc, &x| acc.min(x))
}
pub fn unsqueeze(&self, dim: usize) -> RusTorchResult<Self> {
let mut new_shape = self.shape.clone();
new_shape.insert(dim, 1);
self.reshape(&new_shape)
}
pub fn expand(&self, new_shape: &[usize]) -> RusTorchResult<Self> {
let expanded_data = self
.data
.broadcast(new_shape)
.ok_or_else(|| crate::error::RusTorchError::tensor_op("Cannot broadcast to new shape"))?
.to_owned();
let mut result = F64Tensor::new(expanded_data);
result.requires_grad = self.requires_grad;
Ok(result)
}
pub fn transpose_dims(&self, dim1: usize, dim2: usize) -> RusTorchResult<Self> {
let mut permutation: Vec<usize> = (0..self.ndim()).collect();
permutation.swap(dim1, dim2);
let transposed = self.data.clone().permuted_axes(permutation);
let mut result = F64Tensor::new(transposed);
result.requires_grad = self.requires_grad;
Ok(result)
}
pub fn softmax(&self, dim: Option<usize>) -> RusTorchResult<Self> {
let axis = dim.unwrap_or(self.ndim() - 1);
let max_vals = self.data.map_axis(ndarray::Axis(axis), |lane| {
lane.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x))
});
let shifted = &self.data - &max_vals.insert_axis(ndarray::Axis(axis));
let exp_vals = shifted.mapv(|x| x.exp());
let sum_exp = exp_vals.sum_axis(ndarray::Axis(axis));
let result_data = exp_vals / sum_exp.insert_axis(ndarray::Axis(axis));
let mut result = F64Tensor::new(result_data);
result.requires_grad = self.requires_grad;
Ok(result)
}
}
impl Index<usize> for F64Tensor {
type Output = f64;
fn index(&self, index: usize) -> &Self::Output {
&self.data[index]
}
}
impl IndexMut<usize> for F64Tensor {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.data[index]
}
}
impl Index<Index2D> for F64Tensor {
type Output = f64;
fn index(&self, index: Index2D) -> &Self::Output {
&self.data[[index.0, index.1]]
}
}
impl IndexMut<Index2D> for F64Tensor {
fn index_mut(&mut self, index: Index2D) -> &mut Self::Output {
&mut self.data[[index.0, index.1]]
}
}
impl Index<Index3D> for F64Tensor {
type Output = f64;
fn index(&self, index: Index3D) -> &Self::Output {
&self.data[[index.0, index.1, index.2]]
}
}
impl IndexMut<Index3D> for F64Tensor {
fn index_mut(&mut self, index: Index3D) -> &mut Self::Output {
&mut self.data[[index.0, index.1, index.2]]
}
}