tensorite_math_tensor 0.2.2

Tensorite Math Tensor | Tensor Implementation
Documentation
use std::fmt::{Display};
use std::ops::{Add, Sub, Mul, Div};

#[derive(Clone, Debug, PartialEq)]
pub struct Tensor<T: Copy, const RANK: usize> {
    pub data: Vec<T>,
    shape: [usize; RANK],
}

impl<T: Copy, const RANK: usize> Tensor<T, RANK> {
    pub fn new(data: Vec<T>, shape: [usize; RANK]) -> Self {
        assert_eq!(data.len(), shape.iter().product::<usize>(), "Data length must match shape product");
        Tensor { data, shape }
    }

    // New method for 1D tensors: automatically sets shape to [data.len()]
    pub fn new_vec(data: &Vec<T>) -> Tensor<T, 1> {
        Tensor::new(data.clone(), [data.clone().len()])
    }

    pub fn null() -> Tensor<T, 0> {
        Tensor {
            data: vec![],
            shape: [],
        }
    }

    pub fn empty(shape: &[usize; RANK]) -> Tensor<T, RANK> {
        Tensor {
            data: vec![],
            shape: *shape,
        }
    }

    pub fn as_vec(&self) -> &Vec<T> {
        &self.data
    }


    pub fn shape(&self) -> &[usize; RANK] {
        &self.shape
    }

    pub fn len(&self) -> usize {
        self.data.len()
    }

    /// Reshape the tensor to a new shape, even if the `RANK` changes.
    /// Returns a new `Tensor` with the updated shape.
    pub fn reshape<const NEW_RANK: usize>(self, new_shape: [usize; NEW_RANK]) -> Tensor<T, NEW_RANK> {
        assert_eq!(self.data.len(), new_shape.iter().product::<usize>(), "Data length must match new shape product");
        Tensor { data: self.data, shape: new_shape }
    }

    pub fn size(&self) -> usize {
        self.data.len()
    }

    pub fn to_vec(&self) -> Vec<T> {
        self.data.clone()
    }
}

impl<T: Copy, const RANK: usize> Tensor<T, RANK> {
    pub fn filled(value: T, shape: [usize; RANK]) -> Self {
        Tensor::new(vec![value; shape.iter().product()], shape)
    }
}

/// # Element-wise addition
/// Perform element-wise addition of two (2) tensors of the same shape.
///
/// ```
/// # use tensorite_math_tensor::Tensor;
/// # let shape = [4];
/// let a = Tensor::new(vec![1, 2, 3, 4], shape);
/// let b = Tensor::new(vec![5, 6, 7, 8], shape);
/// let c = a + b; // c = [6, 8, 10, 12]
/// ```
///
/// If the shapes of the tensors that you want to use are not of the same shape,
/// consider broadcasting them:
impl<T: Add<Output = T> + Copy, const RANK: usize> Add for Tensor<T, RANK> {
    type Output = Self;
    fn add(self, rhs: Self) -> Self::Output {
        assert_eq!(self.shape, rhs.shape, "Shapes must match for addition");
        let data = self.data.iter().zip(rhs.data.iter()).map(|(&a, &b)| a + b).collect();
        Tensor::new(data, *self.shape())
    }
}

/// # Element-wise subtraction
/// Perform element-wise subtraction of two (2) tensors of the same shape
///
/// ```
/// # use tensorite_math_tensor::Tensor;
/// # let shape = [4];
/// let a = Tensor::new(vec![1, 2, 3, 4], shape);
/// let b = Tensor::new(vec![5, 6, 7, 8], shape);
/// let c = a - b; // c = [-4, -4, -4, -4]
/// ```
///
/// If the shapes of the tensors that you want to use are not of the same shape,
/// consider broadcasting them:
impl<T: Sub<Output = T> + Copy, const RANK: usize> Sub for Tensor<T, RANK> {
    type Output = Self;
    fn sub(self, rhs: Self) -> Self::Output {
        assert_eq!(self.shape, rhs.shape, "Shapes must match for subtraction");
        let data = self.data.iter().zip(rhs.data.iter()).map(|(&a, &b)| a - b).collect();
        Tensor::new(data, *self.shape())
    }
}

/// # Element-wise multiplication
/// Perform element-wise multiplication of two (2) tensors of the same shape.
///
/// ```
/// # use tensorite_math_tensor::Tensor;
/// # let shape = [4];
/// let a = Tensor::new(vec![1, 2, 3, 4], shape);
/// let b = Tensor::new(vec![5, 6, 7, 8], shape);
/// let c = a + b; // c = [5, 12, 21, 32]
/// ```
///
/// If the shapes of the tensors that you want to use are not of the same shape,
/// consider broadcasting them:
impl<T: Mul<Output = T> + Copy, const RANK: usize> Mul for Tensor<T, RANK> {
    type Output = Self;
    fn mul(self, rhs: Self) -> Self::Output {
        assert_eq!(self.shape, rhs.shape, "Shapes must match for multiplication");
        let data = self.data.iter().zip(rhs.data.iter()).map(|(&a, &b)| a * b).collect();
        Tensor::new(data, *self.shape())
    }
}

