use core::any::TypeId;
use core::marker::PhantomData;
use alloc::format;
use alloc::string::String;
use alloc::vec::Vec;
use bytemuck::{Pod, Zeroable};
use wgpu::util::DeviceExt;
use crate::element::NumericElement;
use crate::kernel::{Kernel, MAX_WORKGROUPS, WORKGROUP_SIZE};
use crate::{Buffer, Context};
pub(crate) mod sum;
#[repr(C)]
#[derive(Debug, Clone, Copy, Pod, Zeroable)]
struct Params {
rank: u32,
len: u32,
reduction_len: u32,
_pad: u32,
}
macro_rules! define_kernel {
($kernel:ident, $label:literal, $init:ident, $op:literal) => {
pub(crate) struct $kernel<T>(PhantomData<T>);
impl<T: NumericElement> Kernel for $kernel<T> {
const LABEL: &'static str = $label;
type Output = T;
fn wgsl() -> String {
let ty = T::wgsl_type();
let init = T::$init();
format!(
r"
const WG_SIZE: u32 = {WORKGROUP_SIZE}u;
struct Params {{
rank: u32,
len: u32,
reduction_len: u32,
_pad: u32,
}}
@group(0) @binding(0) var<storage, read> x: array<{ty}>;
@group(0) @binding(1) var<storage, read_write> y: array<{ty}>;
@group(0) @binding(2) var<storage, read> x_dims: array<u32>;
@group(0) @binding(3) var<storage, read> x_strides: array<u32>;
@group(0) @binding(4) var<storage, read> y_strides: array<u32>;
@group(0) @binding(5) var<storage, read> reduce_mask: array<u32>;
@group(0) @binding(6) var<uniform> params: Params;
var<workgroup> sdata: array<{ty}, WG_SIZE>;
@compute @workgroup_size(WG_SIZE)
fn main(
@builtin(local_invocation_id) lid: vec3<u32>,
@builtin(workgroup_id) wid: vec3<u32>
) {{
let tid = lid.x;
let y_idx = wid.x;
if y_idx >= params.len {{
return;
}}
var base_coords: array<u32, 32>;
var remaining = y_idx;
for (var i = 0u; i < params.rank; i++) {{
let stride = y_strides[i];
if stride > 0u {{
base_coords[i] = remaining / stride;
remaining = remaining % stride;
}} else {{
base_coords[i] = 0u;
}}
}}
var acc: {ty} = {init};
var reduction_idx = tid;
while reduction_idx < params.reduction_len {{
var input_idx = 0u;
var red_remaining = reduction_idx;
for (var i = 0u; i < params.rank; i++) {{
var coord: u32;
if reduce_mask[i] != 0u {{
var red_stride = 1u;
for (var j = i + 1u; j < params.rank; j++) {{
if reduce_mask[j] != 0u {{
red_stride *= x_dims[j];
}}
}}
coord = red_remaining / red_stride;
red_remaining = red_remaining % red_stride;
}} else {{
coord = base_coords[i];
}}
input_idx += coord * x_strides[i];
}}
acc = {op}(acc, x[input_idx]);
reduction_idx += WG_SIZE;
}}
sdata[tid] = acc;
workgroupBarrier();
if tid < 128u {{ sdata[tid] = {op}(sdata[tid], sdata[tid + 128u]); }}
workgroupBarrier();
if tid < 64u {{ sdata[tid] = {op}(sdata[tid], sdata[tid + 64u]); }}
workgroupBarrier();
if tid < 32u {{ sdata[tid] = {op}(sdata[tid], sdata[tid + 32u]); }}
workgroupBarrier();
if tid < 16u {{ sdata[tid] = {op}(sdata[tid], sdata[tid + 16u]); }}
workgroupBarrier();
if tid < 8u {{ sdata[tid] = {op}(sdata[tid], sdata[tid + 8u]); }}
workgroupBarrier();
if tid < 4u {{ sdata[tid] = {op}(sdata[tid], sdata[tid + 4u]); }}
workgroupBarrier();
if tid < 2u {{ sdata[tid] = {op}(sdata[tid], sdata[tid + 2u]); }}
workgroupBarrier();
if tid < 1u {{ sdata[tid] = {op}(sdata[tid], sdata[tid + 1u]); }}
if tid == 0u {{
y[y_idx] = sdata[0];
}}
}}
",
op = $op
)
}
}
};
}
define_kernel!(MaxReduce, "max_reduce", wgsl_min, "max");
define_kernel!(MinReduce, "min_reduce", wgsl_max, "min");
#[allow(clippy::too_many_lines)]
pub(crate) fn execute<K: Kernel + 'static, T: NumericElement>(
ctx: &Context,
x: &Buffer<T>,
y: &Buffer<T>,
x_dimensions: &[usize],
x_strides: &[usize],
y_strides: &[usize],
axes: &[usize],
) {
let rank = u32::try_from(y_strides.len()).expect("output rank exceeds max size");
let len = u32::try_from(y.len()).expect("output length exceeds max size");
let reduction_len = u32::try_from(axes.iter().map(|&a| x_dimensions[a]).product::<usize>())
.expect("reduction length exceeds max size");
if len == 0 || reduction_len == 0 {
return;
}
assert!(
len <= MAX_WORKGROUPS,
"output length exceeds maximum workgroups"
);
let pipeline = ctx.get_or_create_pipeline(TypeId::of::<K>(), K::wgsl, K::LABEL);
let x_dimensions = crate::kernel::convert_strides(x_dimensions);
let x_strides = crate::kernel::convert_strides(x_strides);
let y_strides = crate::kernel::convert_strides(y_strides);
let reduce_mask: Vec<u32> = (0..rank as usize)
.map(|i| u32::from(axes.contains(&i)))
.collect();
let params = Params {
rank,
len,
reduction_len,
_pad: 0,
};
let x_dimensions = ctx
.device()
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
contents: bytemuck::cast_slice(&x_dimensions),
usage: wgpu::BufferUsages::STORAGE,
});
let x_strides = ctx
.device()
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
contents: bytemuck::cast_slice(&x_strides),
usage: wgpu::BufferUsages::STORAGE,
});
let y_strides = ctx
.device()
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
contents: bytemuck::cast_slice(&y_strides),
usage: wgpu::BufferUsages::STORAGE,
});
let reduce_mask = ctx
.device()
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
contents: bytemuck::cast_slice(&reduce_mask),
usage: wgpu::BufferUsages::STORAGE,
});
let params = ctx.create_uniform_buffer(¶ms);
let bind_group = ctx.device().create_bind_group(&wgpu::BindGroupDescriptor {
label: Some(K::LABEL),
layout: &pipeline.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: x.inner().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: y.inner().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x_dimensions.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: x_strides.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: y_strides.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: reduce_mask.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 6,
resource: params.as_entire_binding(),
},
],
});
let mut encoder = ctx
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some(K::LABEL),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(K::LABEL),
..Default::default()
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(len, 1, 1);
}
ctx.queue().submit(Some(encoder.finish()));
}