use crate::{Tensor, TensorElement, FloatElement};
use torsh_core::error::{Result, TorshError};
use torsh_core::shape::Shape;
use scirs2_core::ndarray::{Array, Array2, ArrayView, Axis, IxDyn};
use scirs2_core::numeric::{Float, Zero, One, cast::ToPrimitive};
use std::ops::{Add, Mul, Sub};
#[cfg(feature = "simd")]
use scirs2_core::simd::{SimdArray, SimdOps};
use scirs2_core::simd_ops::{simd_dot_product, simd_matrix_multiply};
#[cfg(feature = "parallel")]
use scirs2_core::parallel::{ParallelExecutor, ChunkStrategy};
use scirs2_core::parallel_ops::{par_chunks, par_join};
#[cfg(feature = "profiling")]
use scirs2_core::profiling::profile_section;
impl<T: TensorElement> Tensor<T> {
pub fn matmul(&self, other: &Self) -> Result<Self>
where
T: Add<Output = T> + Mul<Output = T> + Zero + Copy
{
#[cfg(feature = "profiling")]
let _profile = profile_section!("matmul");
#[cfg(feature = "gpu")]
{
if self.ndim() == 2 && other.ndim() == 2 {
let self_shape = self.shape().dims();
let other_shape = other.shape().dims();
if self_shape[0] * self_shape[1] > 10000 && other_shape[0] * other_shape[1] > 10000 {
if let Ok(result) = self.gpu_matmul_2d(other) {
return Ok(result);
}
}
}
}
#[cfg(feature = "simd")]
{
if self.ndim() == 2 && other.ndim() == 2 {
let self_shape = self.shape().dims();
let other_shape = other.shape().dims();
if self_shape[0] * self_shape[1] > 1000 && other_shape[0] * other_shape[1] > 1000 {
if let Ok(result) = self.simd_matmul_2d(other) {
return Ok(result);
}
}
}
}
match (self.ndim(), other.ndim()) {
(2, 2) => self.matmul_2d(other),
(1, 2) => self.matmul_1d_2d(other),
(2, 1) => self.matmul_2d_1d(other),
(1, 1) => self.matmul_1d_1d(other),
_ => self.batch_matmul(other),
}
}
#[cfg(feature = "gpu")]
fn gpu_matmul_2d(&self, other: &Self) -> Result<Self>
where
T: scirs2_core::gpu::GpuElement + Add<Output = T> + Mul<Output = T> + Zero + Copy,
{
#[cfg(feature = "profiling")]
let _profile = profile_section!("gpu_matmul");
use scirs2_core::gpu::{GpuContext, GpuBuffer, GpuKernel};
let self_shape = self.shape().dims();
let other_shape = other.shape().dims();
if self_shape.len() != 2 || other_shape.len() != 2 {
return Err(TorshError::Other("Both tensors must be 2D for GPU matmul".to_string()));
}
let (m, k1) = (self_shape[0], self_shape[1]);
let (k2, n) = (other_shape[0], other_shape[1]);
if k1 != k2 {
return Err(TorshError::Other(format!(
"Cannot multiply matrices with shapes [{}, {}] and [{}, {}]",
m, k1, k2, n
)));
}
let gpu_context = GpuContext::new()?;
let self_data = self.data()?;
let other_data = other.data()?;
let gpu_buffer_a = GpuBuffer::from_slice(&gpu_context, &self_data)?;
let gpu_buffer_b = GpuBuffer::from_slice(&gpu_context, &other_data)?;
let gpu_buffer_c = GpuBuffer::zeros(&gpu_context, m * n)?;
let kernel = GpuKernel::matrix_multiply(&gpu_context)?;
kernel.launch(
&gpu_buffer_a, &gpu_buffer_b, &gpu_buffer_c,
m, k1, n
)?;
let result_data = gpu_buffer_c.to_vec()?;
let result_shape = vec![m, n];
Self::from_data(result_data, result_shape, self.device())
}
#[cfg(feature = "simd")]
fn simd_matmul_2d(&self, other: &Self) -> Result<Self>
where
T: scirs2_core::simd::SimdElement + Add<Output = T> + Mul<Output = T> + Zero + Copy,
{
let self_shape = self.shape().dims();
let other_shape = other.shape().dims();
if self_shape.len() != 2 || other_shape.len() != 2 {
return Err(TorshError::Other("Both tensors must be 2D for SIMD matmul".to_string()));
}
let (m, k1) = (self_shape[0], self_shape[1]);
let (k2, n) = (other_shape[0], other_shape[1]);
if k1 != k2 {
return Err(TorshError::Other(format!(
"Cannot multiply matrices with shapes [{}, {}] and [{}, {}]",
m, k1, k2, n
)));
}
let self_data = self.data()?;
let other_data = other.data()?;
let result_data = simd_matrix_multiply(
&self_data, &other_data,
m, k1, n
);
let result_shape = vec![m, n];
Self::from_data(result_data, result_shape, self.device())
}
fn matmul_2d(&self, other: &Self) -> Result<Self>
where
T: Add<Output = T> + Mul<Output = T> + Zero + Copy
{
let self_shape = self.shape().dims();
let other_shape = other.shape().dims();
if self_shape.len() != 2 || other_shape.len() != 2 {
return Err(TorshError::Other("Both tensors must be 2D for 2D matmul".to_string()));
}
let (m, k1) = (self_shape[0], self_shape[1]);
let (k2, n) = (other_shape[0], other_shape[1]);
if k1 != k2 {
return Err(TorshError::Other(format!(
"Cannot multiply matrices with shapes [{}, {}] and [{}, {}]",
m, k1, k2, n
)));
}
let result_shape = vec![m, n];
let mut result = Self::zeros(&result_shape, self.device())?;
for i in 0..m {
for j in 0..n {
let mut sum = <T as TensorElement>::zero();
for k in 0..k1 {
let a_val = self.get_item(&[i, k])?;
let b_val = other.get_item(&[k, j])?;
sum = sum + a_val * b_val;
}
result.set_item(&[i, j], sum)?;
}
}
Ok(result)
}
fn matmul_1d_2d(&self, other: &Self) -> Result<Self>
where
T: Add<Output = T> + Mul<Output = T> + Zero + Copy
{
let self_shape = self.shape().dims();
let other_shape = other.shape().dims();
if self_shape.len() != 1 || other_shape.len() != 2 {
return Err(TorshError::Other("Expected 1D and 2D tensors".to_string()));
}
let k1 = self_shape[0];
let (k2, n) = (other_shape[0], other_shape[1]);
if k1 != k2 {
return Err(TorshError::Other(format!(
"Cannot multiply vector of length {} with matrix of shape [{}, {}]",
k1, k2, n
)));
}
let result_shape = vec![n];
let mut result = Self::zeros(&result_shape, self.device())?;
for j in 0..n {
let mut sum = <T as TensorElement>::zero();
for k in 0..k1 {
let a_val = self.get_item(&[k])?;
let b_val = other.get_item(&[k, j])?;
sum = sum + a_val * b_val;
}
result.set_item(&[j], sum)?;
}
Ok(result)
}
fn matmul_2d_1d(&self, other: &Self) -> Result<Self>
where
T: Add<Output = T> + Mul<Output = T> + Zero + Copy
{
let self_shape = self.shape().dims();
let other_shape = other.shape().dims();
if self_shape.len() != 2 || other_shape.len() != 1 {
return Err(TorshError::Other("Expected 2D and 1D tensors".to_string()));
}
let (m, k1) = (self_shape[0], self_shape[1]);
let k2 = other_shape[0];
if k1 != k2 {
return Err(TorshError::Other(format!(
"Cannot multiply matrix of shape [{}, {}] with vector of length {}",
m, k1, k2
)));
}
let result_shape = vec![m];
let mut result = Self::zeros(&result_shape, self.device())?;
for i in 0..m {
let mut sum = <T as TensorElement>::zero();
for k in 0..k1 {
let a_val = self.get_item(&[i, k])?;
let b_val = other.get_item(&[k])?;
sum = sum + a_val * b_val;
}
result.set_item(&[i], sum)?;
}
Ok(result)
}
fn matmul_1d_1d(&self, other: &Self) -> Result<Self>
where
T: Add<Output = T> + Mul<Output = T> + Zero + Copy
{
let self_shape = self.shape().dims();
let other_shape = other.shape().dims();
if self_shape.len() != 1 || other_shape.len() != 1 {
return Err(TorshError::Other("Both tensors must be 1D for dot product".to_string()));
}
let n1 = self_shape[0];
let n2 = other_shape[0];
if n1 != n2 {
return Err(TorshError::Other(format!(
"Cannot compute dot product of vectors with lengths {} and {}",
n1, n2
)));
}
let mut sum = <T as TensorElement>::zero();
for i in 0..n1 {
let a_val = self.get_item(&[i])?;
let b_val = other.get_item(&[i])?;
sum = sum + a_val * b_val;
}
Self::from_scalar(sum, &[], self.device())
}
fn batch_matmul(&self, other: &Self) -> Result<Self>
where
T: Add<Output = T> + Mul<Output = T> + Zero + Copy
{
let self_shape_obj = self.shape();
let self_shape = self_shape_obj.dims();
let other_shape_obj = other.shape();
let other_shape = other_shape_obj.dims();
if self_shape.len() < 2 || other_shape.len() < 2 {
return Err(TorshError::Other("Tensors must have at least 2 dimensions for batched matmul".to_string()));
}
let self_ndim = self_shape.len();
let other_ndim = other_shape.len();
let self_matrix_dims = [self_shape[self_ndim - 2], self_shape[self_ndim - 1]];
let other_matrix_dims = [other_shape[other_ndim - 2], other_shape[other_ndim - 1]];
if self_matrix_dims[1] != other_matrix_dims[0] {
return Err(TorshError::Other(format!(
"Cannot multiply matrices with inner dimensions {} and {}",
self_matrix_dims[1], other_matrix_dims[0]
)));
}
let result_matrix_dims = [self_matrix_dims[0], other_matrix_dims[1]];
let mut result_shape = self_shape[..self_ndim - 2].to_vec();
result_shape.extend_from_slice(&result_matrix_dims);
let result = Self::zeros(&result_shape, self.device())?;
Ok(result)
}
pub fn outer(&self, other: &Self) -> Result<Self>
where
T: Mul<Output = T> + Copy
{
if self.ndim() != 1 || other.ndim() != 1 {
return Err(TorshError::Other("Both tensors must be 1D for outer product".to_string()));
}
let m = self.shape().dims()[0];
let n = other.shape().dims()[0];
let result_shape = vec![m, n];
let mut result = Self::zeros(&result_shape, self.device())?;
for i in 0..m {
for j in 0..n {
let a_val = self.get_item(&[i])?;
let b_val = other.get_item(&[j])?;
result.set_item(&[i, j], a_val * b_val)?;
}
}
Ok(result)
}
pub fn cross(&self, other: &Self) -> Result<Self>
where
T: Sub<Output = T> + Mul<Output = T> + Copy
{
if self.ndim() != 1 || other.ndim() != 1 {
return Err(TorshError::Other("Both tensors must be 1D for cross product".to_string()));
}
if self.shape().dims()[0] != 3 || other.shape().dims()[0] != 3 {
return Err(TorshError::Other("Cross product requires 3D vectors".to_string()));
}
let a = [self.get_item(&[0])?, self.get_item(&[1])?, self.get_item(&[2])?];
let b = [other.get_item(&[0])?, other.get_item(&[1])?, other.get_item(&[2])?];
let result_data = [
a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2], a[0] * b[1] - a[1] * b[0], ];
let mut result = Self::zeros(&[3], self.device())?;
for i in 0..3 {
result.set_item(&[i], result_data[i])?;
}
Ok(result)
}
pub fn trace(&self) -> Result<T>
where
T: Add<Output = T> + Zero + Copy
{
if self.ndim() != 2 {
return Err(TorshError::Other("Trace can only be computed for 2D tensors".to_string()));
}
let shape = self.shape().dims();
let min_dim = std::cmp::min(shape[0], shape[1]);
let mut trace_val = <T as TensorElement>::zero();
for i in 0..min_dim {
trace_val = trace_val + self.get_item(&[i, i])?;
}
Ok(trace_val)
}
pub fn diag(&self) -> Result<Self>
where
T: Copy
{
if self.ndim() != 2 {
return Err(TorshError::Other("Diagonal can only be extracted from 2D tensors".to_string()));
}
let shape = self.shape().dims();
let min_dim = std::cmp::min(shape[0], shape[1]);
let mut result = Self::zeros(&[min_dim], self.device())?;
for i in 0..min_dim {
let val = self.get_item(&[i, i])?;
result.set_item(&[i], val)?;
}
Ok(result)
}
pub fn diag_embed(diag: &Self) -> Result<Self>
where
T: Zero + Copy
{
if diag.ndim() != 1 {
return Err(TorshError::Other("Input must be 1D tensor for diagonal embedding".to_string()));
}
let n = diag.shape().dims()[0];
let mut result = Self::zeros(&[n, n], diag.device())?;
for i in 0..n {
let val = diag.get_item(&[i])?;
result.set_item(&[i, i], val)?;
}
Ok(result)
}
}
impl<T: FloatElement> Tensor<T> {
pub fn det(&self) -> Result<T>
where
T: Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Zero + One + Copy + PartialEq,
{
if self.ndim() != 2 {
return Err(TorshError::Other("Determinant can only be computed for 2D tensors".to_string()));
}
let shape = self.shape().dims();
if shape[0] != shape[1] {
return Err(TorshError::Other("Determinant requires a square matrix".to_string()));
}
let n = shape[0];
match n {
0 => Ok(<T as TensorElement>::one()),
1 => Ok(self.get_item(&[0, 0])?),
2 => {
let a00 = self.get_item(&[0, 0])?;
let a01 = self.get_item(&[0, 1])?;
let a10 = self.get_item(&[1, 0])?;
let a11 = self.get_item(&[1, 1])?;
Ok(a00 * a11 - a01 * a10)
}
3 => {
let a00 = self.get_item(&[0, 0])?;
let a01 = self.get_item(&[0, 1])?;
let a02 = self.get_item(&[0, 2])?;
let a10 = self.get_item(&[1, 0])?;
let a11 = self.get_item(&[1, 1])?;
let a12 = self.get_item(&[1, 2])?;
let a20 = self.get_item(&[2, 0])?;
let a21 = self.get_item(&[2, 1])?;
let a22 = self.get_item(&[2, 2])?;
let pos = a00 * a11 * a22 + a01 * a12 * a20 + a02 * a10 * a21;
let neg = a02 * a11 * a20 + a01 * a10 * a22 + a00 * a12 * a21;
Ok(pos - neg)
}
_ => {
let data = self.to_vec()?;
let mut lu = data.clone();
let mut swaps = 0;
for k in 0..n {
let mut max_val = lu[k * n + k];
let mut max_row = k;
for i in (k + 1)..n {
let val = lu[i * n + k];
if val.abs() > max_val.abs() {
max_val = val;
max_row = i;
}
}
if max_row != k {
for j in 0..n {
let temp = lu[k * n + j];
lu[k * n + j] = lu[max_row * n + j];
lu[max_row * n + j] = temp;
}
swaps += 1;
}
if lu[k * n + k].abs() < <T as TensorElement>::zero() {
return Ok(<T as TensorElement>::zero());
}
for i in (k + 1)..n {
lu[i * n + k] = lu[i * n + k] / lu[k * n + k];
for j in (k + 1)..n {
lu[i * n + j] = lu[i * n + j] - lu[i * n + k] * lu[k * n + j];
}
}
}
let mut det = if swaps % 2 == 0 {
<T as TensorElement>::one()
} else {
<T as TensorElement>::zero() - <T as TensorElement>::one()
};
for i in 0..n {
det = det * lu[i * n + i];
}
Ok(det)
}
}
}
pub fn inverse(&self) -> Result<Self>
where
T: Add<Output = T> + Sub<Output = T> + Mul<Output = T> +
std::ops::Div<Output = T> + Zero + One + Copy + PartialEq,
{
if self.ndim() != 2 {
return Err(TorshError::Other("Inverse can only be computed for 2D tensors".to_string()));
}
let shape = self.shape().dims();
if shape[0] != shape[1] {
return Err(TorshError::Other("Inverse requires a square matrix".to_string()));
}
let n = shape[0];
let det = self.det()?;
if det == <T as TensorElement>::zero() {
return Err(TorshError::Other("Matrix is singular and cannot be inverted".to_string()));
}
match n {
1 => {
let val = self.get_item(&[0, 0])?;
let inv_val = <T as TensorElement>::one() / val;
let mut result = Self::zeros(&[1, 1], self.device())?;
result.set_item(&[0, 0], inv_val)?;
Ok(result)
}
2 => {
let a = self.get_item(&[0, 0])?;
let b = self.get_item(&[0, 1])?;
let c = self.get_item(&[1, 0])?;
let d = self.get_item(&[1, 1])?;
let det_inv = <T as TensorElement>::one() / det;
let mut result = Self::zeros(&[2, 2], self.device())?;
result.set_item(&[0, 0], d * det_inv)?;
result.set_item(&[0, 1], (<T as TensorElement>::zero() - b) * det_inv)?;
result.set_item(&[1, 0], (<T as TensorElement>::zero() - c) * det_inv)?;
result.set_item(&[1, 1], a * det_inv)?;
Ok(result)
}
_ => {
let data = self.to_vec()?;
let mut lu = data.clone();
let mut perm: Vec<usize> = (0..n).collect();
for k in 0..n {
let mut max_val = lu[perm[k] * n + k];
let mut max_row = k;
for i in (k + 1)..n {
let val = lu[perm[i] * n + k];
if val.abs() > max_val.abs() {
max_val = val;
max_row = i;
}
}
if max_row != k {
perm.swap(k, max_row);
}
let pivot = lu[perm[k] * n + k];
if pivot.abs() < <T as TensorElement>::zero() {
return Err(TorshError::Other("Matrix is singular".to_string()));
}
for i in (k + 1)..n {
let factor = lu[perm[i] * n + k] / lu[perm[k] * n + k];
lu[perm[i] * n + k] = factor;
for j in (k + 1)..n {
lu[perm[i] * n + j] = lu[perm[i] * n + j] - factor * lu[perm[k] * n + j];
}
}
}
let mut result_data = vec![<T as TensorElement>::zero(); n * n];
for col in 0..n {
let mut b = vec![<T as TensorElement>::zero(); n];
b[col] = <T as TensorElement>::one();
let mut y = vec![<T as TensorElement>::zero(); n];
for i in 0..n {
let mut sum = b[perm[i]];
for j in 0..i {
sum = sum - lu[perm[i] * n + j] * y[j];
}
y[i] = sum;
}
let mut x = vec![<T as TensorElement>::zero(); n];
for i in (0..n).rev() {
let mut sum = y[i];
for j in (i + 1)..n {
sum = sum - lu[perm[i] * n + j] * x[j];
}
x[i] = sum / lu[perm[i] * n + i];
}
for row in 0..n {
result_data[row * n + col] = x[row];
}
}
Self::from_vec(result_data, &[n, n], self.device())
}
}
}
pub fn matrix_rank(&self) -> Result<usize>
where
T: Add<Output = T> + Sub<Output = T> + Mul<Output = T> +
std::ops::Div<Output = T> + Zero + One + Copy + PartialOrd,
{
if self.ndim() != 2 {
return Err(TorshError::Other("Rank can only be computed for 2D tensors".to_string()));
}
let shape_obj = self.shape();
let shape = shape_obj.dims();
Ok(std::cmp::min(shape[0], shape[1]))
}
pub fn frobenius_norm(&self) -> Result<T>
where
T: Add<Output = T> + Mul<Output = T> + Zero + Copy + ToPrimitive,
{
if self.ndim() != 2 {
return Err(TorshError::Other("Frobenius norm can only be computed for 2D tensors".to_string()));
}
let mut sum = <T as TensorElement>::zero();
for i in 0..self.shape().dims()[0] {
for j in 0..self.shape().dims()[1] {
let val = self.get_item(&[i, j])?;
sum = sum + val * val;
}
}
if let Some(sum_f64) = <T as TensorElement>::to_f64(&sum) {
Ok(T::from(sum_f64.sqrt()).unwrap_or(sum))
} else {
Ok(sum)
}
}
pub fn cond(&self) -> Result<T>
where
T: Add<Output = T> + Sub<Output = T> + Mul<Output = T> +
std::ops::Div<Output = T> + Zero + One + Copy + PartialEq + PartialOrd + ToPrimitive,
{
let norm_a = self.frobenius_norm()?;
let inv_a = self.inverse()?;
let norm_inv_a = inv_a.frobenius_norm()?;
Ok(norm_a * norm_inv_a)
}
pub fn solve(&self, b: &Self) -> Result<Self>
where
T: Add<Output = T> + Sub<Output = T> + Mul<Output = T> +
std::ops::Div<Output = T> + Zero + One + Copy + PartialEq + PartialOrd,
{
if self.ndim() != 2 {
return Err(TorshError::Other("A must be a 2D tensor (matrix)".to_string()));
}
let shape_a_obj = self.shape();
let shape_a = shape_a_obj.dims();
if shape_a[0] != shape_a[1] {
return Err(TorshError::Other("A must be a square matrix".to_string()));
}
if b.ndim() != 1 && b.ndim() != 2 {
return Err(TorshError::Other("b must be a 1D or 2D tensor".to_string()));
}
let n = shape_a[0];
let shape_b_obj = b.shape();
let shape_b = shape_b_obj.dims();
if shape_b[0] != n {
return Err(TorshError::Other("Incompatible dimensions for solving Ax = b".to_string()));
}
let inv_a = self.inverse()?;
inv_a.matmul(b)
}
pub fn tril(&self, diagonal: isize) -> Result<Self>
where
T: Zero + Copy,
{
let shape = self.shape().dims();
if shape.len() < 2 {
return Err(TorshError::InvalidArgument(
"tril requires at least 2D tensor".to_string(),
));
}
let data = self.data()?;
let mut result = vec![<T as Zero>::zero(); data.len()];
let rows = shape[shape.len() - 2];
let cols = shape[shape.len() - 1];
let matrix_size = rows * cols;
let num_matrices = data.len() / matrix_size;
for mat_idx in 0..num_matrices {
let offset = mat_idx * matrix_size;
for i in 0..rows {
for j in 0..cols {
let idx = offset + i * cols + j;
if (j as isize) <= (i as isize) + diagonal {
result[idx] = data[idx];
}
}
}
}
Self::from_data(result, shape.to_vec(), self.device.clone())
}
pub fn triu(&self, diagonal: isize) -> Result<Self>
where
T: Zero + Copy,
{
let shape = self.shape().dims();
if shape.len() < 2 {
return Err(TorshError::InvalidArgument(
"triu requires at least 2D tensor".to_string(),
));
}
let data = self.data()?;
let mut result = vec![<T as Zero>::zero(); data.len()];
let rows = shape[shape.len() - 2];
let cols = shape[shape.len() - 1];
let matrix_size = rows * cols;
let num_matrices = data.len() / matrix_size;
for mat_idx in 0..num_matrices {
let offset = mat_idx * matrix_size;
for i in 0..rows {
for j in 0..cols {
let idx = offset + i * cols + j;
if (j as isize) >= (i as isize) + diagonal {
result[idx] = data[idx];
}
}
}
}
Self::from_data(result, shape.to_vec(), self.device.clone())
}
pub fn diagonal(&self, offset: isize, dim1: Option<isize>, dim2: Option<isize>) -> Result<Self>
where
T: Copy,
{
let shape = self.shape().dims();
let ndim = shape.len();
if ndim < 2 {
return Err(TorshError::InvalidArgument(
"diagonal requires at least 2D tensor".to_string(),
));
}
let d1 = if let Some(d) = dim1 {
if d < 0 {
(ndim as isize + d) as usize
} else {
d as usize
}
} else {
ndim - 2
};
let d2 = if let Some(d) = dim2 {
if d < 0 {
(ndim as isize + d) as usize
} else {
d as usize
}
} else {
ndim - 1
};
if d1 >= ndim || d2 >= ndim {
return Err(TorshError::InvalidArgument(
"dimension out of range".to_string(),
));
}
if d1 == d2 {
return Err(TorshError::InvalidArgument(
"dimensions must be different".to_string(),
));
}
let data = self.data()?;
let rows = shape[d1];
let cols = shape[d2];
let diag_len = if offset >= 0 {
if offset as usize >= cols {
0
} else {
std::cmp::min(rows, cols - offset as usize)
}
} else {
let abs_offset = (-offset) as usize;
if abs_offset >= rows {
0
} else {
std::cmp::min(rows - abs_offset, cols)
}
};
if diag_len == 0 {
return Self::from_data(vec![], vec![0], self.device.clone());
}
if ndim == 2 {
let mut result = Vec::with_capacity(diag_len);
for i in 0..diag_len {
let row = if offset < 0 {
i + (-offset) as usize
} else {
i
};
let col = if offset >= 0 {
i + offset as usize
} else {
i
};
if row < rows && col < cols {
let idx = row * cols + col;
result.push(data[idx]);
}
}
return Self::from_data(result, vec![diag_len], self.device.clone());
}
Err(TorshError::NotImplemented(
"diagonal for >2D tensors not yet implemented".to_string(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::device::DeviceType;
#[test]
fn test_tril_2d() {
let tensor = Tensor::from_data(
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
vec![3, 3],
DeviceType::Cpu,
).expect("tensor creation failed");
let result = tensor.tril(0).expect("tril failed");
let data = result.data().expect("data retrieval failed");
assert_eq!(data, vec![1.0, 0.0, 0.0, 4.0, 5.0, 0.0, 7.0, 8.0, 9.0]);
}
#[test]
fn test_tril_diagonal_offset() {
let tensor = Tensor::from_data(
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
vec![3, 3],
DeviceType::Cpu,
).expect("tensor creation failed");
let result = tensor.tril(1).expect("tril failed");
let data = result.data().expect("data retrieval failed");
assert_eq!(data, vec![1.0, 2.0, 0.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
}
#[test]
fn test_triu_2d() {
let tensor = Tensor::from_data(
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
vec![3, 3],
DeviceType::Cpu,
).expect("tensor creation failed");
let result = tensor.triu(0).expect("triu failed");
let data = result.data().expect("data retrieval failed");
assert_eq!(data, vec![1.0, 2.0, 3.0, 0.0, 5.0, 6.0, 0.0, 0.0, 9.0]);
}
#[test]
fn test_triu_diagonal_offset() {
let tensor = Tensor::from_data(
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
vec![3, 3],
DeviceType::Cpu,
).expect("tensor creation failed");
let result = tensor.triu(1).expect("triu failed");
let data = result.data().expect("data retrieval failed");
assert_eq!(data, vec![0.0, 2.0, 3.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0]);
}
#[test]
fn test_tril_rectangular() {
let tensor = Tensor::from_data(
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
vec![2, 3],
DeviceType::Cpu,
).expect("tensor creation failed");
let result = tensor.tril(0).expect("tril failed");
let data = result.data().expect("data retrieval failed");
assert_eq!(data, vec![1.0, 0.0, 0.0, 4.0, 5.0, 0.0]);
}
#[test]
fn test_diagonal_main() {
let tensor = Tensor::from_data(
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
vec![3, 3],
DeviceType::Cpu,
).expect("tensor creation failed");
let result = tensor.diagonal(0, None, None).expect("diagonal failed");
let data = result.data().expect("data retrieval failed");
assert_eq!(data, vec![1.0, 5.0, 9.0]);
}
#[test]
fn test_diagonal_above() {
let tensor = Tensor::from_data(
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
vec![3, 3],
DeviceType::Cpu,
).expect("tensor creation failed");
let result = tensor.diagonal(1, None, None).expect("diagonal failed");
let data = result.data().expect("data retrieval failed");
assert_eq!(data, vec![2.0, 6.0]);
}
#[test]
fn test_diagonal_below() {
let tensor = Tensor::from_data(
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
vec![3, 3],
DeviceType::Cpu,
).expect("tensor creation failed");
let result = tensor.diagonal(-1, None, None).expect("diagonal failed");
let data = result.data().expect("data retrieval failed");
assert_eq!(data, vec![4.0, 8.0]);
}
#[test]
fn test_diagonal_rectangular() {
let tensor = Tensor::from_data(
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
vec![2, 3],
DeviceType::Cpu,
).expect("tensor creation failed");
let result = tensor.diagonal(0, None, None).expect("diagonal failed");
let data = result.data().expect("data retrieval failed");
assert_eq!(data, vec![1.0, 5.0]);
}
}