#![allow(clippy::too_many_arguments)]
#![allow(clippy::needless_range_loop)]
use bytemuck::{Pod, Zeroable};
use futures_channel::oneshot;
use crate::backend::WgpuCtx;
use crate::backend::pipelines::Pipelines;
use crate::error::{Result, RullamaError};
pub async fn fence_submitted_work(device: &wgpu::Device, queue: &wgpu::Queue) -> Result<()> {
let (tx, rx) = oneshot::channel();
queue.on_submitted_work_done(move || {
let _ = tx.send(());
});
device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
})
.map_err(|e| RullamaError::Inference(format!("{e:?}")))?;
rx.await
.map_err(|e| RullamaError::BufferMap(format!("{e}")))?;
Ok(())
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct MatmulParams {
k: u32,
n: u32,
_p0: u32,
_p1: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct RmsParams {
n: u32,
eps: f32,
has_weight: u32,
_p: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct CapParams {
n: u32,
cap: f32,
_p0: u32,
_p1: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct XEntParams {
vocab_size: u32,
target: u32,
_p0: u32,
_p1: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct GegluParams {
n: u32,
_p0: u32,
_p1: u32,
_p2: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct RopeParams {
head_dim: u32,
n_heads: u32,
rope_dims: u32,
pos: u32,
base: f32,
has_factors: u32,
_p0: u32,
_p1: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct AttnParams {
head_dim: u32,
n_heads: u32,
n_kv_heads: u32,
heads_per_kv: u32,
pos: u32,
history_len: u32,
window: u32,
_p: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct ResAddParams {
n: u32,
_p0: u32,
_p1: u32,
_p2: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct ScaleParams {
n: u32,
s: f32,
_p0: u32,
_p1: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct RmsPerRowParams {
n_rows: u32,
row_dim: u32,
eps: f32,
has_weight: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct Conv2dParams {
in_c: u32,
in_h: u32,
in_w: u32,
out_c: u32,
out_h: u32,
out_w: u32,
k_h: u32,
k_w: u32,
s_h: u32,
s_w: u32,
p_h: u32,
p_w: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct ClampParams {
n: u32,
lo: f32,
hi: f32,
_p: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct AvgPool2dParams {
in_h: u32,
in_w: u32,
out_h: u32,
out_w: u32,
channels: u32,
k: u32,
_p0: u32,
_p1: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct Rope2dParams {
head_dim: u32,
n_heads: u32,
n_patches: u32,
base: f32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
pub(crate) struct BatchedMatmulParams {
pub k: u32,
pub n: u32,
pub batch: u32,
pub _pad: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct TransposeParams {
n_patches: u32,
n_heads: u32,
head_dim: u32,
_pad: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct PosEmbedAddParams {
n_patches: u32,
hidden_size: u32,
pos_size: u32,
_pad: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct VisionAttnParams {
head_dim: u32,
n_heads: u32,
n_patches: u32,
_pad: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct GluSplitParams {
seq: u32,
inner: u32,
_p0: u32,
_p1: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct DepthwiseConv1dParams {
seq: u32,
channels: u32,
kernel: u32,
_p: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct ScalarNParams {
n: u32,
_p0: u32,
_p1: u32,
_p2: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct BlockLocalAttnParams {
seq: u32,
padded_len: u32,
hidden: u32,
n_heads: u32,
head_dim: u32,
chunk_size: u32,
context_size: u32,
max_span: u32,
max_past: u32,
max_future: u32,
pad_left: u32,
logit_cap: f32,
}
fn dispatch_dims_1d(n_elements: u32, wg_size: u32) -> (u32, u32, u32) {
const MAX_WG_PER_DIM: u32 = 65535;
let total = n_elements.div_ceil(wg_size);
if total <= MAX_WG_PER_DIM {
(total, 1, 1)
} else {
(MAX_WG_PER_DIM, total.div_ceil(MAX_WG_PER_DIM), 1)
}
}
pub(crate) fn write_uniform<T: Pod>(
device: &wgpu::Device,
queue: &wgpu::Queue,
label: &str,
data: &T,
) -> wgpu::Buffer {
let buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: std::mem::size_of::<T>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
queue.write_buffer(&buf, 0, bytemuck::bytes_of(data));
buf
}
fn write_storage(
device: &wgpu::Device,
queue: &wgpu::Queue,
label: &str,
bytes: &[u8],
) -> wgpu::Buffer {
let buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: bytes.len().max(4) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
if !bytes.is_empty() {
queue.write_buffer(&buf, 0, bytes);
}
buf
}
fn write_storage_f32(
device: &wgpu::Device,
queue: &wgpu::Queue,
label: &str,
x: &[f32],
) -> wgpu::Buffer {
write_storage(device, queue, label, bytemuck::cast_slice(x))
}
fn make_output_pair(
device: &wgpu::Device,
label: &str,
n_bytes: u64,
) -> (wgpu::Buffer, wgpu::Buffer) {
let out = device.create_buffer(&wgpu::BufferDescriptor {
label: Some(&format!("{label}.out")),
size: n_bytes,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let read = device.create_buffer(&wgpu::BufferDescriptor {
label: Some(&format!("{label}.read")),
size: n_bytes,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
(out, read)
}
async fn read_back_f32(device: &wgpu::Device, read_buf: &wgpu::Buffer) -> Result<Vec<f32>> {
let slice = read_buf.slice(..);
let (sender, receiver) = oneshot::channel();
slice.map_async(wgpu::MapMode::Read, move |r| {
let _ = sender.send(r);
});
device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
})
.map_err(|e| RullamaError::Inference(format!("{e:?}")))?;
receiver
.await
.map_err(|e| RullamaError::BufferMap(format!("{e}")))?
.map_err(|e| RullamaError::BufferMap(format!("{e}")))?;
let data = slice.get_mapped_range();
let v: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
drop(data);
read_buf.unmap();
Ok(v)
}
async fn run_matmul(
ctx: &WgpuCtx,
pipeline: &wgpu::ComputePipeline,
label: &str,
w_bytes: &[u8],
x: &[f32],
k: usize,
n: usize,
) -> Result<Vec<f32>> {
let device = &ctx.device;
let queue = &ctx.queue;
let params = MatmulParams {
k: k as u32,
n: n as u32,
_p0: 0,
_p1: 0,
};
let p_buf = write_uniform(device, queue, &format!("{label}.params"), ¶ms);
let w_buf = write_storage(device, queue, &format!("{label}.W"), w_bytes);
let x_buf = write_storage_f32(device, queue, &format!("{label}.x"), x);
let n_bytes = (n * 4) as u64;
let (y_buf, read_buf) = make_output_pair(device, label, n_bytes);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some(&format!("{label}.bg")),
layout: &pipeline.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y_buf.as_entire_binding(),
},
],
});
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some(&format!("{label}.encoder")),
});
{
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(label),
timestamp_writes: None,
});
cp.set_pipeline(pipeline);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), 1, 1);
}
enc.copy_buffer_to_buffer(&y_buf, 0, &read_buf, 0, n_bytes);
queue.submit(Some(enc.finish()));
read_back_f32(device, &read_buf).await
}
pub async fn matmul_q4_k_cached(
ctx: &WgpuCtx,
p: &Pipelines,
w_bytes: &[u8],
x: &[f32],
k: usize,
n: usize,
) -> Result<Vec<f32>> {
run_matmul(ctx, &p.q4_k_matmul, "q4k_matmul", w_bytes, x, k, n).await
}
pub async fn matmul_q6_k_cached(
ctx: &WgpuCtx,
p: &Pipelines,
w_bytes: &[u8],
x: &[f32],
k: usize,
n: usize,
) -> Result<Vec<f32>> {
run_matmul(ctx, &p.q6_k_matmul, "q6k_matmul", w_bytes, x, k, n).await
}
#[allow(dead_code)]
pub async fn matmul_f16_cached(
ctx: &WgpuCtx,
p: &Pipelines,
w_bytes: &[u8],
x: &[f32],
k: usize,
n: usize,
) -> Result<Vec<f32>> {
run_matmul(ctx, &p.f16_matmul, "f16_matmul", w_bytes, x, k, n).await
}
async fn run_matmul_buf(
ctx: &WgpuCtx,
pipeline: &wgpu::ComputePipeline,
label: &str,
w_buf: &wgpu::Buffer,
x: &[f32],
k: usize,
n: usize,
) -> Result<Vec<f32>> {
let device = &ctx.device;
let queue = &ctx.queue;
let params = MatmulParams {
k: k as u32,
n: n as u32,
_p0: 0,
_p1: 0,
};
let p_buf = write_uniform(device, queue, &format!("{label}.params"), ¶ms);
let x_buf = write_storage_f32(device, queue, &format!("{label}.x"), x);
let n_bytes = (n * 4) as u64;
let (y_buf, read_buf) = make_output_pair(device, label, n_bytes);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some(&format!("{label}.bg")),
layout: &pipeline.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y_buf.as_entire_binding(),
},
],
});
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some(&format!("{label}.encoder")),
});
{
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(label),
timestamp_writes: None,
});
cp.set_pipeline(pipeline);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), 1, 1);
}
enc.copy_buffer_to_buffer(&y_buf, 0, &read_buf, 0, n_bytes);
queue.submit(Some(enc.finish()));
read_back_f32(device, &read_buf).await
}
pub async fn matmul_q4_k_buf(
ctx: &WgpuCtx,
p: &Pipelines,
w: &wgpu::Buffer,
x: &[f32],
k: usize,
n: usize,
) -> Result<Vec<f32>> {
run_matmul_buf(ctx, &p.q4_k_matmul, "q4k_matmul_buf", w, x, k, n).await
}
pub async fn matmul_q6_k_buf(
ctx: &WgpuCtx,
p: &Pipelines,
w: &wgpu::Buffer,
x: &[f32],
k: usize,
n: usize,
) -> Result<Vec<f32>> {
run_matmul_buf(ctx, &p.q6_k_matmul, "q6k_matmul_buf", w, x, k, n).await
}
#[allow(dead_code)]
pub async fn matmul_f16_buf(
ctx: &WgpuCtx,
p: &Pipelines,
w: &wgpu::Buffer,
x: &[f32],
k: usize,
n: usize,
) -> Result<Vec<f32>> {
run_matmul_buf(ctx, &p.f16_matmul, "f16_matmul_buf", w, x, k, n).await
}
pub async fn rmsnorm_cached(
ctx: &WgpuCtx,
p: &Pipelines,
x: &[f32],
weight: Option<&[f32]>,
eps: f32,
) -> Result<Vec<f32>> {
let n = x.len();
if n == 0 {
return Ok(Vec::new());
}
let device = &ctx.device;
let queue = &ctx.queue;
let params = RmsParams {
n: n as u32,
eps,
has_weight: if weight.is_some() { 1 } else { 0 },
_p: 0,
};
let p_buf = write_uniform(device, queue, "rms.params", ¶ms);
let x_buf = write_storage_f32(device, queue, "rms.x", x);
let w_buf = match weight {
Some(w) => write_storage_f32(device, queue, "rms.w", w),
None => write_storage(device, queue, "rms.w_dummy", &[0u8; 4]),
};
let n_bytes = (n * 4) as u64;
let (y_buf, read_buf) = make_output_pair(device, "rms", n_bytes);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("rms.bg"),
layout: &p.rmsnorm.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: w_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y_buf.as_entire_binding(),
},
],
});
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("rms.enc"),
});
{
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("rms.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.rmsnorm);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(1, 1, 1);
}
enc.copy_buffer_to_buffer(&y_buf, 0, &read_buf, 0, n_bytes);
queue.submit(Some(enc.finish()));
read_back_f32(device, &read_buf).await
}
pub async fn matmul_q4_k_backward_input_cached(
ctx: &WgpuCtx,
p: &Pipelines,
w_bytes: &[u8],
dy: &[f32],
k: usize,
n: usize,
) -> Result<Vec<f32>> {
if k == 0 || n == 0 {
return Ok(vec![0.0; k]);
}
let device = &ctx.device;
let queue = &ctx.queue;
let w_buf = write_storage(device, queue, "q4k_bwd.w", w_bytes);
let dy_buf = write_storage_f32(device, queue, "q4k_bwd.dy", dy);
let n_bytes = (k * 4) as u64;
let (dx_buf, dx_read) = make_output_pair(device, "q4k_bwd.dx", n_bytes);
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("q4k_bwd.enc"),
});
matmul_q4_k_backward_input_chained(ctx, p, &mut enc, &w_buf, &dy_buf, &dx_buf, k, n);
enc.copy_buffer_to_buffer(&dx_buf, 0, &dx_read, 0, n_bytes);
queue.submit(Some(enc.finish()));
read_back_f32(device, &dx_read).await
}
pub async fn matmul_q6_k_backward_input_cached(
ctx: &WgpuCtx,
p: &Pipelines,
w_bytes: &[u8],
dy: &[f32],
k: usize,
n: usize,
) -> Result<Vec<f32>> {
if k == 0 || n == 0 {
return Ok(vec![0.0; k]);
}
let device = &ctx.device;
let queue = &ctx.queue;
let w_buf = write_storage(device, queue, "q6k_bwd.w", w_bytes);
let dy_buf = write_storage_f32(device, queue, "q6k_bwd.dy", dy);
let n_bytes = (k * 4) as u64;
let (dx_buf, dx_read) = make_output_pair(device, "q6k_bwd.dx", n_bytes);
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("q6k_bwd.enc"),
});
matmul_q6_k_backward_input_chained(ctx, p, &mut enc, &w_buf, &dy_buf, &dx_buf, k, n);
enc.copy_buffer_to_buffer(&dx_buf, 0, &dx_read, 0, n_bytes);
queue.submit(Some(enc.finish()));
read_back_f32(device, &dx_read).await
}
pub async fn cross_entropy_backward_cached(
ctx: &WgpuCtx,
p: &Pipelines,
logits: &[f32],
target: u32,
) -> Result<(Vec<f32>, f32)> {
let n = logits.len();
if n == 0 {
return Ok((Vec::new(), 0.0));
}
let device = &ctx.device;
let queue = &ctx.queue;
let logits_buf = write_storage_f32(device, queue, "xent.logits", logits);
let n_bytes = (n * 4) as u64;
let (d_logits_buf, d_logits_read) = make_output_pair(device, "xent.dlog", n_bytes);
let (loss_buf, loss_read) = make_output_pair(device, "xent.loss", 4);
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("xent.enc"),
});
cross_entropy_backward_chained(
ctx,
p,
&mut enc,
&logits_buf,
&d_logits_buf,
&loss_buf,
n,
target,
);
enc.copy_buffer_to_buffer(&d_logits_buf, 0, &d_logits_read, 0, n_bytes);
enc.copy_buffer_to_buffer(&loss_buf, 0, &loss_read, 0, 4);
queue.submit(Some(enc.finish()));
let d_logits = read_back_f32(device, &d_logits_read).await?;
let loss_vec = read_back_f32(device, &loss_read).await?;
Ok((d_logits, loss_vec[0]))
}
pub async fn softcap_cached(ctx: &WgpuCtx, p: &Pipelines, x: &[f32], cap: f32) -> Result<Vec<f32>> {
let n = x.len();
if n == 0 {
return Ok(Vec::new());
}
let device = &ctx.device;
let queue = &ctx.queue;
let params = CapParams {
n: n as u32,
cap,
_p0: 0,
_p1: 0,
};
let p_buf = write_uniform(device, queue, "cap.params", ¶ms);
let x_buf = write_storage_f32(device, queue, "cap.x", x);
let n_bytes = (n * 4) as u64;
let (y_buf, read_buf) = make_output_pair(device, "cap", n_bytes);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("cap.bg"),
layout: &p.softcap.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: y_buf.as_entire_binding(),
},
],
});
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("cap.enc"),
});
{
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("cap.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.softcap);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), 1, 1);
}
enc.copy_buffer_to_buffer(&y_buf, 0, &read_buf, 0, n_bytes);
queue.submit(Some(enc.finish()));
read_back_f32(device, &read_buf).await
}
pub async fn geglu_cached(
ctx: &WgpuCtx,
p: &Pipelines,
gate: &[f32],
up: &[f32],
) -> Result<Vec<f32>> {
if gate.len() != up.len() {
return Err(RullamaError::Inference(
"geglu: gate/up length mismatch".into(),
));
}
let n = gate.len();
if n == 0 {
return Ok(Vec::new());
}
let device = &ctx.device;
let queue = &ctx.queue;
let params = GegluParams {
n: n as u32,
_p0: 0,
_p1: 0,
_p2: 0,
};
let p_buf = write_uniform(device, queue, "geglu.params", ¶ms);
let gate_buf = write_storage_f32(device, queue, "geglu.gate", gate);
let up_buf = write_storage_f32(device, queue, "geglu.up", up);
let n_bytes = (n * 4) as u64;
let (y_buf, read_buf) = make_output_pair(device, "geglu", n_bytes);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("geglu.bg"),
layout: &p.geglu.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: gate_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: up_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y_buf.as_entire_binding(),
},
],
});
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("geglu.enc"),
});
{
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("geglu.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.geglu);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), 1, 1);
}
enc.copy_buffer_to_buffer(&y_buf, 0, &read_buf, 0, n_bytes);
queue.submit(Some(enc.finish()));
read_back_f32(device, &read_buf).await
}
pub async fn rope_neox_cached(
ctx: &WgpuCtx,
p: &Pipelines,
x: &[f32],
head_dim: usize,
n_heads: usize,
pos: usize,
rope_dims: usize,
base: f32,
factors: Option<&[f32]>,
) -> Result<Vec<f32>> {
if x.len() != head_dim * n_heads {
return Err(RullamaError::Inference("rope: shape mismatch".into()));
}
if rope_dims > head_dim || !rope_dims.is_multiple_of(2) {
return Err(RullamaError::Inference("rope: bad rope_dims".into()));
}
let device = &ctx.device;
let queue = &ctx.queue;
let params = RopeParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
rope_dims: rope_dims as u32,
pos: pos as u32,
base,
has_factors: if factors.is_some() { 1 } else { 0 },
_p0: 0,
_p1: 0,
};
let p_buf = write_uniform(device, queue, "rope.params", ¶ms);
let x_bytes = (x.len() * 4) as u64;
let x_buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("rope.x"),
size: x_bytes,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
queue.write_buffer(&x_buf, 0, bytemuck::cast_slice(x));
let factors_buf = match factors {
Some(f) => write_storage_f32(device, queue, "rope.factors", f),
None => write_storage(device, queue, "rope.factors_dummy", &[0u8; 4]),
};
let read_buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("rope.read"),
size: x_bytes,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("rope.bg"),
layout: &p.rope_neox.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: factors_buf.as_entire_binding(),
},
],
});
let total = (n_heads * (rope_dims / 2)) as u32;
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("rope.enc"),
});
{
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("rope.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.rope_neox);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(total.div_ceil(64), 1, 1);
}
enc.copy_buffer_to_buffer(&x_buf, 0, &read_buf, 0, x_bytes);
queue.submit(Some(enc.finish()));
read_back_f32(device, &read_buf).await
}
pub fn make_dummy_storage(device: &wgpu::Device, label: &str) -> wgpu::Buffer {
device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: 4,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
})
}
fn matmul_chained_inner(
device: &wgpu::Device,
queue: &wgpu::Queue,
enc: &mut wgpu::CommandEncoder,
pipeline: &wgpu::ComputePipeline,
label: &str,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
) {
let params = MatmulParams {
k: k as u32,
n: n as u32,
_p0: 0,
_p1: 0,
};
let p_buf = write_uniform(device, queue, &format!("{label}.params"), ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some(&format!("{label}.bg")),
layout: &pipeline.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(label),
timestamp_writes: None,
});
cp.set_pipeline(pipeline);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), 1, 1);
}
pub fn matmul_q4_k_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
) {
matmul_chained_inner(
&ctx.device,
&ctx.queue,
enc,
&p.q4_k_matmul,
"q4k_chain",
w,
x,
y,
k,
n,
);
}
pub fn matmul_q6_k_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
) {
matmul_chained_inner(
&ctx.device,
&ctx.queue,
enc,
&p.q6_k_matmul,
"q6k_chain",
w,
x,
y,
k,
n,
);
}
#[allow(dead_code)]
pub fn matmul_f16_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
) {
matmul_chained_inner(
&ctx.device,
&ctx.queue,
enc,
&p.f16_matmul,
"f16_chain",
w,
x,
y,
k,
n,
);
}
#[allow(dead_code)]
pub fn matmul_bf16_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
) {
matmul_chained_inner(
&ctx.device,
&ctx.queue,
enc,
&p.bf16_matmul,
"bf16_chain",
w,
x,
y,
k,
n,
);
}
pub fn rmsnorm_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
weight: Option<&wgpu::Buffer>,
dummy: &wgpu::Buffer,
y: &wgpu::Buffer,
n: usize,
eps: f32,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = RmsParams {
n: n as u32,
eps,
has_weight: weight.is_some() as u32,
_p: 0,
};
let p_buf = write_uniform(device, queue, "rms_chain.params", ¶ms);
let w_buf = weight.unwrap_or(dummy);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("rms_chain.bg"),
layout: &p.rmsnorm.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: w_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("rms_chain.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.rmsnorm);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(1, 1, 1);
}
pub fn half_residual_add_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
n: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = ScalarNParams {
n: n as u32,
_p0: 0,
_p1: 0,
_p2: 0,
};
let p_buf = write_uniform(device, queue, "halfres.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("halfres.bg"),
layout: &p.half_residual_add.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("halfres.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.half_residual_add);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), 1, 1);
}
pub fn silu_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
n: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = ScalarNParams {
n: n as u32,
_p0: 0,
_p1: 0,
_p2: 0,
};
let p_buf = write_uniform(device, queue, "silu.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("silu.bg"),
layout: &p.silu.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("silu.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.silu);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), 1, 1);
}
pub fn glu_split_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
seq: usize,
inner: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = GluSplitParams {
seq: seq as u32,
inner: inner as u32,
_p0: 0,
_p1: 0,
};
let p_buf = write_uniform(device, queue, "glu.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("glu.bg"),
layout: &p.glu_split.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: y.as_entire_binding(),
},
],
});
let total = (seq * inner) as u32;
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("glu.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.glu_split);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(total.div_ceil(64), 1, 1);
}
pub fn depthwise_conv1d_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
w: &wgpu::Buffer,
y: &wgpu::Buffer,
seq: usize,
channels: usize,
kernel: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = DepthwiseConv1dParams {
seq: seq as u32,
channels: channels as u32,
kernel: kernel as u32,
_p: 0,
};
let p_buf = write_uniform(device, queue, "dwconv.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("dwconv.bg"),
layout: &p.depthwise_conv1d.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let total = (seq * channels) as u32;
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("dwconv.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.depthwise_conv1d);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(total.div_ceil(64), 1, 1);
}
pub fn pos_embed_add_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
hidden: &wgpu::Buffer,
pos_embd: &wgpu::Buffer,
pos_x: &wgpu::Buffer,
pos_y: &wgpu::Buffer,
n_patches: usize,
hidden_size: usize,
pos_size: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = PosEmbedAddParams {
n_patches: n_patches as u32,
hidden_size: hidden_size as u32,
pos_size: pos_size as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "posembed.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("posembed.bg"),
layout: &p.pos_embed_add.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: hidden.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: pos_embd.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: pos_x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: pos_y.as_entire_binding(),
},
],
});
let total = (n_patches * hidden_size) as u32;
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("posembed.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.pos_embed_add);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(total.div_ceil(64), 1, 1);
}
pub fn vision_attention_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
if head_dim <= 64 && n_patches >= 8 {
if let Some(sub) = p.vision_attention_flash_subgroup.as_ref() {
return vision_attention_flash_subgroup_chained(
ctx, p, sub, enc, q, k, v, out, head_dim, n_heads, n_patches,
);
}
vision_attention_flash_q8_chained(ctx, p, enc, q, k, v, out, head_dim, n_heads, n_patches);
return;
}
if head_dim <= 64 && n_patches >= 4 {
vision_attention_flash_q4_chained(ctx, p, enc, q, k, v, out, head_dim, n_heads, n_patches);
return;
}
if head_dim <= 64 {
vision_attention_flash_chained(ctx, p, enc, q, k, v, out, head_dim, n_heads, n_patches);
return;
}
let device = &ctx.device;
let queue = &ctx.queue;
let params = VisionAttnParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "vattn.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vattn.bg"),
layout: &p.vision_attention.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vattn.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.vision_attention);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(n_patches as u32, n_heads as u32, 1);
}
pub fn vision_attention_flash_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = VisionAttnParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "vattnf.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vattnf.bg"),
layout: &p.vision_attention_flash.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vattnf.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.vision_attention_flash);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(n_patches as u32, n_heads as u32, 1);
}
pub fn vision_attention_flash_q16_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = VisionAttnParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "vattnq16.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vattnq16.bg"),
layout: &p.vision_attention_flash_q16.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vattnq16.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.vision_attention_flash_q16);
cp.set_bind_group(0, &bg, &[]);
let n_query_groups = (n_patches as u32).div_ceil(16);
cp.dispatch_workgroups(n_query_groups, n_heads as u32, 1);
}
pub fn vision_attention_flash_q8_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = VisionAttnParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "vattnq8.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vattnq8.bg"),
layout: &p.vision_attention_flash_q8.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vattnq8.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.vision_attention_flash_q8);
cp.set_bind_group(0, &bg, &[]);
let n_query_groups = (n_patches as u32).div_ceil(8);
cp.dispatch_workgroups(n_query_groups, n_heads as u32, 1);
}
pub fn vision_attention_flash_subgroup_chained(
ctx: &WgpuCtx,
p: &Pipelines,
sub: &wgpu::ComputePipeline,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
let _ = p; let device = &ctx.device;
let queue = &ctx.queue;
let params = VisionAttnParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "vattnSub.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vattnSub.bg"),
layout: &sub.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vattnSub.pass"),
timestamp_writes: None,
});
cp.set_pipeline(sub);
cp.set_bind_group(0, &bg, &[]);
let n_query_groups = (n_patches as u32).div_ceil(8);
cp.dispatch_workgroups(n_query_groups, n_heads as u32, 1);
}
pub fn transpose_phd_to_hpd_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
src: &wgpu::Buffer,
dst: &wgpu::Buffer,
n_patches: usize,
n_heads: usize,
head_dim: usize,
) {
transpose_chained(
ctx,
&p.transpose_phd_to_hpd,
"tposePHDtoHPD",
enc,
src,
dst,
n_patches,
n_heads,
head_dim,
);
}
pub fn transpose_hpd_to_phd_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
src: &wgpu::Buffer,
dst: &wgpu::Buffer,
n_patches: usize,
n_heads: usize,
head_dim: usize,
) {
transpose_chained(
ctx,
&p.transpose_hpd_to_phd,
"tposeHPDtoPHD",
enc,
src,
dst,
n_patches,
n_heads,
head_dim,
);
}
fn transpose_chained(
ctx: &WgpuCtx,
pipe: &wgpu::ComputePipeline,
label: &str,
enc: &mut wgpu::CommandEncoder,
src: &wgpu::Buffer,
dst: &wgpu::Buffer,
n_patches: usize,
n_heads: usize,
head_dim: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = TransposeParams {
n_patches: n_patches as u32,
n_heads: n_heads as u32,
head_dim: head_dim as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, &format!("{label}.params"), ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some(&format!("{label}.bg")),
layout: &pipe.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: src.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: dst.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(&format!("{label}.pass")),
timestamp_writes: None,
});
cp.set_pipeline(pipe);
cp.set_bind_group(0, &bg, &[]);
let total = (n_patches * n_heads * head_dim) as u32;
cp.dispatch_workgroups(total.div_ceil(64), 1, 1);
}
pub fn vision_attention_flash_sub_hpd_f16_chained(
ctx: &WgpuCtx,
p: &Pipelines,
pipe: &wgpu::ComputePipeline,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
vision_attention_flash_sub_hpd_chained(
ctx, p, pipe, enc, q, k, v, out, head_dim, n_heads, n_patches,
);
}
pub fn vision_attention_flash_hpd_f16_chained(
ctx: &WgpuCtx,
p: &Pipelines,
pipe: &wgpu::ComputePipeline,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
vision_attention_flash_sub_hpd_chained(
ctx, p, pipe, enc, q, k, v, out, head_dim, n_heads, n_patches,
);
}
pub fn vision_attention_flash_sub_hpd_f16_q16_chained(
ctx: &WgpuCtx,
p: &Pipelines,
pipe: &wgpu::ComputePipeline,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
let _ = p;
let device = &ctx.device;
let queue = &ctx.queue;
let params = VisionAttnParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "vattnSubHPDQ16.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vattnSubHPDQ16.bg"),
layout: &pipe.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vattnSubHPDQ16.pass"),
timestamp_writes: None,
});
cp.set_pipeline(pipe);
cp.set_bind_group(0, &bg, &[]);
let n_query_groups = (n_patches as u32).div_ceil(16);
cp.dispatch_workgroups(n_query_groups, n_heads as u32, 1);
}
pub fn vision_attention_flash_sub_hpd_chained(
ctx: &WgpuCtx,
p: &Pipelines,
pipe: &wgpu::ComputePipeline,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
let _ = p;
let device = &ctx.device;
let queue = &ctx.queue;
let params = VisionAttnParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "vattnSubHPD.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vattnSubHPD.bg"),
layout: &pipe.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vattnSubHPD.pass"),
timestamp_writes: None,
});
cp.set_pipeline(pipe);
cp.set_bind_group(0, &bg, &[]);
let n_query_groups = (n_patches as u32).div_ceil(8);
cp.dispatch_workgroups(n_query_groups, n_heads as u32, 1);
}
pub fn vision_attention_flash_sub_t64_chained(
ctx: &WgpuCtx,
p: &Pipelines,
pipe: &wgpu::ComputePipeline,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
let _ = p;
let device = &ctx.device;
let queue = &ctx.queue;
let params = VisionAttnParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "vattnSubT64.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vattnSubT64.bg"),
layout: &pipe.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vattnSubT64.pass"),
timestamp_writes: None,
});
cp.set_pipeline(pipe);
cp.set_bind_group(0, &bg, &[]);
let n_query_groups = (n_patches as u32).div_ceil(8);
cp.dispatch_workgroups(n_query_groups, n_heads as u32, 1);
}
pub fn vision_attention_flash_q4_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = VisionAttnParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "vattnq4.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vattnq4.bg"),
layout: &p.vision_attention_flash_q4.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vattnq4.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.vision_attention_flash_q4);
cp.set_bind_group(0, &bg, &[]);
let n_query_groups = (n_patches as u32).div_ceil(4);
cp.dispatch_workgroups(n_query_groups, n_heads as u32, 1);
}
pub fn block_local_attention_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
q_pad: &wgpu::Buffer,
k_padded: &wgpu::Buffer,
v_padded: &wgpu::Buffer,
pos_proj: &wgpu::Buffer,
attn_out: &wgpu::Buffer,
seq: usize,
padded_len: usize,
hidden: usize,
n_heads: usize,
head_dim: usize,
chunk_size: usize,
context_size: usize,
max_span: usize,
max_past: usize,
max_future: usize,
pad_left: usize,
logit_cap: f32,
) {
debug_assert_eq!(
head_dim, 128,
"block_local_attention.wgsl is hard-coded to head_dim=128"
);
let device = &ctx.device;
let queue = &ctx.queue;
let params = BlockLocalAttnParams {
seq: seq as u32,
padded_len: padded_len as u32,
hidden: hidden as u32,
n_heads: n_heads as u32,
head_dim: head_dim as u32,
chunk_size: chunk_size as u32,
context_size: context_size as u32,
max_span: max_span as u32,
max_past: max_past as u32,
max_future: max_future as u32,
pad_left: pad_left as u32,
logit_cap,
};
let p_buf = write_uniform(device, queue, "blattn.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("blattn.bg"),
layout: &p.block_local_attention.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q_pad.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k_padded.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v_padded.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: pos_proj.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: attn_out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("blattn.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.block_local_attention);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(padded_len as u32, n_heads as u32, 1);
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct ScalePerInnerDimParams {
n: u32,
inner_dim: u32,
_p0: u32,
_p1: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct AddBiasBatchedParams {
n: u32,
batch: u32,
_p0: u32,
_p1: u32,
}
pub fn scale_per_inner_dim_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
s: &wgpu::Buffer,
n: usize,
inner_dim: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = ScalePerInnerDimParams {
n: n as u32,
inner_dim: inner_dim as u32,
_p0: 0,
_p1: 0,
};
let p_buf = write_uniform(device, queue, "scale_pd.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("scale_pd.bg"),
layout: &p.scale_per_inner_dim.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: s.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("scale_pd.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.scale_per_inner_dim);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), 1, 1);
}
pub fn add_bias_batched_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
y: &wgpu::Buffer,
bias: &wgpu::Buffer,
n: usize,
batch: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = AddBiasBatchedParams {
n: n as u32,
batch: batch as u32,
_p0: 0,
_p1: 0,
};
let p_buf = write_uniform(device, queue, "addbias.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("addbias.bg"),
layout: &p.add_bias_batched.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: y.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: bias.as_entire_binding(),
},
],
});
let total = (n * batch) as u32;
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("addbias.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.add_bias_batched);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(total.div_ceil(64), 1, 1);
}
fn use_tiled_batched(k: usize, n: usize, batch: usize) -> bool {
k >= 16 && n >= 8 && batch >= 8
}
fn use_tiled_batched_v2(k: usize, n: usize, batch: usize) -> bool {
k >= 16 && n >= 16 && batch >= 16
}
fn use_tiled_batched_v3(k: usize, n: usize, batch: usize) -> bool {
k >= 16 && n >= 32 && batch >= 32
}
pub fn matmul_bf16_batched_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
if use_tiled_batched_v3(k, n, batch) {
if let Some(pipe_f) = p.bf16_matmul_batched_tiled_v3_f16lds.as_ref() {
return matmul_bf16_batched_tiled_v3_f16lds_chained(
ctx, p, pipe_f, enc, w, x, y, k, n, batch,
);
}
matmul_bf16_batched_tiled_v3_chained(ctx, p, enc, w, x, y, k, n, batch);
return;
}
if use_tiled_batched_v2(k, n, batch) {
matmul_bf16_batched_tiled_v2_chained(ctx, p, enc, w, x, y, k, n, batch);
return;
}
if use_tiled_batched(k, n, batch) {
matmul_bf16_batched_tiled_chained(ctx, p, enc, w, x, y, k, n, batch);
return;
}
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "bf16bmm.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("bf16bmm.bg"),
layout: &p.bf16_matmul_batched.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("bf16bmm.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.bf16_matmul_batched);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), batch as u32, 1);
}
pub fn matmul_f16_batched_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
if use_tiled_batched_v3(k, n, batch) {
if let Some(pipe_f) = p.f16_matmul_batched_tiled_v3_f16lds.as_ref() {
return matmul_f16_batched_tiled_v3_f16lds_chained(
ctx, p, pipe_f, enc, w, x, y, k, n, batch,
);
}
matmul_f16_batched_tiled_v3_chained(ctx, p, enc, w, x, y, k, n, batch);
return;
}
if use_tiled_batched_v2(k, n, batch) {
matmul_f16_batched_tiled_v2_chained(ctx, p, enc, w, x, y, k, n, batch);
return;
}
if use_tiled_batched(k, n, batch) {
matmul_f16_batched_tiled_chained(ctx, p, enc, w, x, y, k, n, batch);
return;
}
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "f16bmm.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("f16bmm.bg"),
layout: &p.f16_matmul_batched.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("f16bmm.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.f16_matmul_batched);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), batch as u32, 1);
}
pub fn matmul_f16_batched_tiled_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "f16bmmt.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("f16bmmt.bg"),
layout: &p.f16_matmul_batched_tiled.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("f16bmmt.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.f16_matmul_batched_tiled);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(8), (batch as u32).div_ceil(8), 1);
}
pub fn matmul_bf16_batched_tiled_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "bf16bmmt.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("bf16bmmt.bg"),
layout: &p.bf16_matmul_batched_tiled.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("bf16bmmt.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.bf16_matmul_batched_tiled);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(8), (batch as u32).div_ceil(8), 1);
}
pub fn matmul_f16_batched_tiled_v2_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "f16bmmt2.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("f16bmmt2.bg"),
layout: &p.f16_matmul_batched_tiled_v2.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("f16bmmt2.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.f16_matmul_batched_tiled_v2);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(16), (batch as u32).div_ceil(16), 1);
}
pub fn matmul_f16_batched_tiled_v3_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "f16bmmt3.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("f16bmmt3.bg"),
layout: &p.f16_matmul_batched_tiled_v3.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("f16bmmt3.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.f16_matmul_batched_tiled_v3);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(32), (batch as u32).div_ceil(32), 1);
}
pub fn matmul_bf16_batched_tiled_v3_f16lds_chained(
ctx: &WgpuCtx,
p: &Pipelines,
pipe: &wgpu::ComputePipeline,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
let _ = p;
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "bf16bmmt3f.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("bf16bmmt3f.bg"),
layout: &pipe.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("bf16bmmt3f.pass"),
timestamp_writes: None,
});
cp.set_pipeline(pipe);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(32), (batch as u32).div_ceil(32), 1);
}
pub fn matmul_f16_batched_tiled_v3_f16lds_chained(
ctx: &WgpuCtx,
p: &Pipelines,
pipe: &wgpu::ComputePipeline,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
let _ = p;
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "f16bmmt3f.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("f16bmmt3f.bg"),
layout: &pipe.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("f16bmmt3f.pass"),
timestamp_writes: None,
});
cp.set_pipeline(pipe);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(32), (batch as u32).div_ceil(32), 1);
}
pub fn matmul_f16_batched_tiled_v4_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "f16bmmt4.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("f16bmmt4.bg"),
layout: &p.f16_matmul_batched_tiled_v4.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("f16bmmt4.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.f16_matmul_batched_tiled_v4);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(32), (batch as u32).div_ceil(64), 1);
}
pub fn matmul_bf16_batched_tiled_v3_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "bf16bmmt3.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("bf16bmmt3.bg"),
layout: &p.bf16_matmul_batched_tiled_v3.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("bf16bmmt3.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.bf16_matmul_batched_tiled_v3);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(32), (batch as u32).div_ceil(32), 1);
}
pub fn matmul_bf16_batched_tiled_v2_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "bf16bmmt2.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("bf16bmmt2.bg"),
layout: &p.bf16_matmul_batched_tiled_v2.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("bf16bmmt2.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.bf16_matmul_batched_tiled_v2);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(16), (batch as u32).div_ceil(16), 1);
}
pub fn conv2d_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
in_c: usize,
in_h: usize,
in_w: usize,
out_c: usize,
out_h: usize,
out_w: usize,
k_h: usize,
k_w: usize,
s_h: usize,
s_w: usize,
pad_h: usize,
pad_w: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = Conv2dParams {
in_c: in_c as u32,
in_h: in_h as u32,
in_w: in_w as u32,
out_c: out_c as u32,
out_h: out_h as u32,
out_w: out_w as u32,
k_h: k_h as u32,
k_w: k_w as u32,
s_h: s_h as u32,
s_w: s_w as u32,
p_h: pad_h as u32,
p_w: pad_w as u32,
};
let p_buf = write_uniform(device, queue, "conv2d.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("conv2d.bg"),
layout: &p.conv2d.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let total = (out_c * out_h * out_w) as u32;
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("conv2d.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.conv2d);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(total.div_ceil(64), 1, 1);
}
pub fn clamp_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
n: usize,
lo: f32,
hi: f32,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = ClampParams {
n: n as u32,
lo,
hi,
_p: 0,
};
let p_buf = write_uniform(device, queue, "clamp.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("clamp.bg"),
layout: &p.clamp.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("clamp.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.clamp);
cp.set_bind_group(0, &bg, &[]);
let (dx, dy, dz) = dispatch_dims_1d(n as u32, 64);
cp.dispatch_workgroups(dx, dy, dz);
}
pub fn quick_geglu_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
gate: &wgpu::Buffer,
up: &wgpu::Buffer,
y: &wgpu::Buffer,
n: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = GegluParams {
n: n as u32,
_p0: 0,
_p1: 0,
_p2: 0,
};
let p_buf = write_uniform(device, queue, "qgeglu.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("qgeglu.bg"),
layout: &p.quick_geglu.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: gate.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: up.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("qgeglu.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.quick_geglu);
cp.set_bind_group(0, &bg, &[]);
let (dx, dy, dz) = dispatch_dims_1d(n as u32, 64);
cp.dispatch_workgroups(dx, dy, dz);
}
pub fn avg_pool2d_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
in_h: usize,
in_w: usize,
channels: usize,
k: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let out_h = in_h / k;
let out_w = in_w / k;
let params = AvgPool2dParams {
in_h: in_h as u32,
in_w: in_w as u32,
out_h: out_h as u32,
out_w: out_w as u32,
channels: channels as u32,
k: k as u32,
_p0: 0,
_p1: 0,
};
let p_buf = write_uniform(device, queue, "pool2d.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("pool2d.bg"),
layout: &p.avg_pool2d.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: y.as_entire_binding(),
},
],
});
let total = (out_h * out_w * channels) as u32;
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("pool2d.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.avg_pool2d);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(total.div_ceil(64), 1, 1);
}
pub fn rope_2d_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
pos_x: &wgpu::Buffer,
pos_y: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
base: f32,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = Rope2dParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
base,
};
let p_buf = write_uniform(device, queue, "rope2d.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("rope2d.bg"),
layout: &p.rope_2d.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: pos_x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: pos_y.as_entire_binding(),
},
],
});
let total = (n_patches * n_heads * (head_dim / 2)) as u32;
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("rope2d.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.rope_2d);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(total.div_ceil(64), 1, 1);
}
pub fn rmsnorm_per_row_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
weight: Option<&wgpu::Buffer>,
dummy: &wgpu::Buffer,
y: &wgpu::Buffer,
n_rows: usize,
row_dim: usize,
eps: f32,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = RmsPerRowParams {
n_rows: n_rows as u32,
row_dim: row_dim as u32,
eps,
has_weight: weight.is_some() as u32,
};
let p_buf = write_uniform(device, queue, "rmspr_chain.params", ¶ms);
let w_buf = weight.unwrap_or(dummy);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("rmspr_chain.bg"),
layout: &p.rmsnorm_per_row.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: w_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("rmspr_chain.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.rmsnorm_per_row);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(n_rows as u32, 1, 1);
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct AdamParams {
n: u32,
step: u32,
_pad0: u32,
_pad1: u32,
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
_pad2: f32,
_pad3: f32,
_pad4: f32,
}
#[derive(Clone, Copy, Debug)]
pub struct AdamConfig {
pub lr: f32,
pub beta1: f32,
pub beta2: f32,
pub eps: f32,
pub weight_decay: f32,
pub step: u32,
}
impl Default for AdamConfig {
fn default() -> Self {
Self {
lr: 1e-3,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
weight_decay: 0.0,
step: 1,
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn adam_step_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
grad: &wgpu::Buffer,
param: &wgpu::Buffer,
m: &wgpu::Buffer,
v: &wgpu::Buffer,
n: usize,
cfg: AdamConfig,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = AdamParams {
n: n as u32,
step: cfg.step,
_pad0: 0,
_pad1: 0,
lr: cfg.lr,
beta1: cfg.beta1,
beta2: cfg.beta2,
eps: cfg.eps,
weight_decay: cfg.weight_decay,
_pad2: 0.0,
_pad3: 0.0,
_pad4: 0.0,
};
let p_buf = write_uniform(device, queue, "adam.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("adam.bg"),
layout: &p.adam_step.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: grad.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: param.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: m.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: v.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("adam.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.adam_step);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), 1, 1);
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct SumOfSquaresParams {
n: u32,
scale_in: f32,
_p0: u32,
_p1: u32,
}
pub fn sum_of_squares_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
input: &wgpu::Buffer,
output: &wgpu::Buffer,
n: usize,
scale_in: f32,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = SumOfSquaresParams {
n: n as u32,
scale_in,
_p0: 0,
_p1: 0,
};
let p_buf = write_uniform(device, queue, "sos.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("sos.bg"),
layout: &p.sum_of_squares.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: input.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: output.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("sos.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.sum_of_squares);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(1, 1, 1);
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct LoraMatmulParams {
k: u32,
n: u32,
accumulate: u32,
_pad: u32,
scale: f32,
_pad2: u32,
_pad3: u32,
_pad4: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct LoraOuterParams {
outer_a: u32,
outer_b: u32,
accumulate: u32,
_pad: u32,
scale: f32,
_pad2: u32,
_pad3: u32,
_pad4: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn lora_matmul_row_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
scale: f32,
accumulate: bool,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = LoraMatmulParams {
k: k as u32,
n: n as u32,
accumulate: accumulate as u32,
_pad: 0,
scale,
_pad2: 0,
_pad3: 0,
_pad4: 0,
};
let p_buf = write_uniform(device, queue, "lora_mm_row.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("lora_mm_row.bg"),
layout: &p.lora_matmul_row.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("lora_mm_row.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.lora_matmul_row);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), 1, 1);
}
#[allow(clippy::too_many_arguments)]
pub fn lora_matmul_col_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
outer: usize,
inner: usize,
scale: f32,
accumulate: bool,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = LoraMatmulParams {
k: outer as u32,
n: inner as u32,
accumulate: accumulate as u32,
_pad: 0,
scale,
_pad2: 0,
_pad3: 0,
_pad4: 0,
};
let p_buf = write_uniform(device, queue, "lora_mm_col.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("lora_mm_col.bg"),
layout: &p.lora_matmul_col.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("lora_mm_col.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.lora_matmul_col);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((inner as u32).div_ceil(64), 1, 1);
}
#[allow(clippy::too_many_arguments)]
pub fn lora_outer_add_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
a: &wgpu::Buffer,
b: &wgpu::Buffer,
out: &wgpu::Buffer,
outer_a: usize,
outer_b: usize,
scale: f32,
accumulate: bool,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = LoraOuterParams {
outer_a: outer_a as u32,
outer_b: outer_b as u32,
accumulate: accumulate as u32,
_pad: 0,
scale,
_pad2: 0,
_pad3: 0,
_pad4: 0,
};
let p_buf = write_uniform(device, queue, "lora_outer.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("lora_outer.bg"),
layout: &p.lora_outer_add.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: a.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: b.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("lora_outer.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.lora_outer_add);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(
(outer_a as u32).div_ceil(8),
(outer_b as u32).div_ceil(8),
1,
);
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct AttnBackParams {
head_dim: u32,
n_heads: u32,
n_kv_heads: u32,
heads_per_kv: u32,
history_len: u32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn attention_backward_dq_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
k_hist: &wgpu::Buffer,
v_hist: &wgpu::Buffer,
probs: &wgpu::Buffer,
d_out: &wgpu::Buffer,
d_scores: &wgpu::Buffer,
d_q: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_kv_heads: usize,
history_len: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let heads_per_kv = n_heads / n_kv_heads;
let params = AttnBackParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_kv_heads: n_kv_heads as u32,
heads_per_kv: heads_per_kv as u32,
history_len: history_len as u32,
_pad0: 0,
_pad1: 0,
_pad2: 0,
};
let p_buf = write_uniform(device, queue, "attn_bwd_dq.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("attn_bwd_dq.bg"),
layout: &p.attention_backward_dq.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: k_hist.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: v_hist.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: probs.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: d_out.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: d_scores.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 6,
resource: d_q.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("attn_bwd_dq.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.attention_backward_dq);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(n_heads as u32, 1, 1);
}
#[allow(clippy::too_many_arguments)]
pub fn attention_backward_dkv_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
probs: &wgpu::Buffer,
d_out: &wgpu::Buffer,
d_scores: &wgpu::Buffer,
d_k_hist: &wgpu::Buffer,
d_v_hist: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_kv_heads: usize,
history_len: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let heads_per_kv = n_heads / n_kv_heads;
let params = AttnBackParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_kv_heads: n_kv_heads as u32,
heads_per_kv: heads_per_kv as u32,
history_len: history_len as u32,
_pad0: 0,
_pad1: 0,
_pad2: 0,
};
let p_buf = write_uniform(device, queue, "attn_bwd_dkv.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("attn_bwd_dkv.bg"),
layout: &p.attention_backward_dkv.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: probs.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: d_out.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: d_scores.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: d_k_hist.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 6,
resource: d_v_hist.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("attn_bwd_dkv.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.attention_backward_dkv);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(n_kv_heads as u32, history_len as u32, 1);
}
pub fn rmsnorm_backward_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
w: &wgpu::Buffer,
dy: &wgpu::Buffer,
dx: &wgpu::Buffer,
n: usize,
eps: f32,
has_weight: bool,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = RmsParams {
n: n as u32,
eps,
has_weight: has_weight as u32,
_p: 0,
};
let p_buf = write_uniform(device, queue, "rms_bwd.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("rms_bwd.bg"),
layout: &p.rmsnorm_backward.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: dy.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: dx.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("rms_bwd.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.rmsnorm_backward);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(1, 1, 1);
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct RmsPerRowBackParams {
n_rows: u32,
n: u32,
eps: f32,
has_weight: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn rmsnorm_per_row_backward_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
w: &wgpu::Buffer,
dy: &wgpu::Buffer,
dx: &wgpu::Buffer,
n_rows: usize,
n: usize,
eps: f32,
has_weight: bool,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = RmsPerRowBackParams {
n_rows: n_rows as u32,
n: n as u32,
eps,
has_weight: has_weight as u32,
};
let p_buf = write_uniform(device, queue, "rms_pr_bwd.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("rms_pr_bwd.bg"),
layout: &p.rmsnorm_per_row_backward.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: dy.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: dx.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("rms_pr_bwd.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.rmsnorm_per_row_backward);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(n_rows as u32, 1, 1);
}
pub fn geglu_backward_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
gate: &wgpu::Buffer,
up: &wgpu::Buffer,
dy: &wgpu::Buffer,
d_gate: &wgpu::Buffer,
d_up: &wgpu::Buffer,
n: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = GegluParams {
n: n as u32,
_p0: 0,
_p1: 0,
_p2: 0,
};
let p_buf = write_uniform(device, queue, "geglu_bwd.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("geglu_bwd.bg"),
layout: &p.geglu_backward.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: gate.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: up.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: dy.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: d_gate.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: d_up.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("geglu_bwd.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.geglu_backward);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), 1, 1);
}
pub fn rope_neox_backward_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
factors: Option<&wgpu::Buffer>,
dummy: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
pos: usize,
rope_dims: usize,
base: f32,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = RopeParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
rope_dims: rope_dims as u32,
pos: pos as u32,
base,
has_factors: factors.is_some() as u32,
_p0: 0,
_p1: 0,
};
let p_buf = write_uniform(device, queue, "rope_bwd.params", ¶ms);
let f_buf = factors.unwrap_or(dummy);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("rope_bwd.bg"),
layout: &p.rope_neox_backward.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: f_buf.as_entire_binding(),
},
],
});
let total = (n_heads * (rope_dims / 2)) as u32;
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("rope_bwd.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.rope_neox_backward);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(total.div_ceil(64), 1, 1);
}
pub fn matmul_q4_k_backward_input_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
weight: &wgpu::Buffer,
dy: &wgpu::Buffer,
dx: &wgpu::Buffer,
k: usize,
n: usize,
) {
assert!(
k.is_multiple_of(256),
"k must be divisible by 256 for Q4_K backward"
);
let device = &ctx.device;
let queue = &ctx.queue;
let params = MatmulParams {
k: k as u32,
n: n as u32,
_p0: 0,
_p1: 0,
};
let p_buf = write_uniform(device, queue, "q4k_bwd.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("q4k_bwd.bg"),
layout: &p.matmul_q4_k_backward_input.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: weight.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: dy.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: dx.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("q4k_bwd.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.matmul_q4_k_backward_input);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((k / 256) as u32, 1, 1);
}
pub fn matmul_q6_k_backward_input_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
weight: &wgpu::Buffer,
dy: &wgpu::Buffer,
dx: &wgpu::Buffer,
k: usize,
n: usize,
) {
assert!(
k.is_multiple_of(256),
"k must be divisible by 256 for Q6_K backward"
);
let device = &ctx.device;
let queue = &ctx.queue;
let params = MatmulParams {
k: k as u32,
n: n as u32,
_p0: 0,
_p1: 0,
};
let p_buf = write_uniform(device, queue, "q6k_bwd.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("q6k_bwd.bg"),
layout: &p.matmul_q6_k_backward_input.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: weight.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: dy.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: dx.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("q6k_bwd.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.matmul_q6_k_backward_input);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((k / 256) as u32, 1, 1);
}
pub fn cross_entropy_backward_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
logits: &wgpu::Buffer,
d_logits: &wgpu::Buffer,
loss_out: &wgpu::Buffer,
vocab_size: usize,
target: u32,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = XEntParams {
vocab_size: vocab_size as u32,
target,
_p0: 0,
_p1: 0,
};
let p_buf = write_uniform(device, queue, "xent_bwd.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("xent_bwd.bg"),
layout: &p.cross_entropy_backward.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: logits.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: d_logits.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: loss_out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("xent_bwd.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.cross_entropy_backward);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(1, 1, 1);
}
pub fn softcap_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
n: usize,
cap: f32,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = CapParams {
n: n as u32,
cap,
_p0: 0,
_p1: 0,
};
let p_buf = write_uniform(device, queue, "cap_chain.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("cap_chain.bg"),
layout: &p.softcap.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("cap_chain.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.softcap);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), 1, 1);
}
pub fn geglu_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
gate: &wgpu::Buffer,
up: &wgpu::Buffer,
y: &wgpu::Buffer,
n: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = GegluParams {
n: n as u32,
_p0: 0,
_p1: 0,
_p2: 0,
};
let p_buf = write_uniform(device, queue, "geglu_chain.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("geglu_chain.bg"),
layout: &p.geglu.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: gate.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: up.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("geglu_chain.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.geglu);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), 1, 1);
}
pub fn rope_neox_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
factors: Option<&wgpu::Buffer>,
dummy: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
pos: usize,
rope_dims: usize,
base: f32,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = RopeParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
rope_dims: rope_dims as u32,
pos: pos as u32,
base,
has_factors: factors.is_some() as u32,
_p0: 0,
_p1: 0,
};
let p_buf = write_uniform(device, queue, "rope_chain.params", ¶ms);
let f_buf = factors.unwrap_or(dummy);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("rope_chain.bg"),
layout: &p.rope_neox.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: f_buf.as_entire_binding(),
},
],
});
let total = (n_heads * (rope_dims / 2)) as u32;
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("rope_chain.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.rope_neox);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(total.div_ceil(64), 1, 1);
}
pub fn residual_add_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
n: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = ResAddParams {
n: n as u32,
_p0: 0,
_p1: 0,
_p2: 0,
};
let p_buf = write_uniform(device, queue, "resadd_chain.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("resadd_chain.bg"),
layout: &p.residual_add.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("resadd_chain.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.residual_add);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), 1, 1);
}
pub fn scale_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
n: usize,
s: f32,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = ScaleParams {
n: n as u32,
s,
_p0: 0,
_p1: 0,
};
let p_buf = write_uniform(device, queue, "scale_chain.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("scale_chain.bg"),
layout: &p.scale.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("scale_chain.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.scale);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), 1, 1);
}
pub fn attention_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k_hist: &wgpu::Buffer,
v_hist: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_kv_heads: usize,
pos: usize,
history_len: usize,
window: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = AttnParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_kv_heads: n_kv_heads as u32,
heads_per_kv: (n_heads / n_kv_heads) as u32,
pos: pos as u32,
history_len: history_len as u32,
window: window as u32,
_p: 0,
};
let p_buf = write_uniform(device, queue, "attn_chain.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("attn_chain.bg"),
layout: &p.attention.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k_hist.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v_hist.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("attn_chain.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.attention);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(n_heads as u32, 1, 1);
}
#[allow(clippy::too_many_arguments)]
pub fn attention_probs_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k_hist: &wgpu::Buffer,
probs: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_kv_heads: usize,
pos: usize,
history_len: usize,
window: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = AttnParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_kv_heads: n_kv_heads as u32,
heads_per_kv: (n_heads / n_kv_heads) as u32,
pos: pos as u32,
history_len: history_len as u32,
window: window as u32,
_p: 0,
};
let p_buf = write_uniform(device, queue, "attn_probs.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("attn_probs.bg"),
layout: &p.attention_probs.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k_hist.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: probs.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("attn_probs.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.attention_probs);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(n_heads as u32, 1, 1);
}
pub async fn attention_cached(
ctx: &WgpuCtx,
p: &Pipelines,
q: &[f32],
k_hist: &[f32],
v_hist: &[f32],
head_dim: usize,
n_heads: usize,
n_kv_heads: usize,
pos: usize,
history_len: usize,
window: usize,
) -> Result<Vec<f32>> {
if q.len() != n_heads * head_dim {
return Err(RullamaError::Inference("attn: q shape".into()));
}
if k_hist.len() != history_len * n_kv_heads * head_dim
|| v_hist.len() != history_len * n_kv_heads * head_dim
{
return Err(RullamaError::Inference("attn: kv shape".into()));
}
if !n_heads.is_multiple_of(n_kv_heads) {
return Err(RullamaError::Inference("attn: n_heads % n_kv_heads".into()));
}
let device = &ctx.device;
let queue = &ctx.queue;
let params = AttnParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_kv_heads: n_kv_heads as u32,
heads_per_kv: (n_heads / n_kv_heads) as u32,
pos: pos as u32,
history_len: history_len as u32,
window: window as u32,
_p: 0,
};
let p_buf = write_uniform(device, queue, "attn.params", ¶ms);
let q_buf = write_storage_f32(device, queue, "attn.q", q);
let k_buf = write_storage_f32(device, queue, "attn.k", k_hist);
let v_buf = write_storage_f32(device, queue, "attn.v", v_hist);
let out_bytes = (n_heads * head_dim * 4) as u64;
let (out_buf, read_buf) = make_output_pair(device, "attn", out_bytes);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("attn.bg"),
layout: &p.attention.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out_buf.as_entire_binding(),
},
],
});
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("attn.enc"),
});
{
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("attn.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.attention);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(n_heads as u32, 1, 1);
}
enc.copy_buffer_to_buffer(&out_buf, 0, &read_buf, 0, out_bytes);
queue.submit(Some(enc.finish()));
read_back_f32(device, &read_buf).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cross_entropy_backward_gpu_vs_cpu() {
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let p = Pipelines::new(&ctx.device);
let vocab = 4096usize;
let mut state: u32 = 0x1234_5678;
let mut next = || {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
((state >> 8) as f32 / 16_777_216.0) - 0.5
};
let logits: Vec<f32> = (0..vocab).map(|_| next() * 4.0).collect();
let target: u32 = 137;
let mut cpu_grad = vec![0.0f32; vocab];
let cpu_loss =
crate::reference::ops::cross_entropy_backward(&logits, target, &mut cpu_grad);
let (gpu_grad, gpu_loss) =
pollster::block_on(cross_entropy_backward_cached(&ctx, &p, &logits, target))
.expect("gpu");
assert!(
(cpu_loss - gpu_loss).abs() < 1e-3,
"loss cpu={cpu_loss} gpu={gpu_loss}"
);
let mut max_diff = 0.0f32;
for (c, g) in cpu_grad.iter().zip(gpu_grad.iter()) {
let d = (c - g).abs();
if d > max_diff {
max_diff = d;
}
}
assert!(max_diff < 1e-5, "d_logits max_diff = {max_diff}");
}
#[test]
fn matmul_q4_k_backward_input_gpu_vs_cpu() {
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let p = Pipelines::new(&ctx.device);
let k = 256usize;
let n = 16usize;
let row_bytes = (k / 256) * 144;
let total_bytes = n * row_bytes;
let mut w_bytes = vec![0u8; total_bytes];
let mut state: u32 = 0xDEAD_BEEF;
for b in w_bytes.iter_mut() {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
*b = (state >> 16) as u8;
}
for j in 0..n {
let off = j * row_bytes;
w_bytes[off] = 0x00;
w_bytes[off + 1] = 0x2C; w_bytes[off + 2] = 0x00;
w_bytes[off + 3] = 0x28; }
let dy: Vec<f32> = (0..n).map(|j| ((j as i32 - 8) as f32) * 0.25).collect();
let mut cpu_dx = vec![0.0f32; k];
crate::reference::ops::matmul_q4_k_backward_input(&w_bytes, &dy, k, n, &mut cpu_dx);
let gpu_dx = pollster::block_on(matmul_q4_k_backward_input_cached(
&ctx, &p, &w_bytes, &dy, k, n,
))
.expect("gpu");
let mut max_diff = 0.0f32;
let mut max_rel = 0.0f32;
for (c, g) in cpu_dx.iter().zip(gpu_dx.iter()) {
let d = (c - g).abs();
if d > max_diff {
max_diff = d;
}
let denom = c.abs().max(1e-6);
let r = d / denom;
if r > max_rel {
max_rel = r;
}
}
assert!(
max_diff < 1e-3 && max_rel < 1e-3,
"q4_k_bwd_input max_abs={max_diff} max_rel={max_rel}"
);
}
#[test]
fn matmul_q6_k_backward_input_gpu_vs_cpu() {
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let p = Pipelines::new(&ctx.device);
let k = 256usize;
let n = 16usize;
let block_bytes = 210usize;
let row_bytes = (k / 256) * block_bytes;
let total_bytes = n * row_bytes;
let mut w_bytes = vec![0u8; total_bytes];
let mut state: u32 = 0xCAFEBABE;
for b in w_bytes.iter_mut() {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
*b = (state >> 16) as u8;
}
for j in 0..n {
for b in 0..(k / 256) {
let off = j * row_bytes + b * block_bytes;
w_bytes[off + 208] = 0x00;
w_bytes[off + 209] = 0x28; }
}
let dy: Vec<f32> = (0..n).map(|j| ((j as i32 - 8) as f32) * 0.25).collect();
let mut cpu_dx = vec![0.0f32; k];
crate::reference::ops::matmul_q6_k_backward_input(&w_bytes, &dy, k, n, &mut cpu_dx);
let gpu_dx = pollster::block_on(matmul_q6_k_backward_input_cached(
&ctx, &p, &w_bytes, &dy, k, n,
))
.expect("gpu");
let mut max_diff = 0.0f32;
let mut max_rel = 0.0f32;
for (c, g) in cpu_dx.iter().zip(gpu_dx.iter()) {
let d = (c - g).abs();
if d > max_diff {
max_diff = d;
}
let denom = c.abs().max(1e-6);
let r = d / denom;
if r > max_rel {
max_rel = r;
}
}
assert!(
max_diff < 1e-3 && max_rel < 1e-3,
"q6_k_bwd_input max_abs={max_diff} max_rel={max_rel}"
);
}
#[test]
fn rmsnorm_backward_gpu_vs_cpu() {
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let p = Pipelines::new(&ctx.device);
let n = 64usize;
let x: Vec<f32> = (0..n).map(|i| (i as f32 - 30.0) * 0.05).collect();
let w: Vec<f32> = (0..n).map(|i| (i as f32 * 0.3).sin() * 0.3 + 1.0).collect();
let dy: Vec<f32> = (0..n).map(|i| (i as f32 * 0.7).cos() * 0.5).collect();
let eps = 1e-6f32;
let mut cpu_dx = vec![0.0f32; n];
crate::reference::ops::rmsnorm_backward(&x, Some(&w), &dy, eps, &mut cpu_dx);
let device = &ctx.device;
let queue = &ctx.queue;
let x_buf = write_storage_f32(device, queue, "x", &x);
let w_buf = write_storage_f32(device, queue, "w", &w);
let dy_buf = write_storage_f32(device, queue, "dy", &dy);
let (dx_buf, dx_read) = make_output_pair(device, "dx", (n * 4) as u64);
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("rms_bwd.enc"),
});
rmsnorm_backward_chained(
&ctx, &p, &mut enc, &x_buf, &w_buf, &dy_buf, &dx_buf, n, eps, true,
);
enc.copy_buffer_to_buffer(&dx_buf, 0, &dx_read, 0, (n * 4) as u64);
queue.submit(Some(enc.finish()));
let gpu_dx = pollster::block_on(read_back_f32(device, &dx_read)).expect("readback");
let mut max_diff = 0.0f32;
for (c, g) in cpu_dx.iter().zip(gpu_dx.iter()) {
let d = (c - g).abs();
if d > max_diff {
max_diff = d;
}
}
assert!(max_diff < 1e-4, "rmsnorm_bwd max_diff = {max_diff}");
}
#[test]
fn rmsnorm_per_row_backward_gpu_vs_cpu() {
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let p = Pipelines::new(&ctx.device);
let n_rows = 4usize;
let n = 32usize;
let total = n_rows * n;
let x: Vec<f32> = (0..total)
.map(|i| ((i as i32 - 30) as f32) * 0.05)
.collect();
let w: Vec<f32> = (0..n).map(|i| (i as f32 * 0.3).sin() * 0.3 + 1.0).collect();
let dy: Vec<f32> = (0..total).map(|i| (i as f32 * 0.7).cos() * 0.5).collect();
let eps = 1e-6f32;
let mut cpu_dx = vec![0.0f32; total];
crate::reference::ops::rmsnorm_per_row_backward(
&x,
Some(&w),
&dy,
eps,
n_rows,
n,
&mut cpu_dx,
);
let device = &ctx.device;
let queue = &ctx.queue;
let x_buf = write_storage_f32(device, queue, "x", &x);
let w_buf = write_storage_f32(device, queue, "w", &w);
let dy_buf = write_storage_f32(device, queue, "dy", &dy);
let (dx_buf, dx_read) = make_output_pair(device, "dx", (total * 4) as u64);
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("rms_pr_bwd.enc"),
});
rmsnorm_per_row_backward_chained(
&ctx, &p, &mut enc, &x_buf, &w_buf, &dy_buf, &dx_buf, n_rows, n, eps, true,
);
enc.copy_buffer_to_buffer(&dx_buf, 0, &dx_read, 0, (total * 4) as u64);
queue.submit(Some(enc.finish()));
let gpu_dx = pollster::block_on(read_back_f32(device, &dx_read)).expect("readback");
let mut max_diff = 0.0f32;
for (c, g) in cpu_dx.iter().zip(gpu_dx.iter()) {
let d = (c - g).abs();
if d > max_diff {
max_diff = d;
}
}
assert!(max_diff < 1e-4, "rmsnorm_per_row_bwd max_diff = {max_diff}");
let mut cpu_dx_u = vec![0.0f32; total];
crate::reference::ops::rmsnorm_per_row_backward(
&x,
None,
&dy,
eps,
n_rows,
n,
&mut cpu_dx_u,
);
let dummy = make_dummy_storage(device, "dummy");
let (dx_u_buf, dx_u_read) = make_output_pair(device, "dx_u", (total * 4) as u64);
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("rms_pr_bwd_u.enc"),
});
rmsnorm_per_row_backward_chained(
&ctx, &p, &mut enc, &x_buf, &dummy, &dy_buf, &dx_u_buf, n_rows, n, eps, false,
);
enc.copy_buffer_to_buffer(&dx_u_buf, 0, &dx_u_read, 0, (total * 4) as u64);
queue.submit(Some(enc.finish()));
let gpu_dx_u = pollster::block_on(read_back_f32(device, &dx_u_read)).expect("readback");
let mut max_diff_u = 0.0f32;
for (c, g) in cpu_dx_u.iter().zip(gpu_dx_u.iter()) {
let d = (c - g).abs();
if d > max_diff_u {
max_diff_u = d;
}
}
assert!(
max_diff_u < 1e-4,
"rmsnorm_per_row_bwd unweighted max_diff = {max_diff_u}"
);
}
#[test]
fn sum_of_squares_gpu_vs_cpu() {
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let p = Pipelines::new(&ctx.device);
let device = &ctx.device;
let queue = &ctx.queue;
for &n in &[63usize, 256usize, 1024usize, 4097usize] {
let x: Vec<f32> = (0..n).map(|i| ((i as i32 - 100) as f32) * 0.03).collect();
let cpu_sos: f32 = x.iter().map(|&v| v * v).sum();
let x_buf = write_storage_f32(device, queue, "x", &x);
let (out_buf, out_read) = make_output_pair(device, "sos", 4);
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("sos.enc"),
});
sum_of_squares_chained(&ctx, &p, &mut enc, &x_buf, &out_buf, n, 1.0);
enc.copy_buffer_to_buffer(&out_buf, 0, &out_read, 0, 4);
queue.submit(Some(enc.finish()));
let gpu = pollster::block_on(read_back_f32(device, &out_read)).expect("readback")[0];
let denom = cpu_sos.abs().max(1e-6);
let rel = (cpu_sos - gpu).abs() / denom;
assert!(rel < 1e-4, "n={n} cpu={cpu_sos} gpu={gpu} rel={rel}");
}
let n = 256usize;
let x: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1 - 5.0).collect();
let cpu_sos: f32 = x.iter().map(|&v| (v * 0.5) * (v * 0.5)).sum();
let x_buf = write_storage_f32(device, queue, "x", &x);
let (out_buf, out_read) = make_output_pair(device, "sos.scaled", 4);
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("sos.scaled.enc"),
});
sum_of_squares_chained(&ctx, &p, &mut enc, &x_buf, &out_buf, n, 0.5);
enc.copy_buffer_to_buffer(&out_buf, 0, &out_read, 0, 4);
queue.submit(Some(enc.finish()));
let gpu = pollster::block_on(read_back_f32(device, &out_read)).expect("readback")[0];
let rel = (cpu_sos - gpu).abs() / cpu_sos.abs().max(1e-6);
assert!(rel < 1e-4, "scaled cpu={cpu_sos} gpu={gpu} rel={rel}");
}
#[test]
fn geglu_backward_gpu_vs_cpu() {
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let p = Pipelines::new(&ctx.device);
let n = 64usize;
let gate: Vec<f32> = (0..n).map(|i| (i as f32 - 30.0) * 0.05).collect();
let up: Vec<f32> = (0..n).map(|i| (i as f32) * 0.02 + 0.5).collect();
let dy: Vec<f32> = (0..n).map(|i| (i as f32 * 0.4).sin()).collect();
let mut cpu_dg = vec![0.0f32; n];
let mut cpu_du = vec![0.0f32; n];
crate::reference::ops::geglu_backward(&gate, &up, &dy, &mut cpu_dg, &mut cpu_du);
let device = &ctx.device;
let queue = &ctx.queue;
let g_buf = write_storage_f32(device, queue, "gate", &gate);
let u_buf = write_storage_f32(device, queue, "up", &up);
let dy_buf = write_storage_f32(device, queue, "dy", &dy);
let (dg_buf, dg_read) = make_output_pair(device, "dg", (n * 4) as u64);
let (du_buf, du_read) = make_output_pair(device, "du", (n * 4) as u64);
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("geglu_bwd.enc"),
});
geglu_backward_chained(
&ctx, &p, &mut enc, &g_buf, &u_buf, &dy_buf, &dg_buf, &du_buf, n,
);
enc.copy_buffer_to_buffer(&dg_buf, 0, &dg_read, 0, (n * 4) as u64);
enc.copy_buffer_to_buffer(&du_buf, 0, &du_read, 0, (n * 4) as u64);
queue.submit(Some(enc.finish()));
let gpu_dg = pollster::block_on(read_back_f32(device, &dg_read)).expect("dg readback");
let gpu_du = pollster::block_on(read_back_f32(device, &du_read)).expect("du readback");
let mut max_dg = 0.0f32;
let mut max_du = 0.0f32;
for i in 0..n {
max_dg = max_dg.max((cpu_dg[i] - gpu_dg[i]).abs());
max_du = max_du.max((cpu_du[i] - gpu_du[i]).abs());
}
assert!(
max_dg < 1e-5 && max_du < 1e-5,
"geglu_bwd max_dg={max_dg} max_du={max_du}"
);
}
#[test]
fn adam_step_gpu_vs_cpu() {
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let p = Pipelines::new(&ctx.device);
let n = 128usize;
let mut param: Vec<f32> = (0..n).map(|i| (i as f32 * 0.07).sin() * 0.5).collect();
let grad: Vec<f32> = (0..n).map(|i| (i as f32 * 0.13).cos() * 0.1).collect();
let mut m_cpu = vec![0.0f32; n];
let mut v_cpu = vec![0.0f32; n];
let mut param_cpu = param.clone();
let lr = 1e-3;
let beta1 = 0.9;
let beta2 = 0.999;
let eps = 1e-8;
let wd = 0.01;
let step = 1u32;
crate::reference::ops::adam_step(
&grad,
&mut param_cpu,
&mut m_cpu,
&mut v_cpu,
lr,
beta1,
beta2,
eps,
wd,
step,
);
let device = &ctx.device;
let queue = &ctx.queue;
let grad_buf = write_storage_f32(device, queue, "g", &grad);
let param_buf = {
let buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("param"),
size: (n * 4) as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
queue.write_buffer(&buf, 0, bytemuck::cast_slice(¶m));
buf
};
let m_init = vec![0.0f32; n];
let v_init = vec![0.0f32; n];
let m_buf = {
let buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("m"),
size: (n * 4) as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
queue.write_buffer(&buf, 0, bytemuck::cast_slice(&m_init));
buf
};
let v_buf = {
let buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("v"),
size: (n * 4) as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
queue.write_buffer(&buf, 0, bytemuck::cast_slice(&v_init));
buf
};
let param_read = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("param.read"),
size: (n * 4) as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("adam.enc"),
});
adam_step_chained(
&ctx,
&p,
&mut enc,
&grad_buf,
¶m_buf,
&m_buf,
&v_buf,
n,
AdamConfig {
lr,
beta1,
beta2,
eps,
weight_decay: wd,
step,
},
);
enc.copy_buffer_to_buffer(¶m_buf, 0, ¶m_read, 0, (n * 4) as u64);
queue.submit(Some(enc.finish()));
let gpu_param = pollster::block_on(read_back_f32(device, ¶m_read)).unwrap();
param = gpu_param;
let max_diff = param
.iter()
.zip(param_cpu.iter())
.map(|(g, c)| (g - c).abs())
.fold(0.0f32, f32::max);
assert!(max_diff < 1e-6, "adam max_diff = {max_diff}");
}
#[test]
fn lora_forward_backward_gpu_vs_cpu() {
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let p = Pipelines::new(&ctx.device);
let k = 16usize;
let r = 4usize;
let n = 12usize;
let scale = 0.5f32;
let a: Vec<f32> = (0..r * k).map(|i| (i as f32 * 0.17).sin() * 0.4).collect();
let b: Vec<f32> = (0..n * r).map(|i| (i as f32 * 0.29).cos() * 0.3).collect();
let x: Vec<f32> = (0..k)
.map(|i| (i as f32 * 0.31).sin() * 0.5 + 0.1)
.collect();
let dy: Vec<f32> = (0..n)
.map(|i| (i as f32 * 0.47).cos() * 0.3 + 0.2)
.collect();
let mut z_cpu = vec![0f32; r];
crate::reference::ops::lora_matmul_row(&a, &x, &mut z_cpu, k, r, 1.0, false);
let mut y_cpu = vec![0f32; n];
crate::reference::ops::lora_matmul_row(&b, &z_cpu, &mut y_cpu, r, n, scale, false);
let mut u_cpu = vec![0f32; r];
crate::reference::ops::lora_matmul_col(&b, &dy, &mut u_cpu, n, r, 1.0, false);
let mut da_cpu = vec![0f32; r * k];
crate::reference::ops::lora_outer_add(&u_cpu, &x, &mut da_cpu, scale, false);
let mut db_cpu = vec![0f32; n * r];
crate::reference::ops::lora_outer_add(&dy, &z_cpu, &mut db_cpu, scale, false);
let mut dx_cpu = vec![0f32; k];
crate::reference::ops::lora_matmul_col(&a, &u_cpu, &mut dx_cpu, r, k, scale, false);
let device = &ctx.device;
let queue = &ctx.queue;
let a_buf = write_storage_f32(device, queue, "A", &a);
let b_buf = write_storage_f32(device, queue, "B", &b);
let x_buf = write_storage_f32(device, queue, "x", &x);
let dy_buf = write_storage_f32(device, queue, "dy", &dy);
let (z_buf, z_read) = make_output_pair(device, "z", (r * 4) as u64);
let (y_buf, y_read) = make_output_pair(device, "y", (n * 4) as u64);
let (u_buf, u_read) = make_output_pair(device, "u", (r * 4) as u64);
let (da_buf, da_read) = make_output_pair(device, "dA", (r * k * 4) as u64);
let (db_buf, db_read) = make_output_pair(device, "dB", (n * r * 4) as u64);
let (dx_buf, dx_read) = make_output_pair(device, "dx", (k * 4) as u64);
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("lora_fb.enc"),
});
lora_matmul_row_chained(&ctx, &p, &mut enc, &a_buf, &x_buf, &z_buf, k, r, 1.0, false);
lora_matmul_row_chained(
&ctx, &p, &mut enc, &b_buf, &z_buf, &y_buf, r, n, scale, false,
);
lora_matmul_col_chained(
&ctx, &p, &mut enc, &b_buf, &dy_buf, &u_buf, n, r, 1.0, false,
);
lora_outer_add_chained(
&ctx, &p, &mut enc, &u_buf, &x_buf, &da_buf, r, k, scale, false,
);
lora_outer_add_chained(
&ctx, &p, &mut enc, &dy_buf, &z_buf, &db_buf, n, r, scale, false,
);
lora_matmul_col_chained(
&ctx, &p, &mut enc, &a_buf, &u_buf, &dx_buf, r, k, scale, false,
);
for (src, sz, dst) in [
(&z_buf, (r * 4) as u64, &z_read),
(&y_buf, (n * 4) as u64, &y_read),
(&u_buf, (r * 4) as u64, &u_read),
(&da_buf, (r * k * 4) as u64, &da_read),
(&db_buf, (n * r * 4) as u64, &db_read),
(&dx_buf, (k * 4) as u64, &dx_read),
] {
enc.copy_buffer_to_buffer(src, 0, dst, 0, sz);
}
queue.submit(Some(enc.finish()));
let z_gpu = pollster::block_on(read_back_f32(device, &z_read)).unwrap();
let y_gpu = pollster::block_on(read_back_f32(device, &y_read)).unwrap();
let u_gpu = pollster::block_on(read_back_f32(device, &u_read)).unwrap();
let da_gpu = pollster::block_on(read_back_f32(device, &da_read)).unwrap();
let db_gpu = pollster::block_on(read_back_f32(device, &db_read)).unwrap();
let dx_gpu = pollster::block_on(read_back_f32(device, &dx_read)).unwrap();
let max = |a: &[f32], b: &[f32]| {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).abs())
.fold(0.0f32, f32::max)
};
assert!(max(&z_cpu, &z_gpu) < 1e-5);
assert!(max(&y_cpu, &y_gpu) < 1e-5);
assert!(max(&u_cpu, &u_gpu) < 1e-5);
assert!(max(&da_cpu, &da_gpu) < 1e-5);
assert!(max(&db_cpu, &db_gpu) < 1e-5);
assert!(max(&dx_cpu, &dx_gpu) < 1e-5);
}
#[test]
fn attention_backward_gpu_vs_cpu() {
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let p = Pipelines::new(&ctx.device);
let n_heads = 2usize;
let n_kv_heads = 1usize;
let head_dim = 8usize;
let history_len = 5usize;
let q_len = n_heads * head_dim;
let kv_len = history_len * n_kv_heads * head_dim;
let q: Vec<f32> = (0..q_len).map(|i| (i as f32 * 0.31).sin() * 0.4).collect();
let k_hist: Vec<f32> = (0..kv_len).map(|i| (i as f32 * 0.17).cos() * 0.3).collect();
let v_hist: Vec<f32> = (0..kv_len).map(|i| (i as f32 * 0.23).sin() * 0.5).collect();
let d_out: Vec<f32> = (0..q_len)
.map(|i| (i as f32 * 0.47).cos() * 0.3 + 0.1)
.collect();
let mut out_unused = vec![0f32; q_len];
let mut probs = vec![0f32; n_heads * history_len];
crate::reference::ops::attention_forward(
&q,
&k_hist,
&v_hist,
&mut out_unused,
&mut probs,
head_dim,
n_heads,
n_kv_heads,
history_len,
);
let mut cpu_dq = vec![0f32; q_len];
let mut cpu_dk = vec![0f32; kv_len];
let mut cpu_dv = vec![0f32; kv_len];
crate::reference::ops::attention_backward(
&q,
&k_hist,
&v_hist,
&probs,
&d_out,
&mut cpu_dq,
&mut cpu_dk,
&mut cpu_dv,
head_dim,
n_heads,
n_kv_heads,
history_len,
);
let device = &ctx.device;
let queue = &ctx.queue;
let q_buf = write_storage_f32(device, queue, "q", &q);
let k_buf = write_storage_f32(device, queue, "k_hist", &k_hist);
let v_buf = write_storage_f32(device, queue, "v_hist", &v_hist);
let probs_buf = write_storage_f32(device, queue, "probs", &probs);
let dout_buf = write_storage_f32(device, queue, "d_out", &d_out);
let (ds_buf, _) = make_output_pair(device, "d_scores", (n_heads * history_len * 4) as u64);
let (dq_buf, dq_read) = make_output_pair(device, "d_q", (q_len * 4) as u64);
let (dk_buf, dk_read) = make_output_pair(device, "d_k_hist", (kv_len * 4) as u64);
let (dv_buf, dv_read) = make_output_pair(device, "d_v_hist", (kv_len * 4) as u64);
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("attn_bwd.enc"),
});
attention_backward_dq_chained(
&ctx,
&p,
&mut enc,
&k_buf,
&v_buf,
&probs_buf,
&dout_buf,
&ds_buf,
&dq_buf,
head_dim,
n_heads,
n_kv_heads,
history_len,
);
attention_backward_dkv_chained(
&ctx,
&p,
&mut enc,
&q_buf,
&probs_buf,
&dout_buf,
&ds_buf,
&dk_buf,
&dv_buf,
head_dim,
n_heads,
n_kv_heads,
history_len,
);
enc.copy_buffer_to_buffer(&dq_buf, 0, &dq_read, 0, (q_len * 4) as u64);
enc.copy_buffer_to_buffer(&dk_buf, 0, &dk_read, 0, (kv_len * 4) as u64);
enc.copy_buffer_to_buffer(&dv_buf, 0, &dv_read, 0, (kv_len * 4) as u64);
queue.submit(Some(enc.finish()));
let gpu_dq = pollster::block_on(read_back_f32(device, &dq_read)).expect("dq");
let gpu_dk = pollster::block_on(read_back_f32(device, &dk_read)).expect("dk");
let gpu_dv = pollster::block_on(read_back_f32(device, &dv_read)).expect("dv");
let max = |a: &[f32], b: &[f32]| -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).abs())
.fold(0.0f32, f32::max)
};
let dq_diff = max(&cpu_dq, &gpu_dq);
let dk_diff = max(&cpu_dk, &gpu_dk);
let dv_diff = max(&cpu_dv, &gpu_dv);
assert!(
dq_diff < 1e-5 && dk_diff < 1e-5 && dv_diff < 1e-5,
"attn_bwd diffs dq={dq_diff} dk={dk_diff} dv={dv_diff}"
);
}
#[test]
fn rope_neox_forward_then_backward_gpu_is_identity() {
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let p = Pipelines::new(&ctx.device);
let head_dim = 16usize;
let n_heads = 4usize;
let rope_dims = 16usize;
let pos = 11usize;
let base = 10_000.0f32;
let total = head_dim * n_heads;
let orig: Vec<f32> = (0..total).map(|i| (i as f32) * 0.07 - 1.5).collect();
let device = &ctx.device;
let queue = &ctx.queue;
let x_buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("rope.x"),
size: (total * 4) as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
queue.write_buffer(&x_buf, 0, bytemuck::cast_slice(&orig));
let dummy = write_storage_f32(device, queue, "dummy", &[0.0]);
let read_buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("rope.read"),
size: (total * 4) as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let mut enc = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("rope.enc"),
});
rope_neox_chained(
&ctx, &p, &mut enc, &x_buf, None, &dummy, head_dim, n_heads, pos, rope_dims, base,
);
rope_neox_backward_chained(
&ctx, &p, &mut enc, &x_buf, None, &dummy, head_dim, n_heads, pos, rope_dims, base,
);
enc.copy_buffer_to_buffer(&x_buf, 0, &read_buf, 0, (total * 4) as u64);
queue.submit(Some(enc.finish()));
let out = pollster::block_on(read_back_f32(device, &read_buf)).expect("readback");
let mut max_drift = 0.0f32;
for (o, n) in orig.iter().zip(out.iter()) {
let d = (o - n).abs();
if d > max_drift {
max_drift = d;
}
}
assert!(max_drift < 1e-4, "rope fwd+bwd drift = {max_drift}");
}
#[test]
fn cross_entropy_backward_gpu_masked_target_is_zero() {
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let p = Pipelines::new(&ctx.device);
let logits: Vec<f32> = (0..512).map(|i| (i as f32) * 0.01).collect();
let (gpu_grad, gpu_loss) =
pollster::block_on(cross_entropy_backward_cached(&ctx, &p, &logits, u32::MAX))
.expect("gpu");
assert_eq!(gpu_loss, 0.0);
for g in &gpu_grad {
assert_eq!(*g, 0.0);
}
}
}