use super::*;
use vyre_foundation::ir::BufferAccess;
#[test]
fn cpu_real_2x2() {
let a = vec![1, 2, 3, 4];
let b = vec![5, 6, 7, 8];
let c = semiring_gemm_cpu(&a, &b, 2, 2, 2, Semiring::Real);
assert_eq!(c, vec![19, 22, 43, 50]);
}
#[test]
fn cpu_real_identity() {
let a = vec![3, 5, 7, 11];
let i = vec![1, 0, 0, 1];
let c = semiring_gemm_cpu(&a, &i, 2, 2, 2, Semiring::Real);
assert_eq!(c, a);
}
#[test]
fn cpu_into_reuses_output_and_truncates_stale_tail() {
let a = vec![3, 5, 7, 11];
let i = vec![1, 0, 0, 1];
let mut c = Vec::with_capacity(8);
c.extend([99; 8]);
let ptr = c.as_ptr();
try_semiring_gemm_cpu_into(&a, &i, 2, 2, 2, Semiring::Real, &mut c).unwrap();
assert_eq!(c, a);
assert_eq!(c.as_ptr(), ptr);
}
#[test]
fn generated_cpu_matches_independent_real_gemm() {
for case in 0..48 {
let m = 1 + (case % 4);
let n = 1 + (case % 5);
let k = 1 + (case % 6);
let a: Vec<u32> = (0..m * k)
.map(|idx| (idx as u32).wrapping_mul(3).wrapping_add(case as u32))
.collect();
let b: Vec<u32> = (0..k * n)
.map(|idx| (idx as u32).wrapping_mul(5).wrapping_add(7))
.collect();
let mut c = Vec::with_capacity((m * n + 3) as usize);
try_semiring_gemm_cpu_into(&a, &b, m as u32, n as u32, k as u32, Semiring::Real, &mut c)
.unwrap();
for i in 0..m {
for j in 0..n {
let mut expected = 0u32;
for kk in 0..k {
expected = expected.wrapping_add(a[i * k + kk].wrapping_mul(b[kk * n + j]));
}
assert_eq!(c[i * n + j], expected, "case {case} cell {i},{j}");
}
}
}
}
#[test]
fn cpu_min_plus_shortest_path_step() {
let inf = u32::MAX;
let a = vec![
inf, 5, inf, inf, inf, 3, inf, inf, inf, ];
let c = semiring_gemm_cpu(&a, &a, 3, 3, 3, Semiring::MinPlus);
assert_eq!(c[0 * 3 + 2], 8);
assert_eq!(c[0 * 3 + 1], inf);
}
#[test]
fn cpu_min_plus_saturating_no_overflow() {
let inf = u32::MAX;
let a = vec![inf, inf, inf, inf];
let b = vec![inf, inf, inf, inf];
let c = semiring_gemm_cpu(&a, &b, 2, 2, 2, Semiring::MinPlus);
for v in c {
assert_eq!(v, inf);
}
}
#[test]
fn cpu_bool_or_reachability() {
let a = vec![
0, 1, 0, 0, 0, 1, 0, 0, 0, ];
let c = semiring_gemm_cpu(&a, &a, 3, 3, 3, Semiring::BoolOr);
assert_eq!(c[0 * 3 + 2], 1);
assert_eq!(c[0 * 3 + 1], 0); }
#[test]
fn cpu_lineage_scallop_join() {
let f1 = 0b01;
let f2 = 0b10;
let a = vec![
0, f1, 0, 0, 0, f2, 0, 0, 0, ];
let c = semiring_gemm_cpu(&a, &a, 3, 3, 3, Semiring::Lineage);
assert_eq!(c[0 * 3 + 2], f1 | f2, "lineage = union of facts along path");
assert_eq!(c[0 * 3 + 1], 0);
}
#[test]
fn cpu_lineage_alternative_paths_union() {
let f1 = 0b0001;
let f2 = 0b0010;
let f3 = 0b0100;
let f4 = 0b1000;
let a = vec![
0, f1, f3, 0, 0, 0, 0, f2, 0, 0, 0, f4, 0, 0, 0, 0, ];
let c = semiring_gemm_cpu(&a, &a, 4, 4, 4, Semiring::Lineage);
assert_eq!(
c[0 * 4 + 3],
f1 | f2 | f3 | f4,
"expected union over both paths"
);
}
#[test]
fn cpu_max_plus_longest_path() {
let a = vec![
0, 5, 0, 0, 0, 3, 0, 0, 0, ];
let c = semiring_gemm_cpu(&a, &a, 3, 3, 3, Semiring::MaxPlus);
assert_eq!(c[0 * 3 + 2], 8);
}
#[test]
fn cpu_gf2_xor_closure() {
let a = vec![1, 0, 1, 1];
let b = vec![1, 1, 0, 1];
let c = semiring_gemm_cpu(&a, &b, 2, 2, 2, Semiring::Gf2);
assert_eq!(c, vec![1, 1, 1, 0]);
}
#[test]
fn cpu_max_times_viterbi() {
let a = vec![50, 50];
let b = vec![60, 40, 30, 70];
let c = semiring_gemm_cpu(&a, &b, 1, 2, 2, Semiring::MaxTimes);
assert_eq!(c, vec![3000, 3500]);
}
#[test]
fn emitted_program_buffer_layout() {
let p = semiring_gemm("A", "B", "C", 4, 5, 3, Semiring::Real);
assert_eq!(p.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["A", "B", "C"]);
assert_eq!(p.buffers[0].count(), 4 * 3); assert_eq!(p.buffers[1].count(), 3 * 5); assert_eq!(p.buffers[2].count(), 4 * 5); }
#[test]
fn emitted_program_buffer_access_modes() {
let p = semiring_gemm("A", "B", "C", 2, 2, 2, Semiring::MinPlus);
assert_eq!(p.buffers[0].access(), BufferAccess::ReadOnly);
assert_eq!(p.buffers[1].access(), BufferAccess::ReadOnly);
assert_eq!(p.buffers[2].access(), BufferAccess::ReadWrite);
}
#[test]
fn zero_m_traps() {
let p = semiring_gemm("A", "B", "C", 0, 1, 1, Semiring::Real);
assert!(p.stats().trap());
}
#[test]
fn zero_n_traps() {
let p = semiring_gemm("A", "B", "C", 1, 0, 1, Semiring::Real);
assert!(p.stats().trap());
}
#[test]
fn zero_k_traps() {
let p = semiring_gemm("A", "B", "C", 1, 1, 0, Semiring::Real);
assert!(p.stats().trap());
}
#[test]
fn identity_table_matches_doc() {
assert_eq!(Semiring::Real.identity(), 0);
assert_eq!(Semiring::MinPlus.identity(), u32::MAX);
assert_eq!(Semiring::MaxPlus.identity(), 0);
assert_eq!(Semiring::MaxTimes.identity(), 0);
assert_eq!(Semiring::BoolOr.identity(), 0);
assert_eq!(Semiring::BoolAnd.identity(), u32::MAX);
assert_eq!(Semiring::Gf2.identity(), 0);
assert_eq!(Semiring::Lineage.identity(), 0);
}