use wgpu::{Buffer, Queue};
use super::pipeline::{LayoutKey, PipelineCache, workgroup_count};
use crate::dtype::DType;
use crate::error::{Error, Result};
const CONV1D_SHADER: &str = include_str!("conv1d_f32.wgsl");
const CONV2D_SHADER: &str = include_str!("conv2d_f32.wgsl");
const DEPTHWISE_CONV2D_SHADER: &str = include_str!("depthwise_conv2d_f32.wgsl");
fn check_dtype_f32(dtype: DType, op: &'static str) -> Result<()> {
match dtype {
DType::F32 => Ok(()),
_ => Err(Error::UnsupportedDType { dtype, op }),
}
}
pub fn launch_conv1d(
cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
weight: &Buffer,
bias: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
total_output: usize,
dtype: DType,
) -> Result<()> {
check_dtype_f32(dtype, "conv1d")?;
let module = cache.get_or_create_module("conv1d_f32", CONV1D_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 4,
num_uniform_buffers: 1,
num_readonly_storage: 3,
});
let pipeline = cache.get_or_create_pipeline("conv1d_f32", "conv1d_f32", &module, &layout);
let bind_group =
cache.create_bind_group(&layout, &[input, weight, bias, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("conv1d"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("conv1d"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(total_output), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_conv2d(
cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
weight: &Buffer,
bias: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
total_output: usize,
dtype: DType,
) -> Result<()> {
check_dtype_f32(dtype, "conv2d")?;
let module = cache.get_or_create_module("conv2d_f32", CONV2D_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 4,
num_uniform_buffers: 1,
num_readonly_storage: 3,
});
let pipeline = cache.get_or_create_pipeline("conv2d_f32", "conv2d_f32", &module, &layout);
let bind_group =
cache.create_bind_group(&layout, &[input, weight, bias, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("conv2d"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("conv2d"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(total_output), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_depthwise_conv2d(
cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
weight: &Buffer,
bias: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
total_output: usize,
dtype: DType,
) -> Result<()> {
check_dtype_f32(dtype, "depthwise_conv2d")?;
let module = cache.get_or_create_module("depthwise_conv2d_f32", DEPTHWISE_CONV2D_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 4,
num_uniform_buffers: 1,
num_readonly_storage: 3,
});
let pipeline = cache.get_or_create_pipeline(
"depthwise_conv2d_f32",
"depthwise_conv2d_f32",
&module,
&layout,
);
let bind_group =
cache.create_bind_group(&layout, &[input, weight, bias, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("depthwise_conv2d"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("depthwise_conv2d"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(total_output), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}