use num_traits::Float;
use numrs2::linalg::tensor_ops::{kron, tensordot};
use numrs2::prelude::*;
#[test]
fn test_kron_basic() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
let b = Array::from_vec(vec![5.0, 6.0, 7.0, 8.0]).reshape(&[2, 2]);
let result = kron(&a, &b).unwrap();
assert_eq!(result.shape(), &[4, 4]);
let expected = [
5.0, 6.0, 10.0, 12.0, 7.0, 8.0, 14.0, 16.0, 15.0, 18.0, 20.0, 24.0, 21.0, 24.0, 28.0, 32.0,
];
let result_data = result.to_vec();
for (i, (&actual, &expected)) in result_data.iter().zip(expected.iter()).enumerate() {
assert!(
Float::abs(actual - expected) < 1e-10f64,
"Mismatch at index {}: {} != {}",
i,
actual,
expected
);
}
}
#[test]
fn test_kron_identity() {
let a = Array::from_vec(vec![1.0, 0.0, 0.0, 1.0]).reshape(&[2, 2]); let b = Array::from_vec(vec![5.0, 6.0, 7.0, 8.0]).reshape(&[2, 2]);
let result = kron(&a, &b).unwrap();
assert_eq!(result.shape(), &[4, 4]);
let expected = [
5.0, 6.0, 0.0, 0.0, 7.0, 8.0, 0.0, 0.0, 0.0, 0.0, 5.0, 6.0, 0.0, 0.0, 7.0, 8.0,
];
let result_data = result.to_vec();
for (i, (&actual, &expected)) in result_data.iter().zip(expected.iter()).enumerate() {
assert!(
Float::abs(actual - expected) < 1e-10f64,
"Mismatch at index {}: {} != {}",
i,
actual,
expected
);
}
}
#[test]
fn test_tensordot_matrix_multiplication() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
let b = Array::from_vec(vec![5.0, 6.0, 7.0, 8.0]).reshape(&[2, 2]);
let result = tensordot(&a, &b, &[1, 0]).unwrap();
let expected = a.matmul(&b).unwrap();
assert_eq!(result.shape(), expected.shape());
let result_data = result.to_vec();
let expected_data = expected.to_vec();
for (i, (&actual, &expected)) in result_data.iter().zip(expected_data.iter()).enumerate() {
assert!(
Float::abs(actual - expected) < 1e-10f64,
"Mismatch at index {}: {} != {}",
i,
actual,
expected
);
}
}
#[test]
fn test_tensordot_error_handling() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
let b = Array::from_vec(vec![5.0, 6.0, 7.0]).reshape(&[3, 1]);
let result = tensordot(&a, &b, &[1, 0]);
assert!(result.is_err());
let result = tensordot(&a, &b, &[1, 0, 1]);
assert!(result.is_err());
}
#[test]
fn test_kron_error_handling() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0]).reshape(&[3]); let b = Array::from_vec(vec![5.0, 6.0, 7.0, 8.0]).reshape(&[2, 2]);
let result = kron(&a, &b);
assert!(result.is_err());
}