use crate::sigmoid::sigmoid;
use crate::soft_rank;
pub fn differentiable_topk(values: &[f64], k: usize, temperature: f64) -> (Vec<f64>, Vec<f64>) {
let n = values.len();
if n == 0 || k == 0 {
return (vec![], vec![]);
}
if k >= n {
let indicators = vec![1.0; n];
return (values.to_vec(), indicators);
}
let ranks = match soft_rank(values, temperature) {
Ok(r) => r,
Err(_) => return (vec![0.0; n], vec![0.0; n]),
};
let threshold = k as f64 + 0.5;
let mut weighted_values = Vec::with_capacity(n);
let mut indicators = Vec::with_capacity(n);
for i in 0..n {
let indicator = sigmoid((threshold - ranks[i]) / temperature);
indicators.push(indicator);
weighted_values.push(values[i] * indicator);
}
(weighted_values, indicators)
}
pub fn differentiable_bottomk(values: &[f64], k: usize, temperature: f64) -> (Vec<f64>, Vec<f64>) {
let n = values.len();
if n == 0 || k == 0 {
return (vec![], vec![]);
}
if k >= n {
let indicators = vec![1.0; n];
return (values.to_vec(), indicators);
}
let ranks = match soft_rank(values, temperature) {
Ok(r) => r,
Err(_) => return (vec![0.0; n], vec![0.0; n]),
};
let threshold = (n - k) as f64 + 0.5;
let mut weighted_values = Vec::with_capacity(n);
let mut indicators = Vec::with_capacity(n);
for i in 0..n {
let indicator = sigmoid((ranks[i] - threshold) / temperature);
indicators.push(indicator);
weighted_values.push(values[i] * indicator);
}
(weighted_values, indicators)
}
#[cfg(feature = "gumbel")]
pub mod gumbel {
use rand::Rng;
pub fn gumbel_noise<R: Rng + ?Sized>(rng: &mut R) -> f64 {
let u: f64 = rng.gen_range(0.0..1.0);
let u = u.clamp(1e-10, 1.0 - 1e-10);
-(-u.ln()).ln()
}
pub fn add_gumbel_noise<R: Rng + ?Sized>(logits: &[f64], rng: &mut R) -> Vec<f64> {
logits.iter().map(|&l| l + gumbel_noise(rng)).collect()
}
pub fn gumbel_softmax<R: Rng + ?Sized>(
logits: &[f64],
temperature: f64,
rng: &mut R,
) -> Vec<f64> {
let noisy = add_gumbel_noise(logits, rng);
let max = noisy.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let exps: Vec<f64> = noisy
.iter()
.map(|&l| ((l - max) / temperature).exp())
.collect();
let sum: f64 = exps.iter().sum();
exps.iter().map(|&e| e / sum).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_topk_basic() {
let values = [0.1, 0.9, 0.5, 0.8, 0.2];
let (weighted, indicators) = differentiable_topk(&values, 2, 0.1);
assert_eq!(weighted.len(), 5);
assert_eq!(indicators.len(), 5);
assert!(
indicators[1] > 0.5,
"0.9 should be in top-2: {}",
indicators[1]
);
assert!(
indicators[3] > 0.5,
"0.8 should be in top-2: {}",
indicators[3]
);
assert!(
indicators[0] < 0.5,
"0.1 should not be in top-2: {}",
indicators[0]
);
assert!(
indicators[2] < 0.5,
"0.5 should not be in top-2: {}",
indicators[2]
);
assert!(
indicators[4] < 0.5,
"0.2 should not be in top-2: {}",
indicators[4]
);
}
#[test]
fn test_bottomk_basic() {
let values = [0.1, 0.9, 0.5, 0.8, 0.2];
let (_, indicators) = differentiable_bottomk(&values, 2, 0.1);
assert!(
indicators[0] > 0.5,
"0.1 should be in bottom-2: {}",
indicators[0]
);
assert!(
indicators[4] > 0.5,
"0.2 should be in bottom-2: {}",
indicators[4]
);
assert!(indicators[1] < 0.5);
assert!(indicators[2] < 0.5);
assert!(indicators[3] < 0.5);
}
#[test]
fn test_topk_empty() {
let (w, i) = differentiable_topk(&[], 2, 0.1);
assert!(w.is_empty());
assert!(i.is_empty());
}
#[test]
fn test_topk_k_zero() {
let values = [1.0, 2.0, 3.0];
let (w, i) = differentiable_topk(&values, 0, 0.1);
assert!(w.is_empty());
assert!(i.is_empty());
}
#[test]
fn test_topk_k_geq_n() {
let values = [1.0, 2.0, 3.0];
let (w, indicators) = differentiable_topk(&values, 5, 0.1);
assert_eq!(w, values);
for &i in &indicators {
assert_eq!(i, 1.0);
}
}
#[test]
fn test_temperature_effect() {
let values = [0.1, 0.9, 0.5];
let (_, indicators_sharp) = differentiable_topk(&values, 1, 0.01);
let (_, indicators_smooth) = differentiable_topk(&values, 1, 1.0);
let sharp_entropy: f64 = indicators_sharp
.iter()
.map(|&p| {
if p > 0.0 && p < 1.0 {
-p * p.ln() - (1.0 - p) * (1.0 - p).ln()
} else {
0.0
}
})
.sum();
let smooth_entropy: f64 = indicators_smooth
.iter()
.map(|&p| {
if p > 0.0 && p < 1.0 {
-p * p.ln() - (1.0 - p) * (1.0 - p).ln()
} else {
0.0
}
})
.sum();
assert!(
sharp_entropy < smooth_entropy,
"sharp should have lower entropy: {} vs {}",
sharp_entropy,
smooth_entropy
);
}
}