use nalgebra::DVector;
use slang_hal::backend::Backend;
use slang_hal::function::GpuFunction;
use slang_hal::{BufferUsages, Shader, ShaderArgs};
use stensor::tensor::GpuTensor;
#[derive(Shader)]
#[shader(module = "slosh::grid::prefix_sum")]
pub struct WgPrefixSum<B: Backend> {
prefix_sum: GpuFunction<B>,
add_data_grp: GpuFunction<B>,
}
#[derive(ShaderArgs)]
struct PrefixSumArgs<'a, B: Backend> {
data: &'a GpuTensor<u32, B>,
aux: &'a GpuTensor<u32, B>,
}
impl<B: Backend> WgPrefixSum<B> {
const THREADS: u32 = 256;
pub fn launch(
&self,
backend: &B,
pass: &mut B::Pass,
workspace: &mut PrefixSumWorkspace<B>,
data: &GpuTensor<u32, B>,
) -> Result<(), B::Error> {
assert_eq!(
Self::THREADS,
256,
"Internal error: prefix sum assumes a thread count equal to 256"
);
workspace.reserve(backend, data.len() as u32)?;
let ngroups0 = workspace.stages[0].buffer.len() as u32;
let aux0 = &workspace.stages[0].buffer;
let args0 = PrefixSumArgs { data, aux: aux0 };
self.prefix_sum
.launch_grid(backend, pass, &args0, [ngroups0, 1, 1])?;
for i in 0..workspace.num_stages - 1 {
let ngroups = workspace.stages[i + 1].buffer.len() as u32;
let buf = &workspace.stages[i].buffer;
let aux = &workspace.stages[i + 1].buffer;
let args = PrefixSumArgs { data: buf, aux };
self.prefix_sum
.launch_grid(backend, pass, &args, [ngroups, 1, 1])?;
}
if workspace.num_stages > 2 {
for i in (0..workspace.num_stages - 2).rev() {
let ngroups = workspace.stages[i + 1].buffer.len() as u32;
let buf = &workspace.stages[i].buffer;
let aux = &workspace.stages[i + 1].buffer;
let args = PrefixSumArgs { data: buf, aux };
self.add_data_grp
.launch_grid(backend, pass, &args, [ngroups, 1, 1])?;
}
}
if workspace.num_stages > 1 {
let args = PrefixSumArgs { data, aux: aux0 };
self.add_data_grp
.launch_grid(backend, pass, &args, [ngroups0, 1, 1])?;
}
Ok(())
}
pub fn eval_cpu(&self, v: &mut DVector<u32>) {
for i in 0..v.len() - 1 {
v[i + 1] += v[i];
}
for i in (1..v.len()).rev() {
v[i] = v[i - 1];
}
v[0] = 0;
}
}
struct PrefixSumStage<B: Backend> {
capacity: u32,
buffer: GpuTensor<u32, B>,
}
#[derive(Default)]
pub struct PrefixSumWorkspace<B: Backend> {
stages: Vec<PrefixSumStage<B>>,
num_stages: usize,
}
impl<B: Backend> PrefixSumWorkspace<B> {
pub fn new() -> Self {
Self {
stages: vec![],
num_stages: 0,
}
}
pub fn with_capacity(backend: &B, buffer_len: u32) -> Result<Self, B::Error> {
let mut result = Self {
stages: vec![],
num_stages: 0,
};
result.reserve(backend, buffer_len)?;
Ok(result)
}
pub fn reserve(&mut self, backend: &B, buffer_len: u32) -> Result<(), B::Error> {
let mut stage_len = buffer_len.div_ceil(WgPrefixSum::<B>::THREADS);
if self.stages.is_empty() || self.stages[0].capacity < stage_len {
self.stages.clear();
while stage_len != 1 {
let buffer = GpuTensor::vector(
backend,
DVector::<u32>::zeros(stage_len as usize),
BufferUsages::STORAGE,
)?;
self.stages.push(PrefixSumStage {
capacity: stage_len,
buffer,
});
stage_len = stage_len.div_ceil(WgPrefixSum::<B>::THREADS);
}
self.stages.push(PrefixSumStage {
capacity: 1,
buffer: GpuTensor::vector(
backend,
DVector::<u32>::zeros(1),
BufferUsages::STORAGE,
)?,
});
self.num_stages = self.stages.len();
} else if self.stages[0].buffer.len() as u32 != stage_len {
self.num_stages = 0;
while stage_len != 1 {
self.num_stages += 1;
stage_len = stage_len.div_ceil(WgPrefixSum::<B>::THREADS);
}
self.num_stages += 1;
}
Ok(())
}
}