use wgpu::{Buffer, Queue};
use super::pipeline::{LayoutKey, PipelineCache, workgroup_count};
use crate::dtype::DType;
use crate::error::{Error, Result};
const TERNARY_SHADER: &str = include_str!("fused_elementwise.wgsl");
const SCALAR_SHADER: &str = include_str!("fused_elementwise_scalar.wgsl");
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct TernaryParams {
numel: u32,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct ScalarFmaParams {
numel: u32,
scale: f32,
bias: f32,
_pad: u32,
}
fn launch_ternary(
cache: &PipelineCache,
queue: &Queue,
entry_point: &'static str,
op_name: &'static str,
a: &Buffer,
b: &Buffer,
c: &Buffer,
out: &Buffer,
numel: usize,
dtype: DType,
) -> Result<()> {
if dtype != DType::F32 {
return Err(Error::UnsupportedDType { dtype, op: op_name });
}
let params = TernaryParams {
numel: numel as u32,
};
let params_buf = cache.device().create_buffer(&wgpu::BufferDescriptor {
label: Some("fused_elem_params"),
size: std::mem::size_of::<TernaryParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
queue.write_buffer(¶ms_buf, 0, bytemuck::bytes_of(¶ms));
let module = cache.get_or_create_module("fused_elementwise_f32", TERNARY_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 4,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline =
cache.get_or_create_pipeline("fused_elementwise_f32", entry_point, &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[a, b, c, out, ¶ms_buf]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some(op_name),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(op_name),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(numel), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_fused_mul_add(
cache: &PipelineCache,
queue: &Queue,
a: &Buffer,
b: &Buffer,
c: &Buffer,
out: &Buffer,
numel: usize,
dtype: DType,
) -> Result<()> {
launch_ternary(
cache,
queue,
"fused_mul_add_f32",
"fused_mul_add",
a,
b,
c,
out,
numel,
dtype,
)
}
pub fn launch_fused_add_mul(
cache: &PipelineCache,
queue: &Queue,
a: &Buffer,
b: &Buffer,
c: &Buffer,
out: &Buffer,
numel: usize,
dtype: DType,
) -> Result<()> {
launch_ternary(
cache,
queue,
"fused_add_mul_f32",
"fused_add_mul",
a,
b,
c,
out,
numel,
dtype,
)
}
pub fn launch_fused_mul_add_scalar(
cache: &PipelineCache,
queue: &Queue,
a: &Buffer,
out: &Buffer,
numel: usize,
dtype: DType,
scale: f32,
bias: f32,
) -> Result<()> {
if dtype != DType::F32 {
return Err(Error::UnsupportedDType {
dtype,
op: "fused_mul_add_scalar",
});
}
let params = ScalarFmaParams {
numel: numel as u32,
scale,
bias,
_pad: 0,
};
let params_buf = cache.device().create_buffer(&wgpu::BufferDescriptor {
label: Some("fused_elem_scalar_params"),
size: std::mem::size_of::<ScalarFmaParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
queue.write_buffer(¶ms_buf, 0, bytemuck::bytes_of(¶ms));
let module = cache.get_or_create_module("fused_elementwise_scalar_f32", SCALAR_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(
"fused_elementwise_scalar_f32",
"fused_mul_add_scalar_f32",
&module,
&layout,
);
let bind_group = cache.create_bind_group(&layout, &[a, out, ¶ms_buf]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("fused_mul_add_scalar"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("fused_mul_add_scalar"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(numel), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}