use super::super::*;
use crate::Result;
pub fn execute_gather<T>(
input: &GpuBuffer<T>,
indices: &GpuBuffer<u32>,
axis: usize,
input_shape: &[usize],
indices_shape: &[usize],
output_len: usize,
) -> Result<GpuBuffer<T>>
where
T: bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static,
{
use wgpu::util::DeviceExt;
let context = crate::gpu::GpuContext::global()?;
let device = &context.device;
let queue = &context.queue;
let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("gather_output"),
size: (output_len * std::mem::size_of::<T>()) as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let gather_info = [
input_shape.len() as u32, output_len as u32, axis as u32, 0u32, ];
let gather_info_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("gather_info"),
contents: bytemuck::cast_slice(&gather_info),
usage: wgpu::BufferUsages::UNIFORM,
});
let input_shape_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("gather_input_shape"),
contents: bytemuck::cast_slice(&input_shape.iter().map(|&x| x as u32).collect::<Vec<_>>()),
usage: wgpu::BufferUsages::STORAGE,
});
let indices_shape_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("gather_indices_shape"),
contents: bytemuck::cast_slice(
&indices_shape.iter().map(|&x| x as u32).collect::<Vec<_>>(),
),
usage: wgpu::BufferUsages::STORAGE,
});
let shader_source = include_str!("../shaders/manipulation_ops.wgsl");
let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("gather_shader"),
source: wgpu::ShaderSource::Wgsl(shader_source.into()),
});
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("gather_bind_group_layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 4,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 5,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("gather_pipeline_layout"),
bind_group_layouts: &[Some(&bind_group_layout)],
immediate_size: 0,
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("gather_pipeline"),
layout: Some(&pipeline_layout),
module: &shader_module,
entry_point: Some("gather_op"),
cache: None,
compilation_options: Default::default(),
});
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("gather_bind_group"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: input.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: indices.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: output_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: gather_info_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: input_shape_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: indices_shape_buffer.as_entire_binding(),
},
],
});
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("gather_encoder"),
});
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("gather_pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(&pipeline);
compute_pass.set_bind_group(0, &bind_group, &[]);
let workgroup_size = 64;
let num_workgroups = (output_len + workgroup_size - 1) / workgroup_size;
compute_pass.dispatch_workgroups(num_workgroups as u32, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
let device_id = match input.device_enum() {
Device::Gpu(id) => id,
_ => 0, };
Ok(GpuBuffer::from_wgpu_buffer(
output_buffer,
context.device.clone(),
context.queue.clone(),
Device::Gpu(device_id),
output_len,
))
}
pub fn execute_scatter<T>(
input: &GpuBuffer<T>,
indices: &GpuBuffer<u32>,
updates: &GpuBuffer<T>,
axis: usize,
input_shape: &[usize],
indices_shape: &[usize],
updates_shape: &[usize],
) -> Result<GpuBuffer<T>>
where
T: bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static,
{
use wgpu::util::DeviceExt;
let context = crate::gpu::GpuContext::global()?;
let device = &context.device;
let queue = &context.queue;
let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("scatter_output"),
size: input.buffer().size(),
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("scatter_copy_encoder"),
});
encoder.copy_buffer_to_buffer(input.buffer(), 0, &output_buffer, 0, input.buffer().size());
queue.submit(std::iter::once(encoder.finish()));
let device_id = match input.device_enum() {
Device::Gpu(id) => id,
_ => 0, };
Ok(GpuBuffer::from_wgpu_buffer(
output_buffer,
context.device.clone(),
context.queue.clone(),
Device::Gpu(device_id),
input.len(),
))
}
pub fn execute_where<T>(
condition: &GpuBuffer<u32>,
x: &GpuBuffer<T>,
y: &GpuBuffer<T>,
output_len: usize,
) -> Result<GpuBuffer<T>>
where
T: bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static,
{
use wgpu::util::DeviceExt;
let context = crate::gpu::GpuContext::global()?;
let device = &context.device;
let queue = &context.queue;
let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("where_output"),
size: (output_len * std::mem::size_of::<T>()) as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let where_info = [
output_len as u32, 0u32, 0u32, 0u32, ];
let where_info_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("where_info"),
contents: bytemuck::cast_slice(&where_info),
usage: wgpu::BufferUsages::UNIFORM,
});
let shader_source = include_str!("../shaders/manipulation_ops.wgsl");
let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("where_shader"),
source: wgpu::ShaderSource::Wgsl(shader_source.into()),
});
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("where_bind_group_layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 4,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("where_pipeline_layout"),
bind_group_layouts: &[Some(&bind_group_layout)],
immediate_size: 0,
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("where_pipeline"),
layout: Some(&pipeline_layout),
module: &shader_module,
entry_point: Some("where_op"),
cache: None,
compilation_options: Default::default(),
});
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("where_bind_group"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: condition.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: y.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: output_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: where_info_buffer.as_entire_binding(),
},
],
});
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("where_encoder"),
});
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("where_pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(&pipeline);
compute_pass.set_bind_group(0, &bind_group, &[]);
let workgroup_size = 64;
let num_workgroups = (output_len + workgroup_size - 1) / workgroup_size;
compute_pass.dispatch_workgroups(num_workgroups as u32, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
let device_id = match x.device_enum() {
Device::Gpu(id) => id,
_ => 0, };
Ok(GpuBuffer::from_wgpu_buffer(
output_buffer,
context.device.clone(),
context.queue.clone(),
Device::Gpu(device_id),
output_len,
))
}
pub fn execute_one_hot<T>(
indices: &GpuBuffer<u32>,
depth: usize,
on_value: T,
off_value: T,
axis: i32,
indices_shape: &[usize],
output_len: usize,
) -> Result<GpuBuffer<T>>
where
T: bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static,
{
use wgpu::util::DeviceExt;
let context = crate::gpu::GpuContext::global()?;
let device = &context.device;
let queue = &context.queue;
let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("one_hot_output"),
size: (output_len * std::mem::size_of::<T>()) as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let on_value_bytes: [u8; 4] = bytemuck::cast(on_value);
let off_value_bytes: [u8; 4] = bytemuck::cast(off_value);
let on_value_u32 = u32::from_ne_bytes(on_value_bytes);
let off_value_u32 = u32::from_ne_bytes(off_value_bytes);
let one_hot_info = [
output_len as u32, depth as u32, on_value_u32, off_value_u32, ];
let one_hot_info_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("one_hot_info"),
contents: bytemuck::cast_slice(&one_hot_info),
usage: wgpu::BufferUsages::UNIFORM,
});
let shader_source = include_str!("../shaders/manipulation_ops.wgsl");
let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("one_hot_shader"),
source: wgpu::ShaderSource::Wgsl(shader_source.into()),
});
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("one_hot_bind_group_layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("one_hot_pipeline_layout"),
bind_group_layouts: &[Some(&bind_group_layout)],
immediate_size: 0,
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("one_hot_pipeline"),
layout: Some(&pipeline_layout),
module: &shader_module,
entry_point: Some("one_hot_op"),
cache: None,
compilation_options: Default::default(),
});
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("one_hot_bind_group"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: indices.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: output_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: one_hot_info_buffer.as_entire_binding(),
},
],
});
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("one_hot_encoder"),
});
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("one_hot_pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(&pipeline);
compute_pass.set_bind_group(0, &bind_group, &[]);
let workgroup_size = 64;
let num_workgroups = (output_len + workgroup_size - 1) / workgroup_size;
compute_pass.dispatch_workgroups(num_workgroups as u32, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
let device_id = match indices.device_enum() {
Device::Gpu(id) => id,
_ => 0, };
Ok(GpuBuffer::from_wgpu_buffer(
output_buffer,
context.device.clone(),
context.queue.clone(),
Device::Gpu(device_id),
output_len,
))
}