use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::math::submodular_argmax_of_marginals";
pub const NO_WINNER: u32 = u32::MAX;
#[must_use]
pub fn argmax_of_marginals(
gains: &str,
picked_mask: &str,
winner_idx: &str,
winner_gain: &str,
n_candidates: u32,
) -> Program {
if n_candidates == 0 {
return crate::invalid_output_program(
OP_ID,
winner_idx,
DataType::U32,
format!("Fix: argmax_of_marginals requires n_candidates > 0, got {n_candidates}."),
);
}
let t = Expr::InvocationId { axis: 0 };
let body = vec![Node::if_then(
Expr::eq(t.clone(), Expr::u32(0)),
vec![
Node::let_bind("best_idx", Expr::u32(NO_WINNER)),
Node::let_bind("best_gain", Expr::u32(0)),
Node::loop_for(
"c",
Expr::u32(0),
Expr::u32(n_candidates),
vec![Node::if_then(
Expr::eq(Expr::load(picked_mask, Expr::var("c")), Expr::u32(0)),
vec![
Node::let_bind("g", Expr::load(gains, Expr::var("c"))),
Node::if_then(
Expr::or(
Expr::eq(Expr::var("best_idx"), Expr::u32(NO_WINNER)),
Expr::gt(Expr::var("g"), Expr::var("best_gain")),
),
vec![
Node::assign("best_idx", Expr::var("c")),
Node::assign("best_gain", Expr::var("g")),
],
),
],
)],
),
Node::store(winner_idx, Expr::u32(0), Expr::var("best_idx")),
Node::store(winner_gain, Expr::u32(0), Expr::var("best_gain")),
],
)];
Program::wrapped(
vec![
BufferDecl::storage(gains, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(n_candidates),
BufferDecl::storage(picked_mask, 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(n_candidates),
BufferDecl::storage(winner_idx, 2, BufferAccess::ReadWrite, DataType::U32)
.with_count(1),
BufferDecl::storage(winner_gain, 3, BufferAccess::ReadWrite, DataType::U32)
.with_count(1),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
pub fn argmax_of_marginals_cpu(gains: &[u32], picked_mask: &[u32]) -> (u32, u32) {
let mut best: Option<(u32, u32)> = None;
for (i, (&g, &m)) in gains.iter().zip(picked_mask.iter()).enumerate() {
if m != 0 {
continue;
}
match best {
None => best = Some((i as u32, g)),
Some((_, bg)) if g > bg => best = Some((i as u32, g)),
_ => {}
}
}
best.unwrap_or((NO_WINNER, 0))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cpu_picks_global_max_when_nothing_picked() {
let gains = vec![3, 7, 2, 9, 5];
let picked = vec![0, 0, 0, 0, 0];
let (idx, gain) = argmax_of_marginals_cpu(&gains, &picked);
assert_eq!(idx, 3);
assert_eq!(gain, 9);
}
#[test]
fn cpu_skips_already_picked() {
let gains = vec![3, 7, 2, 9, 5];
let picked = vec![0, 0, 0, 1, 0]; let (idx, gain) = argmax_of_marginals_cpu(&gains, &picked);
assert_eq!(idx, 1);
assert_eq!(gain, 7);
}
#[test]
fn cpu_all_picked_returns_no_winner() {
let gains = vec![1, 2, 3];
let picked = vec![1, 1, 1];
let (idx, gain) = argmax_of_marginals_cpu(&gains, &picked);
assert_eq!(idx, NO_WINNER);
assert_eq!(gain, 0);
}
#[test]
fn cpu_mismatched_inputs_only_consider_complete_pairs() {
let (idx, gain) = argmax_of_marginals_cpu(&[3, 9, 1], &[1, 0]);
assert_eq!((idx, gain), (1, 9));
}
#[test]
fn cpu_ties_pick_first() {
let gains = vec![5, 5, 5];
let picked = vec![0, 0, 0];
let (idx, _) = argmax_of_marginals_cpu(&gains, &picked);
assert_eq!(idx, 0);
}
#[test]
fn cpu_simulated_greedy_loop_three_picks() {
let gains = vec![1, 5, 3, 8, 2];
let mut picked = vec![0u32; gains.len()];
let (i1, _) = argmax_of_marginals_cpu(&gains, &picked);
assert_eq!(i1, 3); picked[i1 as usize] = 1;
let (i2, _) = argmax_of_marginals_cpu(&gains, &picked);
assert_eq!(i2, 1); picked[i2 as usize] = 1;
let (i3, _) = argmax_of_marginals_cpu(&gains, &picked);
assert_eq!(i3, 2); }
#[test]
fn ir_program_buffer_layout() {
let p = argmax_of_marginals("g", "p", "wi", "wg", 16);
assert_eq!(p.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["g", "p", "wi", "wg"]);
assert_eq!(p.buffers[0].count(), 16);
assert_eq!(p.buffers[1].count(), 16);
assert_eq!(p.buffers[2].count(), 1);
assert_eq!(p.buffers[3].count(), 1);
}
#[test]
fn zero_candidates_traps() {
let p = argmax_of_marginals("g", "p", "wi", "wg", 0);
assert!(p.stats().trap());
}
}