use crate::metrics::evaluation::Metric;
use crate::metrics::ranking::GainScheme;
use crate::objective::ObjectiveFunction;
use serde::{Deserialize, Serialize};
#[derive(Default, Debug, Deserialize, Serialize, Clone)]
pub struct ListNetLoss {}
const LOSS_FOR_SINGLE_GROUP: f32 = f32::INFINITY;
const EPSILON: f32 = 1e-15;
#[inline]
fn compute_softmax_inplace(input: &[f64], output: &mut [f32]) {
debug_assert_eq!(input.len(), output.len());
let max_val = input.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let mut sum = 0.0f32;
for (i, &val) in input.iter().enumerate() {
let exp_val = ((val - max_val) as f32).exp();
output[i] = exp_val;
sum += exp_val;
}
if sum > 0.0 {
let inv_sum = 1.0 / sum;
for val in output.iter_mut() {
*val *= inv_sum;
}
}
}
#[inline]
fn compute_listnet_loss(softmax_y: &[f32], softmax_yhat: &[f32], weights: Option<&[f64]>) -> f32 {
match weights {
Some(w) => softmax_y
.iter()
.zip(softmax_yhat)
.zip(w)
.map(|((p_y, p_yhat), weight)| {
if *p_y > 0.0 {
-p_y * p_yhat.max(EPSILON).ln() * (*weight as f32)
} else {
0.0
}
})
.sum(),
None => softmax_y
.iter()
.zip(softmax_yhat)
.map(|(p_y, p_yhat)| {
if *p_y > 0.0 {
-p_y * p_yhat.max(EPSILON).ln()
} else {
0.0
}
})
.sum(),
}
}
#[inline]
fn compute_group_gradients(softmax_y: &[f32], softmax_yhat: &[f32], weights: Option<&[f64]>, output: &mut [f32]) {
match weights {
Some(w) => {
for (i, ((p_yhat, p_y), weight)) in softmax_yhat.iter().zip(softmax_y).zip(w).enumerate() {
output[i] = (p_yhat - p_y) * (*weight as f32);
}
}
None => {
for (i, (p_yhat, p_y)) in softmax_yhat.iter().zip(softmax_y).enumerate() {
output[i] = p_yhat - p_y;
}
}
}
}
#[inline]
fn compute_group_hessian(softmax_yhat: &[f32], weights: Option<&[f64]>, output: &mut [f32]) {
match weights {
Some(w) => {
for (i, (p_yhat, weight)) in softmax_yhat.iter().zip(w).enumerate() {
output[i] = p_yhat * (1.0 - p_yhat) * (*weight as f32);
}
}
None => {
for (i, p_yhat) in softmax_yhat.iter().enumerate() {
output[i] = p_yhat * (1.0 - p_yhat);
}
}
}
}
impl ObjectiveFunction for ListNetLoss {
#[inline]
fn loss(&self, y: &[f64], yhat: &[f64], sample_weight: Option<&[f64]>, group: Option<&[u64]>) -> Vec<f32> {
if y.len() < 2 {
return vec![LOSS_FOR_SINGLE_GROUP; y.len()];
}
if group.is_some_and(|group_sizes| group_sizes.iter().sum::<u64>() != y.len() as u64) {
panic!(
"Sum of group sizes ({}) does not match number of samples ({}).",
group.unwrap().iter().sum::<u64>(),
y.len()
);
}
let mut losses = vec![0.0f32; y.len()];
if let Some(group_sizes) = group {
let mut start = 0;
for &group_size in group_sizes {
let end = start + group_size as usize;
let group_len = group_size as usize;
let y_group = &y[start..end];
let yhat_group = &yhat[start..end];
let weight_group = sample_weight.map(|w| &w[start..end]);
let mut softmax_y = vec![0.0f32; group_len];
let mut softmax_yhat = vec![0.0f32; group_len];
compute_softmax_inplace(y_group, &mut softmax_y);
compute_softmax_inplace(yhat_group, &mut softmax_yhat);
let group_loss = compute_listnet_loss(&softmax_y, &softmax_yhat, weight_group);
let per_sample_loss = group_loss / (group_size as f32);
losses[start..end].fill(per_sample_loss);
start = end;
}
} else {
let mut softmax_y = vec![0.0f32; y.len()];
let mut softmax_yhat = vec![0.0f32; y.len()];
compute_softmax_inplace(y, &mut softmax_y);
compute_softmax_inplace(yhat, &mut softmax_yhat);
let total_loss = compute_listnet_loss(&softmax_y, &softmax_yhat, sample_weight);
let per_sample_loss = total_loss / (y.len() as f32);
losses.fill(per_sample_loss);
}
losses
}
#[inline]
fn gradient(
&self,
y: &[f64],
yhat: &[f64],
sample_weight: Option<&[f64]>,
group: Option<&[u64]>,
) -> (Vec<f32>, Option<Vec<f32>>) {
if y.len() < 2 {
return (vec![0.0f32; y.len()], None);
}
if group.is_some_and(|group_sizes| group_sizes.iter().sum::<u64>() != y.len() as u64) {
panic!(
"Sum of group sizes ({}) does not match number of samples ({}).",
group.unwrap().iter().sum::<u64>(),
y.len()
);
}
let mut gradients = vec![0.0f32; y.len()];
let mut hessians = vec![0.0f32; y.len()];
if let Some(group_sizes) = group {
let mut start = 0;
for &group_size in group_sizes {
let end = start + group_size as usize;
let group_len = group_size as usize;
let y_group = &y[start..end];
let yhat_group = &yhat[start..end];
let weight_group = sample_weight.map(|w| &w[start..end]);
let mut softmax_y = vec![0.0f32; group_len];
let mut softmax_yhat = vec![0.0f32; group_len];
compute_softmax_inplace(y_group, &mut softmax_y);
compute_softmax_inplace(yhat_group, &mut softmax_yhat);
compute_group_gradients(&softmax_y, &softmax_yhat, weight_group, &mut gradients[start..end]);
compute_group_hessian(&softmax_yhat, weight_group, &mut hessians[start..end]);
start = end;
}
} else {
let mut softmax_y = vec![0.0f32; y.len()];
let mut softmax_yhat = vec![0.0f32; y.len()];
compute_softmax_inplace(y, &mut softmax_y);
compute_softmax_inplace(yhat, &mut softmax_yhat);
compute_group_gradients(&softmax_y, &softmax_yhat, sample_weight, &mut gradients);
compute_group_hessian(&softmax_yhat, sample_weight, &mut hessians);
}
(gradients, Some(hessians))
}
#[inline]
fn initial_value(&self, _y: &[f64], _sample_weight: Option<&[f64]>, _group: Option<&[u64]>) -> f64 {
0.0
}
fn default_metric(&self) -> Metric {
Metric::NDCG {
k: None,
gain: GainScheme::Burges,
}
}
fn gradient_and_loss(
&self,
y: &[f64],
yhat: &[f64],
sample_weight: Option<&[f64]>,
group: Option<&[u64]>,
) -> (Vec<f32>, Option<Vec<f32>>, Vec<f32>) {
if y.len() < 2 {
return (vec![0.0f32; y.len()], None, vec![LOSS_FOR_SINGLE_GROUP; y.len()]);
}
if group.is_some_and(|group_sizes| group_sizes.iter().sum::<u64>() != y.len() as u64) {
panic!(
"Sum of group sizes ({}) does not match number of samples ({}).",
group.unwrap().iter().sum::<u64>(),
y.len()
);
}
let mut gradients = vec![0.0f32; y.len()];
let mut hessians = vec![0.0f32; y.len()];
let mut losses = vec![0.0f32; y.len()];
if let Some(group_sizes) = group {
let mut start = 0;
for &group_size in group_sizes {
let end = start + group_size as usize;
let group_len = group_size as usize;
let y_group = &y[start..end];
let yhat_group = &yhat[start..end];
let weight_group = sample_weight.map(|w| &w[start..end]);
let mut softmax_y = vec![0.0f32; group_len];
let mut softmax_yhat = vec![0.0f32; group_len];
compute_softmax_inplace(y_group, &mut softmax_y);
compute_softmax_inplace(yhat_group, &mut softmax_yhat);
compute_group_gradients(&softmax_y, &softmax_yhat, weight_group, &mut gradients[start..end]);
compute_group_hessian(&softmax_yhat, weight_group, &mut hessians[start..end]);
let group_loss = compute_listnet_loss(&softmax_y, &softmax_yhat, weight_group);
let per_sample_loss = group_loss / (group_size as f32);
losses[start..end].fill(per_sample_loss);
start = end;
}
} else {
let mut softmax_y = vec![0.0f32; y.len()];
let mut softmax_yhat = vec![0.0f32; y.len()];
compute_softmax_inplace(y, &mut softmax_y);
compute_softmax_inplace(yhat, &mut softmax_yhat);
compute_group_gradients(&softmax_y, &softmax_yhat, sample_weight, &mut gradients);
compute_group_hessian(&softmax_yhat, sample_weight, &mut hessians);
let total_loss = compute_listnet_loss(&softmax_y, &softmax_yhat, sample_weight);
let per_sample_loss = total_loss / (y.len() as f32);
losses.fill(per_sample_loss);
}
(gradients, Some(hessians), losses)
}
}
impl ListNetLoss {
#[inline]
pub fn loss_single(&self, _y: f64, _yhat: f64, _sample_weight: Option<f64>) -> f32 {
LOSS_FOR_SINGLE_GROUP
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_softmax() {
let input = vec![1.0, 2.0, 3.0];
let mut output = vec![0.0f32; 3];
compute_softmax_inplace(&input, &mut output);
let sum: f32 = output.iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
assert!(output[2] > output[1]);
assert!(output[1] > output[0]);
}
#[test]
fn test_listnet_loss() {
let y = vec![3.0, 1.0, 0.0];
let yhat = vec![2.0, 1.0, 0.5];
let loss_fn = ListNetLoss::default();
let l = loss_fn.loss(&y, &yhat, None, None);
assert_eq!(l.len(), 3);
assert_eq!(l[0], l[1]);
assert_eq!(l[1], l[2]);
let (g, h) = loss_fn.gradient(&y, &yhat, None, None);
let h = h.unwrap();
assert_eq!(g.len(), 3);
assert_eq!(h.len(), 3);
let g_sum: f32 = g.iter().sum();
assert!(g_sum.abs() < 1e-6);
}
#[test]
fn test_listnet_loss_weighted() {
let y = vec![3.0, 1.0, 0.0];
let yhat = vec![2.0, 1.0, 0.5];
let w = vec![2.0, 1.0, 1.0];
let loss_fn = ListNetLoss::default();
let l = loss_fn.loss(&y, &yhat, Some(&w), None);
assert_eq!(l.len(), 3);
let (g, h) = loss_fn.gradient(&y, &yhat, Some(&w), None);
let h = h.unwrap();
assert_eq!(g.len(), 3);
assert_eq!(h.len(), 3);
}
#[test]
fn test_listnet_loss_grouped() {
let y = vec![3.0, 1.0, 0.0, 2.0, 1.0];
let yhat = vec![2.0, 1.0, 0.5, 1.5, 0.5];
let group = vec![3u64, 2];
let loss_fn = ListNetLoss::default();
let l = loss_fn.loss(&y, &yhat, None, Some(&group));
assert_eq!(l.len(), 5);
assert_eq!(l[0], l[1]);
assert_eq!(l[0], l[2]);
assert_eq!(l[3], l[4]);
let (g, h) = loss_fn.gradient(&y, &yhat, None, Some(&group));
let h = h.unwrap();
assert_eq!(g.len(), 5);
assert_eq!(h.len(), 5);
}
#[test]
fn test_listnet_loss_grouped_weighted() {
let y = vec![3.0, 1.0, 0.0, 2.0, 1.0];
let yhat = vec![2.0, 1.0, 0.5, 1.5, 0.5];
let w = vec![2.0, 1.0, 1.0, 1.5, 0.5];
let group = vec![3u64, 2];
let loss_fn = ListNetLoss::default();
let l = loss_fn.loss(&y, &yhat, Some(&w), Some(&group));
assert_eq!(l.len(), 5);
let (g, h) = loss_fn.gradient(&y, &yhat, Some(&w), Some(&group));
assert_eq!(g.len(), 5);
assert!(h.is_some());
}
#[test]
fn test_listnet_gradient_and_loss_no_group() {
let y = vec![3.0, 1.0, 0.0];
let yhat = vec![2.0, 1.0, 0.5];
let loss_fn = ListNetLoss::default();
let (g, h, l) = loss_fn.gradient_and_loss(&y, &yhat, None, None);
assert_eq!(g.len(), 3);
assert!(h.is_some());
assert_eq!(l.len(), 3);
}
#[test]
fn test_listnet_gradient_and_loss_grouped() {
let y = vec![3.0, 1.0, 0.0, 2.0, 1.0];
let yhat = vec![2.0, 1.0, 0.5, 1.5, 0.5];
let group = vec![3u64, 2];
let loss_fn = ListNetLoss::default();
let (g, h, l) = loss_fn.gradient_and_loss(&y, &yhat, None, Some(&group));
assert_eq!(g.len(), 5);
assert!(h.is_some());
assert_eq!(l.len(), 5);
}
#[test]
fn test_listnet_gradient_and_loss_weighted_grouped() {
let y = vec![3.0, 1.0, 0.0, 2.0, 1.0];
let yhat = vec![2.0, 1.0, 0.5, 1.5, 0.5];
let w = vec![2.0, 1.0, 1.0, 1.5, 0.5];
let group = vec![3u64, 2];
let loss_fn = ListNetLoss::default();
let (g, h, l) = loss_fn.gradient_and_loss(&y, &yhat, Some(&w), Some(&group));
assert_eq!(g.len(), 5);
assert!(h.is_some());
assert_eq!(l.len(), 5);
}
#[test]
fn test_listnet_loss_single() {
let loss_fn = ListNetLoss::default();
let l = loss_fn.loss_single(1.0, 2.0, None);
assert_eq!(l, f32::INFINITY);
}
#[test]
fn test_listnet_initial_value() {
let loss_fn = ListNetLoss::default();
assert_eq!(loss_fn.initial_value(&[1.0, 2.0], None, None), 0.0);
}
#[test]
fn test_listnet_small_input() {
let loss_fn = ListNetLoss::default();
let l = loss_fn.loss(&[1.0], &[2.0], None, None);
assert_eq!(l.len(), 1);
assert_eq!(l[0], f32::INFINITY);
let (g, _h) = loss_fn.gradient(&[1.0], &[2.0], None, None);
assert_eq!(g[0], 0.0);
}
#[test]
fn test_listnet_gradient_and_loss_small() {
let loss_fn = ListNetLoss::default();
let (g, h, l) = loss_fn.gradient_and_loss(&[1.0], &[2.0], None, None);
assert_eq!(g.len(), 1);
assert!(h.is_none());
assert_eq!(l[0], f32::INFINITY);
}
#[test]
fn test_compute_listnet_loss_weighted() {
let softmax_y = vec![0.5, 0.3, 0.2];
let softmax_yhat = vec![0.6, 0.3, 0.1];
let w = vec![2.0, 1.0, 1.0];
let l_w = compute_listnet_loss(&softmax_y, &softmax_yhat, Some(&w));
let l_nw = compute_listnet_loss(&softmax_y, &softmax_yhat, None);
assert!(l_w > 0.0);
assert!(l_nw > 0.0);
}
#[test]
fn test_compute_group_gradients_weighted() {
let softmax_y = vec![0.5, 0.3, 0.2];
let softmax_yhat = vec![0.6, 0.3, 0.1];
let w = vec![2.0, 1.0, 1.0];
let mut output = vec![0.0f32; 3];
compute_group_gradients(&softmax_y, &softmax_yhat, Some(&w), &mut output);
assert_eq!(output.len(), 3);
}
#[test]
fn test_compute_group_hessian_weighted() {
let softmax_yhat = vec![0.6, 0.3, 0.1];
let w = vec![2.0, 1.0, 1.0];
let mut output = vec![0.0f32; 3];
compute_group_hessian(&softmax_yhat, Some(&w), &mut output);
assert_eq!(output.len(), 3);
assert!(output[0] > 0.0);
}
}