use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array2, ArrayD, ArrayView, ArrayViewD, Dimension};
use scirs2_core::numeric::{Float, NumAssign, One, Zero};
use std::collections::HashMap;
use std::fmt::Debug;
use std::iter::Sum;
use std::sync::{Arc, Mutex};
pub mod cp;
pub mod tensor_network;
pub mod tensor_train;
pub mod tucker;
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub fn contract<A, D1, D2>(
a: &ArrayView<A, D1>,
b: &ArrayView<A, D2>,
axes_a: &[usize],
axes_b: &[usize],
) -> LinalgResult<ArrayD<A>>
where
A: Clone + Float + NumAssign + Zero + Send + Sync + Sum + Debug + 'static,
D1: Dimension,
D2: Dimension,
{
if axes_a.len() != axes_b.len() {
return Err(LinalgError::ShapeError(format!(
"Number of contraction axes must match: got {} and {}",
axes_a.len(),
axes_b.len()
)));
}
for (i, (&ax_a, &ax_b)) in axes_a.iter().zip(axes_b.iter()).enumerate() {
if ax_a >= a.ndim() {
return Err(LinalgError::ShapeError(format!(
"Axis {} out of bounds for first input with dimension {}",
ax_a,
a.ndim()
)));
}
if ax_b >= b.ndim() {
return Err(LinalgError::ShapeError(format!(
"Axis {} out of bounds for second input with dimension {}",
ax_b,
b.ndim()
)));
}
if a.shape()[ax_a] != b.shape()[ax_b] {
return Err(LinalgError::ShapeError(format!(
"Dimension mismatch at index {}: {} != {}",
i,
a.shape()[ax_a],
b.shape()[ax_b]
)));
}
}
let free_axes_a: Vec<usize> = (0..a.ndim()).filter(|&ax| !axes_a.contains(&ax)).collect();
let free_axes_b: Vec<usize> = (0..b.ndim()).filter(|&ax| !axes_b.contains(&ax)).collect();
let mut resultshape = Vec::with_capacity(free_axes_a.len() + free_axes_b.len());
let mut free_dims_a = Vec::with_capacity(free_axes_a.len());
let mut free_dims_b = Vec::with_capacity(free_axes_b.len());
for &ax in &free_axes_a {
resultshape.push(a.shape()[ax]);
free_dims_a.push(a.shape()[ax]);
}
for &ax in &free_axes_b {
resultshape.push(b.shape()[ax]);
free_dims_b.push(b.shape()[ax]);
}
let a_dyn = a.view().into_dyn();
let b_dyn = b.view().into_dyn();
let result = ArrayD::zeros(resultshape.clone());
let result = Arc::new(Mutex::new(result));
let mut all_free_indices = Vec::new();
let total_combinations: usize = free_dims_a.iter().chain(free_dims_b.iter()).product();
all_free_indices.reserve(total_combinations);
fn generate_indices(
free_dims: &[usize],
current: Vec<usize>,
depth: usize,
all_indices: &mut Vec<Vec<usize>>,
) {
if depth == free_dims.len() {
all_indices.push(current);
return;
}
let mut current = current;
for i in 0..free_dims[depth] {
current.push(i);
generate_indices(free_dims, current.clone(), depth + 1, all_indices);
current.pop();
}
}
let mut combined_dims = free_dims_a.clone();
combined_dims.extend(free_dims_b.iter());
generate_indices(&combined_dims, Vec::new(), 0, &mut all_free_indices);
use scirs2_core::parallel_ops::*;
let results: Vec<_> = all_free_indices
.par_iter()
.map(|free_idx| {
let free_idx_a = &free_idx[0..free_dims_a.len()];
let free_idx_b = &free_idx[free_dims_a.len()..];
let mut a_idx = vec![0; a.ndim()];
let mut b_idx = vec![0; b.ndim()];
for (i, &ax) in free_axes_a.iter().enumerate() {
a_idx[ax] = free_idx_a[i];
}
for (i, &ax) in free_axes_b.iter().enumerate() {
b_idx[ax] = free_idx_b[i];
}
let mut sum = A::zero();
fn accumulate_sum<A>(
a: &ArrayViewD<A>,
b: &ArrayViewD<A>,
a_idx: &mut Vec<usize>,
b_idx: &mut Vec<usize>,
axes_a: &[usize],
axes_b: &[usize],
depth: usize,
sum: &mut A,
) where
A: Clone + Float + NumAssign + Zero,
{
if depth == axes_a.len() {
*sum += a[a_idx.as_slice()] * b[b_idx.as_slice()];
return;
}
let ax_a = axes_a[depth];
let ax_b = axes_b[depth];
let dim = a.shape()[ax_a];
for i in 0..dim {
a_idx[ax_a] = i;
b_idx[ax_b] = i;
accumulate_sum(a, b, a_idx, b_idx, axes_a, axes_b, depth + 1, sum);
}
}
accumulate_sum(
&a_dyn, &b_dyn, &mut a_idx, &mut b_idx, axes_a, axes_b, 0, &mut sum,
);
(free_idx.clone(), sum)
})
.collect();
for (idx, sum) in results {
let mut result_tensor = result.lock().expect("Operation failed");
result_tensor[idx.as_slice()] = sum;
}
Ok(Arc::try_unwrap(result)
.expect("Operation failed")
.into_inner()
.expect("Operation failed"))
}
#[allow(dead_code)]
pub fn batch_matmul<A, D1, D2>(
a: &ArrayView<A, D1>,
b: &ArrayView<A, D2>,
batch_dims: usize,
) -> LinalgResult<ArrayD<A>>
where
A: Clone + Float + NumAssign + Zero + Send + Sync + Sum + Debug + 'static,
D1: Dimension,
D2: Dimension,
{
if a.ndim() < batch_dims + 2 || b.ndim() < batch_dims + 2 {
return Err(LinalgError::ShapeError(format!(
"Both tensors must have at least batch_dims + 2 dimensions, got {} and {}",
a.ndim(),
b.ndim()
)));
}
for i in 0..batch_dims {
if a.shape()[i] != b.shape()[i] {
return Err(LinalgError::ShapeError(format!(
"Batch dimensions must match: {} != {} at index {}",
a.shape()[i],
b.shape()[i],
i
)));
}
}
if a.shape()[batch_dims + 1] != b.shape()[batch_dims] {
return Err(LinalgError::ShapeError(format!(
"Inner dimensions for matrix multiplication must match: {} != {}",
a.shape()[batch_dims + 1],
b.shape()[batch_dims]
)));
}
let mut outshape = Vec::with_capacity(batch_dims + 2);
for i in 0..batch_dims {
outshape.push(a.shape()[i]);
}
outshape.push(a.shape()[batch_dims]); outshape.push(b.shape()[batch_dims + 1]);
let a_dyn = a.view().into_dyn();
let b_dyn = b.view().into_dyn();
let batchsize: usize = outshape.iter().take(batch_dims).product();
let m = a.shape()[batch_dims]; let k = a.shape()[batch_dims + 1]; let n = b.shape()[batch_dims + 1];
let result = ArrayD::zeros(outshape.clone());
let result = Arc::new(Mutex::new(result));
let mut all_batch_indices = Vec::with_capacity(batchsize);
fn generate_batch_indices(
shape: &[usize],
current: Vec<usize>,
depth: usize,
max_depth: usize,
all_indices: &mut Vec<Vec<usize>>,
) {
if depth == max_depth {
all_indices.push(current);
return;
}
let mut current = current;
for i in 0..shape[depth] {
current.push(i);
generate_batch_indices(shape, current.clone(), depth + 1, max_depth, all_indices);
current.pop();
}
}
generate_batch_indices(&outshape, Vec::new(), 0, batch_dims, &mut all_batch_indices);
use scirs2_core::parallel_ops::*;
let results: Vec<_> = all_batch_indices
.par_iter()
.map(|batch_idx| {
let mut result_batch = Array2::zeros((m, n));
for i in 0..m {
for j in 0..n {
let mut sum = A::zero();
for p in 0..k {
let mut a_idx = batch_idx.clone();
a_idx.push(i);
a_idx.push(p);
let mut b_idx = batch_idx.clone();
b_idx.push(p);
b_idx.push(j);
sum += a_dyn[a_idx.as_slice()] * b_dyn[b_idx.as_slice()];
}
result_batch[[i, j]] = sum;
}
}
(batch_idx.clone(), result_batch)
})
.collect();
for (batch_idx, result_batch) in results {
let mut result_tensor = result.lock().expect("Operation failed");
for i in 0..m {
for j in 0..n {
let mut result_idx = batch_idx.clone();
result_idx.push(i);
result_idx.push(j);
result_tensor[result_idx.as_slice()] = result_batch[[i, j]];
}
}
}
Ok(Arc::try_unwrap(result)
.expect("Operation failed")
.into_inner()
.expect("Operation failed"))
}
#[allow(dead_code)]
pub fn mode_n_product<A, D1, D2>(
tensor: &ArrayView<A, D1>,
matrix: &ArrayView<A, D2>,
mode: usize,
) -> LinalgResult<ArrayD<A>>
where
A: Clone + Float + NumAssign + Zero + Send + Sync + Debug + 'static,
D1: Dimension,
D2: Dimension,
{
if mode >= tensor.ndim() {
return Err(LinalgError::ShapeError(format!(
"Mode {} is out of bounds for tensor with {} dimensions",
mode,
tensor.ndim()
)));
}
if matrix.ndim() != 2 {
return Err(LinalgError::ShapeError(format!(
"Matrix must be 2-dimensional, got {} dimensions",
matrix.ndim()
)));
}
if matrix.shape()[1] != tensor.shape()[mode] {
return Err(LinalgError::ShapeError(format!(
"Matrix columns ({}) must match tensor dimension along mode {} ({})",
matrix.shape()[1],
mode,
tensor.shape()[mode]
)));
}
let mut outshape = tensor.shape().to_vec();
outshape[mode] = matrix.shape()[0];
let tensor_dyn = tensor.view().into_dyn();
let matrix_view = match matrix
.view()
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
{
Ok(view) => view,
Err(_) => {
return Err(LinalgError::ComputationError(
"Failed to convert matrix to 2D view".to_string(),
))
}
};
let result = ArrayD::zeros(outshape.clone());
let result = Arc::new(Mutex::new(result));
let mut all_indices = Vec::new();
let mut shape_without_mode = tensor.shape().to_vec();
shape_without_mode.remove(mode);
let total_combinations: usize = shape_without_mode.iter().product();
all_indices.reserve(total_combinations);
fn generate_indices_without_mode(
shape: &[usize],
current: Vec<usize>,
depth: usize,
mode: usize,
_mode_dim: usize,
all_indices: &mut Vec<Vec<usize>>,
) {
if depth == shape.len() {
all_indices.push(current);
return;
}
if depth == mode {
generate_indices_without_mode(shape, current, depth + 1, mode, _mode_dim, all_indices);
return;
}
let dimsize = shape[depth];
let mut current = current;
for i in 0..dimsize {
current.push(i);
generate_indices_without_mode(
shape,
current.clone(),
depth + 1,
mode,
_mode_dim,
all_indices,
);
current.pop();
}
}
generate_indices_without_mode(
tensor.shape(),
Vec::new(),
0,
mode,
tensor.shape()[mode],
&mut all_indices,
);
use scirs2_core::parallel_ops::*;
let all_results: Vec<Vec<_>> = all_indices
.par_iter()
.map(|idx| {
let mut tensor_idx = idx.clone();
tensor_idx.insert(mode, 0);
let mut results = Vec::new();
for j in 0..matrix.shape()[0] {
let mut sum = A::zero();
for k in 0..tensor.shape()[mode] {
tensor_idx[mode] = k;
sum += tensor_dyn[tensor_idx.as_slice()] * matrix_view[[j, k]];
}
let mut result_idx = idx.clone();
result_idx.insert(mode, j);
results.push((result_idx, sum));
}
results
})
.collect();
for batch_results in all_results {
let mut result_tensor = result.lock().expect("Operation failed");
for (idx, sum) in batch_results {
result_tensor[idx.as_slice()] = sum;
}
}
Ok(Arc::try_unwrap(result)
.expect("Operation failed")
.into_inner()
.expect("Operation failed"))
}
#[allow(clippy::type_complexity)]
#[allow(dead_code)]
pub fn einsum<'a, A>(
einsum_str: &str,
tensors: &'a [&'a ArrayViewD<'a, A>],
) -> LinalgResult<ArrayD<A>>
where
A: Clone + Float + NumAssign + Zero + Send + Sync + Sum + Debug + 'static,
{
fn parse_einsum_notation(einsumstr: &str) -> LinalgResult<(Vec<Vec<char>>, Vec<char>)> {
let parts: Vec<&str> = einsumstr.split("->").collect();
if parts.len() != 2 {
return Err(LinalgError::ValueError(
"Einsum string must contain exactly one '->'".to_string(),
));
}
let inputs: Vec<&str> = parts[0].split(',').collect();
let mut input_indices = Vec::with_capacity(inputs.len());
for input in inputs {
let indices: Vec<char> = input.trim().chars().collect();
input_indices.push(indices);
}
let output_indices: Vec<char> = parts[1].trim().chars().collect();
Ok((input_indices, output_indices))
}
let (input_indices, output_indices) = parse_einsum_notation(einsum_str)?;
if tensors.len() != input_indices.len() {
return Err(LinalgError::ValueError(format!(
"Number of tensors ({}) doesn't match number of index groups ({})",
tensors.len(),
input_indices.len()
)));
}
for (i, (tensor, indices)) in tensors.iter().zip(input_indices.iter()).enumerate() {
if tensor.ndim() != indices.len() {
return Err(LinalgError::ShapeError(format!(
"Tensor {} has {} dimensions, but {} indices were provided",
i,
tensor.ndim(),
indices.len()
)));
}
}
let mut index_to_dim: HashMap<char, usize> = HashMap::new();
for (tensor, indices) in tensors.iter().zip(input_indices.iter()) {
for (&dimsize, &idx) in tensor.shape().iter().zip(indices.iter()) {
if let Some(&existing_dim) = index_to_dim.get(&idx) {
if existing_dim != dimsize {
return Err(LinalgError::ShapeError(format!(
"Inconsistent dimensions for index '{}': {} and {}",
idx, existing_dim, dimsize
)));
}
} else {
index_to_dim.insert(idx, dimsize);
}
}
}
for &idx in &output_indices {
if !index_to_dim.contains_key(&idx) {
return Err(LinalgError::ValueError(format!(
"Output index '{}' not found in any input indices",
idx
)));
}
}
let mut outputshape = Vec::with_capacity(output_indices.len());
for &idx in &output_indices {
outputshape.push(index_to_dim[&idx]);
}
let result = Arc::new(Mutex::new(ArrayD::zeros(outputshape.clone())));
let mut contracted_indices: Vec<char> = Vec::new();
for indices in input_indices.iter() {
for &idx in indices {
if !output_indices.contains(&idx) && !contracted_indices.contains(&idx) {
contracted_indices.push(idx);
}
}
}
let mut all_output_indices = Vec::new();
let total_output_combinations: usize = outputshape.iter().product();
all_output_indices.reserve(total_output_combinations);
fn generate_output_indices(
shape: &[usize],
current: Vec<usize>,
depth: usize,
all_indices: &mut Vec<Vec<usize>>,
) {
if depth == shape.len() {
all_indices.push(current);
return;
}
let mut current = current;
for i in 0..shape[depth] {
current.push(i);
generate_output_indices(shape, current.clone(), depth + 1, all_indices);
current.pop();
}
}
generate_output_indices(&outputshape, Vec::new(), 0, &mut all_output_indices);
use scirs2_core::parallel_ops::*;
let results: Vec<_> = all_output_indices
.par_iter()
.map(|output_idx| {
let mut index_values = HashMap::new();
for (i, &idx) in output_indices.iter().enumerate() {
index_values.insert(idx, output_idx[i]);
}
fn compute_sum_recursive<A>(
tensors: &[&ArrayViewD<A>],
input_indices: &[Vec<char>],
contracted_indices: &[char],
index_values: &mut HashMap<char, usize>,
index_to_dim: &HashMap<char, usize>,
depth: usize,
) -> A
where
A: Clone + Float + NumAssign + Zero + One,
{
if depth == contracted_indices.len() {
let mut product = A::one();
for (tensor, indices) in tensors.iter().zip(input_indices.iter()) {
let tensor_indices: Vec<usize> =
indices.iter().map(|&idx| index_values[&idx]).collect();
product *= tensor[tensor_indices.as_slice()];
}
return product;
}
let idx = contracted_indices[depth];
let _dim = index_to_dim[&idx];
let mut sum = A::zero();
for i in 0.._dim {
index_values.insert(idx, i);
sum += compute_sum_recursive(
tensors,
input_indices,
contracted_indices,
index_values,
index_to_dim,
depth + 1,
);
}
sum
}
let sum = compute_sum_recursive(
tensors,
&input_indices,
&contracted_indices,
&mut index_values,
&index_to_dim,
0,
);
(output_idx.clone(), sum)
})
.collect();
for (idx, sum) in results {
let mut result_tensor = result.lock().expect("Operation failed");
result_tensor[idx.as_slice()] = sum;
}
Ok(Arc::try_unwrap(result)
.expect("Operation failed")
.into_inner()
.expect("Operation failed"))
}
#[allow(dead_code)]
pub fn hosvd<A, D>(
tensor: &ArrayView<A, D>,
rank: &[usize],
) -> LinalgResult<(ArrayD<A>, Vec<Array2<A>>)>
where
A: Clone
+ Float
+ NumAssign
+ Zero
+ Send
+ Sync
+ Sum
+ Debug
+ 'static
+ scirs2_core::ndarray::ScalarOperand,
D: Dimension,
{
if rank.len() != tensor.ndim() {
return Err(LinalgError::ShapeError(format!(
"Rank vector length ({}) must match tensor dimensions ({})",
rank.len(),
tensor.ndim()
)));
}
for (i, &r) in rank.iter().enumerate() {
if r > tensor.shape()[i] {
return Err(LinalgError::ShapeError(format!(
"Rank for mode {} ({}) cannot exceed the mode dimension ({})",
i,
r,
tensor.shape()[i]
)));
}
}
let tensor_dyn = tensor.to_owned().into_dyn();
use scirs2_core::parallel_ops::*;
let modes: Vec<usize> = (0..tensor.ndim()).collect();
let factors: Vec<Array2<A>> = modes
.par_iter()
.map(|mode| {
let unfolded = unfold(&tensor_dyn, *mode).expect("Operation failed");
let (u, _, _) = svd_truncated(&unfolded, rank[*mode]).expect("Operation failed");
u
})
.collect();
let mut core = tensor_dyn.to_owned();
for (mode, factor) in factors.iter().enumerate() {
let factor_t = factor.t().to_owned();
core = mode_n_product(&core.view(), &factor_t.view(), mode)?;
}
Ok((core, factors))
}
#[allow(dead_code)]
fn unfold<A>(tensor: &ArrayD<A>, mode: usize) -> LinalgResult<Array2<A>>
where
A: Clone + Float + Debug + Send + Sync,
{
if mode >= tensor.ndim() {
return Err(LinalgError::ShapeError(format!(
"Mode {} is out of bounds for _tensor with {} dimensions",
mode,
tensor.ndim()
)));
}
let shape = tensor.shape();
let mode_dim = shape[mode];
let other_dims_prod: usize = shape
.iter()
.enumerate()
.filter(|&(i, _)| i != mode)
.map(|(_, &dim)| dim)
.product();
let mut result = Array2::zeros((mode_dim, other_dims_prod));
fn calc_col_idx(idx: &[usize], shape: &[usize], mode: usize) -> usize {
let mut col_idx = 0;
let mut stride = 1;
for dim in (0..shape.len()).rev() {
if dim != mode {
col_idx += idx[dim] * stride;
stride *= shape[dim];
}
}
col_idx
}
let tensorshape = tensor.shape().to_vec();
let mut all_indices = Vec::new();
let total_elements: usize = tensorshape.iter().product();
all_indices.reserve(total_elements);
fn generate_tensor_indices(
shape: &[usize],
current: Vec<usize>,
depth: usize,
all_indices: &mut Vec<Vec<usize>>,
) {
if depth == shape.len() {
all_indices.push(current);
return;
}
let mut current = current;
for i in 0..shape[depth] {
current.push(i);
generate_tensor_indices(shape, current.clone(), depth + 1, all_indices);
current.pop();
}
}
generate_tensor_indices(&tensorshape, Vec::new(), 0, &mut all_indices);
use scirs2_core::parallel_ops::*;
let results: Vec<_> = all_indices
.par_iter()
.map(|idx| {
let mode_idx = idx[mode];
let col_idx = calc_col_idx(idx, &tensorshape, mode);
let val = tensor[idx.as_slice()];
(mode_idx, col_idx, val)
})
.collect();
for (mode_idx, col_idx, val) in results {
result[[mode_idx, col_idx]] = val;
}
Ok(result)
}
#[allow(dead_code)]
pub fn svd_truncated<A>(
matrix: &Array2<A>,
rank: usize,
) -> LinalgResult<(Array2<A>, Array2<A>, Array2<A>)>
where
A: Clone
+ Float
+ NumAssign
+ Zero
+ Send
+ Sync
+ Sum
+ std::fmt::Debug
+ 'static
+ scirs2_core::ndarray::ScalarOperand,
{
use crate::decomposition::svd;
let matrix_view = matrix.view();
let (u, s, vt) = svd(&matrix_view, false, None)?;
let u_trunc = u.slice(scirs2_core::ndarray::s![.., ..rank]).to_owned();
let s_trunc = Array2::from_diag(&s.slice(scirs2_core::ndarray::s![..rank]));
let vt_trunc = vt.slice(scirs2_core::ndarray::s![..rank, ..]).to_owned();
Ok((u_trunc, s_trunc, vt_trunc))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
fn testmatrix_multiplication() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let b = array![[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]];
let result = contract(&a.view(), &b.view(), &[1], &[0]).expect("Operation failed");
let expected = array![[58.0, 64.0], [139.0, 154.0]];
assert_eq!(result.shape(), &[2, 2]);
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(result[[i, j]], expected[[i, j]], epsilon = 1e-10);
}
}
}
#[test]
fn test_batch_matmul() {
let a = array![
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
[[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]
];
let b = array![
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
[[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]
];
let result = batch_matmul(&a.view(), &b.view(), 1).expect("Operation failed");
let expected_batch0 = array![[22.0, 28.0], [49.0, 64.0]];
let expected_batch1 = array![[220.0, 244.0], [301.0, 334.0]];
assert_eq!(result.shape(), &[2, 2, 2]);
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(result[[0, i, j]], expected_batch0[[i, j]], epsilon = 1e-10);
assert_abs_diff_eq!(result[[1, i, j]], expected_batch1[[i, j]], epsilon = 1e-10);
}
}
}
#[test]
fn test_einsummatrix_multiplication() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let b = array![[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]];
let a_view = a.view().into_dyn();
let b_view = b.view().into_dyn();
let result = einsum("ij,jk->ik", &[&a_view, &b_view]).expect("Operation failed");
let expected = array![[58.0, 64.0], [139.0, 154.0]];
assert_eq!(result.shape(), &[2, 2]);
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(result[[i, j]], expected[[i, j]], epsilon = 1e-10);
}
}
}
#[test]
fn test_einsum_inner_product() {
let a = array![1.0, 2.0, 3.0];
let b = array![4.0, 5.0, 6.0];
let a_view = a.view().into_dyn();
let b_view = b.view().into_dyn();
let result = einsum("i,i->", &[&a_view, &b_view]).expect("Operation failed");
let expected = 1.0 * 4.0 + 2.0 * 5.0 + 3.0 * 6.0;
assert_eq!(result.shape(), &[] as &[usize]);
assert_abs_diff_eq!(
result.iter().next().expect("Operation failed"),
&expected,
epsilon = 1e-10
);
}
#[test]
#[ignore = "Needs investigation - possibly SVD-related issue"]
fn test_mode_n_product() {
let tensor = array![
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
[[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]
];
let matrix = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
let result = mode_n_product(&tensor.view(), &matrix.view(), 0).expect("Operation failed");
assert_eq!(result.shape(), &[4, 3, 2]);
assert_abs_diff_eq!(result[[0, 0, 0]], 1.0 * 1.0 + 2.0 * 7.0, epsilon = 1e-10); assert_abs_diff_eq!(result[[0, 0, 1]], 1.0 * 2.0 + 2.0 * 8.0, epsilon = 1e-10);
}
#[test]
fn test_hosvd_basic() {
let tensor = array![[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]];
let (core, factors) = hosvd(&tensor.view(), &[2, 2, 2]).expect("Operation failed");
assert_eq!(core.shape(), &[2, 2, 2]);
assert_eq!(factors.len(), 3);
assert_eq!(factors[0].shape(), &[2, 2]);
assert_eq!(factors[1].shape(), &[2, 2]);
assert_eq!(factors[2].shape(), &[2, 2]);
let mut reconstructed = core.clone();
for (mode, factor) in factors.iter().enumerate() {
reconstructed = mode_n_product(&reconstructed.view(), &factor.view(), mode)
.expect("Operation failed");
}
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
assert_abs_diff_eq!(
reconstructed[[i, j, k]],
tensor[[i, j, k]],
epsilon = 1e-5
);
}
}
}
}
}