use bytemuck::{Pod, Zeroable};
use super::{
GegluParams, MatmulBackInputParams, RmsParams, RopeParams, XEntParams, cached_dispatch,
};
use crate::backend::WgpuCtx;
use crate::backend::pipelines::Pipelines;
use crate::error::{Result, RullamaError};
use crate::gguf::GgmlDtype;
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct AdamParams {
n: u32,
step: u32,
offset: u32,
_pad1: u32,
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
_pad2: f32,
_pad3: f32,
_pad4: f32,
}
#[derive(Clone, Copy, Debug)]
pub struct AdamConfig {
pub lr: f32,
pub beta1: f32,
pub beta2: f32,
pub eps: f32,
pub weight_decay: f32,
pub step: u32,
}
impl Default for AdamConfig {
fn default() -> Self {
Self {
lr: 1e-3,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
weight_decay: 0.0,
step: 1,
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn adam_step_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
grad: &wgpu::Buffer,
param: &wgpu::Buffer,
m: &wgpu::Buffer,
v: &wgpu::Buffer,
n: usize,
cfg: AdamConfig,
) {
let total_groups = (n as u32).div_ceil(64);
const MAX_GROUPS_PER_DISPATCH: u32 = 65535;
let mut groups_done: u32 = 0;
while groups_done < total_groups {
let groups_this = (total_groups - groups_done).min(MAX_GROUPS_PER_DISPATCH);
let params = AdamParams {
n: n as u32,
step: cfg.step,
offset: groups_done * 64,
_pad1: 0,
lr: cfg.lr,
beta1: cfg.beta1,
beta2: cfg.beta2,
eps: cfg.eps,
weight_decay: cfg.weight_decay,
_pad2: 0.0,
_pad3: 0.0,
_pad4: 0.0,
};
cached_dispatch(
ctx,
enc,
&p.adam_step,
"adam",
&[grad, param, m, v],
¶ms,
(groups_this, 1, 1),
);
groups_done += groups_this;
}
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct SumOfSquaresParams {
n: u32,
scale_in: f32,
_p0: u32,
_p1: u32,
}
pub fn sum_of_squares_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
input: &wgpu::Buffer,
output: &wgpu::Buffer,
n: usize,
scale_in: f32,
) {
let params = SumOfSquaresParams {
n: n as u32,
scale_in,
_p0: 0,
_p1: 0,
};
cached_dispatch(
ctx,
enc,
&p.sum_of_squares,
"sos",
&[input, output],
¶ms,
(1, 1, 1),
);
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct LoraMatmulParams {
k: u32,
n: u32,
accumulate: u32,
_pad: u32,
scale: f32,
_pad2: u32,
_pad3: u32,
_pad4: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct LoraOuterParams {
outer_a: u32,
outer_b: u32,
accumulate: u32,
_pad: u32,
scale: f32,
_pad2: u32,
_pad3: u32,
_pad4: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn lora_matmul_row_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
scale: f32,
accumulate: bool,
) {
let params = LoraMatmulParams {
k: k as u32,
n: n as u32,
accumulate: accumulate as u32,
_pad: 0,
scale,
_pad2: 0,
_pad3: 0,
_pad4: 0,
};
let key = crate::backend::CacheKey::three(&p.lora_matmul_row, w, x, y);
let cached = ctx.bind_cache.get_or_create(key, || {
let uniform = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("lora_mm_row.params"),
size: std::mem::size_of::<LoraMatmulParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("lora_mm_row.bg"),
layout: &p.lora_matmul_row.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: uniform.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
crate::backend::CachedDispatch {
uniform,
bind_group,
}
});
ctx.queue
.write_buffer(&cached.uniform, 0, bytemuck::bytes_of(¶ms));
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("lora_mm_row.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.lora_matmul_row);
cp.set_bind_group(0, &cached.bind_group, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), 1, 1);
}
#[allow(clippy::too_many_arguments)]
pub fn lora_matmul_col_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
outer: usize,
inner: usize,
scale: f32,
accumulate: bool,
) {
let params = LoraMatmulParams {
k: outer as u32,
n: inner as u32,
accumulate: accumulate as u32,
_pad: 0,
scale,
_pad2: 0,
_pad3: 0,
_pad4: 0,
};
let key = crate::backend::CacheKey::three(&p.lora_matmul_col, w, x, y);
let cached = ctx.bind_cache.get_or_create(key, || {
let uniform = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("lora_mm_col.params"),
size: std::mem::size_of::<LoraMatmulParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("lora_mm_col.bg"),
layout: &p.lora_matmul_col.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: uniform.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y.as_entire_binding(),
},
],
});
crate::backend::CachedDispatch {
uniform,
bind_group,
}
});
ctx.queue
.write_buffer(&cached.uniform, 0, bytemuck::bytes_of(¶ms));
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("lora_mm_col.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.lora_matmul_col);
cp.set_bind_group(0, &cached.bind_group, &[]);
cp.dispatch_workgroups((inner as u32).div_ceil(64), 1, 1);
}
#[allow(clippy::too_many_arguments)]
pub fn lora_outer_add_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
a: &wgpu::Buffer,
b: &wgpu::Buffer,
out: &wgpu::Buffer,
outer_a: usize,
outer_b: usize,
scale: f32,
accumulate: bool,
) {
let params = LoraOuterParams {
outer_a: outer_a as u32,
outer_b: outer_b as u32,
accumulate: accumulate as u32,
_pad: 0,
scale,
_pad2: 0,
_pad3: 0,
_pad4: 0,
};
let key = crate::backend::CacheKey::three(&p.lora_outer_add, a, b, out);
let cached = ctx.bind_cache.get_or_create(key, || {
let uniform = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("lora_outer.params"),
size: std::mem::size_of::<LoraOuterParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("lora_outer.bg"),
layout: &p.lora_outer_add.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: uniform.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: a.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: b.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: out.as_entire_binding(),
},
],
});
crate::backend::CachedDispatch {
uniform,
bind_group,
}
});
ctx.queue
.write_buffer(&cached.uniform, 0, bytemuck::bytes_of(¶ms));
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("lora_outer.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.lora_outer_add);
cp.set_bind_group(0, &cached.bind_group, &[]);
cp.dispatch_workgroups(
(outer_a as u32).div_ceil(8),
(outer_b as u32).div_ceil(8),
1,
);
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct LoraEmbedColParams {
rank: u32,
vocab: u32,
col: u32,
_pad: u32,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct LoraEmbedColScatterParams {
rank: u32,
vocab: u32,
col: u32,
_pad: u32,
scale: f32,
_pad2: u32,
_pad3: u32,
_pad4: u32,
}
pub fn lora_embed_col_read_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
a: &wgpu::Buffer,
z: &wgpu::Buffer,
rank: u32,
vocab: u32,
col: u32,
) {
let params = LoraEmbedColParams {
rank,
vocab,
col,
_pad: 0,
};
let key = crate::backend::CacheKey::two(&p.lora_embed_col_read, a, z);
let cached = ctx.bind_cache.get_or_create(key, || {
let uniform = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("lora_embed_col_read.params"),
size: std::mem::size_of::<LoraEmbedColParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("lora_embed_col_read.bg"),
layout: &p.lora_embed_col_read.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: uniform.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: a.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: z.as_entire_binding(),
},
],
});
crate::backend::CachedDispatch {
uniform,
bind_group,
}
});
ctx.queue
.write_buffer(&cached.uniform, 0, bytemuck::bytes_of(¶ms));
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("lora_embed_col_read.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.lora_embed_col_read);
cp.set_bind_group(0, &cached.bind_group, &[]);
cp.dispatch_workgroups(rank.div_ceil(64), 1, 1);
}
#[allow(clippy::too_many_arguments)]
pub fn lora_embed_col_scatter_add_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
u: &wgpu::Buffer,
da: &wgpu::Buffer,
rank: u32,
vocab: u32,
col: u32,
scale: f32,
) {
let params = LoraEmbedColScatterParams {
rank,
vocab,
col,
_pad: 0,
scale,
_pad2: 0,
_pad3: 0,
_pad4: 0,
};
let key = crate::backend::CacheKey::two(&p.lora_embed_col_scatter_add, u, da);
let cached = ctx.bind_cache.get_or_create(key, || {
let uniform = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("lora_embed_col_scatter.params"),
size: std::mem::size_of::<LoraEmbedColScatterParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("lora_embed_col_scatter.bg"),
layout: &p.lora_embed_col_scatter_add.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: uniform.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: u.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: da.as_entire_binding(),
},
],
});
crate::backend::CachedDispatch {
uniform,
bind_group,
}
});
ctx.queue
.write_buffer(&cached.uniform, 0, bytemuck::bytes_of(¶ms));
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("lora_embed_col_scatter.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.lora_embed_col_scatter_add);
cp.set_bind_group(0, &cached.bind_group, &[]);
cp.dispatch_workgroups(rank.div_ceil(64), 1, 1);
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct LoraFusedParams {
k: u32,
n: u32,
rank: u32,
accumulate: u32,
scale: f32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn lora_matmul_fused_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
a: &wgpu::Buffer,
b: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
z_out: &wgpu::Buffer,
k: usize,
n: usize,
rank: usize,
scale: f32,
accumulate: bool,
) {
let params = LoraFusedParams {
k: k as u32,
n: n as u32,
rank: rank as u32,
accumulate: accumulate as u32,
scale,
_pad0: 0,
_pad1: 0,
_pad2: 0,
};
let key = crate::backend::CacheKey::four(&p.lora_matmul_fused, a, b, x, y);
let cached = ctx.bind_cache.get_or_create(key, || {
let uniform = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("lora_fused.params"),
size: std::mem::size_of::<LoraFusedParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("lora_fused.bg"),
layout: &p.lora_matmul_fused.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: uniform.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: a.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: b.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: y.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: z_out.as_entire_binding(),
},
],
});
crate::backend::CachedDispatch {
uniform,
bind_group,
}
});
ctx.queue
.write_buffer(&cached.uniform, 0, bytemuck::bytes_of(¶ms));
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("lora_fused.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.lora_matmul_fused);
cp.set_bind_group(0, &cached.bind_group, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), 1, 1);
}
#[allow(clippy::too_many_arguments)]
pub fn lora_matmul_fused_f16b_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
a: &wgpu::Buffer,
b: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
z_out: &wgpu::Buffer,
k: usize,
n: usize,
rank: usize,
scale: f32,
accumulate: bool,
) {
let params = LoraFusedParams {
k: k as u32,
n: n as u32,
rank: rank as u32,
accumulate: accumulate as u32,
scale,
_pad0: 0,
_pad1: 0,
_pad2: 0,
};
let key = crate::backend::CacheKey::four(&p.lora_matmul_fused_f16b, a, b, x, y);
let cached = ctx.bind_cache.get_or_create(key, || {
let uniform = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("lora_fused_f16b.params"),
size: std::mem::size_of::<LoraFusedParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("lora_fused_f16b.bg"),
layout: &p.lora_matmul_fused_f16b.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: uniform.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: a.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: b.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: y.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: z_out.as_entire_binding(),
},
],
});
crate::backend::CachedDispatch {
uniform,
bind_group,
}
});
ctx.queue
.write_buffer(&cached.uniform, 0, bytemuck::bytes_of(¶ms));
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("lora_fused_f16b.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.lora_matmul_fused_f16b);
cp.set_bind_group(0, &cached.bind_group, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), 1, 1);
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct AttnBackParams {
head_dim: u32,
n_heads: u32,
n_kv_heads: u32,
heads_per_kv: u32,
history_len: u32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn attention_backward_dq_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
k_hist: &wgpu::Buffer,
v_hist: &wgpu::Buffer,
probs: &wgpu::Buffer,
d_out: &wgpu::Buffer,
d_scores: &wgpu::Buffer,
d_q: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_kv_heads: usize,
history_len: usize,
) {
let heads_per_kv = n_heads / n_kv_heads;
let params = AttnBackParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_kv_heads: n_kv_heads as u32,
heads_per_kv: heads_per_kv as u32,
history_len: history_len as u32,
_pad0: 0,
_pad1: 0,
_pad2: 0,
};
cached_dispatch(
ctx,
enc,
&p.attention_backward_dq,
"attn_bwd_dq",
&[k_hist, v_hist, probs, d_out, d_scores, d_q],
¶ms,
(n_heads as u32, 1, 1),
);
}
#[allow(clippy::too_many_arguments)]
pub fn attention_backward_dkv_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
probs: &wgpu::Buffer,
d_out: &wgpu::Buffer,
d_scores: &wgpu::Buffer,
d_k_hist: &wgpu::Buffer,
d_v_hist: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_kv_heads: usize,
history_len: usize,
) {
let heads_per_kv = n_heads / n_kv_heads;
let params = AttnBackParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_kv_heads: n_kv_heads as u32,
heads_per_kv: heads_per_kv as u32,
history_len: history_len as u32,
_pad0: 0,
_pad1: 0,
_pad2: 0,
};
cached_dispatch(
ctx,
enc,
&p.attention_backward_dkv,
"attn_bwd_dkv",
&[q, probs, d_out, d_scores, d_k_hist, d_v_hist],
¶ms,
(n_kv_heads as u32, history_len as u32, 1),
);
}
pub fn rmsnorm_backward_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
w: &wgpu::Buffer,
dy: &wgpu::Buffer,
dx: &wgpu::Buffer,
n: usize,
eps: f32,
has_weight: bool,
) {
let params = RmsParams {
n: n as u32,
eps,
has_weight: has_weight as u32,
_p: 0,
};
cached_dispatch(
ctx,
enc,
&p.rmsnorm_backward,
"rms_bwd",
&[x, w, dy, dx],
¶ms,
(1, 1, 1),
);
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct RmsPerRowBackParams {
n_rows: u32,
n: u32,
eps: f32,
has_weight: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn rmsnorm_per_row_backward_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
w: &wgpu::Buffer,
dy: &wgpu::Buffer,
dx: &wgpu::Buffer,
n_rows: usize,
n: usize,
eps: f32,
has_weight: bool,
) {
let params = RmsPerRowBackParams {
n_rows: n_rows as u32,
n: n as u32,
eps,
has_weight: has_weight as u32,
};
cached_dispatch(
ctx,
enc,
&p.rmsnorm_per_row_backward,
"rms_pr_bwd",
&[x, w, dy, dx],
¶ms,
(n_rows as u32, 1, 1),
);
}
pub fn geglu_backward_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
gate: &wgpu::Buffer,
up: &wgpu::Buffer,
dy: &wgpu::Buffer,
d_gate: &wgpu::Buffer,
d_up: &wgpu::Buffer,
n: usize,
) {
let params = GegluParams {
n: n as u32,
_p0: 0,
_p1: 0,
_p2: 0,
};
cached_dispatch(
ctx,
enc,
&p.geglu_backward,
"geglu_bwd",
&[gate, up, dy, d_gate, d_up],
¶ms,
((n as u32).div_ceil(64), 1, 1),
);
}
pub fn rope_neox_backward_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
factors: Option<&wgpu::Buffer>,
dummy: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
pos: usize,
rope_dims: usize,
base: f32,
) {
let params = RopeParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
rope_dims: rope_dims as u32,
pos: pos as u32,
base,
has_factors: factors.is_some() as u32,
_p0: 0,
_p1: 0,
};
let f_buf = factors.unwrap_or(dummy);
let total = (n_heads * (rope_dims / 2)) as u32;
cached_dispatch(
ctx,
enc,
&p.rope_neox_backward,
"rope_bwd",
&[x, f_buf],
¶ms,
(total.div_ceil(64), 1, 1),
);
}
pub fn matmul_q4_k_backward_input_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
weight: &wgpu::Buffer,
dy: &wgpu::Buffer,
dx: &wgpu::Buffer,
k: usize,
n: usize,
) {
assert!(
k.is_multiple_of(256),
"k must be divisible by 256 for Q4_K backward"
);
let params = MatmulBackInputParams {
k: k as u32,
n: n as u32,
j_start: 0,
j_end: n as u32,
accumulate: 0,
_p0: 0,
_p1: 0,
_p2: 0,
};
cached_dispatch(
ctx,
enc,
&p.matmul_q4_k_backward_input,
"q4k_bwd",
&[weight, dy, dx],
¶ms,
((k / 256) as u32, 1, 1),
);
}
#[allow(clippy::too_many_arguments)]
pub fn matmul_q4_k_backward_input_tile_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
weight: &wgpu::Buffer,
dy: &wgpu::Buffer,
dx: &wgpu::Buffer,
k: usize,
n: usize,
j_start: u32,
j_end: u32,
accumulate: bool,
) {
assert!(
k.is_multiple_of(256),
"k must be divisible by 256 for Q4_K backward"
);
assert!(
j_start <= j_end && (j_end as usize) <= n,
"tile out of range"
);
let params = MatmulBackInputParams {
k: k as u32,
n: n as u32,
j_start,
j_end,
accumulate: if accumulate { 1 } else { 0 },
_p0: 0,
_p1: 0,
_p2: 0,
};
cached_dispatch(
ctx,
enc,
&p.matmul_q4_k_backward_input,
"q4k_bwd_tile",
&[weight, dy, dx],
¶ms,
((k / 256) as u32, 1, 1),
);
}
pub fn matmul_q4_0_backward_input_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
weight: &wgpu::Buffer,
dy: &wgpu::Buffer,
dx: &wgpu::Buffer,
k: usize,
n: usize,
) {
assert!(
k.is_multiple_of(32),
"k must be divisible by 32 for Q4_0 backward"
);
let params = MatmulBackInputParams {
k: k as u32,
n: n as u32,
j_start: 0,
j_end: n as u32,
accumulate: 0,
_p0: 0,
_p1: 0,
_p2: 0,
};
cached_dispatch(
ctx,
enc,
&p.matmul_q4_0_backward_input,
"q4_0_bwd",
&[weight, dy, dx],
¶ms,
((k / 32) as u32, 1, 1),
);
}
pub fn matmul_q4_0_backward_input_tile_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
weight: &wgpu::Buffer,
dy: &wgpu::Buffer,
dx: &wgpu::Buffer,
k: usize,
n: usize,
j_start: u32,
j_end: u32,
accumulate: bool,
) {
assert!(
k.is_multiple_of(32),
"k must be divisible by 32 for Q4_0 backward"
);
assert!(
j_start <= j_end && (j_end as usize) <= n,
"tile out of range"
);
let params = MatmulBackInputParams {
k: k as u32,
n: n as u32,
j_start,
j_end,
accumulate: if accumulate { 1 } else { 0 },
_p0: 0,
_p1: 0,
_p2: 0,
};
cached_dispatch(
ctx,
enc,
&p.matmul_q4_0_backward_input,
"q4_0_bwd_tile",
&[weight, dy, dx],
¶ms,
((k / 32) as u32, 1, 1),
);
}
#[allow(clippy::too_many_arguments)]
pub fn matmul_quant_backward_input_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
weight: &wgpu::Buffer,
dy: &wgpu::Buffer,
dx: &wgpu::Buffer,
k: usize,
n: usize,
dtype: GgmlDtype,
) -> Result<()> {
match dtype {
GgmlDtype::Q4_K => matmul_q4_k_backward_input_chained(ctx, p, enc, weight, dy, dx, k, n),
GgmlDtype::Q6_K => matmul_q6_k_backward_input_chained(ctx, p, enc, weight, dy, dx, k, n),
GgmlDtype::Q4_0 => matmul_q4_0_backward_input_chained(ctx, p, enc, weight, dy, dx, k, n),
other => {
return Err(RullamaError::Inference(format!(
"weight backward: unsupported quant dtype {other:?} (expected Q4_0, Q4_K, or Q6_K)"
)));
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn matmul_quant_backward_input_tile_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
weight: &wgpu::Buffer,
dy: &wgpu::Buffer,
dx: &wgpu::Buffer,
k: usize,
n: usize,
j_start: u32,
j_end: u32,
accumulate: bool,
dtype: GgmlDtype,
) -> Result<()> {
match dtype {
GgmlDtype::Q4_K => matmul_q4_k_backward_input_tile_chained(
ctx, p, enc, weight, dy, dx, k, n, j_start, j_end, accumulate,
),
GgmlDtype::Q6_K => matmul_q6_k_backward_input_tile_chained(
ctx, p, enc, weight, dy, dx, k, n, j_start, j_end, accumulate,
),
GgmlDtype::Q4_0 => matmul_q4_0_backward_input_tile_chained(
ctx, p, enc, weight, dy, dx, k, n, j_start, j_end, accumulate,
),
other => {
return Err(RullamaError::Inference(format!(
"weight backward (tiled): unsupported quant dtype {other:?}"
)));
}
}
Ok(())
}
pub fn matmul_q6_k_backward_input_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
weight: &wgpu::Buffer,
dy: &wgpu::Buffer,
dx: &wgpu::Buffer,
k: usize,
n: usize,
) {
assert!(
k.is_multiple_of(256),
"k must be divisible by 256 for Q6_K backward"
);
let params = MatmulBackInputParams {
k: k as u32,
n: n as u32,
j_start: 0,
j_end: n as u32,
accumulate: 0,
_p0: 0,
_p1: 0,
_p2: 0,
};
cached_dispatch(
ctx,
enc,
&p.matmul_q6_k_backward_input,
"q6k_bwd",
&[weight, dy, dx],
¶ms,
((k / 256) as u32, 1, 1),
);
}
#[allow(clippy::too_many_arguments)]
pub fn matmul_q6_k_backward_input_tile_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
weight: &wgpu::Buffer,
dy: &wgpu::Buffer,
dx: &wgpu::Buffer,
k: usize,
n: usize,
j_start: u32,
j_end: u32,
accumulate: bool,
) {
assert!(
k.is_multiple_of(256),
"k must be divisible by 256 for Q6_K backward"
);
assert!(
j_start <= j_end && (j_end as usize) <= n,
"tile out of range"
);
let params = MatmulBackInputParams {
k: k as u32,
n: n as u32,
j_start,
j_end,
accumulate: if accumulate { 1 } else { 0 },
_p0: 0,
_p1: 0,
_p2: 0,
};
cached_dispatch(
ctx,
enc,
&p.matmul_q6_k_backward_input,
"q6k_bwd_tile",
&[weight, dy, dx],
¶ms,
((k / 256) as u32, 1, 1),
);
}
pub fn cross_entropy_backward_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
logits: &wgpu::Buffer,
d_logits: &wgpu::Buffer,
loss_out: &wgpu::Buffer,
vocab_size: usize,
target: u32,
) {
let params = XEntParams {
vocab_size: vocab_size as u32,
target,
_p0: 0,
_p1: 0,
};
cached_dispatch(
ctx,
enc,
&p.cross_entropy_backward,
"xent_bwd",
&[logits, d_logits, loss_out],
¶ms,
(1, 1, 1),
);
}