#[inline]
fn smooth_max(a: f32, b: f32, temperature: f32) -> f32 {
let inv_t = temperature.recip();
let m = a.max(b);
if m == f32::NEG_INFINITY {
return f32::NEG_INFINITY;
}
m + temperature * (((a - m) * inv_t).exp() + ((b - m) * inv_t).exp()).ln()
}
#[inline]
fn log_sum_exp(xs: &[f32], temperature: f32) -> f32 {
let m = xs.iter().copied().fold(f32::NEG_INFINITY, f32::max);
if m == f32::NEG_INFINITY {
return f32::NEG_INFINITY;
}
let inv_t = temperature.recip();
let sum: f32 = xs.iter().map(|&x| ((x - m) * inv_t).exp()).sum();
m + temperature * sum.ln()
}
fn forward_dp(scores: &[f32], k: usize, temperature: f32) -> Vec<Vec<f32>> {
let n = scores.len();
let mut v = vec![vec![f32::NEG_INFINITY; k + 1]; n + 1];
v[0][0] = 0.0;
for i in 1..=n {
let s = scores[i - 1];
for j in 0..=k.min(i) {
let skip = v[i - 1][j];
let pick = if j > 0 {
v[i - 1][j - 1] + s
} else {
f32::NEG_INFINITY
};
v[i][j] = smooth_max(skip, pick, temperature);
}
}
v
}
fn backward_selection(scores: &[f32], k: usize, temperature: f32, v: &[Vec<f32>]) -> Vec<f32> {
let n = scores.len();
let mut w = vec![vec![f32::NEG_INFINITY; k + 1]; n + 1];
w[n][k] = 0.0;
for i in (1..=n).rev() {
let s = scores[i - 1];
for j in 0..=k.min(i) {
let cur_w = w[i][j];
if cur_w == f32::NEG_INFINITY {
continue;
}
w[i - 1][j] = smooth_max(w[i - 1][j], cur_w, temperature);
if j > 0 {
w[i - 1][j - 1] = smooth_max(w[i - 1][j - 1], cur_w + s, temperature);
}
}
}
let opt = v[n][k];
if opt == f32::NEG_INFINITY {
return vec![0.0; n];
}
let inv_t = temperature.recip();
let mut selection = vec![0.0_f32; n];
for i in 1..=n {
let s = scores[i - 1];
let mut pick_vals = Vec::with_capacity(k.min(i));
for j in 1..=k.min(i) {
let fwd = v[i - 1][j - 1];
let bwd = w[i][j];
if fwd > f32::NEG_INFINITY && bwd > f32::NEG_INFINITY {
pick_vals.push(fwd + s + bwd);
}
}
if pick_vals.is_empty() {
continue;
}
let log_pick = log_sum_exp(&pick_vals, temperature);
let diff = (log_pick - opt) * inv_t;
selection[i - 1] = diff.exp().clamp(0.0, 1.0);
}
selection
}
pub fn dp_topk(scores: &[f32], k: usize, temperature: f32) -> Vec<f32> {
assert!(k <= scores.len(), "k ({k}) > n ({})", scores.len());
assert!(
temperature > 0.0,
"temperature must be positive, got {temperature}"
);
let n = scores.len();
if n == 0 || k == 0 {
return vec![0.0; n];
}
if k == n {
return vec![1.0; n];
}
let v = forward_dp(scores, k, temperature);
backward_selection(scores, k, temperature, &v)
}
pub fn dp_topk_with_grad(scores: &[f32], k: usize, temperature: f32) -> (Vec<f32>, Vec<f32>) {
let sel = dp_topk(scores, k, temperature);
let grad = sel.clone();
(sel, grad)
}
pub fn dp_knapsack(
scores: &[f32],
weights: &[usize],
capacity: usize,
temperature: f32,
) -> Vec<f32> {
let n = scores.len();
assert_eq!(n, weights.len(), "scores and weights must have same length");
assert!(
temperature > 0.0,
"temperature must be positive, got {temperature}"
);
if n == 0 || capacity == 0 {
return vec![0.0; n];
}
let mut v = vec![vec![f32::NEG_INFINITY; capacity + 1]; n + 1];
v[0][0] = 0.0;
for i in 1..=n {
let s = scores[i - 1];
let w = weights[i - 1];
for c in 0..=capacity {
let skip = v[i - 1][c];
let pick = if w <= c && v[i - 1][c - w] > f32::NEG_INFINITY {
v[i - 1][c - w] + s
} else {
f32::NEG_INFINITY
};
v[i][c] = smooth_max(skip, pick, temperature);
}
}
let opt = {
let mut vals = Vec::with_capacity(capacity + 1);
for val in &v[n][..=capacity] {
if *val > f32::NEG_INFINITY {
vals.push(*val);
}
}
if vals.is_empty() {
return vec![0.0; n];
}
log_sum_exp(&vals, temperature)
};
let mut w_back = vec![vec![f32::NEG_INFINITY; capacity + 1]; n + 1];
for val in w_back[n].iter_mut().take(capacity + 1) {
*val = 0.0;
}
for i in (1..=n).rev() {
let s = scores[i - 1];
let w = weights[i - 1];
for c in 0..=capacity {
let cur = w_back[i][c];
if cur == f32::NEG_INFINITY {
continue;
}
w_back[i - 1][c] = smooth_max(w_back[i - 1][c], cur, temperature);
if c + w <= capacity {
w_back[i - 1][c + w] = smooth_max(w_back[i - 1][c + w], cur + s, temperature);
}
}
}
let inv_t = temperature.recip();
let mut selection = vec![0.0_f32; n];
for i in 1..=n {
let s = scores[i - 1];
let w = weights[i - 1];
let mut pick_vals = Vec::new();
for c in w..=capacity {
let fwd = v[i - 1][c - w];
let bwd = w_back[i][c];
if fwd > f32::NEG_INFINITY && bwd > f32::NEG_INFINITY {
pick_vals.push(fwd + s + bwd);
}
}
if pick_vals.is_empty() {
continue;
}
let log_pick = log_sum_exp(&pick_vals, temperature);
let diff = (log_pick - opt) * inv_t;
selection[i - 1] = diff.exp().clamp(0.0, 1.0);
}
selection
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn smooth_max_approaches_hard_max() {
let a = 3.0_f32;
let b = 5.0;
let result = smooth_max(a, b, 0.01);
assert!((result - 5.0).abs() < 0.05, "got {result}");
let soft = smooth_max(a, b, 10.0);
assert!(soft > 5.0, "soft max should exceed hard max, got {soft}");
}
#[test]
fn returns_correct_length() {
let scores = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let sel = dp_topk(&scores, 3, 0.5);
assert_eq!(sel.len(), 5);
}
#[test]
fn values_in_unit_interval() {
let scores = vec![3.0, 1.0, 4.0, 1.5, 9.0, 2.6];
let sel = dp_topk(&scores, 3, 0.5);
for (i, &v) in sel.iter().enumerate() {
assert!((0.0..=1.0).contains(&v), "sel[{i}] = {v} out of [0, 1]");
}
}
#[test]
fn sum_approximately_k() {
let scores = vec![3.0, 1.0, 4.0, 1.5, 9.0, 2.6];
for k in 1..=5 {
let sel = dp_topk(&scores, k, 0.5);
let sum: f32 = sel.iter().sum();
assert!(
(sum - k as f32).abs() < 1.0,
"k={k}, sum={sum}, expected ~{k}"
);
}
}
#[test]
fn low_temperature_matches_hard_topk() {
let scores = vec![3.0, 1.0, 4.0, 1.5, 9.0, 2.6];
let k = 2;
let sel = dp_topk(&scores, k, 0.01);
assert!(sel[4] > 0.9, "top item sel={}", sel[4]);
assert!(sel[2] > 0.9, "second item sel={}", sel[2]);
assert!(sel[1] < 0.1, "non-top item sel={}", sel[1]);
assert!(sel[3] < 0.1, "non-top item sel={}", sel[3]);
}
#[test]
fn monotonicity_higher_scores_higher_selection() {
let scores = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let sel = dp_topk(&scores, 3, 0.5);
assert!(
sel[4] > sel[0],
"sel[4]={} should > sel[0]={}",
sel[4],
sel[0]
);
for i in 0..4 {
assert!(
sel[i] <= sel[i + 1] + 0.05,
"sel[{i}]={} > sel[{}]={} (non-monotone)",
sel[i],
i + 1,
sel[i + 1]
);
}
}
#[test]
fn gradient_nonzero_for_all_scores() {
let scores = vec![3.0, 1.0, 4.0, 1.5, 2.0];
let (_, grad) = dp_topk_with_grad(&scores, 2, 1.0);
for (i, &g) in grad.iter().enumerate() {
assert!(g > 0.0, "grad[{i}] = {g}, expected > 0");
}
}
#[test]
fn edge_case_k_zero() {
let scores = vec![1.0, 2.0, 3.0];
let sel = dp_topk(&scores, 0, 1.0);
assert_eq!(sel, vec![0.0, 0.0, 0.0]);
}
#[test]
fn edge_case_k_equals_n() {
let scores = vec![1.0, 2.0, 3.0];
let sel = dp_topk(&scores, 3, 1.0);
assert_eq!(sel, vec![1.0, 1.0, 1.0]);
}
#[test]
fn edge_case_empty() {
let scores: Vec<f32> = vec![];
let sel = dp_topk(&scores, 0, 1.0);
assert!(sel.is_empty());
}
#[test]
fn edge_case_single_item() {
let scores = vec![42.0];
let sel = dp_topk(&scores, 1, 0.5);
assert_eq!(sel, vec![1.0]);
}
#[test]
fn knapsack_basic() {
let scores = vec![6.0, 5.0, 4.0, 3.0];
let weights = vec![3, 2, 2, 1];
let capacity = 4;
let sel = dp_knapsack(&scores, &weights, capacity, 0.1);
assert!(sel[0] > 0.5, "item 0 should be selected, got {}", sel[0]);
assert!(sel[3] > 0.5, "item 3 should be selected, got {}", sel[3]);
}
#[test]
fn knapsack_respects_capacity() {
let scores = vec![10.0, 8.0, 6.0];
let weights = vec![5, 5, 5];
let capacity = 7;
let sel = dp_knapsack(&scores, &weights, capacity, 0.01);
assert!(sel[0] > 0.9, "best item should be selected, got {}", sel[0]);
assert!(
sel[2] < 0.1,
"worst item should not be selected, got {}",
sel[2]
);
}
#[test]
fn knapsack_values_in_unit_interval() {
let scores = vec![3.0, 1.0, 4.0, 1.5];
let weights = vec![2, 1, 3, 1];
let sel = dp_knapsack(&scores, &weights, 4, 0.5);
for (i, &v) in sel.iter().enumerate() {
assert!(
(0.0..=1.0 + 1e-6).contains(&v),
"sel[{i}] = {v} out of [0, 1]"
);
}
}
#[test]
fn knapsack_empty() {
let sel = dp_knapsack(&[], &[], 10, 1.0);
assert!(sel.is_empty());
}
#[test]
fn knapsack_zero_capacity() {
let sel = dp_knapsack(&[5.0, 3.0], &[1, 2], 0, 1.0);
assert_eq!(sel, vec![0.0, 0.0]);
}
#[test]
fn temperature_effect() {
let scores = vec![3.0, 1.0, 4.0, 1.5, 2.0];
let sharp = dp_topk(&scores, 2, 0.01);
let smooth = dp_topk(&scores, 2, 5.0);
let mean_sharp = sharp.iter().sum::<f32>() / sharp.len() as f32;
let var_sharp: f32 = sharp.iter().map(|&p| (p - mean_sharp).powi(2)).sum();
let mean_smooth = smooth.iter().sum::<f32>() / smooth.len() as f32;
let var_smooth: f32 = smooth.iter().map(|&p| (p - mean_smooth).powi(2)).sum();
assert!(
var_sharp > var_smooth,
"sharp variance {var_sharp} should > smooth variance {var_smooth}"
);
}
#[test]
#[should_panic(expected = "k (4) > n (3)")]
fn panics_k_exceeds_n() {
dp_topk(&[1.0, 2.0, 3.0], 4, 1.0);
}
#[test]
#[should_panic(expected = "temperature must be positive")]
fn panics_negative_temperature() {
dp_topk(&[1.0, 2.0], 1, -0.5);
}
#[test]
fn equal_scores_uniform_selection() {
let scores = vec![1.0; 5];
let sel = dp_topk(&scores, 2, 1.0);
let mean = sel.iter().sum::<f32>() / sel.len() as f32;
for (i, &v) in sel.iter().enumerate() {
assert!(
(v - mean).abs() < 0.15,
"sel[{i}]={v} deviates from mean={mean}"
);
}
}
}