use scirs2_core::ndarray::{Array, IxDyn};
use scirs2_core::Complex64;
use quantrs2_sim::tensor_network::tensor::Tensor;
fn make_tensor(shape: &[usize], values: Vec<Complex64>) -> Tensor {
let data = Array::from_shape_vec(IxDyn(shape), values)
.expect("shape / values length mismatch in make_tensor");
Tensor::new(data)
}
const fn c(re: f64) -> Complex64 {
Complex64::new(re, 0.0)
}
#[test]
fn test_contract_vector_dot_product() {
let v = make_tensor(&[3], vec![c(1.0), c(2.0), c(3.0)]);
let w = make_tensor(&[3], vec![c(4.0), c(5.0), c(6.0)]);
let result = v.contract(&w, 0, 0).expect("contraction failed");
assert_eq!(result.rank, 0, "dot product should be scalar (rank 0)");
assert!(result.dimensions.is_empty(), "scalar has no dimensions");
let val = result.data[IxDyn(&[])];
assert!((val - c(32.0)).norm() < 1e-10, "expected 32, got {val}");
}
#[test]
fn test_contract_matrix_multiplication() {
let a_vals: Vec<Complex64> = vec![1, 2, 3, 4, 5, 6]
.into_iter()
.map(|x| c(x as f64))
.collect();
let b_vals: Vec<Complex64> = vec![1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0]
.into_iter()
.map(|x| c(x as f64))
.collect();
let a = make_tensor(&[2, 3], a_vals);
let b = make_tensor(&[3, 4], b_vals);
let result = a
.contract(&b, 1, 0)
.expect("matrix-multiply contraction failed");
assert_eq!(
result.rank, 2,
"result of 2D × 2D contraction should be rank 2"
);
assert_eq!(result.dimensions, vec![2, 4], "result shape should be 2×4");
let expected = [
[c(1.0), c(2.0), c(3.0), c(0.0)],
[c(4.0), c(5.0), c(6.0), c(0.0)],
];
for (i, exp_row) in expected.iter().enumerate() {
for (j, &exp) in exp_row.iter().enumerate() {
let got = result.data[IxDyn(&[i, j])];
assert!(
(got - exp).norm() < 1e-10,
"mismatch at [{i},{j}]: expected {exp}, got {got}"
);
}
}
}
#[test]
fn test_contract_rank3_tensors() {
let a = make_tensor(&[2, 2, 2], vec![c(1.0); 8]);
let b = make_tensor(&[2, 2, 2], vec![c(1.0); 8]);
let result = a.contract(&b, 2, 0).expect("rank-3 contraction failed");
assert_eq!(result.rank, 4, "result should be rank 4");
assert_eq!(result.dimensions, vec![2, 2, 2, 2]);
for (_, val) in result.data.indexed_iter() {
assert!((val - c(2.0)).norm() < 1e-10, "expected 2.0, got {val}");
}
}
#[test]
fn test_contract_non_trivial_axis_selection() {
let a = make_tensor(&[2, 2], vec![c(1.0), c(2.0), c(3.0), c(4.0)]);
let b = make_tensor(&[2, 2], vec![c(5.0), c(6.0), c(7.0), c(8.0)]);
let result = a
.contract(&b, 0, 1)
.expect("non-trivial axis contraction failed");
assert_eq!(result.rank, 2);
assert_eq!(result.dimensions, vec![2, 2]);
let expected = [[c(23.0), c(31.0)], [c(34.0), c(46.0)]];
for (i, exp_row) in expected.iter().enumerate() {
for (j, &exp) in exp_row.iter().enumerate() {
let got = result.data[IxDyn(&[i, j])];
assert!(
(got - exp).norm() < 1e-10,
"mismatch at [{i},{j}]: expected {exp}, got {got}"
);
}
}
}
#[test]
fn test_contract_dimension_mismatch_returns_error() {
let a = make_tensor(&[2, 3], vec![c(0.0); 6]);
let b = make_tensor(&[4, 3], vec![c(0.0); 12]);
let result = a.contract(&b, 1, 0);
assert!(result.is_err(), "expected error for dimension mismatch");
}
#[test]
fn test_contract_out_of_range_axis_returns_error() {
let a = make_tensor(&[2, 2], vec![c(0.0); 4]);
let b = make_tensor(&[2, 2], vec![c(0.0); 4]);
let result = a.contract(&b, 5, 0);
assert!(result.is_err(), "expected error for out-of-range self_axis");
let result2 = a.contract(&b, 0, 5);
assert!(
result2.is_err(),
"expected error for out-of-range other_axis"
);
}
#[test]
fn test_svd_reconstruction() {
#[rustfmt::skip]
let vals: Vec<Complex64> = vec![
c(4.0), c(3.0), c(2.0), c(1.0),
c(3.0), c(4.0), c(3.0), c(2.0),
c(2.0), c(3.0), c(4.0), c(3.0),
c(1.0), c(2.0), c(3.0), c(4.0),
];
let t = make_tensor(&[4, 4], vals);
let (left, right) = t.svd(&[0], &[1], 4).expect("SVD failed");
let reconstructed = left
.contract(&right, left.rank - 1, 0)
.expect("SVD reconstruction contraction failed");
assert_eq!(reconstructed.rank, 2);
assert_eq!(reconstructed.dimensions, vec![4, 4]);
for (idx, orig_val) in t.data.indexed_iter() {
let rec_val = reconstructed.data[idx.clone()];
assert!(
(orig_val - rec_val).norm() < 1e-8,
"SVD reconstruction mismatch at {idx:?}: original {orig_val}, reconstructed {rec_val}"
);
}
}
#[test]
fn test_svd_truncation_reduces_bond_dimension() {
let diag_vals = [10.0_f64, 3.0, 0.1, 0.01];
let mut vals = vec![c(0.0); 16];
for (i, &d) in diag_vals.iter().enumerate() {
vals[i * 4 + i] = c(d);
}
let t = make_tensor(&[4, 4], vals);
let (left, right) = t.svd(&[0], &[1], 2).expect("truncated SVD failed");
let left_bond = *left.dimensions.last().expect("left tensor has dimensions");
let right_bond = right.dimensions[0];
assert_eq!(left_bond, 2, "left tensor bond dim should be 2");
assert_eq!(right_bond, 2, "right tensor bond dim should be 2");
let reconstructed = left
.contract(&right, left.rank - 1, 0)
.expect("reconstruction contraction failed");
let v00 = reconstructed.data[IxDyn(&[0, 0])];
let v11 = reconstructed.data[IxDyn(&[1, 1])];
assert!(
(v00 - c(10.0)).norm() < 0.5,
"expected ~10 at [0,0], got {v00}"
);
assert!(
(v11 - c(3.0)).norm() < 0.5,
"expected ~3 at [1,1], got {v11}"
);
}
#[test]
fn test_svd_rectangular_matrix() {
let vals: Vec<Complex64> = (1..=8).map(|x| c(x as f64)).collect();
let t = make_tensor(&[2, 4], vals);
let (left, right) = t
.svd(&[0], &[1], 8)
.expect("SVD of rectangular matrix failed");
let bond = *left.dimensions.last().expect("left has dims");
assert!(bond <= 2, "bond dim should be <= min(m,n) = 2, got {bond}");
assert_eq!(right.dimensions[0], bond, "right bond dim should match");
assert_eq!(left.dimensions[0], 2, "left outer dim should be 2");
assert_eq!(right.dimensions[1], 4, "right outer dim should be 4");
}
#[test]
fn test_svd_rank3_tensor() {
let vals: Vec<Complex64> = (0..16).map(|i| c((i as f64 + 1.0) * 0.5)).collect();
let t = make_tensor(&[2, 2, 4], vals);
let (left, right) = t.svd(&[0, 1], &[2], 4).expect("rank-3 SVD failed");
assert_eq!(left.rank, 3, "left should be rank 3");
assert_eq!(right.rank, 2, "right should be rank 2");
assert_eq!(left.dimensions[0], 2);
assert_eq!(left.dimensions[1], 2);
assert_eq!(right.dimensions[1], 4);
assert_eq!(
*left.dimensions.last().unwrap(),
right.dimensions[0],
"bond dims must match"
);
}
#[test]
fn test_svd_invalid_axes_returns_error() {
let t = make_tensor(&[2, 2], vec![c(0.0); 4]);
let result = t.svd(&[0], &[], 2);
assert!(result.is_err(), "expected error: wrong total axes count");
let result3 = t.svd(&[0], &[5], 2);
assert!(result3.is_err(), "expected error: axis out of range");
}
#[test]
fn test_svd_zero_bond_dim_returns_error() {
let t = make_tensor(&[2, 2], vec![c(0.0); 4]);
let result = t.svd(&[0], &[1], 0);
assert!(result.is_err(), "expected error for max_bond_dim = 0");
}