use super::context::GpuContext;
use wgpu::util::{BufferInitDescriptor, DeviceExt};
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct ElementwiseParams {
numel: u32,
scalar: f32,
_pad0: u32,
_pad1: u32,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct MatmulParams {
m: u32,
k: u32,
n: u32,
_pad: u32,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct BiasParams {
m: u32,
n: u32,
_pad0: u32,
_pad1: u32,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct SgdParams {
lr: f32,
momentum: f32,
numel: u32,
_pad: u32,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct AdamParams {
lr: f32,
beta1: f32,
beta2: f32,
epsilon: f32,
bc1: f32,
bc2: f32,
numel: u32,
_pad: u32,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct Im2ColParams {
c_in: u32,
h: u32,
w: u32,
k: u32,
stride: u32,
pad: u32,
out_h: u32,
out_w: u32,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct ChannelBiasParams {
channels: u32,
spatial: u32,
_pad0: u32,
_pad1: u32,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct ContiguousParams {
numel: u32,
ndim: u32,
offset: u32,
_pad: u32,
shape_lo: [u32; 4],
shape_hi: [u32; 4],
strides_lo: [u32; 4],
strides_hi: [u32; 4],
suffix_lo: [u32; 4],
suffix_hi: [u32; 4],
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct FusedDropoutParams {
numel: u32,
seed: u32,
p_threshold: u32,
scale: f32,
ndim: u32,
offset: u32,
_pad0: u32,
_pad1: u32,
shape_lo: [u32; 4],
shape_hi: [u32; 4],
strides_lo: [u32; 4],
strides_hi: [u32; 4],
suffix_lo: [u32; 4],
suffix_hi: [u32; 4],
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct CrossEntropyParams {
batch: u32,
num_classes: u32,
_pad0: u32,
_pad1: u32,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct ReduceParams {
numel: u32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct AdamWParams {
lr: f32,
beta1: f32,
beta2: f32,
epsilon: f32,
bc1: f32,
bc2: f32,
weight_decay: f32,
numel: u32,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct BroadcastScaleParams {
numel: u32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}
const _: () = assert!(std::mem::size_of::<BroadcastScaleParams>() == 16);
const _: () = assert!(std::mem::size_of::<CrossEntropyParams>() == 16);
const _: () = assert!(std::mem::size_of::<ReduceParams>() == 16);
const _: () = assert!(std::mem::size_of::<AdamWParams>() == 32);
const _: () = assert!(std::mem::size_of::<FusedDropoutParams>() == 128);
const _: () = assert!(std::mem::size_of::<ContiguousParams>() == 112);
const _: () = assert!(std::mem::size_of::<ElementwiseParams>() == 16);
const _: () = assert!(std::mem::size_of::<MatmulParams>() == 16);
const _: () = assert!(std::mem::size_of::<BiasParams>() == 16);
const _: () = assert!(std::mem::size_of::<SgdParams>() == 16);
const _: () = assert!(std::mem::size_of::<AdamParams>() == 32);
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct MaxPool2dParams {
channels: u32,
h: u32,
w: u32,
k: u32,
stride: u32,
out_h: u32,
out_w: u32,
_pad: u32,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct DropoutParams {
numel: u32,
seed: u32,
p_threshold: u32,
scale: f32,
}
const _: () = assert!(std::mem::size_of::<Im2ColParams>() == 32);
const _: () = assert!(std::mem::size_of::<ChannelBiasParams>() == 16);
const _: () = assert!(std::mem::size_of::<MaxPool2dParams>() == 32);
const _: () = assert!(std::mem::size_of::<DropoutParams>() == 16);
fn dispatch_binary(
ctx: &GpuContext,
pipeline: &wgpu::ComputePipeline,
lhs: &wgpu::Buffer,
rhs: &wgpu::Buffer,
dst: &wgpu::Buffer,
numel: u32,
) {
let params = ElementwiseParams {
numel,
scalar: 0.0,
_pad0: 0,
_pad1: 0,
};
let params_buf = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: Some("params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &ctx.pipelines.binary_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: lhs.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: rhs.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: dst.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
],
label: None,
});
let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups((numel + 63) / 64, 1, 1);
}
ctx.queue.submit(std::iter::once(encoder.finish()));
}
pub fn add(ctx: &GpuContext, lhs: &wgpu::Buffer, rhs: &wgpu::Buffer, dst: &wgpu::Buffer, numel: u32) {
dispatch_binary(ctx, &ctx.pipelines.add_pipeline, lhs, rhs, dst, numel);
}
pub fn sub(ctx: &GpuContext, lhs: &wgpu::Buffer, rhs: &wgpu::Buffer, dst: &wgpu::Buffer, numel: u32) {
dispatch_binary(ctx, &ctx.pipelines.sub_pipeline, lhs, rhs, dst, numel);
}
pub fn mul(ctx: &GpuContext, lhs: &wgpu::Buffer, rhs: &wgpu::Buffer, dst: &wgpu::Buffer, numel: u32) {
dispatch_binary(ctx, &ctx.pipelines.mul_pipeline, lhs, rhs, dst, numel);
}
pub fn relu_backward(
ctx: &GpuContext,
saved_input: &wgpu::Buffer,
out_grad: &wgpu::Buffer,
dst: &wgpu::Buffer,
numel: u32,
) {
dispatch_binary(ctx, &ctx.pipelines.relu_bw_pipeline, saved_input, out_grad, dst, numel);
}
fn dispatch_unary(
ctx: &GpuContext,
pipeline: &wgpu::ComputePipeline,
input: &wgpu::Buffer,
dst: &wgpu::Buffer,
numel: u32,
scalar: f32,
) {
let params = ElementwiseParams {
numel,
scalar,
_pad0: 0,
_pad1: 0,
};
let params_buf = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: Some("params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &ctx.pipelines.unary_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: input.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: dst.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: params_buf.as_entire_binding() },
],
label: None,
});
let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups((numel + 63) / 64, 1, 1);
}
ctx.queue.submit(std::iter::once(encoder.finish()));
}
pub fn relu(ctx: &GpuContext, input: &wgpu::Buffer, dst: &wgpu::Buffer, numel: u32) {
dispatch_unary(ctx, &ctx.pipelines.relu_pipeline, input, dst, numel, 0.0);
}
pub fn scale(ctx: &GpuContext, input: &wgpu::Buffer, dst: &wgpu::Buffer, numel: u32, scalar: f32) {
dispatch_unary(ctx, &ctx.pipelines.scale_pipeline, input, dst, numel, scalar);
}
pub fn matmul(
ctx: &GpuContext,
a: &wgpu::Buffer,
b: &wgpu::Buffer,
dst: &wgpu::Buffer,
m: u32,
k: u32,
n: u32,
) {
let params = MatmulParams { m, k, n, _pad: 0 };
let params_buf = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: Some("matmul_params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &ctx.pipelines.matmul_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: a.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: b.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: dst.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
],
label: None,
});
let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&ctx.pipelines.matmul_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups((n + 15) / 16, (m + 15) / 16, 1);
}
ctx.queue.submit(std::iter::once(encoder.finish()));
}
pub fn add_bias(
ctx: &GpuContext,
matrix: &wgpu::Buffer,
bias: &wgpu::Buffer,
dst: &wgpu::Buffer,
m: u32,
n: u32,
) {
let params = BiasParams { m, n, _pad0: 0, _pad1: 0 };
let params_buf = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: Some("bias_params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &ctx.pipelines.bias_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: matrix.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: bias.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: dst.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
],
label: None,
});
let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&ctx.pipelines.add_bias_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups((m * n + 63) / 64, 1, 1);
}
ctx.queue.submit(std::iter::once(encoder.finish()));
}
pub fn sum_rows(
ctx: &GpuContext,
matrix: &wgpu::Buffer,
bias_placeholder: &wgpu::Buffer,
dst: &wgpu::Buffer,
m: u32,
n: u32,
) {
let params = BiasParams { m, n, _pad0: 0, _pad1: 0 };
let params_buf = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: Some("sum_rows_params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &ctx.pipelines.bias_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: matrix.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: bias_placeholder.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: dst.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
],
label: None,
});
let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&ctx.pipelines.sum_rows_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups((n + 63) / 64, 1, 1);
}
ctx.queue.submit(std::iter::once(encoder.finish()));
}
pub fn sgd_step(
ctx: &GpuContext,
grad: &wgpu::Buffer,
vel: &wgpu::Buffer,
param: &wgpu::Buffer,
numel: u32,
lr: f32,
momentum: f32,
) {
let params = SgdParams { lr, momentum, numel, _pad: 0 };
let params_buf = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: Some("sgd_params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &ctx.pipelines.sgd_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: grad.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: vel.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: param.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
],
label: None,
});
let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&ctx.pipelines.sgd_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups((numel + 63) / 64, 1, 1);
}
ctx.queue.submit(std::iter::once(encoder.finish()));
}
pub fn adam_step(
ctx: &GpuContext,
grad: &wgpu::Buffer,
m: &wgpu::Buffer,
v: &wgpu::Buffer,
param: &wgpu::Buffer,
numel: u32,
lr: f32,
beta1: f32,
beta2: f32,
epsilon: f32,
bc1: f32,
bc2: f32,
) {
let params = AdamParams { lr, beta1, beta2, epsilon, bc1, bc2, numel, _pad: 0 };
let params_buf = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: Some("adam_params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &ctx.pipelines.adam_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: grad.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: m.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: v.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: param.as_entire_binding() },
wgpu::BindGroupEntry { binding: 4, resource: params_buf.as_entire_binding() },
],
label: None,
});
let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&ctx.pipelines.adam_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups((numel + 63) / 64, 1, 1);
}
ctx.queue.submit(std::iter::once(encoder.finish()));
}
pub fn im2col_dispatch(
ctx: &GpuContext,
input: &wgpu::Buffer,
dst: &wgpu::Buffer,
c_in: u32, h: u32, w: u32,
k: u32, stride: u32, pad: u32,
out_h: u32, out_w: u32,
) {
let params = Im2ColParams { c_in, h, w, k, stride, pad, out_h, out_w };
let params_buf = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: Some("im2col_params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &ctx.pipelines.unary_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: input.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: dst.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: params_buf.as_entire_binding() },
],
label: None,
});
let total = c_in * k * k * out_h * out_w;
let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&ctx.pipelines.im2col_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups((total + 63) / 64, 1, 1);
}
ctx.queue.submit(std::iter::once(encoder.finish()));
}
pub fn col2im_dispatch(
ctx: &GpuContext,
input: &wgpu::Buffer,
dst: &wgpu::Buffer,
c_in: u32, h: u32, w: u32,
k: u32, stride: u32, pad: u32,
out_h: u32, out_w: u32,
) {
let params = Im2ColParams { c_in, h, w, k, stride, pad, out_h, out_w };
let params_buf = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: Some("col2im_params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &ctx.pipelines.unary_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: input.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: dst.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: params_buf.as_entire_binding() },
],
label: None,
});
let total = c_in * h * w;
let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&ctx.pipelines.col2im_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups((total + 63) / 64, 1, 1);
}
ctx.queue.submit(std::iter::once(encoder.finish()));
}
pub fn add_channel_bias(
ctx: &GpuContext,
src: &wgpu::Buffer,
bias: &wgpu::Buffer,
dst: &wgpu::Buffer,
channels: u32,
spatial: u32,
) {
let params = ChannelBiasParams { channels, spatial, _pad0: 0, _pad1: 0 };
let params_buf = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: Some("channel_bias_params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &ctx.pipelines.bias_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: src.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: bias.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: dst.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
],
label: None,
});
let total = channels * spatial;
let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&ctx.pipelines.add_channel_bias_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups((total + 63) / 64, 1, 1);
}
ctx.queue.submit(std::iter::once(encoder.finish()));
}
pub fn max_pool2d_forward(
ctx: &GpuContext,
input: &wgpu::Buffer,
output: &wgpu::Buffer,
indices: &wgpu::Buffer,
channels: u32, h: u32, w: u32,
k: u32, stride: u32,
out_h: u32, out_w: u32,
) {
let params = MaxPool2dParams { channels, h, w, k, stride, out_h, out_w, _pad: 0 };
let params_buf = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: Some("pool_params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &ctx.pipelines.pool_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: input.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: output.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: indices.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
],
label: None,
});
let total = channels * out_h * out_w;
let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&ctx.pipelines.max_pool2d_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups((total + 63) / 64, 1, 1);
}
ctx.queue.submit(std::iter::once(encoder.finish()));
}
pub fn max_pool2d_backward(
ctx: &GpuContext,
out_grad: &wgpu::Buffer,
indices: &wgpu::Buffer,
grad_input: &wgpu::Buffer,
channels: u32, h: u32, w: u32,
out_h: u32, out_w: u32,
) {
let params = MaxPool2dParams { channels, h, w, k: 0, stride: 0, out_h, out_w, _pad: 0 };
let params_buf = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: Some("pool_bw_params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &ctx.pipelines.pool_bw_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: out_grad.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: indices.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: grad_input.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
],
label: None,
});
let total = channels * out_h * out_w;
let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&ctx.pipelines.max_pool2d_bw_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups((total + 63) / 64, 1, 1);
}
ctx.queue.submit(std::iter::once(encoder.finish()));
}
pub fn dropout_forward(
ctx: &GpuContext,
input: &wgpu::Buffer,
output: &wgpu::Buffer,
mask: &wgpu::Buffer,
numel: u32,
seed: u32,
p: f32,
) {
let scale = 1.0 / (1.0 - p);
let p_threshold = (p * u32::MAX as f32) as u32;
let params = DropoutParams { numel, seed, p_threshold, scale };
let params_buf = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: Some("dropout_params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &ctx.pipelines.pool_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: input.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: output.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: mask.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
],
label: None,
});
let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&ctx.pipelines.dropout_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups((numel + 63) / 64, 1, 1);
}
ctx.queue.submit(std::iter::once(encoder.finish()));
}
pub fn fused_dropout_forward(
ctx: &GpuContext,
input: &wgpu::Buffer,
output: &wgpu::Buffer,
mask: &wgpu::Buffer,
numel: u32,
seed: u32,
p: f32,
ndim: u32,
offset: u32,
shape: &[usize],
strides: &[usize],
suffix: &[usize],
) {
let scale = 1.0 / (1.0 - p);
let p_threshold = (p * u32::MAX as f32) as u32;
let mut params = FusedDropoutParams {
numel, seed, p_threshold, scale,
ndim, offset, _pad0: 0, _pad1: 0,
shape_lo: [0u32; 4], shape_hi: [0u32; 4],
strides_lo: [0u32; 4], strides_hi: [0u32; 4],
suffix_lo: [0u32; 4], suffix_hi: [0u32; 4],
};
for i in 0..ndim as usize {
if i < 4 {
params.shape_lo[i] = shape[i] as u32;
params.strides_lo[i] = strides[i] as u32;
params.suffix_lo[i] = suffix[i] as u32;
} else {
params.shape_hi[i - 4] = shape[i] as u32;
params.strides_hi[i - 4] = strides[i] as u32;
params.suffix_hi[i - 4] = suffix[i] as u32;
}
}
let params_buf = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: Some("fused_dropout_params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &ctx.pipelines.pool_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: input.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: output.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: mask.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
],
label: None,
});
let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&ctx.pipelines.fused_dropout_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups((numel + 63) / 64, 1, 1);
}
ctx.queue.submit(std::iter::once(encoder.finish()));
}
pub fn contiguous_copy(
ctx: &GpuContext,
src: &wgpu::Buffer,
dst: &wgpu::Buffer,
numel: u32,
ndim: u32,
offset: u32,
shape: &[usize],
strides: &[usize],
suffix: &[usize],
) {
let mut params = ContiguousParams {
numel,
ndim,
offset,
_pad: 0,
shape_lo: [0u32; 4],
shape_hi: [0u32; 4],
strides_lo: [0u32; 4],
strides_hi: [0u32; 4],
suffix_lo: [0u32; 4],
suffix_hi: [0u32; 4],
};
for i in 0..ndim as usize {
if i < 4 {
params.shape_lo[i] = shape[i] as u32;
params.strides_lo[i] = strides[i] as u32;
params.suffix_lo[i] = suffix[i] as u32;
} else {
params.shape_hi[i - 4] = shape[i] as u32;
params.strides_hi[i - 4] = strides[i] as u32;
params.suffix_hi[i - 4] = suffix[i] as u32;
}
}
let params_buf = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: Some("contiguous_params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &ctx.pipelines.unary_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: src.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: dst.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: params_buf.as_entire_binding() },
],
label: None,
});
let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&ctx.pipelines.contiguous_copy_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups((numel + 63) / 64, 1, 1);
}
ctx.queue.submit(std::iter::once(encoder.finish()));
}
pub fn broadcast_scale(
ctx: &GpuContext,
scalar_buf: &wgpu::Buffer,
src: &wgpu::Buffer,
dst: &wgpu::Buffer,
numel: u32,
) {
let params = BroadcastScaleParams { numel, _pad0: 0, _pad1: 0, _pad2: 0 };
let params_buf = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: Some("broadcast_scale_params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &ctx.pipelines.binary_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: scalar_buf.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: src.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: dst.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
],
label: None,
});
let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&ctx.pipelines.broadcast_scale_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups((numel + 63) / 64, 1, 1);
}
ctx.queue.submit(std::iter::once(encoder.finish()));
}
pub fn cross_entropy_forward(
ctx: &GpuContext,
logits: &wgpu::Buffer,
targets: &wgpu::Buffer,
grad: &wgpu::Buffer,
loss_per_b: &wgpu::Buffer,
batch: u32,
num_classes: u32,
) {
let params = CrossEntropyParams { batch, num_classes, _pad0: 0, _pad1: 0 };
let params_buf = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: Some("ce_params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &ctx.pipelines.ce_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: logits.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: targets.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: grad.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: loss_per_b.as_entire_binding() },
wgpu::BindGroupEntry { binding: 4, resource: params_buf.as_entire_binding() },
],
label: None,
});
let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&ctx.pipelines.ce_forward_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(batch, 1, 1);
}
ctx.queue.submit(std::iter::once(encoder.finish()));
}
pub fn reduce_loss(
ctx: &GpuContext,
input: &wgpu::Buffer,
output: &wgpu::Buffer,
numel: u32,
) {
let params = ReduceParams { numel, _pad0: 0, _pad1: 0, _pad2: 0 };
let params_buf = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: Some("reduce_params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &ctx.pipelines.unary_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: input.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: output.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: params_buf.as_entire_binding() },
],
label: None,
});
let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&ctx.pipelines.ce_reduce_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(1, 1, 1);
}
ctx.queue.submit(std::iter::once(encoder.finish()));
}
pub fn adamw_step(
ctx: &GpuContext,
grad: &wgpu::Buffer,
m: &wgpu::Buffer,
v: &wgpu::Buffer,
param: &wgpu::Buffer,
numel: u32,
lr: f32, beta1: f32, beta2: f32, epsilon: f32,
bc1: f32, bc2: f32, weight_decay: f32,
) {
let params = AdamWParams { lr, beta1, beta2, epsilon, bc1, bc2, weight_decay, numel };
let params_buf = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: Some("adamw_params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &ctx.pipelines.adam_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: grad.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: m.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: v.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: param.as_entire_binding() },
wgpu::BindGroupEntry { binding: 4, resource: params_buf.as_entire_binding() },
],
label: None,
});
let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&ctx.pipelines.adamw_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups((numel + 63) / 64, 1, 1);
}
ctx.queue.submit(std::iter::once(encoder.finish()));
}