use crate::sigmoid::sigmoid;
use crate::soft_rank;
use crate::sorting_network::{relaxed_sigmoid, DiffSortNet, NetworkType, RelaxDist};
use crate::{Error, Result};
pub fn differentiable_topk(values: &[f64], k: usize, temperature: f64) -> (Vec<f64>, Vec<f64>) {
let n = values.len();
if n == 0 || k == 0 {
return (vec![], vec![]);
}
if k >= n {
let indicators = vec![1.0; n];
return (values.to_vec(), indicators);
}
let ranks = match soft_rank(values, temperature) {
Ok(r) => r,
Err(_) => return (vec![0.0; n], vec![0.0; n]),
};
let threshold = k as f64 + 0.5;
let mut weighted_values = Vec::with_capacity(n);
let mut indicators = Vec::with_capacity(n);
for i in 0..n {
let indicator = sigmoid((threshold - ranks[i]) / temperature);
indicators.push(indicator);
weighted_values.push(values[i] * indicator);
}
(weighted_values, indicators)
}
pub fn differentiable_bottomk(values: &[f64], k: usize, temperature: f64) -> (Vec<f64>, Vec<f64>) {
let n = values.len();
if n == 0 || k == 0 {
return (vec![], vec![]);
}
if k >= n {
let indicators = vec![1.0; n];
return (values.to_vec(), indicators);
}
let ranks = match soft_rank(values, temperature) {
Ok(r) => r,
Err(_) => return (vec![0.0; n], vec![0.0; n]),
};
let threshold = (n - k) as f64 + 0.5;
let mut weighted_values = Vec::with_capacity(n);
let mut indicators = Vec::with_capacity(n);
for i in 0..n {
let indicator = sigmoid((ranks[i] - threshold) / temperature);
indicators.push(indicator);
weighted_values.push(values[i] * indicator);
}
(weighted_values, indicators)
}
pub fn sparse_topk_matrix(
scores: &[f64],
k: usize,
steepness: f64,
network_type: NetworkType,
dist: RelaxDist,
) -> Result<Vec<Vec<f64>>> {
let n = scores.len();
if n == 0 {
return Err(Error::EmptyInput);
}
if steepness <= 0.0 {
return Err(Error::InvalidTemperature(steepness));
}
let k = k.min(n);
if k == 0 {
return Ok(vec![vec![]; n]);
}
let net = DiffSortNet::new(network_type, n, steepness, dist);
let padded_n = net.size;
let pad_val = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max) + 1e6;
let mut values: Vec<f64> = Vec::with_capacity(padded_n);
values.extend_from_slice(scores);
while values.len() < padded_n {
values.push(pad_val);
}
let comparators = net.comparator_pairs();
let mut alphas = Vec::with_capacity(comparators.len());
let mut vals = values.clone();
for &(a, b) in comparators {
let diff = (vals[a] - vals[b]) * steepness;
let alpha = relaxed_sigmoid(diff, dist);
alphas.push(alpha);
let va = vals[a];
let vb = vals[b];
vals[a] = (1.0 - alpha) * va + alpha * vb;
vals[b] = alpha * va + (1.0 - alpha) * vb;
}
let mut x = vec![vec![0.0; k]; padded_n];
#[allow(clippy::needless_range_loop)] for j in 0..k {
let col = padded_n - k + j;
x[col][j] = 1.0;
}
for (idx, &(a, b)) in comparators.iter().enumerate().rev() {
let alpha = alphas[idx];
#[allow(clippy::needless_range_loop)] for j in 0..k {
let xa = x[a][j];
let xb = x[b][j];
x[a][j] = (1.0 - alpha) * xa + alpha * xb;
x[b][j] = alpha * xa + (1.0 - alpha) * xb;
}
}
let result: Vec<Vec<f64>> = x.into_iter().take(n).collect();
Ok(result)
}
pub fn sparse_topk(scores: &[f64], k: usize, steepness: f64) -> Result<Vec<Vec<f64>>> {
sparse_topk_matrix(
scores,
k,
steepness,
NetworkType::Bitonic,
RelaxDist::Logistic,
)
}
pub fn topk_cross_entropy_loss(
logits: &[f64],
target: usize,
p_k: &[f64],
steepness: f64,
network_type: NetworkType,
dist: RelaxDist,
) -> Result<f64> {
let n = logits.len();
if n == 0 {
return Err(Error::EmptyInput);
}
if target >= n {
return Err(Error::LengthMismatch(target, n));
}
let k = p_k.len();
if k == 0 || k > n {
return Err(Error::EmptyInput);
}
let attr = sparse_topk_matrix(logits, k, steepness, network_type, dist)?;
let use_softmax_top1 = p_k[0] > 0.0;
let mut topk_dist = vec![0.0; n];
if use_softmax_top1 {
let max_logit = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = logits.iter().map(|&l| (l - max_logit).exp()).collect();
let sum_exp: f64 = exps.iter().sum();
for i in 0..n {
topk_dist[i] += p_k[0] * exps[i] / sum_exp;
}
}
let start_j = if use_softmax_top1 { 1 } else { 0 };
for (j, &p_k_j) in p_k.iter().enumerate().take(k).skip(start_j) {
let first_col = k - (j + 1);
for i in 0..n {
let cumulative: f64 = attr[i][first_col..].iter().sum();
topk_dist[i] += p_k_j * cumulative;
}
}
let eps = 1e-7;
let target_prob = topk_dist[target].clamp(eps, 1.0 - eps);
Ok(-target_prob.ln())
}
pub fn topk_ce_loss(logits: &[f64], target: usize, p_k: &[f64], steepness: f64) -> Result<f64> {
topk_cross_entropy_loss(
logits,
target,
p_k,
steepness,
NetworkType::Bitonic,
RelaxDist::Logistic,
)
}
#[cfg(feature = "gumbel")]
pub mod gumbel {
use rand::Rng;
pub fn gumbel_noise<R: Rng + ?Sized>(rng: &mut R) -> f64 {
let u: f64 = rng.random_range(0.0..1.0);
let u = u.clamp(1e-10, 1.0 - 1e-10);
-(-u.ln()).ln()
}
pub fn add_gumbel_noise<R: Rng + ?Sized>(logits: &[f64], rng: &mut R) -> Vec<f64> {
logits.iter().map(|&l| l + gumbel_noise(rng)).collect()
}
pub fn gumbel_softmax<R: Rng + ?Sized>(
logits: &[f64],
temperature: f64,
rng: &mut R,
) -> Vec<f64> {
let noisy = add_gumbel_noise(logits, rng);
let max = noisy.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let exps: Vec<f64> = noisy
.iter()
.map(|&l| ((l - max) / temperature).exp())
.collect();
let sum: f64 = exps.iter().sum();
exps.iter().map(|&e| e / sum).collect()
}
pub fn relaxed_topk_gumbel<R: Rng + ?Sized>(
scores: &[f64],
k: usize,
temperature: f64,
scale: f64,
rng: &mut R,
) -> Vec<f64> {
kuji::relaxed_topk_gumbel(scores, k, temperature, scale, rng)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sorting_network::{NetworkType, RelaxDist};
#[test]
fn test_topk_basic() {
let values = [0.1, 0.9, 0.5, 0.8, 0.2];
let (weighted, indicators) = differentiable_topk(&values, 2, 0.1);
assert_eq!(weighted.len(), 5);
assert_eq!(indicators.len(), 5);
assert!(
indicators[1] > 0.5,
"0.9 should be in top-2: {}",
indicators[1]
);
assert!(
indicators[3] > 0.5,
"0.8 should be in top-2: {}",
indicators[3]
);
assert!(
indicators[0] < 0.5,
"0.1 should not be in top-2: {}",
indicators[0]
);
assert!(
indicators[2] < 0.5,
"0.5 should not be in top-2: {}",
indicators[2]
);
assert!(
indicators[4] < 0.5,
"0.2 should not be in top-2: {}",
indicators[4]
);
}
#[test]
fn test_bottomk_basic() {
let values = [0.1, 0.9, 0.5, 0.8, 0.2];
let (_, indicators) = differentiable_bottomk(&values, 2, 0.1);
assert!(
indicators[0] > 0.5,
"0.1 should be in bottom-2: {}",
indicators[0]
);
assert!(
indicators[4] > 0.5,
"0.2 should be in bottom-2: {}",
indicators[4]
);
assert!(indicators[1] < 0.5);
assert!(indicators[2] < 0.5);
assert!(indicators[3] < 0.5);
}
#[test]
fn test_topk_empty() {
let (w, i) = differentiable_topk(&[], 2, 0.1);
assert!(w.is_empty());
assert!(i.is_empty());
}
#[test]
fn test_topk_k_zero() {
let values = [1.0, 2.0, 3.0];
let (w, i) = differentiable_topk(&values, 0, 0.1);
assert!(w.is_empty());
assert!(i.is_empty());
}
#[test]
fn test_topk_k_geq_n() {
let values = [1.0, 2.0, 3.0];
let (w, indicators) = differentiable_topk(&values, 5, 0.1);
assert_eq!(w, values);
for &i in &indicators {
assert_eq!(i, 1.0);
}
}
#[test]
fn test_temperature_effect() {
let values = [0.1, 0.9, 0.5];
let (_, indicators_sharp) = differentiable_topk(&values, 1, 0.01);
let (_, indicators_smooth) = differentiable_topk(&values, 1, 1.0);
let sharp_entropy: f64 = indicators_sharp
.iter()
.map(|&p| {
if p > 0.0 && p < 1.0 {
-p * p.ln() - (1.0 - p) * (1.0 - p).ln()
} else {
0.0
}
})
.sum();
let smooth_entropy: f64 = indicators_smooth
.iter()
.map(|&p| {
if p > 0.0 && p < 1.0 {
-p * p.ln() - (1.0 - p) * (1.0 - p).ln()
} else {
0.0
}
})
.sum();
assert!(
sharp_entropy < smooth_entropy,
"sharp should have lower entropy: {} vs {}",
sharp_entropy,
smooth_entropy
);
}
#[test]
fn sparse_topk_correct_shape() {
let scores = vec![3.0, 1.0, 4.0, 2.0];
let a = sparse_topk_matrix(&scores, 2, 10.0, NetworkType::Bitonic, RelaxDist::Logistic)
.unwrap();
assert_eq!(a.len(), 4, "should have n=4 rows");
for row in &a {
assert_eq!(row.len(), 2, "should have k=2 columns");
}
}
#[test]
fn sparse_topk_correct_shape_odd_even() {
let scores = vec![5.0, 2.0, 7.0, 1.0, 3.0];
let a = sparse_topk_matrix(&scores, 3, 10.0, NetworkType::OddEven, RelaxDist::Logistic)
.unwrap();
assert_eq!(a.len(), 5);
for row in &a {
assert_eq!(row.len(), 3);
}
}
#[test]
fn sparse_topk_row_sums_le_one() {
let scores = vec![3.0, 1.0, 4.0, 2.0];
let a = sparse_topk_matrix(&scores, 2, 20.0, NetworkType::Bitonic, RelaxDist::Logistic)
.unwrap();
for (i, row) in a.iter().enumerate() {
let sum: f64 = row.iter().sum();
assert!(
sum <= 1.0 + 0.05,
"row {} sum should be <= 1: got {}",
i,
sum
);
}
}
#[test]
fn sparse_topk_column_sums_approx_one() {
let scores = vec![3.0, 1.0, 4.0, 2.0];
let k = 2;
let a = sparse_topk_matrix(&scores, k, 20.0, NetworkType::Bitonic, RelaxDist::Logistic)
.unwrap();
let n = scores.len();
#[allow(clippy::needless_range_loop)] for j in 0..k {
let col_sum: f64 = (0..n).map(|i| a[i][j]).sum();
assert!(
(col_sum - 1.0).abs() < 0.1,
"column {} sum should be ~1.0, got {}",
j,
col_sum
);
}
}
#[test]
fn sparse_topk_high_steepness_approaches_hard() {
let scores = vec![3.0, 1.0, 4.0, 2.0];
let k = 2;
let a = sparse_topk_matrix(&scores, k, 100.0, NetworkType::Bitonic, RelaxDist::Logistic)
.unwrap();
assert!(
a[2][k - 1] > 0.8,
"score 4.0 should be rank 1 (col {}): got {}",
k - 1,
a[2][k - 1]
);
assert!(
a[0][0] > 0.8,
"score 3.0 should be rank 2 (col 0): got {}",
a[0][0]
);
let sum_1: f64 = a[1].iter().sum();
assert!(
sum_1 < 0.2,
"score 1.0 should not be in top-2: row sum = {}",
sum_1
);
let sum_3: f64 = a[3].iter().sum();
assert!(
sum_3 < 0.2,
"score 2.0 should not be in top-2: row sum = {}",
sum_3
);
}
#[test]
fn sparse_topk_k_equals_n_recovers_full_perm() {
let scores = vec![3.0, 1.0, 4.0, 2.0];
let n = scores.len();
let a = sparse_topk_matrix(&scores, n, 20.0, NetworkType::Bitonic, RelaxDist::Logistic)
.unwrap();
assert_eq!(a.len(), n);
for row in &a {
assert_eq!(row.len(), n);
}
for (i, row) in a.iter().enumerate() {
let sum: f64 = row.iter().sum();
assert!(
(sum - 1.0).abs() < 0.05,
"row {} sum = {}, expected ~1.0",
i,
sum
);
}
}
#[test]
fn sparse_topk_with_cauchy() {
let scores = vec![3.0, 1.0, 4.0, 2.0];
let a =
sparse_topk_matrix(&scores, 2, 10.0, NetworkType::Bitonic, RelaxDist::Cauchy).unwrap();
assert_eq!(a.len(), 4);
assert_eq!(a[0].len(), 2);
assert!(a[2][0] > a[1][0], "4.0 should rank higher than 1.0");
}
#[test]
fn sparse_topk_with_gaussian() {
let scores = vec![3.0, 1.0, 4.0, 2.0];
let a = sparse_topk_matrix(&scores, 2, 10.0, NetworkType::Bitonic, RelaxDist::Gaussian)
.unwrap();
assert_eq!(a.len(), 4);
assert_eq!(a[0].len(), 2);
}
#[test]
fn sparse_topk_non_power_of_two() {
let scores = vec![5.0, 2.0, 7.0];
let a = sparse_topk_matrix(&scores, 2, 10.0, NetworkType::Bitonic, RelaxDist::Logistic)
.unwrap();
assert_eq!(a.len(), 3, "should have n=3 rows despite padding");
for row in &a {
assert_eq!(row.len(), 2);
}
}
#[test]
fn sparse_topk_convenience_wrapper() {
let scores = vec![3.0, 1.0, 4.0, 2.0];
let a = sparse_topk(&scores, 2, 10.0).unwrap();
assert_eq!(a.len(), 4);
assert_eq!(a[0].len(), 2);
}
#[test]
fn sparse_topk_empty_input() {
assert!(sparse_topk(&[], 2, 10.0).is_err());
}
#[test]
fn sparse_topk_invalid_steepness() {
assert!(sparse_topk(&[1.0, 2.0], 1, 0.0).is_err());
assert!(sparse_topk(&[1.0, 2.0], 1, -1.0).is_err());
}
#[test]
fn topk_ce_loss_nonnegative() {
let logits = vec![2.0, 0.5, 1.0, 0.1];
let p_k = vec![0.5, 0.0, 0.0, 0.5];
let loss = topk_ce_loss(&logits, 0, &p_k, 10.0).unwrap();
assert!(loss >= 0.0, "loss should be non-negative: {}", loss);
}
#[test]
fn topk_ce_loss_correct_class_low_loss() {
let logits = vec![5.0, 0.1, 0.2, 0.3];
let p_k = vec![0.5, 0.0, 0.0, 0.5];
let loss = topk_ce_loss(&logits, 0, &p_k, 10.0).unwrap();
assert!(
loss < 1.0,
"correct class at rank 1 should have low loss: {}",
loss
);
}
#[test]
fn topk_ce_loss_wrong_class_high_loss() {
let logits = vec![0.1, 5.0, 4.0, 3.0];
let p_k = vec![1.0]; let loss_wrong = topk_ce_loss(&logits, 0, &p_k, 10.0).unwrap();
let loss_right = topk_ce_loss(&logits, 1, &p_k, 10.0).unwrap();
assert!(
loss_wrong > loss_right,
"wrong class should have higher loss: {} vs {}",
loss_wrong,
loss_right
);
}
#[test]
fn topk_ce_loss_uniform_pk() {
let logits = vec![2.0, 0.5, 1.0, 0.1];
let p_k = vec![0.25, 0.25, 0.25, 0.25];
let loss = topk_ce_loss(&logits, 0, &p_k, 10.0).unwrap();
assert!(loss.is_finite());
}
#[test]
fn topk_ce_loss_decreases_when_in_topk() {
let logits = vec![2.0, 5.0, 1.0, 0.1];
let target = 0;
let p_k_topk = vec![0.0, 1.0]; let p_k_top1 = vec![1.0, 0.0];
let loss_topk = topk_ce_loss(&logits, target, &p_k_topk, 10.0).unwrap();
let loss_top1 = topk_ce_loss(&logits, target, &p_k_top1, 10.0).unwrap();
assert!(
loss_topk < loss_top1,
"top-k loss should be lower than top-1 loss when target is rank 2: {} vs {}",
loss_topk,
loss_top1
);
}
#[test]
fn topk_ce_loss_invalid_target() {
let logits = vec![1.0, 2.0, 3.0];
let p_k = vec![1.0];
assert!(topk_ce_loss(&logits, 5, &p_k, 10.0).is_err());
}
}