use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::math::iht_threshold";
#[must_use]
pub fn iht_threshold(z: &str, threshold: &str, out: &str, n: u32) -> Program {
if n == 0 {
return crate::invalid_output_program(
OP_ID,
out,
DataType::U32,
format!("Fix: iht_threshold requires n > 0, got {n}."),
);
}
let t = Expr::InvocationId { axis: 0 };
let abs_z = Expr::bitand(Expr::load(z, t.clone()), Expr::u32(0x7FFF_FFFF));
let thresh_v = Expr::load(threshold, Expr::u32(0));
let value = Expr::select(
Expr::ge(abs_z, thresh_v),
Expr::load(z, t.clone()),
Expr::u32(0),
);
let body = vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(n)),
vec![Node::store(out, t, value)],
)];
Program::wrapped(
vec![
BufferDecl::storage(z, 0, BufferAccess::ReadOnly, DataType::U32).with_count(n),
BufferDecl::storage(threshold, 1, BufferAccess::ReadOnly, DataType::U32).with_count(1),
BufferDecl::storage(out, 2, BufferAccess::ReadWrite, DataType::U32).with_count(n),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
pub fn iht_top_k_cpu(z: &[f64], k: usize) -> (Vec<f64>, f64) {
let n = z.len();
if k >= n {
return (z.to_vec(), 0.0);
}
if k == 0 {
return (vec![0.0; n], f64::INFINITY);
}
let mut order: Vec<usize> = (0..n).collect();
order.sort_by(|&i, &j| finite_abs_score(z[j]).total_cmp(&finite_abs_score(z[i])));
let threshold = z[order[k - 1]].abs();
let mut out = vec![0.0; n];
for &i in &order[..k] {
out[i] = z[i];
}
(out, threshold)
}
fn finite_abs_score(value: f64) -> f64 {
let abs = value.abs();
if abs.is_nan() {
f64::NEG_INFINITY
} else {
abs
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < 1e-10 * (1.0 + a.abs() + b.abs())
}
#[test]
fn cpu_top_2_keeps_largest() {
let z = vec![0.1, -2.0, 0.5, 3.0, -0.05];
let (out, thresh) = iht_top_k_cpu(&z, 2);
assert!(approx_eq(out[3], 3.0));
assert!(approx_eq(out[1], -2.0));
assert!(approx_eq(out[0], 0.0));
assert!(approx_eq(out[2], 0.0));
assert!(approx_eq(out[4], 0.0));
assert!(approx_eq(thresh, 2.0));
}
#[test]
fn cpu_k_equals_n_returns_all() {
let z = vec![1.0, 2.0, 3.0];
let (out, _) = iht_top_k_cpu(&z, 3);
assert_eq!(out, z);
}
#[test]
fn cpu_k_zero_zeros_all() {
let z = vec![1.0, 2.0, 3.0];
let (out, thresh) = iht_top_k_cpu(&z, 0);
for v in out {
assert!(approx_eq(v, 0.0));
}
assert!(thresh.is_infinite());
}
#[test]
fn cpu_preserves_signs() {
let z = vec![-5.0, 3.0, -7.0];
let (out, _) = iht_top_k_cpu(&z, 2);
assert!(approx_eq(out[2], -7.0));
assert!(approx_eq(out[0], -5.0));
assert!(approx_eq(out[1], 0.0));
}
#[test]
fn ir_program_buffer_layout() {
let p = iht_threshold("z", "th", "out", 32);
assert_eq!(p.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["z", "th", "out"]);
assert_eq!(p.buffers[0].count(), 32);
assert_eq!(p.buffers[1].count(), 1);
assert_eq!(p.buffers[2].count(), 32);
}
#[test]
fn zero_n_traps() {
let p = iht_threshold("z", "th", "out", 0);
assert!(p.stats().trap());
}
}