use wgpu::{Buffer, Queue};
use super::pipeline::{LayoutKey, PipelineCache};
use crate::dtype::DType;
use crate::error::{Error, Result};
const FUSED_ADD_NORM_SHADER: &str = include_str!("fused_add_norm.wgsl");
macro_rules! check_dtype_f32 {
($dtype:expr, $op:expr) => {
if $dtype != DType::F32 {
return Err(Error::UnsupportedDType {
dtype: $dtype,
op: $op,
});
}
};
}
pub fn launch_fused_add_rms_norm(
cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
residual: &Buffer,
weight: &Buffer,
output: &Buffer,
pre_norm: &Buffer,
params_buffer: &Buffer,
batch_size: usize,
dtype: DType,
) -> Result<()> {
check_dtype_f32!(dtype, "fused_add_rms_norm");
let module = cache.get_or_create_module("fused_add_norm", FUSED_ADD_NORM_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 5,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline =
cache.get_or_create_pipeline("fused_add_norm", "fused_add_rms_norm_f32", &module, &layout);
let bind_group = cache.create_bind_group(
&layout,
&[input, residual, weight, output, pre_norm, params_buffer],
);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("fused_add_rms_norm"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("fused_add_rms_norm"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(batch_size as u32, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_fused_add_layer_norm(
cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
residual: &Buffer,
weight: &Buffer,
bias: &Buffer,
output: &Buffer,
pre_norm: &Buffer,
params_buffer: &Buffer,
batch_size: usize,
dtype: DType,
) -> Result<()> {
check_dtype_f32!(dtype, "fused_add_layer_norm");
let module = cache.get_or_create_module("fused_add_norm", FUSED_ADD_NORM_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 6,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(
"fused_add_norm",
"fused_add_layer_norm_f32",
&module,
&layout,
);
let bind_group = cache.create_bind_group(
&layout,
&[
input,
residual,
weight,
bias,
output,
pre_norm,
params_buffer,
],
);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("fused_add_layer_norm"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("fused_add_layer_norm"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(batch_size as u32, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_fused_add_rms_norm_bwd(
cache: &PipelineCache,
queue: &Queue,
grad: &Buffer,
pre_norm: &Buffer,
weight: &Buffer,
d_input_residual: &Buffer,
d_weight_scratch: &Buffer,
params_buffer: &Buffer,
batch_size: usize,
dtype: DType,
) -> Result<()> {
check_dtype_f32!(dtype, "fused_add_rms_norm_bwd");
let module = cache.get_or_create_module("fused_add_norm", FUSED_ADD_NORM_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 5,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(
"fused_add_norm",
"fused_add_rms_norm_bwd_f32",
&module,
&layout,
);
let bind_group = cache.create_bind_group(
&layout,
&[
grad,
pre_norm,
weight,
d_input_residual,
d_weight_scratch,
params_buffer,
],
);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("fused_add_rms_norm_bwd"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("fused_add_rms_norm_bwd"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(batch_size as u32, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_fused_add_layer_norm_bwd(
cache: &PipelineCache,
queue: &Queue,
grad: &Buffer,
pre_norm: &Buffer,
weight: &Buffer,
bias: &Buffer,
d_input_residual: &Buffer,
d_weight_scratch: &Buffer,
d_bias_scratch: &Buffer,
params_buffer: &Buffer,
batch_size: usize,
dtype: DType,
) -> Result<()> {
check_dtype_f32!(dtype, "fused_add_layer_norm_bwd");
let module = cache.get_or_create_module("fused_add_norm", FUSED_ADD_NORM_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 7,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(
"fused_add_norm",
"fused_add_layer_norm_bwd_f32",
&module,
&layout,
);
let bind_group = cache.create_bind_group(
&layout,
&[
grad,
pre_norm,
weight,
bias,
d_input_residual,
d_weight_scratch,
d_bias_scratch,
params_buffer,
],
);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("fused_add_layer_norm_bwd"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("fused_add_layer_norm_bwd"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(batch_size as u32, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_reduce_sum_rows(
cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
hidden_size: usize,
dtype: DType,
) -> Result<()> {
check_dtype_f32!(dtype, "reduce_sum_rows");
let module = cache.get_or_create_module("fused_add_norm", FUSED_ADD_NORM_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline =
cache.get_or_create_pipeline("fused_add_norm", "reduce_sum_rows_f32", &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[input, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("reduce_sum_rows"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("reduce_sum_rows"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
let num_workgroups = (hidden_size as u32 + 255) / 256;
pass.dispatch_workgroups(num_workgroups, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}