use crate::Tensor;
use crate::test::helpers::*;
use ndarray::Array2;
use svod_dtype::DType;
crate::codegen_tests! {
#[test_case(2 ; "N=2")]
#[test_case(4 ; "N=4")]
fn test_realized_matrix_matmul(config, n: usize) {
test_setup();
let data: Vec<f32> = (0..n * n).map(|i| i as f32 * 0.1).collect();
let matrix = Tensor::from_ndarray(&Array2::from_shape_vec((n, n), data).unwrap());
let x = Tensor::from_ndarray(&Array2::from_shape_vec((n, 1), vec![1.0f32; n]).unwrap());
let out = matrix.dot(&x).unwrap().try_reshape([n as isize]).unwrap();
let mut out = out;
out.realize_with(&config).expect("realized matrix matmul");
assert_eq!(out.as_vec::<f32>().unwrap().len(), n);
}
#[test_case(2 ; "N=2")]
#[test_case(4 ; "N=4")]
fn test_unary_on_buffer_rooted_matmul(config, n: usize) {
test_setup();
let data: Vec<f32> = (0..n * n).map(|i| i as f32 * 0.1).collect();
let matrix = Tensor::from_ndarray(&Array2::from_shape_vec((n, n), data).unwrap()).cos().unwrap();
let x = Tensor::from_ndarray(&Array2::from_shape_vec((n, 1), vec![1.0f32; n]).unwrap());
let out = matrix.dot(&x).unwrap().try_reshape([n as isize]).unwrap();
let mut out = out;
out.realize_with(&config).expect("unary on buffer-rooted matmul");
assert_eq!(out.as_vec::<f32>().unwrap().len(), n);
}
fn test_diamond_elementwise_no_matmul(config) {
test_setup();
let t = Tensor::from_ndarray(&Array2::from_shape_vec((2, 2), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap());
let out = t.cos().unwrap().try_add(&t.sin().unwrap()).unwrap();
let mut out = out;
out.realize_with(&config).expect("diamond elementwise");
let expected: Vec<f32> = [1.0f32, 2.0, 3.0, 4.0].iter().map(|x| x.cos() + x.sin()).collect();
assert_close_f32(&out.as_vec::<f32>().unwrap(), &expected, 1e-5);
}
#[test_case(2 ; "N=2")]
#[test_case(4 ; "N=4")]
fn test_lazy_outer_product_matmul(config, n: usize) {
test_setup();
let indices = Tensor::arange(n as i64, None, None).unwrap().cast(DType::Float32).unwrap();
let k = indices.try_reshape([n as isize, 1]).unwrap();
let j = indices.try_reshape([1, n as isize]).unwrap();
let matrix = k.try_mul(&j).unwrap();
let x = Tensor::from_ndarray(&Array2::from_shape_vec((n, 1), vec![1.0f32; n]).unwrap());
let out = matrix.dot(&x).unwrap().try_reshape([n as isize]).unwrap();
let mut out = out;
out.realize_with(&config).expect("lazy outer product matmul");
assert_eq!(out.as_vec::<f32>().unwrap().len(), n);
}
#[test_case(2 ; "N=2")]
#[test_case(4 ; "N=4")]
fn test_lazy_outer_product_unary_matmul(config, n: usize) {
test_setup();
let indices = Tensor::arange(n as i64, None, None).unwrap().cast(DType::Float32).unwrap();
let k = indices.try_reshape([n as isize, 1]).unwrap();
let j = indices.try_reshape([1, n as isize]).unwrap();
let matrix = k.try_mul(&j).unwrap().cos().unwrap();
let x = Tensor::from_ndarray(&Array2::from_shape_vec((n, 1), vec![1.0f32; n]).unwrap());
let out = matrix.dot(&x).unwrap().try_reshape([n as isize]).unwrap();
let mut out = out;
out.realize_with(&config).expect("lazy outer product unary matmul");
assert_eq!(out.as_vec::<f32>().unwrap().len(), n);
}
#[test_case(2 ; "N=2")]
#[test_case(4 ; "N=4")]
fn test_dft_pattern(config, n: usize) {
test_setup();
let indices = Tensor::arange(n as i64, None, None).unwrap().cast(DType::Float32).unwrap();
let k = indices.try_reshape([n as isize, 1]).unwrap();
let j = indices.try_reshape([1, n as isize]).unwrap();
let angles = k.try_mul(&j).unwrap().try_mul(&Tensor::from_slice([-0.5f32])).unwrap();
let cos_w = angles.cos().unwrap();
let sin_w = angles.sin().unwrap();
let x = Tensor::from_ndarray(&Array2::from_shape_vec((n, 1), vec![1.0f32; n]).unwrap());
let out = cos_w.dot(&x).unwrap().try_add(&sin_w.dot(&x).unwrap()).unwrap().try_reshape([n as isize]).unwrap();
let mut out = out;
out.realize_with(&config).expect("DFT pattern");
assert_eq!(out.as_vec::<f32>().unwrap().len(), n);
}
}