use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::reduce::histogram";
#[must_use]
pub fn histogram(input: &str, output: &str, count: u32, num_bins: u32) -> Program {
if count == 0 {
return crate::invalid_output_program(
OP_ID,
output,
DataType::U32,
format!("Fix: histogram requires count > 0, got {count}."),
);
}
if num_bins == 0 {
return crate::invalid_output_program(
OP_ID,
output,
DataType::U32,
format!("Fix: histogram requires num_bins > 0, got {num_bins}."),
);
}
let t = Expr::InvocationId { axis: 0 };
let body = vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(num_bins)),
vec![
Node::let_bind("total", Expr::u32(0)),
Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(count),
vec![Node::assign(
"total",
Expr::add(
Expr::var("total"),
Expr::select(
Expr::eq(Expr::load(input, Expr::var("i")), t.clone()),
Expr::u32(1),
Expr::u32(0),
),
),
)],
),
Node::store(output, t.clone(), Expr::var("total")),
],
)];
Program::wrapped(
vec![
BufferDecl::storage(input, 0, BufferAccess::ReadOnly, DataType::U32).with_count(count),
BufferDecl::storage(output, 1, BufferAccess::ReadWrite, DataType::U32)
.with_count(num_bins),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
pub fn histogram_atomic_scatter(input: &str, output: &str, count: u32, num_bins: u32) -> Program {
if count == 0 {
return crate::invalid_output_program(
OP_ID,
output,
DataType::U32,
format!("Fix: histogram_atomic_scatter requires count > 0, got {count}."),
);
}
if num_bins == 0 {
return crate::invalid_output_program(
OP_ID,
output,
DataType::U32,
format!("Fix: histogram_atomic_scatter requires num_bins > 0, got {num_bins}."),
);
}
let t = Expr::InvocationId { axis: 0 };
let body = vec![
Node::let_bind("bin", Expr::load(input, t.clone())),
Node::if_then(
Expr::lt(Expr::var("bin"), Expr::u32(num_bins)),
vec![Node::let_bind(
"_prev",
Expr::atomic_add(output, Expr::var("bin"), Expr::u32(1)),
)],
),
];
Program::wrapped(
vec![
BufferDecl::storage(input, 0, BufferAccess::ReadOnly, DataType::U32).with_count(count),
BufferDecl::storage(output, 1, BufferAccess::ReadWrite, DataType::U32)
.with_count(num_bins),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(count)),
body,
)]),
}],
)
}
#[must_use]
#[cfg(any(test, feature = "cpu-parity"))]
pub fn cpu_ref(input: &[u32], num_bins: u32) -> Vec<u32> {
let mut out = Vec::new();
match try_cpu_ref_into(input, num_bins, &mut out) {
Ok(()) => out,
Err(error) => {
eprintln!("vyre-primitives histogram CPU reference failed: {error}");
Vec::new()
}
}
}
#[cfg(any(test, feature = "cpu-parity"))]
pub fn cpu_ref_into(input: &[u32], num_bins: u32, out: &mut Vec<u32>) {
if let Err(error) = try_cpu_ref_into(input, num_bins, out) {
eprintln!("vyre-primitives histogram CPU reference failed: {error}");
out.clear();
}
}
#[cfg(any(test, feature = "cpu-parity"))]
pub fn try_cpu_ref_into(input: &[u32], num_bins: u32, out: &mut Vec<u32>) -> Result<(), String> {
let num_bins = usize::try_from(num_bins)
.map_err(|_| format!("histogram bin count {num_bins} does not fit host usize"))?;
if num_bins > out.capacity() {
out.try_reserve_exact(num_bins - out.capacity())
.map_err(|err| {
format!("histogram CPU reference could not reserve {num_bins} bins: {err}")
})?;
}
out.clear();
out.resize(num_bins, 0);
for &bin in input {
if let Ok(bin) = usize::try_from(bin) {
if let Some(slot) = out.get_mut(bin) {
*slot = slot.wrapping_add(1);
}
}
}
Ok(())
}
#[cfg(feature = "inventory-registry")]
inventory::submit! {
crate::harness::OpEntry::new(
OP_ID,
|| histogram("input", "output", 8, 4),
Some(|| {
let to_bytes = |w: &[u32]| crate::wire::pack_u32_slice(w);
vec![vec![
to_bytes(&[0, 1, 2, 3, 0, 1, 2, 3]),
to_bytes(&[0, 0, 0, 0]),
]]
}),
Some(|| {
let to_bytes = |w: &[u32]| crate::wire::pack_u32_slice(w);
vec![vec![to_bytes(&[2, 2, 2, 2])]]
}),
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_histogram() {
let input = &[0u32, 1, 2, 3, 0, 1, 2, 3];
assert_eq!(cpu_ref(input, 4), vec![2, 2, 2, 2]);
}
#[test]
fn empty_input() {
assert_eq!(cpu_ref(&[], 4), vec![0, 0, 0, 0]);
}
#[test]
fn all_same_bin() {
let input = &[2u32, 2, 2, 2, 2];
assert_eq!(cpu_ref(input, 4), vec![0, 0, 5, 0]);
}
#[test]
fn out_of_bounds_ignored() {
let input = &[0u32, 1, 99, 2, 3, 100];
assert_eq!(cpu_ref(input, 4), vec![1, 1, 1, 1]);
}
#[test]
fn try_cpu_ref_into_reuses_output_and_clears_stale_tail() {
let input = &[0u32, 1, 99, 2, 3, 100];
let mut out = Vec::with_capacity(16);
out.extend_from_slice(&[u32::MAX; 16]);
let ptr = out.as_ptr();
try_cpu_ref_into(input, 4, &mut out).unwrap();
assert_eq!(out, vec![1, 1, 1, 1]);
assert_eq!(out.as_ptr(), ptr);
}
#[test]
fn compatibility_wrappers_match_fallible_reference() {
let input = &[0u32, 1, 99, 2, 3, 100];
let mut compat = Vec::with_capacity(16);
let mut fallible = Vec::with_capacity(16);
cpu_ref_into(input, 4, &mut compat);
try_cpu_ref_into(input, 4, &mut fallible)
.expect("Fix: small histogram CPU reference must reserve");
assert_eq!(cpu_ref(input, 4), fallible);
assert_eq!(compat, fallible);
}
#[test]
fn production_cpu_ref_wrappers_have_no_raw_panic_path() {
let production = include_str!("histogram.rs")
.split("#[cfg(test)]")
.next()
.expect("Fix: histogram.rs must contain production section");
assert!(
!production.contains(".expect(") && !production.contains(".unwrap("),
"Fix: histogram CPU parity wrappers must not panic in production."
);
}
#[test]
fn wrapping_on_overflow() {
let mut base = u32::MAX - 1;
base = base.wrapping_add(1); base = base.wrapping_add(1); assert_eq!(base, 0);
}
#[test]
fn wrapping_overflow_correct() {
let base = u32::MAX - 1;
let after_three = base.wrapping_add(3);
assert_eq!(after_three, 1);
}
#[test]
fn many_bins() {
let input: Vec<u32> = (0..100).collect();
let out = cpu_ref(&input, 100);
assert_eq!(out.len(), 100);
for (i, &v) in out.iter().enumerate() {
assert_eq!(v, 1, "bin {i} should have count 1");
}
}
#[test]
fn sparse_bins() {
let input = &[0u32, 50, 50, 99];
let mut expected = vec![0u32; 100];
expected[0] = 1;
expected[50] = 2;
expected[99] = 1;
assert_eq!(cpu_ref(input, 100), expected);
}
#[test]
fn program_has_expected_buffers() {
let p = histogram("in", "out", 1024, 16);
assert_eq!(p.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["in", "out"]);
}
#[test]
fn program_buffer_counts() {
let p = histogram("in", "out", 1024, 16);
assert_eq!(p.buffers[0].count(), 1024);
assert_eq!(p.buffers[1].count(), 16);
}
#[test]
fn zero_bins_traps() {
let p = histogram("in", "out", 10, 0);
assert!(p.stats().trap());
}
#[test]
fn zero_count_traps() {
let p = histogram("in", "out", 0, 4);
assert!(p.stats().trap());
}
#[test]
fn concurrent_access_cpu_simulation() {
let input = vec![7u32; 10_000];
let out = cpu_ref(&input, 16);
assert_eq!(out[7], 10_000);
for (i, &v) in out.iter().enumerate() {
if i != 7 {
assert_eq!(v, 0);
}
}
}
#[test]
fn adversarial_all_out_of_bounds() {
let input = &[100u32, 200, 300];
assert_eq!(cpu_ref(input, 2), vec![0, 0]);
}
#[test]
fn adversarial_max_u32_index() {
let input = &[u32::MAX];
assert_eq!(cpu_ref(input, 4), vec![0, 0, 0, 0]);
}
}