use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::math::kfac_block_inverse";
#[must_use]
pub fn kfac_block_inverse(
blocks_out: &str,
blocks_in: &str,
scratch: &str,
num_blocks: u32,
n: u32,
) -> Program {
if num_blocks == 0 {
return crate::invalid_output_program(
OP_ID,
blocks_out,
DataType::F32,
"Fix: kfac_block_inverse requires num_blocks > 0, got 0.".to_string(),
);
}
if n == 0 {
return crate::invalid_output_program(
OP_ID,
blocks_out,
DataType::F32,
"Fix: kfac_block_inverse requires n > 0, got 0.".to_string(),
);
}
let t = Expr::InvocationId { axis: 0 };
let n_expr = Expr::u32(n);
let mut iter_body = Vec::new();
let offset = |i: Expr, j: Expr| {
Expr::add(
Expr::mul(t.clone(), Expr::mul(n_expr.clone(), n_expr.clone())),
Expr::add(Expr::mul(i, n_expr.clone()), j),
)
};
iter_body.push(Node::loop_for(
"i",
Expr::u32(0),
n_expr.clone(),
vec![Node::loop_for(
"j",
Expr::u32(0),
n_expr.clone(),
vec![
Node::let_bind("idx", offset(Expr::var("i"), Expr::var("j"))),
Node::store(
scratch,
Expr::var("idx"),
Expr::load(blocks_in, Expr::var("idx")),
),
Node::store(
blocks_out,
Expr::var("idx"),
Expr::select(
Expr::eq(Expr::var("i"), Expr::var("j")),
Expr::f32(1.0),
Expr::f32(0.0),
),
),
],
)],
));
iter_body.push(Node::loop_for(
"i",
Expr::u32(0),
n_expr.clone(),
vec![
Node::let_bind(
"pivot",
Expr::load(scratch, offset(Expr::var("i"), Expr::var("i"))),
),
Node::loop_for(
"j",
Expr::u32(0),
n_expr.clone(),
vec![
Node::let_bind("idx_ij", offset(Expr::var("i"), Expr::var("j"))),
Node::store(
scratch,
Expr::var("idx_ij"),
Expr::div(Expr::load(scratch, Expr::var("idx_ij")), Expr::var("pivot")),
),
Node::store(
blocks_out,
Expr::var("idx_ij"),
Expr::div(
Expr::load(blocks_out, Expr::var("idx_ij")),
Expr::var("pivot"),
),
),
],
),
Node::loop_for(
"k",
Expr::u32(0),
n_expr.clone(),
vec![Node::if_then(
Expr::ne(Expr::var("k"), Expr::var("i")),
vec![
Node::let_bind(
"factor",
Expr::load(scratch, offset(Expr::var("k"), Expr::var("i"))),
),
Node::loop_for(
"jj",
Expr::u32(0),
n_expr.clone(),
vec![
Node::let_bind("idx_kj", offset(Expr::var("k"), Expr::var("jj"))),
Node::let_bind("idx_ij", offset(Expr::var("i"), Expr::var("jj"))),
Node::store(
scratch,
Expr::var("idx_kj"),
Expr::sub(
Expr::load(scratch, Expr::var("idx_kj")),
Expr::mul(
Expr::var("factor"),
Expr::load(scratch, Expr::var("idx_ij")),
),
),
),
Node::store(
blocks_out,
Expr::var("idx_kj"),
Expr::sub(
Expr::load(blocks_out, Expr::var("idx_kj")),
Expr::mul(
Expr::var("factor"),
Expr::load(blocks_out, Expr::var("idx_ij")),
),
),
),
],
),
],
)],
),
],
));
let entry = vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(num_blocks)),
iter_body,
)];
Program::wrapped(
vec![
BufferDecl::storage(blocks_out, 0, BufferAccess::ReadWrite, DataType::F32)
.with_count(num_blocks * n * n),
BufferDecl::storage(blocks_in, 1, BufferAccess::ReadOnly, DataType::F32)
.with_count(num_blocks * n * n),
BufferDecl::storage(scratch, 2, BufferAccess::ReadWrite, DataType::F32)
.with_count(num_blocks * n * n),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(entry),
}],
)
}
#[must_use]
pub fn cpu_ref(blocks_in: &[f32], num_blocks: u32, n: u32) -> Vec<f32> {
let n = n as usize;
let mut out = vec![0.0; blocks_in.len()];
for b in 0..num_blocks as usize {
let block_offset = b * n * n;
let mut mat = vec![vec![0.0; n]; n];
let mut inv = vec![vec![0.0; n]; n];
for i in 0..n {
for j in 0..n {
mat[i][j] = blocks_in[block_offset + i * n + j];
inv[i][j] = if i == j { 1.0 } else { 0.0 };
}
}
for i in 0..n {
let pivot = mat[i][i];
for j in 0..n {
mat[i][j] /= pivot;
inv[i][j] /= pivot;
}
for k in 0..n {
if k != i {
let factor = mat[k][i];
for j in 0..n {
mat[k][j] -= factor * mat[i][j];
inv[k][j] -= factor * inv[i][j];
}
}
}
}
for i in 0..n {
for j in 0..n {
out[block_offset + i * n + j] = inv[i][j];
}
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cpu_ref_1x1() {
let blocks_in = vec![2.0];
let out = cpu_ref(&blocks_in, 1, 1);
assert_eq!(out, vec![0.5]);
}
#[test]
fn test_cpu_ref_multi_block() {
let blocks_in = vec![2.0, 0.0, 0.0, 2.0, 4.0, 0.0, 0.0, 4.0];
let out = cpu_ref(&blocks_in, 2, 2);
assert_eq!(out, vec![0.5, 0.0, 0.0, 0.5, 0.25, 0.0, 0.0, 0.25]);
}
#[test]
fn test_cpu_ref_3x3_diag() {
let blocks_in = vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 4.0];
let out = cpu_ref(&blocks_in, 1, 3);
assert_eq!(out, vec![1.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.0, 0.25]);
}
#[test]
fn test_cpu_ref_large_blocks() {
let n = 2;
let num_blocks = 64;
let mut blocks_in = vec![0.0; num_blocks * n * n];
for b in 0..num_blocks {
blocks_in[b * 4] = 2.0;
blocks_in[b * 4 + 3] = 2.0;
}
let out = cpu_ref(&blocks_in, num_blocks as u32, n as u32);
assert_eq!(out[0], 0.5);
assert_eq!(out[out.len() - 1], 0.5);
}
#[test]
fn test_cpu_ref_asymmetric_values() {
let blocks_in = vec![2.0, 1.0, 1.0, 2.0];
let out = cpu_ref(&blocks_in, 1, 2);
assert!((out[0] - 0.6666667).abs() < 1e-6);
assert!((out[1] - (-0.3333333)).abs() < 1e-6);
}
#[test]
fn test_parity_2x2() {
let blocks_in = vec![4.0, 0.0, 0.0, 2.0];
let p = kfac_block_inverse("bo", "bi", "s", 1, 2);
let expected_out = cpu_ref(&blocks_in, 1, 2);
use vyre_reference::reference_eval;
use vyre_reference::value::Value;
let to_value = |data: &[f32]| {
let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
Value::Bytes(Arc::from(bytes))
};
let inputs = vec![
to_value(&[0.0; 4]), to_value(&blocks_in), to_value(&[0.0; 4]), ];
let results = reference_eval(&p, &inputs).expect("Fix: interpreter failed");
let actual_bytes = results[0].to_bytes();
let actual_out: Vec<f32> = actual_bytes
.chunks_exact(4)
.map(|c| f32::from_le_bytes(c.try_into().unwrap()))
.collect();
for (a, b) in actual_out.iter().zip(expected_out.iter()) {
assert!((a - b).abs() < 1e-6);
}
}
#[test]
fn program_declares_three_buffers() {
let p = kfac_block_inverse("bo", "bi", "s", 4, 4);
assert_eq!(p.buffers().len(), 3);
}
#[test]
fn rejects_zero_num_blocks_with_trap() {
let p = kfac_block_inverse("bo", "bi", "s", 0, 4);
assert!(p.stats().trap());
}
#[test]
fn rejects_zero_n_with_trap() {
let p = kfac_block_inverse("bo", "bi", "s", 4, 0);
assert!(p.stats().trap());
}
}