use super::super::*;
use super::operation_types::ReductionOp;
use crate::Result;
pub fn execute_reduction_op<T>(
input: &GpuBuffer<T>,
op: ReductionOp,
axes: Option<&[usize]>,
) -> Result<GpuBuffer<T>>
where
T: bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static,
{
let axes_i32: Option<Vec<i32>> = axes.map(|a| a.iter().map(|&x| x as i32).collect());
let axes_ref = axes_i32.as_deref();
let input_shape = &[input.len()];
let output_len = 1;
execute_axis_reduction_op(input, op, input_shape, axes_ref, false, output_len)
}
pub fn execute_axis_reduction_op<T>(
input: &GpuBuffer<T>,
op: ReductionOp,
input_shape: &[usize],
axes: Option<&[i32]>,
keep_dims: bool,
output_len: usize,
) -> Result<GpuBuffer<T>>
where
T: bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static,
{
use wgpu::util::DeviceExt;
let context = crate::gpu::GpuContext::global()?;
let device = &context.device;
let queue = &context.queue;
let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("axis_reduction_output"),
size: (output_len * std::mem::size_of::<T>()) as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let input_len = input.len();
let total_elements = input_shape.iter().product::<usize>();
let output_shape: Vec<usize> = if let Some(axes_slice) = axes {
let mut out_shape = input_shape.to_vec();
let mut sorted_axes: Vec<_> = axes_slice
.iter()
.map(|&a| {
if a < 0 {
(input_shape.len() as i32 + a) as usize
} else {
a as usize
}
})
.collect();
sorted_axes.sort_by(|a, b| b.cmp(a));
for &axis in &sorted_axes {
if keep_dims {
out_shape[axis] = 1;
} else {
out_shape.remove(axis);
}
}
if out_shape.is_empty() {
vec![1] } else {
out_shape
}
} else {
vec![1] };
let mut metadata = vec![
input_len as u32,
output_len as u32,
input_shape.len() as u32, 0u32, ];
if let Some(axes_slice) = axes {
metadata[3] = axes_slice.len() as u32; for &axis in axes_slice.iter().take(4) {
let normalized_axis = if axis < 0 {
(input_shape.len() as i32 + axis) as u32
} else {
axis as u32
};
metadata.push(normalized_axis);
}
}
let metadata_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("axis_reduction_metadata"),
contents: bytemuck::cast_slice(&metadata),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
});
let input_shape_u32: Vec<u32> = input_shape.iter().map(|&x| x as u32).collect();
let output_shape_u32: Vec<u32> = output_shape.iter().map(|&x| x as u32).collect();
let input_shape_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("input_shape"),
contents: bytemuck::cast_slice(&input_shape_u32),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
});
let output_shape_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("output_shape"),
contents: bytemuck::cast_slice(&output_shape_u32),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
});
let is_axis_reduction = axes.is_some();
let shader_entry_point = match op {
ReductionOp::Sum => {
if is_axis_reduction {
"sum_axis_reduction"
} else {
"sum_reduction"
}
}
ReductionOp::Mean => {
if is_axis_reduction {
"mean_axis_reduction"
} else {
"mean_reduction"
}
}
ReductionOp::Max => {
if is_axis_reduction {
"max_axis_reduction"
} else {
"max_reduction"
}
}
ReductionOp::Min => {
if is_axis_reduction {
"min_axis_reduction"
} else {
"min_reduction"
}
}
ReductionOp::Product | ReductionOp::Prod => "product_reduction",
ReductionOp::ArgMax => "argmax_reduction",
ReductionOp::ArgMin => "argmin_reduction",
ReductionOp::All => "all_reduction",
ReductionOp::Any => "any_reduction",
ReductionOp::InfNanDetection => "inf_nan_detection",
ReductionOp::Variance => "variance_reduction",
ReductionOp::TopK => {
return Err(crate::TensorError::unsupported_operation_simple(
"TopK reduction requires specialized implementation".to_string(),
))
}
};
let shader_source = include_str!("../shaders/reduction_ops.wgsl");
let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("axis_reduction_shader"),
source: wgpu::ShaderSource::Wgsl(shader_source.into()),
});
let mut bind_group_entries = vec![
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
];
if is_axis_reduction {
bind_group_entries.push(wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
});
bind_group_entries.push(wgpu::BindGroupLayoutEntry {
binding: 4,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
});
}
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("axis_reduction_bind_group_layout"),
entries: &bind_group_entries,
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("axis_reduction_pipeline_layout"),
bind_group_layouts: &[Some(&bind_group_layout)],
immediate_size: 0,
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("axis_reduction_pipeline"),
layout: Some(&pipeline_layout),
module: &shader_module,
entry_point: Some(shader_entry_point),
cache: None,
compilation_options: Default::default(),
});
let mut bind_group_bind_entries = vec![
wgpu::BindGroupEntry {
binding: 0,
resource: input.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: output_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: metadata_buffer.as_entire_binding(),
},
];
if is_axis_reduction {
bind_group_bind_entries.push(wgpu::BindGroupEntry {
binding: 3,
resource: input_shape_buffer.as_entire_binding(),
});
bind_group_bind_entries.push(wgpu::BindGroupEntry {
binding: 4,
resource: output_shape_buffer.as_entire_binding(),
});
}
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("axis_reduction_bind_group"),
layout: &bind_group_layout,
entries: &bind_group_bind_entries,
});
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("axis_reduction_encoder"),
});
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("axis_reduction_pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(&pipeline);
compute_pass.set_bind_group(0, &bind_group, &[]);
let workgroup_size = 256;
let num_workgroups = (output_len + workgroup_size - 1) / workgroup_size;
compute_pass.dispatch_workgroups(num_workgroups as u32, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
let device_id = match input.device_enum() {
Device::Gpu(id) => id,
_ => 0, };
Ok(GpuBuffer::from_wgpu_buffer(
output_buffer,
context.device.clone(),
context.queue.clone(),
Device::Gpu(device_id),
output_len,
))
}