use crate::Tensor;
#[test]
fn vv_dot_basic() {
let a = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
let b = Tensor::from_vec(vec![4.0, 5.0, 6.0]);
let d = a.dot(&b);
assert!(d.is_scalar());
assert_eq!(d.as_slice(), &[32.0]); }
#[test]
fn vv_dot_orthogonal() {
let a = Tensor::from_vec(vec![1.0, 0.0, 0.0]);
let b = Tensor::from_vec(vec![0.0, 1.0, 0.0]);
assert_eq!(a.dot(&b).as_slice(), &[0.0]);
}
#[test]
#[should_panic(expected = "lengths must match")]
fn vv_dot_length_mismatch_panics() {
let a = Tensor::from_vec(vec![1.0, 2.0]);
let b = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
let _ = a.dot(&b);
}
#[test]
fn matrix_vector_mul() {
let m = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let v = Tensor::from_vec(vec![1.0, 0.0, 1.0]);
let r = m.matmul(&v);
assert_eq!(r.shape(), &[2]);
assert_eq!(r.as_slice(), &[4.0, 10.0]);
}
#[test]
fn vector_matrix_mul() {
let v = Tensor::from_vec(vec![1.0, 2.0]);
let m = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let r = v.matmul(&m);
assert_eq!(r.shape(), &[3]);
assert_eq!(r.as_slice(), &[9.0, 12.0, 15.0]);
}
#[test]
fn matrix_matrix_mul() {
let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = Tensor::new(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]);
let c = a.matmul(&b);
assert_eq!(c.shape(), &[2, 2]);
assert_eq!(c.as_slice(), &[19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn matmul_non_square() {
let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let b = Tensor::new((1..=12).map(|x| x as f64).collect(), &[3, 4]);
let c = a.matmul(&b);
assert_eq!(c.shape(), &[2, 4]);
assert_eq!(c.as_slice()[0], 38.0); assert_eq!(c.as_slice()[1], 44.0); }
#[test]
fn dot_and_matmul_are_aliases() {
let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = Tensor::new(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]);
assert_eq!(a.dot(&b), a.matmul(&b));
}
#[test]
#[should_panic(expected = "left columns")]
fn matmul_dimension_mismatch_panics() {
let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let b = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[4, 2]);
let _ = a.matmul(&b);
}
#[test]
#[should_panic(expected = "unsupported rank")]
fn matmul_rank3_panics() {
let a = Tensor::zeros(&[2, 2, 2]);
let b = Tensor::zeros(&[2, 2, 2]);
let _ = a.matmul(&b);
}
#[test]
fn star_is_still_element_wise_not_matmul() {
let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = Tensor::new(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]);
assert_eq!((&a * &b).as_slice(), &[5.0, 12.0, 21.0, 32.0]); assert_eq!(a.matmul(&b).as_slice(), &[19.0, 22.0, 43.0, 50.0]); }