use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::math::natural_gradient_block_apply";
#[must_use]
pub fn natural_gradient_block_apply(
m_inv_sqrt: &str,
grad: &str,
grad_nat: &str,
n: u32,
) -> Program {
match try_natural_gradient_block_apply(m_inv_sqrt, grad, grad_nat, n) {
Ok(program) => program,
Err(error) => crate::invalid_output_program(OP_ID, grad_nat, DataType::U32, error),
}
}
pub fn try_natural_gradient_block_apply(
m_inv_sqrt: &str,
grad: &str,
grad_nat: &str,
n: u32,
) -> Result<Program, String> {
if n == 0 {
return Err(format!(
"Fix: natural_gradient_block_apply requires n > 0, got {n}."
));
}
let matrix_cells = n.checked_mul(n).ok_or_else(|| {
format!(
"natural_gradient_block_apply n={n} overflows preconditioner block cell count. Fix: shard the gradient block before GPU dispatch."
)
})?;
let t = Expr::InvocationId { axis: 0 };
let body = vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(n)),
vec![
Node::let_bind("acc", Expr::u32(0)),
Node::let_bind("row_base", Expr::mul(t.clone(), Expr::u32(n))),
Node::loop_for(
"j",
Expr::u32(0),
Expr::u32(n),
vec![Node::assign(
"acc",
Expr::add(
Expr::var("acc"),
crate::fixed_mul_16_16_expr(
Expr::load(
m_inv_sqrt,
Expr::add(Expr::var("row_base"), Expr::var("j")),
),
Expr::load(grad, Expr::var("j")),
),
),
)],
),
Node::store(grad_nat, t, Expr::var("acc")),
],
)];
Ok(Program::wrapped(
vec![
BufferDecl::storage(m_inv_sqrt, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(matrix_cells),
BufferDecl::storage(grad, 1, BufferAccess::ReadOnly, DataType::U32).with_count(n),
BufferDecl::storage(grad_nat, 2, BufferAccess::ReadWrite, DataType::U32).with_count(n),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}],
))
}
#[must_use]
#[cfg(any(test, feature = "cpu-parity"))]
pub fn natural_gradient_block_apply_cpu(m_inv_sqrt: &[f64], grad: &[f64], n: u32) -> Vec<f64> {
let mut out = Vec::new();
try_natural_gradient_block_apply_cpu_into(m_inv_sqrt, grad, n, &mut out)
.unwrap_or_else(|error| panic!("{error}"));
out
}
#[cfg(any(test, feature = "cpu-parity"))]
pub fn natural_gradient_block_apply_cpu_into(
m_inv_sqrt: &[f64],
grad: &[f64],
n: u32,
out: &mut Vec<f64>,
) {
try_natural_gradient_block_apply_cpu_into(m_inv_sqrt, grad, n, out)
.unwrap_or_else(|error| panic!("{error}"));
}
#[cfg(any(test, feature = "cpu-parity"))]
pub fn try_natural_gradient_block_apply_cpu_into(
m_inv_sqrt: &[f64],
grad: &[f64],
n: u32,
out: &mut Vec<f64>,
) -> Result<(), String> {
let n = n as usize;
n.checked_mul(n).ok_or_else(|| {
format!(
"natural_gradient_block_apply CPU oracle n={n} overflows preconditioner block indexing. Fix: shard the gradient block before parity evaluation."
)
})?;
if n > out.capacity() {
crate::graph::scratch::reserve_graph_items(
out,
n - out.len(),
"natural-gradient CPU oracle",
"natural_gradient_block_apply output",
)?;
}
out.clear();
out.resize(n, 0.0);
for i in 0..n {
let mut acc = 0.0;
for j in 0..n {
let Some(&m) = m_inv_sqrt.get(i * n + j) else {
continue;
};
let Some(&g) = grad.get(j) else {
continue;
};
acc += m * g;
}
out[i] = acc;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < 1e-10 * (1.0 + a.abs() + b.abs())
}
#[test]
fn cpu_identity_passthrough() {
let i_mat = vec![1.0, 0.0, 0.0, 1.0];
let g = vec![3.0, 5.0];
let out = natural_gradient_block_apply_cpu(&i_mat, &g, 2);
assert!(approx_eq(out[0], 3.0));
assert!(approx_eq(out[1], 5.0));
}
#[test]
fn cpu_diagonal_scales_each_component() {
let m = vec![0.5, 0.0, 0.0, 2.0];
let g = vec![4.0, 3.0];
let out = natural_gradient_block_apply_cpu(&m, &g, 2);
assert!(approx_eq(out[0], 2.0)); assert!(approx_eq(out[1], 6.0)); }
#[test]
fn cpu_off_diagonal_couples() {
let m = vec![0.0, 1.0, 1.0, 0.0];
let g = vec![7.0, 11.0];
let out = natural_gradient_block_apply_cpu(&m, &g, 2);
assert!(approx_eq(out[0], 11.0));
assert!(approx_eq(out[1], 7.0));
}
#[test]
fn cpu_into_reuses_output_buffer() {
let m = vec![0.5, 0.0, 0.0, 2.0];
let g = vec![4.0, 3.0];
let mut out = Vec::with_capacity(8);
out.extend_from_slice(&[99.0, 98.0, 97.0, 96.0]);
let ptr = out.as_ptr();
let capacity = out.capacity();
natural_gradient_block_apply_cpu_into(&m, &g, 2, &mut out);
assert!(approx_eq(out[0], 2.0));
assert!(approx_eq(out[1], 6.0));
assert_eq!(out.len(), 2);
assert_eq!(out.as_ptr(), ptr);
assert_eq!(out.capacity(), capacity);
natural_gradient_block_apply_cpu_into(&[3.0], &[5.0], 1, &mut out);
assert_eq!(out, vec![15.0]);
assert_eq!(out.as_ptr(), ptr);
assert_eq!(out.capacity(), capacity);
}
#[test]
fn cpu_malformed_inputs_treat_missing_entries_as_zero() {
let out = natural_gradient_block_apply_cpu(&[2.0], &[3.0], 2);
assert_eq!(out.len(), 2);
assert!(approx_eq(out[0], 6.0));
assert!(approx_eq(out[1], 0.0));
}
#[test]
fn ir_program_buffer_layout() {
let p = natural_gradient_block_apply("M", "g", "gn", 8);
assert_eq!(p.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["M", "g", "gn"]);
assert_eq!(p.buffers[0].count(), 64); assert_eq!(p.buffers[1].count(), 8); assert_eq!(p.buffers[2].count(), 8); }
#[test]
fn zero_n_traps() {
let p = natural_gradient_block_apply("M", "g", "gn", 0);
assert!(p.stats().trap());
}
#[test]
fn checked_builder_rejects_preconditioner_cell_overflow() {
let error = try_natural_gradient_block_apply("M", "g", "gn", u32::MAX)
.expect_err("checked natural-gradient builder must reject n*n overflow");
assert!(
error.contains("overflows preconditioner block cell count"),
"error should describe preconditioner shape overflow: {error}"
);
}
}