use wgpu::{Buffer, Queue};
use super::pipeline::{LayoutKey, PipelineCache};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::GemmActivation;
const GEMM_EPILOGUE_SHADER: &str = include_str!("gemm_epilogue_f32.wgsl");
const GEMM_EPILOGUE_RESIDUAL_SHADER: &str = include_str!("gemm_epilogue_residual_f32.wgsl");
const TILE_SIZE: u32 = 16;
fn activation_to_u32(act: GemmActivation) -> u32 {
match act {
GemmActivation::None => 0,
GemmActivation::ReLU => 1,
GemmActivation::GELU => 2,
GemmActivation::SiLU => 3,
GemmActivation::Sigmoid => 4,
GemmActivation::Tanh => 5,
}
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
pub struct GemmEpilogueParams {
pub m: u32,
pub k: u32,
pub n: u32,
pub batch_size: u32,
pub activation_type: u32,
pub _pad0: u32,
pub _pad1: u32,
pub _pad2: u32,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
pub struct GemmResidualParams {
pub m: u32,
pub k: u32,
pub n: u32,
pub batch_size: u32,
}
fn check_f32(dtype: DType, op: &'static str) -> Result<()> {
if dtype != DType::F32 {
return Err(Error::UnsupportedDType { dtype, op });
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn launch_gemm_bias_act(
cache: &PipelineCache,
queue: &Queue,
a: &Buffer,
b: &Buffer,
bias: &Buffer,
c: &Buffer,
params_buffer: &Buffer,
m: usize,
n: usize,
dtype: DType,
) -> Result<()> {
check_f32(dtype, "gemm_bias_act")?;
let module = cache.get_or_create_module("gemm_epilogue_f32", GEMM_EPILOGUE_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 4,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline =
cache.get_or_create_pipeline("gemm_bias_act_f32", "gemm_bias_act_f32", &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[a, b, bias, c, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("gemm_bias_act"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("gemm_bias_act"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
let gx = (n as u32 + TILE_SIZE - 1) / TILE_SIZE;
let gy = (m as u32 + TILE_SIZE - 1) / TILE_SIZE;
pass.dispatch_workgroups(gx, gy, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn launch_gemm_bias_act_batched(
cache: &PipelineCache,
queue: &Queue,
a: &Buffer,
b: &Buffer,
bias: &Buffer,
c: &Buffer,
params_buffer: &Buffer,
m: usize,
n: usize,
batch_size: usize,
dtype: DType,
) -> Result<()> {
check_f32(dtype, "gemm_bias_act_batched")?;
let module = cache.get_or_create_module("gemm_epilogue_f32", GEMM_EPILOGUE_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 4,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(
"gemm_bias_act_batched_f32",
"gemm_bias_act_batched_f32",
&module,
&layout,
);
let bind_group = cache.create_bind_group(&layout, &[a, b, bias, c, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("gemm_bias_act_batched"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("gemm_bias_act_batched"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
let gx = (n as u32 + TILE_SIZE - 1) / TILE_SIZE;
let gy = (m as u32 + TILE_SIZE - 1) / TILE_SIZE;
pass.dispatch_workgroups(gx, gy, batch_size as u32);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn launch_gemm_bias_residual(
cache: &PipelineCache,
queue: &Queue,
a: &Buffer,
b: &Buffer,
bias: &Buffer,
residual: &Buffer,
c: &Buffer,
params_buffer: &Buffer,
m: usize,
n: usize,
dtype: DType,
) -> Result<()> {
check_f32(dtype, "gemm_bias_residual")?;
let module =
cache.get_or_create_module("gemm_epilogue_residual_f32", GEMM_EPILOGUE_RESIDUAL_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 5,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(
"gemm_bias_residual_f32",
"gemm_bias_residual_f32",
&module,
&layout,
);
let bind_group = cache.create_bind_group(&layout, &[a, b, bias, residual, c, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("gemm_bias_residual"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("gemm_bias_residual"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
let gx = (n as u32 + TILE_SIZE - 1) / TILE_SIZE;
let gy = (m as u32 + TILE_SIZE - 1) / TILE_SIZE;
pass.dispatch_workgroups(gx, gy, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn launch_gemm_bias_residual_batched(
cache: &PipelineCache,
queue: &Queue,
a: &Buffer,
b: &Buffer,
bias: &Buffer,
residual: &Buffer,
c: &Buffer,
params_buffer: &Buffer,
m: usize,
n: usize,
batch_size: usize,
dtype: DType,
) -> Result<()> {
check_f32(dtype, "gemm_bias_residual_batched")?;
let module =
cache.get_or_create_module("gemm_epilogue_residual_f32", GEMM_EPILOGUE_RESIDUAL_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 5,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(
"gemm_bias_residual_batched_f32",
"gemm_bias_residual_batched_f32",
&module,
&layout,
);
let bind_group = cache.create_bind_group(&layout, &[a, b, bias, residual, c, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("gemm_bias_residual_batched"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("gemm_bias_residual_batched"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
let gx = (n as u32 + TILE_SIZE - 1) / TILE_SIZE;
let gy = (m as u32 + TILE_SIZE - 1) / TILE_SIZE;
pass.dispatch_workgroups(gx, gy, batch_size as u32);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn create_epilogue_params_buffer(
cache: &PipelineCache,
m: u32,
k: u32,
n: u32,
batch_size: u32,
activation: GemmActivation,
) -> Buffer {
let params = GemmEpilogueParams {
m,
k,
n,
batch_size,
activation_type: activation_to_u32(activation),
_pad0: 0,
_pad1: 0,
_pad2: 0,
};
use wgpu::util::DeviceExt;
cache
.device()
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("gemm_epilogue_params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
})
}
pub fn create_residual_params_buffer(
cache: &PipelineCache,
m: u32,
k: u32,
n: u32,
batch_size: u32,
) -> Buffer {
let params = GemmResidualParams {
m,
k,
n,
batch_size,
};
use wgpu::util::DeviceExt;
cache
.device()
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("gemm_residual_params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
})
}