use vyre::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use crate::region::wrap_anonymous;
const OP_ID: &str = "vyre-libs::math::linalg::matmul_strassen_2x2";
#[must_use]
pub fn matmul_strassen_2x2(a: &str, b: &str, c: &str) -> Program {
let body = vec![
Node::let_bind("a00", Expr::load(a, Expr::u32(0))),
Node::let_bind("a01", Expr::load(a, Expr::u32(1))),
Node::let_bind("a10", Expr::load(a, Expr::u32(2))),
Node::let_bind("a11", Expr::load(a, Expr::u32(3))),
Node::let_bind("b00", Expr::load(b, Expr::u32(0))),
Node::let_bind("b01", Expr::load(b, Expr::u32(1))),
Node::let_bind("b10", Expr::load(b, Expr::u32(2))),
Node::let_bind("b11", Expr::load(b, Expr::u32(3))),
Node::let_bind(
"m1",
Expr::mul(
Expr::add(Expr::var("a00"), Expr::var("a11")),
Expr::add(Expr::var("b00"), Expr::var("b11")),
),
),
Node::let_bind(
"m2",
Expr::mul(
Expr::add(Expr::var("a10"), Expr::var("a11")),
Expr::var("b00"),
),
),
Node::let_bind(
"m3",
Expr::mul(
Expr::var("a00"),
Expr::sub(Expr::var("b01"), Expr::var("b11")),
),
),
Node::let_bind(
"m4",
Expr::mul(
Expr::var("a11"),
Expr::sub(Expr::var("b10"), Expr::var("b00")),
),
),
Node::let_bind(
"m5",
Expr::mul(
Expr::add(Expr::var("a00"), Expr::var("a01")),
Expr::var("b11"),
),
),
Node::let_bind(
"m6",
Expr::mul(
Expr::sub(Expr::var("a10"), Expr::var("a00")),
Expr::add(Expr::var("b00"), Expr::var("b01")),
),
),
Node::let_bind(
"m7",
Expr::mul(
Expr::sub(Expr::var("a01"), Expr::var("a11")),
Expr::add(Expr::var("b10"), Expr::var("b11")),
),
),
Node::Store {
buffer: c.into(),
index: Expr::u32(0),
value: Expr::add(
Expr::sub(Expr::add(Expr::var("m1"), Expr::var("m4")), Expr::var("m5")),
Expr::var("m7"),
),
},
Node::Store {
buffer: c.into(),
index: Expr::u32(1),
value: Expr::add(Expr::var("m3"), Expr::var("m5")),
},
Node::Store {
buffer: c.into(),
index: Expr::u32(2),
value: Expr::add(Expr::var("m2"), Expr::var("m4")),
},
Node::Store {
buffer: c.into(),
index: Expr::u32(3),
value: Expr::add(
Expr::add(Expr::sub(Expr::var("m1"), Expr::var("m2")), Expr::var("m3")),
Expr::var("m6"),
),
},
];
Program::wrapped(
vec![
BufferDecl::storage(a, 0, BufferAccess::ReadOnly, DataType::F32).with_count(4),
BufferDecl::storage(b, 1, BufferAccess::ReadOnly, DataType::F32).with_count(4),
BufferDecl::output(c, 2, DataType::F32).with_count(4),
],
[1, 1, 1],
vec![wrap_anonymous(OP_ID, body)],
)
}
inventory::submit! {
crate::harness::OpEntry {
id: OP_ID,
build: || matmul_strassen_2x2("a", "b", "c"),
test_inputs: Some(|| {
let a = crate::test_support::byte_pack::f32_bytes(&[1.0, 2.0, 3.0, 4.0]);
let b = crate::test_support::byte_pack::f32_bytes(&[5.0, 6.0, 7.0, 8.0]);
vec![vec![a, b]]
}),
expected_output: Some(|| {
vec![vec![crate::test_support::byte_pack::f32_bytes(&[19.0, 22.0, 43.0, 50.0])]]
}),
category: Some("math"),
}
}
pub fn matmul_strassen_one_level(a: &str, b: &str, c: &str, n: u32) -> Result<Program, String> {
if n == 0 {
return Err("Fix: matmul_strassen_one_level n=0 is invalid".to_string());
}
if n % 2 != 0 {
return Err(format!(
"Fix: matmul_strassen_one_level requires even n; got n={n}. Use matmul or pad."
));
}
let half = n / 2;
let total = n
.checked_mul(n)
.ok_or_else(|| "Fix: matmul_strassen_one_level n*n overflows u32; reduce n.".to_string())?;
let body = vec![
Node::let_bind("flat", Expr::InvocationId { axis: 0 }),
Node::if_then(
Expr::lt(Expr::var("flat"), Expr::u32(total)),
vec![
Node::let_bind("row", Expr::div(Expr::var("flat"), Expr::u32(n))),
Node::let_bind("col", Expr::rem(Expr::var("flat"), Expr::u32(n))),
Node::let_bind("q_row", Expr::div(Expr::var("row"), Expr::u32(half))),
Node::let_bind("q_col", Expr::div(Expr::var("col"), Expr::u32(half))),
Node::let_bind("sr", Expr::rem(Expr::var("row"), Expr::u32(half))),
Node::let_bind("sc", Expr::rem(Expr::var("col"), Expr::u32(half))),
Node::let_bind("c_val", Expr::f32(0.0)),
Node::let_bind("m1", Expr::f32(0.0)),
Node::let_bind("m2", Expr::f32(0.0)),
Node::let_bind("m3", Expr::f32(0.0)),
Node::let_bind("m4", Expr::f32(0.0)),
Node::let_bind("m5", Expr::f32(0.0)),
Node::let_bind("m6", Expr::f32(0.0)),
Node::let_bind("m7", Expr::f32(0.0)),
Node::loop_for(
"k",
Expr::u32(0),
Expr::u32(half),
vec![
Node::let_bind(
"a11",
Expr::load(
a,
Expr::add(Expr::mul(Expr::var("sr"), Expr::u32(n)), Expr::var("k")),
),
),
Node::let_bind(
"a12",
Expr::load(
a,
Expr::add(
Expr::mul(Expr::var("sr"), Expr::u32(n)),
Expr::add(Expr::u32(half), Expr::var("k")),
),
),
),
Node::let_bind(
"a21",
Expr::load(
a,
Expr::add(
Expr::mul(
Expr::add(Expr::u32(half), Expr::var("sr")),
Expr::u32(n),
),
Expr::var("k"),
),
),
),
Node::let_bind(
"a22",
Expr::load(
a,
Expr::add(
Expr::mul(
Expr::add(Expr::u32(half), Expr::var("sr")),
Expr::u32(n),
),
Expr::add(Expr::u32(half), Expr::var("k")),
),
),
),
Node::let_bind(
"b11",
Expr::load(
b,
Expr::add(Expr::mul(Expr::var("k"), Expr::u32(n)), Expr::var("sc")),
),
),
Node::let_bind(
"b12",
Expr::load(
b,
Expr::add(
Expr::mul(Expr::var("k"), Expr::u32(n)),
Expr::add(Expr::u32(half), Expr::var("sc")),
),
),
),
Node::let_bind(
"b21",
Expr::load(
b,
Expr::add(
Expr::mul(
Expr::add(Expr::u32(half), Expr::var("k")),
Expr::u32(n),
),
Expr::var("sc"),
),
),
),
Node::let_bind(
"b22",
Expr::load(
b,
Expr::add(
Expr::mul(
Expr::add(Expr::u32(half), Expr::var("k")),
Expr::u32(n),
),
Expr::add(Expr::u32(half), Expr::var("sc")),
),
),
),
Node::assign(
"m1",
Expr::add(
Expr::var("m1"),
Expr::mul(
Expr::add(Expr::var("a11"), Expr::var("a22")),
Expr::add(Expr::var("b11"), Expr::var("b22")),
),
),
),
Node::assign(
"m2",
Expr::add(
Expr::var("m2"),
Expr::mul(
Expr::add(Expr::var("a21"), Expr::var("a22")),
Expr::var("b11"),
),
),
),
Node::assign(
"m3",
Expr::add(
Expr::var("m3"),
Expr::mul(
Expr::var("a11"),
Expr::sub(Expr::var("b12"), Expr::var("b22")),
),
),
),
Node::assign(
"m4",
Expr::add(
Expr::var("m4"),
Expr::mul(
Expr::var("a22"),
Expr::sub(Expr::var("b21"), Expr::var("b11")),
),
),
),
Node::assign(
"m5",
Expr::add(
Expr::var("m5"),
Expr::mul(
Expr::add(Expr::var("a11"), Expr::var("a12")),
Expr::var("b22"),
),
),
),
Node::assign(
"m6",
Expr::add(
Expr::var("m6"),
Expr::mul(
Expr::sub(Expr::var("a21"), Expr::var("a11")),
Expr::add(Expr::var("b11"), Expr::var("b12")),
),
),
),
Node::assign(
"m7",
Expr::add(
Expr::var("m7"),
Expr::mul(
Expr::sub(Expr::var("a12"), Expr::var("a22")),
Expr::add(Expr::var("b21"), Expr::var("b22")),
),
),
),
],
),
Node::assign(
"c_val",
Expr::select(
Expr::and(
Expr::eq(Expr::var("q_row"), Expr::u32(0)),
Expr::eq(Expr::var("q_col"), Expr::u32(0)),
),
Expr::add(
Expr::sub(Expr::add(Expr::var("m1"), Expr::var("m4")), Expr::var("m5")),
Expr::var("m7"),
),
Expr::select(
Expr::and(
Expr::eq(Expr::var("q_row"), Expr::u32(0)),
Expr::eq(Expr::var("q_col"), Expr::u32(1)),
),
Expr::add(Expr::var("m3"), Expr::var("m5")),
Expr::select(
Expr::and(
Expr::eq(Expr::var("q_row"), Expr::u32(1)),
Expr::eq(Expr::var("q_col"), Expr::u32(0)),
),
Expr::add(Expr::var("m2"), Expr::var("m4")),
Expr::add(
Expr::add(
Expr::sub(Expr::var("m1"), Expr::var("m2")),
Expr::var("m3"),
),
Expr::var("m6"),
),
),
),
),
),
Node::store(c, Expr::var("flat"), Expr::var("c_val")),
],
),
];
Ok(Program::wrapped(
vec![
BufferDecl::storage(a, 0, BufferAccess::ReadOnly, DataType::F32).with_count(total),
BufferDecl::storage(b, 1, BufferAccess::ReadOnly, DataType::F32).with_count(total),
BufferDecl::output(c, 2, DataType::F32).with_count(total),
],
[64, 1, 1],
vec![wrap_anonymous(
"vyre-libs::math::linalg::matmul_strassen_one_level",
body,
)],
))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::byte_pack::f32_bytes;
use vyre_reference::value::Value;
fn decode(bytes: &[u8]) -> Vec<f32> {
bytes
.chunks_exact(4)
.map(|c| f32::from_le_bytes(c.try_into().unwrap()))
.collect()
}
fn naive_2x2(a: &[f32], b: &[f32]) -> [f32; 4] {
[
a[0] * b[0] + a[1] * b[2],
a[0] * b[1] + a[1] * b[3],
a[2] * b[0] + a[3] * b[2],
a[2] * b[1] + a[3] * b[3],
]
}
fn run_strassen(a: &[f32], b: &[f32]) -> Vec<f32> {
let prog = matmul_strassen_2x2("a", "b", "c");
let outputs = vyre_reference::reference_eval(
&prog,
&[
Value::from(f32_bytes(a)),
Value::from(f32_bytes(b)),
Value::from(vec![0u8; 16]),
],
)
.expect("Fix: matmul_strassen_2x2 must execute in the reference interpreter.");
decode(&outputs[0].to_bytes())
}
#[test]
fn strassen_matches_naive_canonical_fixture() {
let a = [1.0, 2.0, 3.0, 4.0];
let b = [5.0, 6.0, 7.0, 8.0];
let actual = run_strassen(&a, &b);
let expected = naive_2x2(&a, &b);
assert_eq!(expected, [19.0, 22.0, 43.0, 50.0]);
for (lhs, rhs) in actual.iter().zip(expected.iter()) {
assert!((lhs - rhs).abs() <= 1.0e-5, "{lhs} != {rhs}");
}
}
#[test]
fn strassen_identity_returns_a() {
let a = [1.5_f32, -2.25, 3.75, -0.5];
let identity = [1.0_f32, 0.0, 0.0, 1.0];
let actual = run_strassen(&a, &identity);
for (lhs, rhs) in actual.iter().zip(a.iter()) {
assert!((lhs - rhs).abs() <= 1.0e-5, "{lhs} != {rhs}");
}
}
#[test]
fn strassen_zero_returns_zero() {
let a = [1.5_f32, -2.25, 3.75, -0.5];
let zero = [0.0_f32; 4];
let actual = run_strassen(&a, &zero);
for v in actual {
assert_eq!(v, 0.0);
}
}
fn naive_nxn(a: &[f32], b: &[f32], n: usize) -> Vec<f32> {
let mut c = vec![0.0_f32; n * n];
for i in 0..n {
for j in 0..n {
let mut acc = 0.0_f32;
for k in 0..n {
acc += a[i * n + k] * b[k * n + j];
}
c[i * n + j] = acc;
}
}
c
}
fn run_one_level(a: &[f32], b: &[f32], n: u32) -> Vec<f32> {
let prog = matmul_strassen_one_level("a", "b", "c", n).expect("Fix: build");
let outputs = vyre_reference::reference_eval(
&prog,
&[
Value::from(f32_bytes(a)),
Value::from(f32_bytes(b)),
Value::from(vec![0u8; (n as usize) * (n as usize) * 4]),
],
)
.expect("Fix: matmul_strassen_one_level must execute in the reference interpreter.");
decode(&outputs[0].to_bytes())
}
#[test]
fn strassen_one_level_matches_naive_at_n4() {
let a: Vec<f32> = (0..16).map(|i| (i as f32) * 0.5 - 4.0).collect();
let b: Vec<f32> = (0..16).map(|i| (i as f32) * 0.25 + 1.0).collect();
let actual = run_one_level(&a, &b, 4);
let expected = naive_nxn(&a, &b, 4);
assert_eq!(actual.len(), 16);
for (i, (l, r)) in actual.iter().zip(expected.iter()).enumerate() {
assert!((l - r).abs() <= 1.0e-4, "lane {i}: strassen={l} naive={r}");
}
}
#[test]
fn strassen_one_level_n2_matches_strassen_2x2() {
let a = [1.0_f32, 2.0, 3.0, 4.0];
let b = [5.0_f32, 6.0, 7.0, 8.0];
let level1 = run_one_level(&a, &b, 2);
let flat = run_strassen(&a, &b);
for (l, f) in level1.iter().zip(flat.iter()) {
assert!((l - f).abs() <= 1.0e-5, "{l} != {f}");
}
}
#[test]
fn strassen_one_level_rejects_odd_n() {
let err = matmul_strassen_one_level("a", "b", "c", 3).expect_err("odd n must error");
assert!(err.contains("even"));
}
#[test]
fn strassen_matches_naive_on_random_fuzz() {
let mut state = 0x12345678_u64;
let mut next = || {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((state >> 33) as f32 / (u32::MAX as f32 / 2.0)) - 1.0
};
for _ in 0..100 {
let a = [next(), next(), next(), next()];
let b = [next(), next(), next(), next()];
let actual = run_strassen(&a, &b);
let expected = naive_2x2(&a, &b);
for (i, (lhs, rhs)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
(lhs - rhs).abs() <= 1.0e-4,
"lane {i}: strassen={lhs} naive={rhs} diff={}",
(lhs - rhs).abs()
);
}
}
}
}