use wgpu::{Buffer, Queue};
use super::pipeline::{LayoutKey, PipelineCache};
use crate::dtype::DType;
use crate::error::{Error, Result};
const NORM_SHADER: &str = include_str!("norm.wgsl");
macro_rules! check_dtype_f32 {
($dtype:expr, $op:expr) => {
if $dtype != DType::F32 {
return Err(Error::UnsupportedDType {
dtype: $dtype,
op: $op,
});
}
};
}
pub fn launch_rms_norm(
cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
weight: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
batch_size: usize,
dtype: DType,
) -> Result<()> {
check_dtype_f32!(dtype, "rms_norm");
let module = cache.get_or_create_module("norm", NORM_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 3,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline("norm", "rms_norm_f32", &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[input, weight, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("rms_norm"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("rms_norm"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(batch_size as u32, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_layer_norm(
cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
weight: &Buffer,
bias: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
batch_size: usize,
dtype: DType,
) -> Result<()> {
check_dtype_f32!(dtype, "layer_norm");
let module = cache.get_or_create_module("norm", NORM_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 4,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline("norm", "layer_norm_f32", &module, &layout);
let bind_group =
cache.create_bind_group(&layout, &[input, weight, bias, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("layer_norm"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("layer_norm"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(batch_size as u32, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_layer_norm_no_bias(
cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
weight: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
batch_size: usize,
dtype: DType,
) -> Result<()> {
check_dtype_f32!(dtype, "layer_norm_no_bias");
let module = cache.get_or_create_module("norm", NORM_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 3,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline("norm", "layer_norm_no_bias_f32", &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[input, weight, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("layer_norm_no_bias"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("layer_norm_no_bias"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(batch_size as u32, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_group_norm(
cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
weight: &Buffer,
bias: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
batch_size: usize,
num_groups: usize,
dtype: DType,
) -> Result<()> {
check_dtype_f32!(dtype, "group_norm");
let module = cache.get_or_create_module("norm", NORM_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 4,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline("norm", "group_norm_f32", &module, &layout);
let bind_group =
cache.create_bind_group(&layout, &[input, weight, bias, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("group_norm"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("group_norm"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups((batch_size * num_groups) as u32, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}