use rand::Rng;
pub fn gumbel_softmax(
logits: &[f64],
temperature: f64,
scale: f64,
rng: &mut impl Rng,
) -> Vec<f64> {
kuji::gumbel_softmax(logits, temperature, scale, rng)
}
pub fn relaxed_topk_gumbel(
scores: &[f64],
k: usize,
temperature: f64,
scale: f64,
rng: &mut impl Rng,
) -> Vec<f64> {
kuji::relaxed_topk_gumbel(scores, k, temperature, scale, rng)
}
pub fn gumbel_attention_mask(
reranker_scores: &[f64],
k: usize,
temperature: f64,
scale: f64,
rng: &mut impl Rng,
) -> Vec<f64> {
relaxed_topk_gumbel(reranker_scores, k, temperature, scale, rng)
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
#[test]
fn test_gumbel_softmax() {
let mut rng = StdRng::seed_from_u64(42);
let logits = vec![0.5, 1.0, 0.3];
let probs = gumbel_softmax(&logits, 0.5, 1.0, &mut rng);
assert_eq!(probs.len(), logits.len());
let sum: f64 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
assert!(probs.iter().all(|&p| (0.0..=1.0).contains(&p)));
}
#[test]
fn test_relaxed_topk_sums_to_k() {
let mut rng = StdRng::seed_from_u64(42);
let scores = vec![0.8, 0.6, 0.9, 0.3, 0.7];
let k = 3;
let mask = relaxed_topk_gumbel(&scores, k, 0.1, 1.0, &mut rng);
assert_eq!(mask.len(), scores.len());
assert!(mask.iter().all(|&m| m >= -1e-10));
let sum: f64 = mask.iter().sum();
assert!(
(sum - k as f64).abs() < 1.5,
"relaxed top-k mask should sum to ~{k}, got {sum}"
);
}
#[test]
fn test_edge_cases() {
let mut rng = StdRng::seed_from_u64(42);
let empty: Vec<f64> = vec![];
assert!(relaxed_topk_gumbel(&empty, 3, 0.5, 1.0, &mut rng).is_empty());
let scores = vec![0.5, 0.3];
assert_eq!(
relaxed_topk_gumbel(&scores, 5, 0.5, 1.0, &mut rng),
vec![1.0, 1.0]
);
}
}