use crate::{DType, Device, Result, Shape, Tensor, TensorError};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ReductionOp {
Sum,
Product,
Max,
Min,
Mean,
All,
Any,
}
impl ReductionOp {
pub fn wgsl_op(&self, dtype: &str) -> String {
match self {
Self::Sum => "a + b".to_string(),
Self::Product => "a * b".to_string(),
Self::Max => "max(a, b)".to_string(),
Self::Min => "min(a, b)".to_string(),
Self::Mean => "a + b".to_string(), Self::All => "a && b".to_string(),
Self::Any => "a || b".to_string(),
}
}
pub fn identity(&self, dtype: &str) -> String {
match self {
Self::Sum => match dtype {
"f32" => "0.0f".to_string(),
"f64" => "0.0".to_string(),
"i8" | "i16" | "i32" => "0".to_string(),
"i64" => "0".to_string(),
"u8" | "u16" | "u32" => "0u".to_string(),
"u64" => "0u".to_string(),
_ => "0".to_string(),
},
Self::Product => match dtype {
"f32" => "1.0f".to_string(),
"f64" => "1.0".to_string(),
"i8" | "i16" | "i32" => "1".to_string(),
"i64" => "1".to_string(),
"u8" | "u16" | "u32" => "1u".to_string(),
"u64" => "1u".to_string(),
_ => "1".to_string(),
},
Self::Max => match dtype {
"f32" => "-3.40282347e+38f".to_string(), "f64" => "-1.7976931348623157e+308".to_string(),
"i8" => "-128".to_string(), "i16" => "-32768".to_string(), "i32" => "-2147483648".to_string(), "i64" => "-9223372036854775808".to_string(), "u8" | "u16" | "u32" => "0u".to_string(),
"u64" => "0u".to_string(),
_ => "0".to_string(),
},
Self::Min => match dtype {
"f32" => "3.40282347e+38f".to_string(), "f64" => "1.7976931348623157e+308".to_string(),
"i8" => "127".to_string(), "i16" => "32767".to_string(), "i32" => "2147483647".to_string(), "i64" => "9223372036854775807".to_string(), "u8" => "255u".to_string(), "u16" => "65535u".to_string(), "u32" => "4294967295u".to_string(), "u64" => "18446744073709551615u".to_string(), _ => "0".to_string(),
},
Self::Mean => "0.0f".to_string(),
Self::All => "true".to_string(),
Self::Any => "false".to_string(),
}
}
pub fn name(&self) -> &'static str {
match self {
Self::Sum => "sum",
Self::Product => "product",
Self::Max => "max",
Self::Min => "min",
Self::Mean => "mean",
Self::All => "all",
Self::Any => "any",
}
}
}
#[derive(Debug, Clone)]
pub struct ReductionKernel {
pub op: ReductionOp,
pub dtype: String,
pub shader_source: String,
pub workgroup_size: u32,
}
impl ReductionKernel {
pub fn new(op: ReductionOp, dtype: &str, workgroup_size: u32) -> Self {
let shader_source = generate_reduction_shader(op, dtype, workgroup_size);
Self {
op,
dtype: dtype.to_string(),
shader_source,
workgroup_size,
}
}
pub fn id(&self) -> String {
format!("reduce_{}_{}", self.op.name(), self.dtype)
}
}
pub fn generate_reduction_shader(op: ReductionOp, dtype: &str, workgroup_size: u32) -> String {
let wgsl_type = match dtype {
"f32" => "f32",
"f64" => "f32", "i8" => "i32", "i16" => "i32", "i32" => "i32",
"i64" => "i32", "u8" => "u32", "u16" => "u32", "u32" => "u32",
"u64" => "u32", "bool" => "bool",
_ => "f32",
};
let identity = op.identity(dtype);
let operation = op.wgsl_op(dtype);
let op_name = op.name();
format!(
r#"
// Reduction Kernel: {op_name} for {dtype}
// Workgroup size: {workgroup_size}
struct ReductionParams {{
input_size: u32,
output_size: u32,
axis_size: u32,
reduce_stride: u32,
}}
@group(0) @binding(0) var<storage, read> input: array<{wgsl_type}>;
@group(0) @binding(1) var<storage, read_write> output: array<{wgsl_type}>;
@group(0) @binding(2) var<uniform> params: ReductionParams;
// Shared memory for workgroup reduction
var<workgroup> shared_data: array<{wgsl_type}, {workgroup_size}>;
@compute @workgroup_size({workgroup_size}, 1, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {{
let tid = local_id.x;
let gid = global_id.x;
let wg_id = workgroup_id.x;
// Each thread accumulates multiple elements
var accumulator: {wgsl_type} = {identity};
// Grid-stride loop for coalesced memory access
var idx = gid;
while (idx < params.input_size) {{
let value = input[idx];
accumulator = {operation};
idx += {workgroup_size}u;
}}
// Store in shared memory
shared_data[tid] = accumulator;
workgroupBarrier();
// Tree reduction in shared memory
var stride = {workgroup_size}u / 2u;
while (stride > 0u) {{
if (tid < stride) {{
let a = shared_data[tid];
let b = shared_data[tid + stride];
shared_data[tid] = {operation};
}}
workgroupBarrier();
stride = stride / 2u;
}}
// First thread writes workgroup result
if (tid == 0u) {{
output[wg_id] = shared_data[0];
}}
}}
"#,
op_name = op_name,
dtype = dtype,
wgsl_type = wgsl_type,
workgroup_size = workgroup_size,
identity = identity,
operation = operation,
)
}
pub fn create_reduction_kernel(op: ReductionOp, dtype: &str) -> Result<ReductionKernel> {
match dtype {
"f32" | "f64" | "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "bool" => {}
_ => {
return Err(TensorError::unsupported_operation_simple(format!(
"Reduction not supported for dtype: {}",
dtype
)))
}
}
let workgroup_size = 256;
Ok(ReductionKernel::new(op, dtype, workgroup_size))
}
#[derive(Debug, Clone, Default)]
pub struct ReductionConfig {
pub axis: Option<usize>,
pub keepdims: bool,
pub kahan: bool,
}
pub fn calculate_reduction_workgroups(input_size: usize, workgroup_size: u32) -> (u32, bool) {
let elements_per_thread = 4; let elements_per_workgroup = workgroup_size * elements_per_thread;
let num_workgroups =
((input_size as u32 + elements_per_workgroup - 1) / elements_per_workgroup).max(1);
let needs_second_stage = num_workgroups > 1;
(num_workgroups, needs_second_stage)
}
#[cfg(feature = "gpu")]
pub fn execute_reduction_gpu<T>(
input: &Tensor<T>,
op: ReductionOp,
config: &ReductionConfig,
) -> Result<Tensor<T>>
where
T: scirs2_core::num_traits::Float + Default + 'static + bytemuck::Pod,
{
if !matches!(input.device(), Device::Gpu(_)) {
return Err(TensorError::invalid_argument(
"Input tensor must be on GPU device".to_string(),
));
}
let dtype = match std::any::TypeId::of::<T>() {
id if id == std::any::TypeId::of::<f32>() => "f32",
id if id == std::any::TypeId::of::<f64>() => "f64",
_ => {
return Err(TensorError::unsupported_operation_simple(
"Unsupported type for GPU reduction".to_string(),
))
}
};
let kernel = create_reduction_kernel(op, dtype)?;
let gpu_ctx = crate::gpu::GpuContext::global().map_err(|_| {
TensorError::unsupported_operation_simple("GPU context not available".to_string())
})?;
let input_data = input.data();
let input_size = input_data.len();
let (num_workgroups, needs_second_stage) =
calculate_reduction_workgroups(input_size, kernel.workgroup_size);
use wgpu::util::DeviceExt;
let input_buffer = gpu_ctx
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("reduction_input"),
contents: bytemuck::cast_slice(input_data),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
});
let intermediate_buffer = gpu_ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("reduction_intermediate"),
size: (num_workgroups as usize * std::mem::size_of::<T>()) as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let output_size = if needs_second_stage {
num_workgroups as usize
} else {
1
};
let output_buffer = gpu_ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("reduction_output"),
size: (output_size * std::mem::size_of::<T>()) as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let shader_module = gpu_ctx
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(&format!("reduction_shader_{}", kernel.id())),
source: wgpu::ShaderSource::Wgsl(kernel.shader_source.as_str().into()),
});
let bind_group_layout =
gpu_ctx
.device
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("reduction_bind_group_layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = gpu_ctx
.device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("reduction_pipeline_layout"),
bind_group_layouts: &[Some(&bind_group_layout)],
immediate_size: 0,
});
let compute_pipeline =
gpu_ctx
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("reduction_pipeline"),
layout: Some(&pipeline_layout),
module: &shader_module,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
let params = [input_size as u32, output_size as u32, 0u32, 0u32];
let params_buffer = gpu_ctx
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("reduction_params"),
contents: bytemuck::cast_slice(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = gpu_ctx
.device
.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("reduction_bind_group"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: input_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: if needs_second_stage {
intermediate_buffer.as_entire_binding()
} else {
output_buffer.as_entire_binding()
},
},
wgpu::BindGroupEntry {
binding: 2,
resource: params_buffer.as_entire_binding(),
},
],
});
let mut encoder = gpu_ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("reduction_encoder"),
});
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("reduction_pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(&compute_pipeline);
compute_pass.set_bind_group(0, &bind_group, &[]);
compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
}
if needs_second_stage {
let params2 = [num_workgroups, 1u32, 0u32, 0u32];
let params_buffer2 = gpu_ctx
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("reduction_params_stage2"),
contents: bytemuck::cast_slice(¶ms2),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group2 = gpu_ctx
.device
.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("reduction_bind_group_stage2"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: intermediate_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: output_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: params_buffer2.as_entire_binding(),
},
],
});
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("reduction_pass_stage2"),
timestamp_writes: None,
});
compute_pass.set_pipeline(&compute_pipeline);
compute_pass.set_bind_group(0, &bind_group2, &[]);
compute_pass.dispatch_workgroups(1, 1, 1);
}
let staging_buffer = gpu_ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("reduction_staging"),
size: std::mem::size_of::<T>() as u64,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
encoder.copy_buffer_to_buffer(
&output_buffer,
0,
&staging_buffer,
0,
std::mem::size_of::<T>() as u64,
);
gpu_ctx.queue.submit(std::iter::once(encoder.finish()));
let buffer_slice = staging_buffer.slice(..);
let (sender, receiver) = futures::channel::oneshot::channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
sender.send(result).ok();
});
gpu_ctx
.device
.poll(wgpu::PollType::wait_indefinitely())
.ok();
if let Ok(Ok(())) = pollster::block_on(receiver) {
let data = buffer_slice.get_mapped_range();
let result_value: T = *bytemuck::from_bytes::<T>(&data[..std::mem::size_of::<T>()]);
drop(data);
staging_buffer.unmap();
use scirs2_core::ndarray::Array;
let result_array = Array::from_elem(vec![], result_value).into_dyn();
Ok(Tensor::from_array(result_array))
} else {
Err(TensorError::gpu_error(
"GPU reduction",
"Failed to read back result from staging buffer",
Some(input.device().id()),
false,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reduction_op_wgsl() {
assert_eq!(ReductionOp::Sum.wgsl_op("f32"), "a + b");
assert_eq!(ReductionOp::Max.wgsl_op("f32"), "max(a, b)");
assert_eq!(ReductionOp::Product.wgsl_op("f32"), "a * b");
}
#[test]
fn test_reduction_op_identity() {
assert_eq!(ReductionOp::Sum.identity("f32"), "0.0f");
assert_eq!(ReductionOp::Product.identity("f32"), "1.0f");
assert!(ReductionOp::Max.identity("f32").contains("-3.40282347e+38"));
}
#[test]
fn test_create_reduction_kernel() {
let kernel = create_reduction_kernel(ReductionOp::Sum, "f32")
.expect("test: create_reduction_kernel should succeed");
assert_eq!(kernel.op, ReductionOp::Sum);
assert_eq!(kernel.dtype, "f32");
assert_eq!(kernel.workgroup_size, 256);
assert!(!kernel.shader_source.is_empty());
}
#[test]
fn test_shader_generation() {
let shader = generate_reduction_shader(ReductionOp::Sum, "f32", 256);
assert!(shader.contains("@compute"));
assert!(shader.contains("@workgroup_size(256"));
assert!(shader.contains("shared_data"));
assert!(shader.contains("workgroupBarrier"));
assert!(shader.contains("a + b")); }
#[test]
fn test_kernel_id() {
let kernel = create_reduction_kernel(ReductionOp::Sum, "f32")
.expect("test: create_reduction_kernel should succeed");
assert_eq!(kernel.id(), "reduce_sum_f32");
let kernel2 = create_reduction_kernel(ReductionOp::Max, "i32")
.expect("test: create_reduction_kernel should succeed");
assert_eq!(kernel2.id(), "reduce_max_i32");
}
#[test]
fn test_workgroup_calculation() {
let (wg, second_stage) = calculate_reduction_workgroups(1000, 256);
assert_eq!(wg, 1);
assert!(!second_stage);
let (wg, second_stage) = calculate_reduction_workgroups(10000, 256);
assert!(wg > 1);
assert!(second_stage);
}
#[test]
fn test_unsupported_dtype() {
let result = create_reduction_kernel(ReductionOp::Sum, "string");
assert!(result.is_err());
}
#[test]
fn test_all_ops_generate_valid_shaders() {
let ops = vec![
ReductionOp::Sum,
ReductionOp::Product,
ReductionOp::Max,
ReductionOp::Min,
ReductionOp::Mean,
];
for op in ops {
let kernel = create_reduction_kernel(op, "f32")
.expect("test: create_reduction_kernel should succeed");
assert!(!kernel.shader_source.is_empty());
assert!(kernel.shader_source.contains("@compute"));
}
}
#[test]
fn test_reduction_config_default() {
let config = ReductionConfig::default();
assert_eq!(config.axis, None);
assert!(!config.keepdims);
assert!(!config.kahan);
}
#[test]
fn test_extended_integer_dtypes() {
let dtypes = vec!["i8", "i16", "i64", "u8", "u16", "u64"];
for dtype in dtypes {
let kernel = create_reduction_kernel(ReductionOp::Sum, dtype);
assert!(
kernel.is_ok(),
"Failed to create reduction kernel for dtype: {}",
dtype
);
let kernel = kernel.expect("test: operation should succeed");
assert_eq!(kernel.dtype, dtype);
assert!(!kernel.shader_source.is_empty());
assert!(kernel.shader_source.contains("@compute"));
}
}
#[test]
fn test_integer_identity_values() {
assert_eq!(ReductionOp::Sum.identity("i8"), "0");
assert_eq!(ReductionOp::Product.identity("i8"), "1");
assert_eq!(ReductionOp::Max.identity("i8"), "-128");
assert_eq!(ReductionOp::Min.identity("i8"), "127");
assert_eq!(ReductionOp::Sum.identity("i16"), "0");
assert_eq!(ReductionOp::Max.identity("i16"), "-32768");
assert_eq!(ReductionOp::Min.identity("i16"), "32767");
assert_eq!(ReductionOp::Max.identity("i64"), "-9223372036854775808");
assert_eq!(ReductionOp::Min.identity("i64"), "9223372036854775807");
assert_eq!(ReductionOp::Sum.identity("u8"), "0u");
assert_eq!(ReductionOp::Min.identity("u8"), "255u");
assert_eq!(ReductionOp::Min.identity("u16"), "65535u");
assert_eq!(ReductionOp::Min.identity("u64"), "18446744073709551615u");
}
#[test]
fn test_all_dtypes_shader_generation() {
let dtypes = vec![
"f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "bool",
];
for dtype in dtypes {
let shader = generate_reduction_shader(ReductionOp::Sum, dtype, 256);
assert!(!shader.is_empty(), "Empty shader for dtype: {}", dtype);
assert!(
shader.contains("@compute"),
"Missing @compute directive for dtype: {}",
dtype
);
assert!(
shader.contains("@workgroup_size"),
"Missing workgroup_size for dtype: {}",
dtype
);
}
}
#[test]
fn test_wgsl_type_mapping() {
let shader_i8 = generate_reduction_shader(ReductionOp::Sum, "i8", 256);
assert!(shader_i8.contains("array<i32"));
let shader_u8 = generate_reduction_shader(ReductionOp::Sum, "u8", 256);
assert!(shader_u8.contains("array<u32"));
let shader_i64 = generate_reduction_shader(ReductionOp::Sum, "i64", 256);
assert!(shader_i64.contains("array<i32"));
let shader_u64 = generate_reduction_shader(ReductionOp::Sum, "u64", 256);
assert!(shader_u64.contains("array<u32"));
}
}