use vyre_primitives::math::natural_gradient::{
natural_gradient_block_apply_cpu, natural_gradient_block_apply_cpu_into,
};
#[must_use]
pub fn precondition_autotune_gradient(m_inv_sqrt: &[f64], grad: &[f64], n: u32) -> Vec<f64> {
use crate::observability::{bump, natural_gradient_autotuner_calls};
bump(&natural_gradient_autotuner_calls);
natural_gradient_block_apply_cpu(m_inv_sqrt, grad, n)
}
pub fn precondition_autotune_gradient_into(
m_inv_sqrt: &[f64],
grad: &[f64],
n: u32,
out: &mut Vec<f64>,
) {
use crate::observability::{bump, natural_gradient_autotuner_calls};
bump(&natural_gradient_autotuner_calls);
natural_gradient_block_apply_cpu_into(m_inv_sqrt, grad, n, out);
}
#[must_use]
pub fn autotune_step(m_inv_sqrt: &[f64], grad: &[f64], n: u32, learning_rate: f64) -> Vec<f64> {
let mut out = Vec::new();
autotune_step_into(m_inv_sqrt, grad, n, learning_rate, &mut out);
out
}
pub fn autotune_step_into(
m_inv_sqrt: &[f64],
grad: &[f64],
n: u32,
learning_rate: f64,
out: &mut Vec<f64>,
) {
precondition_autotune_gradient_into(m_inv_sqrt, grad, n, out);
for value in out.iter_mut() {
*value *= -learning_rate;
}
}
#[must_use]
pub fn identity_fisher_block(n: u32) -> Vec<f64> {
let mut out = Vec::new();
identity_fisher_block_into(n, &mut out);
out
}
pub fn identity_fisher_block_into(n: u32, out: &mut Vec<f64>) {
let n_us = n as usize;
out.clear();
out.resize(n_us * n_us, 0.0);
for i in 0..n_us {
out[i * n_us + i] = 1.0;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < 1e-9 * (1.0 + a.abs() + b.abs())
}
#[test]
fn identity_fisher_recovers_plain_gradient() {
let id = identity_fisher_block(3);
let grad = vec![1.0, -2.0, 0.5];
let g_nat = precondition_autotune_gradient(&id, &grad, 3);
for (a, b) in grad.iter().zip(g_nat.iter()) {
assert!(approx_eq(*a, *b));
}
}
#[test]
fn autotune_step_negates_gradient() {
let id = identity_fisher_block(2);
let grad = vec![1.0, 2.0];
let step = autotune_step(&id, &grad, 2, 0.1);
assert!(approx_eq(step[0], -0.1));
assert!(approx_eq(step[1], -0.2));
}
#[test]
fn autotune_step_zero_lr_no_motion() {
let id = identity_fisher_block(3);
let grad = vec![1.0, 2.0, 3.0];
let step = autotune_step(&id, &grad, 3, 0.0);
for v in step {
assert!(approx_eq(v, 0.0));
}
}
#[test]
fn autotune_step_into_reuses_output() {
let id = identity_fisher_block(2);
let grad = vec![1.0, 2.0];
let mut step = Vec::with_capacity(8);
let ptr = step.as_ptr();
autotune_step_into(&id, &grad, 2, 0.1, &mut step);
assert!(approx_eq(step[0], -0.1));
assert!(approx_eq(step[1], -0.2));
assert_eq!(step.as_ptr(), ptr);
}
#[test]
fn diagonal_fisher_scales_per_axis() {
let m_inv_sqrt = vec![1.0, 0.0, 0.0, 0.5];
let grad = vec![10.0, 10.0];
let g_nat = precondition_autotune_gradient(&m_inv_sqrt, &grad, 2);
assert!(approx_eq(g_nat[0], 10.0));
assert!(approx_eq(g_nat[1], 5.0));
}
#[test]
fn identity_fisher_block_is_diagonal_of_ones() {
let id = identity_fisher_block(4);
for i in 0..4 {
assert!(approx_eq(id[i * 4 + i], 1.0));
for j in 0..4 {
if i != j {
assert!(approx_eq(id[i * 4 + j], 0.0));
}
}
}
}
}