use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::runtime::cuda::{CudaDevice, CudaRuntime};
use crate::tensor::Tensor;
use cudarc::driver::{CudaContext, CudaStream, PushKernelArg};
use std::sync::Arc;
use super::loader::{get_kernel_function, get_or_load_module, kernel_names, launch_config};
const SCAN_BLOCK_SIZE: u32 = 512;
const MAX_SCAN_RECURSION_DEPTH: usize = 10;
pub unsafe fn exclusive_scan_i32_gpu(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
device: &CudaDevice,
input: &Tensor<CudaRuntime>,
) -> Result<(Tensor<CudaRuntime>, usize)> {
let n = input.numel();
if input.dtype() != DType::I32 {
return Err(Error::Internal(format!(
"exclusive_scan_i32_gpu expects I32 input, got {:?}",
input.dtype()
)));
}
let output = Tensor::<CudaRuntime>::zeros(&[n + 1], DType::I32, device);
let input_ptr = input.ptr();
let output_ptr = output.ptr();
if n <= SCAN_BLOCK_SIZE as usize {
unsafe {
launch_scan_single_block_i32(
context,
stream,
device_index,
input_ptr,
output_ptr,
n as u32,
)?;
}
} else {
unsafe {
launch_scan_multi_block_i32(
context,
stream,
device_index,
device,
input_ptr,
output_ptr,
n as u32,
0, )?;
}
}
stream
.synchronize()
.map_err(|e| Error::Internal(format!("Failed to synchronize after scan: {:?}", e)))?;
let mut total_i32: i32 = 0;
let offset_bytes = n * std::mem::size_of::<i32>();
unsafe {
cudarc::driver::sys::cuMemcpyDtoH_v2(
&mut total_i32 as *mut i32 as *mut std::ffi::c_void,
output.ptr() + offset_bytes as u64,
std::mem::size_of::<i32>(),
);
}
let total = total_i32 as usize;
Ok((output, total))
}
unsafe fn launch_scan_single_block_i32(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
input_ptr: u64,
output_ptr: u64,
n: u32,
) -> Result<()> {
let module = get_or_load_module(context, device_index, kernel_names::SCAN_MODULE)?;
let func = get_kernel_function(&module, "exclusive_scan_i32")?;
let grid = (1, 1, 1);
let block = (SCAN_BLOCK_SIZE, 1, 1);
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&input_ptr);
builder.arg(&output_ptr);
builder.arg(&n);
unsafe { builder.launch(cfg) }.map_err(|e| {
Error::Internal(format!(
"CUDA scan single-block kernel launch failed: {:?}",
e
))
})?;
Ok(())
}
unsafe fn launch_scan_multi_block_i32(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
device: &CudaDevice,
input_ptr: u64,
output_ptr: u64,
n: u32,
depth: usize,
) -> Result<()> {
if depth >= MAX_SCAN_RECURSION_DEPTH {
return Err(Error::Internal(format!(
"Scan recursion depth {} exceeds maximum {}. \
This indicates an algorithmic error or impossibly large input.",
depth, MAX_SCAN_RECURSION_DEPTH
)));
}
let module = get_or_load_module(context, device_index, kernel_names::SCAN_MODULE)?;
let num_blocks = (n + SCAN_BLOCK_SIZE - 1) / SCAN_BLOCK_SIZE;
let block_sums = Tensor::<CudaRuntime>::zeros(&[num_blocks as usize], DType::I32, device);
let block_sums_ptr = block_sums.ptr();
let func_step1 = get_kernel_function(&module, "scan_blocks_i32_step1")?;
let grid = (num_blocks, 1, 1);
let block = (SCAN_BLOCK_SIZE, 1, 1);
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func_step1);
builder.arg(&input_ptr);
builder.arg(&output_ptr);
builder.arg(&block_sums_ptr);
builder.arg(&n);
unsafe { builder.launch(cfg) }
.map_err(|e| Error::Internal(format!("CUDA scan step 1 kernel launch failed: {:?}", e)))?;
stream.synchronize().map_err(|e| {
Error::Internal(format!("Failed to synchronize after scan step 1: {:?}", e))
})?;
let scanned_block_sums =
Tensor::<CudaRuntime>::zeros(&[num_blocks as usize + 1], DType::I32, device);
let scanned_block_sums_ptr = scanned_block_sums.ptr();
if num_blocks <= SCAN_BLOCK_SIZE {
let func_scan = get_kernel_function(&module, "exclusive_scan_i32")?;
let grid = (1, 1, 1);
let block = (SCAN_BLOCK_SIZE, 1, 1);
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func_scan);
builder.arg(&block_sums_ptr);
builder.arg(&scanned_block_sums_ptr);
builder.arg(&num_blocks);
unsafe { builder.launch(cfg) }.map_err(|e| {
Error::Internal(format!(
"CUDA scan block sums kernel launch failed: {:?}",
e
))
})?;
stream.synchronize().map_err(|e| {
Error::Internal(format!("Failed to synchronize after scan step 2: {:?}", e))
})?;
} else {
unsafe {
launch_scan_multi_block_i32(
context,
stream,
device_index,
device,
block_sums_ptr,
scanned_block_sums_ptr,
num_blocks,
depth + 1, )?;
}
}
let func_step3 = get_kernel_function(&module, "add_block_offsets_i32_step3")?;
let grid = (num_blocks, 1, 1);
let block = (SCAN_BLOCK_SIZE, 1, 1);
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func_step3);
builder.arg(&output_ptr);
builder.arg(&scanned_block_sums_ptr);
builder.arg(&n);
unsafe { builder.launch(cfg) }
.map_err(|e| Error::Internal(format!("CUDA scan step 3 kernel launch failed: {:?}", e)))?;
stream.synchronize().map_err(|e| {
Error::Internal(format!("Failed to synchronize after scan step 3: {:?}", e))
})?;
let func_total = get_kernel_function(&module, "write_total_i32")?;
let cfg = launch_config((1, 1, 1), (1, 1, 1), 0);
let mut builder = stream.launch_builder(&func_total);
builder.arg(&output_ptr);
builder.arg(&scanned_block_sums_ptr);
builder.arg(&n);
builder.arg(&num_blocks);
unsafe { builder.launch(cfg) }.map_err(|e| {
Error::Internal(format!(
"CUDA scan total write kernel launch failed: {:?}",
e
))
})?;
Ok(())
}
pub unsafe fn exclusive_scan_i64_gpu(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
device: &CudaDevice,
input: &Tensor<CudaRuntime>,
) -> Result<(Tensor<CudaRuntime>, usize)> {
let n = input.numel();
if input.dtype() != DType::I64 {
return Err(Error::Internal(format!(
"exclusive_scan_i64_gpu expects I64 input, got {:?}",
input.dtype()
)));
}
let output = Tensor::<CudaRuntime>::zeros(&[n + 1], DType::I64, device);
let input_ptr = input.ptr();
let output_ptr = output.ptr();
if n <= SCAN_BLOCK_SIZE as usize {
unsafe {
launch_scan_single_block_i64(
context,
stream,
device_index,
input_ptr,
output_ptr,
n as u32,
)?;
}
} else {
unsafe {
launch_scan_multi_block_i64(
context,
stream,
device_index,
device,
input_ptr,
output_ptr,
n as u32,
0, )?;
}
}
stream
.synchronize()
.map_err(|e| Error::Internal(format!("Failed to synchronize after scan: {:?}", e)))?;
let mut total_i64: i64 = 0;
let offset_bytes = n * std::mem::size_of::<i64>();
unsafe {
cudarc::driver::sys::cuMemcpyDtoH_v2(
&mut total_i64 as *mut i64 as *mut std::ffi::c_void,
output.ptr() + offset_bytes as u64,
std::mem::size_of::<i64>(),
);
}
let total = total_i64 as usize;
Ok((output, total))
}
unsafe fn launch_scan_single_block_i64(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
input_ptr: u64,
output_ptr: u64,
n: u32,
) -> Result<()> {
let module = get_or_load_module(context, device_index, kernel_names::SCAN_MODULE)?;
let func = get_kernel_function(&module, "exclusive_scan_i64")?;
let grid = (1, 1, 1);
let block = (SCAN_BLOCK_SIZE, 1, 1);
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&input_ptr);
builder.arg(&output_ptr);
builder.arg(&n);
unsafe { builder.launch(cfg) }.map_err(|e| {
Error::Internal(format!(
"CUDA scan single-block i64 kernel launch failed: {:?}",
e
))
})?;
Ok(())
}
unsafe fn launch_scan_multi_block_i64(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
device: &CudaDevice,
input_ptr: u64,
output_ptr: u64,
n: u32,
depth: usize,
) -> Result<()> {
if depth >= MAX_SCAN_RECURSION_DEPTH {
return Err(Error::Internal(format!(
"Scan recursion depth {} exceeds maximum {}. \
This indicates an algorithmic error or impossibly large input.",
depth, MAX_SCAN_RECURSION_DEPTH
)));
}
let module = get_or_load_module(context, device_index, kernel_names::SCAN_MODULE)?;
let num_blocks = (n + SCAN_BLOCK_SIZE - 1) / SCAN_BLOCK_SIZE;
let block_sums = Tensor::<CudaRuntime>::zeros(&[num_blocks as usize], DType::I64, device);
let block_sums_ptr = block_sums.ptr();
let func_step1 = get_kernel_function(&module, "scan_blocks_i64_step1")?;
let grid = (num_blocks, 1, 1);
let block = (SCAN_BLOCK_SIZE, 1, 1);
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func_step1);
builder.arg(&input_ptr);
builder.arg(&output_ptr);
builder.arg(&block_sums_ptr);
builder.arg(&n);
unsafe { builder.launch(cfg) }.map_err(|e| {
Error::Internal(format!(
"CUDA scan i64 step 1 kernel launch failed: {:?}",
e
))
})?;
stream.synchronize().map_err(|e| {
Error::Internal(format!(
"Failed to synchronize after scan i64 step 1: {:?}",
e
))
})?;
let scanned_block_sums =
Tensor::<CudaRuntime>::zeros(&[num_blocks as usize + 1], DType::I64, device);
let scanned_block_sums_ptr = scanned_block_sums.ptr();
if num_blocks <= SCAN_BLOCK_SIZE {
let func_scan = get_kernel_function(&module, "exclusive_scan_i64")?;
let grid = (1, 1, 1);
let block = (SCAN_BLOCK_SIZE, 1, 1);
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func_scan);
builder.arg(&block_sums_ptr);
builder.arg(&scanned_block_sums_ptr);
builder.arg(&num_blocks);
unsafe { builder.launch(cfg) }.map_err(|e| {
Error::Internal(format!(
"CUDA scan i64 block sums kernel launch failed: {:?}",
e
))
})?;
stream.synchronize().map_err(|e| {
Error::Internal(format!(
"Failed to synchronize after scan i64 step 2: {:?}",
e
))
})?;
} else {
unsafe {
launch_scan_multi_block_i64(
context,
stream,
device_index,
device,
block_sums_ptr,
scanned_block_sums_ptr,
num_blocks,
depth + 1, )?;
}
}
let func_step3 = get_kernel_function(&module, "add_block_offsets_i64_step3")?;
let grid = (num_blocks, 1, 1);
let block = (SCAN_BLOCK_SIZE, 1, 1);
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func_step3);
builder.arg(&output_ptr);
builder.arg(&scanned_block_sums_ptr);
builder.arg(&n);
unsafe { builder.launch(cfg) }.map_err(|e| {
Error::Internal(format!(
"CUDA scan i64 step 3 kernel launch failed: {:?}",
e
))
})?;
stream.synchronize().map_err(|e| {
Error::Internal(format!(
"Failed to synchronize after scan i64 step 3: {:?}",
e
))
})?;
let func_total = get_kernel_function(&module, "write_total_i64")?;
let cfg = launch_config((1, 1, 1), (1, 1, 1), 0);
let mut builder = stream.launch_builder(&func_total);
builder.arg(&output_ptr);
builder.arg(&scanned_block_sums_ptr);
builder.arg(&n);
builder.arg(&num_blocks);
unsafe { builder.launch(cfg) }.map_err(|e| {
Error::Internal(format!(
"CUDA scan i64 total write kernel launch failed: {:?}",
e
))
})?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::Runtime;
#[test]
#[cfg(feature = "cuda")]
fn test_exclusive_scan_small() {
if !crate::runtime::cuda::is_cuda_available() {
return;
}
let device = CudaDevice::new(0);
let client = CudaRuntime::default_client(&device);
let input = Tensor::<CudaRuntime>::from_slice(&[3i32, 1, 4, 1, 5], &[5], &device);
let (output, total) = unsafe {
exclusive_scan_i32_gpu(
&client.context,
&client.stream,
device.index,
&device,
&input,
)
}
.expect("scan failed");
let output_vec: Vec<i32> = output.to_vec();
assert_eq!(output_vec, vec![0, 3, 4, 8, 9, 14]);
assert_eq!(total, 14);
}
#[test]
#[cfg(feature = "cuda")]
fn test_exclusive_scan_large() {
if !crate::runtime::cuda::is_cuda_available() {
return;
}
let device = CudaDevice::new(0);
let client = CudaRuntime::default_client(&device);
let input_vec = vec![1i32; 1024];
let input = Tensor::<CudaRuntime>::from_slice(&input_vec, &[1024], &device);
let (output, total) = unsafe {
exclusive_scan_i32_gpu(
&client.context,
&client.stream,
device.index,
&device,
&input,
)
}
.expect("scan failed");
let output_vec: Vec<i32> = output.to_vec();
assert_eq!(output_vec.len(), 1025);
assert_eq!(output_vec[0], 0);
assert_eq!(output_vec[1024], 1024);
assert_eq!(total, 1024);
for i in 0..1024 {
assert_eq!(output_vec[i], i as i32);
}
}
#[test]
#[cfg(feature = "cuda")]
fn test_exclusive_scan_zeros() {
if !crate::runtime::cuda::is_cuda_available() {
return;
}
let device = CudaDevice::new(0);
let client = CudaRuntime::default_client(&device);
let input = Tensor::<CudaRuntime>::from_slice(&[0i32, 0, 0, 0], &[4], &device);
let (output, total) = unsafe {
exclusive_scan_i32_gpu(
&client.context,
&client.stream,
device.index,
&device,
&input,
)
}
.expect("scan failed");
let output_vec: Vec<i32> = output.to_vec();
assert_eq!(output_vec, vec![0, 0, 0, 0, 0]);
assert_eq!(total, 0);
}
#[test]
#[cfg(feature = "cuda")]
fn test_exclusive_scan_single_element() {
if !crate::runtime::cuda::is_cuda_available() {
return;
}
let device = CudaDevice::new(0);
let client = CudaRuntime::default_client(&device);
let input = Tensor::<CudaRuntime>::from_slice(&[42i32], &[1], &device);
let (output, total) = unsafe {
exclusive_scan_i32_gpu(
&client.context,
&client.stream,
device.index,
&device,
&input,
)
}
.expect("scan failed");
let output_vec: Vec<i32> = output.to_vec();
assert_eq!(output_vec, vec![0, 42]);
assert_eq!(total, 42);
}
#[test]
#[cfg(feature = "cuda")]
fn test_exclusive_scan_very_large() {
if !crate::runtime::cuda::is_cuda_available() {
return;
}
let device = CudaDevice::new(0);
let client = CudaRuntime::default_client(&device);
let n = 500_000;
let input_vec = vec![1i32; n];
let input = Tensor::<CudaRuntime>::from_slice(&input_vec, &[n], &device);
let (output, total) = unsafe {
exclusive_scan_i32_gpu(
&client.context,
&client.stream,
device.index,
&device,
&input,
)
}
.expect("large scan failed");
let output_vec: Vec<i32> = output.to_vec();
assert_eq!(output_vec.len(), n + 1);
assert_eq!(output_vec[0], 0);
assert_eq!(output_vec[n], n as i32);
assert_eq!(total, n);
assert_eq!(output_vec[1000], 1000);
assert_eq!(output_vec[100_000], 100_000);
assert_eq!(output_vec[250_000], 250_000);
}
#[test]
#[cfg(feature = "cuda")]
fn test_exclusive_scan_boundary_size() {
if !crate::runtime::cuda::is_cuda_available() {
return;
}
let device = CudaDevice::new(0);
let client = CudaRuntime::default_client(&device);
let n = 262_145;
let input_vec = vec![1i32; n];
let input = Tensor::<CudaRuntime>::from_slice(&input_vec, &[n], &device);
let (output, total) = unsafe {
exclusive_scan_i32_gpu(
&client.context,
&client.stream,
device.index,
&device,
&input,
)
}
.expect("boundary scan failed");
let output_vec: Vec<i32> = output.to_vec();
assert_eq!(output_vec.len(), n + 1);
assert_eq!(output_vec[0], 0);
assert_eq!(output_vec[n], n as i32);
assert_eq!(total, n);
}
#[test]
#[cfg(feature = "cuda")]
fn test_exclusive_scan_i64_very_large() {
if !crate::runtime::cuda::is_cuda_available() {
return;
}
let device = CudaDevice::new(0);
let client = CudaRuntime::default_client(&device);
let n = 100_000;
let input_vec = vec![50_000i64; n];
let input = Tensor::<CudaRuntime>::from_slice(&input_vec, &[n], &device);
let (output, total) = unsafe {
exclusive_scan_i64_gpu(
&client.context,
&client.stream,
device.index,
&device,
&input,
)
}
.expect("i64 scan failed");
let expected_total: i64 = 50_000 * 100_000;
assert_eq!(total, expected_total as usize);
let output_vec: Vec<i64> = output.to_vec();
assert_eq!(output_vec[0], 0);
assert_eq!(output_vec[n], expected_total);
}
}