use crate::metrics::evaluation::EvaluationMetric;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Deserialize, Serialize, Eq, PartialEq)]
pub enum GainScheme {
Jarvelin,
Burges,
}
pub struct NDCGMetric {
k: Option<u64>,
gain: GainScheme,
}
impl NDCGMetric {
pub fn new(k: Option<u64>, gain: GainScheme) -> Self {
Self { k, gain }
}
}
impl EvaluationMetric for NDCGMetric {
fn calculate_metric(y: &[f64], yhat: &[f64], sample_weight: &[f64], group: &[u64], _alpha: Option<f32>) -> f64 {
let metric = NDCGMetric {
k: None,
gain: GainScheme::Burges,
};
metric.calculate_metric_with_params(y, yhat, sample_weight, group, None, &GainScheme::Burges)
}
fn maximize() -> bool {
true
}
}
impl NDCGMetric {
pub fn calculate_metric_with_params(
&self,
y: &[f64],
yhat: &[f64],
sample_weight: &[f64],
group: &[u64],
_alpha: Option<f32>,
_default_gain: &GainScheme,
) -> f64 {
ndcg_at_k_metric(y, yhat, sample_weight, group, self.k, &self.gain)
}
}
#[inline]
fn compute_discount(rank: usize) -> f64 {
1.0 / ((rank + 2) as f64).log2()
}
#[inline]
fn compute_gain(relevance: f64, scheme: &GainScheme) -> f64 {
match scheme {
GainScheme::Jarvelin => relevance,
GainScheme::Burges => 2_f64.powf(relevance) - 1.0,
}
}
fn compute_group_dcg(relevance_scores: &[f64], k: Option<u64>, weights: &[f64], scheme: &GainScheme) -> f64 {
let limit = k
.map(|k| k as usize)
.unwrap_or(relevance_scores.len())
.min(relevance_scores.len());
relevance_scores
.iter()
.zip(weights)
.take(limit)
.enumerate()
.map(|(rank, (&relevance, &weight))| {
let gain = compute_gain(relevance, scheme) * weight;
let discount = compute_discount(rank);
gain * discount
})
.sum()
}
fn compute_group_ndcg(
y_group: &[f64],
yhat_group: &[f64],
weights_group: &[f64],
k: Option<u64>,
scheme: &GainScheme,
) -> f64 {
if y_group.is_empty() {
return 0.0;
}
let mut items: Vec<(f64, f64, f64, usize)> = y_group
.iter()
.zip(yhat_group)
.zip(weights_group)
.enumerate()
.map(|(idx, ((&y, &yhat), &weight))| (y, yhat, weight, idx))
.collect();
items.sort_by(|a, b| b.1.total_cmp(&a.1));
let predicted_relevance: Vec<f64> = items.iter().map(|(y, _, _, _)| *y).collect();
let predicted_weights: Vec<f64> = items.iter().map(|(_, _, weight, _)| *weight).collect();
let dcg = compute_group_dcg(&predicted_relevance, k, &predicted_weights, scheme);
items.sort_by(|a, b| b.0.total_cmp(&a.0));
let ideal_relevance: Vec<f64> = items.iter().map(|(y, _, _, _)| *y).collect();
let ideal_weights: Vec<f64> = items.iter().map(|(_, _, weight, _)| *weight).collect();
let idcg = compute_group_dcg(&ideal_relevance, k, &ideal_weights, scheme);
if idcg > 0.0 { dcg / idcg } else { 0.0 }
}
pub fn ndcg_at_k_metric(
y: &[f64],
yhat: &[f64],
sample_weight: &[f64],
group: &[u64],
k: Option<u64>,
scheme: &GainScheme,
) -> f64 {
if y.is_empty() {
return 0.0;
}
let mut start = 0;
let mut total_ndcg = 0.0;
let mut total_weight = 0.0;
for &group_size in group {
let end = start + group_size as usize;
if end > y.len() {
break;
}
let y_group = &y[start..end];
let yhat_group = &yhat[start..end];
let weights_group = &sample_weight[start..end];
let group_ndcg = compute_group_ndcg(y_group, yhat_group, weights_group, k, scheme);
let group_weight: f64 = weights_group.iter().sum();
total_ndcg += group_ndcg * group_weight;
total_weight += group_weight;
start = end;
}
if total_weight > 0.0 {
total_ndcg / total_weight
} else {
0.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_discount() {
assert!((compute_discount(0) - 1.0).abs() < 1e-10); assert!((compute_discount(1) - 1.0 / 3.0f64.log2()).abs() < 1e-10);
}
#[test]
fn test_compute_gain() {
assert_eq!(compute_gain(3.0, &GainScheme::Jarvelin), 3.0);
assert_eq!(compute_gain(3.0, &GainScheme::Burges), 7.0); }
#[test]
fn test_compute_group_dcg() {
let relevance = vec![3.0, 2.0, 1.0];
let weights = vec![1.0, 1.0, 1.0];
let dcg = compute_group_dcg(&relevance, None, &weights, &GainScheme::Jarvelin);
let expected = 3.0 / 2.0f64.log2() + 2.0 / 3.0f64.log2() + 1.0 / 4.0f64.log2();
assert!((dcg - expected).abs() < 1e-10);
let dcg_k1 = compute_group_dcg(&relevance, Some(1), &weights, &GainScheme::Jarvelin);
assert!((dcg_k1 - 3.0).abs() < 1e-10);
}
#[test]
fn test_compute_group_ndcg() {
let y = vec![1.0, 3.0, 2.0];
let yhat = vec![0.1, 0.3, 0.2]; let weights = vec![1.0, 1.0, 1.0];
let ndcg = compute_group_ndcg(&y, &yhat, &weights, None, &GainScheme::Jarvelin);
assert!((ndcg - 1.0).abs() < 1e-10);
let yhat_bad = vec![0.3, 0.1, 0.2]; let ndcg_bad = compute_group_ndcg(&y, &yhat_bad, &weights, None, &GainScheme::Jarvelin);
assert!(ndcg_bad < 1.0);
}
#[test]
fn test_ndcg_at_k_metric() {
let y = vec![3.0, 2.0, 3.0, 0.0];
let yhat = vec![0.5, 0.4, 0.5, 0.1];
let weights = vec![1.0, 1.0, 1.0, 1.0];
let groups = vec![2, 2];
let ndcg = ndcg_at_k_metric(&y, &yhat, &weights, &groups, None, &GainScheme::Jarvelin);
assert!(ndcg > 0.0 && ndcg <= 1.0);
assert_eq!(ndcg_at_k_metric(&[], &[], &[], &[], None, &GainScheme::Jarvelin), 0.0);
}
#[test]
fn test_ndcg_metric_struct() {
let y = vec![3.0, 2.0];
let yhat = vec![0.5, 0.4];
let weights = vec![1.0, 1.0];
let groups = vec![2];
let val = NDCGMetric::calculate_metric(&y, &yhat, &weights, &groups, None);
assert!(val >= 0.0);
}
}