#[cfg(feature = "gpu")]
use super::wgpu_block::{WgpuBlock, WgpuBlockManager};
#[cfg(feature = "gpu")]
use super::wgpu_training::WgpuTrainer;
#[cfg(feature = "gpu")]
use trueno::backends::gpu::wgpu;
#[cfg(feature = "gpu")]
pub struct WgslBackwardPass {
trainer: WgpuTrainer,
}
#[cfg(feature = "gpu")]
impl WgslBackwardPass {
pub fn new(trainer: WgpuTrainer) -> Self {
Self { trainer }
}
pub fn backward_layer(
&self,
block: &WgpuBlock,
mgr: &WgpuBlockManager,
grad_output: &wgpu::Buffer, layer_input: &wgpu::Buffer, seq_len: u32,
) -> wgpu::Buffer {
let h = mgr.hidden_size;
let inter = mgr.intermediate_size;
let q_dim = mgr.num_heads * mgr.head_dim;
let _kv_dim = mgr.num_kv_heads * mgr.head_dim;
let grad_silu = self.trainer.zeros((seq_len * inter) as usize);
let grad_down_b = self.trainer.zeros(0); self.trainer.matmul_backward(
&mgr.ffn_silu_buf,
&block.w_down,
grad_output,
&grad_silu,
&grad_down_b,
seq_len,
inter,
h,
);
let grad_norm = self.trainer.zeros((seq_len * h) as usize);
let grad_gate_b = self.trainer.zeros(0);
self.trainer.matmul_backward(
&mgr.norm_buf,
&block.w_gate,
&grad_silu,
&grad_norm,
&grad_gate_b,
seq_len,
h,
inter,
);
let grad_up_b = self.trainer.zeros(0);
let grad_norm2 = self.trainer.zeros((seq_len * h) as usize);
self.trainer.matmul_backward(
&mgr.norm_buf,
&block.w_up,
&grad_silu,
&grad_norm2,
&grad_up_b,
seq_len,
h,
inter,
);
let gn1 = self.trainer.download(&grad_norm);
let gn2 = self.trainer.download(&grad_norm2);
let combined: Vec<f32> = gn1.iter().zip(gn2.iter()).map(|(a, b)| a + b).collect();
let _grad_ffn_norm = self.trainer.upload(&combined);
let grad_attn = self.trainer.zeros((seq_len * q_dim) as usize);
let grad_o_b = self.trainer.zeros(0);
self.trainer.matmul_backward(
&mgr.attn_out_buf,
&block.w_o,
grad_output,
&grad_attn,
&grad_o_b,
seq_len,
q_dim,
h,
);
if let Some(lora) = &block.lora {
self.compute_lora_gradients(block, mgr, grad_output, layer_input, lora, seq_len);
}
let grad_input_data = self.trainer.download(grad_output);
self.trainer.upload(&grad_input_data)
}
fn compute_lora_gradients(
&self,
_block: &WgpuBlock,
mgr: &WgpuBlockManager,
grad_output: &wgpu::Buffer,
layer_input: &wgpu::Buffer,
lora: &super::wgpu_block::WgpuLoraAdapters,
seq_len: u32,
) {
let h = mgr.hidden_size;
let rank = lora.rank;
let xa_q = self.trainer.zeros((seq_len * rank) as usize);
self.trainer.matmul_forward(layer_input, &lora.a_q, &xa_q, seq_len, h, rank);
let _grad_lora_q = self.trainer.zeros((seq_len * h) as usize);
let grad_b_q = self.trainer.zeros((rank * h) as usize);
let grad_a_q = self.trainer.zeros((h * rank) as usize);
self.trainer.matmul_backward(
&xa_q,
&lora.b_q,
grad_output,
&xa_q,
&grad_b_q,
seq_len,
rank,
h,
);
let grad_xb = self.trainer.zeros((seq_len * rank) as usize);
self.trainer.matmul_backward(
layer_input,
&lora.a_q,
&grad_xb,
&self.trainer.zeros((seq_len * h) as usize),
&grad_a_q,
seq_len,
h,
rank,
);
}
pub fn trainer(&self) -> &WgpuTrainer {
&self.trainer
}
pub fn trainer_mut(&mut self) -> &mut WgpuTrainer {
&mut self.trainer
}
}