/// # Element-wise division
/// Perform element-wise division of two (2) tensors of the same shape.
///
/// ```
/// # use tensorite_math_tensor::Tensor;
/// # let shape = [4];
/// let a = Tensor::new(vec![1.0, 2.0, 3.5, 4.0], shape);
/// let b = Tensor::new(vec![5.0, 6.0, 7.0, 8.0], shape);
/// let c = a / b; // c = [0.2, 0.5, 0.5, 0.5]
/// ```
///
/// If the shapes of the tensors that you want to use are not of the same shape,
/// consider broadcasting them:
impl<T: Div<Output = T> + Copy, const RANK: usize> Div for Tensor<T, RANK> {
    type Output = Self;
    fn div(self, rhs: Self) -> Self::Output {
        assert_eq!(self.shape, rhs.shape, "Shapes must match for division");
        let data = self.data.iter().zip(rhs.data.iter()).map(|(&a, &b)| a / b).collect();
        Tensor::new(data, *self.shape())
    }
}

/// # Dot product
/// Calculate the dot product of two (2) tensors of the same shape.
///
/// ```
/// # use tensorite_math_tensor::Tensor;
/// # let shape = [4];
/// let a = Tensor::new(vec![1.0, 2.0, 3.5, 4.0], shape);
/// let b = Tensor::new(vec![5.0, 6.0, 7.0, 8.0], shape);
/// let c = a / b; // c = [0.2, 0.5, 0.5, 0.5]
/// ```
///
/// If the shapes of the tensors that you want to use are not of the same shape,
/// consider broadcasting them:
pub trait DotProduct<Rhs = Self> {
    type Output;
    fn dot(&self, rhs: &Rhs) -> Self::Output;
}

/// Implement dot product for `Tensor<T>`s as `tensor.dot()`.
impl<T: Copy + Mul<Output = T> + Add<Output = T> + Default> DotProduct for Tensor<T, 1> {
    type Output = T;
    fn dot(&self, rhs: &Tensor<T, 1>) -> Self::Output {
        assert_eq!(self.shape, rhs.shape, "Shapes must match for dot product");
        self.data.iter().zip(rhs.data.iter()).map(|(&a, &b)| a * b).fold(T::default(), |acc, x| acc + x)
    }
}

#[derive(Clone, Copy, Debug, PartialEq)]
pub struct TensorBroadcastError;

pub trait Broadcast<T: Copy, const RANK: usize> {
    fn broadcast(&self, other: &Tensor<T, RANK>) -> Result<Tensor<T, RANK>, TensorBroadcastError>;

    /// Broadcast the tensor in-place to the given shape, modifying the data as needed.
    /// Panics if the new shape's product doesn't match the data length (for non-broadcasting resizes).
    fn broadcast_inplace(&mut self, new_shape: [usize; RANK]);

    /// Broadcast the tensor to the given shape
    fn broadcast_shape(&self, shape: [usize; RANK]) -> Tensor<T, RANK>;

    /// Broadcasting support: returns the shape after broadcasting
    fn get_broadcast_shape(&self, other: &Tensor<T, RANK>) -> Result<[usize; RANK], TensorBroadcastError>;
}

impl<T: Copy, const RANK: usize> Broadcast<T, RANK> for Tensor<T, RANK> {
    fn broadcast(&self, other: &Tensor<T, RANK>) -> Result<Tensor<T, RANK>, TensorBroadcastError> {
        let shape = self.get_broadcast_shape(other);

        let mut tensor = Tensor::new(self.data.clone(), self.shape.clone());

        tensor.broadcast_inplace(shape?);

        Ok(tensor)
    }

    /// Resizes the tensor in-place to the given shape, modifying the data as needed.
    /// Panics if the new shape's product doesn't match the data length (for non-broadcasting resizes).
    fn broadcast_inplace(&mut self, new_shape: [usize; RANK]) {
        let new_size = new_shape.iter().product::<usize>();
        if new_size != self.data.len() {
            // Handle broadcasting-like expansion or truncation
            if self.data.len() == 1 {
                // Expand a scalar to the new shape
                let value = self.data[0];
                self.data = vec![value; new_size];
            } else if new_size == self.data.len() {
                // Truncate or pad to match the new size (if needed)
                // (This is a simple example; you can customize this logic)
                self.data.truncate(new_size);
                self.data.resize(new_size, self.data[0]);
            } else {
                panic!("Cannot resize tensor: new shape product must match data length or be a broadcastable shape");
            }
        }
        self.shape = new_shape;
    }

    fn broadcast_shape(&self, shape: [usize; RANK]) -> Tensor<T, RANK> {
        Tensor::new(self.data.clone(), shape)
    }

    fn get_broadcast_shape(&self, other: &Tensor<T, RANK>) -> Result<[usize; RANK], TensorBroadcastError> {
        let mut result_shape = [0; RANK];

        for i in 0..RANK {
            if self.shape[i] == other.shape[i] {
                result_shape[i] = self.shape[i];
            } else if self.shape[i] == 1 {
                result_shape[i] = other.shape[i];
            } else if other.shape[i] == 1 {
                result_shape[i] = self.shape[i];
            } else {
                return Err(TensorBroadcastError); // Cannot broadcast
            }
        }

        Ok(result_shape)
    }
}

impl<T: Copy, const RANK: usize> Display for Tensor<T, RANK> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{:?}", self.shape)
    }
}