use std::fmt::Display;
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
use rayon::prelude::*;
use super::errors;
use super::sparse::Tensor as TensorSparse;
use super::tensor_trait::TensorTrait;
use crate::math::scalar::{Scalar, ScalarCastError};
#[derive(Debug, Clone)]
pub struct Tensor<T: Scalar> {
pub(crate) shape: Vec<usize>,
pub(crate) data: Vec<T>,
}
impl<T: Scalar> Tensor<T> {
#[inline(always)]
pub fn len(&self) -> usize {
self.data.len()
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[inline]
pub(crate) fn from_parts_unchecked(shape: Vec<usize>, data: Vec<T>) -> Self {
Self { shape, data }
}
#[inline]
pub(crate) fn from_vec(shape: &[usize], data: Vec<T>) -> Self {
let expected = checked_num_elements(shape, "dense tensor from vector");
assert_eq!(
data.len(),
expected,
"dense tensor data length mismatch: expected {expected}, got {}",
data.len()
);
Self {
shape: shape.to_vec(),
data,
}
}
#[inline]
pub(crate) fn data(&self) -> &[T] {
&self.data
}
#[inline]
pub(crate) fn data_mut(&mut self) -> &mut [T] {
&mut self.data
}
}
#[inline(always)]
fn wrap_axis_index(idx: isize, dim: usize) -> usize {
debug_assert!(dim > 0);
let d = dim as isize;
let mut m = idx % d;
if m < 0 {
m += d;
}
m as usize
}
pub(crate) fn checked_num_elements(shape: &[usize], context: &str) -> usize {
errors::checked_num_elements(shape).unwrap_or_else(|error| panic!("{context}: {error}"))
}
impl<T> TensorTrait<T> for Tensor<T>
where
T: Scalar,
{
type Repr<U: Scalar> = Tensor<U>;
#[inline]
fn empty(shape: &[usize]) -> Self {
assert!(
shape.iter().all(|&d| d > 0),
"All dimensions must be > 0; got {shape:?}"
);
let size = checked_num_elements(shape, "dense tensor");
Self {
shape: shape.to_vec(),
data: vec![T::default(); size],
}
}
fn get_sum(&self) -> T {
let result = self
.data
.par_iter()
.cloned()
.reduce(|| T::zero(), |a, b| a + b);
result
}
#[inline]
fn shape(&self) -> &[usize] {
&self.shape
}
#[inline(always)]
fn index(&self, indices: &[isize]) -> usize {
assert_eq!(indices.len(), self.shape.len(), "Index rank mismatch");
let mut flat = 0usize;
let mut stride = 1usize;
for (&dim, &raw_a) in self.shape.iter().rev().zip(indices.iter().rev()) {
let a = wrap_axis_index(raw_a, dim);
flat += a * stride;
stride *= dim;
}
flat
}
#[inline(always)]
fn get(&self, indices: &[isize]) -> T {
let k = self.index(indices);
unsafe { *self.data.get_unchecked(k) }
}
#[inline(always)]
fn get_mut(&mut self, indices: &[isize]) -> &mut T {
let k = self.index(indices);
unsafe { self.data.get_unchecked_mut(k) }
}
#[inline(always)]
fn set(&mut self, indices: &[isize], val: T) {
let k = self.index(indices);
unsafe { *self.data.get_unchecked_mut(k) = val }
}
#[inline]
fn par_fill(&mut self, value: T)
where
T: Copy + Send + Sync,
{
self.data.par_iter_mut().for_each(|x| *x = value);
}
#[inline]
fn par_map_in_place<F>(&mut self, f: F)
where
T: Copy + Send + Sync,
F: Fn(T) -> T + Sync + Send,
{
self.data.par_iter_mut().for_each(|x| *x = f(*x));
}
#[inline]
fn par_zip_with_inplace<F, Rhs>(&mut self, other: &Rhs, f: F)
where
Rhs: TensorTrait<T>,
T: Copy + Send + Sync,
F: Fn(T, T) -> T + Sync + Send,
{
assert_eq!(self.shape(), other.shape(), "Tensor shape mismatch");
let rank = self.shape.len();
let dims = self.shape.clone();
self.data.par_iter_mut().enumerate().for_each(|(k, a)| {
let mut rem = k;
let mut idx = vec![0isize; rank];
for ax in (0..rank).rev() {
let d = dims[ax];
idx[ax] = (rem % d) as isize;
rem /= d;
}
let b = other.get(&idx);
*a = f(*a, b);
});
}
#[inline]
fn try_cast_to<U: Scalar>(&self) -> Result<Self::Repr<U>, ScalarCastError>
where
T: Copy + Send + Sync,
{
Tensor::<T>::try_cast_to::<U>(self)
}
fn print(&self) {
Tensor::<T>::print(self);
}
}
macro_rules! impl_tensor_ref_binop {
($trait:ident, $method:ident, $op:tt) => {
impl<'a, T> $trait<&'a Tensor<T>> for &'a Tensor<T>
where
T: Scalar + Copy + Send + Sync + core::ops::$trait<Output = T>,
{
type Output = Tensor<T>;
#[inline]
fn $method(self, rhs: &'a Tensor<T>) -> Self::Output {
assert_eq!(self.shape, rhs.shape, "Tensor shape mismatch");
let mut out = self.clone(); out.data
.par_iter_mut()
.zip(rhs.data.par_iter())
.for_each(|(a, &b)| { *a = *a $op b; });
out
}
}
};
}
impl_tensor_ref_binop!(Add, add, +);
impl_tensor_ref_binop!(Sub, sub, -);
impl_tensor_ref_binop!(Mul, mul, *);
impl_tensor_ref_binop!(Div, div, /);
macro_rules! impl_tensor_ref_assign {
($trait:ident, $method:ident, $op:tt) => {
impl<'a, T> $trait<&'a Tensor<T>> for Tensor<T>
where
T: Scalar + Copy + Send + Sync + core::ops::$trait<T>,
{
#[inline]
fn $method(&mut self, rhs: &'a Tensor<T>) {
assert_eq!(self.shape, rhs.shape, "Tensor shape mismatch");
self.data
.par_iter_mut()
.zip(rhs.data.par_iter())
.for_each(|(a, &b)| { *a = (*a) $op b; });
}
}
};
}
impl_tensor_ref_assign!(AddAssign, add_assign, +);
impl_tensor_ref_assign!(SubAssign, sub_assign, -);
impl_tensor_ref_assign!(MulAssign, mul_assign, *);
impl_tensor_ref_assign!(DivAssign, div_assign, /);
macro_rules! impl_tensor_ref_scalar_binop {
($trait:ident, $method:ident, $op:tt) => {
impl<'a, T> $trait<T> for &'a Tensor<T>
where
T: Scalar + Copy + Send + Sync + core::ops::$trait<Output = T>,
{
type Output = Tensor<T>;
#[inline]
fn $method(self, rhs: T) -> Self::Output {
let mut out = self.clone();
out.data.par_iter_mut().for_each(|a| *a = *a $op rhs);
out
}
}
};
}
impl_tensor_ref_scalar_binop!(Add, add, +);
impl_tensor_ref_scalar_binop!(Sub, sub, -);
impl_tensor_ref_scalar_binop!(Mul, mul, *);
impl_tensor_ref_scalar_binop!(Div, div, /);
macro_rules! impl_tensor_scalar_assign {
($trait:ident, $method:ident, $op:tt) => {
impl<T> $trait<T> for Tensor<T>
where
T: Scalar + Copy + Send + Sync + core::ops::$trait<T>,
{
#[inline]
fn $method(&mut self, rhs: T) {
self.data.par_iter_mut().for_each(|a| *a = *a $op rhs);
}
}
};
}
impl_tensor_scalar_assign!(AddAssign, add_assign, +);
impl_tensor_scalar_assign!(SubAssign, sub_assign, -);
impl_tensor_scalar_assign!(MulAssign, mul_assign, *);
impl_tensor_scalar_assign!(DivAssign, div_assign, /);
impl<T: Scalar> Tensor<T> {
pub fn try_cast_to<U: Scalar>(&self) -> Result<Tensor<U>, ScalarCastError> {
let data: Result<Vec<U>, _> = self.data.par_iter().map(|&x| x.try_cast::<U>()).collect();
Ok(Tensor {
shape: self.shape.clone(),
data: data?,
})
}
}
impl<T: Scalar> Tensor<T> {
#[inline]
pub fn to_sparse(&self) -> TensorSparse<T> {
TensorSparse::from_dense(self)
}
#[inline]
pub fn from_sparse(sparse: &TensorSparse<T>) -> Self {
let shape = sparse.shape().to_vec();
let size = checked_num_elements(&shape, "dense tensor from sparse");
let mut data = vec![T::zero(); size];
for (&k, &v) in sparse.iter() {
unsafe {
*data.get_unchecked_mut(k) = v;
}
}
Self { shape, data }
}
}
impl<T: Scalar + Display + Copy> Tensor<T> {
pub fn print(&self) {
match self.shape.len() {
1 => {
for i in 0..self.shape[0] {
print!("{:<8} ", self.get(&[i as isize]));
}
println!();
}
2 => {
let rows = self.shape[0];
let cols = self.shape[1];
for i in 0..rows {
for j in 0..cols {
print!("{:<8} ", self.get(&[i as isize, j as isize]));
}
println!();
}
}
_ => {
println!(
"Tensor shape {:?}, {} elements",
self.shape,
self.data.len()
);
println!("{}", crate::math::io::string::format_dense_storage(self));
}
}
}
}