use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::math::semiring_gemm";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Semiring {
Real,
MinPlus,
MaxPlus,
MaxTimes,
BoolOr,
BoolAnd,
Gf2,
Lineage,
}
impl Semiring {
pub fn combine(self, a: Expr, b: Expr) -> Expr {
match self {
Self::Real | Self::MaxTimes => Expr::mul(a, b),
Self::MinPlus => {
let max_const = Expr::u32(u32::MAX);
let either_inf = Expr::or(
Expr::eq(a.clone(), max_const.clone()),
Expr::eq(b.clone(), max_const.clone()),
);
Expr::select(either_inf, max_const, Expr::add(a, b))
}
Self::MaxPlus => Expr::add(a, b),
Self::BoolOr | Self::Gf2 => Expr::bitand(a, b),
Self::BoolAnd => Expr::bitor(a, b),
Self::Lineage => {
let either_zero = Expr::or(
Expr::eq(a.clone(), Expr::u32(0)),
Expr::eq(b.clone(), Expr::u32(0)),
);
Expr::select(either_zero, Expr::u32(0), Expr::bitor(a, b))
}
}
}
pub fn accumulate(self, acc: Expr, val: Expr) -> Expr {
match self {
Self::Real | Self::MaxPlus => Expr::add(acc, val),
Self::MinPlus => Expr::min(acc, val),
Self::MaxTimes => Expr::max(acc, val),
Self::BoolOr | Self::Lineage => Expr::bitor(acc, val),
Self::BoolAnd => Expr::bitand(acc, val),
Self::Gf2 => Expr::bitxor(acc, val),
}
}
#[must_use]
pub fn identity(self) -> u32 {
match self {
Self::Real
| Self::MaxPlus
| Self::MaxTimes
| Self::BoolOr
| Self::Gf2
| Self::Lineage => 0,
Self::MinPlus | Self::BoolAnd => u32::MAX,
}
}
}
#[must_use]
pub fn semiring_gemm(
a: &str,
b: &str,
c: &str,
m: u32,
n: u32,
k: u32,
semiring: Semiring,
) -> Program {
if m == 0 {
return crate::invalid_output_program(
OP_ID,
c,
DataType::U32,
format!("Fix: semiring_gemm requires m > 0, got {m}."),
);
}
if n == 0 {
return crate::invalid_output_program(
OP_ID,
c,
DataType::U32,
format!("Fix: semiring_gemm requires n > 0, got {n}."),
);
}
if k == 0 {
return crate::invalid_output_program(
OP_ID,
c,
DataType::U32,
format!("Fix: semiring_gemm requires k > 0, got {k}."),
);
}
let cell_count = m * n;
let t = Expr::InvocationId { axis: 0 };
let i_expr = Expr::div(t.clone(), Expr::u32(n));
let j_expr = Expr::rem(t.clone(), Expr::u32(n));
let a_idx = Expr::add(Expr::mul(Expr::var("i"), Expr::u32(k)), Expr::var("kk"));
let b_idx = Expr::add(Expr::mul(Expr::var("kk"), Expr::u32(n)), Expr::var("j"));
let combine = semiring.combine(Expr::load(a, a_idx), Expr::load(b, b_idx));
let accumulate = semiring.accumulate(Expr::var("acc"), combine);
let body = vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(cell_count)),
vec![
Node::let_bind("i", i_expr),
Node::let_bind("j", j_expr),
Node::let_bind("acc", Expr::u32(semiring.identity())),
Node::loop_for(
"kk",
Expr::u32(0),
Expr::u32(k),
vec![Node::assign("acc", accumulate)],
),
Node::store(c, t, Expr::var("acc")),
],
)];
Program::wrapped(
vec![
BufferDecl::storage(a, 0, BufferAccess::ReadOnly, DataType::U32).with_count(m * k),
BufferDecl::storage(b, 1, BufferAccess::ReadOnly, DataType::U32).with_count(k * n),
BufferDecl::storage(c, 2, BufferAccess::ReadWrite, DataType::U32)
.with_count(cell_count),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
pub fn semiring_gemm_cpu(
a: &[u32],
b: &[u32],
m: u32,
n: u32,
k: u32,
semiring: Semiring,
) -> Vec<u32> {
let mut c = Vec::new();
semiring_gemm_cpu_into(a, b, m, n, k, semiring, &mut c);
c
}
pub fn semiring_gemm_cpu_into(
a: &[u32],
b: &[u32],
m: u32,
n: u32,
k: u32,
semiring: Semiring,
c: &mut Vec<u32>,
) {
let cell_count = (m * n) as usize;
c.clear();
c.resize(cell_count, semiring.identity());
for i in 0..m {
for j in 0..n {
let mut acc = semiring.identity();
for kk in 0..k {
let a_v = a
.get((i * k + kk) as usize)
.copied()
.unwrap_or(semiring.identity());
let b_v = b
.get((kk * n + j) as usize)
.copied()
.unwrap_or(semiring.identity());
let combined = semiring_combine_cpu(semiring, a_v, b_v);
acc = semiring_accumulate_cpu(semiring, acc, combined);
}
c[(i * n + j) as usize] = acc;
}
}
}
#[inline]
fn semiring_combine_cpu(s: Semiring, a: u32, b: u32) -> u32 {
match s {
Semiring::Real | Semiring::MaxTimes => a.wrapping_mul(b),
Semiring::MinPlus => {
if a == u32::MAX || b == u32::MAX {
u32::MAX
} else {
a.saturating_add(b)
}
}
Semiring::MaxPlus => a.saturating_add(b),
Semiring::BoolOr | Semiring::Gf2 => a & b,
Semiring::BoolAnd => a | b,
Semiring::Lineage => {
if a == 0 || b == 0 {
0
} else {
a | b
}
}
}
}
#[inline]
fn semiring_accumulate_cpu(s: Semiring, acc: u32, val: u32) -> u32 {
match s {
Semiring::Real | Semiring::MaxPlus => acc.wrapping_add(val),
Semiring::MinPlus => acc.min(val),
Semiring::MaxTimes => acc.max(val),
Semiring::BoolOr | Semiring::Lineage => acc | val,
Semiring::BoolAnd => acc & val,
Semiring::Gf2 => acc ^ val,
}
}
#[cfg(feature = "inventory-registry")]
fn fixture_u32(words: &[u32]) -> Vec<u8> {
words.iter().flat_map(|word| word.to_le_bytes()).collect()
}
#[cfg(feature = "inventory-registry")]
inventory::submit! {
crate::harness::OpEntry::new(
OP_ID,
|| semiring_gemm("a", "b", "c", 2, 2, 2, Semiring::Real),
Some(|| vec![vec![
fixture_u32(&[1, 2, 3, 4]),
fixture_u32(&[5, 6, 7, 8]),
fixture_u32(&[0, 0, 0, 0]),
]]),
Some(|| vec![vec![fixture_u32(&[19, 22, 43, 50])]]),
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cpu_real_2x2() {
let a = vec![1, 2, 3, 4];
let b = vec![5, 6, 7, 8];
let c = semiring_gemm_cpu(&a, &b, 2, 2, 2, Semiring::Real);
assert_eq!(c, vec![19, 22, 43, 50]);
}
#[test]
fn cpu_real_identity() {
let a = vec![3, 5, 7, 11];
let i = vec![1, 0, 0, 1];
let c = semiring_gemm_cpu(&a, &i, 2, 2, 2, Semiring::Real);
assert_eq!(c, a);
}
#[test]
fn cpu_min_plus_shortest_path_step() {
let inf = u32::MAX;
let a = vec![
inf, 5, inf, inf, inf, 3, inf, inf, inf, ];
let c = semiring_gemm_cpu(&a, &a, 3, 3, 3, Semiring::MinPlus);
assert_eq!(c[0 * 3 + 2], 8);
assert_eq!(c[0 * 3 + 1], inf);
}
#[test]
fn cpu_min_plus_saturating_no_overflow() {
let inf = u32::MAX;
let a = vec![inf, inf, inf, inf];
let b = vec![inf, inf, inf, inf];
let c = semiring_gemm_cpu(&a, &b, 2, 2, 2, Semiring::MinPlus);
for v in c {
assert_eq!(v, inf);
}
}
#[test]
fn cpu_bool_or_reachability() {
let a = vec![
0, 1, 0, 0, 0, 1, 0, 0, 0, ];
let c = semiring_gemm_cpu(&a, &a, 3, 3, 3, Semiring::BoolOr);
assert_eq!(c[0 * 3 + 2], 1);
assert_eq!(c[0 * 3 + 1], 0); }
#[test]
fn cpu_lineage_scallop_join() {
let f1 = 0b01;
let f2 = 0b10;
let a = vec![
0, f1, 0, 0, 0, f2, 0, 0, 0, ];
let c = semiring_gemm_cpu(&a, &a, 3, 3, 3, Semiring::Lineage);
assert_eq!(c[0 * 3 + 2], f1 | f2, "lineage = union of facts along path");
assert_eq!(c[0 * 3 + 1], 0);
}
#[test]
fn cpu_lineage_alternative_paths_union() {
let f1 = 0b0001;
let f2 = 0b0010;
let f3 = 0b0100;
let f4 = 0b1000;
let a = vec![
0, f1, f3, 0, 0, 0, 0, f2, 0, 0, 0, f4, 0, 0, 0, 0, ];
let c = semiring_gemm_cpu(&a, &a, 4, 4, 4, Semiring::Lineage);
assert_eq!(
c[0 * 4 + 3],
f1 | f2 | f3 | f4,
"expected union over both paths"
);
}
#[test]
fn cpu_max_plus_longest_path() {
let a = vec![
0, 5, 0, 0, 0, 3, 0, 0, 0, ];
let c = semiring_gemm_cpu(&a, &a, 3, 3, 3, Semiring::MaxPlus);
assert_eq!(c[0 * 3 + 2], 8);
}
#[test]
fn cpu_gf2_xor_closure() {
let a = vec![1, 0, 1, 1];
let b = vec![1, 1, 0, 1];
let c = semiring_gemm_cpu(&a, &b, 2, 2, 2, Semiring::Gf2);
assert_eq!(c, vec![1, 1, 1, 0]);
}
#[test]
fn cpu_max_times_viterbi() {
let a = vec![50, 50];
let b = vec![60, 40, 30, 70];
let c = semiring_gemm_cpu(&a, &b, 1, 2, 2, Semiring::MaxTimes);
assert_eq!(c, vec![3000, 3500]);
}
#[test]
fn emitted_program_buffer_layout() {
let p = semiring_gemm("A", "B", "C", 4, 5, 3, Semiring::Real);
assert_eq!(p.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["A", "B", "C"]);
assert_eq!(p.buffers[0].count(), 4 * 3); assert_eq!(p.buffers[1].count(), 3 * 5); assert_eq!(p.buffers[2].count(), 4 * 5); }
#[test]
fn emitted_program_buffer_access_modes() {
let p = semiring_gemm("A", "B", "C", 2, 2, 2, Semiring::MinPlus);
assert_eq!(p.buffers[0].access(), BufferAccess::ReadOnly);
assert_eq!(p.buffers[1].access(), BufferAccess::ReadOnly);
assert_eq!(p.buffers[2].access(), BufferAccess::ReadWrite);
}
#[test]
fn zero_m_traps() {
let p = semiring_gemm("A", "B", "C", 0, 1, 1, Semiring::Real);
assert!(p.stats().trap());
}
#[test]
fn zero_n_traps() {
let p = semiring_gemm("A", "B", "C", 1, 0, 1, Semiring::Real);
assert!(p.stats().trap());
}
#[test]
fn zero_k_traps() {
let p = semiring_gemm("A", "B", "C", 1, 1, 0, Semiring::Real);
assert!(p.stats().trap());
}
#[test]
fn identity_table_matches_doc() {
assert_eq!(Semiring::Real.identity(), 0);
assert_eq!(Semiring::MinPlus.identity(), u32::MAX);
assert_eq!(Semiring::MaxPlus.identity(), 0);
assert_eq!(Semiring::MaxTimes.identity(), 0);
assert_eq!(Semiring::BoolOr.identity(), 0);
assert_eq!(Semiring::BoolAnd.identity(), u32::MAX);
assert_eq!(Semiring::Gf2.identity(), 0);
assert_eq!(Semiring::Lineage.identity(), 0);
}
}