#![allow(clippy::too_many_arguments)]
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use crate::backend::dispatch::{
attention_backward_dkv_chained, attention_backward_dq_chained, attention_chained,
attention_probs_chained, cross_entropy_backward_chained, geglu_backward_chained, geglu_chained,
lora_matmul_col_chained, lora_matmul_row_chained, lora_outer_add_chained, make_dummy_storage,
matmul_q4_k_backward_input_chained, matmul_q4_k_chained, matmul_q6_k_backward_input_chained,
matmul_q6_k_chained, residual_add_chained, rmsnorm_backward_chained, rmsnorm_chained,
rmsnorm_per_row_backward_chained, rmsnorm_per_row_chained, rope_neox_backward_chained,
rope_neox_chained, scale_chained, softcap_chained,
};
pub struct LayerCaptureBuffers<'a> {
pub hidden_in: &'a wgpu::Buffer,
pub norm_x_attn: &'a wgpu::Buffer,
pub q_pre_norm: &'a wgpu::Buffer,
pub q_post_rope: &'a wgpu::Buffer,
pub k_pre_norm: &'a wgpu::Buffer,
pub v_pre_norm: &'a wgpu::Buffer,
pub attn_out: &'a wgpu::Buffer,
pub attn_proj: &'a wgpu::Buffer,
pub pre_ffn_rms: &'a wgpu::Buffer,
pub norm_x_ffn: &'a wgpu::Buffer,
pub ffn_gate: &'a wgpu::Buffer,
pub ffn_up: &'a wgpu::Buffer,
pub ffn_act: &'a wgpu::Buffer,
pub ffn_out: &'a wgpu::Buffer,
pub ple_state: &'a wgpu::Buffer,
pub ple_act: &'a wgpu::Buffer,
pub ple_proj: &'a wgpu::Buffer,
}
pub struct LoraSlot<'a> {
pub a: &'a wgpu::Buffer, pub b: &'a wgpu::Buffer, pub z: &'a wgpu::Buffer, pub rank: u32,
pub scale: f32, }
pub type LayerProgressCb<'a> = dyn Fn(&str, u32, u32) + 'a;
pub struct LayerLoraSlots<'a> {
pub q: Option<LoraSlot<'a>>,
pub k: Option<LoraSlot<'a>>,
pub v: Option<LoraSlot<'a>>,
pub o: Option<LoraSlot<'a>>,
pub ffn_gate: Option<LoraSlot<'a>>,
pub ffn_up: Option<LoraSlot<'a>>,
pub ffn_down: Option<LoraSlot<'a>>,
}
use crate::backend::{Pipelines, WeightCache, WgpuCtx};
use crate::error::{Result, RullamaError};
use crate::gguf::GgmlDtype;
use crate::model::config::{Gemma4Config, LayerKind};
use crate::reference::forward::build_donor_map_pub;
use crate::reference::weights::Weights;
use bytemuck::{Pod, Zeroable};
use futures_channel::oneshot;
pub const MAX_CONTEXT: u32 = 4096;
pub struct Forward {
cfg: Gemma4Config,
ctx: WgpuCtx,
pipes: Arc<Pipelines>,
wcache: Arc<WeightCache>,
weights: Weights,
hidden: wgpu::Buffer,
norm_x: wgpu::Buffer, norm_y: wgpu::Buffer, q: wgpu::Buffer, q_norm: wgpu::Buffer, k: wgpu::Buffer, k_norm: wgpu::Buffer,
v: wgpu::Buffer,
v_norm: wgpu::Buffer,
attn_out_buf: wgpu::Buffer, attn_proj: wgpu::Buffer, ffn_gate: wgpu::Buffer, ffn_up: wgpu::Buffer,
ffn_act: wgpu::Buffer,
ffn_out: wgpu::Buffer,
per_layer_residual: wgpu::Buffer, per_layer_proj: wgpu::Buffer,
per_layer: wgpu::Buffer,
ple_state: wgpu::Buffer, ple_act: wgpu::Buffer, ple_proj: wgpu::Buffer,
logits_tile: wgpu::Buffer,
logits: wgpu::Buffer,
logits_read: wgpu::Buffer,
kv_k: Vec<Arc<wgpu::Buffer>>,
kv_v: Vec<Arc<wgpu::Buffer>>,
kv_lens: Vec<u32>,
donor_map: Vec<Option<u32>>,
layer_scalars: Vec<Option<f32>>,
dummy: wgpu::Buffer,
max_context: u32,
cancel_flag: Arc<AtomicBool>,
pos: u32,
}
impl Forward {
pub async fn new(
cfg: Gemma4Config,
ctx: WgpuCtx,
pipes: Arc<Pipelines>,
weights: Weights,
wcache: Arc<WeightCache>,
) -> Result<Self> {
Self::new_with_max_context(cfg, ctx, pipes, weights, wcache, MAX_CONTEXT).await
}
pub async fn new_with_max_context(
cfg: Gemma4Config,
ctx: WgpuCtx,
pipes: Arc<Pipelines>,
weights: Weights,
wcache: Arc<WeightCache>,
max_context: u32,
) -> Result<Self> {
if max_context == 0 || max_context > MAX_CONTEXT {
return Err(crate::error::RullamaError::Inference(format!(
"max_context={max_context} out of range (1..={MAX_CONTEXT})"
)));
}
let device = &ctx.device;
let alloc_storage = |label: &str, n: usize| -> wgpu::Buffer {
device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: (n * 4).max(4) as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
})
};
let d_model = cfg.d_model as usize;
let n_heads = cfg.n_heads as usize;
let head_dim_max = cfg.head_dim_global.max(cfg.head_dim_swa) as usize;
let n_kv_heads_max = cfg.n_kv_heads_global.max(cfg.n_kv_heads_swa) as usize;
let ffn_inter_max = (0..cfg.n_layers).map(|i| cfg.ffn(i)).max().unwrap_or(0) as usize;
let ple_dim = cfg.ple_dim as usize;
let n_layers = cfg.n_layers as usize;
let vocab = cfg.vocab_size as usize;
let hidden = alloc_storage("fwd.hidden", d_model);
let norm_x = alloc_storage("fwd.norm_x", d_model);
let norm_y = alloc_storage("fwd.norm_y", d_model);
let q = alloc_storage("fwd.q", n_heads * head_dim_max);
let q_norm = alloc_storage("fwd.q_norm", n_heads * head_dim_max);
let k = alloc_storage("fwd.k", n_kv_heads_max * head_dim_max);
let k_norm = alloc_storage("fwd.k_norm", n_kv_heads_max * head_dim_max);
let v = alloc_storage("fwd.v", n_kv_heads_max * head_dim_max);
let v_norm = alloc_storage("fwd.v_norm", n_kv_heads_max * head_dim_max);
let attn_out_buf = alloc_storage("fwd.attn_out", n_heads * head_dim_max);
let attn_proj = alloc_storage("fwd.attn_proj", d_model);
let ffn_gate = alloc_storage("fwd.ffn_gate", ffn_inter_max);
let ffn_up = alloc_storage("fwd.ffn_up", ffn_inter_max);
let ffn_act = alloc_storage("fwd.ffn_act", ffn_inter_max);
let ffn_out = alloc_storage("fwd.ffn_out", d_model);
let per_layer_residual = alloc_storage("fwd.per_layer_residual", n_layers * ple_dim.max(1));
let per_layer_proj = alloc_storage("fwd.per_layer_proj", n_layers * ple_dim.max(1));
let per_layer = alloc_storage("fwd.per_layer", n_layers * ple_dim.max(1));
let ple_state = alloc_storage("fwd.ple_state", ple_dim.max(1));
let ple_act = alloc_storage("fwd.ple_act", ple_dim.max(1));
let ple_proj = alloc_storage("fwd.ple_proj", d_model);
let logits_tile = alloc_storage("fwd.logits_tile", vocab);
let logits = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("fwd.logits"),
size: (vocab * 4) as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let logits_read = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("fwd.logits_read"),
size: (vocab * 4) as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let donor_map = build_donor_map_pub(&cfg);
let mut kv_k_opt: Vec<Option<Arc<wgpu::Buffer>>> = vec![None; n_layers];
let mut kv_v_opt: Vec<Option<Arc<wgpu::Buffer>>> = vec![None; n_layers];
for i in 0..n_layers {
if donor_map[i].is_none() {
let n_kv = cfg.n_kv_heads(i as u32) as usize;
let hd = cfg.head_dim(i as u32) as usize;
let bytes = (max_context as usize * n_kv * hd * 4) as u64;
kv_k_opt[i] = Some(Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
label: Some(&format!("fwd.kv_k.{i}")),
size: bytes,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
})));
kv_v_opt[i] = Some(Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
label: Some(&format!("fwd.kv_v.{i}")),
size: bytes,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
})));
}
}
for i in 0..n_layers {
if let Some(d) = donor_map[i] {
kv_k_opt[i] = kv_k_opt[d as usize].clone();
kv_v_opt[i] = kv_v_opt[d as usize].clone();
}
}
let kv_k: Vec<Arc<wgpu::Buffer>> = kv_k_opt.into_iter().map(|x| x.unwrap()).collect();
let kv_v: Vec<Arc<wgpu::Buffer>> = kv_v_opt.into_iter().map(|x| x.unwrap()).collect();
let kv_lens = vec![0u32; n_layers];
let dummy = make_dummy_storage(device, "fwd.dummy");
let mut layer_scalars: Vec<Option<f32>> = Vec::with_capacity(n_layers);
for i in 0..cfg.n_layers {
let name = format!("blk.{i}.layer_output_scale.weight");
let v = weights.load_opt_async(&name).await?;
layer_scalars.push(v.and_then(|vec| vec.first().copied()));
}
Ok(Self {
cfg,
ctx,
pipes,
wcache,
weights,
hidden,
norm_x,
norm_y,
q,
q_norm,
k,
k_norm,
v,
v_norm,
attn_out_buf,
attn_proj,
ffn_gate,
ffn_up,
ffn_act,
ffn_out,
per_layer_residual,
per_layer_proj,
per_layer,
ple_state,
ple_act,
ple_proj,
logits_tile,
logits,
logits_read,
kv_k,
kv_v,
kv_lens,
donor_map,
layer_scalars,
dummy,
max_context,
cancel_flag: Arc::new(AtomicBool::new(false)),
pos: 0,
})
}
pub fn cancel(&self) {
self.cancel_flag.store(true, Ordering::Release);
}
fn reset_cancel(&self) {
self.cancel_flag.store(false, Ordering::Release);
}
fn check_cancelled(&self) -> Result<()> {
if self.cancel_flag.load(Ordering::Acquire) {
Err(RullamaError::Cancelled)
} else {
Ok(())
}
}
pub fn cancel_flag(&self) -> Arc<AtomicBool> {
self.cancel_flag.clone()
}
pub fn cfg(&self) -> &Gemma4Config {
&self.cfg
}
pub fn pos(&self) -> u32 {
self.pos
}
pub fn wcache(&self) -> &Arc<WeightCache> {
&self.wcache
}
pub fn ctx(&self) -> &WgpuCtx {
&self.ctx
}
pub fn pipes(&self) -> &std::sync::Arc<Pipelines> {
&self.pipes
}
pub fn logits_buffer(&self) -> &wgpu::Buffer {
&self.logits
}
pub fn hidden_buffer(&self) -> &wgpu::Buffer {
&self.hidden
}
pub async fn run_final_norm_and_output_proj_only(&mut self) -> Result<()> {
let d_model = self.cfg.d_model as usize;
let eps = self.cfg.rms_norm_eps;
let wc = self.wcache.clone();
let final_norm = wc.buffer_async("output_norm.weight").await?;
let token_embd_dtype = wc.dtype("token_embd.weight")?;
let mut enc = self
.ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("fwd.out_proj_only"),
});
rmsnorm_chained(
&self.ctx,
&self.pipes,
&mut enc,
&self.hidden,
Some(&final_norm),
&self.dummy,
&self.norm_x,
d_model,
eps,
);
self.ctx.queue.submit(Some(enc.finish()));
const MAX_TILE_BYTES: usize = 8 * 1024 * 1024;
let tiles = wc
.buffer_tiles_async("token_embd.weight", MAX_TILE_BYTES)
.await?;
for tile in &tiles {
let mut enc = self
.ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("fwd.out_proj_only.tile"),
});
run_matmul_into_buf(
&self.ctx,
&self.pipes,
&mut enc,
token_embd_dtype,
&tile.buffer,
&self.norm_x,
&self.logits_tile,
tile.n_rows,
d_model,
"fwd.out_proj_only_tile",
)?;
enc.copy_buffer_to_buffer(
&self.logits_tile,
0,
&self.logits,
(tile.row_start as u64) * 4,
(tile.n_rows as u64) * 4,
);
self.ctx.queue.submit(Some(enc.finish()));
}
Ok(())
}
pub fn set_hidden_from(&self, src: &wgpu::Buffer, src_offset: u64) {
let d_model = self.cfg.d_model as usize;
let bytes = (d_model as u64) * 4;
let mut enc = self
.ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("fwd.set_hidden_from"),
});
enc.copy_buffer_to_buffer(src, src_offset, &self.hidden, 0, bytes);
self.ctx.queue.submit(Some(enc.finish()));
}
pub fn reset(&mut self) {
self.pos = 0;
for l in self.kv_lens.iter_mut() {
*l = 0;
}
}
fn kv_layout_hash(&self) -> u32 {
let mut h: u32 = 0x811C9DC5; for i in 0..self.cfg.n_layers {
let nkv = self.cfg.n_kv_heads(i);
let hd = self.cfg.head_dim(i);
for byte in nkv.to_le_bytes().iter().chain(hd.to_le_bytes().iter()) {
h ^= *byte as u32;
h = h.wrapping_mul(0x01000193);
}
}
h
}
pub async fn dump_kv(&self) -> Result<Vec<u8>> {
let n_layers = self.cfg.n_layers as usize;
struct Section {
layer_idx: u32,
kv_len: u32,
n_kv_heads: u16,
head_dim: u16,
bytes: u64,
}
let mut sections: Vec<Section> = Vec::new();
let mut total_payload: u64 = 0;
for i in 0..n_layers {
if self.donor_map[i].is_some() {
continue;
}
let kv_len = self.kv_lens[i];
if kv_len == 0 {
continue;
}
let nkv = self.cfg.n_kv_heads(i as u32);
let hd = self.cfg.head_dim(i as u32);
let bytes = (kv_len as u64) * (nkv as u64) * (hd as u64) * 4;
sections.push(Section {
layer_idx: i as u32,
kv_len,
n_kv_heads: nkv as u16,
head_dim: hd as u16,
bytes,
});
total_payload += bytes * 2; }
let mut header = Vec::<u8>::with_capacity(16 + 12 * sections.len());
header.extend_from_slice(b"RLKV");
header.push(1u8);
header.push(sections.len() as u8);
header.extend_from_slice(&[0u8, 0u8]);
header.extend_from_slice(&self.pos.to_le_bytes());
header.extend_from_slice(&self.kv_layout_hash().to_le_bytes());
for s in §ions {
header.extend_from_slice(&s.layer_idx.to_le_bytes());
header.extend_from_slice(&s.kv_len.to_le_bytes());
header.extend_from_slice(&s.n_kv_heads.to_le_bytes());
header.extend_from_slice(&s.head_dim.to_le_bytes());
}
if total_payload == 0 {
return Ok(header);
}
let staging = self.ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("fwd.kv_dump.staging"),
size: total_payload,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let mut enc = self
.ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("fwd.kv_dump.enc"),
});
let mut offset: u64 = 0;
for s in §ions {
let i = s.layer_idx as usize;
enc.copy_buffer_to_buffer(&self.kv_k[i], 0, &staging, offset, s.bytes);
offset += s.bytes;
enc.copy_buffer_to_buffer(&self.kv_v[i], 0, &staging, offset, s.bytes);
offset += s.bytes;
}
self.ctx.queue.submit(Some(enc.finish()));
let payload = read_back_bytes(&self.ctx.device, &staging).await?;
let mut out = header;
out.extend_from_slice(&payload);
Ok(out)
}
pub fn load_kv(&mut self, bytes: &[u8]) -> Result<()> {
if bytes.len() < 16 {
return Err(RullamaError::Inference(format!(
"kv snapshot too short: {} bytes",
bytes.len()
)));
}
if &bytes[0..4] != b"RLKV" {
return Err(RullamaError::Inference("kv snapshot: bad magic".into()));
}
let version = bytes[4];
if version != 1 {
return Err(RullamaError::Inference(format!(
"kv snapshot: unknown version {version}"
)));
}
let n_owned = bytes[5] as usize;
let position = u32::from_le_bytes(bytes[8..12].try_into().unwrap());
let layout_hash = u32::from_le_bytes(bytes[12..16].try_into().unwrap());
let expected_hash = self.kv_layout_hash();
if layout_hash != expected_hash {
return Err(RullamaError::Inference(format!(
"kv snapshot: layout_hash mismatch (snapshot=0x{layout_hash:08X}, model=0x{expected_hash:08X})"
)));
}
let header_size = 16 + 12 * n_owned;
if bytes.len() < header_size {
return Err(RullamaError::Inference(
"kv snapshot: truncated header".into(),
));
}
if position > self.max_context {
return Err(RullamaError::Inference(format!(
"kv snapshot: position {position} exceeds max_context {}",
self.max_context
)));
}
struct Section {
layer_idx: u32,
kv_len: u32,
bytes: u64,
}
let mut sections: Vec<Section> = Vec::with_capacity(n_owned);
let mut total_payload: u64 = 0;
for s in 0..n_owned {
let off = 16 + 12 * s;
let layer_idx = u32::from_le_bytes(bytes[off..off + 4].try_into().unwrap());
let kv_len = u32::from_le_bytes(bytes[off + 4..off + 8].try_into().unwrap());
let nkv = u16::from_le_bytes(bytes[off + 8..off + 10].try_into().unwrap());
let hd = u16::from_le_bytes(bytes[off + 10..off + 12].try_into().unwrap());
if (layer_idx as usize) >= self.kv_lens.len() {
return Err(RullamaError::Inference(format!(
"kv snapshot: layer_idx {layer_idx} out of range"
)));
}
if self.donor_map[layer_idx as usize].is_some() {
return Err(RullamaError::Inference(format!(
"kv snapshot: layer {layer_idx} marked as donor in current model but snapshot has data"
)));
}
let exp_nkv = self.cfg.n_kv_heads(layer_idx) as u16;
let exp_hd = self.cfg.head_dim(layer_idx) as u16;
if nkv != exp_nkv || hd != exp_hd {
return Err(RullamaError::Inference(format!(
"kv snapshot: layer {layer_idx} geometry mismatch \
(snapshot n_kv={nkv} hd={hd}, model n_kv={exp_nkv} hd={exp_hd})"
)));
}
if kv_len > self.max_context {
return Err(RullamaError::Inference(format!(
"kv snapshot: layer {layer_idx} kv_len {kv_len} exceeds max_context {}",
self.max_context
)));
}
let layer_bytes = (kv_len as u64) * (nkv as u64) * (hd as u64) * 4;
sections.push(Section {
layer_idx,
kv_len,
bytes: layer_bytes,
});
total_payload += layer_bytes * 2;
}
let payload_off = header_size;
if (bytes.len() as u64) < (payload_off as u64) + total_payload {
return Err(RullamaError::Inference(format!(
"kv snapshot: payload truncated (have {}, need {})",
bytes.len() - payload_off,
total_payload,
)));
}
let queue = &self.ctx.queue;
let mut off: usize = payload_off;
for s in §ions {
let i = s.layer_idx as usize;
let n = s.bytes as usize;
queue.write_buffer(&self.kv_k[i], 0, &bytes[off..off + n]);
off += n;
queue.write_buffer(&self.kv_v[i], 0, &bytes[off..off + n]);
off += n;
self.kv_lens[i] = s.kv_len;
}
for i in 0..self.kv_lens.len() {
if self.donor_map[i].is_some() {
continue;
}
if !sections.iter().any(|s| s.layer_idx as usize == i) {
self.kv_lens[i] = 0;
}
}
self.pos = position;
Ok(())
}
pub async fn step(&mut self, token_id: u32) -> Result<Vec<f32>> {
self.step_inner(token_id, None, None).await
}
pub async fn step_capture<'a>(
&mut self,
token_id: u32,
capture: &'a [LayerCaptureBuffers<'a>],
loras: Option<&'a [LayerLoraSlots<'a>]>,
) -> Result<Vec<f32>> {
if capture.len() != self.cfg.n_layers as usize {
return Err(RullamaError::Inference(format!(
"step_capture: got {} capture layers, expected {}",
capture.len(),
self.cfg.n_layers
)));
}
if let Some(l) = loras
&& l.len() != self.cfg.n_layers as usize
{
return Err(RullamaError::Inference(format!(
"step_capture: got {} lora slots, expected {}",
l.len(),
self.cfg.n_layers
)));
}
self.step_inner(token_id, Some(capture), loras).await
}
pub async fn step_with_lora<'a>(
&mut self,
token_id: u32,
loras: &'a [LayerLoraSlots<'a>],
) -> Result<Vec<f32>> {
if loras.len() != self.cfg.n_layers as usize {
return Err(RullamaError::Inference(format!(
"step_with_lora: got {} lora slots, expected {}",
loras.len(),
self.cfg.n_layers
)));
}
self.step_inner(token_id, None, Some(loras)).await
}
pub async fn step_with_lora_seqcap<'a>(
&mut self,
token_id: u32,
loras: &'a [LayerLoraSlots<'a>],
capture: &'a [LayerCaptureBuffers<'a>],
) -> Result<Vec<f32>> {
self.step_with_lora_seqcap_with_progress(token_id, loras, capture, None)
.await
}
pub async fn step_with_lora_seqcap_with_progress<'a>(
&mut self,
token_id: u32,
loras: &'a [LayerLoraSlots<'a>],
capture: &'a [LayerCaptureBuffers<'a>],
progress_cb: Option<&LayerProgressCb<'_>>,
) -> Result<Vec<f32>> {
if loras.len() != self.cfg.n_layers as usize {
return Err(RullamaError::Inference(format!(
"step_with_lora_seqcap: got {} lora slots, expected {}",
loras.len(),
self.cfg.n_layers
)));
}
if capture.len() != self.cfg.n_layers as usize {
return Err(RullamaError::Inference(format!(
"step_with_lora_seqcap: got {} captures, expected {}",
capture.len(),
self.cfg.n_layers
)));
}
self.step_inner_with_progress(token_id, Some(capture), Some(loras), progress_cb)
.await
}
async fn step_inner<'a>(
&mut self,
token_id: u32,
capture: Option<&'a [LayerCaptureBuffers<'a>]>,
loras: Option<&'a [LayerLoraSlots<'a>]>,
) -> Result<Vec<f32>> {
self.step_inner_with_progress(token_id, capture, loras, None)
.await
}
async fn step_inner_with_progress<'a>(
&mut self,
token_id: u32,
capture: Option<&'a [LayerCaptureBuffers<'a>]>,
loras: Option<&'a [LayerLoraSlots<'a>]>,
progress_cb: Option<&LayerProgressCb<'_>>,
) -> Result<Vec<f32>> {
if (token_id as u64) >= self.cfg.vocab_size as u64 {
return Err(RullamaError::Inference(format!(
"token_id {token_id} >= vocab_size {}",
self.cfg.vocab_size
)));
}
if self.pos >= self.max_context {
return Err(RullamaError::Inference(format!(
"context length exceeded max_context={}",
self.max_context
)));
}
let d_model = self.cfg.d_model as usize;
let ple_dim = self.cfg.ple_dim as usize;
let mut hidden_cpu = self
.weights
.load_row_async("token_embd.weight", token_id as usize)
.await?;
let scale_factor = (d_model as f32).sqrt();
for v in hidden_cpu.iter_mut() {
*v *= scale_factor;
}
self.ctx
.queue
.write_buffer(&self.hidden, 0, bytemuck::cast_slice(&hidden_cpu));
drop(hidden_cpu);
if self.cfg.has_ple() {
let mut ple_in = self
.weights
.load_row_async("per_layer_token_embd.weight", token_id as usize)
.await?;
let s = (ple_dim as f32).sqrt();
for v in ple_in.iter_mut() {
*v *= s;
}
self.ctx
.queue
.write_buffer(&self.per_layer_residual, 0, bytemuck::cast_slice(&ple_in));
drop(ple_in);
}
self.run_forward_from_hidden_with_progress(capture, loras, progress_cb)
.await
}
pub async fn step_with_embedding(&mut self, embedding: &[f32]) -> Result<Vec<f32>> {
self.step_with_embedding_inner(embedding, None).await
}
pub async fn step_with_embedding_with_lora<'a>(
&mut self,
embedding: &[f32],
loras: &'a [LayerLoraSlots<'a>],
) -> Result<Vec<f32>> {
if loras.len() != self.cfg.n_layers as usize {
return Err(RullamaError::Inference(format!(
"step_with_embedding_with_lora: got {} lora slots, expected {}",
loras.len(),
self.cfg.n_layers
)));
}
self.step_with_embedding_inner(embedding, Some(loras)).await
}
async fn step_with_embedding_inner<'a>(
&mut self,
embedding: &[f32],
loras: Option<&'a [LayerLoraSlots<'a>]>,
) -> Result<Vec<f32>> {
let d_model = self.cfg.d_model as usize;
if embedding.len() != d_model {
return Err(RullamaError::Inference(format!(
"step_with_embedding: got {} f32s, expected d_model = {d_model}",
embedding.len(),
)));
}
if self.pos >= self.max_context {
return Err(RullamaError::Inference(format!(
"context length exceeded max_context={}",
self.max_context
)));
}
self.ctx
.queue
.write_buffer(&self.hidden, 0, bytemuck::cast_slice(embedding));
if self.cfg.has_ple() {
let n_layers = self.cfg.n_layers as usize;
let zeros = vec![0f32; n_layers * self.cfg.ple_dim as usize];
self.ctx
.queue
.write_buffer(&self.per_layer_residual, 0, bytemuck::cast_slice(&zeros));
}
self.run_forward_from_hidden(None, loras).await
}
async fn run_forward_from_hidden<'a>(
&mut self,
capture: Option<&'a [LayerCaptureBuffers<'a>]>,
loras: Option<&'a [LayerLoraSlots<'a>]>,
) -> Result<Vec<f32>> {
self.run_forward_from_hidden_with_progress(capture, loras, None)
.await
}
async fn run_forward_from_hidden_with_progress<'a>(
&mut self,
capture: Option<&'a [LayerCaptureBuffers<'a>]>,
loras: Option<&'a [LayerLoraSlots<'a>]>,
progress_cb: Option<&LayerProgressCb<'_>>,
) -> Result<Vec<f32>> {
self.reset_cancel();
let d_model = self.cfg.d_model as usize;
let n_layers = self.cfg.n_layers as usize;
let ple_dim = self.cfg.ple_dim as usize;
let eps = self.cfg.rms_norm_eps;
let pos = self.pos;
let wc = self.wcache.clone();
let final_norm = wc.buffer_async("output_norm.weight").await?;
let token_embd_dtype = wc.dtype("token_embd.weight")?;
let (ple_proj_w_buf, ple_proj_norm_w_buf, ple_proj_n) = if self.cfg.has_ple() {
if wc.dtype("per_layer_model_proj.weight")? != GgmlDtype::Q4_K {
return Err(RullamaError::Inference(
"per_layer_model_proj expected Q4_K".into(),
));
}
let proj_w = wc.buffer_async("per_layer_model_proj.weight").await?;
let proj_norm = wc.buffer_async("per_layer_proj_norm.weight").await?;
(Some(proj_w), Some(proj_norm), n_layers * ple_dim)
} else {
(None, None, 0)
};
let mut enc = self
.ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("fwd.token_encoder"),
});
if self.cfg.has_ple() {
let proj_w = ple_proj_w_buf.as_ref().unwrap();
let proj_norm_w = ple_proj_norm_w_buf.as_ref().unwrap();
matmul_q4_k_chained(
&self.ctx,
&self.pipes,
&mut enc,
proj_w,
&self.hidden,
&self.per_layer_proj,
d_model,
ple_proj_n,
);
scale_chained(
&self.ctx,
&self.pipes,
&mut enc,
&self.per_layer_proj,
ple_proj_n,
1.0 / (d_model as f32).sqrt(),
);
rmsnorm_per_row_chained(
&self.ctx,
&self.pipes,
&mut enc,
&self.per_layer_proj,
Some(proj_norm_w),
&self.dummy,
&self.per_layer,
n_layers,
ple_dim,
eps,
);
residual_add_chained(
&self.ctx,
&self.pipes,
&mut enc,
&self.per_layer,
&self.per_layer_residual,
ple_proj_n,
);
scale_chained(
&self.ctx,
&self.pipes,
&mut enc,
&self.per_layer,
ple_proj_n,
1.0 / 2.0_f32.sqrt(),
);
}
for i in 0..n_layers as u32 {
let cap = capture.map(|c| &c[i as usize]);
let lora = loras.map(|l| &l[i as usize]);
self.encode_layer(&mut enc, i, pos, cap, lora).await?;
self.ctx.queue.submit(Some(enc.finish()));
self.check_cancelled()?;
if let Some(cb) = progress_cb {
cb("forward", i + 1, n_layers as u32);
}
enc = self
.ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("fwd.token_encoder.cont"),
});
}
rmsnorm_chained(
&self.ctx,
&self.pipes,
&mut enc,
&self.hidden,
Some(&final_norm),
&self.dummy,
&self.norm_x,
d_model,
eps,
);
self.ctx.queue.submit(Some(enc.finish()));
enc = self
.ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("fwd.out_proj_encoder"),
});
const MAX_TILE_BYTES: usize = 8 * 1024 * 1024;
let tiles = wc
.buffer_tiles_async("token_embd.weight", MAX_TILE_BYTES)
.await?;
for tile in &tiles {
run_matmul_into_buf(
&self.ctx,
&self.pipes,
&mut enc,
token_embd_dtype,
&tile.buffer,
&self.norm_x,
&self.logits_tile,
tile.n_rows,
d_model,
"fwd.output_tile",
)?;
enc.copy_buffer_to_buffer(
&self.logits_tile,
0,
&self.logits,
(tile.row_start as u64) * 4,
(tile.n_rows as u64) * 4,
);
self.ctx.queue.submit(Some(enc.finish()));
enc = self
.ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("fwd.out_proj_encoder.cont"),
});
}
let final_src: &wgpu::Buffer = if self.cfg.final_logit_softcap > 0.0 {
softcap_chained(
&self.ctx,
&self.pipes,
&mut enc,
&self.logits,
&self.logits_tile,
self.cfg.vocab_size as usize,
self.cfg.final_logit_softcap,
);
&self.logits_tile
} else {
&self.logits
};
enc.copy_buffer_to_buffer(
final_src,
0,
&self.logits_read,
0,
(self.cfg.vocab_size as u64) * 4,
);
self.ctx.queue.submit(Some(enc.finish()));
let logits = read_back_f32(&self.ctx.device, &self.logits_read).await?;
self.pos = self.pos.saturating_add(1);
Ok(logits)
}
async fn encode_layer<'a>(
&mut self,
enc: &mut wgpu::CommandEncoder,
i: u32,
pos: u32,
capture: Option<&'a LayerCaptureBuffers<'a>>,
loras: Option<&'a LayerLoraSlots<'a>>,
) -> Result<()> {
let prefix = format!("blk.{i}.");
let d_model = self.cfg.d_model as usize;
let eps = self.cfg.rms_norm_eps;
let n_heads = self.cfg.n_heads as usize;
let n_kv_heads = self.cfg.n_kv_heads(i) as usize;
let head_dim = self.cfg.head_dim(i) as usize;
let ffn_n = self.cfg.ffn(i) as usize;
let kind = self.cfg.kind(i);
let donor = self.donor_map[i as usize];
if let Some(cap) = capture {
enc.copy_buffer_to_buffer(
&self.hidden,
0,
cap.hidden_in,
(pos as u64) * (d_model as u64) * 4,
(d_model * 4) as u64,
);
}
let attn_norm_w = self
.wcache
.buffer_async(&format!("{prefix}attn_norm.weight"))
.await?;
let post_attn_w = self
.wcache
.buffer_async(&format!("{prefix}post_attention_norm.weight"))
.await?;
let mlp_norm_w = self
.wcache
.buffer_async(&format!("{prefix}ffn_norm.weight"))
.await?;
let post_ffw_w = self
.wcache
.buffer_async(&format!("{prefix}post_ffw_norm.weight"))
.await?;
let q_w = self
.wcache
.buffer_async(&format!("{prefix}attn_q.weight"))
.await?;
let q_norm_w = self
.wcache
.buffer_async(&format!("{prefix}attn_q_norm.weight"))
.await?;
let o_w = self
.wcache
.buffer_async(&format!("{prefix}attn_output.weight"))
.await?;
let (k_w, k_norm_w, v_w, v_w_dtype) = if donor.is_none() {
let kw = self
.wcache
.buffer_async(&format!("{prefix}attn_k.weight"))
.await?;
let knw = self
.wcache
.buffer_async(&format!("{prefix}attn_k_norm.weight"))
.await?;
let v_name = format!("{prefix}attn_v.weight");
let vw = self.wcache.buffer_async(&v_name).await?;
let dt = self.wcache.dtype(&v_name)?;
(Some(kw), Some(knw), Some(vw), Some(dt))
} else {
(None, None, None, None)
};
let gate_w = self
.wcache
.buffer_async(&format!("{prefix}ffn_gate.weight"))
.await?;
let up_w = self
.wcache
.buffer_async(&format!("{prefix}ffn_up.weight"))
.await?;
let down_name = format!("{prefix}ffn_down.weight");
let down_w = self.wcache.buffer_async(&down_name).await?;
let down_dtype = self.wcache.dtype(&down_name)?;
let (inp_gate_w, proj_w, post_norm_w) = if self.cfg.has_ple() {
let a = self
.wcache
.buffer_async(&format!("{prefix}inp_gate.weight"))
.await?;
let b = self
.wcache
.buffer_async(&format!("{prefix}proj.weight"))
.await?;
let c = self
.wcache
.buffer_async(&format!("{prefix}post_norm.weight"))
.await?;
(Some(a), Some(b), Some(c))
} else {
(None, None, None)
};
let factors_w = if matches!(kind, LayerKind::Global) {
self.wcache.buffer_opt_async("rope_freqs.weight").await?
} else {
None
};
rmsnorm_chained(
&self.ctx,
&self.pipes,
enc,
&self.hidden,
Some(&attn_norm_w),
&self.dummy,
&self.norm_x,
d_model,
eps,
);
if let Some(cap) = capture {
enc.copy_buffer_to_buffer(
&self.norm_x,
0,
cap.norm_x_attn,
(pos as u64) * (d_model as u64) * 4,
(d_model * 4) as u64,
);
}
matmul_q4_k_chained(
&self.ctx,
&self.pipes,
enc,
&q_w,
&self.norm_x,
&self.q,
d_model,
n_heads * head_dim,
);
if let Some(slot) = loras.and_then(|l| l.q.as_ref()) {
lora_matmul_row_chained(
&self.ctx,
&self.pipes,
enc,
slot.a,
&self.norm_x,
slot.z,
d_model,
slot.rank as usize,
1.0,
false,
);
lora_matmul_row_chained(
&self.ctx,
&self.pipes,
enc,
slot.b,
slot.z,
&self.q,
slot.rank as usize,
n_heads * head_dim,
slot.scale,
true,
);
}
if let Some(cap) = capture {
enc.copy_buffer_to_buffer(
&self.q,
0,
cap.q_pre_norm,
(pos as u64) * (n_heads as u64) * (head_dim as u64) * 4,
(n_heads * head_dim * 4) as u64,
);
}
rmsnorm_per_row_chained(
&self.ctx,
&self.pipes,
enc,
&self.q,
Some(&q_norm_w),
&self.dummy,
&self.q_norm,
n_heads,
head_dim,
eps,
);
let (rope_base, rope_dims) = match kind {
LayerKind::SlidingWindow => {
(self.cfg.rope_freq_base_swa, self.cfg.rope_dim_swa as usize)
}
LayerKind::Global => (self.cfg.rope_freq_base, self.cfg.rope_dim_global as usize),
};
rope_neox_chained(
&self.ctx,
&self.pipes,
enc,
&self.q_norm,
factors_w.as_ref(),
&self.dummy,
head_dim,
n_heads,
pos as usize,
rope_dims,
rope_base,
);
if let Some(cap) = capture {
enc.copy_buffer_to_buffer(
&self.q_norm,
0,
cap.q_post_rope,
(pos as u64) * (n_heads as u64) * (head_dim as u64) * 4,
(n_heads * head_dim * 4) as u64,
);
}
if donor.is_none() {
let kw = k_w.as_ref().unwrap();
let knw = k_norm_w.as_ref().unwrap();
let vw = v_w.as_ref().unwrap();
let vdt = v_w_dtype.unwrap();
matmul_q4_k_chained(
&self.ctx,
&self.pipes,
enc,
kw,
&self.norm_x,
&self.k,
d_model,
n_kv_heads * head_dim,
);
if let Some(slot) = loras.and_then(|l| l.k.as_ref()) {
lora_matmul_row_chained(
&self.ctx,
&self.pipes,
enc,
slot.a,
&self.norm_x,
slot.z,
d_model,
slot.rank as usize,
1.0,
false,
);
lora_matmul_row_chained(
&self.ctx,
&self.pipes,
enc,
slot.b,
slot.z,
&self.k,
slot.rank as usize,
n_kv_heads * head_dim,
slot.scale,
true,
);
}
if let Some(cap) = capture {
enc.copy_buffer_to_buffer(
&self.k,
0,
cap.k_pre_norm,
(pos as u64) * (n_kv_heads as u64) * (head_dim as u64) * 4,
(n_kv_heads * head_dim * 4) as u64,
);
}
rmsnorm_per_row_chained(
&self.ctx,
&self.pipes,
enc,
&self.k,
Some(knw),
&self.dummy,
&self.k_norm,
n_kv_heads,
head_dim,
eps,
);
rope_neox_chained(
&self.ctx,
&self.pipes,
enc,
&self.k_norm,
factors_w.as_ref(),
&self.dummy,
head_dim,
n_kv_heads,
pos as usize,
rope_dims,
rope_base,
);
match vdt {
GgmlDtype::Q6_K => matmul_q6_k_chained(
&self.ctx,
&self.pipes,
enc,
vw,
&self.norm_x,
&self.v,
d_model,
n_kv_heads * head_dim,
),
GgmlDtype::Q4_K => matmul_q4_k_chained(
&self.ctx,
&self.pipes,
enc,
vw,
&self.norm_x,
&self.v,
d_model,
n_kv_heads * head_dim,
),
other => {
return Err(RullamaError::Inference(format!(
"attn_v dtype {other:?} unsupported"
)));
}
}
if let Some(slot) = loras.and_then(|l| l.v.as_ref()) {
lora_matmul_row_chained(
&self.ctx,
&self.pipes,
enc,
slot.a,
&self.norm_x,
slot.z,
d_model,
slot.rank as usize,
1.0,
false,
);
lora_matmul_row_chained(
&self.ctx,
&self.pipes,
enc,
slot.b,
slot.z,
&self.v,
slot.rank as usize,
n_kv_heads * head_dim,
slot.scale,
true,
);
}
if let Some(cap) = capture {
enc.copy_buffer_to_buffer(
&self.v,
0,
cap.v_pre_norm,
(pos as u64) * (n_kv_heads as u64) * (head_dim as u64) * 4,
(n_kv_heads * head_dim * 4) as u64,
);
}
rmsnorm_per_row_chained(
&self.ctx,
&self.pipes,
enc,
&self.v,
None,
&self.dummy,
&self.v_norm,
n_kv_heads,
head_dim,
eps,
);
let row_bytes = (n_kv_heads * head_dim * 4) as u64;
let dst_offset = self.kv_lens[i as usize] as u64 * row_bytes;
enc.copy_buffer_to_buffer(
&self.k_norm,
0,
&self.kv_k[i as usize],
dst_offset,
row_bytes,
);
enc.copy_buffer_to_buffer(
&self.v_norm,
0,
&self.kv_v[i as usize],
dst_offset,
row_bytes,
);
self.kv_lens[i as usize] = self.kv_lens[i as usize].saturating_add(1);
}
let history_layer = donor.map(|d| d as usize).unwrap_or(i as usize);
let history_len = self.kv_lens[history_layer] as usize;
let window = if matches!(kind, LayerKind::SlidingWindow) {
self.cfg.sliding_window as usize
} else {
0
};
attention_chained(
&self.ctx,
&self.pipes,
enc,
&self.q_norm,
&self.kv_k[i as usize],
&self.kv_v[i as usize],
&self.attn_out_buf,
head_dim,
n_heads,
n_kv_heads,
pos as usize,
history_len,
window,
);
if let Some(cap) = capture {
enc.copy_buffer_to_buffer(
&self.attn_out_buf,
0,
cap.attn_out,
(pos as u64) * (n_heads as u64) * (head_dim as u64) * 4,
(n_heads * head_dim * 4) as u64,
);
}
matmul_q4_k_chained(
&self.ctx,
&self.pipes,
enc,
&o_w,
&self.attn_out_buf,
&self.attn_proj,
n_heads * head_dim,
d_model,
);
if let Some(slot) = loras.and_then(|l| l.o.as_ref()) {
lora_matmul_row_chained(
&self.ctx,
&self.pipes,
enc,
slot.a,
&self.attn_out_buf,
slot.z,
n_heads * head_dim,
slot.rank as usize,
1.0,
false,
);
lora_matmul_row_chained(
&self.ctx,
&self.pipes,
enc,
slot.b,
slot.z,
&self.attn_proj,
slot.rank as usize,
d_model,
slot.scale,
true,
);
}
if let Some(cap) = capture {
enc.copy_buffer_to_buffer(
&self.attn_proj,
0,
cap.attn_proj,
(pos as u64) * (d_model as u64) * 4,
(d_model * 4) as u64,
);
}
rmsnorm_chained(
&self.ctx,
&self.pipes,
enc,
&self.attn_proj,
Some(&post_attn_w),
&self.dummy,
&self.norm_y,
d_model,
eps,
);
residual_add_chained(
&self.ctx,
&self.pipes,
enc,
&self.hidden,
&self.norm_y,
d_model,
);
if let Some(cap) = capture {
enc.copy_buffer_to_buffer(
&self.hidden,
0,
cap.pre_ffn_rms,
(pos as u64) * (d_model as u64) * 4,
(d_model * 4) as u64,
);
}
rmsnorm_chained(
&self.ctx,
&self.pipes,
enc,
&self.hidden,
Some(&mlp_norm_w),
&self.dummy,
&self.norm_x,
d_model,
eps,
);
if let Some(cap) = capture {
enc.copy_buffer_to_buffer(
&self.norm_x,
0,
cap.norm_x_ffn,
(pos as u64) * (d_model as u64) * 4,
(d_model * 4) as u64,
);
}
matmul_q4_k_chained(
&self.ctx,
&self.pipes,
enc,
&gate_w,
&self.norm_x,
&self.ffn_gate,
d_model,
ffn_n,
);
if let Some(slot) = loras.and_then(|l| l.ffn_gate.as_ref()) {
lora_matmul_row_chained(
&self.ctx,
&self.pipes,
enc,
slot.a,
&self.norm_x,
slot.z,
d_model,
slot.rank as usize,
1.0,
false,
);
lora_matmul_row_chained(
&self.ctx,
&self.pipes,
enc,
slot.b,
slot.z,
&self.ffn_gate,
slot.rank as usize,
ffn_n,
slot.scale,
true,
);
}
matmul_q4_k_chained(
&self.ctx,
&self.pipes,
enc,
&up_w,
&self.norm_x,
&self.ffn_up,
d_model,
ffn_n,
);
if let Some(slot) = loras.and_then(|l| l.ffn_up.as_ref()) {
lora_matmul_row_chained(
&self.ctx,
&self.pipes,
enc,
slot.a,
&self.norm_x,
slot.z,
d_model,
slot.rank as usize,
1.0,
false,
);
lora_matmul_row_chained(
&self.ctx,
&self.pipes,
enc,
slot.b,
slot.z,
&self.ffn_up,
slot.rank as usize,
ffn_n,
slot.scale,
true,
);
}
if let Some(cap) = capture {
let ffn_pos_off = (pos as u64) * (ffn_n as u64) * 4;
enc.copy_buffer_to_buffer(
&self.ffn_gate,
0,
cap.ffn_gate,
ffn_pos_off,
(ffn_n * 4) as u64,
);
enc.copy_buffer_to_buffer(&self.ffn_up, 0, cap.ffn_up, ffn_pos_off, (ffn_n * 4) as u64);
}
geglu_chained(
&self.ctx,
&self.pipes,
enc,
&self.ffn_gate,
&self.ffn_up,
&self.ffn_act,
ffn_n,
);
if let Some(cap) = capture {
enc.copy_buffer_to_buffer(
&self.ffn_act,
0,
cap.ffn_act,
(pos as u64) * (ffn_n as u64) * 4,
(ffn_n * 4) as u64,
);
}
match down_dtype {
GgmlDtype::Q6_K => matmul_q6_k_chained(
&self.ctx,
&self.pipes,
enc,
&down_w,
&self.ffn_act,
&self.ffn_out,
ffn_n,
d_model,
),
GgmlDtype::Q4_K => matmul_q4_k_chained(
&self.ctx,
&self.pipes,
enc,
&down_w,
&self.ffn_act,
&self.ffn_out,
ffn_n,
d_model,
),
other => {
return Err(RullamaError::Inference(format!(
"ffn_down dtype {other:?} unsupported"
)));
}
}
if let Some(slot) = loras.and_then(|l| l.ffn_down.as_ref()) {
lora_matmul_row_chained(
&self.ctx,
&self.pipes,
enc,
slot.a,
&self.ffn_act,
slot.z,
ffn_n,
slot.rank as usize,
1.0,
false,
);
lora_matmul_row_chained(
&self.ctx,
&self.pipes,
enc,
slot.b,
slot.z,
&self.ffn_out,
slot.rank as usize,
d_model,
slot.scale,
true,
);
}
if let Some(cap) = capture {
enc.copy_buffer_to_buffer(
&self.ffn_out,
0,
cap.ffn_out,
(pos as u64) * (d_model as u64) * 4,
(d_model * 4) as u64,
);
}
rmsnorm_chained(
&self.ctx,
&self.pipes,
enc,
&self.ffn_out,
Some(&post_ffw_w),
&self.dummy,
&self.norm_y,
d_model,
eps,
);
residual_add_chained(
&self.ctx,
&self.pipes,
enc,
&self.hidden,
&self.norm_y,
d_model,
);
if self.cfg.has_ple() {
let inp_gate_w = inp_gate_w.unwrap();
let proj_w = proj_w.unwrap();
let post_norm_w = post_norm_w.unwrap();
let ple_dim = self.cfg.ple_dim as usize;
matmul_q4_k_chained(
&self.ctx,
&self.pipes,
enc,
&inp_gate_w,
&self.hidden,
&self.ple_state,
d_model,
ple_dim,
);
if let Some(cap) = capture {
enc.copy_buffer_to_buffer(
&self.ple_state,
0,
cap.ple_state,
(pos as u64) * (ple_dim as u64) * 4,
(ple_dim * 4) as u64,
);
}
let layer_off = (i as u64) * (ple_dim as u64) * 4;
let layer_bytes = (ple_dim as u64) * 4;
enc.copy_buffer_to_buffer(&self.per_layer, layer_off, &self.ple_proj, 0, layer_bytes);
geglu_chained(
&self.ctx,
&self.pipes,
enc,
&self.ple_state,
&self.ple_proj,
&self.ple_act,
ple_dim,
);
if let Some(cap) = capture {
enc.copy_buffer_to_buffer(
&self.ple_act,
0,
cap.ple_act,
(pos as u64) * (ple_dim as u64) * 4,
(ple_dim * 4) as u64,
);
}
matmul_q4_k_chained(
&self.ctx,
&self.pipes,
enc,
&proj_w,
&self.ple_act,
&self.ple_proj,
ple_dim,
d_model,
);
if let Some(cap) = capture {
enc.copy_buffer_to_buffer(
&self.ple_proj,
0,
cap.ple_proj,
(pos as u64) * (d_model as u64) * 4,
(d_model * 4) as u64,
);
}
rmsnorm_chained(
&self.ctx,
&self.pipes,
enc,
&self.ple_proj,
Some(&post_norm_w),
&self.dummy,
&self.norm_y,
d_model,
eps,
);
residual_add_chained(
&self.ctx,
&self.pipes,
enc,
&self.hidden,
&self.norm_y,
d_model,
);
}
if let Some(s) = self.layer_scalars[i as usize] {
scale_chained(&self.ctx, &self.pipes, enc, &self.hidden, d_model, s);
}
Ok(())
}
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct MatmulParams {
k: u32,
n: u32,
_p0: u32,
_p1: u32,
}
fn run_matmul_into_buf(
ctx: &WgpuCtx,
pipes: &Pipelines,
enc: &mut wgpu::CommandEncoder,
dtype: GgmlDtype,
w: &wgpu::Buffer,
x: &wgpu::Buffer,
dst: &wgpu::Buffer,
n_rows: usize,
k: usize,
label: &str,
) -> Result<()> {
let device = &ctx.device;
let queue = &ctx.queue;
let pipeline = match dtype {
GgmlDtype::Q4_K => &pipes.q4_k_matmul,
GgmlDtype::Q6_K => &pipes.q6_k_matmul,
other => {
return Err(RullamaError::Inference(format!(
"output proj dtype {other:?} not supported"
)));
}
};
let params = MatmulParams {
k: k as u32,
n: n_rows as u32,
_p0: 0,
_p1: 0,
};
let p_buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some(&format!("{label}.params")),
size: std::mem::size_of::<MatmulParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
queue.write_buffer(&p_buf, 0, bytemuck::bytes_of(¶ms));
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some(&format!("{label}.bg")),
layout: &pipeline.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: dst.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(label),
timestamp_writes: None,
});
cp.set_pipeline(pipeline);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n_rows as u32).div_ceil(64), 1, 1);
Ok(())
}
async fn read_buf_stats(ctx: &WgpuCtx, buf: &wgpu::Buffer, n: usize) -> Result<(f32, usize)> {
let bytes = (n * 4) as u64;
let read_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("trace.read"),
size: bytes,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let mut enc = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("trace.enc"),
});
enc.copy_buffer_to_buffer(buf, 0, &read_buf, 0, bytes);
ctx.queue.submit(Some(enc.finish()));
let v = read_back_f32(&ctx.device, &read_buf).await?;
let mut max_abs = 0.0f32;
let mut nans = 0usize;
for &x in &v {
if x.is_nan() {
nans += 1;
} else if x.abs() > max_abs {
max_abs = x.abs();
}
}
Ok((max_abs, nans))
}
async fn read_back_f32(device: &wgpu::Device, buf: &wgpu::Buffer) -> Result<Vec<f32>> {
let slice = buf.slice(..);
let (sender, receiver) = oneshot::channel();
slice.map_async(wgpu::MapMode::Read, move |r| {
let _ = sender.send(r);
});
device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
})
.map_err(|e| RullamaError::Inference(format!("{e:?}")))?;
receiver
.await
.map_err(|e| RullamaError::BufferMap(format!("{e}")))?
.map_err(|e| RullamaError::BufferMap(format!("{e}")))?;
let data = slice.get_mapped_range();
let v: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
drop(data);
buf.unmap();
Ok(v)
}
async fn read_back_bytes(device: &wgpu::Device, buf: &wgpu::Buffer) -> Result<Vec<u8>> {
let slice = buf.slice(..);
let (sender, receiver) = oneshot::channel();
slice.map_async(wgpu::MapMode::Read, move |r| {
let _ = sender.send(r);
});
device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
})
.map_err(|e| RullamaError::Inference(format!("{e:?}")))?;
receiver
.await
.map_err(|e| RullamaError::BufferMap(format!("{e}")))?
.map_err(|e| RullamaError::BufferMap(format!("{e}")))?;
let data = slice.get_mapped_range();
let v: Vec<u8> = data.to_vec();
drop(data);
buf.unmap();
Ok(v)
}
pub struct LoraGradPair<'a> {
pub a: &'a wgpu::Buffer, pub b: &'a wgpu::Buffer, pub z: &'a wgpu::Buffer, pub d_a: &'a wgpu::Buffer, pub d_b: &'a wgpu::Buffer, pub rank: u32,
pub scale: f32,
}
pub struct LayerLoraGrads<'a> {
pub q: Option<LoraGradPair<'a>>,
pub k: Option<LoraGradPair<'a>>,
pub v: Option<LoraGradPair<'a>>,
pub o: Option<LoraGradPair<'a>>,
pub ffn_gate: Option<LoraGradPair<'a>>,
pub ffn_up: Option<LoraGradPair<'a>>,
pub ffn_down: Option<LoraGradPair<'a>>,
}
#[allow(clippy::struct_field_names)]
pub struct BackwardScratchView<'a> {
pub d_logits: &'a wgpu::Buffer,
pub loss: &'a wgpu::Buffer,
pub d_hidden_final: &'a wgpu::Buffer,
pub d_hidden: &'a wgpu::Buffer,
pub d_hidden_tmp: &'a wgpu::Buffer,
pub d_hidden_tmp2: &'a wgpu::Buffer,
pub attn_probs: &'a wgpu::Buffer,
pub attn_d_scores: &'a wgpu::Buffer,
pub d_attn_out: &'a wgpu::Buffer,
pub d_q: &'a wgpu::Buffer,
pub d_k_hist: &'a wgpu::Buffer,
pub d_v_hist: &'a wgpu::Buffer,
pub d_q_pre_rope: &'a wgpu::Buffer,
pub d_k_pre_rope: &'a wgpu::Buffer,
pub d_q_pre_norm: &'a wgpu::Buffer,
pub d_k_pre_norm: &'a wgpu::Buffer,
pub d_v_pre_norm: &'a wgpu::Buffer,
pub d_ffn_a: &'a wgpu::Buffer,
pub d_ffn_b: &'a wgpu::Buffer,
pub d_ffn_c: &'a wgpu::Buffer,
pub d_ple_state: &'a wgpu::Buffer,
pub d_ple_act: &'a wgpu::Buffer,
pub d_ple_up_discard: &'a wgpu::Buffer,
pub ple_per_layer_tmp: &'a wgpu::Buffer,
pub norm_x_attn_window: &'a wgpu::Buffer,
pub k_pre_norm_window: &'a wgpu::Buffer,
pub v_pre_norm_window: &'a wgpu::Buffer,
pub hidden_in_window: &'a wgpu::Buffer,
pub q_pre_norm_window: &'a wgpu::Buffer,
pub q_post_rope_window: &'a wgpu::Buffer,
pub attn_out_window: &'a wgpu::Buffer,
pub attn_proj_window: &'a wgpu::Buffer,
pub pre_ffn_rms_window: &'a wgpu::Buffer,
pub norm_x_ffn_window: &'a wgpu::Buffer,
pub ffn_gate_window: &'a wgpu::Buffer,
pub ffn_up_window: &'a wgpu::Buffer,
pub ffn_act_window: &'a wgpu::Buffer,
pub ffn_out_window: &'a wgpu::Buffer,
pub ple_state_window: &'a wgpu::Buffer,
pub ple_act_window: &'a wgpu::Buffer,
pub ple_proj_window: &'a wgpu::Buffer,
}
impl Forward {
#[allow(clippy::too_many_arguments)]
pub async fn backward_step<'a>(
&mut self,
target_id: u32,
capture: &'a [LayerCaptureBuffers<'a>],
loras: &'a [LayerLoraSlots<'a>],
grads: &'a [LayerLoraGrads<'a>],
scratch: &'a BackwardScratchView<'a>,
history_len: u32,
pos: u32,
recompute_captures: bool,
) -> Result<f32> {
self.backward_step_with_progress(
target_id,
capture,
loras,
grads,
scratch,
history_len,
pos,
recompute_captures,
None,
)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn backward_step_with_progress<'a>(
&mut self,
target_id: u32,
capture: &'a [LayerCaptureBuffers<'a>],
loras: &'a [LayerLoraSlots<'a>],
grads: &'a [LayerLoraGrads<'a>],
scratch: &'a BackwardScratchView<'a>,
history_len: u32,
pos: u32,
recompute_captures: bool,
progress_cb: Option<&LayerProgressCb<'_>>,
) -> Result<f32> {
self.reset_cancel();
let n_layers = self.cfg.n_layers as usize;
if capture.len() != n_layers || loras.len() != n_layers || grads.len() != n_layers {
return Err(RullamaError::Inference(
"backward_step: capture/loras/grads slice length must equal n_layers".into(),
));
}
let d_model = self.cfg.d_model as usize;
let vocab = self.cfg.vocab_size as usize;
let eps = self.cfg.rms_norm_eps;
let wc = self.wcache.clone();
let final_norm = wc.buffer_async("output_norm.weight").await?;
let token_embd = wc.buffer_async("token_embd.weight").await?;
let token_embd_dtype = wc.dtype("token_embd.weight")?;
let mut enc = self
.ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("bwd.head"),
});
cross_entropy_backward_chained(
&self.ctx,
&self.pipes,
&mut enc,
&self.logits,
scratch.d_logits,
scratch.loss,
vocab,
target_id,
);
match token_embd_dtype {
GgmlDtype::Q6_K => matmul_q6_k_backward_input_chained(
&self.ctx,
&self.pipes,
&mut enc,
&token_embd,
scratch.d_logits,
scratch.d_hidden_final,
d_model,
vocab,
),
GgmlDtype::Q4_K => matmul_q4_k_backward_input_chained(
&self.ctx,
&self.pipes,
&mut enc,
&token_embd,
scratch.d_logits,
scratch.d_hidden_final,
d_model,
vocab,
),
other => {
return Err(RullamaError::Inference(format!(
"backward_step: token_embd dtype {other:?} unsupported"
)));
}
}
rmsnorm_backward_chained(
&self.ctx,
&self.pipes,
&mut enc,
&self.hidden,
&final_norm,
scratch.d_hidden_final,
scratch.d_hidden,
d_model,
eps,
true,
);
self.ctx.queue.submit(Some(enc.finish()));
let trace_hidden = std::env::var("RULLAMA_TRACE_DHIDDEN").is_ok();
let clip_max: f32 = std::env::var("RULLAMA_CLIP_DHIDDEN")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(1.0);
if trace_hidden {
let (max_abs, nans) =
read_buf_stats(&self.ctx, scratch.d_hidden, self.cfg.d_model as usize).await?;
eprintln!("[trace] after head section: d_hidden max_abs={max_abs:.3e} nan={nans}");
let (max_abs_f, nans_f) =
read_buf_stats(&self.ctx, scratch.d_hidden_final, self.cfg.d_model as usize)
.await?;
eprintln!("[trace] d_hidden_final (head): max_abs={max_abs_f:.3e} nan={nans_f}");
let (max_abs_l, nans_l) =
read_buf_stats(&self.ctx, scratch.d_logits, self.cfg.vocab_size as usize).await?;
eprintln!("[trace] d_logits: max_abs={max_abs_l:.3e} nan={nans_l}");
}
let d_model_bytes = (self.cfg.d_model as u64) * 4;
for li in (0..n_layers).rev() {
let i = li as u32;
let cap = &capture[li];
let lora = &loras[li];
let grad = &grads[li];
if recompute_captures {
let mut renc =
self.ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("bwd.replay"),
});
renc.copy_buffer_to_buffer(
cap.hidden_in,
(pos as u64) * d_model_bytes,
&self.hidden,
0,
d_model_bytes,
);
let saved_len = self.kv_lens[li];
if self.donor_map[li].is_none() && saved_len > 0 {
self.kv_lens[li] = saved_len - 1;
}
self.encode_layer(&mut renc, i, pos, Some(cap), Some(lora))
.await?;
debug_assert_eq!(
self.kv_lens[li], saved_len,
"replay should leave kv_lens unchanged for layer {li}"
);
self.ctx.queue.submit(Some(renc.finish()));
}
let mut lenc =
self.ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("bwd.layer"),
});
self.backward_layer(&mut lenc, i, history_len, pos, cap, lora, grad, scratch)
.await?;
self.ctx.queue.submit(Some(lenc.finish()));
self.check_cancelled()?;
if let Some(cb) = progress_cb {
let logical = (n_layers as u32) - i;
cb("backward", logical, n_layers as u32);
}
if clip_max > 0.0 {
let (max_abs, _) =
read_buf_stats(&self.ctx, scratch.d_hidden, self.cfg.d_model as usize).await?;
if max_abs > clip_max && max_abs.is_finite() {
let s = clip_max / max_abs;
let mut cenc =
self.ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("bwd.clip"),
});
scale_chained(
&self.ctx,
&self.pipes,
&mut cenc,
scratch.d_hidden,
self.cfg.d_model as usize,
s,
);
self.ctx.queue.submit(Some(cenc.finish()));
}
}
if trace_hidden {
let (max_abs, nans) =
read_buf_stats(&self.ctx, scratch.d_hidden, self.cfg.d_model as usize).await?;
eprintln!(
"[trace] after layer {li} bwd: d_hidden max_abs={max_abs:.3e} nan={nans}"
);
}
}
let loss_read = self.ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("bwd.loss_read"),
size: 4,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let mut renc = self
.ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("bwd.loss_copy"),
});
renc.copy_buffer_to_buffer(scratch.loss, 0, &loss_read, 0, 4);
self.ctx.queue.submit(Some(renc.finish()));
let loss_vec = read_back_f32(&self.ctx.device, &loss_read).await?;
Ok(loss_vec[0])
}
#[allow(clippy::too_many_arguments)]
async fn backward_layer<'a>(
&mut self,
enc: &mut wgpu::CommandEncoder,
i: u32,
history_len: u32,
pos: u32,
cap: &LayerCaptureBuffers<'a>,
lora: &LayerLoraSlots<'a>,
grad: &LayerLoraGrads<'a>,
scratch: &BackwardScratchView<'a>,
) -> Result<()> {
let prefix = format!("blk.{i}.");
let d_model = self.cfg.d_model as usize;
let eps = self.cfg.rms_norm_eps;
let n_heads = self.cfg.n_heads as usize;
let n_kv_heads = self.cfg.n_kv_heads(i) as usize;
let head_dim = self.cfg.head_dim(i) as usize;
let ffn_n = self.cfg.ffn(i) as usize;
let kind = self.cfg.kind(i);
let wc = self.wcache.clone();
let attn_norm_w = wc
.buffer_async(&format!("{prefix}attn_norm.weight"))
.await?;
let post_attn_w = wc
.buffer_async(&format!("{prefix}post_attention_norm.weight"))
.await?;
let mlp_norm_w = wc.buffer_async(&format!("{prefix}ffn_norm.weight")).await?;
let post_ffw_w = wc
.buffer_async(&format!("{prefix}post_ffw_norm.weight"))
.await?;
let q_w = wc.buffer_async(&format!("{prefix}attn_q.weight")).await?;
let q_norm_w = wc
.buffer_async(&format!("{prefix}attn_q_norm.weight"))
.await?;
let o_w = wc
.buffer_async(&format!("{prefix}attn_output.weight"))
.await?;
let k_w = wc.buffer_async(&format!("{prefix}attn_k.weight")).await?;
let k_norm_w = wc
.buffer_async(&format!("{prefix}attn_k_norm.weight"))
.await?;
let v_name = format!("{prefix}attn_v.weight");
let v_w = wc.buffer_async(&v_name).await?;
let v_w_dtype = wc.dtype(&v_name)?;
let gate_w = wc.buffer_async(&format!("{prefix}ffn_gate.weight")).await?;
let up_w = wc.buffer_async(&format!("{prefix}ffn_up.weight")).await?;
let down_name = format!("{prefix}ffn_down.weight");
let down_w = wc.buffer_async(&down_name).await?;
let down_dtype = wc.dtype(&down_name)?;
let factors_w = if matches!(kind, LayerKind::Global) {
wc.buffer_opt_async("rope_freqs.weight").await?
} else {
None
};
if let Some(s) = self.layer_scalars[i as usize] {
scale_chained(&self.ctx, &self.pipes, enc, scratch.d_hidden, d_model, s);
}
let d_model_bytes = (d_model as u64) * 4;
let kv_row_bytes = (n_kv_heads as u64) * (head_dim as u64) * 4;
let n_heads_row_bytes = (n_heads as u64) * (head_dim as u64) * 4;
let ffn_row_bytes = (ffn_n as u64) * 4;
let pos_off = pos as u64;
enc.copy_buffer_to_buffer(
cap.norm_x_attn,
pos_off * d_model_bytes,
scratch.norm_x_attn_window,
0,
d_model_bytes,
);
enc.copy_buffer_to_buffer(
cap.k_pre_norm,
pos_off * kv_row_bytes,
scratch.k_pre_norm_window,
0,
kv_row_bytes,
);
enc.copy_buffer_to_buffer(
cap.v_pre_norm,
pos_off * kv_row_bytes,
scratch.v_pre_norm_window,
0,
kv_row_bytes,
);
enc.copy_buffer_to_buffer(
cap.hidden_in,
pos_off * d_model_bytes,
scratch.hidden_in_window,
0,
d_model_bytes,
);
enc.copy_buffer_to_buffer(
cap.q_pre_norm,
pos_off * n_heads_row_bytes,
scratch.q_pre_norm_window,
0,
n_heads_row_bytes,
);
enc.copy_buffer_to_buffer(
cap.q_post_rope,
pos_off * n_heads_row_bytes,
scratch.q_post_rope_window,
0,
n_heads_row_bytes,
);
enc.copy_buffer_to_buffer(
cap.attn_out,
pos_off * n_heads_row_bytes,
scratch.attn_out_window,
0,
n_heads_row_bytes,
);
enc.copy_buffer_to_buffer(
cap.attn_proj,
pos_off * d_model_bytes,
scratch.attn_proj_window,
0,
d_model_bytes,
);
enc.copy_buffer_to_buffer(
cap.pre_ffn_rms,
pos_off * d_model_bytes,
scratch.pre_ffn_rms_window,
0,
d_model_bytes,
);
enc.copy_buffer_to_buffer(
cap.norm_x_ffn,
pos_off * d_model_bytes,
scratch.norm_x_ffn_window,
0,
d_model_bytes,
);
enc.copy_buffer_to_buffer(
cap.ffn_gate,
pos_off * ffn_row_bytes,
scratch.ffn_gate_window,
0,
ffn_row_bytes,
);
enc.copy_buffer_to_buffer(
cap.ffn_up,
pos_off * ffn_row_bytes,
scratch.ffn_up_window,
0,
ffn_row_bytes,
);
enc.copy_buffer_to_buffer(
cap.ffn_act,
pos_off * ffn_row_bytes,
scratch.ffn_act_window,
0,
ffn_row_bytes,
);
enc.copy_buffer_to_buffer(
cap.ffn_out,
pos_off * d_model_bytes,
scratch.ffn_out_window,
0,
d_model_bytes,
);
if self.cfg.has_ple() {
let ple_dim_bytes = (self.cfg.ple_dim as u64) * 4;
enc.copy_buffer_to_buffer(
cap.ple_state,
pos_off * ple_dim_bytes,
scratch.ple_state_window,
0,
ple_dim_bytes,
);
enc.copy_buffer_to_buffer(
cap.ple_act,
pos_off * ple_dim_bytes,
scratch.ple_act_window,
0,
ple_dim_bytes,
);
enc.copy_buffer_to_buffer(
cap.ple_proj,
pos_off * d_model_bytes,
scratch.ple_proj_window,
0,
d_model_bytes,
);
}
if self.cfg.has_ple() {
let ple_dim = self.cfg.ple_dim as usize;
let inp_gate_w = wc.buffer_async(&format!("{prefix}inp_gate.weight")).await?;
let proj_w = wc.buffer_async("per_layer_model_proj.weight").await?;
let post_norm_w = wc.buffer_async("per_layer_proj_norm.weight").await?;
rmsnorm_backward_chained(
&self.ctx,
&self.pipes,
enc,
scratch.ple_proj_window,
&post_norm_w,
scratch.d_hidden,
scratch.d_hidden_tmp,
d_model,
eps,
true,
);
matmul_q4_k_backward_input_chained(
&self.ctx,
&self.pipes,
enc,
&proj_w,
scratch.d_hidden_tmp,
scratch.d_ple_act,
ple_dim,
d_model,
);
let layer_off = (i as u64) * (ple_dim as u64) * 4;
let layer_bytes = (ple_dim as u64) * 4;
enc.copy_buffer_to_buffer(
&self.per_layer,
layer_off,
scratch.ple_per_layer_tmp,
0,
layer_bytes,
);
geglu_backward_chained(
&self.ctx,
&self.pipes,
enc,
scratch.ple_state_window,
scratch.ple_per_layer_tmp,
scratch.d_ple_act,
scratch.d_ple_state,
scratch.d_ple_up_discard,
ple_dim,
);
matmul_q4_k_backward_input_chained(
&self.ctx,
&self.pipes,
enc,
&inp_gate_w,
scratch.d_ple_state,
scratch.d_hidden_tmp,
d_model,
ple_dim,
);
residual_add_chained(
&self.ctx,
&self.pipes,
enc,
scratch.d_hidden,
scratch.d_hidden_tmp,
d_model,
);
}
rmsnorm_backward_chained(
&self.ctx,
&self.pipes,
enc,
scratch.ffn_out_window,
&post_ffw_w,
scratch.d_hidden,
scratch.d_hidden_tmp,
d_model,
eps,
true,
);
match down_dtype {
GgmlDtype::Q6_K => matmul_q6_k_backward_input_chained(
&self.ctx,
&self.pipes,
enc,
&down_w,
scratch.d_hidden_tmp,
scratch.d_ffn_a,
ffn_n,
d_model,
),
GgmlDtype::Q4_K => matmul_q4_k_backward_input_chained(
&self.ctx,
&self.pipes,
enc,
&down_w,
scratch.d_hidden_tmp,
scratch.d_ffn_a,
ffn_n,
d_model,
),
other => {
return Err(RullamaError::Inference(format!(
"ffn_down dtype {other:?} unsupported in backward"
)));
}
}
if let (Some(d_lora), Some(d_grad)) = (lora.ffn_down.as_ref(), grad.ffn_down.as_ref()) {
let r = d_lora.rank as usize;
let s = d_lora.scale;
lora_outer_add_chained(
&self.ctx,
&self.pipes,
enc,
scratch.d_hidden_tmp,
d_lora.z,
d_grad.d_b,
d_model,
r,
s,
true,
);
lora_matmul_col_chained(
&self.ctx,
&self.pipes,
enc,
d_lora.b,
scratch.d_hidden_tmp,
d_lora.z,
d_model,
r,
1.0,
false,
);
lora_matmul_col_chained(
&self.ctx,
&self.pipes,
enc,
d_lora.a,
d_lora.z,
scratch.d_ffn_a,
r,
ffn_n,
s,
true,
);
lora_outer_add_chained(
&self.ctx,
&self.pipes,
enc,
d_lora.z,
scratch.ffn_act_window,
d_grad.d_a,
r,
ffn_n,
s,
true,
);
}
geglu_backward_chained(
&self.ctx,
&self.pipes,
enc,
scratch.ffn_gate_window,
scratch.ffn_up_window,
scratch.d_ffn_a,
scratch.d_ffn_b,
scratch.d_ffn_c,
ffn_n,
);
matmul_q4_k_backward_input_chained(
&self.ctx,
&self.pipes,
enc,
&gate_w,
scratch.d_ffn_b,
scratch.d_hidden_tmp,
d_model,
ffn_n,
);
if let (Some(g_lora), Some(g_grad)) = (lora.ffn_gate.as_ref(), grad.ffn_gate.as_ref()) {
let r = g_lora.rank as usize;
let s = g_lora.scale;
lora_outer_add_chained(
&self.ctx,
&self.pipes,
enc,
scratch.d_ffn_b,
g_lora.z,
g_grad.d_b,
ffn_n,
r,
s,
true,
);
lora_matmul_col_chained(
&self.ctx,
&self.pipes,
enc,
g_lora.b,
scratch.d_ffn_b,
g_lora.z,
ffn_n,
r,
1.0,
false,
);
lora_matmul_col_chained(
&self.ctx,
&self.pipes,
enc,
g_lora.a,
g_lora.z,
scratch.d_hidden_tmp,
r,
d_model,
s,
true,
);
lora_outer_add_chained(
&self.ctx,
&self.pipes,
enc,
g_lora.z,
scratch.norm_x_ffn_window,
g_grad.d_a,
r,
d_model,
s,
true,
);
}
matmul_q4_k_backward_input_chained(
&self.ctx,
&self.pipes,
enc,
&up_w,
scratch.d_ffn_c,
scratch.d_hidden_tmp2,
d_model,
ffn_n,
);
if let (Some(u_lora), Some(u_grad)) = (lora.ffn_up.as_ref(), grad.ffn_up.as_ref()) {
let r = u_lora.rank as usize;
let s = u_lora.scale;
lora_outer_add_chained(
&self.ctx,
&self.pipes,
enc,
scratch.d_ffn_c,
u_lora.z,
u_grad.d_b,
ffn_n,
r,
s,
true,
);
lora_matmul_col_chained(
&self.ctx,
&self.pipes,
enc,
u_lora.b,
scratch.d_ffn_c,
u_lora.z,
ffn_n,
r,
1.0,
false,
);
lora_matmul_col_chained(
&self.ctx,
&self.pipes,
enc,
u_lora.a,
u_lora.z,
scratch.d_hidden_tmp2,
r,
d_model,
s,
true,
);
lora_outer_add_chained(
&self.ctx,
&self.pipes,
enc,
u_lora.z,
scratch.norm_x_ffn_window,
u_grad.d_a,
r,
d_model,
s,
true,
);
}
residual_add_chained(
&self.ctx,
&self.pipes,
enc,
scratch.d_hidden_tmp,
scratch.d_hidden_tmp2,
d_model,
);
rmsnorm_backward_chained(
&self.ctx,
&self.pipes,
enc,
scratch.pre_ffn_rms_window,
&mlp_norm_w,
scratch.d_hidden_tmp,
scratch.d_hidden_tmp2,
d_model,
eps,
true,
);
residual_add_chained(
&self.ctx,
&self.pipes,
enc,
scratch.d_hidden,
scratch.d_hidden_tmp2,
d_model,
);
rmsnorm_backward_chained(
&self.ctx,
&self.pipes,
enc,
scratch.attn_proj_window,
&post_attn_w,
scratch.d_hidden,
scratch.d_hidden_tmp,
d_model,
eps,
true,
);
matmul_q4_k_backward_input_chained(
&self.ctx,
&self.pipes,
enc,
&o_w,
scratch.d_hidden_tmp,
scratch.d_attn_out,
n_heads * head_dim,
d_model,
);
if let (Some(o_lora), Some(o_grad)) = (lora.o.as_ref(), grad.o.as_ref()) {
let r = o_lora.rank as usize;
let s = o_lora.scale;
lora_outer_add_chained(
&self.ctx,
&self.pipes,
enc,
scratch.d_hidden_tmp,
o_lora.z,
o_grad.d_b,
d_model,
r,
s,
true,
);
lora_matmul_col_chained(
&self.ctx,
&self.pipes,
enc,
o_lora.b,
scratch.d_hidden_tmp,
o_lora.z,
d_model,
r,
1.0,
false,
);
lora_matmul_col_chained(
&self.ctx,
&self.pipes,
enc,
o_lora.a,
o_lora.z,
scratch.d_attn_out,
r,
n_heads * head_dim,
s,
true,
);
lora_outer_add_chained(
&self.ctx,
&self.pipes,
enc,
o_lora.z,
scratch.attn_out_window,
o_grad.d_a,
r,
n_heads * head_dim,
s,
true,
);
}
let window = if matches!(kind, LayerKind::SlidingWindow) {
self.cfg.sliding_window as usize
} else {
0
};
attention_probs_chained(
&self.ctx,
&self.pipes,
enc,
scratch.q_post_rope_window,
&self.kv_k[i as usize],
scratch.attn_probs,
head_dim,
n_heads,
n_kv_heads,
pos as usize,
history_len as usize,
window,
);
attention_backward_dq_chained(
&self.ctx,
&self.pipes,
enc,
&self.kv_k[i as usize],
&self.kv_v[i as usize],
scratch.attn_probs,
scratch.d_attn_out,
scratch.attn_d_scores,
scratch.d_q,
head_dim,
n_heads,
n_kv_heads,
history_len as usize,
);
attention_backward_dkv_chained(
&self.ctx,
&self.pipes,
enc,
scratch.q_post_rope_window,
scratch.attn_probs,
scratch.d_attn_out,
scratch.attn_d_scores,
scratch.d_k_hist,
scratch.d_v_hist,
head_dim,
n_heads,
n_kv_heads,
history_len as usize,
);
let (rope_base, rope_dims) = match kind {
LayerKind::SlidingWindow => {
(self.cfg.rope_freq_base_swa, self.cfg.rope_dim_swa as usize)
}
LayerKind::Global => (self.cfg.rope_freq_base, self.cfg.rope_dim_global as usize),
};
rope_neox_backward_chained(
&self.ctx,
&self.pipes,
enc,
scratch.d_q,
factors_w.as_ref(),
&self.dummy,
head_dim,
n_heads,
pos as usize,
rope_dims,
rope_base,
);
rmsnorm_per_row_backward_chained(
&self.ctx,
&self.pipes,
enc,
scratch.q_pre_norm_window,
&q_norm_w,
scratch.d_q,
scratch.d_q_pre_norm,
n_heads,
head_dim,
eps,
true,
);
matmul_q4_k_backward_input_chained(
&self.ctx,
&self.pipes,
enc,
&q_w,
scratch.d_q_pre_norm,
scratch.d_hidden_tmp,
d_model,
n_heads * head_dim,
);
if let (Some(q_lora), Some(q_grad)) = (lora.q.as_ref(), grad.q.as_ref()) {
let r = q_lora.rank as usize;
let s = q_lora.scale;
lora_outer_add_chained(
&self.ctx,
&self.pipes,
enc,
scratch.d_q_pre_norm,
q_lora.z,
q_grad.d_b,
n_heads * head_dim,
r,
s,
true,
);
lora_matmul_col_chained(
&self.ctx,
&self.pipes,
enc,
q_lora.b,
scratch.d_q_pre_norm,
q_lora.z,
n_heads * head_dim,
r,
1.0,
false,
);
lora_matmul_col_chained(
&self.ctx,
&self.pipes,
enc,
q_lora.a,
q_lora.z,
scratch.d_hidden_tmp,
r,
d_model,
s,
true,
);
lora_outer_add_chained(
&self.ctx,
&self.pipes,
enc,
q_lora.z,
scratch.norm_x_attn_window,
q_grad.d_a,
r,
d_model,
s,
true,
);
}
let donor = self.donor_map[i as usize];
if donor.is_none() {
let row_bytes = (n_kv_heads * head_dim * 4) as u64;
let dk_final_off = pos as u64 * row_bytes;
enc.copy_buffer_to_buffer(
scratch.d_k_hist,
dk_final_off,
scratch.d_k_pre_rope,
0,
row_bytes,
);
rope_neox_backward_chained(
&self.ctx,
&self.pipes,
enc,
scratch.d_k_pre_rope,
factors_w.as_ref(),
&self.dummy,
head_dim,
n_kv_heads,
pos as usize,
rope_dims,
rope_base,
);
rmsnorm_per_row_backward_chained(
&self.ctx,
&self.pipes,
enc,
scratch.k_pre_norm_window,
&k_norm_w,
scratch.d_k_pre_rope,
scratch.d_k_pre_norm,
n_kv_heads,
head_dim,
eps,
true,
);
matmul_q4_k_backward_input_chained(
&self.ctx,
&self.pipes,
enc,
&k_w,
scratch.d_k_pre_norm,
scratch.d_hidden_tmp2,
d_model,
n_kv_heads * head_dim,
);
residual_add_chained(
&self.ctx,
&self.pipes,
enc,
scratch.d_hidden_tmp,
scratch.d_hidden_tmp2,
d_model,
);
if let (Some(k_lora), Some(k_grad)) = (lora.k.as_ref(), grad.k.as_ref()) {
let r = k_lora.rank as usize;
let s = k_lora.scale;
lora_outer_add_chained(
&self.ctx,
&self.pipes,
enc,
scratch.d_k_pre_norm,
k_lora.z,
k_grad.d_b,
n_kv_heads * head_dim,
r,
s,
true,
);
lora_matmul_col_chained(
&self.ctx,
&self.pipes,
enc,
k_lora.b,
scratch.d_k_pre_norm,
k_lora.z,
n_kv_heads * head_dim,
r,
1.0,
false,
);
lora_matmul_col_chained(
&self.ctx,
&self.pipes,
enc,
k_lora.a,
k_lora.z,
scratch.d_hidden_tmp,
r,
d_model,
s,
true,
);
lora_outer_add_chained(
&self.ctx,
&self.pipes,
enc,
k_lora.z,
scratch.norm_x_attn_window,
k_grad.d_a,
r,
d_model,
s,
true,
);
}
enc.copy_buffer_to_buffer(
scratch.d_v_hist,
dk_final_off,
scratch.d_k_pre_norm,
0,
row_bytes,
);
rmsnorm_per_row_backward_chained(
&self.ctx,
&self.pipes,
enc,
scratch.v_pre_norm_window,
&self.dummy,
scratch.d_k_pre_norm,
scratch.d_v_pre_norm,
n_kv_heads,
head_dim,
eps,
false,
);
match v_w_dtype {
GgmlDtype::Q6_K => matmul_q6_k_backward_input_chained(
&self.ctx,
&self.pipes,
enc,
&v_w,
scratch.d_v_pre_norm,
scratch.d_hidden_tmp2,
d_model,
n_kv_heads * head_dim,
),
GgmlDtype::Q4_K => matmul_q4_k_backward_input_chained(
&self.ctx,
&self.pipes,
enc,
&v_w,
scratch.d_v_pre_norm,
scratch.d_hidden_tmp2,
d_model,
n_kv_heads * head_dim,
),
other => {
return Err(RullamaError::Inference(format!(
"attn_v dtype {other:?} unsupported in backward"
)));
}
}
residual_add_chained(
&self.ctx,
&self.pipes,
enc,
scratch.d_hidden_tmp,
scratch.d_hidden_tmp2,
d_model,
);
if let (Some(v_lora), Some(v_grad)) = (lora.v.as_ref(), grad.v.as_ref()) {
let r = v_lora.rank as usize;
let s = v_lora.scale;
lora_outer_add_chained(
&self.ctx,
&self.pipes,
enc,
scratch.d_v_pre_norm,
v_lora.z,
v_grad.d_b,
n_kv_heads * head_dim,
r,
s,
true,
);
lora_matmul_col_chained(
&self.ctx,
&self.pipes,
enc,
v_lora.b,
scratch.d_v_pre_norm,
v_lora.z,
n_kv_heads * head_dim,
r,
1.0,
false,
);
lora_matmul_col_chained(
&self.ctx,
&self.pipes,
enc,
v_lora.a,
v_lora.z,
scratch.d_hidden_tmp,
r,
d_model,
s,
true,
);
lora_outer_add_chained(
&self.ctx,
&self.pipes,
enc,
v_lora.z,
scratch.norm_x_attn_window,
v_grad.d_a,
r,
d_model,
s,
true,
);
}
for hp_u in 0..history_len {
if hp_u == pos {
continue;
}
let hp = hp_u as usize;
let p_kv_off = hp_u as u64 * row_bytes;
let p_dm_off = hp_u as u64 * d_model_bytes;
enc.copy_buffer_to_buffer(
cap.norm_x_attn,
p_dm_off,
scratch.norm_x_attn_window,
0,
d_model_bytes,
);
enc.copy_buffer_to_buffer(
cap.k_pre_norm,
p_kv_off,
scratch.k_pre_norm_window,
0,
row_bytes,
);
enc.copy_buffer_to_buffer(
scratch.d_k_hist,
p_kv_off,
scratch.d_k_pre_rope,
0,
row_bytes,
);
rope_neox_backward_chained(
&self.ctx,
&self.pipes,
enc,
scratch.d_k_pre_rope,
factors_w.as_ref(),
&self.dummy,
head_dim,
n_kv_heads,
hp,
rope_dims,
rope_base,
);
rmsnorm_per_row_backward_chained(
&self.ctx,
&self.pipes,
enc,
scratch.k_pre_norm_window,
&k_norm_w,
scratch.d_k_pre_rope,
scratch.d_k_pre_norm,
n_kv_heads,
head_dim,
eps,
true,
);
if let (Some(k_lora), Some(k_grad)) = (lora.k.as_ref(), grad.k.as_ref()) {
let r = k_lora.rank as usize;
let s = k_lora.scale;
lora_matmul_row_chained(
&self.ctx,
&self.pipes,
enc,
k_lora.a,
scratch.norm_x_attn_window,
k_lora.z,
d_model,
r,
1.0,
false,
);
lora_outer_add_chained(
&self.ctx,
&self.pipes,
enc,
scratch.d_k_pre_norm,
k_lora.z,
k_grad.d_b,
n_kv_heads * head_dim,
r,
s,
true,
);
lora_matmul_col_chained(
&self.ctx,
&self.pipes,
enc,
k_lora.b,
scratch.d_k_pre_norm,
k_lora.z,
n_kv_heads * head_dim,
r,
1.0,
false,
);
lora_outer_add_chained(
&self.ctx,
&self.pipes,
enc,
k_lora.z,
scratch.norm_x_attn_window,
k_grad.d_a,
r,
d_model,
s,
true,
);
}
enc.copy_buffer_to_buffer(
cap.v_pre_norm,
p_kv_off,
scratch.v_pre_norm_window,
0,
row_bytes,
);
enc.copy_buffer_to_buffer(
scratch.d_v_hist,
p_kv_off,
scratch.d_k_pre_norm,
0,
row_bytes,
);
rmsnorm_per_row_backward_chained(
&self.ctx,
&self.pipes,
enc,
scratch.v_pre_norm_window,
&self.dummy,
scratch.d_k_pre_norm,
scratch.d_v_pre_norm,
n_kv_heads,
head_dim,
eps,
false,
);
if let (Some(v_lora), Some(v_grad)) = (lora.v.as_ref(), grad.v.as_ref()) {
let r = v_lora.rank as usize;
let s = v_lora.scale;
lora_matmul_row_chained(
&self.ctx,
&self.pipes,
enc,
v_lora.a,
scratch.norm_x_attn_window,
v_lora.z,
d_model,
r,
1.0,
false,
);
lora_outer_add_chained(
&self.ctx,
&self.pipes,
enc,
scratch.d_v_pre_norm,
v_lora.z,
v_grad.d_b,
n_kv_heads * head_dim,
r,
s,
true,
);
lora_matmul_col_chained(
&self.ctx,
&self.pipes,
enc,
v_lora.b,
scratch.d_v_pre_norm,
v_lora.z,
n_kv_heads * head_dim,
r,
1.0,
false,
);
lora_outer_add_chained(
&self.ctx,
&self.pipes,
enc,
v_lora.z,
scratch.norm_x_attn_window,
v_grad.d_a,
r,
d_model,
s,
true,
);
}
}
}
enc.copy_buffer_to_buffer(
cap.norm_x_attn,
(pos as u64) * d_model_bytes,
scratch.norm_x_attn_window,
0,
d_model_bytes,
);
rmsnorm_backward_chained(
&self.ctx,
&self.pipes,
enc,
scratch.hidden_in_window,
&attn_norm_w,
scratch.d_hidden_tmp,
scratch.d_hidden_tmp2,
d_model,
eps,
true,
);
residual_add_chained(
&self.ctx,
&self.pipes,
enc,
scratch.d_hidden,
scratch.d_hidden_tmp2,
d_model,
);
Ok(())
}
}