1#[cfg(feature = "gpu")]
8use crate::transformer::wgpu_block::WgpuForwardPass;
9#[cfg(feature = "gpu")]
10use crate::transformer::TransformerConfig;
11#[cfg(feature = "gpu")]
12use trueno::backends::gpu::GpuDevice;
13
14#[cfg(feature = "gpu")]
16fn transpose(data: &[f32], rows: usize, cols: usize) -> Vec<f32> {
17 let mut o = vec![0.0f32; rows * cols];
18 for r in 0..rows {
19 for c in 0..cols {
20 o[c * rows + r] = data[r * cols + c];
21 }
22 }
23 o
24}
25
26#[cfg(feature = "gpu")]
27pub struct WgpuTransformerTrainer {
28 forward: WgpuForwardPass,
29 device: GpuDevice,
30 config: TransformerConfig,
31 step: u32,
32 lr: f32,
33 beta1: f32,
34 beta2: f32,
35 eps: f32,
36 weight_decay: f32,
37 lora_rank: u32,
38 lora_alpha: f32,
39}
40
41#[cfg(feature = "gpu")]
46pub struct WgpuModelState {
47 pub layers: Vec<super::wgpu_nf4::Nf4LayerWeights>,
49 pub lora: Vec<super::wgpu_checkpoint::LoraLayerSet>,
52 pub lm_head: Vec<f32>,
54 pub lm_head_m: Vec<f32>,
56 pub lm_head_v: Vec<f32>,
57 pub hidden_size: usize,
59 pub num_layers: usize,
60 pub vocab_size: usize,
61 pub num_heads: usize,
62 pub num_kv_heads: usize,
63 pub head_dim: usize,
64 pub intermediate_size: usize,
65 pub ffn_cache: Vec<Option<(Vec<f32>, Vec<f32>, Vec<f32>)>>,
67 pub attn_cache: Vec<Option<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)>>,
68}
69
70#[cfg(feature = "gpu")]
71impl WgpuModelState {
72 pub fn load_qwen3_4b(
79 model_dir: &std::path::Path,
80 lora_rank: u32,
81 _lora_alpha: f32,
82 ) -> Result<Self, String> {
83 use std::fs;
84
85 let config_path = model_dir.join("config.json");
86 let config_str = fs::read_to_string(&config_path)
87 .map_err(|e| format!("Cannot read config.json: {e}"))?;
88 let config: serde_json::Value =
89 serde_json::from_str(&config_str).map_err(|e| format!("Invalid config.json: {e}"))?;
90
91 let hidden_size = config["hidden_size"].as_u64().unwrap_or(2560) as usize;
92 let num_layers = config["num_hidden_layers"].as_u64().unwrap_or(36) as usize;
93 let num_heads = config["num_attention_heads"].as_u64().unwrap_or(32) as usize;
94 let num_kv_heads = config["num_key_value_heads"].as_u64().unwrap_or(8) as usize;
95 let intermediate_size = config["intermediate_size"].as_u64().unwrap_or(9728) as usize;
96 let vocab_size = config["vocab_size"].as_u64().unwrap_or(151936) as usize;
97 let head_dim = config["head_dim"].as_u64().unwrap_or(128) as usize;
98
99 eprintln!("Loading Qwen3-4B: {num_layers} layers, h={hidden_size}, i={intermediate_size}");
100
101 let mut shards: Vec<String> = fs::read_dir(model_dir)
103 .map_err(|e| format!("Cannot read model dir: {e}"))?
104 .filter_map(std::result::Result::ok)
105 .map(|e| e.file_name().to_string_lossy().to_string())
106 .filter(|n| n.ends_with(".safetensors"))
107 .collect();
108 shards.sort();
109
110 if shards.is_empty() {
111 return Err("No .safetensors files found".to_string());
112 }
113
114 let mut all_data: Vec<Vec<u8>> = Vec::new();
116 for shard in &shards {
117 let path = model_dir.join(shard);
118 eprintln!(" Loading {shard}...");
119 let data = fs::read(&path).map_err(|e| format!("Cannot read {shard}: {e}"))?;
120 all_data.push(data);
121 }
122
123 let parsed: Vec<safetensors::SafeTensors<'_>> = all_data
125 .iter()
126 .map(|d| safetensors::SafeTensors::deserialize(d))
127 .collect::<Result<Vec<_>, _>>()
128 .map_err(|e| format!("Deserialize error: {e}"))?;
129
130 let mut layers = Vec::with_capacity(num_layers);
132 let q_dim = num_heads * head_dim;
133 let block_size = 64u32;
134
135 for layer_idx in 0..num_layers {
136 let prefix = format!("model.layers.{layer_idx}");
137
138 let find_and_quantize = |name: &str,
140 rows: usize,
141 cols: usize|
142 -> Result<(Vec<u32>, Vec<f32>, u32), String> {
143 for tensors in &parsed {
144 if tensors.tensor(name).is_ok() {
145 return super::wgpu_nf4::Nf4LayerWeights::quantize_projection_from_tensors(
146 tensors, name, rows, cols,
147 );
148 }
149 }
150 Err(format!("Tensor {name} not found in any shard"))
151 };
152
153 let kv_dim = num_kv_heads * head_dim;
154 let (gate_p, gate_s, gate_n) = find_and_quantize(
155 &format!("{prefix}.mlp.gate_proj.weight"),
156 intermediate_size,
157 hidden_size,
158 )?;
159 let (up_p, up_s, up_n) = find_and_quantize(
160 &format!("{prefix}.mlp.up_proj.weight"),
161 intermediate_size,
162 hidden_size,
163 )?;
164 let (down_p, down_s, down_n) = find_and_quantize(
165 &format!("{prefix}.mlp.down_proj.weight"),
166 hidden_size,
167 intermediate_size,
168 )?;
169 let (q_p, q_s, q_n) = find_and_quantize(
170 &format!("{prefix}.self_attn.q_proj.weight"),
171 q_dim,
172 hidden_size,
173 )?;
174 let (k_p, k_s, k_n) = find_and_quantize(
175 &format!("{prefix}.self_attn.k_proj.weight"),
176 kv_dim,
177 hidden_size,
178 )?;
179 let (v_p, v_s, v_n) = find_and_quantize(
180 &format!("{prefix}.self_attn.v_proj.weight"),
181 kv_dim,
182 hidden_size,
183 )?;
184 let (o_p, o_s, o_n) = find_and_quantize(
185 &format!("{prefix}.self_attn.o_proj.weight"),
186 hidden_size,
187 q_dim,
188 )?;
189
190 let layer = super::wgpu_nf4::Nf4LayerWeights {
191 gate_packed: gate_p,
192 gate_scales: gate_s,
193 up_packed: up_p,
194 up_scales: up_s,
195 down_packed: down_p,
196 down_scales: down_s,
197 q_packed: q_p,
198 q_scales: q_s,
199 k_packed: k_p,
200 k_scales: k_s,
201 v_packed: v_p,
202 v_scales: v_s,
203 o_packed: o_p,
204 o_scales: o_s,
205 gate_n,
206 up_n,
207 down_n,
208 q_n,
209 k_n,
210 v_n,
211 o_n,
212 block_size,
213 };
214
215 let mb = layer.memory_bytes() as f64 / 1024.0 / 1024.0;
216 if layer_idx % 6 == 0 || layer_idx == num_layers - 1 {
217 eprintln!(" Layer {layer_idx}: {mb:.1} MB NF4");
218 }
219 layers.push(layer);
220 }
221
222 let mut lora = Vec::with_capacity(num_layers);
224 for _ in 0..num_layers {
225 lora.push(super::wgpu_checkpoint::LoraLayerSet::new(
226 lora_rank,
227 hidden_size as u32,
228 q_dim as u32,
229 (num_kv_heads * head_dim) as u32,
230 intermediate_size as u32,
231 ));
232 }
233
234 let last_data = all_data.last().ok_or("No shards")?;
236 let _tensors = safetensors::SafeTensors::deserialize(last_data)
237 .map_err(|e| format!("Deserialize: {e}"))?;
238
239 let mut lm_head_view = None;
241 for data in &all_data {
242 let t = safetensors::SafeTensors::deserialize(data)
243 .map_err(|e| format!("Deserialize: {e}"))?;
244 for name in ["lm_head.weight", "model.lm_head.weight", "model.embed_tokens.weight"] {
245 if let Ok(v) = t.tensor(name) {
246 let fp32: Vec<f32> = match v.dtype() {
248 safetensors::Dtype::F16 => v
249 .data()
250 .chunks_exact(2)
251 .map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
252 .collect(),
253 safetensors::Dtype::BF16 => v
254 .data()
255 .chunks_exact(2)
256 .map(|b| half::bf16::from_le_bytes([b[0], b[1]]).to_f32())
257 .collect(),
258 _ => bytemuck::cast_slice(v.data()).to_vec(),
259 };
260 eprintln!(" LM head from {name}: {} elements", fp32.len());
261 lm_head_view = Some(fp32);
262 break;
263 }
264 }
265 if lm_head_view.is_some() {
266 break;
267 }
268 }
269 let lm_head = lm_head_view.ok_or("lm_head/embed_tokens not found in any shard")?;
270 let lm_head_len = lm_head.len();
271 let lora_params: usize =
272 lora.iter().map(super::wgpu_checkpoint::LoraLayerSet::num_params).sum();
273 eprintln!(" LoRA params: {lora_params} (rank={lora_rank}, 7 modules/layer)");
274 eprintln!(
275 " LM head: {} elements ({:.1} MB)",
276 lm_head_len,
277 lm_head_len as f64 * 4.0 / 1024.0 / 1024.0
278 );
279 Ok(Self {
280 layers,
281 lora,
282 lm_head,
283 lm_head_m: vec![0.0f32; lm_head_len],
284 lm_head_v: vec![0.0f32; lm_head_len],
285 hidden_size,
286 num_layers,
287 vocab_size,
288 num_heads,
289 num_kv_heads,
290 head_dim,
291 intermediate_size,
292 ffn_cache: vec![None; num_layers],
293 attn_cache: vec![None; num_layers],
294 })
295 }
296
297 pub fn populate_weight_cache(
299 &mut self,
300 device: &trueno::backends::gpu::GpuDevice,
301 ) -> Result<(), String> {
302 let (h, i) = (self.hidden_size, self.intermediate_size);
303 let (qd, kvd) = (self.num_heads * self.head_dim, self.num_kv_heads * self.head_dim);
304 for li in 0..self.num_layers {
305 let layer = &self.layers[li];
306 if self.ffn_cache[li].is_none() {
307 self.ffn_cache[li] = Some((
308 transpose(&layer.dequant_gate(device)?, i, h),
309 transpose(&layer.dequant_up(device)?, i, h),
310 transpose(&layer.dequant_down(device)?, h, i),
311 ));
312 }
313 if self.attn_cache[li].is_none() {
314 self.attn_cache[li] = Some((
315 transpose(&layer.dequant_q(device)?, qd, h),
316 transpose(&layer.dequant_k(device)?, kvd, h),
317 transpose(&layer.dequant_v(device)?, kvd, h),
318 transpose(&layer.dequant_o(device)?, h, qd),
319 ));
320 if li % 12 == 0 || li == self.num_layers - 1 {
321 eprintln!(" Cached layer {li}");
322 }
323 }
324 }
325 Ok(())
326 }
327
328 pub fn trainable_params(&self) -> usize {
330 self.lora.iter().map(super::wgpu_checkpoint::LoraLayerSet::num_params).sum::<usize>()
331 + self.lm_head.len()
332 }
333 pub fn save_checkpoint(
334 &self,
335 dir: &std::path::Path,
336 step: u32,
337 loss: f32,
338 rank: u32,
339 alpha: f32,
340 ) -> Result<std::path::PathBuf, String> {
341 contract_pre_save_checkpoint!();
342 let result = super::wgpu_checkpoint::save_lora_checkpoint(
343 &self.lora,
344 self.hidden_size,
345 dir,
346 step,
347 loss,
348 rank,
349 alpha,
350 );
351 contract_post_save_checkpoint!(result);
352 result
353 }
354
355 pub fn load_checkpoint(&mut self, path: &std::path::Path) -> Result<(u32, f32), String> {
357 contract_pre_load_checkpoint!();
358 let result = super::wgpu_checkpoint::load_lora_checkpoint(
359 &mut self.lora,
360 self.num_layers,
361 self.hidden_size,
362 path,
363 );
364 contract_post_load_checkpoint!(result);
365 result
366 }
367}
368
369#[cfg(feature = "gpu")]
370impl WgpuTransformerTrainer {
371 pub fn new(config: &TransformerConfig, lr: f32) -> Result<Self, String> {
373 let forward = WgpuForwardPass::new_default(config)?;
374 let device = GpuDevice::new()?;
375
376 Ok(Self {
377 forward,
378 device,
379 config: config.clone(),
380 step: 0,
381 lr,
382 beta1: 0.9,
383 beta2: 0.95, eps: 1e-8,
385 weight_decay: 0.1, lora_rank: 0,
387 lora_alpha: 0.0,
388 })
389 }
390
391 pub fn with_lora(mut self, rank: u32, _alpha: f32) -> Self {
393 self.lora_rank = rank;
394 self
395 }
396
397 pub fn with_adamw(mut self, beta1: f32, beta2: f32, eps: f32, weight_decay: f32) -> Self {
399 self.beta1 = beta1;
400 self.beta2 = beta2;
401 self.eps = eps;
402 self.weight_decay = weight_decay;
403 self
404 }
405
406 pub fn adapter_info(&self) -> String {
408 self.forward.adapter_info()
409 }
410
411 pub fn current_step(&self) -> u32 {
413 self.step
414 }
415
416 pub fn layer_train_step(
419 &mut self,
420 hidden: &[f32], model: &mut super::wgpu_nf4::Nf4LayerWeights,
422 lora_q: &mut super::wgpu_nf4::LoraAdapter,
423 _lora_v: &mut super::wgpu_nf4::LoraAdapter,
424 seq_len: u32,
425 hidden_size: u32,
426 intermediate_size: u32,
427 ) -> Result<(Vec<f32>, f32), String> {
428 let gate_fp32 = model.dequant_gate(&self.device)?;
431 let up_fp32 = model.dequant_up(&self.device)?;
432 let down_fp32 = model.dequant_down(&self.device)?;
433
434 let s = seq_len;
435 let h = hidden_size;
436 let i = intermediate_size;
437
438 let mut gate_out = vec![0.0f32; (s * i) as usize];
440 for si in 0..s as usize {
441 for ii in 0..i as usize {
442 let mut sum = 0.0f32;
443 for hi in 0..h as usize {
444 sum += hidden[si * h as usize + hi] * gate_fp32[ii * h as usize + hi];
445 }
446 gate_out[si * i as usize + ii] = sum;
447 }
448 }
449
450 let mut up_out = vec![0.0f32; (s * i) as usize];
452 for si in 0..s as usize {
453 for ii in 0..i as usize {
454 let mut sum = 0.0f32;
455 for hi in 0..h as usize {
456 sum += hidden[si * h as usize + hi] * up_fp32[ii * h as usize + hi];
457 }
458 up_out[si * i as usize + ii] = sum;
459 }
460 }
461
462 let silu_gate: Vec<f32> = gate_out
464 .iter()
465 .map(|&x| {
466 let sig = 1.0 / (1.0 + (-x).exp());
467 x * sig
468 })
469 .collect();
470 let swiglu_out: Vec<f32> =
471 silu_gate.iter().zip(up_out.iter()).map(|(&sg, &u)| sg * u).collect();
472
473 let mut ffn_out = vec![0.0f32; (s * h) as usize];
475 for si in 0..s as usize {
476 for hi in 0..h as usize {
477 let mut sum = 0.0f32;
478 for ii in 0..i as usize {
479 sum += swiglu_out[si * i as usize + ii] * down_fp32[hi * i as usize + ii];
480 }
481 ffn_out[si * h as usize + hi] = sum;
482 }
483 }
484
485 let output: Vec<f32> = hidden.iter().zip(ffn_out.iter()).map(|(&h, &f)| h + f).collect();
487
488 let pseudo_grad: Vec<f32> = ffn_out.iter().map(|&v| v * 0.01).collect();
491
492 let grad_input = self.ffn_backward(
493 &pseudo_grad,
494 hidden,
495 &gate_fp32,
496 &up_fp32,
497 &down_fp32,
498 &gate_out,
499 &up_out,
500 &silu_gate,
501 s,
502 h,
503 i,
504 )?;
505
506 let grad_norm: f32 = grad_input.iter().map(|g| g * g).sum::<f32>().sqrt();
507
508 self.step += 1;
510 let _q_dim = lora_q.out_dim;
512 let _q_fp32 = model.dequant_gate(&self.device)?; let mut h_cached = vec![0.0f32; (s * lora_q.rank) as usize];
514 for si in 0..s as usize {
515 for ri in 0..lora_q.rank as usize {
516 for hi in 0..h as usize {
517 h_cached[si * lora_q.rank as usize + ri] +=
518 hidden[si * h as usize + hi] * lora_q.a[ri * h as usize + hi];
519 }
520 }
521 }
522
523 let grad_a = vec![0.001f32; lora_q.a.len()];
525 let _a_len = lora_q.a.len();
526 let mut a_buf = std::mem::take(&mut lora_q.a);
527 let mut ma_buf = std::mem::take(&mut lora_q.m_a);
528 let mut va_buf = std::mem::take(&mut lora_q.v_a);
529
530 self.device.adamw_step(
531 &mut a_buf,
532 &grad_a,
533 &mut ma_buf,
534 &mut va_buf,
535 self.lr,
536 self.beta1,
537 self.beta2,
538 self.eps,
539 self.weight_decay,
540 self.step,
541 )?;
542
543 lora_q.a = a_buf;
544 lora_q.m_a = ma_buf;
545 lora_q.v_a = va_buf;
546
547 Ok((output, grad_norm))
548 }
549
550 pub fn full_train_step(
552 &mut self,
553 token_hidden: &[f32], target_ids: &[u32], model: &mut WgpuModelState,
556 ) -> Result<(f32, f32), String> {
557 contract_pre_gpu_forward!();
558 let s = target_ids.len() as u32;
559 let h = model.hidden_size as u32;
560 let i = model.intermediate_size as u32;
561 let v = model.vocab_size as u32;
562 let n_layers = model.num_layers;
563
564 model.populate_weight_cache(&self.device)?;
565
566 let mut hidden = token_hidden.to_vec();
567 let ns = 5.0f32 / ((s as f32) * (h as f32)).sqrt();
569 for (i, v) in hidden.iter_mut().enumerate() {
570 *v += ((i as u64).wrapping_mul(6364136223846793005).wrapping_add(u64::from(self.step))
571 as f32
572 / u64::MAX as f32
573 * 2.0
574 - 1.0)
575 * ns;
576 }
577 let mut layer_acts = Vec::with_capacity(n_layers);
578 let rmsnorm = |buf: &mut [f32], s: usize, h: usize| {
580 let eps = 1e-5f32;
581 for si in 0..s {
582 let rms = (buf[si * h..(si + 1) * h].iter().map(|x| x * x).sum::<f32>() / h as f32
583 + eps)
584 .sqrt();
585 for hi in 0..h {
586 buf[si * h + hi] /= rms;
587 }
588 }
589 };
590
591 for layer_idx in 0..n_layers {
592 rmsnorm(&mut hidden, s as usize, h as usize);
593 let (q_w, k_w, v_w, o_w) = model.attn_cache[layer_idx]
594 .as_ref()
595 .map(|(q, k, v, o)| (q.as_slice(), k.as_slice(), v.as_slice(), o.as_slice()))
596 .expect("attn cache");
597 let (attn_out, attn_cache) = super::wgpu_attention::attention_forward(
598 &self.device,
599 &hidden,
600 q_w,
601 k_w,
602 v_w,
603 o_w,
604 &model.lora[layer_idx].q,
605 &model.lora[layer_idx].v,
606 self.lora_alpha,
607 s,
608 h,
609 model.num_heads as u32,
610 model.num_kv_heads as u32,
611 model.head_dim as u32,
612 )?;
613 let attn_input = hidden.clone(); for j in 0..(s * h) as usize {
615 hidden[j] += attn_out[j];
616 }
617 rmsnorm(&mut hidden, s as usize, h as usize); let hidden_input = hidden.clone(); let (gate_fp32, up_fp32, down_fp32) = model.ffn_cache[layer_idx]
622 .as_ref()
623 .map(|(g, u, d)| (g.as_slice(), u.as_slice(), d.as_slice()))
624 .expect("cache populated above");
625
626 let mut gate_out = vec![0.0f32; (s * i) as usize];
627 self.device.matmul(
628 &hidden,
629 gate_fp32,
630 &mut gate_out,
631 s as usize,
632 h as usize,
633 i as usize,
634 )?;
635 let mut up_out = vec![0.0f32; (s * i) as usize];
636 self.device.matmul(
637 &hidden,
638 up_fp32,
639 &mut up_out,
640 s as usize,
641 h as usize,
642 i as usize,
643 )?;
644
645 let silu_gate: Vec<f32> = gate_out
646 .iter()
647 .map(|&x| {
648 let sig = 1.0 / (1.0 + (-x).exp());
649 x * sig
650 })
651 .collect();
652 let swiglu: Vec<f32> =
653 silu_gate.iter().zip(up_out.iter()).map(|(&sg, &u)| sg * u).collect();
654
655 let mut ffn_out = vec![0.0f32; (s * h) as usize];
656 self.device.matmul(
657 &swiglu,
658 down_fp32,
659 &mut ffn_out,
660 s as usize,
661 i as usize,
662 h as usize,
663 )?;
664
665 for j in 0..(s * h) as usize {
666 hidden[j] += ffn_out[j];
667 }
668
669 layer_acts.push(super::wgpu_backward::LayerActivations {
670 attn_input,
671 hidden_input,
672 gate_output: gate_out,
673 up_output: up_out,
674 silu_gate,
675 q: attn_cache.q,
676 k: attn_cache.k,
677 v: attn_cache.v,
678 attn_weights: attn_cache.attn_weights,
679 context: attn_cache.context,
680 lora_q_h: attn_cache.lora_q_h,
681 lora_v_h: attn_cache.lora_v_h,
682 });
683 }
684
685 let mut logits = vec![0.0f32; (s * v) as usize];
686 self.device.gemm_backward_a(&hidden, &model.lm_head, &mut logits, s, v, h)?;
687 let mut loss = 0.0f32;
688 let mut grad_logits = vec![0.0f32; (s * v) as usize];
689 for si in 0..s as usize {
690 let row = &logits[si * v as usize..(si + 1) * v as usize];
691 let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
692 let sum_exp: f32 = row.iter().map(|&x| (x - max_val).exp()).sum();
693 let lse = max_val + sum_exp.ln();
694 let t = target_ids[si] as usize;
695 if t < v as usize {
696 loss -= logits[si * v as usize + t] - lse;
697 }
698 for vi in 0..v as usize {
699 grad_logits[si * v as usize + vi] = (logits[si * v as usize + vi] - lse).exp();
700 if vi == t {
701 grad_logits[si * v as usize + vi] -= 1.0;
702 }
703 }
704 }
705 loss /= s as f32;
706 for si in 0..s as usize {
708 let t = target_ids[si] as usize;
709 if t < v as usize {
710 let w =
711 0.3 + 0.7 * (1.0 - (grad_logits[si * v as usize + t] + 1.0).clamp(0.0, 1.0));
712 for vi in 0..v as usize {
713 grad_logits[si * v as usize + vi] *= w;
714 }
715 }
716 }
717 for g in &mut grad_logits {
718 *g /= s as f32;
719 }
720
721 let mut grad_hidden = vec![0.0f32; (s * h) as usize];
723 self.device.gemm_backward_a(&grad_logits, &model.lm_head, &mut grad_hidden, s, h, v)?;
724
725 let mut grad_lm_head_t = vec![0.0f32; (h * v) as usize];
726 self.device.gemm_backward_b(&hidden, &grad_logits, &mut grad_lm_head_t, s, h, v)?;
727 let mut grad_lm = vec![0.0f32; (v * h) as usize];
728 for hi in 0..h as usize {
729 for vi in 0..v as usize {
730 grad_lm[vi * h as usize + hi] = grad_lm_head_t[hi * v as usize + vi];
731 }
732 }
733
734 self.step += 1;
735 let clip = |g: &mut [f32]| {
737 let n: f32 = g.iter().map(|x| x * x).sum::<f32>().sqrt();
738 if n > 1.0 {
739 let s = 1.0 / n;
740 for v in g.iter_mut() {
741 *v *= s;
742 }
743 }
744 n
745 };
746 let lm_gnorm = clip(&mut grad_lm);
747 clip(&mut grad_hidden);
748
749 let mut lm = std::mem::take(&mut model.lm_head);
750 let mut lm_m = std::mem::take(&mut model.lm_head_m);
751 let mut lm_v = std::mem::take(&mut model.lm_head_v);
752 self.device.adamw_step(
753 &mut lm,
754 &grad_lm,
755 &mut lm_m,
756 &mut lm_v,
757 self.lr,
758 self.beta1,
759 self.beta2,
760 self.eps,
761 self.weight_decay,
762 self.step,
763 )?;
764 model.lm_head = lm;
765 model.lm_head_m = lm_m;
766 model.lm_head_v = lm_v;
767
768 let lora_gnorm = super::wgpu_backward::backward_through_layers(
770 &self.device,
771 &mut grad_hidden,
772 &layer_acts,
773 model,
774 s,
775 h,
776 i,
777 self.lr,
778 self.beta1,
779 self.beta2,
780 self.eps,
781 self.weight_decay,
782 self.step,
783 self.lora_alpha,
784 )?;
785
786 let grad_norm = (lm_gnorm * lm_gnorm + lora_gnorm * lora_gnorm).sqrt();
787 Ok((loss, grad_norm))
788 }
789
790 pub fn lora_forward(
792 &self,
793 x: &[f32],
794 w_fp32: &[f32], lora_a: &[f32], lora_b: &[f32], seq_len: u32,
798 in_dim: u32,
799 out_dim: u32,
800 rank: u32,
801 alpha: f32,
802 ) -> Result<Vec<f32>, String> {
803 let n = (seq_len * out_dim) as usize;
804 let scaling = alpha / rank as f32;
805
806 let mut y = vec![0.0f32; n];
808 for i in 0..seq_len as usize {
809 for j in 0..out_dim as usize {
810 let mut sum = 0.0f32;
811 for p in 0..in_dim as usize {
812 sum += x[i * in_dim as usize + p] * w_fp32[j * in_dim as usize + p];
813 }
814 y[i * out_dim as usize + j] = sum;
815 }
816 }
817
818 let mut h = vec![0.0f32; (seq_len * rank) as usize];
822 for i in 0..seq_len as usize {
823 for j in 0..rank as usize {
824 let mut sum = 0.0f32;
825 for p in 0..in_dim as usize {
826 sum += x[i * in_dim as usize + p] * lora_a[j * in_dim as usize + p];
827 }
828 h[i * rank as usize + j] = sum;
829 }
830 }
831
832 let mut lora_out = vec![0.0f32; n];
835 for i in 0..seq_len as usize {
836 for j in 0..out_dim as usize {
837 let mut sum = 0.0f32;
838 for p in 0..rank as usize {
839 sum += h[i * rank as usize + p] * lora_b[j * rank as usize + p];
840 }
841 lora_out[i * out_dim as usize + j] = sum;
842 }
843 }
844
845 for i in 0..n {
847 y[i] += scaling * lora_out[i];
848 }
849
850 Ok(y)
851 }
852
853 pub fn lora_backward(
855 &self,
856 grad_output: &[f32], x: &[f32], w_fp32: &[f32], lora_a: &[f32], lora_b: &[f32], h_cached: &[f32], seq_len: u32,
863 in_dim: u32,
864 out_dim: u32,
865 rank: u32,
866 alpha: f32,
867 ) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>), String> {
868 let scaling = alpha / rank as f32;
870
871 let mut grad_x = vec![0.0f32; (seq_len * in_dim) as usize];
875 self.device.gemm_backward_a(grad_output, w_fp32, &mut grad_x, seq_len, in_dim, out_dim)?;
876
877 let mut grad_h = vec![0.0f32; (seq_len * rank) as usize];
881 self.device.gemm_backward_a(grad_output, lora_b, &mut grad_h, seq_len, rank, out_dim)?;
882 for v in &mut grad_h {
883 *v *= scaling;
884 }
885
886 let mut grad_b_transposed = vec![0.0f32; (rank * out_dim) as usize];
890 self.device.gemm_backward_b(
891 h_cached,
892 grad_output,
893 &mut grad_b_transposed,
894 seq_len,
895 rank,
896 out_dim,
897 )?;
898 let mut grad_b = vec![0.0f32; (out_dim * rank) as usize];
900 for i in 0..rank as usize {
901 for j in 0..out_dim as usize {
902 grad_b[j * rank as usize + i] =
903 grad_b_transposed[i * out_dim as usize + j] * scaling;
904 }
905 }
906
907 let mut grad_a = vec![0.0f32; (rank * in_dim) as usize];
911 self.device.gemm_backward_b(
912 &grad_h, x, &mut grad_a,
915 seq_len,
916 rank, in_dim, )?;
919
920 for i in 0..seq_len as usize {
924 for j in 0..in_dim as usize {
925 let mut sum = 0.0f32;
926 for p in 0..rank as usize {
927 sum += grad_h[i * rank as usize + p] * lora_a[p * in_dim as usize + j];
928 }
929 grad_x[i * in_dim as usize + j] += sum;
930 }
931 }
932
933 Ok((grad_a, grad_b, grad_x))
934 }
935
936 pub fn train_step(
940 &mut self,
941 _input_ids: &[u32],
942 target_ids: &[u32],
943 hidden_states: &[f32],
944 lm_head_weight: &mut [f32],
945 m_state: &mut [f32],
946 v_state: &mut [f32],
947 ) -> Result<(f32, f32), String> {
948 self.step += 1;
949 let seq_len = target_ids.len() as u32;
950 let hidden_size = self.config.hidden_size as u32;
951 let vocab_size = self.config.vocab_size as u32;
952
953 let m = seq_len;
954 let k = hidden_size;
955 let n = vocab_size;
956
957 let mut logits = vec![0.0f32; (m * n) as usize];
959 for i in 0..m as usize {
960 for j in 0..n as usize {
961 let mut sum = 0.0f32;
962 for p in 0..k as usize {
963 sum += hidden_states[i * k as usize + p] * lm_head_weight[j * k as usize + p];
964 }
965 logits[i * n as usize + j] = sum;
966 }
967 }
968
969 let mut loss = 0.0f32;
971 let mut grad_logits = vec![0.0f32; (m * n) as usize];
972 for i in 0..m as usize {
973 let row = &logits[i * n as usize..(i + 1) * n as usize];
974 let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
975 let sum_exp: f32 = row.iter().map(|&x| (x - max_val).exp()).sum();
976 let log_sum_exp = max_val + sum_exp.ln();
977
978 let target = target_ids[i] as usize;
979 if target < n as usize {
980 loss -= logits[i * n as usize + target] - log_sum_exp;
981 }
982
983 for j in 0..n as usize {
984 let softmax_j = (logits[i * n as usize + j] - log_sum_exp).exp();
985 grad_logits[i * n as usize + j] = softmax_j;
986 if j == target {
987 grad_logits[i * n as usize + j] -= 1.0;
988 }
989 }
990 }
991 loss /= m as f32;
992 for g in &mut grad_logits {
993 *g /= m as f32;
994 }
995
996 let mut grad_hidden = vec![0.0f32; (m * k) as usize];
998 self.device.gemm_backward_a(&grad_logits, lm_head_weight, &mut grad_hidden, m, k, n)?;
999
1000 let mut grad_lm_head_t = vec![0.0f32; (k * n) as usize];
1010 self.device.gemm_backward_b(hidden_states, &grad_logits, &mut grad_lm_head_t, m, k, n)?;
1011
1012 let mut grad_lm_head = vec![0.0f32; (n * k) as usize];
1013 for i in 0..k as usize {
1014 for j in 0..n as usize {
1015 grad_lm_head[j * k as usize + i] = grad_lm_head_t[i * n as usize + j];
1016 }
1017 }
1018 let grad_norm: f32 = grad_lm_head.iter().map(|g| g * g).sum::<f32>().sqrt();
1019 self.device.adamw_step(
1020 lm_head_weight,
1021 &grad_lm_head,
1022 m_state,
1023 v_state,
1024 self.lr,
1025 self.beta1,
1026 self.beta2,
1027 self.eps,
1028 self.weight_decay,
1029 self.step,
1030 )?;
1031
1032 Ok((loss, grad_norm))
1033 }
1034
1035 pub fn ffn_backward(
1047 &self,
1048 grad_output: &[f32], _hidden_input: &[f32], gate_weight: &[f32], up_weight: &[f32], down_weight: &[f32], gate_output: &[f32], up_output: &[f32], silu_gate_output: &[f32], seq_len: u32,
1057 hidden_size: u32,
1058 intermediate_size: u32,
1059 ) -> Result<Vec<f32>, String> {
1060 let s = seq_len;
1061 let h = hidden_size;
1062 let i = intermediate_size;
1063
1064 let mut grad_swiglu = vec![0.0f32; (s * i) as usize]; self.device.gemm_backward_a(
1068 grad_output, down_weight, &mut grad_swiglu,
1071 s,
1072 i,
1073 h,
1074 )?;
1075
1076 let n_inter = (s * i) as usize;
1081 let mut grad_gate = vec![0.0f32; n_inter];
1082 let mut grad_up = vec![0.0f32; n_inter];
1083
1084 for j in 0..n_inter {
1087 let x = gate_output[j];
1088 let sig = 1.0 / (1.0 + (-x).exp());
1089 let y = x * sig;
1090 let silu_prime = sig * (1.0 + x - y);
1091
1092 grad_gate[j] = grad_swiglu[j] * up_output[j] * silu_prime;
1093 grad_up[j] = grad_swiglu[j] * silu_gate_output[j];
1094 }
1095
1096 let mut grad_input_gate = vec![0.0f32; (s * h) as usize];
1098 self.device.gemm_backward_a(
1099 &grad_gate,
1100 gate_weight, &mut grad_input_gate,
1102 s,
1103 h,
1104 i,
1105 )?;
1106
1107 let mut grad_input_up = vec![0.0f32; (s * h) as usize];
1109 self.device.gemm_backward_a(
1110 &grad_up,
1111 up_weight, &mut grad_input_up,
1113 s,
1114 h,
1115 i,
1116 )?;
1117
1118 let mut grad_ffn_input = vec![0.0f32; (s * h) as usize];
1120 for j in 0..(s * h) as usize {
1121 grad_ffn_input[j] = grad_input_gate[j] + grad_input_up[j];
1122 }
1123
1124 Ok(grad_ffn_input)
1125 }
1126}
1127
1128#[cfg(all(test, feature = "gpu"))]
1129mod tests {
1130 use super::*;
1131
1132 #[test]
1137 fn test_falsify_wgpu_002_toy_convergence() {
1138 let mut config = TransformerConfig::llama2_7b();
1139 config.hidden_size = 16;
1140 config.vocab_size = 32;
1141 config.num_hidden_layers = 1;
1142 config.num_attention_heads = 2;
1143 config.num_kv_heads = 2;
1144 config.intermediate_size = 64;
1145 config.max_position_embeddings = 8;
1146
1147 let mut trainer = WgpuTransformerTrainer::new(&config, 5e-2).expect("WGPU trainer");
1148
1149 eprintln!("WGPU adapter: {}", trainer.adapter_info());
1150
1151 let input_ids: Vec<u32> = vec![1, 5, 10, 15];
1152 let target_ids: Vec<u32> = vec![5, 10, 15, 20];
1153
1154 let hidden: Vec<f32> =
1156 (0..4 * 16).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0 - 0.5).collect();
1157
1158 let mut lm_head: Vec<f32> =
1160 (0..32 * 16).map(|i| ((i * 13 + 7) % 100) as f32 / 100.0 - 0.5).collect();
1161 let mut m_state = vec![0.0f32; 32 * 16];
1162 let mut v_state = vec![0.0f32; 32 * 16];
1163
1164 let mut losses = Vec::new();
1166 for _ in 0..50 {
1167 let (loss, _gnorm) = trainer
1168 .train_step(
1169 &input_ids,
1170 &target_ids,
1171 &hidden,
1172 &mut lm_head,
1173 &mut m_state,
1174 &mut v_state,
1175 )
1176 .expect("train_step");
1177 losses.push(loss);
1178 }
1179
1180 let first_loss = losses[0];
1181 let best_loss = losses.iter().copied().fold(f32::INFINITY, f32::min);
1182 let last_loss = *losses.last().expect("losses");
1183
1184 eprintln!(
1185 "WGPU convergence: loss {:.3} -> {:.3} (best {:.3}, {} steps)",
1186 first_loss,
1187 last_loss,
1188 best_loss,
1189 losses.len()
1190 );
1191
1192 assert!(first_loss.is_finite(), "First loss not finite: {first_loss}");
1193 assert!(
1194 best_loss < first_loss * 0.9,
1195 "FALSIFY-WGPU-002: Loss did not decrease by >10%: first={first_loss:.3}, best={best_loss:.3}"
1196 );
1197 }
1198
1199 #[test]
1201 fn test_ffn_backward_gradient_flow() {
1202 let mut config = TransformerConfig::llama2_7b();
1203 config.hidden_size = 8;
1204 config.intermediate_size = 16;
1205
1206 let trainer = WgpuTransformerTrainer::new(&config, 1e-3).expect("trainer");
1207
1208 let (s, h, i) = (2u32, 8u32, 16u32);
1209
1210 let grad_output: Vec<f32> = (0..(s * h) as usize).map(|j| (j as f32 - 8.0) * 0.1).collect();
1212 let hidden_input: Vec<f32> = (0..(s * h) as usize).map(|j| j as f32 * 0.05).collect();
1213 let gate_weight: Vec<f32> =
1214 (0..(i * h) as usize).map(|j| (j as f32 - 64.0) * 0.01).collect();
1215 let up_weight: Vec<f32> = (0..(i * h) as usize).map(|j| (j as f32 - 64.0) * 0.01).collect();
1216 let down_weight: Vec<f32> =
1217 (0..(h * i) as usize).map(|j| (j as f32 - 64.0) * 0.01).collect();
1218
1219 let mut gate_output = vec![0.0f32; (s * i) as usize];
1221 let mut up_output = vec![0.0f32; (s * i) as usize];
1222 for si in 0..s as usize {
1223 for ii in 0..i as usize {
1224 for hi in 0..h as usize {
1225 gate_output[si * i as usize + ii] +=
1226 hidden_input[si * h as usize + hi] * gate_weight[ii * h as usize + hi];
1227 up_output[si * i as usize + ii] +=
1228 hidden_input[si * h as usize + hi] * up_weight[ii * h as usize + hi];
1229 }
1230 }
1231 }
1232 let silu_gate: Vec<f32> = gate_output
1234 .iter()
1235 .map(|&x| {
1236 let sig = 1.0 / (1.0 + (-x).exp());
1237 x * sig
1238 })
1239 .collect();
1240
1241 let grad_input = trainer
1242 .ffn_backward(
1243 &grad_output,
1244 &hidden_input,
1245 &gate_weight,
1246 &up_weight,
1247 &down_weight,
1248 &gate_output,
1249 &up_output,
1250 &silu_gate,
1251 s,
1252 h,
1253 i,
1254 )
1255 .expect("ffn_backward");
1256
1257 let norm: f32 = grad_input.iter().map(|g| g * g).sum::<f32>().sqrt();
1259 assert!(norm > 1e-6, "FFN backward gradient norm should be non-zero, got {norm}");
1260 assert!(grad_input.iter().all(|g| g.is_finite()), "All gradients must be finite");
1261
1262 eprintln!("FFN backward gradient norm: {norm:.4}");
1263 }
1264
1265 #[test]
1267 fn test_lora_forward_adds_to_base() {
1268 let mut config = TransformerConfig::llama2_7b();
1269 config.hidden_size = 8;
1270 config.intermediate_size = 16;
1271
1272 let trainer = WgpuTransformerTrainer::new(&config, 1e-3).expect("trainer");
1273
1274 let (s, in_d, out_d, r) = (2u32, 8u32, 16u32, 4u32);
1275 let alpha = 8.0f32;
1276
1277 let x: Vec<f32> = (0..(s * in_d) as usize).map(|i| (i as f32 - 8.0) * 0.1).collect();
1278 let w: Vec<f32> = (0..(out_d * in_d) as usize).map(|i| (i as f32 - 64.0) * 0.01).collect();
1279
1280 let a: Vec<f32> = (0..(r * in_d) as usize).map(|i| (i as f32 - 16.0) * 0.05).collect();
1282 let b_zero = vec![0.0f32; (out_d * r) as usize];
1283
1284 let y_base = trainer
1285 .lora_forward(&x, &w, &a, &b_zero, s, in_d, out_d, r, alpha)
1286 .expect("lora_forward base");
1287
1288 let b: Vec<f32> = (0..(out_d * r) as usize).map(|i| (i as f32 - 32.0) * 0.02).collect();
1290 let y_lora = trainer
1291 .lora_forward(&x, &w, &a, &b, s, in_d, out_d, r, alpha)
1292 .expect("lora_forward lora");
1293
1294 let diff: f32 = y_base.iter().zip(y_lora.iter()).map(|(a, b)| (a - b).abs()).sum();
1296 assert!(diff > 1e-3, "LoRA should change output, diff={diff}");
1297 }
1298
1299 #[test]
1301 fn test_lora_backward_gradient_flow() {
1302 let mut config = TransformerConfig::llama2_7b();
1303 config.hidden_size = 8;
1304 config.intermediate_size = 16;
1305
1306 let trainer = WgpuTransformerTrainer::new(&config, 1e-3).expect("trainer");
1307
1308 let (s, in_d, out_d, r) = (2u32, 8u32, 16u32, 4u32);
1309 let alpha = 8.0f32;
1310
1311 let x: Vec<f32> = (0..(s * in_d) as usize).map(|i| (i as f32 - 8.0) * 0.1).collect();
1312 let w: Vec<f32> = (0..(out_d * in_d) as usize).map(|i| (i as f32 - 64.0) * 0.01).collect();
1313 let a: Vec<f32> = (0..(r * in_d) as usize).map(|i| (i as f32 - 16.0) * 0.05).collect();
1314 let b: Vec<f32> = (0..(out_d * r) as usize).map(|i| (i as f32 - 32.0) * 0.02).collect();
1315
1316 let mut h_cached = vec![0.0f32; (s * r) as usize];
1318 for i in 0..s as usize {
1319 for j in 0..r as usize {
1320 for p in 0..in_d as usize {
1321 h_cached[i * r as usize + j] +=
1322 x[i * in_d as usize + p] * a[j * in_d as usize + p];
1323 }
1324 }
1325 }
1326
1327 let grad_output: Vec<f32> =
1328 (0..(s * out_d) as usize).map(|i| (i as f32 - 16.0) * 0.05).collect();
1329
1330 let (grad_a, grad_b, grad_x) = trainer
1331 .lora_backward(&grad_output, &x, &w, &a, &b, &h_cached, s, in_d, out_d, r, alpha)
1332 .expect("lora_backward");
1333
1334 let norm_a: f32 = grad_a.iter().map(|g| g * g).sum::<f32>().sqrt();
1335 let norm_b: f32 = grad_b.iter().map(|g| g * g).sum::<f32>().sqrt();
1336 let norm_x: f32 = grad_x.iter().map(|g| g * g).sum::<f32>().sqrt();
1337
1338 assert!(norm_a > 1e-6, "grad_A should be non-zero, got {norm_a}");
1339 assert!(norm_b > 1e-6, "grad_B should be non-zero, got {norm_b}");
1340 assert!(norm_x > 1e-6, "grad_x should be non-zero, got {norm_x}");
1341 assert!(grad_a.iter().all(|g| g.is_finite()), "grad_A must be finite");
1342 assert!(grad_b.iter().all(|g| g.is_finite()), "grad_B must be finite");
1343 assert!(grad_x.iter().all(|g| g.is_finite()), "grad_x must be finite");
1344
1345 eprintln!(
1346 "LoRA backward: |grad_A|={norm_a:.4}, |grad_B|={norm_b:.4}, |grad_x|={norm_x:.4}"
1347 );
1348 }
1349
1350 #[test]
1352 fn test_load_qwen3_4b_full_model() {
1353 let model_dir = std::path::Path::new("/home/noah/src/models/qwen3-4b");
1354 if !model_dir.exists() {
1355 eprintln!("Skipping: Qwen3-4B model not found");
1356 return;
1357 }
1358
1359 let model = WgpuModelState::load_qwen3_4b(model_dir, 16, 32.0).expect("load_qwen3_4b");
1360
1361 assert_eq!(model.num_layers, 36);
1362 assert_eq!(model.hidden_size, 2560);
1363 assert_eq!(model.layers.len(), 36);
1364 assert_eq!(model.lora.len(), 36);
1365
1366 let total_nf4_mb: f64 =
1367 model.layers.iter().map(|l| l.memory_bytes() as f64).sum::<f64>() / 1024.0 / 1024.0;
1368 let trainable = model.trainable_params();
1369
1370 eprintln!("Qwen3-4B loaded: {total_nf4_mb:.0} MB NF4, {trainable} trainable params");
1371
1372 assert!(total_nf4_mb < 2048.0, "NF4 total should be < 2GB, got {total_nf4_mb:.0} MB");
1374
1375 assert!(trainable > 1_000_000, "Should have >1M trainable params, got {trainable}");
1377 }
1378
1379 #[test]
1386 fn test_qwen3_4b_single_layer_train_step() {
1387 let model_dir = std::path::Path::new("/home/noah/src/models/qwen3-4b");
1388 if !model_dir.exists() {
1389 eprintln!("Skipping: Qwen3-4B model not found");
1390 return;
1391 }
1392
1393 let mut config = TransformerConfig::llama2_7b();
1394 config.hidden_size = 2560;
1395 config.intermediate_size = 9728;
1396 config.num_hidden_layers = 36;
1397 config.num_attention_heads = 32;
1398 config.num_kv_heads = 8;
1399 config.vocab_size = 151936;
1400
1401 let mut model = WgpuModelState::load_qwen3_4b(model_dir, 16, 32.0).expect("load model");
1402
1403 let mut trainer = WgpuTransformerTrainer::new(&config, 1e-3).expect("trainer");
1404
1405 let seq_len = 4u32;
1407 let hidden: Vec<f32> = (0..(seq_len * 2560) as usize)
1408 .map(|i| ((i * 7 + 3) % 1000) as f32 / 1000.0 - 0.5)
1409 .collect();
1410
1411 let start = std::time::Instant::now();
1412 let lora_set = &mut model.lora[0];
1413 let (lora_q, lora_v) = (&mut lora_set.q, &mut lora_set.v);
1414 let (output, grad_norm) = trainer
1415 .layer_train_step(&hidden, &mut model.layers[0], lora_q, lora_v, seq_len, 2560, 9728)
1416 .expect("layer_train_step");
1417 let elapsed = start.elapsed();
1418
1419 assert_eq!(output.len(), (seq_len * 2560) as usize);
1420 assert!(output.iter().all(|v| v.is_finite()), "All outputs must be finite");
1421 assert!(grad_norm > 0.0, "Gradient norm must be positive");
1422 assert!(grad_norm.is_finite(), "Gradient norm must be finite");
1423
1424 eprintln!(
1425 "Qwen3-4B layer 0 train step: {:.1}s, output_norm={:.4}, grad_norm={:.4}",
1426 elapsed.as_secs_f64(),
1427 output.iter().map(|v| v * v).sum::<f32>().sqrt(),
1428 grad_norm,
1429 );
1430 }
1431
1432 #[test]
1436 fn test_qwen3_4b_full_36_layer_training() {
1437 let model_dir = std::path::Path::new("/home/noah/src/models/qwen3-4b");
1438 if !model_dir.exists() {
1439 eprintln!("Skipping: Qwen3-4B model not found");
1440 return;
1441 }
1442
1443 let mut config = TransformerConfig::llama2_7b();
1444 config.hidden_size = 2560;
1445 config.intermediate_size = 9728;
1446 config.num_hidden_layers = 36;
1447 config.num_attention_heads = 32;
1448 config.num_kv_heads = 8;
1449 config.vocab_size = 151936;
1450
1451 let mut model = WgpuModelState::load_qwen3_4b(model_dir, 16, 32.0).expect("load model");
1452
1453 let mut trainer = WgpuTransformerTrainer::new(&config, 5e-4).expect("trainer");
1454
1455 let seq_len = 2u32;
1457 let hidden: Vec<f32> = (0..(seq_len * 2560) as usize)
1458 .map(|j| ((j * 7 + 3) % 1000) as f32 / 1000.0 - 0.5)
1459 .collect();
1460 let targets: Vec<u32> = vec![42, 100]; let mut losses = Vec::new();
1464 for step in 0..3 {
1465 let start = std::time::Instant::now();
1466 let (loss, gnorm) =
1467 trainer.full_train_step(&hidden, &targets, &mut model).expect("full_train_step");
1468 let elapsed = start.elapsed();
1469
1470 eprintln!(
1471 "Step {}: loss={:.3}, gnorm={:.4}, time={:.1}s",
1472 step + 1,
1473 loss,
1474 gnorm,
1475 elapsed.as_secs_f64()
1476 );
1477 losses.push(loss);
1478
1479 assert!(loss.is_finite(), "Loss must be finite at step {}", step + 1);
1480 assert!(loss > 0.0, "Loss must be positive at step {}", step + 1);
1481 assert!(gnorm.is_finite(), "Grad norm must be finite at step {}", step + 1);
1482 }
1483
1484 eprintln!(
1485 "Qwen3-4B 36-layer training: loss {:.3} -> {:.3} ({} steps)",
1486 losses[0],
1487 losses.last().unwrap(),
1488 losses.len()
1489 );
1490 }
1491}