use numrs::array::Array;
use numrs::backend::{get_dispatch_table, validate_backends};
use numrs::llo::reduction::ReductionKind;
#[test]
fn test_dispatch_table_initialization() {
let table = get_dispatch_table();
assert!(!table.elementwise_backend.is_empty());
assert!(!table.reduction_backend.is_empty());
assert!(!table.matmul_backend.is_empty());
}
#[test]
fn test_backend_validation() {
let validation = validate_backends();
assert!(
validation.simd_available ||
validation.blas_available ||
validation.webgpu_available ||
validation.gpu_available
);
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
assert!(validation.simd_available);
}
}
#[test]
fn test_dispatch_matmul() {
let table = get_dispatch_table();
let a = Array::new(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]);
let b = Array::new(vec![2, 2], vec![1.0, 0.0, 0.0, 1.0]);
let result = (table.matmul)(&a, &b).expect("matmul should succeed");
assert_eq!(result.data, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_dispatch_elementwise() {
use numrs::llo::ElementwiseKind;
let table = get_dispatch_table();
let a = Array::new(vec![4], vec![1.0, 2.0, 3.0, 4.0]);
let b = Array::new(vec![4], vec![1.0, 1.0, 1.0, 1.0]);
let result = (table.elementwise)(&a, &b, ElementwiseKind::Add)
.expect("elementwise add should succeed");
assert_eq!(result.data, vec![2.0, 3.0, 4.0, 5.0]);
let result = (table.elementwise)(&a, &b, ElementwiseKind::Mul)
.expect("elementwise mul should succeed");
assert_eq!(result.data, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_dispatch_reduction() {
let table = get_dispatch_table();
let a = Array::new(vec![4], vec![1.0, 2.0, 3.0, 4.0]);
let result = (table.reduction)(&a, None, ReductionKind::Sum)
.expect("reduction should succeed");
assert!((result.data[0] - 10.0).abs() < 0.001);
}
#[test]
fn test_fast_path_api() {
use numrs::ops;
let a = Array::new(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]);
let b = Array::new(vec![2, 2], vec![2.0, 2.0, 2.0, 2.0]);
let result = ops::add(&a, &b).expect("ops add should work");
assert_eq!(result.to_f32().data, vec![3.0, 4.0, 5.0, 6.0]);
let c = Array::new(vec![2, 2], vec![1.0, 0.0, 0.0, 1.0]);
let result = ops::matmul(&a, &c).expect("ops matmul should work");
assert_eq!(result.to_f32().data, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
#[cfg(numrs_has_blas)]
fn test_blas_validation() {
let validation = validate_backends();
assert!(validation.blas_available, "BLAS should be available when compiled with MKL/BLIS/Accelerate");
assert!(validation.blas_validated, "BLAS should validate successfully");
}
#[test]
fn test_dispatch_zero_cost() {
let table = get_dispatch_table();
let a = Array::new(vec![100], vec![1.0; 100]);
let b = Array::new(vec![100], vec![2.0; 100]);
use std::time::Instant;
use numrs::llo::ElementwiseKind;
for _ in 0..10 {
let _ = (table.elementwise)(&a, &b, ElementwiseKind::Add);
}
let start = Instant::now();
for _ in 0..1000 {
let _ = (table.elementwise)(&a, &b, ElementwiseKind::Add);
}
let dispatch_time = start.elapsed();
assert!(dispatch_time.as_millis() < 100,
"Dispatch should be fast (zero-cost), took {:?}", dispatch_time);
}