use mlua::prelude::*;
use crate::stats::{mean_impl, table_to_vec};
fn rank_impl(values: &[f64]) -> Vec<f64> {
let n = values.len();
let mut indexed: Vec<(f64, usize)> = values
.iter()
.copied()
.enumerate()
.map(|(i, v)| (v, i))
.collect();
indexed.sort_by(|a, b| a.0.total_cmp(&b.0));
let mut ranks = vec![0.0; n];
let mut i = 0;
while i < n {
let mut j = i + 1;
while j < n && indexed[j].0 == indexed[i].0 {
j += 1;
}
let avg_rank = (i + 1 + j) as f64 / 2.0;
for item in indexed.iter().take(j).skip(i) {
ranks[item.1] = avg_rank;
}
i = j;
}
ranks
}
fn spearman_impl(xs: &[f64], ys: &[f64]) -> Result<f64, &'static str> {
if xs.len() != ys.len() {
return Err("arrays must have equal length");
}
if xs.len() < 2 {
return Err("need at least 2 values");
}
let rx = rank_impl(xs);
let ry = rank_impl(ys);
let mean_rx = mean_impl(&rx);
let mean_ry = mean_impl(&ry);
let mut cov = 0.0;
let mut var_x = 0.0;
let mut var_y = 0.0;
for (&x, &y) in rx.iter().zip(ry.iter()) {
let dx = x - mean_rx;
let dy = y - mean_ry;
cov += dx * dy;
var_x += dx * dx;
var_y += dy * dy;
}
let denom = (var_x * var_y).sqrt();
if denom == 0.0 {
return Err("zero variance in ranks");
}
Ok(cov / denom)
}
fn kendall_tau_impl(xs: &[f64], ys: &[f64]) -> Result<f64, &'static str> {
if xs.len() != ys.len() {
return Err("arrays must have equal length");
}
let n = xs.len();
if n < 2 {
return Err("need at least 2 values");
}
let mut concordant: i64 = 0;
let mut discordant: i64 = 0;
let mut ties_x: i64 = 0;
let mut ties_y: i64 = 0;
for i in 0..n {
for j in (i + 1)..n {
let dx = xs[i].total_cmp(&xs[j]);
let dy = ys[i].total_cmp(&ys[j]);
match (dx, dy) {
(std::cmp::Ordering::Equal, std::cmp::Ordering::Equal) => {
ties_x += 1;
ties_y += 1;
}
(std::cmp::Ordering::Equal, _) => ties_x += 1,
(_, std::cmp::Ordering::Equal) => ties_y += 1,
_ if dx == dy => concordant += 1,
_ => discordant += 1,
}
}
}
let n0 = n as f64 * (n - 1) as f64 / 2.0;
let denom = ((n0 - ties_x as f64) * (n0 - ties_y as f64)).sqrt();
if denom == 0.0 {
return Err("zero variance (all values tied)");
}
Ok((concordant - discordant) as f64 / denom)
}
fn ndcg_impl(relevance: &[f64], k: usize) -> f64 {
if k == 0 || relevance.is_empty() {
return 0.0;
}
let k = k.min(relevance.len());
let dcg: f64 = relevance[..k]
.iter()
.enumerate()
.map(|(i, &r)| r / (2.0 + i as f64).log2())
.sum();
let mut ideal = relevance.to_vec();
ideal.sort_by(|a, b| b.total_cmp(a));
let idcg: f64 = ideal[..k]
.iter()
.enumerate()
.map(|(i, &r)| r / (2.0 + i as f64).log2())
.sum();
if idcg == 0.0 {
return 0.0;
}
dcg / idcg
}
fn mrr_impl(rankings: &[f64]) -> Result<f64, &'static str> {
if rankings.is_empty() {
return Err("need at least 1 ranking");
}
for &r in rankings {
if r < 1.0 {
return Err("rank values must be >= 1");
}
if r.fract() != 0.0 {
return Err("rank values must be integers");
}
}
let sum: f64 = rankings.iter().map(|&r| 1.0 / r).sum();
Ok(sum / rankings.len() as f64)
}
pub(crate) fn register(lua: &Lua, t: &LuaTable) -> LuaResult<()> {
t.set(
"rank",
lua.create_function(|lua, table: LuaTable| {
let v = table_to_vec(&table)?;
let ranks = rank_impl(&v);
let out = lua.create_table()?;
for (i, r) in ranks.into_iter().enumerate() {
out.raw_set(i + 1, r)?;
}
Ok(out)
})?,
)?;
t.set(
"spearman_correlation",
lua.create_function(|_, (xs_t, ys_t): (LuaTable, LuaTable)| {
let xs = table_to_vec(&xs_t)?;
let ys = table_to_vec(&ys_t)?;
spearman_impl(&xs, &ys)
.map_err(|e| LuaError::runtime(format!("spearman_correlation: {e}")))
})?,
)?;
t.set(
"kendall_tau",
lua.create_function(|_, (xs_t, ys_t): (LuaTable, LuaTable)| {
let xs = table_to_vec(&xs_t)?;
let ys = table_to_vec(&ys_t)?;
kendall_tau_impl(&xs, &ys).map_err(|e| LuaError::runtime(format!("kendall_tau: {e}")))
})?,
)?;
t.set(
"ndcg",
lua.create_function(|_, (table, k): (LuaTable, usize)| {
let v = table_to_vec(&table)?;
Ok(ndcg_impl(&v, k))
})?,
)?;
t.set(
"mrr",
lua.create_function(|_, table: LuaTable| {
let v = table_to_vec(&table)?;
mrr_impl(&v).map_err(|e| LuaError::runtime(format!("mrr: {e}")))
})?,
)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rank_no_ties() {
let ranks = rank_impl(&[3.0, 1.0, 2.0]);
assert_eq!(ranks, vec![3.0, 1.0, 2.0]);
}
#[test]
fn rank_with_ties() {
let ranks = rank_impl(&[3.0, 1.0, 3.0]);
assert_eq!(ranks, vec![2.5, 1.0, 2.5]);
}
#[test]
fn rank_all_equal() {
let ranks = rank_impl(&[5.0, 5.0, 5.0]);
assert_eq!(ranks, vec![2.0, 2.0, 2.0]);
}
#[test]
fn spearman_perfect() {
let r = spearman_impl(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0]).unwrap();
assert!((r - 1.0).abs() < 1e-10);
}
#[test]
fn spearman_inverse() {
let r = spearman_impl(&[1.0, 2.0, 3.0, 4.0, 5.0], &[5.0, 4.0, 3.0, 2.0, 1.0]).unwrap();
assert!((r - (-1.0)).abs() < 1e-10);
}
#[test]
fn kendall_tau_concordant() {
let tau = kendall_tau_impl(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0]).unwrap();
assert!((tau - 1.0).abs() < 1e-10);
}
#[test]
fn kendall_tau_discordant() {
let tau = kendall_tau_impl(&[1.0, 2.0, 3.0], &[3.0, 2.0, 1.0]).unwrap();
assert!((tau - (-1.0)).abs() < 1e-10);
}
#[test]
fn ndcg_perfect() {
let score = ndcg_impl(&[3.0, 2.0, 1.0], 3);
assert!((score - 1.0).abs() < 1e-10);
}
#[test]
fn ndcg_imperfect() {
let score = ndcg_impl(&[1.0, 2.0, 3.0], 3);
assert!(score < 1.0);
assert!(score > 0.0);
}
#[test]
fn mrr_basic() {
let r = mrr_impl(&[1.0, 2.0, 5.0]).unwrap();
let expected = (1.0 + 0.5 + 0.2) / 3.0;
assert!((r - expected).abs() < 1e-10);
}
}