mod common;
use common::setup_provider;
#[test]
fn test_prefix_sum_mask_simple() {
let Some(provider) = setup_provider() else {
eprintln!("Skipping: no CUDA device");
return;
};
let mask = vec![1u8, 0, 1, 1, 0, 1];
let (prefix_sum, count) = provider.prefix_sum_mask(&mask).unwrap();
assert_eq!(count, 4);
assert_eq!(prefix_sum, vec![0u32, 1, 1, 2, 3, 3]);
}
#[test]
fn test_prefix_sum_mask_empty() {
let Some(provider) = setup_provider() else {
eprintln!("Skipping: no CUDA device");
return;
};
let mask = vec![0u8; 10];
let (prefix_sum, count) = provider.prefix_sum_mask(&mask).unwrap();
assert_eq!(count, 0);
assert_eq!(prefix_sum, vec![0u32; 10]);
}
#[test]
fn test_prefix_sum_mask_all_ones() {
let Some(provider) = setup_provider() else {
eprintln!("Skipping: no CUDA device");
return;
};
let mask = vec![1u8; 5];
let (prefix_sum, count) = provider.prefix_sum_mask(&mask).unwrap();
assert_eq!(count, 5);
assert_eq!(prefix_sum, vec![0u32, 1, 2, 3, 4]);
}
#[test]
fn test_prefix_sum_mask_max_size() {
let Some(provider) = setup_provider() else {
eprintln!("Skipping: no CUDA device");
return;
};
let mut mask = vec![0u8; 256];
mask[0] = 1;
mask[127] = 1;
mask[255] = 1;
let (prefix_sum, count) = provider.prefix_sum_mask(&mask).unwrap();
assert_eq!(count, 3);
assert_eq!(prefix_sum[0], 0); assert_eq!(prefix_sum[127], 1); assert_eq!(prefix_sum[255], 2); }
#[test]
fn test_prefix_sum_mask_over_256() {
let Some(provider) = setup_provider() else {
eprintln!("Skipping: no CUDA device");
return;
};
let mask = vec![1u8; 257];
let result = provider.prefix_sum_mask(&mask);
assert!(
result.is_ok(),
"prefix_sum_mask should work with 257 elements"
);
let (prefix_sum, count) = result.unwrap();
assert_eq!(count, 257);
for (i, value) in prefix_sum.iter().enumerate().take(257) {
assert_eq!(*value, i as u32);
}
}