use super::{check_axis_count, diagonal, tensordot, tensordot_axes, trace, tril, triu};
use crate::{array::Array, dtype::Dtype, error::Error};
fn mat3() -> Array {
Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], &[3, 3]).unwrap()
}
#[test]
fn check_axis_count_boundary() {
assert!(check_axis_count("t", 0).is_ok());
assert!(check_axis_count("t", i32::MAX as usize).is_ok());
let over = i32::MAX as usize + 1;
match check_axis_count("ctx", over) {
Err(Error::CapExceeded(p)) => {
assert_eq!(p.context(), "ctx");
assert_eq!(p.cap(), i32::MAX as u64);
assert_eq!(p.observed(), over as u64);
}
other => panic!("expected Err(CapExceeded) one past the cap, got {other:?}"),
}
}
#[test]
fn diagonal_offset_i32_min_is_typed_error() {
match diagonal(&mat3(), i32::MIN, 0, 1) {
Err(Error::OutOfRange(p)) => assert_eq!(p.context(), "diagonal: offset"),
other => panic!("expected OutOfRange for i32::MIN offset, got {other:?}"),
}
assert!(diagonal(&mat3(), 1, 0, 1).is_ok());
assert!(diagonal(&mat3(), -1, 0, 1).is_ok());
}
#[test]
fn trace_offset_i32_min_is_typed_error() {
match trace(&mat3(), i32::MIN, 0, 1, None) {
Err(Error::OutOfRange(p)) => assert_eq!(p.context(), "trace: offset"),
other => panic!("expected OutOfRange for i32::MIN offset, got {other:?}"),
}
assert!(trace(&mat3(), 0, 0, 1, None).is_ok());
}
#[test]
fn tril_k_overflow_is_typed_error() {
match tril(&mat3(), i32::MIN) {
Err(Error::ArithmeticOverflow(p)) => assert_eq!(p.context(), "tril: k"),
other => panic!("expected ArithmeticOverflow for i32::MIN k, got {other:?}"),
}
assert!(matches!(
tril(&mat3(), i32::MIN + 1),
Err(Error::ArithmeticOverflow(_))
));
assert!(tril(&mat3(), 0).is_ok());
assert!(tril(&mat3(), -1).is_ok());
}
#[test]
fn triu_k_overflow_is_typed_error() {
match triu(&mat3(), i32::MIN) {
Err(Error::ArithmeticOverflow(p)) => assert_eq!(p.context(), "triu: k - 1"),
other => panic!("expected ArithmeticOverflow for i32::MIN k, got {other:?}"),
}
assert!(matches!(
triu(&mat3(), i32::MIN + 1),
Err(Error::ArithmeticOverflow(_))
));
assert!(triu(&mat3(), 0).is_ok());
assert!(triu(&mat3(), 1).is_ok());
}
#[test]
fn tensordot_int_full_contraction() {
let a = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
let b = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
let mut c = tensordot(&a, &b, 2).unwrap();
assert_eq!(c.to_vec::<f32>().unwrap(), vec![30.0]);
}
#[test]
fn tensordot_int_zero_axes_is_outer() {
let a = Array::from_slice(&[1.0f32, 2.0], &[2]).unwrap();
let b = Array::from_slice(&[3.0f32, 4.0], &[2]).unwrap();
let mut c = tensordot(&a, &b, 0).unwrap();
assert_eq!(c.shape(), vec![2, 2]);
assert_eq!(c.to_vec::<f32>().unwrap(), vec![3.0, 4.0, 6.0, 8.0]);
}
#[test]
fn tensordot_int_one_axis_is_matmul() {
let a = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
let b = Array::from_slice(&[5.0f32, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
let mut c = tensordot(&a, &b, 1).unwrap();
assert_eq!(c.to_vec::<f32>().unwrap(), vec![19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn tensordot_int_negative_axis_errors() {
let a = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
let b = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
assert!(tensordot(&a, &b, -1).is_err());
}
#[test]
fn tensordot_axes_matmul_equivalent() {
let a = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
let b = Array::from_slice(&[5.0f32, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
let mut c = tensordot_axes(&a, &b, &[1], &[0]).unwrap();
assert_eq!(c.to_vec::<f32>().unwrap(), vec![19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn tensordot_axes_full_contraction() {
let a = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
let b = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
let mut c = tensordot_axes(&a, &b, &[0, 1], &[0, 1]).unwrap();
assert_eq!(c.to_vec::<f32>().unwrap(), vec![30.0]);
}
#[test]
fn tensordot_axes_negative_axis_matches_matmul() {
let a = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
let b = Array::from_slice(&[5.0f32, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
let mut c = tensordot_axes(&a, &b, &[-1], &[0]).unwrap();
assert_eq!(c.to_vec::<f32>().unwrap(), vec![19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn tensordot_axes_length_mismatch_is_typed_error() {
let a = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
let b = Array::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
match tensordot_axes(&a, &b, &[0, 1], &[0]).unwrap_err() {
Error::LengthMismatch(p) => {
assert_eq!(p.expected(), 2);
assert_eq!(p.actual(), 1);
}
other => panic!("expected LengthMismatch, got {other:?}"),
}
}
#[test]
fn diagonal_main() {
let mut d = diagonal(&mat3(), 0, 0, 1).unwrap();
assert_eq!(d.to_vec::<f32>().unwrap(), vec![1.0, 5.0, 9.0]);
}
#[test]
fn diagonal_positive_offset() {
let mut d = diagonal(&mat3(), 1, 0, 1).unwrap();
assert_eq!(d.to_vec::<f32>().unwrap(), vec![2.0, 6.0]);
}
#[test]
fn diagonal_negative_offset() {
let mut d = diagonal(&mat3(), -1, 0, 1).unwrap();
assert_eq!(d.to_vec::<f32>().unwrap(), vec![4.0, 8.0]);
}
#[test]
fn diagonal_negative_axes() {
let mut d = diagonal(&mat3(), 0, -2, -1).unwrap();
assert_eq!(d.to_vec::<f32>().unwrap(), vec![1.0, 5.0, 9.0]);
}
#[test]
fn trace_main() {
let mut t = trace(&mat3(), 0, 0, 1, None).unwrap();
assert_eq!(t.to_vec::<f32>().unwrap(), vec![15.0]);
}
#[test]
fn trace_positive_offset() {
let mut t = trace(&mat3(), 1, 0, 1, None).unwrap();
assert_eq!(t.to_vec::<f32>().unwrap(), vec![8.0]);
}
#[test]
fn trace_negative_offset() {
let mut t = trace(&mat3(), -1, 0, 1, None).unwrap();
assert_eq!(t.to_vec::<f32>().unwrap(), vec![12.0]);
}
#[test]
fn trace_explicit_dtype_promotes() {
let a = Array::from_slice(&[1i32, 2, 3, 4], &[2, 2]).unwrap();
let mut t = trace(&a, 0, 0, 1, Some(Dtype::F32)).unwrap();
assert_eq!(t.dtype().unwrap(), Dtype::F32);
assert_eq!(t.to_vec::<f32>().unwrap(), vec![5.0]);
}
#[test]
fn trace_default_dtype_is_input_dtype() {
let a = Array::from_slice(&[1i32, 2, 3, 4], &[2, 2]).unwrap();
let mut t = trace(&a, 0, 0, 1, None).unwrap();
assert_eq!(t.dtype().unwrap(), Dtype::I32);
assert_eq!(t.to_vec::<i32>().unwrap(), vec![5]);
}
#[test]
fn tril_k_zero() {
let mut l = tril(&mat3(), 0).unwrap();
assert_eq!(
l.to_vec::<f32>().unwrap(),
vec![1.0, 0.0, 0.0, 4.0, 5.0, 0.0, 7.0, 8.0, 9.0]
);
}
#[test]
fn tril_k_positive() {
let mut l = tril(&mat3(), 1).unwrap();
assert_eq!(
l.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 0.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]
);
}
#[test]
fn tril_k_negative() {
let mut l = tril(&mat3(), -1).unwrap();
assert_eq!(
l.to_vec::<f32>().unwrap(),
vec![0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 7.0, 8.0, 0.0]
);
}
#[test]
fn triu_k_zero() {
let mut u = triu(&mat3(), 0).unwrap();
assert_eq!(
u.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 3.0, 0.0, 5.0, 6.0, 0.0, 0.0, 9.0]
);
}
#[test]
fn triu_k_positive() {
let mut u = triu(&mat3(), 1).unwrap();
assert_eq!(
u.to_vec::<f32>().unwrap(),
vec![0.0, 2.0, 3.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0]
);
}
#[test]
fn triu_k_negative() {
let mut u = triu(&mat3(), -1).unwrap();
assert_eq!(
u.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 8.0, 9.0]
);
}
#[test]
fn tril_requires_2d() {
let v = Array::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
assert!(tril(&v, 0).is_err());
}