use vyre_primitives::math::submodular_greedy::{argmax_of_marginals_cpu, NO_WINNER};
#[must_use]
pub fn select_retention_set(gains: &mut [u32], n: u32, k: u32) -> Vec<u32> {
let mut picked = Vec::with_capacity(n as usize);
select_retention_set_into(gains, n, k, &mut picked);
picked
}
pub fn select_retention_set_into(gains: &mut [u32], n: u32, k: u32, picked: &mut Vec<u32>) {
use crate::observability::{bump, submodular_cache_eviction_calls};
bump(&submodular_cache_eviction_calls);
assert_eq!(gains.len(), n as usize);
assert!(k <= n, "Fix: k must not exceed n.");
picked.clear();
picked.resize(n as usize, 0);
let mut keep_count = 0u32;
while keep_count < k {
let (winner, _) = argmax_of_marginals_cpu(gains, picked);
if winner == NO_WINNER {
break;
}
picked[winner as usize] = 1;
gains[winner as usize] = 0;
keep_count += 1;
}
}
#[must_use]
pub fn invert_to_eviction_set(retention: &[u32]) -> Vec<u32> {
let mut eviction = Vec::with_capacity(retention.len());
invert_to_eviction_set_into(retention, &mut eviction);
eviction
}
pub fn invert_to_eviction_set_into(retention: &[u32], eviction: &mut Vec<u32>) {
eviction.clear();
eviction.reserve(retention.len());
eviction.extend(retention.iter().map(|&r| if r == 0 { 1 } else { 0 }));
}
#[must_use]
pub fn greedy_quality_bound(optimum: u32) -> u32 {
((optimum as u64) * 6321 / 10000) as u32
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn picks_top_k_by_gain() {
let mut gains = vec![3u32, 7, 2, 9, 5];
let retention = select_retention_set(&mut gains, 5, 3);
assert_eq!(retention, vec![0, 1, 0, 1, 1]);
}
#[test]
fn k_eq_zero_evicts_all() {
let mut gains = vec![3u32, 7, 2, 9, 5];
let retention = select_retention_set(&mut gains, 5, 0);
assert_eq!(retention, vec![0; 5]);
}
#[test]
fn k_eq_n_retains_all() {
let mut gains = vec![3u32, 7, 2, 9, 5];
let retention = select_retention_set(&mut gains, 5, 5);
assert_eq!(retention, vec![1; 5]);
}
#[test]
fn invert_complements_retention() {
let retention = vec![1, 0, 1, 0, 1];
let eviction = invert_to_eviction_set(&retention);
assert_eq!(eviction, vec![0, 1, 0, 1, 0]);
}
#[test]
fn invert_into_reuses_eviction_buffer() {
let retention = vec![1, 0, 1, 0, 1];
let mut eviction = Vec::with_capacity(8);
let ptr = eviction.as_ptr();
invert_to_eviction_set_into(&retention, &mut eviction);
assert_eq!(eviction, vec![0, 1, 0, 1, 0]);
assert_eq!(eviction.as_ptr(), ptr);
}
#[test]
fn quality_bound_is_lower_bound() {
assert_eq!(greedy_quality_bound(100), 63);
assert_eq!(greedy_quality_bound(1000), 632);
}
#[test]
fn k_larger_than_n_panics() {
let mut gains = vec![1u32, 2, 3];
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
select_retention_set(&mut gains, 3, 5)
}));
assert!(result.is_err(), "k > n must panic");
}
}