#![cfg(test)]
mod common;
use common::{bytes_u32, u32_bytes, with_live_backend};
use vyre::DispatchConfig;
use vyre_primitives::reduce::gather::{cpu_ref as gather_cpu, gather};
use vyre_primitives::reduce::histogram::{cpu_ref as hist_cpu, histogram};
use vyre_primitives::reduce::radix_sort::{cpu_ref as radix_cpu, radix_sort};
use vyre_primitives::reduce::scatter::{cpu_ref as scatter_cpu, scatter};
use vyre_primitives::reduce::segment_reduce::{cpu_ref as seg_cpu, segment_reduce_sum};
fn run_gather(src: &[u32], indices: &[u32]) -> Vec<u32> {
let count = src.len() as u32;
let program = gather("src", "indices", "dst", count);
let inputs: Vec<Vec<u8>> = vec![
u32_bytes(src),
u32_bytes(indices),
vec![0u8; count as usize * 4],
];
let mut config = DispatchConfig::default();
let workgroup_x = 256u32;
let grid_x = ((count + workgroup_x - 1) / workgroup_x).max(1);
config.grid_override = Some([grid_x, 1, 1]);
let outputs = with_live_backend("gather primitive", |backend| {
backend
.dispatch(&program, &inputs, &config)
.unwrap_or_else(|error| panic!("Fix: CUDA gather dispatch failed: {error}"))
});
let mut out = bytes_u32(&outputs[0]);
out.truncate(count as usize);
out
}
#[test]
fn cuda_gather_identity() {
let src = vec![10u32, 20, 30, 40];
let indices = vec![0u32, 1, 2, 3];
let cpu = gather_cpu(&src, &indices);
let gpu = run_gather(&src, &indices);
assert_eq!(gpu, cpu);
assert_eq!(gpu, src);
}
#[test]
fn cuda_gather_reverse() {
let src = vec![10u32, 20, 30, 40];
let indices = vec![3u32, 2, 1, 0];
let cpu = gather_cpu(&src, &indices);
let gpu = run_gather(&src, &indices);
assert_eq!(gpu, cpu);
assert_eq!(gpu, vec![40, 30, 20, 10]);
}
fn run_scatter(src: &[u32], indices: &[u32]) -> Vec<u32> {
let count = src.len() as u32;
let program = scatter("src", "indices", "dst", count);
let inputs: Vec<Vec<u8>> = vec![
u32_bytes(src),
u32_bytes(indices),
vec![0u8; count as usize * 4],
];
let mut config = DispatchConfig::default();
let workgroup_x = 256u32;
let grid_x = ((count + workgroup_x - 1) / workgroup_x).max(1);
config.grid_override = Some([grid_x, 1, 1]);
let outputs = with_live_backend("scatter primitive", |backend| {
backend
.dispatch(&program, &inputs, &config)
.unwrap_or_else(|error| panic!("Fix: CUDA scatter dispatch failed: {error}"))
});
let mut out = bytes_u32(&outputs[0]);
out.truncate(count as usize);
out
}
#[test]
fn cuda_scatter_inverse_of_gather() {
let src = vec![10u32, 20, 30, 40];
let indices = vec![3u32, 2, 1, 0];
let cpu = scatter_cpu(&src, &indices, src.len());
let gpu = run_scatter(&src, &indices);
assert_eq!(gpu, cpu);
assert_eq!(gpu, vec![40, 30, 20, 10]);
}
#[test]
fn cuda_scatter_identity() {
let src = vec![5u32, 6, 7, 8];
let indices = vec![0u32, 1, 2, 3];
let cpu = scatter_cpu(&src, &indices, src.len());
let gpu = run_scatter(&src, &indices);
assert_eq!(gpu, cpu);
assert_eq!(gpu, src);
}
fn run_histogram(input: &[u32], num_bins: u32) -> Vec<u32> {
let count = input.len() as u32;
let program = histogram("input", "output", count, num_bins);
let inputs: Vec<Vec<u8>> = vec![u32_bytes(input), vec![0u8; num_bins as usize * 4]];
let mut config = DispatchConfig::default();
let workgroup_x = 256u32;
let grid_x = ((count + workgroup_x - 1) / workgroup_x).max(1);
config.grid_override = Some([grid_x, 1, 1]);
let outputs = with_live_backend("histogram primitive", |backend| {
backend
.dispatch(&program, &inputs, &config)
.unwrap_or_else(|error| panic!("Fix: CUDA histogram dispatch failed: {error}"))
});
let mut out = bytes_u32(&outputs[0]);
out.truncate(num_bins as usize);
out
}
#[test]
fn cuda_histogram_simple() {
let input = vec![0u32, 1, 0, 2, 1, 0];
let num_bins = 4u32;
let cpu = hist_cpu(&input, num_bins);
let gpu = run_histogram(&input, num_bins);
assert_eq!(gpu, cpu);
assert_eq!(gpu, vec![3, 2, 1, 0]);
}
#[test]
fn cuda_histogram_skips_out_of_range() {
let input = vec![0u32, 5, 1, 5, 2];
let num_bins = 4u32;
let cpu = hist_cpu(&input, num_bins);
let gpu = run_histogram(&input, num_bins);
assert_eq!(gpu, cpu);
assert_eq!(gpu, vec![1, 1, 1, 0]);
}
fn run_radix_sort(input: &[u32], bits: u32) -> Vec<u32> {
let count = input.len() as u32;
let program = radix_sort("input", "output", count, bits);
let inputs: Vec<Vec<u8>> = vec![u32_bytes(input), vec![0u8; count as usize * 4]];
let mut config = DispatchConfig::default();
let workgroup_x = 256u32;
let grid_x = ((count + workgroup_x - 1) / workgroup_x).max(1);
config.grid_override = Some([grid_x, 1, 1]);
let outputs = with_live_backend("radix sort primitive", |backend| {
backend
.dispatch(&program, &inputs, &config)
.unwrap_or_else(|error| panic!("Fix: CUDA radix-sort dispatch failed: {error}"))
});
let mut out = bytes_u32(&outputs[0]);
out.truncate(count as usize);
out
}
#[test]
fn cuda_radix_sort_already_sorted() {
let v = vec![1u32, 2, 3, 4, 5];
let cpu = radix_cpu(&v, 8);
let gpu = run_radix_sort(&v, 8);
assert_eq!(gpu, cpu);
assert_eq!(gpu, v);
}
#[test]
fn cuda_radix_sort_reverse() {
let v = vec![5u32, 4, 3, 2, 1];
let cpu = radix_cpu(&v, 8);
let gpu = run_radix_sort(&v, 8);
assert_eq!(gpu, cpu);
assert_eq!(gpu, vec![1, 2, 3, 4, 5]);
}
#[test]
fn cuda_radix_sort_with_duplicates() {
let v = vec![3u32, 1, 4, 1, 5, 9, 2, 6, 5, 3];
let cpu = radix_cpu(&v, 8);
let gpu = run_radix_sort(&v, 8);
assert_eq!(gpu, cpu);
let mut expected = v.clone();
expected.sort_unstable();
assert_eq!(gpu, expected);
}
fn run_segment_reduce(input: &[u32], segment_offsets: &[u32]) -> Vec<u32> {
let num_segments = (segment_offsets.len() - 1) as u32;
let program = segment_reduce_sum("input", "segments", "output", num_segments);
let inputs: Vec<Vec<u8>> = vec![
u32_bytes(input),
u32_bytes(segment_offsets),
vec![0u8; num_segments as usize * 4],
];
let mut config = DispatchConfig::default();
let workgroup_x = 256u32;
let grid_x = ((num_segments + workgroup_x - 1) / workgroup_x).max(1);
config.grid_override = Some([grid_x, 1, 1]);
let outputs = with_live_backend("segment reduce primitive", |backend| {
backend
.dispatch(&program, &inputs, &config)
.unwrap_or_else(|error| panic!("Fix: CUDA segment-reduce dispatch failed: {error}"))
});
let mut out = bytes_u32(&outputs[0]);
out.truncate(num_segments as usize);
out
}
#[test]
fn cuda_segment_reduce_uniform_segments() {
let input = vec![1u32, 2, 3, 4, 5, 6];
let segments = vec![0u32, 2, 4, 6];
let cpu = seg_cpu(&input, &segments);
let gpu = run_segment_reduce(&input, &segments);
assert_eq!(gpu, cpu);
assert_eq!(gpu, vec![3, 7, 11]);
}
#[test]
fn cuda_segment_reduce_uneven_segments() {
let input = vec![10u32, 20, 30, 40, 50];
let segments = vec![0u32, 1, 4, 5];
let cpu = seg_cpu(&input, &segments);
let gpu = run_segment_reduce(&input, &segments);
assert_eq!(gpu, cpu);
assert_eq!(gpu, vec![10, 90, 50]);
}