use crate::{Error, Result};
pub fn lapsum_permutation(scores: &[f64], temperature: f64) -> Result<Vec<f64>> {
if scores.is_empty() {
return Err(Error::EmptyInput);
}
if temperature <= 0.0 {
return Err(Error::InvalidTemperature(temperature));
}
let n = scores.len();
let mut indexed: Vec<(usize, f64)> = scores.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| a.1.total_cmp(&b.1));
let inv_sigma = 1.0 / temperature;
let mut perm = vec![0.0_f64; n * n];
for (pos, &(_, sorted_val)) in indexed.iter().enumerate() {
let row_start = pos * n;
let mut row_sum = 0.0_f64;
for (orig, &score) in scores.iter().enumerate() {
let kernel = (-(sorted_val - score).abs() * inv_sigma).exp();
perm[row_start + orig] = kernel;
row_sum += kernel;
}
if row_sum > 0.0 {
let inv_sum = 1.0 / row_sum;
for j in 0..n {
perm[row_start + j] *= inv_sum;
}
}
}
Ok(perm)
}
pub fn lapsum_sort(scores: &[f64], values: &[f64], temperature: f64) -> Result<Vec<f64>> {
if scores.len() != values.len() {
return Err(Error::LengthMismatch(scores.len(), values.len()));
}
let perm = lapsum_permutation(scores, temperature)?;
let n = scores.len();
let mut result = vec![0.0; n];
for (i, res) in result.iter_mut().enumerate().take(n) {
let row_start = i * n;
for j in 0..n {
*res += perm[row_start + j] * values[j];
}
}
Ok(result)
}
pub fn lapsum_rank(scores: &[f64], temperature: f64) -> Result<Vec<f64>> {
let perm = lapsum_permutation(scores, temperature)?;
let n = scores.len();
let mut ranks = vec![0.0; n];
for i in 0..n {
let row_start = i * n;
let position = (i + 1) as f64;
for j in 0..n {
ranks[j] += perm[row_start + j] * position;
}
}
Ok(ranks)
}
pub fn lapsum_topk(scores: &[f64], k: usize, temperature: f64) -> Result<Vec<f64>> {
let n = scores.len();
if k == 0 || k > n {
return Err(Error::EmptyInput);
}
let perm = lapsum_permutation(scores, temperature)?;
let mut weights = vec![0.0; n];
for i in (n - k)..n {
let row_start = i * n;
for j in 0..n {
weights[j] += perm[row_start + j];
}
}
Ok(weights)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn permutation_rows_sum_to_one() {
let scores = [3.0, 1.0, 4.0, 1.5, 9.0];
let perm = lapsum_permutation(&scores, 0.5).unwrap();
let n = scores.len();
for i in 0..n {
let row_sum: f64 = (0..n).map(|j| perm[i * n + j]).sum();
assert_relative_eq!(row_sum, 1.0, epsilon = 1e-10);
}
}
#[test]
fn permutation_entries_nonnegative() {
let scores = [3.0, 1.0, 4.0, 1.5, 9.0];
let perm = lapsum_permutation(&scores, 0.5).unwrap();
for &val in &perm {
assert!(val >= 0.0);
}
}
#[test]
fn low_sigma_approaches_hard_sort() {
let scores = [3.0, 1.0, 4.0, 2.0];
let sorted = lapsum_sort(&scores, &scores, 0.001).unwrap();
assert_relative_eq!(sorted[0], 1.0, epsilon = 0.1);
assert_relative_eq!(sorted[1], 2.0, epsilon = 0.1);
assert_relative_eq!(sorted[2], 3.0, epsilon = 0.1);
assert_relative_eq!(sorted[3], 4.0, epsilon = 0.1);
}
#[test]
fn low_sigma_ranks_match_hard_ranks() {
let scores = [0.5, 0.2, 0.8, 0.1];
let ranks = lapsum_rank(&scores, 0.001).unwrap();
assert_relative_eq!(ranks[0], 3.0, epsilon = 0.1);
assert_relative_eq!(ranks[1], 2.0, epsilon = 0.1);
assert_relative_eq!(ranks[2], 4.0, epsilon = 0.1);
assert_relative_eq!(ranks[3], 1.0, epsilon = 0.1);
}
#[test]
fn topk_weights_sum_to_k() {
let scores = [0.5, 0.2, 0.8, 0.1, 0.9];
let k = 2;
let weights = lapsum_topk(&scores, k, 0.5).unwrap();
let total: f64 = weights.iter().sum();
assert_relative_eq!(total, k as f64, epsilon = 0.1);
}
#[test]
fn topk_low_sigma_selects_top_elements() {
let scores = [0.5, 0.2, 0.8, 0.1, 0.9];
let weights = lapsum_topk(&scores, 2, 0.001).unwrap();
assert!(weights[2] > 0.9, "weight[2] = {}", weights[2]);
assert!(weights[4] > 0.9, "weight[4] = {}", weights[4]);
assert!(weights[0] < 0.1, "weight[0] = {}", weights[0]);
assert!(weights[1] < 0.1, "weight[1] = {}", weights[1]);
assert!(weights[3] < 0.1, "weight[3] = {}", weights[3]);
}
#[test]
fn symmetry_swapping_scores_swaps_columns() {
let scores_a = [1.0, 3.0, 2.0];
let scores_b = [3.0, 1.0, 2.0]; let sigma = 0.5;
let perm_a = lapsum_permutation(&scores_a, sigma).unwrap();
let perm_b = lapsum_permutation(&scores_b, sigma).unwrap();
let n = 3;
for i in 0..n {
assert_relative_eq!(perm_a[i * n], perm_b[i * n + 1], epsilon = 1e-10);
assert_relative_eq!(perm_a[i * n + 1], perm_b[i * n], epsilon = 1e-10);
assert_relative_eq!(perm_a[i * n + 2], perm_b[i * n + 2], epsilon = 1e-10);
}
}
#[test]
fn sort_preserves_sum() {
let scores = [3.0, 1.0, 4.0, 2.0];
let values = [10.0, 20.0, 30.0, 40.0];
let sorted = lapsum_sort(&scores, &values, 0.5).unwrap();
let orig_sum: f64 = values.iter().sum();
let sorted_sum: f64 = sorted.iter().sum();
assert_relative_eq!(orig_sum, sorted_sum, epsilon = 1e-6);
}
#[test]
fn errors_on_empty_input() {
assert!(lapsum_permutation(&[], 1.0).is_err());
assert!(lapsum_sort(&[], &[], 1.0).is_err());
assert!(lapsum_rank(&[], 1.0).is_err());
}
#[test]
fn errors_on_invalid_temperature() {
let s = [1.0, 2.0];
assert!(lapsum_permutation(&s, 0.0).is_err());
assert!(lapsum_permutation(&s, -1.0).is_err());
}
#[test]
fn errors_on_length_mismatch() {
assert!(lapsum_sort(&[1.0, 2.0], &[1.0], 0.5).is_err());
}
#[test]
fn errors_on_invalid_k() {
let s = [1.0, 2.0, 3.0];
assert!(lapsum_topk(&s, 0, 0.5).is_err());
assert!(lapsum_topk(&s, 4, 0.5).is_err());
}
#[test]
fn single_element() {
let perm = lapsum_permutation(&[5.0], 1.0).unwrap();
assert_relative_eq!(perm[0], 1.0, epsilon = 1e-10);
let sorted = lapsum_sort(&[5.0], &[5.0], 1.0).unwrap();
assert_relative_eq!(sorted[0], 5.0, epsilon = 1e-10);
let ranks = lapsum_rank(&[5.0], 1.0).unwrap();
assert_relative_eq!(ranks[0], 1.0, epsilon = 1e-10);
}
}