use rand::Rng;
pub fn gumbel_noise(rng: &mut impl Rng) -> f64 {
let u: f64 = rng.random_range(0.0..1.0);
let u = u.clamp(1e-10, 1.0 - 1e-10);
-(-u.ln()).ln()
}
fn softmax(logits: &[f64]) -> Vec<f64> {
let n = logits.len();
if n == 0 {
return vec![];
}
let max_logit = logits.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let exps: Vec<f64> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
let sum: f64 = exps.iter().sum();
if sum > 1e-10 {
exps.iter().map(|&e| e / sum).collect()
} else {
vec![1.0 / n as f64; n]
}
}
pub fn gumbel_softmax(
logits: &[f64],
temperature: f64,
scale: f64,
rng: &mut impl Rng,
) -> Vec<f64> {
let n = logits.len();
if n == 0 {
return vec![];
}
if n == 1 {
return vec![1.0];
}
let mut gumbel_logits = Vec::with_capacity(n);
for &logit in logits {
let g = gumbel_noise(rng);
gumbel_logits.push((g + scale * logit) / temperature);
}
softmax(&gumbel_logits)
}
pub fn relaxed_topk_gumbel(
scores: &[f64],
k: usize,
temperature: f64,
scale: f64,
rng: &mut impl Rng,
) -> Vec<f64> {
let n = scores.len();
if n == 0 || k == 0 {
return vec![];
}
if k >= n {
return vec![1.0; n];
}
let mut max_mask = vec![0.0; n];
for _ in 0..k {
let mask = gumbel_softmax(scores, temperature, scale, rng);
for i in 0..n {
if mask[i] > max_mask[i] {
max_mask[i] = mask[i];
}
}
}
max_mask
}
pub fn gumbel_attention_mask(
reranker_scores: &[f64],
k: usize,
temperature: f64,
scale: f64,
rng: &mut impl Rng,
) -> Vec<f64> {
relaxed_topk_gumbel(reranker_scores, k, temperature, scale, rng)
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
#[test]
fn test_gumbel_softmax() {
let mut rng = StdRng::seed_from_u64(42);
let logits = vec![0.5, 1.0, 0.3];
let probs = gumbel_softmax(&logits, 0.5, 1.0, &mut rng);
assert_eq!(probs.len(), logits.len());
let sum: f64 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
assert!(probs.iter().all(|&p| (0.0..=1.0).contains(&p)));
}
#[test]
fn test_relaxed_topk_gumbel() {
let mut rng = StdRng::seed_from_u64(42);
let scores = vec![0.8, 0.6, 0.9, 0.3, 0.7];
let mask = relaxed_topk_gumbel(&scores, 3, 0.5, 1.0, &mut rng);
assert_eq!(mask.len(), scores.len());
assert!(mask.iter().all(|&m| (0.0..=1.0).contains(&m)));
}
#[test]
fn test_edge_cases() {
let mut rng = StdRng::seed_from_u64(42);
let empty: Vec<f64> = vec![];
assert_eq!(gumbel_softmax(&empty, 0.5, 1.0, &mut rng).len(), 0);
let single = vec![1.0];
assert_eq!(gumbel_softmax(&single, 0.5, 1.0, &mut rng), vec![1.0]);
let scores = vec![0.5, 0.3];
assert_eq!(
relaxed_topk_gumbel(&scores, 5, 0.5, 1.0, &mut rng),
vec![1.0, 1.0]
);
}
}