pub fn natural_gradient_softmax(grad: &[f64], softmax_probs: &[f64]) -> Vec<f64> {
assert_eq!(
grad.len(),
softmax_probs.len(),
"grad and softmax_probs must have the same length"
);
let n = grad.len();
if n == 0 {
return vec![];
}
let grad_sum: f64 = grad.iter().sum();
grad.iter()
.zip(softmax_probs.iter())
.map(|(&g, &p)| {
if p < 1e-30 {
0.0
} else {
g / p - grad_sum
}
})
.collect()
}
pub fn fisher_information_softmax(softmax_probs: &[f64]) -> Vec<f64> {
let n = softmax_probs.len();
let mut fisher = vec![0.0; n * n];
for i in 0..n {
for j in 0..n {
let val = if i == j {
softmax_probs[i] * (1.0 - softmax_probs[i])
} else {
-softmax_probs[i] * softmax_probs[j]
};
fisher[i * n + j] = val;
}
}
fisher
}
pub fn with_natural_gradient<F>(loss_grad_fn: F, scores: &[f64]) -> Vec<f64>
where
F: Fn(&[f64]) -> Vec<f64>,
{
let n = scores.len();
if n == 0 {
return vec![];
}
let probs = softmax(scores);
let grad = loss_grad_fn(scores);
natural_gradient_softmax(&grad, &probs)
}
fn softmax(scores: &[f64]) -> Vec<f64> {
let n = scores.len();
if n == 0 {
return vec![];
}
let max_s = scores.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let exps: Vec<f64> = scores.iter().map(|&s| (s - max_s).exp()).collect();
let sum: f64 = exps.iter().sum();
exps.iter().map(|&e| e / sum).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn uniform_probs_scales_gradient() {
let n = 4;
let probs = vec![0.25; n];
let grad = vec![0.3, -0.1, 0.2, -0.4];
let nat_grad = natural_gradient_softmax(&grad, &probs);
let grad_sum: f64 = grad.iter().sum();
for i in 0..n {
let expected = grad[i] / probs[i] - grad_sum;
assert!(
(nat_grad[i] - expected).abs() < 1e-10,
"i={i}: got {}, expected {expected}",
nat_grad[i]
);
}
}
#[test]
fn fisher_is_symmetric() {
let probs = vec![0.1, 0.3, 0.4, 0.2];
let n = probs.len();
let fisher = fisher_information_softmax(&probs);
for i in 0..n {
for j in 0..n {
assert!(
(fisher[i * n + j] - fisher[j * n + i]).abs() < 1e-15,
"Fisher not symmetric at [{i}][{j}]"
);
}
}
}
#[test]
fn fisher_is_psd() {
let probs = vec![0.2, 0.5, 0.3];
let n = probs.len();
let fisher = fisher_information_softmax(&probs);
let test_vecs: Vec<Vec<f64>> = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
vec![1.0, 1.0, 1.0],
vec![1.0, -1.0, 0.0],
vec![1.0, -0.5, -0.5],
vec![-2.0, 1.0, 1.0],
];
for v in &test_vecs {
let mut vtfv = 0.0;
for i in 0..n {
for j in 0..n {
vtfv += v[i] * fisher[i * n + j] * v[j];
}
}
assert!(vtfv >= -1e-15, "v^T F v = {vtfv} < 0 for v = {v:?}");
}
}
#[test]
fn fisher_null_space_is_ones() {
let probs = vec![0.15, 0.35, 0.25, 0.25];
let n = probs.len();
let fisher = fisher_information_softmax(&probs);
let ones = vec![1.0; n];
for i in 0..n {
let mut row_dot = 0.0;
for j in 0..n {
row_dot += fisher[i * n + j] * ones[j];
}
assert!(
row_dot.abs() < 1e-15,
"F * ones is not zero at row {i}: got {row_dot}"
);
}
}
#[test]
fn rare_items_get_larger_updates() {
let probs = vec![0.05, 0.45, 0.50];
let grad = vec![1.0, 1.0, 1.0];
let nat_grad = natural_gradient_softmax(&grad, &probs);
assert!(
nat_grad[0].abs() > nat_grad[2].abs(),
"Rare item grad {} should exceed common item grad {}",
nat_grad[0].abs(),
nat_grad[2].abs()
);
}
#[test]
fn three_item_concrete() {
let scores = vec![2.0, 1.0, 0.0];
let probs = softmax(&scores);
let grad = vec![-0.5, 0.2, 0.3];
let nat_grad = natural_gradient_softmax(&grad, &probs);
let grad_sum: f64 = grad.iter().sum();
for i in 0..3 {
let expected = grad[i] / probs[i] - grad_sum;
assert!(
(nat_grad[i] - expected).abs() < 1e-10,
"i={i}: got {}, expected {expected}",
nat_grad[i]
);
}
}
#[test]
fn with_natural_gradient_integration() {
let scores = vec![1.0, 2.0, 3.0, 0.5];
let result = with_natural_gradient(
|s| {
s.iter().map(|&x| 2.0 * x).collect()
},
&scores,
);
assert_eq!(result.len(), scores.len());
assert!(result.iter().all(|&v| v.is_finite()));
}
#[test]
fn empty_input() {
assert!(natural_gradient_softmax(&[], &[]).is_empty());
assert!(fisher_information_softmax(&[]).is_empty());
assert!(with_natural_gradient(|_| vec![], &[]).is_empty());
}
#[test]
fn single_element() {
let probs = vec![1.0]; let grad = vec![0.5];
let nat_grad = natural_gradient_softmax(&grad, &probs);
assert!((nat_grad[0]).abs() < 1e-15);
let fisher = fisher_information_softmax(&probs);
assert!((fisher[0]).abs() < 1e-15);
}
}