use scirs2_core::num_traits::{Float, FromPrimitive};
use std::fmt::Debug;
use crate::error::{OptimizeError, OptimizeResult};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum StructureType {
Simplex,
Knapsack {
weights: Vec<f64>,
capacity: f64,
},
Permutation {
dim: usize,
},
}
impl Default for StructureType {
fn default() -> Self {
StructureType::Simplex
}
}
#[derive(Debug, Clone)]
pub struct SparsemapConfig {
pub max_iter: usize,
pub tol: f64,
pub structure_type: StructureType,
pub step_size: f64,
}
impl Default for SparsemapConfig {
fn default() -> Self {
Self {
max_iter: 1000,
tol: 1e-6,
structure_type: StructureType::default(),
step_size: 0.1,
}
}
}
#[derive(Debug, Clone)]
pub struct SparsemapResult<F> {
pub solution: Vec<F>,
pub support: Vec<usize>,
pub dual: Vec<F>,
pub n_iters: usize,
}
fn project_simplex<F>(v: &[F]) -> Vec<F>
where
F: Float + FromPrimitive + Debug + Clone,
{
let n = v.len();
let mut u: Vec<F> = v.to_vec();
u.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let mut cssv = F::zero();
let mut rho = 0usize;
for (j, &uj) in u.iter().enumerate() {
cssv = cssv + uj;
let j_f = F::from_usize(j + 1).unwrap_or(F::one());
let one = F::one();
if uj - (cssv - one) / j_f > F::zero() {
rho = j;
}
}
let rho_f = F::from_usize(rho + 1).unwrap_or(F::one());
let one = F::one();
let mut cssv2 = F::zero();
for uj in u.iter().take(rho + 1) {
cssv2 = cssv2 + *uj;
}
let theta = (cssv2 - one) / rho_f;
v.iter()
.map(|&vi| {
let diff = vi - theta;
if diff > F::zero() {
diff
} else {
F::zero()
}
})
.collect()
}
fn project_knapsack<F>(v: &[F], weights: &[f64], capacity: f64) -> Vec<F>
where
F: Float + FromPrimitive + Debug + Clone,
{
let n = v.len();
let mut mu: Vec<F> = v
.iter()
.map(|&vi| {
if vi < F::zero() {
F::zero()
} else if vi > F::one() {
F::one()
} else {
vi
}
})
.collect();
let total_weight: f64 = (0..n)
.map(|i| weights.get(i).copied().unwrap_or(1.0) * mu[i].to_f64().unwrap_or(0.0))
.sum();
if total_weight <= capacity + 1e-12 {
return mu;
}
let mut lo = 0.0_f64;
let mut hi = 1e8_f64;
for _ in 0..200 {
let mid = (lo + hi) / 2.0;
let w_total: f64 = (0..n)
.map(|i| {
let wi = weights.get(i).copied().unwrap_or(1.0);
let vi = v[i].to_f64().unwrap_or(0.0);
let mu_i = (vi / (1.0 + mid * wi)).clamp(0.0, 1.0);
wi * mu_i
})
.sum();
if w_total > capacity {
lo = mid;
} else {
hi = mid;
}
}
let lambda = (lo + hi) / 2.0;
mu = (0..n)
.map(|i| {
let wi = weights.get(i).copied().unwrap_or(1.0);
let vi = v[i].to_f64().unwrap_or(0.0);
let val = (vi / (1.0 + lambda * wi)).clamp(0.0, 1.0);
F::from_f64(val).unwrap_or(F::zero())
})
.collect();
mu
}
pub fn sparsemap<F>(scores: &[F], config: &SparsemapConfig) -> OptimizeResult<SparsemapResult<F>>
where
F: Float + FromPrimitive + Debug + Clone,
{
if scores.is_empty() {
return Err(OptimizeError::InvalidInput(
"scores vector must be non-empty".into(),
));
}
let n = scores.len();
let tol_f = F::from_f64(config.tol).unwrap_or(F::epsilon());
let solution: Vec<F>;
let n_iters: usize;
let dual: Vec<F>;
match &config.structure_type {
StructureType::Simplex => {
solution = project_simplex(scores);
n_iters = 1;
let support_sum: F =
solution.iter().fold(
F::zero(),
|acc, &x| {
if x > F::zero() {
acc + x
} else {
acc
}
},
);
let support_count = solution.iter().filter(|&&x| x > F::zero()).count();
let count_f = F::from_usize(support_count).unwrap_or(F::one());
let lambda = if count_f > F::zero() {
(support_sum - F::one()) / count_f
} else {
F::zero()
};
dual = vec![lambda];
}
StructureType::Knapsack { weights, capacity } => {
let mut mu: Vec<F> = vec![F::zero(); n];
let step = F::from_f64(config.step_size).unwrap_or(F::epsilon());
let mut iter = 0usize;
let mut prev_obj = F::neg_infinity();
loop {
let grad: Vec<F> = mu.iter().zip(scores.iter()).map(|(&m, &s)| m - s).collect();
let mu_new: Vec<F> = mu
.iter()
.zip(grad.iter())
.map(|(&m, &g)| m - step * g)
.collect();
let mu_proj = project_knapsack(&mu_new, weights, *capacity);
let obj = mu_proj
.iter()
.zip(scores.iter())
.fold(F::zero(), |acc, (&m, &s)| {
let diff = m - s;
acc + diff * diff
});
let half = F::from_f64(0.5).unwrap_or(F::one());
let obj = obj * half;
let diff = (obj - prev_obj).abs();
mu = mu_proj;
prev_obj = obj;
iter += 1;
if iter >= config.max_iter || diff < tol_f {
break;
}
}
solution = mu;
n_iters = iter;
let total_w: f64 = (0..n)
.map(|i| {
weights.get(i).copied().unwrap_or(1.0) * solution[i].to_f64().unwrap_or(0.0)
})
.sum();
let slack = *capacity - total_w;
let lambda_val = if slack.abs() < 1e-8 { -1.0 } else { 0.0 };
dual = vec![F::from_f64(lambda_val).unwrap_or(F::zero())];
}
StructureType::Permutation { dim } => {
let d = *dim;
if scores.len() != d * d {
return Err(OptimizeError::InvalidInput(format!(
"Permutation structure requires d²={} scores but got {}",
d * d,
scores.len()
)));
}
let inv_d = F::from_f64(1.0 / d as f64).unwrap_or(F::one());
let mut mu: Vec<F> = vec![inv_d; d * d];
let step = F::from_f64(config.step_size).unwrap_or(F::epsilon());
let mut iter = 0usize;
loop {
let mu_step: Vec<F> = mu
.iter()
.zip(scores.iter())
.map(|(&m, &s)| m - step * (m - s))
.collect();
let mut m_sink = mu_step;
for _ in 0..50 {
for row in 0..d {
let row_sum: F = (0..d)
.map(|col| m_sink[row * d + col])
.fold(F::zero(), |a, b| a + b);
if row_sum > F::zero() {
for col in 0..d {
m_sink[row * d + col] = m_sink[row * d + col] / row_sum;
}
}
}
for col in 0..d {
let col_sum: F = (0..d)
.map(|row| m_sink[row * d + col])
.fold(F::zero(), |a, b| a + b);
if col_sum > F::zero() {
for row in 0..d {
m_sink[row * d + col] = m_sink[row * d + col] / col_sum;
}
}
}
}
let change: F = mu
.iter()
.zip(m_sink.iter())
.map(|(&a, &b)| {
let d = a - b;
d * d
})
.fold(F::zero(), |a, b| a + b);
mu = m_sink;
iter += 1;
if iter >= config.max_iter || change < tol_f * tol_f {
break;
}
}
solution = mu;
n_iters = iter;
dual = vec![F::zero(); 2 * d]; }
}
let support: Vec<usize> = solution
.iter()
.enumerate()
.filter_map(|(i, &v)| {
if v > F::from_f64(1e-9).unwrap_or(F::zero()) {
Some(i)
} else {
None
}
})
.collect();
Ok(SparsemapResult {
solution,
support,
dual,
n_iters,
})
}
pub fn sparsemap_gradient<F>(result: &SparsemapResult<F>, upstream_grad: &[F]) -> Vec<F>
where
F: Float + FromPrimitive + Debug + Clone,
{
let n = result.solution.len();
if upstream_grad.len() != n {
return vec![F::zero(); n];
}
let s = &result.support;
if s.is_empty() {
return vec![F::zero(); n];
}
let s_size = F::from_usize(s.len()).unwrap_or(F::one());
let mean_s: F = s
.iter()
.map(|&i| upstream_grad[i])
.fold(F::zero(), |a, b| a + b)
/ s_size;
let mut grad = vec![F::zero(); n];
for &i in s {
grad[i] = upstream_grad[i] - mean_s;
}
grad
}
#[derive(Debug, Clone)]
pub struct PerturbedOptimizerConfig {
pub n_samples: usize,
pub epsilon: f64,
pub seed: u64,
}
impl Default for PerturbedOptimizerConfig {
fn default() -> Self {
Self {
n_samples: 100,
epsilon: 0.1,
seed: 42,
}
}
}
#[derive(Debug, Clone)]
pub struct PerturbedOptimizer {
config: PerturbedOptimizerConfig,
}
impl PerturbedOptimizer {
pub fn new(config: PerturbedOptimizerConfig) -> Self {
Self { config }
}
pub fn forward<F>(&self, scores: &[F]) -> OptimizeResult<Vec<F>>
where
F: Float + FromPrimitive + Debug + Clone,
{
if scores.is_empty() {
return Err(OptimizeError::InvalidInput(
"scores must be non-empty".into(),
));
}
let n = scores.len();
let mut counts = vec![0usize; n];
let eps = self.config.epsilon;
let mut rng_state = self.config.seed;
let n_samples = self.config.n_samples;
for _ in 0..n_samples {
let mut best_idx = 0usize;
let mut best_val = F::neg_infinity();
for i in 0..n {
let z = sample_standard_normal(&mut rng_state);
let perturbed = scores[i] + F::from_f64(eps * z).unwrap_or(F::zero());
if perturbed > best_val {
best_val = perturbed;
best_idx = i;
}
}
counts[best_idx] += 1;
}
let n_samples_f = F::from_usize(n_samples).unwrap_or(F::one());
let probs: Vec<F> = counts
.iter()
.map(|&c| F::from_usize(c).unwrap_or(F::zero()) / n_samples_f)
.collect();
Ok(probs)
}
pub fn backward<F>(&self, scores: &[F], upstream: &[F]) -> OptimizeResult<Vec<F>>
where
F: Float + FromPrimitive + Debug + Clone,
{
if scores.len() != upstream.len() {
return Err(OptimizeError::InvalidInput(
"scores and upstream must have the same length".into(),
));
}
let n = scores.len();
let eps = self.config.epsilon;
let eps_sq = eps * eps;
let n_samples = self.config.n_samples;
let mut grad = vec![F::zero(); n];
let mut rng_state = self.config.seed;
for _ in 0..n_samples {
let noise: Vec<f64> = (0..n)
.map(|_| sample_standard_normal(&mut rng_state))
.collect();
let mut best_idx = 0usize;
let mut best_val = F::neg_infinity();
for i in 0..n {
let perturbed = scores[i] + F::from_f64(eps * noise[i]).unwrap_or(F::zero());
if perturbed > best_val {
best_val = perturbed;
best_idx = i;
}
}
let dot = upstream[best_idx];
for i in 0..n {
let zi = F::from_f64(noise[i]).unwrap_or(F::zero());
let eps_sq_f = F::from_f64(eps_sq).unwrap_or(F::one());
grad[i] = grad[i] + dot * zi / eps_sq_f;
}
}
let n_f = F::from_usize(n_samples).unwrap_or(F::one());
for g in &mut grad {
*g = *g / n_f;
}
Ok(grad)
}
}
fn splitmix64(state: &mut u64) -> u64 {
*state = state.wrapping_add(0x9e3779b97f4a7c15);
let mut z = *state;
z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
z ^ (z >> 31)
}
fn sample_standard_normal(state: &mut u64) -> f64 {
let u1_raw = splitmix64(state);
let u2_raw = splitmix64(state);
let u1 = (u1_raw as f64 + 0.5) / (u64::MAX as f64 + 1.0);
let u2 = (u2_raw as f64 + 0.5) / (u64::MAX as f64 + 1.0);
let two_pi = 2.0 * std::f64::consts::PI;
(-2.0 * u1.ln()).sqrt() * (two_pi * u2).cos()
}
pub fn soft_sort<F>(x: &[F], temperature: F) -> OptimizeResult<Vec<F>>
where
F: Float + FromPrimitive + Debug + Clone,
{
if x.is_empty() {
return Err(OptimizeError::InvalidInput(
"input vector must be non-empty".into(),
));
}
let n = x.len();
let mut sorted_x: Vec<F> = x.to_vec();
sorted_x.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
if temperature == F::zero() {
return Ok(sorted_x);
}
let mean_val =
sorted_x.iter().fold(F::zero(), |a, b| a + *b) / F::from_usize(n).unwrap_or(F::one());
let t_clamped = if temperature > F::one() {
F::one()
} else {
temperature
};
let one_minus_t = F::one() - t_clamped;
let mixed: Vec<F> = sorted_x
.iter()
.map(|&v| one_minus_t * v + t_clamped * mean_val)
.collect();
let result = pool_adjacent_violators(&mixed);
Ok(result)
}
fn pool_adjacent_violators<F>(s: &[F]) -> Vec<F>
where
F: Float + FromPrimitive + Debug + Clone,
{
let n = s.len();
let mut blocks: Vec<(F, usize)> = s.iter().map(|&v| (v, 1)).collect();
let mut changed = true;
while changed {
changed = false;
let mut i = 0usize;
let mut new_blocks: Vec<(F, usize)> = Vec::with_capacity(blocks.len());
while i < blocks.len() {
let mut sum = blocks[i].0;
let mut cnt = blocks[i].1;
while i + 1 < blocks.len() {
let next_mean =
blocks[i + 1].0 / F::from_usize(blocks[i + 1].1).unwrap_or(F::one());
let cur_mean = sum / F::from_usize(cnt).unwrap_or(F::one());
if cur_mean > next_mean {
sum = sum + blocks[i + 1].0;
cnt += blocks[i + 1].1;
i += 1;
changed = true;
} else {
break;
}
}
new_blocks.push((sum, cnt));
i += 1;
}
blocks = new_blocks;
}
let mut result = Vec::with_capacity(n);
for (sum, cnt) in blocks {
let mean = sum / F::from_usize(cnt).unwrap_or(F::one());
for _ in 0..cnt {
result.push(mean);
}
}
result
}
pub fn soft_rank<F>(x: &[F], temperature: F) -> OptimizeResult<Vec<F>>
where
F: Float + FromPrimitive + Debug + Clone,
{
if x.is_empty() {
return Err(OptimizeError::InvalidInput(
"input vector must be non-empty".into(),
));
}
let n = x.len();
let one = F::one();
let n_f = F::from_usize(n).unwrap_or(one);
if temperature == F::zero() {
let ranks: Vec<F> = (0..n)
.map(|i| {
let rank = x.iter().filter(|&&v| v < x[i]).count();
F::from_usize(rank + 1).unwrap_or(one)
})
.collect();
return Ok(ranks);
}
let two = F::from_f64(2.0).unwrap_or(one);
let ranks: Vec<F> = (0..n)
.map(|i| {
let mut soft_rank_i = one; for j in 0..n {
if i == j {
continue;
}
let diff = (x[i] - x[j]) / temperature;
let diff_clamped = if diff < F::from_f64(-50.0).unwrap_or(-one) {
F::from_f64(-50.0).unwrap_or(-one)
} else if diff > F::from_f64(50.0).unwrap_or(one) {
F::from_f64(50.0).unwrap_or(one)
} else {
diff
};
let sigmoid_val = one / (one + (-diff_clamped).exp());
soft_rank_i = soft_rank_i + sigmoid_val;
}
let mid = (n_f + one) / two;
let t = if temperature > F::from_f64(10.0).unwrap_or(one) {
one
} else {
temperature / F::from_f64(10.0).unwrap_or(one)
};
(one - t) * soft_rank_i + t * mid
})
.collect();
Ok(ranks)
}
pub fn diff_topk<F>(scores: &[F], k: usize, temperature: F) -> OptimizeResult<Vec<F>>
where
F: Float + FromPrimitive + Debug + Clone,
{
let n = scores.len();
if n == 0 {
return Err(OptimizeError::InvalidInput(
"scores must be non-empty".into(),
));
}
if k == 0 || k > n {
return Err(OptimizeError::InvalidInput(format!(
"k must be in [1, {}] but got {}",
n, k
)));
}
let k_f = F::from_usize(k).unwrap_or(F::one());
if temperature == F::zero() {
let mut indexed: Vec<(usize, F)> = scores.iter().copied().enumerate().collect();
indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let mut result = vec![F::zero(); n];
for (idx, _) in indexed.iter().take(k) {
result[*idx] = F::one();
}
return Ok(result);
}
let max_score = scores
.iter()
.copied()
.fold(F::neg_infinity(), |a, b| if b > a { b } else { a });
let exp_scores: Vec<F> = scores
.iter()
.map(|&s| {
let scaled = (s - max_score) / temperature;
let clamped = if scaled < F::from_f64(-80.0).unwrap_or(-F::one()) {
F::from_f64(-80.0).unwrap_or(-F::one())
} else {
scaled
};
clamped.exp()
})
.collect();
let sum_exp: F = exp_scores.iter().fold(F::zero(), |a, b| a + *b);
if sum_exp == F::zero() {
let uniform = k_f / F::from_usize(n).unwrap_or(F::one());
return Ok(vec![uniform; n]);
}
let result: Vec<F> = exp_scores.iter().map(|&e| k_f * e / sum_exp).collect();
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-5;
#[test]
fn test_sparsemap_config_defaults() {
let cfg = SparsemapConfig::default();
assert_eq!(cfg.max_iter, 1000);
assert!((cfg.tol - 1e-6).abs() < 1e-12);
assert!(matches!(cfg.structure_type, StructureType::Simplex));
}
#[test]
fn test_sparsemap_simplex_sums_to_one() {
let scores = vec![1.0_f64, 2.0, 0.5, -0.3, 1.8];
let cfg = SparsemapConfig::default();
let res = sparsemap(&scores, &cfg).unwrap();
let sum: f64 = res.solution.iter().sum();
assert!((sum - 1.0).abs() < EPS, "sum = {}", sum);
}
#[test]
fn test_sparsemap_simplex_sparse_support() {
let scores = vec![10.0_f64, 0.1, 0.1, 0.1, 0.1];
let cfg = SparsemapConfig::default();
let res = sparsemap(&scores, &cfg).unwrap();
let n_nonzero = res.solution.iter().filter(|&&v| v > 1e-9).count();
assert!(
n_nonzero <= scores.len(),
"non-zero count {} should be <= n",
n_nonzero
);
assert!(!res.support.is_empty());
}
#[test]
fn test_sparsemap_simplex_nonneg() {
let scores = vec![-1.0_f64, -0.5, 0.3, 2.0, -3.0];
let cfg = SparsemapConfig::default();
let res = sparsemap(&scores, &cfg).unwrap();
for &v in &res.solution {
assert!(v >= -1e-10, "negative value {}", v);
}
}
#[test]
fn test_sparsemap_gradient_shape_matches_input() {
let scores = vec![1.0_f64, 2.0, 0.5];
let cfg = SparsemapConfig::default();
let res = sparsemap(&scores, &cfg).unwrap();
let upstream = vec![1.0_f64, 0.0, -1.0];
let grad = sparsemap_gradient(&res, &upstream);
assert_eq!(grad.len(), scores.len());
}
#[test]
fn test_sparsemap_gradient_zeros_outside_support() {
let scores = vec![5.0_f64, -5.0, -5.0];
let cfg = SparsemapConfig::default();
let res = sparsemap(&scores, &cfg).unwrap();
let upstream = vec![1.0_f64, 1.0, 1.0];
let grad = sparsemap_gradient(&res, &upstream);
for (i, &g) in grad.iter().enumerate() {
if !res.support.contains(&i) {
assert!(g.abs() < EPS, "index {} outside support has grad {}", i, g);
}
}
}
#[test]
fn test_sparsemap_knapsack_feasibility() {
let weights = vec![1.0_f64, 2.0, 3.0];
let capacity = 3.0_f64;
let cfg = SparsemapConfig {
structure_type: StructureType::Knapsack {
weights: weights.clone(),
capacity,
},
max_iter: 500,
..SparsemapConfig::default()
};
let scores = vec![3.0_f64, 2.0, 1.0];
let res = sparsemap(&scores, &cfg).unwrap();
for &v in &res.solution {
assert!(v >= -EPS && v <= 1.0 + EPS, "value {} out of [0,1]", v);
}
let used: f64 = weights
.iter()
.zip(res.solution.iter())
.map(|(&w, &v)| w * v)
.sum();
assert!(used <= capacity + EPS, "capacity exceeded: {}", used);
}
#[test]
fn test_perturbed_optimizer_config_defaults() {
let cfg = PerturbedOptimizerConfig::default();
assert_eq!(cfg.n_samples, 100);
assert!((cfg.epsilon - 0.1).abs() < 1e-12);
assert_eq!(cfg.seed, 42);
}
#[test]
fn test_perturbed_optimizer_output_sums_to_one() {
let cfg = PerturbedOptimizerConfig {
n_samples: 200,
..Default::default()
};
let opt = PerturbedOptimizer::new(cfg);
let scores = vec![1.0_f64, 2.0, 0.5, 3.0];
let probs = opt.forward(&scores).unwrap();
let sum: f64 = probs.iter().sum();
assert!((sum - 1.0).abs() < 0.01, "sum = {}", sum);
}
#[test]
fn test_perturbed_optimizer_n_samples_1_deterministic() {
let cfg = PerturbedOptimizerConfig {
n_samples: 1,
seed: 7,
..Default::default()
};
let opt = PerturbedOptimizer::new(cfg.clone());
let scores = vec![1.0_f64, 2.0, 0.5];
let p1 = opt.forward(&scores).unwrap();
let opt2 = PerturbedOptimizer::new(cfg);
let p2 = opt2.forward(&scores).unwrap();
for (a, b) in p1.iter().zip(p2.iter()) {
assert_eq!(a, b, "results differ between identical seeds");
}
}
#[test]
fn test_soft_sort_nondecreasing() {
let x = vec![3.0_f64, 1.0, 4.0, 1.5, 9.0, 2.6];
let sorted = soft_sort(&x, 0.0_f64).unwrap();
for w in sorted.windows(2) {
assert!(w[0] <= w[1] + 1e-10, "not sorted: {} > {}", w[0], w[1]);
}
}
#[test]
fn test_soft_sort_nonzero_temp_nondecreasing() {
let x = vec![5.0_f64, 1.0, 3.0, 2.0];
let sorted = soft_sort(&x, 0.5_f64).unwrap();
for w in sorted.windows(2) {
assert!(
w[0] <= w[1] + 1e-9,
"soft_sort not sorted: {} > {}",
w[0],
w[1]
);
}
}
#[test]
fn test_soft_rank_high_temp_input_3_1_2() {
let x = vec![3.0_f64, 1.0, 2.0];
let ranks = soft_rank(&x, 0.0_f64).unwrap();
assert_eq!(ranks[0] as usize, 3, "rank of largest should be 3");
assert_eq!(ranks[1] as usize, 1, "rank of smallest should be 1");
assert_eq!(ranks[2] as usize, 2, "rank of middle should be 2");
}
#[test]
fn test_diff_topk_sums_to_k() {
let scores = vec![1.0_f64, 5.0, 2.0, 4.0, 3.0];
let k = 3;
let p = diff_topk(&scores, k, 0.5_f64).unwrap();
let sum: f64 = p.iter().sum();
assert!(
(sum - k as f64).abs() < 1e-6,
"sum = {} but expected k={}",
sum,
k
);
}
#[test]
fn test_diff_topk_zero_temp_hard_topk() {
let scores = vec![1.0_f64, 5.0, 2.0, 4.0, 3.0];
let k = 2;
let p = diff_topk(&scores, k, 0.0_f64).unwrap();
let sum: f64 = p.iter().sum();
assert!((sum - k as f64).abs() < 1e-9);
assert!((p[1] - 1.0).abs() < 1e-9, "index 1 should be selected");
assert!((p[3] - 1.0).abs() < 1e-9, "index 3 should be selected");
}
#[test]
fn test_diff_topk_all_values_nonneg() {
let scores = vec![0.1_f64, 2.3, -1.0, 5.0, 0.7];
let k = 2usize;
let p = diff_topk(&scores, k, 1.0_f64).unwrap();
for &v in &p {
assert!(v >= -1e-9, "value {} is negative", v);
assert!(v <= k as f64 + 1e-9, "value {} exceeds k={}", v, k);
}
let sum: f64 = p.iter().sum();
assert!(
(sum - k as f64).abs() < 1e-6,
"sum = {} expected k={}",
sum,
k
);
}
}