use crate::sigmoid::sigmoid_derivative;
use crate::{soft_rank, Error, Result};
pub fn soft_rank_hessian_diag(x: &[f64], temperature: f64) -> Result<Vec<f64>> {
if x.is_empty() {
return Err(Error::EmptyInput);
}
if temperature <= 0.0 {
return Err(Error::InvalidTemperature(temperature));
}
let n = x.len();
let mut hessian = vec![0.0; n];
for i in 0..n {
for j in 0..n {
if i != j {
let z = (x[j] - x[i]) / temperature;
hessian[i] += sigmoid_derivative(z);
}
}
hessian[i] /= temperature;
}
Ok(hessian)
}
pub fn damped_newton_gradient(gradient: &[f64], hessian_diag: &[f64], damping: f64) -> Vec<f64> {
gradient
.iter()
.zip(hessian_diag.iter())
.map(|(&g, &h)| g / (h + damping))
.collect()
}
pub fn newton_soft_rank_loss(
predictions: &[f64],
targets: &[f64],
temperature: f64,
) -> (f64, Vec<f64>) {
let n = predictions.len();
assert_eq!(n, targets.len(), "length mismatch");
assert!(n >= 2, "need at least 2 elements");
assert!(temperature > 0.0, "temperature must be positive");
let pred_ranks = soft_rank(predictions, temperature).expect("valid input");
let target_ranks = soft_rank(targets, temperature).expect("valid input");
let mut loss = 0.0;
let mut residuals = vec![0.0; n];
for i in 0..n {
residuals[i] = pred_ranks[i] - target_ranks[i];
loss += residuals[i] * residuals[i];
}
loss /= n as f64;
let mut raw_gradient = vec![0.0; n];
for k in 0..n {
let mut grad_k = 0.0;
for i in 0..n {
let jacobian_ik = if i == k {
let mut s = 0.0;
for j in 0..n {
if j != i {
s += sigmoid_derivative((predictions[j] - predictions[i]) / temperature);
}
}
s / temperature
} else {
-sigmoid_derivative((predictions[k] - predictions[i]) / temperature) / temperature
};
grad_k += residuals[i] * jacobian_ik;
}
raw_gradient[k] = 2.0 * grad_k / n as f64;
}
let hessian_diag = soft_rank_hessian_diag(predictions, temperature).expect("valid input");
let damping = 1e-8;
let newton_grad = damped_newton_gradient(&raw_gradient, &hessian_diag, damping);
(loss, newton_grad)
}
#[cfg(test)]
mod tests {
use super::*;
fn l2_norm(v: &[f64]) -> f64 {
v.iter().map(|x| x * x).sum::<f64>().sqrt()
}
#[test]
fn newton_gradients_smaller_than_raw_when_curvature_high() {
let predictions = vec![0.5, 0.2, 0.8, 0.1, 0.9];
let targets = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let temperature = 0.1;
let n = predictions.len();
let pred_ranks = soft_rank(&predictions, temperature).unwrap();
let target_ranks = soft_rank(&targets, temperature).unwrap();
let mut residuals = vec![0.0; n];
for i in 0..n {
residuals[i] = pred_ranks[i] - target_ranks[i];
}
let mut raw_gradient = vec![0.0; n];
for k in 0..n {
let mut grad_k = 0.0;
for i in 0..n {
let jacobian_ik = if i == k {
let mut s = 0.0;
for j in 0..n {
if j != i {
s +=
sigmoid_derivative((predictions[j] - predictions[i]) / temperature);
}
}
s / temperature
} else {
-sigmoid_derivative((predictions[k] - predictions[i]) / temperature)
/ temperature
};
grad_k += residuals[i] * jacobian_ik;
}
raw_gradient[k] = 2.0 * grad_k / n as f64;
}
let (_, newton_grad) = newton_soft_rank_loss(&predictions, &targets, temperature);
let raw_norm = l2_norm(&raw_gradient);
let newton_norm = l2_norm(&newton_grad);
assert!(
newton_norm < raw_norm,
"Newton norm ({newton_norm:.6}) should be < raw norm ({raw_norm:.6})"
);
}
#[test]
fn damping_prevents_division_by_zero() {
let gradient = vec![1.0, 2.0, 3.0];
let hessian_diag = vec![0.0, 0.0, 0.0];
let damping = 1.0;
let result = damped_newton_gradient(&gradient, &hessian_diag, damping);
for (r, g) in result.iter().zip(gradient.iter()) {
assert!((r - g).abs() < 1e-10);
}
}
#[test]
fn high_temperature_hessian_approximately_constant() {
let x = vec![0.1, 0.5, 0.9, 0.3, 0.7];
let temperature = 100.0;
let hessian = soft_rank_hessian_diag(&x, temperature).unwrap();
let mean_h: f64 = hessian.iter().sum::<f64>() / hessian.len() as f64;
for (i, &h) in hessian.iter().enumerate() {
let rel_diff = (h - mean_h).abs() / mean_h;
assert!(
rel_diff < 0.01,
"Hessian entry {i} ({h:.6}) deviates from mean ({mean_h:.6})"
);
}
let expected = (x.len() - 1) as f64 * 0.25 / temperature;
assert!(
(mean_h - expected).abs() / expected < 0.01,
"mean Hessian ({mean_h:.6}) should be ~ (n-1)*0.25/tau ({expected:.6})"
);
}
#[test]
fn soft_rank_loss_returns_finite() {
let predictions = vec![0.3, 0.7, 0.1, 0.9];
let targets = vec![1.0, 2.0, 3.0, 4.0];
let (loss, grad) = newton_soft_rank_loss(&predictions, &targets, 0.5);
assert!(loss.is_finite());
assert!(loss >= 0.0);
for g in &grad {
assert!(g.is_finite());
}
}
#[test]
fn perfect_ranking_has_zero_loss() {
let predictions = vec![1.0, 2.0, 3.0, 4.0];
let targets = vec![10.0, 20.0, 30.0, 40.0];
let (loss, _) = newton_soft_rank_loss(&predictions, &targets, 0.1);
assert!(loss < 1e-6, "loss should be near zero, got {loss}");
}
}