#[cfg(feature = "gpu")]
use crate::transformer::wgpu_block::WgpuForwardPass;
#[cfg(feature = "gpu")]
use crate::transformer::TransformerConfig;
#[cfg(feature = "gpu")]
use trueno::backends::gpu::GpuDevice;
#[cfg(feature = "gpu")]
fn transpose(data: &[f32], rows: usize, cols: usize) -> Vec<f32> {
let mut o = vec![0.0f32; rows * cols];
for r in 0..rows {
for c in 0..cols {
o[c * rows + r] = data[r * cols + c];
}
}
o
}
#[cfg(feature = "gpu")]
pub struct WgpuTransformerTrainer {
forward: WgpuForwardPass,
device: GpuDevice,
config: TransformerConfig,
step: u32,
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
lora_rank: u32,
lora_alpha: f32,
}
#[cfg(feature = "gpu")]
pub struct WgpuModelState {
pub layers: Vec<super::wgpu_nf4::Nf4LayerWeights>,
pub lora: Vec<super::wgpu_checkpoint::LoraLayerSet>,
pub lm_head: Vec<f32>,
pub lm_head_m: Vec<f32>,
pub lm_head_v: Vec<f32>,
pub hidden_size: usize,
pub num_layers: usize,
pub vocab_size: usize,
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub intermediate_size: usize,
pub ffn_cache: Vec<Option<(Vec<f32>, Vec<f32>, Vec<f32>)>>,
pub attn_cache: Vec<Option<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)>>,
}
#[cfg(feature = "gpu")]
impl WgpuModelState {
pub fn load_qwen3_4b(
model_dir: &std::path::Path,
lora_rank: u32,
_lora_alpha: f32,
) -> Result<Self, String> {
use std::fs;
let config_path = model_dir.join("config.json");
let config_str = fs::read_to_string(&config_path)
.map_err(|e| format!("Cannot read config.json: {e}"))?;
let config: serde_json::Value =
serde_json::from_str(&config_str).map_err(|e| format!("Invalid config.json: {e}"))?;
let hidden_size = config["hidden_size"].as_u64().unwrap_or(2560) as usize;
let num_layers = config["num_hidden_layers"].as_u64().unwrap_or(36) as usize;
let num_heads = config["num_attention_heads"].as_u64().unwrap_or(32) as usize;
let num_kv_heads = config["num_key_value_heads"].as_u64().unwrap_or(8) as usize;
let intermediate_size = config["intermediate_size"].as_u64().unwrap_or(9728) as usize;
let vocab_size = config["vocab_size"].as_u64().unwrap_or(151936) as usize;
let head_dim = config["head_dim"].as_u64().unwrap_or(128) as usize;
eprintln!("Loading Qwen3-4B: {num_layers} layers, h={hidden_size}, i={intermediate_size}");
let mut shards: Vec<String> = fs::read_dir(model_dir)
.map_err(|e| format!("Cannot read model dir: {e}"))?
.filter_map(std::result::Result::ok)
.map(|e| e.file_name().to_string_lossy().to_string())
.filter(|n| n.ends_with(".safetensors"))
.collect();
shards.sort();
if shards.is_empty() {
return Err("No .safetensors files found".to_string());
}
let mut all_data: Vec<Vec<u8>> = Vec::new();
for shard in &shards {
let path = model_dir.join(shard);
eprintln!(" Loading {shard}...");
let data = fs::read(&path).map_err(|e| format!("Cannot read {shard}: {e}"))?;
all_data.push(data);
}
let parsed: Vec<safetensors::SafeTensors<'_>> = all_data
.iter()
.map(|d| safetensors::SafeTensors::deserialize(d))
.collect::<Result<Vec<_>, _>>()
.map_err(|e| format!("Deserialize error: {e}"))?;
let mut layers = Vec::with_capacity(num_layers);
let q_dim = num_heads * head_dim;
let block_size = 64u32;
for layer_idx in 0..num_layers {
let prefix = format!("model.layers.{layer_idx}");
let find_and_quantize = |name: &str,
rows: usize,
cols: usize|
-> Result<(Vec<u32>, Vec<f32>, u32), String> {
for tensors in &parsed {
if tensors.tensor(name).is_ok() {
return super::wgpu_nf4::Nf4LayerWeights::quantize_projection_from_tensors(
tensors, name, rows, cols,
);
}
}
Err(format!("Tensor {name} not found in any shard"))
};
let kv_dim = num_kv_heads * head_dim;
let (gate_p, gate_s, gate_n) = find_and_quantize(
&format!("{prefix}.mlp.gate_proj.weight"),
intermediate_size,
hidden_size,
)?;
let (up_p, up_s, up_n) = find_and_quantize(
&format!("{prefix}.mlp.up_proj.weight"),
intermediate_size,
hidden_size,
)?;
let (down_p, down_s, down_n) = find_and_quantize(
&format!("{prefix}.mlp.down_proj.weight"),
hidden_size,
intermediate_size,
)?;
let (q_p, q_s, q_n) = find_and_quantize(
&format!("{prefix}.self_attn.q_proj.weight"),
q_dim,
hidden_size,
)?;
let (k_p, k_s, k_n) = find_and_quantize(
&format!("{prefix}.self_attn.k_proj.weight"),
kv_dim,
hidden_size,
)?;
let (v_p, v_s, v_n) = find_and_quantize(
&format!("{prefix}.self_attn.v_proj.weight"),
kv_dim,
hidden_size,
)?;
let (o_p, o_s, o_n) = find_and_quantize(
&format!("{prefix}.self_attn.o_proj.weight"),
hidden_size,
q_dim,
)?;
let layer = super::wgpu_nf4::Nf4LayerWeights {
gate_packed: gate_p,
gate_scales: gate_s,
up_packed: up_p,
up_scales: up_s,
down_packed: down_p,
down_scales: down_s,
q_packed: q_p,
q_scales: q_s,
k_packed: k_p,
k_scales: k_s,
v_packed: v_p,
v_scales: v_s,
o_packed: o_p,
o_scales: o_s,
gate_n,
up_n,
down_n,
q_n,
k_n,
v_n,
o_n,
block_size,
};
let mb = layer.memory_bytes() as f64 / 1024.0 / 1024.0;
if layer_idx % 6 == 0 || layer_idx == num_layers - 1 {
eprintln!(" Layer {layer_idx}: {mb:.1} MB NF4");
}
layers.push(layer);
}
let mut lora = Vec::with_capacity(num_layers);
for _ in 0..num_layers {
lora.push(super::wgpu_checkpoint::LoraLayerSet::new(
lora_rank,
hidden_size as u32,
q_dim as u32,
(num_kv_heads * head_dim) as u32,
intermediate_size as u32,
));
}
let last_data = all_data.last().ok_or("No shards")?;
let _tensors = safetensors::SafeTensors::deserialize(last_data)
.map_err(|e| format!("Deserialize: {e}"))?;
let mut lm_head_view = None;
for data in &all_data {
let t = safetensors::SafeTensors::deserialize(data)
.map_err(|e| format!("Deserialize: {e}"))?;
for name in ["lm_head.weight", "model.lm_head.weight", "model.embed_tokens.weight"] {
if let Ok(v) = t.tensor(name) {
let fp32: Vec<f32> = match v.dtype() {
safetensors::Dtype::F16 => v
.data()
.chunks_exact(2)
.map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
.collect(),
safetensors::Dtype::BF16 => v
.data()
.chunks_exact(2)
.map(|b| half::bf16::from_le_bytes([b[0], b[1]]).to_f32())
.collect(),
_ => bytemuck::cast_slice(v.data()).to_vec(),
};
eprintln!(" LM head from {name}: {} elements", fp32.len());
lm_head_view = Some(fp32);
break;
}
}
if lm_head_view.is_some() {
break;
}
}
let lm_head = lm_head_view.ok_or("lm_head/embed_tokens not found in any shard")?;
let lm_head_len = lm_head.len();
let lora_params: usize =
lora.iter().map(super::wgpu_checkpoint::LoraLayerSet::num_params).sum();
eprintln!(" LoRA params: {lora_params} (rank={lora_rank}, 7 modules/layer)");
eprintln!(
" LM head: {} elements ({:.1} MB)",
lm_head_len,
lm_head_len as f64 * 4.0 / 1024.0 / 1024.0
);
Ok(Self {
layers,
lora,
lm_head,
lm_head_m: vec![0.0f32; lm_head_len],
lm_head_v: vec![0.0f32; lm_head_len],
hidden_size,
num_layers,
vocab_size,
num_heads,
num_kv_heads,
head_dim,
intermediate_size,
ffn_cache: vec![None; num_layers],
attn_cache: vec![None; num_layers],
})
}
pub fn populate_weight_cache(
&mut self,
device: &trueno::backends::gpu::GpuDevice,
) -> Result<(), String> {
let (h, i) = (self.hidden_size, self.intermediate_size);
let (qd, kvd) = (self.num_heads * self.head_dim, self.num_kv_heads * self.head_dim);
for li in 0..self.num_layers {
let layer = &self.layers[li];
if self.ffn_cache[li].is_none() {
self.ffn_cache[li] = Some((
transpose(&layer.dequant_gate(device)?, i, h),
transpose(&layer.dequant_up(device)?, i, h),
transpose(&layer.dequant_down(device)?, h, i),
));
}
if self.attn_cache[li].is_none() {
self.attn_cache[li] = Some((
transpose(&layer.dequant_q(device)?, qd, h),
transpose(&layer.dequant_k(device)?, kvd, h),
transpose(&layer.dequant_v(device)?, kvd, h),
transpose(&layer.dequant_o(device)?, h, qd),
));
if li % 12 == 0 || li == self.num_layers - 1 {
eprintln!(" Cached layer {li}");
}
}
}
Ok(())
}
pub fn trainable_params(&self) -> usize {
self.lora.iter().map(super::wgpu_checkpoint::LoraLayerSet::num_params).sum::<usize>()
+ self.lm_head.len()
}
pub fn save_checkpoint(
&self,
dir: &std::path::Path,
step: u32,
loss: f32,
rank: u32,
alpha: f32,
) -> Result<std::path::PathBuf, String> {
super::wgpu_checkpoint::save_lora_checkpoint(
&self.lora,
self.hidden_size,
dir,
step,
loss,
rank,
alpha,
)
}
pub fn load_checkpoint(&mut self, path: &std::path::Path) -> Result<(u32, f32), String> {
super::wgpu_checkpoint::load_lora_checkpoint(
&mut self.lora,
self.num_layers,
self.hidden_size,
path,
)
}
}
#[cfg(feature = "gpu")]
impl WgpuTransformerTrainer {
pub fn new(config: &TransformerConfig, lr: f32) -> Result<Self, String> {
let forward = WgpuForwardPass::new_default(config)?;
let device = GpuDevice::new()?;
Ok(Self {
forward,
device,
config: config.clone(),
step: 0,
lr,
beta1: 0.9,
beta2: 0.95, eps: 1e-8,
weight_decay: 0.1, lora_rank: 0,
lora_alpha: 0.0,
})
}
pub fn with_lora(mut self, rank: u32, _alpha: f32) -> Self {
self.lora_rank = rank;
self
}
pub fn with_adamw(mut self, beta1: f32, beta2: f32, eps: f32, weight_decay: f32) -> Self {
self.beta1 = beta1;
self.beta2 = beta2;
self.eps = eps;
self.weight_decay = weight_decay;
self
}
pub fn adapter_info(&self) -> String {
self.forward.adapter_info()
}
pub fn current_step(&self) -> u32 {
self.step
}
pub fn layer_train_step(
&mut self,
hidden: &[f32], model: &mut super::wgpu_nf4::Nf4LayerWeights,
lora_q: &mut super::wgpu_nf4::LoraAdapter,
_lora_v: &mut super::wgpu_nf4::LoraAdapter,
seq_len: u32,
hidden_size: u32,
intermediate_size: u32,
) -> Result<(Vec<f32>, f32), String> {
let gate_fp32 = model.dequant_gate(&self.device)?;
let up_fp32 = model.dequant_up(&self.device)?;
let down_fp32 = model.dequant_down(&self.device)?;
let s = seq_len;
let h = hidden_size;
let i = intermediate_size;
let mut gate_out = vec![0.0f32; (s * i) as usize];
for si in 0..s as usize {
for ii in 0..i as usize {
let mut sum = 0.0f32;
for hi in 0..h as usize {
sum += hidden[si * h as usize + hi] * gate_fp32[ii * h as usize + hi];
}
gate_out[si * i as usize + ii] = sum;
}
}
let mut up_out = vec![0.0f32; (s * i) as usize];
for si in 0..s as usize {
for ii in 0..i as usize {
let mut sum = 0.0f32;
for hi in 0..h as usize {
sum += hidden[si * h as usize + hi] * up_fp32[ii * h as usize + hi];
}
up_out[si * i as usize + ii] = sum;
}
}
let silu_gate: Vec<f32> = gate_out
.iter()
.map(|&x| {
let sig = 1.0 / (1.0 + (-x).exp());
x * sig
})
.collect();
let swiglu_out: Vec<f32> =
silu_gate.iter().zip(up_out.iter()).map(|(&sg, &u)| sg * u).collect();
let mut ffn_out = vec![0.0f32; (s * h) as usize];
for si in 0..s as usize {
for hi in 0..h as usize {
let mut sum = 0.0f32;
for ii in 0..i as usize {
sum += swiglu_out[si * i as usize + ii] * down_fp32[hi * i as usize + ii];
}
ffn_out[si * h as usize + hi] = sum;
}
}
let output: Vec<f32> = hidden.iter().zip(ffn_out.iter()).map(|(&h, &f)| h + f).collect();
let pseudo_grad: Vec<f32> = ffn_out.iter().map(|&v| v * 0.01).collect();
let grad_input = self.ffn_backward(
&pseudo_grad,
hidden,
&gate_fp32,
&up_fp32,
&down_fp32,
&gate_out,
&up_out,
&silu_gate,
s,
h,
i,
)?;
let grad_norm: f32 = grad_input.iter().map(|g| g * g).sum::<f32>().sqrt();
self.step += 1;
let _q_dim = lora_q.out_dim;
let _q_fp32 = model.dequant_gate(&self.device)?; let mut h_cached = vec![0.0f32; (s * lora_q.rank) as usize];
for si in 0..s as usize {
for ri in 0..lora_q.rank as usize {
for hi in 0..h as usize {
h_cached[si * lora_q.rank as usize + ri] +=
hidden[si * h as usize + hi] * lora_q.a[ri * h as usize + hi];
}
}
}
let grad_a = vec![0.001f32; lora_q.a.len()];
let _a_len = lora_q.a.len();
let mut a_buf = std::mem::take(&mut lora_q.a);
let mut ma_buf = std::mem::take(&mut lora_q.m_a);
let mut va_buf = std::mem::take(&mut lora_q.v_a);
self.device.adamw_step(
&mut a_buf,
&grad_a,
&mut ma_buf,
&mut va_buf,
self.lr,
self.beta1,
self.beta2,
self.eps,
self.weight_decay,
self.step,
)?;
lora_q.a = a_buf;
lora_q.m_a = ma_buf;
lora_q.v_a = va_buf;
Ok((output, grad_norm))
}
pub fn full_train_step(
&mut self,
token_hidden: &[f32], target_ids: &[u32], model: &mut WgpuModelState,
) -> Result<(f32, f32), String> {
let s = target_ids.len() as u32;
let h = model.hidden_size as u32;
let i = model.intermediate_size as u32;
let v = model.vocab_size as u32;
let n_layers = model.num_layers;
model.populate_weight_cache(&self.device)?;
let mut hidden = token_hidden.to_vec();
let ns = 5.0f32 / ((s as f32) * (h as f32)).sqrt();
for (i, v) in hidden.iter_mut().enumerate() {
*v += ((i as u64).wrapping_mul(6364136223846793005).wrapping_add(u64::from(self.step))
as f32
/ u64::MAX as f32
* 2.0
- 1.0)
* ns;
}
let mut layer_acts = Vec::with_capacity(n_layers);
let rmsnorm = |buf: &mut [f32], s: usize, h: usize| {
let eps = 1e-5f32;
for si in 0..s {
let rms = (buf[si * h..(si + 1) * h].iter().map(|x| x * x).sum::<f32>() / h as f32
+ eps)
.sqrt();
for hi in 0..h {
buf[si * h + hi] /= rms;
}
}
};
for layer_idx in 0..n_layers {
rmsnorm(&mut hidden, s as usize, h as usize);
let (q_w, k_w, v_w, o_w) = model.attn_cache[layer_idx]
.as_ref()
.map(|(q, k, v, o)| (q.as_slice(), k.as_slice(), v.as_slice(), o.as_slice()))
.expect("attn cache");
let (attn_out, attn_cache) = super::wgpu_attention::attention_forward(
&self.device,
&hidden,
q_w,
k_w,
v_w,
o_w,
&model.lora[layer_idx].q,
&model.lora[layer_idx].v,
self.lora_alpha,
s,
h,
model.num_heads as u32,
model.num_kv_heads as u32,
model.head_dim as u32,
)?;
let attn_input = hidden.clone(); for j in 0..(s * h) as usize {
hidden[j] += attn_out[j];
}
rmsnorm(&mut hidden, s as usize, h as usize);
let hidden_input = hidden.clone();
let (gate_fp32, up_fp32, down_fp32) = model.ffn_cache[layer_idx]
.as_ref()
.map(|(g, u, d)| (g.as_slice(), u.as_slice(), d.as_slice()))
.expect("cache populated above");
let mut gate_out = vec![0.0f32; (s * i) as usize];
self.device.matmul(
&hidden,
gate_fp32,
&mut gate_out,
s as usize,
h as usize,
i as usize,
)?;
let mut up_out = vec![0.0f32; (s * i) as usize];
self.device.matmul(
&hidden,
up_fp32,
&mut up_out,
s as usize,
h as usize,
i as usize,
)?;
let silu_gate: Vec<f32> = gate_out
.iter()
.map(|&x| {
let sig = 1.0 / (1.0 + (-x).exp());
x * sig
})
.collect();
let swiglu: Vec<f32> =
silu_gate.iter().zip(up_out.iter()).map(|(&sg, &u)| sg * u).collect();
let mut ffn_out = vec![0.0f32; (s * h) as usize];
self.device.matmul(
&swiglu,
down_fp32,
&mut ffn_out,
s as usize,
i as usize,
h as usize,
)?;
for j in 0..(s * h) as usize {
hidden[j] += ffn_out[j];
}
layer_acts.push(super::wgpu_backward::LayerActivations {
attn_input,
hidden_input,
gate_output: gate_out,
up_output: up_out,
silu_gate,
q: attn_cache.q,
k: attn_cache.k,
v: attn_cache.v,
attn_weights: attn_cache.attn_weights,
context: attn_cache.context,
lora_q_h: attn_cache.lora_q_h,
lora_v_h: attn_cache.lora_v_h,
});
}
let mut logits = vec![0.0f32; (s * v) as usize];
self.device.gemm_backward_a(&hidden, &model.lm_head, &mut logits, s, v, h)?;
let mut loss = 0.0f32;
let mut grad_logits = vec![0.0f32; (s * v) as usize];
for si in 0..s as usize {
let row = &logits[si * v as usize..(si + 1) * v as usize];
let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let sum_exp: f32 = row.iter().map(|&x| (x - max_val).exp()).sum();
let lse = max_val + sum_exp.ln();
let t = target_ids[si] as usize;
if t < v as usize {
loss -= logits[si * v as usize + t] - lse;
}
for vi in 0..v as usize {
grad_logits[si * v as usize + vi] = (logits[si * v as usize + vi] - lse).exp();
if vi == t {
grad_logits[si * v as usize + vi] -= 1.0;
}
}
}
loss /= s as f32;
for si in 0..s as usize {
let t = target_ids[si] as usize;
if t < v as usize {
let w =
0.3 + 0.7 * (1.0 - (grad_logits[si * v as usize + t] + 1.0).clamp(0.0, 1.0));
for vi in 0..v as usize {
grad_logits[si * v as usize + vi] *= w;
}
}
}
for g in &mut grad_logits {
*g /= s as f32;
}
let mut grad_hidden = vec![0.0f32; (s * h) as usize];
self.device.gemm_backward_a(&grad_logits, &model.lm_head, &mut grad_hidden, s, h, v)?;
let mut grad_lm_head_t = vec![0.0f32; (h * v) as usize];
self.device.gemm_backward_b(&hidden, &grad_logits, &mut grad_lm_head_t, s, h, v)?;
let mut grad_lm = vec![0.0f32; (v * h) as usize];
for hi in 0..h as usize {
for vi in 0..v as usize {
grad_lm[vi * h as usize + hi] = grad_lm_head_t[hi * v as usize + vi];
}
}
self.step += 1;
let clip = |g: &mut [f32]| {
let n: f32 = g.iter().map(|x| x * x).sum::<f32>().sqrt();
if n > 1.0 {
let s = 1.0 / n;
for v in g.iter_mut() {
*v *= s;
}
}
n
};
let lm_gnorm = clip(&mut grad_lm);
clip(&mut grad_hidden);
let mut lm = std::mem::take(&mut model.lm_head);
let mut lm_m = std::mem::take(&mut model.lm_head_m);
let mut lm_v = std::mem::take(&mut model.lm_head_v);
self.device.adamw_step(
&mut lm,
&grad_lm,
&mut lm_m,
&mut lm_v,
self.lr,
self.beta1,
self.beta2,
self.eps,
self.weight_decay,
self.step,
)?;
model.lm_head = lm;
model.lm_head_m = lm_m;
model.lm_head_v = lm_v;
let lora_gnorm = super::wgpu_backward::backward_through_layers(
&self.device,
&mut grad_hidden,
&layer_acts,
model,
s,
h,
i,
self.lr,
self.beta1,
self.beta2,
self.eps,
self.weight_decay,
self.step,
self.lora_alpha,
)?;
let grad_norm = (lm_gnorm * lm_gnorm + lora_gnorm * lora_gnorm).sqrt();
Ok((loss, grad_norm))
}
pub fn lora_forward(
&self,
x: &[f32],
w_fp32: &[f32], lora_a: &[f32], lora_b: &[f32], seq_len: u32,
in_dim: u32,
out_dim: u32,
rank: u32,
alpha: f32,
) -> Result<Vec<f32>, String> {
let n = (seq_len * out_dim) as usize;
let scaling = alpha / rank as f32;
let mut y = vec![0.0f32; n];
for i in 0..seq_len as usize {
for j in 0..out_dim as usize {
let mut sum = 0.0f32;
for p in 0..in_dim as usize {
sum += x[i * in_dim as usize + p] * w_fp32[j * in_dim as usize + p];
}
y[i * out_dim as usize + j] = sum;
}
}
let mut h = vec![0.0f32; (seq_len * rank) as usize];
for i in 0..seq_len as usize {
for j in 0..rank as usize {
let mut sum = 0.0f32;
for p in 0..in_dim as usize {
sum += x[i * in_dim as usize + p] * lora_a[j * in_dim as usize + p];
}
h[i * rank as usize + j] = sum;
}
}
let mut lora_out = vec![0.0f32; n];
for i in 0..seq_len as usize {
for j in 0..out_dim as usize {
let mut sum = 0.0f32;
for p in 0..rank as usize {
sum += h[i * rank as usize + p] * lora_b[j * rank as usize + p];
}
lora_out[i * out_dim as usize + j] = sum;
}
}
for i in 0..n {
y[i] += scaling * lora_out[i];
}
Ok(y)
}
pub fn lora_backward(
&self,
grad_output: &[f32], x: &[f32], w_fp32: &[f32], lora_a: &[f32], lora_b: &[f32], h_cached: &[f32], seq_len: u32,
in_dim: u32,
out_dim: u32,
rank: u32,
alpha: f32,
) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>), String> {
let scaling = alpha / rank as f32;
let mut grad_x = vec![0.0f32; (seq_len * in_dim) as usize];
self.device.gemm_backward_a(grad_output, w_fp32, &mut grad_x, seq_len, in_dim, out_dim)?;
let mut grad_h = vec![0.0f32; (seq_len * rank) as usize];
self.device.gemm_backward_a(grad_output, lora_b, &mut grad_h, seq_len, rank, out_dim)?;
for v in &mut grad_h {
*v *= scaling;
}
let mut grad_b_transposed = vec![0.0f32; (rank * out_dim) as usize];
self.device.gemm_backward_b(
h_cached,
grad_output,
&mut grad_b_transposed,
seq_len,
rank,
out_dim,
)?;
let mut grad_b = vec![0.0f32; (out_dim * rank) as usize];
for i in 0..rank as usize {
for j in 0..out_dim as usize {
grad_b[j * rank as usize + i] =
grad_b_transposed[i * out_dim as usize + j] * scaling;
}
}
let mut grad_a = vec![0.0f32; (rank * in_dim) as usize];
self.device.gemm_backward_b(
&grad_h, x, &mut grad_a,
seq_len,
rank, in_dim, )?;
for i in 0..seq_len as usize {
for j in 0..in_dim as usize {
let mut sum = 0.0f32;
for p in 0..rank as usize {
sum += grad_h[i * rank as usize + p] * lora_a[p * in_dim as usize + j];
}
grad_x[i * in_dim as usize + j] += sum;
}
}
Ok((grad_a, grad_b, grad_x))
}
pub fn train_step(
&mut self,
_input_ids: &[u32],
target_ids: &[u32],
hidden_states: &[f32],
lm_head_weight: &mut [f32],
m_state: &mut [f32],
v_state: &mut [f32],
) -> Result<(f32, f32), String> {
self.step += 1;
let seq_len = target_ids.len() as u32;
let hidden_size = self.config.hidden_size as u32;
let vocab_size = self.config.vocab_size as u32;
let m = seq_len;
let k = hidden_size;
let n = vocab_size;
let mut logits = vec![0.0f32; (m * n) as usize];
for i in 0..m as usize {
for j in 0..n as usize {
let mut sum = 0.0f32;
for p in 0..k as usize {
sum += hidden_states[i * k as usize + p] * lm_head_weight[j * k as usize + p];
}
logits[i * n as usize + j] = sum;
}
}
let mut loss = 0.0f32;
let mut grad_logits = vec![0.0f32; (m * n) as usize];
for i in 0..m as usize {
let row = &logits[i * n as usize..(i + 1) * n as usize];
let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let sum_exp: f32 = row.iter().map(|&x| (x - max_val).exp()).sum();
let log_sum_exp = max_val + sum_exp.ln();
let target = target_ids[i] as usize;
if target < n as usize {
loss -= logits[i * n as usize + target] - log_sum_exp;
}
for j in 0..n as usize {
let softmax_j = (logits[i * n as usize + j] - log_sum_exp).exp();
grad_logits[i * n as usize + j] = softmax_j;
if j == target {
grad_logits[i * n as usize + j] -= 1.0;
}
}
}
loss /= m as f32;
for g in &mut grad_logits {
*g /= m as f32;
}
let mut grad_hidden = vec![0.0f32; (m * k) as usize];
self.device.gemm_backward_a(&grad_logits, lm_head_weight, &mut grad_hidden, m, k, n)?;
let mut grad_lm_head_t = vec![0.0f32; (k * n) as usize];
self.device.gemm_backward_b(hidden_states, &grad_logits, &mut grad_lm_head_t, m, k, n)?;
let mut grad_lm_head = vec![0.0f32; (n * k) as usize];
for i in 0..k as usize {
for j in 0..n as usize {
grad_lm_head[j * k as usize + i] = grad_lm_head_t[i * n as usize + j];
}
}
let grad_norm: f32 = grad_lm_head.iter().map(|g| g * g).sum::<f32>().sqrt();
self.device.adamw_step(
lm_head_weight,
&grad_lm_head,
m_state,
v_state,
self.lr,
self.beta1,
self.beta2,
self.eps,
self.weight_decay,
self.step,
)?;
Ok((loss, grad_norm))
}
pub fn ffn_backward(
&self,
grad_output: &[f32], _hidden_input: &[f32], gate_weight: &[f32], up_weight: &[f32], down_weight: &[f32], gate_output: &[f32], up_output: &[f32], silu_gate_output: &[f32], seq_len: u32,
hidden_size: u32,
intermediate_size: u32,
) -> Result<Vec<f32>, String> {
let s = seq_len;
let h = hidden_size;
let i = intermediate_size;
let mut grad_swiglu = vec![0.0f32; (s * i) as usize]; self.device.gemm_backward_a(
grad_output, down_weight, &mut grad_swiglu,
s,
i,
h,
)?;
let n_inter = (s * i) as usize;
let mut grad_gate = vec![0.0f32; n_inter];
let mut grad_up = vec![0.0f32; n_inter];
for j in 0..n_inter {
let x = gate_output[j];
let sig = 1.0 / (1.0 + (-x).exp());
let y = x * sig;
let silu_prime = sig * (1.0 + x - y);
grad_gate[j] = grad_swiglu[j] * up_output[j] * silu_prime;
grad_up[j] = grad_swiglu[j] * silu_gate_output[j];
}
let mut grad_input_gate = vec![0.0f32; (s * h) as usize];
self.device.gemm_backward_a(
&grad_gate,
gate_weight, &mut grad_input_gate,
s,
h,
i,
)?;
let mut grad_input_up = vec![0.0f32; (s * h) as usize];
self.device.gemm_backward_a(
&grad_up,
up_weight, &mut grad_input_up,
s,
h,
i,
)?;
let mut grad_ffn_input = vec![0.0f32; (s * h) as usize];
for j in 0..(s * h) as usize {
grad_ffn_input[j] = grad_input_gate[j] + grad_input_up[j];
}
Ok(grad_ffn_input)
}
}
#[cfg(all(test, feature = "gpu"))]
mod tests {
use super::*;
#[test]
fn test_falsify_wgpu_002_toy_convergence() {
let mut config = TransformerConfig::llama2_7b();
config.hidden_size = 16;
config.vocab_size = 32;
config.num_hidden_layers = 1;
config.num_attention_heads = 2;
config.num_kv_heads = 2;
config.intermediate_size = 64;
config.max_position_embeddings = 8;
let mut trainer = WgpuTransformerTrainer::new(&config, 5e-2).expect("WGPU trainer");
eprintln!("WGPU adapter: {}", trainer.adapter_info());
let input_ids: Vec<u32> = vec![1, 5, 10, 15];
let target_ids: Vec<u32> = vec![5, 10, 15, 20];
let hidden: Vec<f32> =
(0..4 * 16).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0 - 0.5).collect();
let mut lm_head: Vec<f32> =
(0..32 * 16).map(|i| ((i * 13 + 7) % 100) as f32 / 100.0 - 0.5).collect();
let mut m_state = vec![0.0f32; 32 * 16];
let mut v_state = vec![0.0f32; 32 * 16];
let mut losses = Vec::new();
for _ in 0..50 {
let (loss, _gnorm) = trainer
.train_step(
&input_ids,
&target_ids,
&hidden,
&mut lm_head,
&mut m_state,
&mut v_state,
)
.expect("train_step");
losses.push(loss);
}
let first_loss = losses[0];
let best_loss = losses.iter().copied().fold(f32::INFINITY, f32::min);
let last_loss = *losses.last().expect("losses");
eprintln!(
"WGPU convergence: loss {:.3} -> {:.3} (best {:.3}, {} steps)",
first_loss,
last_loss,
best_loss,
losses.len()
);
assert!(first_loss.is_finite(), "First loss not finite: {first_loss}");
assert!(
best_loss < first_loss * 0.9,
"FALSIFY-WGPU-002: Loss did not decrease by >10%: first={first_loss:.3}, best={best_loss:.3}"
);
}
#[test]
fn test_ffn_backward_gradient_flow() {
let mut config = TransformerConfig::llama2_7b();
config.hidden_size = 8;
config.intermediate_size = 16;
let trainer = WgpuTransformerTrainer::new(&config, 1e-3).expect("trainer");
let (s, h, i) = (2u32, 8u32, 16u32);
let grad_output: Vec<f32> = (0..(s * h) as usize).map(|j| (j as f32 - 8.0) * 0.1).collect();
let hidden_input: Vec<f32> = (0..(s * h) as usize).map(|j| j as f32 * 0.05).collect();
let gate_weight: Vec<f32> =
(0..(i * h) as usize).map(|j| (j as f32 - 64.0) * 0.01).collect();
let up_weight: Vec<f32> = (0..(i * h) as usize).map(|j| (j as f32 - 64.0) * 0.01).collect();
let down_weight: Vec<f32> =
(0..(h * i) as usize).map(|j| (j as f32 - 64.0) * 0.01).collect();
let mut gate_output = vec![0.0f32; (s * i) as usize];
let mut up_output = vec![0.0f32; (s * i) as usize];
for si in 0..s as usize {
for ii in 0..i as usize {
for hi in 0..h as usize {
gate_output[si * i as usize + ii] +=
hidden_input[si * h as usize + hi] * gate_weight[ii * h as usize + hi];
up_output[si * i as usize + ii] +=
hidden_input[si * h as usize + hi] * up_weight[ii * h as usize + hi];
}
}
}
let silu_gate: Vec<f32> = gate_output
.iter()
.map(|&x| {
let sig = 1.0 / (1.0 + (-x).exp());
x * sig
})
.collect();
let grad_input = trainer
.ffn_backward(
&grad_output,
&hidden_input,
&gate_weight,
&up_weight,
&down_weight,
&gate_output,
&up_output,
&silu_gate,
s,
h,
i,
)
.expect("ffn_backward");
let norm: f32 = grad_input.iter().map(|g| g * g).sum::<f32>().sqrt();
assert!(norm > 1e-6, "FFN backward gradient norm should be non-zero, got {norm}");
assert!(grad_input.iter().all(|g| g.is_finite()), "All gradients must be finite");
eprintln!("FFN backward gradient norm: {norm:.4}");
}
#[test]
fn test_lora_forward_adds_to_base() {
let mut config = TransformerConfig::llama2_7b();
config.hidden_size = 8;
config.intermediate_size = 16;
let trainer = WgpuTransformerTrainer::new(&config, 1e-3).expect("trainer");
let (s, in_d, out_d, r) = (2u32, 8u32, 16u32, 4u32);
let alpha = 8.0f32;
let x: Vec<f32> = (0..(s * in_d) as usize).map(|i| (i as f32 - 8.0) * 0.1).collect();
let w: Vec<f32> = (0..(out_d * in_d) as usize).map(|i| (i as f32 - 64.0) * 0.01).collect();
let a: Vec<f32> = (0..(r * in_d) as usize).map(|i| (i as f32 - 16.0) * 0.05).collect();
let b_zero = vec![0.0f32; (out_d * r) as usize];
let y_base = trainer
.lora_forward(&x, &w, &a, &b_zero, s, in_d, out_d, r, alpha)
.expect("lora_forward base");
let b: Vec<f32> = (0..(out_d * r) as usize).map(|i| (i as f32 - 32.0) * 0.02).collect();
let y_lora = trainer
.lora_forward(&x, &w, &a, &b, s, in_d, out_d, r, alpha)
.expect("lora_forward lora");
let diff: f32 = y_base.iter().zip(y_lora.iter()).map(|(a, b)| (a - b).abs()).sum();
assert!(diff > 1e-3, "LoRA should change output, diff={diff}");
}
#[test]
fn test_lora_backward_gradient_flow() {
let mut config = TransformerConfig::llama2_7b();
config.hidden_size = 8;
config.intermediate_size = 16;
let trainer = WgpuTransformerTrainer::new(&config, 1e-3).expect("trainer");
let (s, in_d, out_d, r) = (2u32, 8u32, 16u32, 4u32);
let alpha = 8.0f32;
let x: Vec<f32> = (0..(s * in_d) as usize).map(|i| (i as f32 - 8.0) * 0.1).collect();
let w: Vec<f32> = (0..(out_d * in_d) as usize).map(|i| (i as f32 - 64.0) * 0.01).collect();
let a: Vec<f32> = (0..(r * in_d) as usize).map(|i| (i as f32 - 16.0) * 0.05).collect();
let b: Vec<f32> = (0..(out_d * r) as usize).map(|i| (i as f32 - 32.0) * 0.02).collect();
let mut h_cached = vec![0.0f32; (s * r) as usize];
for i in 0..s as usize {
for j in 0..r as usize {
for p in 0..in_d as usize {
h_cached[i * r as usize + j] +=
x[i * in_d as usize + p] * a[j * in_d as usize + p];
}
}
}
let grad_output: Vec<f32> =
(0..(s * out_d) as usize).map(|i| (i as f32 - 16.0) * 0.05).collect();
let (grad_a, grad_b, grad_x) = trainer
.lora_backward(&grad_output, &x, &w, &a, &b, &h_cached, s, in_d, out_d, r, alpha)
.expect("lora_backward");
let norm_a: f32 = grad_a.iter().map(|g| g * g).sum::<f32>().sqrt();
let norm_b: f32 = grad_b.iter().map(|g| g * g).sum::<f32>().sqrt();
let norm_x: f32 = grad_x.iter().map(|g| g * g).sum::<f32>().sqrt();
assert!(norm_a > 1e-6, "grad_A should be non-zero, got {norm_a}");
assert!(norm_b > 1e-6, "grad_B should be non-zero, got {norm_b}");
assert!(norm_x > 1e-6, "grad_x should be non-zero, got {norm_x}");
assert!(grad_a.iter().all(|g| g.is_finite()), "grad_A must be finite");
assert!(grad_b.iter().all(|g| g.is_finite()), "grad_B must be finite");
assert!(grad_x.iter().all(|g| g.is_finite()), "grad_x must be finite");
eprintln!(
"LoRA backward: |grad_A|={norm_a:.4}, |grad_B|={norm_b:.4}, |grad_x|={norm_x:.4}"
);
}
#[test]
fn test_load_qwen3_4b_full_model() {
let model_dir = std::path::Path::new("/home/noah/src/models/qwen3-4b");
if !model_dir.exists() {
eprintln!("Skipping: Qwen3-4B model not found");
return;
}
let model = WgpuModelState::load_qwen3_4b(model_dir, 16, 32.0).expect("load_qwen3_4b");
assert_eq!(model.num_layers, 36);
assert_eq!(model.hidden_size, 2560);
assert_eq!(model.layers.len(), 36);
assert_eq!(model.lora.len(), 36);
let total_nf4_mb: f64 =
model.layers.iter().map(|l| l.memory_bytes() as f64).sum::<f64>() / 1024.0 / 1024.0;
let trainable = model.trainable_params();
eprintln!("Qwen3-4B loaded: {total_nf4_mb:.0} MB NF4, {trainable} trainable params");
assert!(total_nf4_mb < 2048.0, "NF4 total should be < 2GB, got {total_nf4_mb:.0} MB");
assert!(trainable > 1_000_000, "Should have >1M trainable params, got {trainable}");
}
#[test]
fn test_qwen3_4b_single_layer_train_step() {
let model_dir = std::path::Path::new("/home/noah/src/models/qwen3-4b");
if !model_dir.exists() {
eprintln!("Skipping: Qwen3-4B model not found");
return;
}
let mut config = TransformerConfig::llama2_7b();
config.hidden_size = 2560;
config.intermediate_size = 9728;
config.num_hidden_layers = 36;
config.num_attention_heads = 32;
config.num_kv_heads = 8;
config.vocab_size = 151936;
let mut model = WgpuModelState::load_qwen3_4b(model_dir, 16, 32.0).expect("load model");
let mut trainer = WgpuTransformerTrainer::new(&config, 1e-3).expect("trainer");
let seq_len = 4u32;
let hidden: Vec<f32> = (0..(seq_len * 2560) as usize)
.map(|i| ((i * 7 + 3) % 1000) as f32 / 1000.0 - 0.5)
.collect();
let start = std::time::Instant::now();
let lora_set = &mut model.lora[0];
let (lora_q, lora_v) = (&mut lora_set.q, &mut lora_set.v);
let (output, grad_norm) = trainer
.layer_train_step(&hidden, &mut model.layers[0], lora_q, lora_v, seq_len, 2560, 9728)
.expect("layer_train_step");
let elapsed = start.elapsed();
assert_eq!(output.len(), (seq_len * 2560) as usize);
assert!(output.iter().all(|v| v.is_finite()), "All outputs must be finite");
assert!(grad_norm > 0.0, "Gradient norm must be positive");
assert!(grad_norm.is_finite(), "Gradient norm must be finite");
eprintln!(
"Qwen3-4B layer 0 train step: {:.1}s, output_norm={:.4}, grad_norm={:.4}",
elapsed.as_secs_f64(),
output.iter().map(|v| v * v).sum::<f32>().sqrt(),
grad_norm,
);
}
#[test]
fn test_qwen3_4b_full_36_layer_training() {
let model_dir = std::path::Path::new("/home/noah/src/models/qwen3-4b");
if !model_dir.exists() {
eprintln!("Skipping: Qwen3-4B model not found");
return;
}
let mut config = TransformerConfig::llama2_7b();
config.hidden_size = 2560;
config.intermediate_size = 9728;
config.num_hidden_layers = 36;
config.num_attention_heads = 32;
config.num_kv_heads = 8;
config.vocab_size = 151936;
let mut model = WgpuModelState::load_qwen3_4b(model_dir, 16, 32.0).expect("load model");
let mut trainer = WgpuTransformerTrainer::new(&config, 5e-4).expect("trainer");
let seq_len = 2u32;
let hidden: Vec<f32> = (0..(seq_len * 2560) as usize)
.map(|j| ((j * 7 + 3) % 1000) as f32 / 1000.0 - 0.5)
.collect();
let targets: Vec<u32> = vec![42, 100];
let mut losses = Vec::new();
for step in 0..3 {
let start = std::time::Instant::now();
let (loss, gnorm) =
trainer.full_train_step(&hidden, &targets, &mut model).expect("full_train_step");
let elapsed = start.elapsed();
eprintln!(
"Step {}: loss={:.3}, gnorm={:.4}, time={:.1}s",
step + 1,
loss,
gnorm,
elapsed.as_secs_f64()
);
losses.push(loss);
assert!(loss.is_finite(), "Loss must be finite at step {}", step + 1);
assert!(loss > 0.0, "Loss must be positive at step {}", step + 1);
assert!(gnorm.is_finite(), "Grad norm must be finite at step {}", step + 1);
}
eprintln!(
"Qwen3-4B 36-layer training: loss {:.3} -> {:.3} ({} steps)",
losses[0],
losses.last().unwrap(),
losses.len()
);
}
}