use crate::dispatch_buffers::{
ceil_div_u32, decode_u32_output_exact, ensure_input_slots, write_u32_slice_le_bytes,
write_zero_bytes,
};
use crate::optimizer::dispatcher::{DispatchError, OptimizerDispatcher};
use vyre_primitives::reduce::{
all::reduce_all, any::reduce_any, count_non_zero::reduce_count_non_zero,
histogram::histogram_atomic_scatter, max::reduce_max, min::reduce_min,
segment_reduce::segment_reduce_sum, sum::reduce_sum,
};
#[cfg(any(test, feature = "cpu-parity"))]
use vyre_primitives::reduce::{
all::cpu_ref as primitive_all, any::cpu_ref as primitive_any,
count_non_zero::cpu_ref as primitive_count_non_zero, histogram::cpu_ref as primitive_histogram,
max::cpu_ref as primitive_max, min::cpu_ref as primitive_min,
segment_reduce::cpu_ref as primitive_segment_reduce_sum, sum::cpu_ref as primitive_sum,
};
#[derive(Debug, Default)]
pub struct ReductionMetricsGpuScratch {
inputs: Vec<Vec<u8>>,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum ReductionMetric {
Sum,
Max,
Min,
CountNonZero,
Any,
All,
}
pub fn reduce_metric_via(
dispatcher: &dyn OptimizerDispatcher,
metric: ReductionMetric,
values: &[u32],
) -> Result<u32, DispatchError> {
let mut scratch = ReductionMetricsGpuScratch::default();
reduce_metric_via_with_scratch(dispatcher, metric, values, &mut scratch)
}
pub fn reduce_metric_via_with_scratch(
dispatcher: &dyn OptimizerDispatcher,
metric: ReductionMetric,
values: &[u32],
scratch: &mut ReductionMetricsGpuScratch,
) -> Result<u32, DispatchError> {
use crate::observability::{bump, reduction_metrics_calls};
bump(&reduction_metrics_calls);
let count = checked_len(values.len(), "reduce_metric_via")?;
let program = match metric {
ReductionMetric::Sum => reduce_sum("values", "out", count),
ReductionMetric::Max => reduce_max("values", "out", count),
ReductionMetric::Min => reduce_min("values", "out", count),
ReductionMetric::CountNonZero => reduce_count_non_zero("values", "out", count),
ReductionMetric::Any => reduce_any("values", "out", count),
ReductionMetric::All => reduce_all("values", "out", count),
};
ensure_input_slots(&mut scratch.inputs, 2);
write_u32_slice_le_bytes(&mut scratch.inputs[0], values);
write_zero_bytes(&mut scratch.inputs[1], std::mem::size_of::<u32>());
let outputs = dispatcher.dispatch(
&program,
&scratch.inputs,
Some(grid_for_metric(metric, count)),
)?;
decode_scalar(&outputs, "reduce_metric_via")
}
pub fn reduce_sum_via(
dispatcher: &dyn OptimizerDispatcher,
values: &[u32],
) -> Result<u32, DispatchError> {
reduce_metric_via(dispatcher, ReductionMetric::Sum, values)
}
pub fn reduce_max_via(
dispatcher: &dyn OptimizerDispatcher,
values: &[u32],
) -> Result<u32, DispatchError> {
reduce_metric_via(dispatcher, ReductionMetric::Max, values)
}
pub fn reduce_min_via(
dispatcher: &dyn OptimizerDispatcher,
values: &[u32],
) -> Result<u32, DispatchError> {
reduce_metric_via(dispatcher, ReductionMetric::Min, values)
}
pub fn reduce_count_non_zero_via(
dispatcher: &dyn OptimizerDispatcher,
values: &[u32],
) -> Result<u32, DispatchError> {
reduce_metric_via(dispatcher, ReductionMetric::CountNonZero, values)
}
pub fn reduce_any_via(
dispatcher: &dyn OptimizerDispatcher,
values: &[u32],
) -> Result<bool, DispatchError> {
Ok(reduce_metric_via(dispatcher, ReductionMetric::Any, values)? != 0)
}
pub fn reduce_all_via(
dispatcher: &dyn OptimizerDispatcher,
values: &[u32],
) -> Result<bool, DispatchError> {
Ok(reduce_metric_via(dispatcher, ReductionMetric::All, values)? != 0)
}
pub fn segment_reduce_sum_via(
dispatcher: &dyn OptimizerDispatcher,
input: &[u32],
segment_offsets: &[u32],
) -> Result<Vec<u32>, DispatchError> {
let mut out = Vec::new();
let mut scratch = ReductionMetricsGpuScratch::default();
segment_reduce_sum_via_with_scratch_into(
dispatcher,
input,
segment_offsets,
&mut scratch,
&mut out,
)?;
Ok(out)
}
pub fn segment_reduce_sum_via_with_scratch_into(
dispatcher: &dyn OptimizerDispatcher,
input: &[u32],
segment_offsets: &[u32],
scratch: &mut ReductionMetricsGpuScratch,
out: &mut Vec<u32>,
) -> Result<(), DispatchError> {
use crate::observability::{bump, reduction_metrics_calls};
bump(&reduction_metrics_calls);
let num_segments = validate_segment_offsets(input, segment_offsets)?;
let program = segment_reduce_sum("input", "segment_offsets", "output", num_segments);
ensure_input_slots(&mut scratch.inputs, 3);
write_u32_slice_le_bytes(&mut scratch.inputs[0], input);
write_u32_slice_le_bytes(&mut scratch.inputs[1], segment_offsets);
write_zero_bytes(
&mut scratch.inputs[2],
num_segments as usize * std::mem::size_of::<u32>(),
);
let outputs = dispatcher.dispatch(&program, &scratch.inputs, Some([1, 1, 1]))?;
decode_first_output(
&outputs,
num_segments as usize,
"segment_reduce_sum_via",
out,
)
}
pub fn histogram_atomic_scatter_via(
dispatcher: &dyn OptimizerDispatcher,
input: &[u32],
num_bins: u32,
) -> Result<Vec<u32>, DispatchError> {
use crate::observability::{bump, reduction_metrics_calls};
bump(&reduction_metrics_calls);
let count = checked_nonzero_len(input.len(), "histogram_atomic_scatter_via")?;
if num_bins == 0 {
return Err(DispatchError::BadInputs(
"Fix: histogram_atomic_scatter_via requires num_bins > 0.".to_string(),
));
}
let bin_count = num_bins as usize;
let program = histogram_atomic_scatter("input", "output", count, num_bins);
let mut scratch = ReductionMetricsGpuScratch::default();
ensure_input_slots(&mut scratch.inputs, 2);
write_u32_slice_le_bytes(&mut scratch.inputs[0], input);
write_zero_bytes(
&mut scratch.inputs[1],
bin_count * std::mem::size_of::<u32>(),
);
let outputs = dispatcher.dispatch(
&program,
&scratch.inputs,
Some([ceil_div_u32(count, 256), 1, 1]),
)?;
let mut out = Vec::new();
decode_first_output(
&outputs,
bin_count,
"histogram_atomic_scatter_via",
&mut out,
)?;
Ok(out)
}
#[cfg(any(test, feature = "cpu-parity"))]
#[must_use]
pub fn reference_reduce_sum(values: &[u32]) -> u32 {
primitive_sum(values)
}
#[cfg(any(test, feature = "cpu-parity"))]
#[must_use]
pub fn reference_reduce_max(values: &[u32]) -> u32 {
primitive_max(values)
}
#[cfg(any(test, feature = "cpu-parity"))]
#[must_use]
pub fn reference_reduce_min(values: &[u32]) -> u32 {
primitive_min(values)
}
#[cfg(any(test, feature = "cpu-parity"))]
#[must_use]
pub fn reference_reduce_count_non_zero(values: &[u32]) -> u32 {
primitive_count_non_zero(values)
}
#[cfg(any(test, feature = "cpu-parity"))]
#[must_use]
pub fn reference_reduce_any(values: &[u32]) -> bool {
primitive_any(values) != 0
}
#[cfg(any(test, feature = "cpu-parity"))]
#[must_use]
pub fn reference_reduce_all(values: &[u32]) -> bool {
primitive_all(values) != 0
}
#[cfg(any(test, feature = "cpu-parity"))]
#[must_use]
pub fn reference_segment_reduce_sum(input: &[u32], segment_offsets: &[u32]) -> Vec<u32> {
primitive_segment_reduce_sum(input, segment_offsets)
}
#[cfg(any(test, feature = "cpu-parity"))]
#[must_use]
pub fn reference_histogram_atomic_scatter(input: &[u32], num_bins: u32) -> Vec<u32> {
primitive_histogram(input, num_bins)
}
fn grid_for_metric(metric: ReductionMetric, count: u32) -> [u32; 3] {
match metric {
ReductionMetric::Sum | ReductionMetric::Max | ReductionMetric::Min => {
[ceil_div_u32(count, 256), 1, 1]
}
ReductionMetric::CountNonZero | ReductionMetric::Any | ReductionMetric::All => [1, 1, 1],
}
}
fn checked_len(len: usize, context: &'static str) -> Result<u32, DispatchError> {
u32::try_from(len).map_err(|_| {
DispatchError::BadInputs(format!(
"Fix: {context} received {len} values, which exceeds the u32 GPU index space."
))
})
}
fn checked_nonzero_len(len: usize, context: &'static str) -> Result<u32, DispatchError> {
let count = checked_len(len, context)?;
if count == 0 {
return Err(DispatchError::BadInputs(format!(
"Fix: {context} requires count > 0."
)));
}
Ok(count)
}
fn validate_segment_offsets(input: &[u32], segment_offsets: &[u32]) -> Result<u32, DispatchError> {
if segment_offsets.len() < 2 {
return Err(DispatchError::BadInputs(
"Fix: segment_reduce_sum_via requires at least two CSR offsets.".to_string(),
));
}
let num_segments = segment_offsets.len() - 1;
if num_segments > 256 {
return Err(DispatchError::BadInputs(format!(
"Fix: segment_reduce_sum_via supports at most 256 segments per primitive dispatch, got {num_segments}."
)));
}
for (segment, pair) in segment_offsets.windows(2).enumerate() {
let start = pair[0] as usize;
let end = pair[1] as usize;
if start > end || end > input.len() {
return Err(DispatchError::BadInputs(format!(
"Fix: segment_reduce_sum_via received malformed segment {segment}: start={start}, end={end}, input_len={}.",
input.len()
)));
}
}
Ok(num_segments as u32)
}
fn decode_scalar(outputs: &[Vec<u8>], context: &'static str) -> Result<u32, DispatchError> {
let mut out = Vec::new();
decode_first_output(outputs, 1, context, &mut out)?;
Ok(out[0])
}
fn decode_first_output(
outputs: &[Vec<u8>],
words: usize,
context: &'static str,
out: &mut Vec<u32>,
) -> Result<(), DispatchError> {
if outputs.is_empty() {
return Err(DispatchError::BackendError(format!(
"Fix: {context} expected at least one output buffer, got 0."
)));
}
decode_u32_output_exact(&outputs[0], words, context, out)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dispatch_buffers::u32_slice_to_le_bytes;
use vyre_foundation::ir::Program;
struct ReduceDispatcher;
impl OptimizerDispatcher for ReduceDispatcher {
fn dispatch(
&self,
program: &Program,
inputs: &[Vec<u8>],
grid_override: Option<[u32; 3]>,
) -> Result<Vec<Vec<u8>>, DispatchError> {
let op_id = program
.entry
.iter()
.find_map(|node| match node {
vyre_foundation::ir::Node::Region { generator, .. } => Some(generator.as_str()),
_ => None,
})
.expect("Fix: reduction primitive should expose a region generator");
let values = crate::hardware::dispatch_buffers::read_u32s(&inputs[0]);
match op_id {
vyre_primitives::reduce::sum::OP_ID => {
assert_eq!(grid_override, Some([1, 1, 1]));
scalar(primitive_sum(&values))
}
vyre_primitives::reduce::max::OP_ID => {
assert_eq!(grid_override, Some([1, 1, 1]));
scalar(primitive_max(&values))
}
vyre_primitives::reduce::min::OP_ID => {
assert_eq!(grid_override, Some([1, 1, 1]));
scalar(primitive_min(&values))
}
vyre_primitives::reduce::count_non_zero::OP_ID => {
assert_eq!(grid_override, Some([1, 1, 1]));
scalar(primitive_count_non_zero(&values))
}
vyre_primitives::reduce::any::OP_ID => {
assert_eq!(grid_override, Some([1, 1, 1]));
scalar(primitive_any(&values))
}
vyre_primitives::reduce::all::OP_ID => {
assert_eq!(grid_override, Some([1, 1, 1]));
scalar(primitive_all(&values))
}
vyre_primitives::reduce::segment_reduce::OP_ID => {
assert_eq!(grid_override, Some([1, 1, 1]));
let offsets = crate::hardware::dispatch_buffers::read_u32s(&inputs[1]);
Ok(vec![u32_slice_to_le_bytes(&primitive_segment_reduce_sum(
&values, &offsets,
))])
}
vyre_primitives::reduce::histogram::OP_ID => {
assert_eq!(grid_override, Some([1, 1, 1]));
let bins = (inputs[1].len() / std::mem::size_of::<u32>()) as u32;
Ok(vec![u32_slice_to_le_bytes(&primitive_histogram(
&values, bins,
))])
}
other => panic!("unexpected reduction primitive op id {other}"),
}
}
}
fn scalar(value: u32) -> Result<Vec<Vec<u8>>, DispatchError> {
Ok(vec![u32_slice_to_le_bytes(&[value])])
}
#[test]
fn reference_reductions_match_primitives_exactly() {
let values = [1u32, 0, 7, u32::MAX];
assert_eq!(reference_reduce_sum(&values), primitive_sum(&values));
assert_eq!(reference_reduce_max(&values), primitive_max(&values));
assert_eq!(reference_reduce_min(&values), primitive_min(&values));
assert_eq!(
reference_reduce_count_non_zero(&values),
primitive_count_non_zero(&values)
);
assert_eq!(reference_reduce_any(&values), primitive_any(&values) != 0);
assert_eq!(reference_reduce_all(&values), primitive_all(&values) != 0);
}
#[test]
fn scalar_reductions_dispatch_through_primitives() {
let values = [1u32, 0, 7, 3];
assert_eq!(reduce_sum_via(&ReduceDispatcher, &values).unwrap(), 11);
assert_eq!(reduce_max_via(&ReduceDispatcher, &values).unwrap(), 7);
assert_eq!(reduce_min_via(&ReduceDispatcher, &values).unwrap(), 0);
assert_eq!(
reduce_count_non_zero_via(&ReduceDispatcher, &values).unwrap(),
3
);
assert!(reduce_any_via(&ReduceDispatcher, &values).unwrap());
assert!(!reduce_all_via(&ReduceDispatcher, &values).unwrap());
}
#[test]
fn segment_and_histogram_dispatch_through_primitives() {
assert_eq!(
segment_reduce_sum_via(&ReduceDispatcher, &[1, 2, 3, 4, 5], &[0, 2, 5]).unwrap(),
vec![3, 12]
);
assert_eq!(
histogram_atomic_scatter_via(&ReduceDispatcher, &[0, 1, 2, 1, 9], 4).unwrap(),
vec![1, 2, 1, 0]
);
}
#[test]
fn scratch_path_reuses_buffers() {
let mut scratch = ReductionMetricsGpuScratch::default();
assert_eq!(
reduce_metric_via_with_scratch(
&ReduceDispatcher,
ReductionMetric::CountNonZero,
&[0, 1, 2],
&mut scratch,
)
.unwrap(),
2
);
let input_capacities = scratch.inputs.iter().map(Vec::capacity).collect::<Vec<_>>();
assert_eq!(
reduce_metric_via_with_scratch(
&ReduceDispatcher,
ReductionMetric::CountNonZero,
&[0, 1, 2],
&mut scratch,
)
.unwrap(),
2
);
assert_eq!(
scratch.inputs.iter().map(Vec::capacity).collect::<Vec<_>>(),
input_capacities
);
}
#[test]
fn invalid_segment_offsets_are_actionable() {
let err = segment_reduce_sum_via(&ReduceDispatcher, &[1, 2], &[0, 3]).unwrap_err();
assert!(err
.to_string()
.contains("Fix: segment_reduce_sum_via received malformed segment"));
}
#[test]
fn zero_bin_histogram_is_rejected_before_dispatch() {
let err = histogram_atomic_scatter_via(&ReduceDispatcher, &[1], 0).unwrap_err();
assert!(err
.to_string()
.contains("Fix: histogram_atomic_scatter_via requires num_bins > 0"));
}
}