use super::{BatchedMatmulParams, MatmulParams, cached_dispatch, write_uniform};
use crate::backend::WgpuCtx;
use crate::backend::pipelines::Pipelines;
use crate::error::{Result, RullamaError};
use crate::gguf::GgmlDtype;
fn matmul_chained_inner(
ctx: &WgpuCtx,
enc: &mut wgpu::CommandEncoder,
pipeline: &wgpu::ComputePipeline,
label: &str,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
) {
let params = MatmulParams {
k: k as u32,
n: n as u32,
_p0: 0,
_p1: 0,
};
cached_dispatch(
ctx,
enc,
pipeline,
label,
&[w, x, y],
¶ms,
((n as u32).div_ceil(64), 1, 1),
);
}
pub fn matmul_q4_k_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
) {
matmul_chained_inner(ctx, enc, &p.q4_k_matmul, "q4k_chain", w, x, y, k, n);
}
pub fn matmul_q6_k_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
) {
matmul_chained_inner(ctx, enc, &p.q6_k_matmul, "q6k_chain", w, x, y, k, n);
}
pub fn matmul_q4_0_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
) {
matmul_chained_inner(ctx, enc, &p.q4_0_matmul, "q4_0_chain", w, x, y, k, n);
}
pub fn matmul_q5_0_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
) {
matmul_chained_inner(ctx, enc, &p.q5_0_matmul, "q5_0_chain", w, x, y, k, n);
}
pub fn matmul_q8_0_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
) {
matmul_chained_inner(ctx, enc, &p.q8_0_matmul, "q8_0_chain", w, x, y, k, n);
}
pub fn matmul_quant_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
dtype: GgmlDtype,
) -> Result<()> {
match dtype {
GgmlDtype::Q4_K => matmul_q4_k_chained(ctx, p, enc, w, x, y, k, n),
GgmlDtype::Q6_K => matmul_q6_k_chained(ctx, p, enc, w, x, y, k, n),
GgmlDtype::Q4_0 => matmul_q4_0_chained(ctx, p, enc, w, x, y, k, n),
GgmlDtype::Q5_0 => matmul_q5_0_chained(ctx, p, enc, w, x, y, k, n),
GgmlDtype::Q8_0 => matmul_q8_0_chained(ctx, p, enc, w, x, y, k, n),
GgmlDtype::F16 => matmul_f16_chained(ctx, p, enc, w, x, y, k, n),
other => {
return Err(RullamaError::Inference(format!(
"weight matmul: unsupported quant dtype {other:?} (expected F16, Q4_0, Q5_0, Q8_0, Q4_K, or Q6_K)"
)));
}
}
Ok(())
}
pub fn matmul_f16_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
) {
matmul_chained_inner(ctx, enc, &p.f16_matmul, "f16_chain", w, x, y, k, n);
}
#[allow(dead_code)]
pub fn matmul_bf16_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
) {
matmul_chained_inner(ctx, enc, &p.bf16_matmul, "bf16_chain", w, x, y, k, n);
}
fn use_tiled_batched(k: usize, n: usize, batch: usize) -> bool {
k >= 16 && n >= 8 && batch >= 8
}
fn use_tiled_batched_v2(k: usize, n: usize, batch: usize) -> bool {
k >= 16 && n >= 16 && batch >= 16
}
fn use_tiled_batched_v3(k: usize, n: usize, batch: usize) -> bool {
k >= 16 && n >= 32 && batch >= 32
}
pub fn matmul_bf16_batched_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
if use_tiled_batched_v3(k, n, batch) {
if let Some(pipe_f) = p.bf16_matmul_batched_tiled_v3_f16lds.as_ref() {
return matmul_bf16_batched_tiled_v3_f16lds_chained(
ctx, p, pipe_f, enc, w, x, y, k, n, batch,
);
}
matmul_bf16_batched_tiled_v3_chained(ctx, p, enc, w, x, y, k, n, batch);
return;
}
if use_tiled_batched_v2(k, n, batch) {
matmul_bf16_batched_tiled_v2_chained(ctx, p, enc, w, x, y, k, n, batch);
return;
}
if use_tiled_batched(k, n, batch) {
matmul_bf16_batched_tiled_chained(ctx, p, enc, w, x, y, k, n, batch);
return;
}
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "bf16bmm.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("bf16bmm.bg"),
layout: &p.bf16_matmul_batched.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("bf16bmm.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.bf16_matmul_batched);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), batch as u32, 1);
}
pub fn matmul_f16_batched_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
if use_tiled_batched_v3(k, n, batch) {
if let Some(pipe_f) = p.f16_matmul_batched_tiled_v3_f16lds.as_ref() {
return matmul_f16_batched_tiled_v3_f16lds_chained(
ctx, p, pipe_f, enc, w, x, y, k, n, batch,
);
}
matmul_f16_batched_tiled_v3_chained(ctx, p, enc, w, x, y, k, n, batch);
return;
}
if use_tiled_batched_v2(k, n, batch) {
matmul_f16_batched_tiled_v2_chained(ctx, p, enc, w, x, y, k, n, batch);
return;
}
if use_tiled_batched(k, n, batch) {
matmul_f16_batched_tiled_chained(ctx, p, enc, w, x, y, k, n, batch);
return;
}
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "f16bmm.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("f16bmm.bg"),
layout: &p.f16_matmul_batched.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("f16bmm.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.f16_matmul_batched);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), batch as u32, 1);
}
pub fn matmul_f16_batched_tiled_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "f16bmmt.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("f16bmmt.bg"),
layout: &p.f16_matmul_batched_tiled.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("f16bmmt.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.f16_matmul_batched_tiled);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(8), (batch as u32).div_ceil(8), 1);
}
pub fn matmul_bf16_batched_tiled_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "bf16bmmt.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("bf16bmmt.bg"),
layout: &p.bf16_matmul_batched_tiled.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("bf16bmmt.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.bf16_matmul_batched_tiled);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(8), (batch as u32).div_ceil(8), 1);
}
pub fn matmul_f16_batched_tiled_v2_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "f16bmmt2.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("f16bmmt2.bg"),
layout: &p.f16_matmul_batched_tiled_v2.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("f16bmmt2.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.f16_matmul_batched_tiled_v2);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(16), (batch as u32).div_ceil(16), 1);
}
pub fn matmul_f16_batched_tiled_v3_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "f16bmmt3.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("f16bmmt3.bg"),
layout: &p.f16_matmul_batched_tiled_v3.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("f16bmmt3.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.f16_matmul_batched_tiled_v3);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(32), (batch as u32).div_ceil(32), 1);
}
pub fn matmul_bf16_batched_tiled_v3_f16lds_chained(
ctx: &WgpuCtx,
p: &Pipelines,
pipe: &wgpu::ComputePipeline,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
let _ = p;
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "bf16bmmt3f.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("bf16bmmt3f.bg"),
layout: &pipe.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("bf16bmmt3f.pass"),
timestamp_writes: None,
});
cp.set_pipeline(pipe);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(32), (batch as u32).div_ceil(32), 1);
}
pub fn matmul_f16_batched_tiled_v3_f16lds_chained(
ctx: &WgpuCtx,
p: &Pipelines,
pipe: &wgpu::ComputePipeline,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
let _ = p;
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "f16bmmt3f.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("f16bmmt3f.bg"),
layout: &pipe.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("f16bmmt3f.pass"),
timestamp_writes: None,
});
cp.set_pipeline(pipe);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(32), (batch as u32).div_ceil(32), 1);
}
pub fn matmul_f16_batched_tiled_v4_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "f16bmmt4.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("f16bmmt4.bg"),
layout: &p.f16_matmul_batched_tiled_v4.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("f16bmmt4.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.f16_matmul_batched_tiled_v4);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(32), (batch as u32).div_ceil(64), 1);
}
pub fn matmul_bf16_batched_tiled_v3_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "bf16bmmt3.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("bf16bmmt3.bg"),
layout: &p.bf16_matmul_batched_tiled_v3.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("bf16bmmt3.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.bf16_matmul_batched_tiled_v3);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(32), (batch as u32).div_ceil(32), 1);
}
pub fn matmul_bf16_batched_tiled_v2_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
batch: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "bf16bmmt2.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("bf16bmmt2.bg"),
layout: &p.bf16_matmul_batched_tiled_v2.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("bf16bmmt2.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.bf16_matmul_batched_tiled_v2);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(16), (batch as u32).div_ceil(16), 1);
}