use super::super::*;
#[test]
fn test_matvec_basic() {
let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let v = Vector::from_slice(&[1.0, 2.0, 3.0]);
let result = m.matvec(&v).unwrap();
assert_eq!(result.len(), 2);
assert!((result.as_slice()[0] - 14.0).abs() < 1e-6);
assert!((result.as_slice()[1] - 32.0).abs() < 1e-6);
}
#[test]
fn test_matvec_identity() {
let m = Matrix::identity(3);
let v = Vector::from_slice(&[1.0, 2.0, 3.0]);
let result = m.matvec(&v).unwrap();
assert_eq!(result.as_slice(), v.as_slice());
}
#[test]
fn test_matvec_dimension_mismatch() {
let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let v = Vector::from_slice(&[1.0, 2.0]);
assert!(m.matvec(&v).is_err());
}
#[test]
fn test_vecmat_basic() {
let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let v = Vector::from_slice(&[1.0, 2.0]);
let result = Matrix::vecmat(&v, &m).unwrap();
assert_eq!(result.len(), 3);
assert!((result.as_slice()[0] - 9.0).abs() < 1e-6);
assert!((result.as_slice()[1] - 12.0).abs() < 1e-6);
assert!((result.as_slice()[2] - 15.0).abs() < 1e-6);
}
#[test]
fn test_vecmat_identity() {
let m = Matrix::identity(3);
let v = Vector::from_slice(&[1.0, 2.0, 3.0]);
let result = Matrix::vecmat(&v, &m).unwrap();
assert_eq!(result.as_slice(), v.as_slice());
}
#[test]
fn test_vecmat_dimension_mismatch() {
let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let v = Vector::from_slice(&[1.0, 2.0, 3.0]);
assert!(Matrix::vecmat(&v, &m).is_err());
}
#[test]
fn test_matvec_zero_vector() {
let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let v = Vector::from_slice(&[0.0, 0.0, 0.0]);
let result = m.matvec(&v).unwrap();
assert_eq!(result.as_slice(), &[0.0, 0.0]);
}
#[test]
fn test_vecmat_zero_vector() {
let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let v = Vector::from_slice(&[0.0, 0.0]);
let result = Matrix::vecmat(&v, &m).unwrap();
assert_eq!(result.as_slice(), &[0.0, 0.0, 0.0]);
}
#[test]
fn test_matvec_transpose_equivalence() {
let m = Matrix::from_vec(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let v = Vector::from_slice(&[1.0, 2.0]);
let av = m.matvec(&v).unwrap();
let m_t = m.transpose(); let v_mt = Matrix::vecmat(&v, &m_t).unwrap();
assert_eq!(av.as_slice(), v_mt.as_slice());
}
#[test]
fn test_convolve2d_basic_3x3() {
let input = Matrix::from_vec(3, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]).unwrap();
let kernel = Matrix::from_vec(1, 1, vec![1.0]).unwrap();
let result = input.convolve2d(&kernel).unwrap();
assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3);
assert_eq!(result.as_slice(), input.as_slice());
}
#[test]
fn test_convolve2d_edge_detection() {
let input = Matrix::from_vec(
4,
4,
vec![
1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 1.0, 2.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, ],
)
.unwrap();
#[rustfmt::skip]
let kernel = Matrix::from_vec(
3,
3,
vec![
-1.0, -1.0, -1.0,
0.0, 0.0, 0.0,
1.0, 1.0, 1.0,
],
)
.unwrap();
let result = input.convolve2d(&kernel).unwrap();
assert_eq!(result.rows(), 2);
assert_eq!(result.cols(), 2);
}
#[test]
fn test_convolve2d_averaging_filter() {
let input = Matrix::from_vec(
5,
5,
vec![
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 9.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ],
)
.unwrap();
let kernel_val = 1.0 / 9.0;
let kernel = Matrix::from_vec(
3,
3,
vec![
kernel_val, kernel_val, kernel_val, kernel_val, kernel_val, kernel_val, kernel_val, kernel_val, kernel_val, ],
)
.unwrap();
let result = input.convolve2d(&kernel).unwrap();
assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3);
assert!((result.get(1, 1).unwrap() - 1.0).abs() < 1e-5);
}
#[test]
fn test_convolve2d_invalid_kernel() {
let input = Matrix::from_vec(3, 3, vec![1.0; 9]).unwrap();
let kernel = Matrix::from_vec(4, 4, vec![1.0; 16]).unwrap();
assert!(input.convolve2d(&kernel).is_err());
}