use super::super::*;
use super::operation_types::PoolingOp;
use crate::Result;
pub fn execute_pooling_op<T>(
input: &GpuBuffer<T>,
op: PoolingOp,
kernel_size: &[usize],
stride: &[usize],
padding: &[usize],
input_shape: &[usize],
output_len: usize,
) -> Result<GpuBuffer<T>>
where
T: bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static,
{
use wgpu::util::DeviceExt;
let context = crate::gpu::GpuContext::global()?;
let device = &context.device;
let queue = &context.queue;
let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("pooling_output"),
size: (output_len * std::mem::size_of::<T>()) as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let mut metadata = Vec::new();
metadata.push(input_shape.len() as u32); metadata.extend(input_shape.iter().map(|&x| x as u32)); metadata.extend(kernel_size.iter().map(|&x| x as u32)); metadata.extend(stride.iter().map(|&x| x as u32)); metadata.extend(padding.iter().map(|&x| x as u32)); metadata.push(output_len as u32);
let metadata_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("pooling_metadata"),
contents: bytemuck::cast_slice(&metadata),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
});
let shader_entry_point = match op {
PoolingOp::MaxPool2D => "max_pool2d_kernel",
PoolingOp::AvgPool2D => "avg_pool2d_kernel",
PoolingOp::GlobalAvgPool => "global_avg_pool2d_kernel",
PoolingOp::GlobalMaxPool => "global_max_pool2d_kernel",
_ => {
return Err(crate::TensorError::unsupported_operation_simple(format!(
"Pooling operation {:?} not implemented for GPU",
op
)))
}
};
let shader_source = include_str!("../shaders/pooling_ops.wgsl");
let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("pooling_shader"),
source: wgpu::ShaderSource::Wgsl(shader_source.into()),
});
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("pooling_bind_group_layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("pooling_pipeline_layout"),
bind_group_layouts: &[Some(&bind_group_layout)],
immediate_size: 0,
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("pooling_pipeline"),
layout: Some(&pipeline_layout),
module: &shader_module,
entry_point: Some(shader_entry_point),
cache: None,
compilation_options: Default::default(),
});
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("pooling_bind_group"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: input.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: output_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: metadata_buffer.as_entire_binding(),
},
],
});
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("pooling_encoder"),
});
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("pooling_pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(&pipeline);
compute_pass.set_bind_group(0, &bind_group, &[]);
match op {
PoolingOp::GlobalAvgPool | PoolingOp::GlobalMaxPool => {
let num_workgroups = (output_len + 255) / 256;
compute_pass.dispatch_workgroups(num_workgroups as u32, 1, 1);
}
_ => {
let output_height = input_shape.len().saturating_sub(2).max(1);
let output_width = input_shape.len().saturating_sub(1).max(1);
let batch_channels =
input_shape.first().unwrap_or(&1) * input_shape.get(1).unwrap_or(&1);
let workgroups_x = (output_width + 7) / 8;
let workgroups_y = (output_height + 7) / 8;
let workgroups_z = batch_channels as u32;
compute_pass.dispatch_workgroups(
workgroups_x as u32,
workgroups_y as u32,
workgroups_z,
);
}
}
}
queue.submit(std::iter::once(encoder.finish()));
let device_id = match input.device_enum() {
Device::Gpu(id) => id,
_ => 0, };
Ok(GpuBuffer::from_wgpu_buffer(
output_buffer,
context.device.clone(),
context.queue.clone(),
Device::Gpu(device_id),
output_len,
))
}
pub fn execute_fractional_pooling_op<T>(
input: &GpuBuffer<T>,
op: PoolingOp,
pooling_ratio: &[f32],
pseudo_random: bool,
overlapping: bool,
input_shape: &[usize],
output_len: usize,
) -> Result<GpuBuffer<T>>
where
T: bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static,
{
use wgpu::util::DeviceExt;
let context = crate::gpu::GpuContext::global()?;
let device = &context.device;
let queue = &context.queue;
let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("fractional_pooling_output"),
size: (output_len * std::mem::size_of::<T>()) as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let output_height = (input_shape[2] as f32 / pooling_ratio[0]) as u32;
let output_width = (input_shape[3] as f32 / pooling_ratio[1]) as u32;
let pooling_params = [
input_shape[0] as u32, input_shape[1] as u32, input_shape[2] as u32, input_shape[3] as u32, output_height, output_width, 0u32,
0u32, 0u32,
0u32, 0u32,
0u32, ];
let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("fractional_pooling_params"),
contents: bytemuck::cast_slice(&pooling_params),
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
});
let shader_entry_point = match op {
PoolingOp::MaxPool2D | PoolingOp::GlobalMaxPool => "fractional_max_pool2d_kernel",
PoolingOp::AvgPool2D | PoolingOp::GlobalAvgPool => "fractional_adaptive_pool2d",
_ => {
return Err(crate::TensorError::unsupported_operation_simple(format!(
"Fractional pooling operation {:?} not implemented for GPU",
op
)))
}
};
let shader_source = include_str!("../shaders/pooling_ops.wgsl");
let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("fractional_pooling_shader"),
source: wgpu::ShaderSource::Wgsl(shader_source.into()),
});
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("fractional_pooling_bind_group_layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("fractional_pooling_pipeline_layout"),
bind_group_layouts: &[Some(&bind_group_layout)],
immediate_size: 0,
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("fractional_pooling_pipeline"),
layout: Some(&pipeline_layout),
module: &shader_module,
entry_point: Some(shader_entry_point),
cache: None,
compilation_options: Default::default(),
});
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("fractional_pooling_bind_group"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: input.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: output_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: params_buffer.as_entire_binding(),
},
],
});
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("fractional_pooling_encoder"),
});
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("fractional_pooling_pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(&pipeline);
compute_pass.set_bind_group(0, &bind_group, &[]);
let batch_channels = input_shape[0] * input_shape[1];
let workgroups_x = (output_width + 7) / 8;
let workgroups_y = (output_height + 7) / 8;
let workgroups_z = batch_channels as u32;
compute_pass.dispatch_workgroups(workgroups_x, workgroups_y, workgroups_z);
}
queue.submit(std::iter::once(encoder.finish()));
let device_id = match input.device_enum() {
Device::Gpu(id) => id,
_ => 0, };
Ok(GpuBuffer::from_wgpu_buffer(
output_buffer,
context.device.clone(),
context.queue.clone(),
Device::Gpu(device_id),
output_len,
))
}