use super::*;
use crate::Vector;
#[test]
fn test_get_mut_valid() {
let mut m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
if let Some(val) = m.get_mut(0, 1) {
*val = 99.0;
}
assert_eq!(m.get(0, 1), Some(&99.0));
}
#[test]
fn test_get_mut_out_of_bounds() {
let mut m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
assert!(m.get_mut(5, 0).is_none());
assert!(m.get_mut(0, 10).is_none());
assert!(m.get_mut(10, 10).is_none());
}
#[test]
fn test_matrix_zeros_coverage() {
let m: Matrix<f32> = Matrix::zeros(3, 4);
assert_eq!(m.rows(), 3);
assert_eq!(m.cols(), 4);
for val in m.as_slice() {
assert_eq!(*val, 0.0);
}
}
#[test]
fn test_matrix_identity_coverage() {
let m: Matrix<f32> = Matrix::identity(3);
assert_eq!(m.get(0, 0), Some(&1.0));
assert_eq!(m.get(1, 1), Some(&1.0));
assert_eq!(m.get(2, 2), Some(&1.0));
assert_eq!(m.get(0, 1), Some(&0.0));
assert_eq!(m.get(1, 0), Some(&0.0));
}
#[test]
fn test_get_out_of_bounds_coverage() {
let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
assert!(m.get(5, 0).is_none());
assert!(m.get(0, 10).is_none());
}
#[test]
fn test_pad_kitchen_sink() {
let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let padded = m.pad(((1, 1), (1, 1)), 0.0).unwrap();
assert_eq!(padded.rows(), 4);
assert_eq!(padded.cols(), 5);
assert_eq!(padded.get(0, 0), Some(&0.0)); assert_eq!(padded.get(1, 1), Some(&1.0));
let padded2 = m.pad(((2, 0), (0, 3)), 9.0).unwrap();
assert_eq!(padded2.rows(), 4);
assert_eq!(padded2.cols(), 6);
assert_eq!(padded2.get(0, 0), Some(&9.0)); assert_eq!(padded2.get(2, 0), Some(&1.0));
let padded3 = m.pad(((0, 0), (0, 0)), 0.0).unwrap();
assert_eq!(padded3.rows(), 2);
assert_eq!(padded3.cols(), 3);
}
#[test]
fn test_gather_kitchen_sink() {
let m = Matrix::from_vec(3, 4, (0..12).map(|x| x as f32).collect()).unwrap();
let rows = m.gather(&[0, 2], 0).unwrap();
assert_eq!(rows.rows(), 2);
assert_eq!(rows.cols(), 4);
assert_eq!(rows.get(0, 0), Some(&0.0)); assert_eq!(rows.get(1, 0), Some(&8.0));
let cols = m.gather(&[1, 3], 1).unwrap();
assert_eq!(cols.rows(), 3);
assert_eq!(cols.cols(), 2);
}
#[test]
fn test_topk_kitchen_sink() {
let m = Matrix::from_vec(2, 4, vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]).unwrap();
let (vals, idxs) = m.topk(1).unwrap();
assert_eq!(vals.len(), 1);
assert_eq!(idxs.len(), 1);
assert_eq!(vals[0], 9.0);
let (vals2, idxs2) = m.topk(3).unwrap();
assert_eq!(vals2.len(), 3);
assert_eq!(idxs2.len(), 3);
let (vals3, _) = m.topk(8).unwrap();
assert_eq!(vals3.len(), 8);
let (vals4, idxs4) = m.topk(0).unwrap();
assert_eq!(vals4.len(), 0);
assert_eq!(idxs4.len(), 0);
}
#[test]
fn test_pooling_kitchen_sink() {
let m = Matrix::from_vec(4, 4, (0..16).map(|x| x as f32).collect()).unwrap();
let max_pooled = m.max_pool2d((2, 2), (2, 2)).unwrap();
assert_eq!(max_pooled.rows(), 2);
assert_eq!(max_pooled.cols(), 2);
let avg_pooled = m.avg_pool2d((2, 2), (2, 2)).unwrap();
assert_eq!(avg_pooled.rows(), 2);
assert_eq!(avg_pooled.cols(), 2);
let m2 = Matrix::from_vec(5, 5, (0..25).map(|x| x as f32).collect()).unwrap();
let max_pooled2 = m2.max_pool2d((2, 2), (2, 2)).unwrap();
assert_eq!(max_pooled2.rows(), 2); assert_eq!(max_pooled2.cols(), 2);
}
#[test]
fn test_vecmat_kitchen_sink() {
let v = Vector::from_slice(&[1.0, 2.0, 3.0]);
let m =
Matrix::from_vec(3, 4, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0])
.unwrap();
let result = Matrix::vecmat(&v, &m).unwrap();
assert_eq!(result.len(), 4);
assert!((result.as_slice()[0] - 38.0).abs() < 1e-5);
assert!((result.as_slice()[1] - 44.0).abs() < 1e-5);
assert!((result.as_slice()[2] - 50.0).abs() < 1e-5);
assert!((result.as_slice()[3] - 56.0).abs() < 1e-5);
}
#[test]
fn test_convolve2d_kitchen_sink() {
let input = Matrix::from_vec(3, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]).unwrap();
let kernel = Matrix::from_vec(3, 3, vec![0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]).unwrap();
let result = input.convolve2d(&kernel).unwrap();
assert_eq!(result.rows(), 1);
assert_eq!(result.cols(), 1);
assert!((result.get(0, 0).unwrap() - 5.0).abs() < 1e-5);
let input5 = Matrix::from_vec(5, 5, (0..25).map(|x| x as f32).collect()).unwrap();
let result5 = input5.convolve2d(&kernel).unwrap();
assert_eq!(result5.rows(), 3);
assert_eq!(result5.cols(), 3);
}
#[test]
fn test_embedding_kitchen_sink() {
let embeddings = Matrix::from_vec(5, 4, (0..20).map(|x| x as f32).collect()).unwrap();
let result = embeddings.embedding_lookup(&[0]).unwrap();
assert_eq!(result.rows(), 1);
assert_eq!(result.cols(), 4);
let result2 = embeddings.embedding_lookup(&[0, 2, 4]).unwrap();
assert_eq!(result2.rows(), 3);
assert_eq!(result2.cols(), 4);
let (result3_matrix, result3_indices) = embeddings.embedding_lookup_sparse(&[0, 1, 2]).unwrap();
assert_eq!(result3_matrix.rows(), 3);
assert_eq!(result3_matrix.cols(), 4);
assert_eq!(result3_indices.len(), 3);
}
#[test]
fn test_batched_matmul_4d_kitchen_sink() {
let batch = 2;
let heads = 2;
let m = 3;
let k = 4;
let n = 5;
let a_data: Vec<f32> = (0..48).map(|x| x as f32 * 0.1).collect();
let b_data: Vec<f32> = (0..80).map(|x| x as f32 * 0.1).collect();
let result = Matrix::batched_matmul_4d(&a_data, &b_data, batch, heads, m, k, n).unwrap();
assert_eq!(result.len(), batch * heads * m * n);
}
#[test]
fn test_matmul_remainder_kitchen_sink() {
for m in [1, 3, 5, 7, 9, 13, 17] {
for k in [1, 3, 5, 7, 9, 15] {
for n in [1, 3, 5, 7, 11] {
let a = Matrix::from_vec(m, k, vec![1.0; m * k]).unwrap();
let b = Matrix::from_vec(k, n, vec![1.0; k * n]).unwrap();
let c = a.matmul(&b).unwrap();
assert_eq!(c.rows(), m);
assert_eq!(c.cols(), n);
assert!((c.get(0, 0).unwrap() - k as f32).abs() < 1e-4);
}
}
}
}
#[test]
fn test_transpose_kitchen_sink() {
let row = Matrix::from_vec(1, 5, vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
let col = row.transpose();
assert_eq!(col.rows(), 5);
assert_eq!(col.cols(), 1);
let col2 = Matrix::from_vec(5, 1, vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
let row2 = col2.transpose();
assert_eq!(row2.rows(), 1);
assert_eq!(row2.cols(), 5);
let single = Matrix::from_vec(1, 1, vec![42.0]).unwrap();
let single_t = single.transpose();
assert_eq!(single_t.get(0, 0), Some(&42.0));
}