use numrs::{Array, ops};
use numrs::backend::get_dispatch_table;
use anyhow::Result;
use std::time::Instant;
fn main() -> Result<()> {
println!("\n๐งช NumRs Operations Comprehensive Validation\n");
println!("{}", "=".repeat(60));
let table = get_dispatch_table();
println!("\n๐ Dispatch Table Configuration:");
println!(" - Elementwise: {}", table.elementwise_backend);
println!(" - Reduction: {}", table.reduction_backend);
println!(" - MatMul: {}", table.matmul_backend);
println!();
println!("{}", "=".repeat(60));
println!("\nโ
Test 1: Elementwise Operations\n");
test_elementwise()?;
println!("{}", "=".repeat(60));
println!("\nโ
Test 2: Reduction Operations\n");
test_reductions()?;
println!("{}", "=".repeat(60));
println!("\nโ
Test 3: MatMul Operations\n");
test_matmul()?;
println!("{}", "=".repeat(60));
println!("\nโ
Test 4: Size Scalability\n");
test_sizes()?;
println!("{}", "=".repeat(60));
println!("\nโ
Test 5: Edge Cases\n");
test_edge_cases()?;
println!("{}", "=".repeat(60));
println!("\nโ
All tests passed! ops validation complete.\n");
Ok(())
}
fn test_elementwise() -> Result<()> {
let a = Array::new(vec![4], vec![1.0, 2.0, 3.0, 4.0]);
let b = Array::new(vec![4], vec![2.0, 2.0, 2.0, 2.0]);
let c = ops::add(&a, &b)?;
assert_eq!(c.data, vec![3.0, 4.0, 5.0, 6.0]);
println!(" โ add([1,2,3,4], [2,2,2,2]) = [3,4,5,6]");
let c = ops::mul(&a, &b)?;
assert_eq!(c.data, vec![2.0, 4.0, 6.0, 8.0]);
println!(" โ mul([1,2,3,4], [2,2,2,2]) = [2,4,6,8]");
let c = ops::sub(&a, &b)?;
assert_eq!(c.data, vec![-1.0, 0.0, 1.0, 2.0]);
println!(" โ sub([1,2,3,4], [2,2,2,2]) = [-1,0,1,2]");
let c = ops::div(&a, &b)?;
assert_eq!(c.data, vec![0.5, 1.0, 1.5, 2.0]);
println!(" โ div([1,2,3,4], [2,2,2,2]) = [0.5,1,1.5,2]");
Ok(())
}
fn test_reductions() -> Result<()> {
let a = Array::new(vec![4], vec![1.0, 2.0, 3.0, 4.0]);
let s = ops::sum(&a, None)?;
assert_eq!(s.data, vec![10.0]);
println!(" โ sum([1,2,3,4]) = 10");
let b = Array::new(vec![100], vec![1.0; 100]);
let s = ops::sum(&b, None)?;
assert_eq!(s.data[0], 100.0);
println!(" โ sum(100 ones) = 100");
let c = Array::new(vec![1000], vec![0.5; 1000]);
let s = ops::sum(&c, None)?;
assert_eq!(s.data[0], 500.0);
println!(" โ sum(1000 halves) = 500");
Ok(())
}
fn test_matmul() -> Result<()> {
let a = Array::new(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]);
let b = Array::new(vec![2, 2], vec![5.0, 6.0, 7.0, 8.0]);
let c = ops::matmul(&a, &b)?;
assert_eq!(c.data, vec![19.0, 22.0, 43.0, 50.0]);
println!(" โ matmul(2x2) = correct");
let a = Array::new(vec![3, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
let identity = Array::new(vec![3, 3], vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
let c = ops::matmul(&a, &identity)?;
assert_eq!(c.data, a.data);
println!(" โ matmul(A, I) = A");
let n = 10;
let a = Array::new(vec![n, n], vec![1.0; (n * n) as usize]);
let b = Array::new(vec![n, n], vec![1.0; (n * n) as usize]);
let c = ops::matmul(&a, &b)?;
assert!(c.data.iter().all(|&x| (x - 10.0).abs() < 1e-5));
println!(" โ matmul(10x10 ones) = all 10s");
Ok(())
}
fn test_sizes() -> Result<()> {
let sizes = vec![4, 10, 100, 1000, 10_000];
for size in sizes {
let start = Instant::now();
let a = Array::new(vec![size], vec![1.0; size as usize]);
let b = Array::new(vec![size], vec![2.0; size as usize]);
let c = ops::add(&a, &b)?;
assert_eq!(c.data.len(), size as usize);
assert!(c.data.iter().all(|&x| (x - 3.0).abs() < 1e-5));
let s = ops::sum(&a, None)?;
assert!((s.data[0] - size as f32).abs() < 1e-3);
let elapsed = start.elapsed();
println!(" โ Size {}: {} ยตs (add + sum)", size, elapsed.as_micros());
}
let mm_sizes = vec![2, 10, 50, 100];
for n in mm_sizes {
let start = Instant::now();
let a = Array::new(vec![n, n], vec![1.0; (n * n) as usize]);
let b = Array::new(vec![n, n], vec![1.0; (n * n) as usize]);
let c = ops::matmul(&a, &b)?;
assert_eq!(c.data.len(), (n * n) as usize);
let elapsed = start.elapsed();
let ops = 2.0 * (n as f64).powi(3);
let gflops = ops / (elapsed.as_secs_f64() * 1e9);
println!(" โ MatMul {}x{}: {} ยตs ({:.2} GFLOPS)",
n, n, elapsed.as_micros(), gflops);
}
Ok(())
}
fn test_edge_cases() -> Result<()> {
let a = Array::new(vec![1], vec![5.0]);
let b = Array::new(vec![1], vec![3.0]);
let c = ops::add(&a, &b)?;
assert_eq!(c.data, vec![8.0]);
println!(" โ Single element: [5] + [3] = [8]");
let a = Array::new(vec![10], vec![0.0; 10]);
let s = ops::sum(&a, None)?;
assert_eq!(s.data[0], 0.0);
println!(" โ Sum of zeros = 0");
let a = Array::new(vec![4], vec![-1.0, -2.0, -3.0, -4.0]);
let b = Array::new(vec![4], vec![1.0, 2.0, 3.0, 4.0]);
let c = ops::add(&a, &b)?;
assert!(c.data.iter().all(|&x| x.abs() < 1e-5));
println!(" โ Negative + positive = zeros");
let a = Array::new(vec![4], vec![1e6, 2e6, 3e6, 4e6]);
let s = ops::sum(&a, None)?;
assert!((s.data[0] - 1e7).abs() / 1e7 < 1e-5);
println!(" โ Large values: sum = 10M");
let a = Array::new(vec![4], vec![1e-6, 2e-6, 3e-6, 4e-6]);
let s = ops::sum(&a, None)?;
assert!((s.data[0] - 1e-5).abs() / 1e-5 < 1e-3);
println!(" โ Small values: sum = 1e-5");
let a = Array::new(vec![1, 1], vec![7.0]);
let b = Array::new(vec![1, 1], vec![3.0]);
let c = ops::matmul(&a, &b)?;
assert_eq!(c.data, vec![21.0]);
println!(" โ 1x1 matmul: [[7]] * [[3]] = [[21]]");
let a = Array::new(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let b = Array::new(vec![3, 2], vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
let c = ops::matmul(&a, &b)?;
assert_eq!(c.shape, vec![2, 2]);
assert_eq!(c.data, vec![58.0, 64.0, 139.0, 154.0]);
println!(" โ Non-square matmul: 2x3 * 3x2 = 2x2");
Ok(())
}