Skip to main content

entrenar/train/transformer_trainer/
wgpu_nf4.rs

1//! NF4 weight management for WGPU training
2//!
3//! Loads NF4-quantized weights and dequantizes on GPU per-layer.
4//! This keeps total VRAM under 16GB for Qwen3-4B (36 layers).
5//!
6//! # Contract: wgpu-transformer-trainer-v1.yaml (C-WGPU-TRAIN-001)
7//!
8//! - NF4 dequant on GPU matches CPU within ε < 1e-6 (FALSIFY-WGPU-003)
9//! - Per-layer dequant avoids storing all fp32 weights simultaneously
10
11#[cfg(feature = "gpu")]
12use trueno::backends::gpu::GpuDevice;
13
14/// NF4 quantized layer weights (compact representation)
15///
16/// Stores packed 4-bit nibbles + per-block scales.
17/// Total size per projection: n_params / 2 bytes (packed) + n_params / block_size * 4 bytes (scales)
18#[cfg(feature = "gpu")]
19pub struct Nf4LayerWeights {
20    /// Gate projection [intermediate, hidden] packed NF4
21    pub gate_packed: Vec<u32>,
22    pub gate_scales: Vec<f32>,
23    /// Up projection [intermediate, hidden] packed NF4
24    pub up_packed: Vec<u32>,
25    pub up_scales: Vec<f32>,
26    /// Down projection [hidden, intermediate] packed NF4
27    pub down_packed: Vec<u32>,
28    pub down_scales: Vec<f32>,
29    /// Q projection [num_heads * head_dim, hidden] packed NF4
30    pub q_packed: Vec<u32>,
31    pub q_scales: Vec<f32>,
32    /// K projection [num_kv_heads * head_dim, hidden] packed NF4
33    pub k_packed: Vec<u32>,
34    pub k_scales: Vec<f32>,
35    /// V projection [num_kv_heads * head_dim, hidden] packed NF4
36    pub v_packed: Vec<u32>,
37    pub v_scales: Vec<f32>,
38    /// O projection [hidden, num_heads * head_dim] packed NF4
39    pub o_packed: Vec<u32>,
40    pub o_scales: Vec<f32>,
41    /// Number of elements per projection
42    pub gate_n: u32,
43    pub up_n: u32,
44    pub down_n: u32,
45    pub q_n: u32,
46    pub k_n: u32,
47    pub v_n: u32,
48    pub o_n: u32,
49    /// NF4 block size (typically 64)
50    pub block_size: u32,
51}
52
53#[cfg(feature = "gpu")]
54impl Nf4LayerWeights {
55    /// Dequantize gate projection to fp32 on GPU
56    pub fn dequant_gate(&self, device: &GpuDevice) -> Result<Vec<f32>, String> {
57        self.dequant_any(&self.gate_packed, &self.gate_scales, self.gate_n, device)
58    }
59    /// Dequantize up projection to fp32 on GPU
60    pub fn dequant_up(&self, device: &GpuDevice) -> Result<Vec<f32>, String> {
61        self.dequant_any(&self.up_packed, &self.up_scales, self.up_n, device)
62    }
63    /// Dequantize down projection to fp32 on GPU
64    pub fn dequant_down(&self, device: &GpuDevice) -> Result<Vec<f32>, String> {
65        self.dequant_any(&self.down_packed, &self.down_scales, self.down_n, device)
66    }
67
68    /// Dequantize Q projection to fp32 on GPU
69    pub fn dequant_q(&self, device: &GpuDevice) -> Result<Vec<f32>, String> {
70        self.dequant_any(&self.q_packed, &self.q_scales, self.q_n, device)
71    }
72    /// Dequantize K projection to fp32 on GPU
73    pub fn dequant_k(&self, device: &GpuDevice) -> Result<Vec<f32>, String> {
74        self.dequant_any(&self.k_packed, &self.k_scales, self.k_n, device)
75    }
76    /// Dequantize V projection to fp32 on GPU
77    pub fn dequant_v(&self, device: &GpuDevice) -> Result<Vec<f32>, String> {
78        self.dequant_any(&self.v_packed, &self.v_scales, self.v_n, device)
79    }
80    /// Dequantize O projection to fp32 on GPU
81    pub fn dequant_o(&self, device: &GpuDevice) -> Result<Vec<f32>, String> {
82        self.dequant_any(&self.o_packed, &self.o_scales, self.o_n, device)
83    }
84
85    fn dequant_any(
86        &self,
87        packed: &[u32],
88        scales: &[f32],
89        n: u32,
90        device: &GpuDevice,
91    ) -> Result<Vec<f32>, String> {
92        let mut output = vec![0.0f32; n as usize];
93        device.nf4_dequant(packed, scales, &mut output, n, self.block_size)?;
94        Ok(output)
95    }
96
97    /// Memory usage in bytes (NF4 packed + scales)
98    pub fn memory_bytes(&self) -> usize {
99        let packed_bytes = (self.gate_packed.len()
100            + self.up_packed.len()
101            + self.down_packed.len()
102            + self.q_packed.len()
103            + self.k_packed.len()
104            + self.v_packed.len()
105            + self.o_packed.len())
106            * 4;
107        let scale_bytes = (self.gate_scales.len()
108            + self.up_scales.len()
109            + self.down_scales.len()
110            + self.q_scales.len()
111            + self.k_scales.len()
112            + self.v_scales.len()
113            + self.o_scales.len())
114            * 4;
115        packed_bytes + scale_bytes
116    }
117
118    /// Quantize a single projection from pre-parsed safetensors
119    ///
120    /// Public so the model loader can call it per-shard.
121    pub fn quantize_projection_from_tensors(
122        tensors: &safetensors::SafeTensors<'_>,
123        name: &str,
124        rows: usize,
125        cols: usize,
126    ) -> Result<(Vec<u32>, Vec<f32>, u32), String> {
127        quantize_projection(tensors, name, rows, cols)
128    }
129}
130
131/// NF4 codebook (same as trueno::quantize::NF4_LUT)
132#[cfg(feature = "gpu")]
133const NF4_LUT: [f32; 16] = [
134    -1.0,
135    -0.696_192_8,
136    -0.525_073_05,
137    -0.394_917_5,
138    -0.284_441_38,
139    -0.184_773_43,
140    -0.091_050_036,
141    0.0,
142    0.079_580_3,
143    0.160_930_2,
144    0.246_112_3,
145    0.337_915_24,
146    0.440_709_83,
147    0.562_617,
148    0.722_956_84,
149    1.0,
150];
151
152const NF4_BLOCK_SIZE: usize = 64;
153
154/// Quantize fp32 values to NF4 format (packed u32 + scales)
155///
156/// Returns (packed_u32, scales, n_elements)
157#[cfg(feature = "gpu")]
158fn quantize_to_nf4(values: &[f32]) -> (Vec<u32>, Vec<f32>) {
159    let n = values.len();
160    assert!(n.is_multiple_of(NF4_BLOCK_SIZE), "Length must be divisible by {NF4_BLOCK_SIZE}");
161
162    let num_blocks = n / NF4_BLOCK_SIZE;
163    let mut scales = Vec::with_capacity(num_blocks);
164    let mut packed_bytes = vec![0u8; n / 2]; // 2 values per byte
165
166    for block_idx in 0..num_blocks {
167        let start = block_idx * NF4_BLOCK_SIZE;
168        let block = &values[start..start + NF4_BLOCK_SIZE];
169
170        // Find absmax for scale
171        let absmax = block.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
172        let scale = if absmax < 1e-10 { 1.0 } else { absmax };
173        scales.push(scale);
174
175        // Quantize each value: find nearest NF4 codebook entry
176        for (i, &val) in block.iter().enumerate() {
177            let normalized = val / scale;
178            let mut best_idx = 0u8;
179            let mut best_dist = f32::MAX;
180            for (j, &lut_val) in NF4_LUT.iter().enumerate() {
181                let dist = (normalized - lut_val).abs();
182                if dist < best_dist {
183                    best_dist = dist;
184                    best_idx = j as u8;
185                }
186            }
187            let elem_idx = start + i;
188            let byte_idx = elem_idx / 2;
189            if elem_idx.is_multiple_of(2) {
190                packed_bytes[byte_idx] |= best_idx; // low nibble
191            } else {
192                packed_bytes[byte_idx] |= best_idx << 4; // high nibble
193            }
194        }
195    }
196
197    // Pack bytes into u32
198    let mut packed = vec![0u32; packed_bytes.len().div_ceil(4)];
199    for (i, &byte) in packed_bytes.iter().enumerate() {
200        packed[i / 4] |= u32::from(byte) << ((i % 4) * 8);
201    }
202
203    (packed, scales)
204}
205
206/// Load one projection from safetensors, quantize to NF4, return GPU format.
207///
208/// # Contract (FALSIFY-WGPU-003)
209#[cfg(feature = "gpu")]
210fn quantize_projection(
211    tensors: &safetensors::SafeTensors<'_>,
212    name: &str,
213    rows: usize,
214    cols: usize,
215) -> Result<(Vec<u32>, Vec<f32>, u32), String> {
216    let view = tensors.tensor(name).map_err(|e| format!("Missing tensor {name}: {e}"))?;
217
218    let fp32: Vec<f32> = match view.dtype() {
219        safetensors::Dtype::F16 => view
220            .data()
221            .chunks_exact(2)
222            .map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
223            .collect(),
224        safetensors::Dtype::F32 => bytemuck::cast_slice(view.data()).to_vec(),
225        safetensors::Dtype::BF16 => view
226            .data()
227            .chunks_exact(2)
228            .map(|b| half::bf16::from_le_bytes([b[0], b[1]]).to_f32())
229            .collect(),
230        dt => return Err(format!("Unsupported dtype {dt:?} for {name}")),
231    };
232
233    let expected = rows * cols;
234    // Pad to NF4_BLOCK_SIZE if needed
235    let mut padded = fp32;
236    if padded.len() != expected {
237        return Err(format!("{name}: expected {expected} elements, got {}", padded.len()));
238    }
239    let remainder = expected % NF4_BLOCK_SIZE;
240    if remainder != 0 {
241        padded.resize(expected + NF4_BLOCK_SIZE - remainder, 0.0);
242    }
243
244    let (packed, scales) = quantize_to_nf4(&padded);
245    Ok((packed, scales, expected as u32))
246}
247
248#[cfg(feature = "gpu")]
249impl Nf4LayerWeights {
250    /// Load a single transformer layer's weights from safetensors as NF4
251    ///
252    /// # Contract (FALSIFY-WGPU-003)
253    ///
254    /// NF4 dequant of loaded weights matches original fp32 within quantization error.
255    pub fn from_safetensors(
256        tensors: &safetensors::SafeTensors<'_>,
257        layer_idx: usize,
258        hidden_size: usize,
259        intermediate_size: usize,
260        num_heads: usize,
261        num_kv_heads: usize,
262        head_dim: usize,
263        block_size: u32,
264    ) -> Result<Self, String> {
265        let prefix = format!("model.layers.{layer_idx}");
266        let q_dim = num_heads * head_dim;
267        let kv_dim = num_kv_heads * head_dim;
268
269        let (gate_packed, gate_scales, gate_n) = quantize_projection(
270            tensors,
271            &format!("{prefix}.mlp.gate_proj.weight"),
272            intermediate_size,
273            hidden_size,
274        )?;
275        let (up_packed, up_scales, up_n) = quantize_projection(
276            tensors,
277            &format!("{prefix}.mlp.up_proj.weight"),
278            intermediate_size,
279            hidden_size,
280        )?;
281        let (down_packed, down_scales, down_n) = quantize_projection(
282            tensors,
283            &format!("{prefix}.mlp.down_proj.weight"),
284            hidden_size,
285            intermediate_size,
286        )?;
287        let (q_packed, q_scales, q_n) = quantize_projection(
288            tensors,
289            &format!("{prefix}.self_attn.q_proj.weight"),
290            q_dim,
291            hidden_size,
292        )?;
293        let (k_packed, k_scales, k_n) = quantize_projection(
294            tensors,
295            &format!("{prefix}.self_attn.k_proj.weight"),
296            kv_dim,
297            hidden_size,
298        )?;
299        let (v_packed, v_scales, v_n) = quantize_projection(
300            tensors,
301            &format!("{prefix}.self_attn.v_proj.weight"),
302            kv_dim,
303            hidden_size,
304        )?;
305        let (o_packed, o_scales, o_n) = quantize_projection(
306            tensors,
307            &format!("{prefix}.self_attn.o_proj.weight"),
308            hidden_size,
309            q_dim,
310        )?;
311
312        Ok(Self {
313            gate_packed,
314            gate_scales,
315            up_packed,
316            up_scales,
317            down_packed,
318            down_scales,
319            q_packed,
320            q_scales,
321            k_packed,
322            k_scales,
323            v_packed,
324            v_scales,
325            o_packed,
326            o_scales,
327            gate_n,
328            up_n,
329            down_n,
330            q_n,
331            k_n,
332            v_n,
333            o_n,
334            block_size,
335        })
336    }
337}
338
339/// LoRA adapter pair for a single projection (rank-r)
340///
341/// Forward: y = x @ W^T + x @ B^T @ A^T (where A is [rank, in_dim], B is [out_dim, rank])
342/// Backward: gradients flow through B and A, frozen base W is not updated
343#[cfg(feature = "gpu")]
344#[derive(Clone, serde::Serialize, serde::Deserialize)]
345pub struct LoraAdapter {
346    /// A matrix [rank, in_dim] — fp32, trainable
347    pub a: Vec<f32>,
348    /// B matrix [out_dim, rank] — fp32, trainable
349    pub b: Vec<f32>,
350    /// AdamW first moment for A
351    pub m_a: Vec<f32>,
352    /// AdamW second moment for A
353    pub v_a: Vec<f32>,
354    /// AdamW first moment for B
355    pub m_b: Vec<f32>,
356    /// AdamW second moment for B
357    pub v_b: Vec<f32>,
358    /// Dimensions
359    pub rank: u32,
360    pub in_dim: u32,
361    pub out_dim: u32,
362}
363
364#[cfg(feature = "gpu")]
365impl LoraAdapter {
366    /// Create a new LoRA adapter with Kaiming-uniform A and zero B
367    pub fn new(rank: u32, in_dim: u32, out_dim: u32) -> Self {
368        let a_len = (rank * in_dim) as usize;
369        let b_len = (out_dim * rank) as usize;
370
371        // Kaiming-uniform initialization for A
372        let scale = (2.0 / f64::from(in_dim)).sqrt() as f32;
373        let mut a = vec![0.0f32; a_len];
374        // Simple deterministic pseudo-random init
375        for (i, val) in a.iter_mut().enumerate() {
376            let hash = ((i as u64)
377                .wrapping_mul(6364136223846793005)
378                .wrapping_add(1442695040888963407)) as f32;
379            *val = (hash / u64::MAX as f32 * 2.0 - 1.0) * scale;
380        }
381
382        Self {
383            a,
384            b: vec![0.0f32; b_len], // B initialized to zero → LoRA starts as identity
385            m_a: vec![0.0f32; a_len],
386            v_a: vec![0.0f32; a_len],
387            m_b: vec![0.0f32; b_len],
388            v_b: vec![0.0f32; b_len],
389            rank,
390            in_dim,
391            out_dim,
392        }
393    }
394
395    /// Total trainable parameters
396    pub fn num_params(&self) -> usize {
397        self.a.len() + self.b.len()
398    }
399}
400
401#[cfg(all(test, feature = "gpu"))]
402mod tests {
403    use super::*;
404
405    #[test]
406    fn test_lora_adapter_creation() {
407        let adapter = LoraAdapter::new(16, 2560, 4096);
408        assert_eq!(adapter.a.len(), 16 * 2560);
409        assert_eq!(adapter.b.len(), 4096 * 16);
410        assert_eq!(adapter.num_params(), 16 * 2560 + 4096 * 16);
411        // B should be zero (identity initialization)
412        assert!(adapter.b.iter().all(|&v| v == 0.0));
413    }
414
415    #[test]
416    fn test_nf4_layer_memory() {
417        // Simulate Qwen3-4B FFN layer
418        let h: u32 = 2560;
419        let i: u32 = 9728;
420        let bs: u32 = 64;
421
422        let layer = Nf4LayerWeights {
423            gate_packed: vec![0u32; (h * i / 8) as usize], // 4 bits per param, 8 params per u32
424            gate_scales: vec![0.0f32; (h * i / bs) as usize],
425            up_packed: vec![0u32; (h * i / 8) as usize],
426            up_scales: vec![0.0f32; (h * i / bs) as usize],
427            down_packed: vec![0u32; (i * h / 8) as usize],
428            down_scales: vec![0.0f32; (i * h / bs) as usize],
429            q_packed: vec![0u32; (h * 4096 / 8) as usize],
430            q_scales: vec![0.0f32; (h * 4096 / bs) as usize],
431            k_packed: vec![0u32; (h * 1024 / 8) as usize],
432            k_scales: vec![0.0f32; (h * 1024 / bs) as usize],
433            v_packed: vec![0u32; (h * 1024 / 8) as usize],
434            v_scales: vec![0.0f32; (h * 1024 / bs) as usize],
435            o_packed: vec![0u32; (4096 * h / 8) as usize],
436            o_scales: vec![0.0f32; (4096 * h / bs) as usize],
437            gate_n: h * i,
438            up_n: h * i,
439            down_n: i * h,
440            q_n: h * 4096,
441            k_n: h * 1024,
442            v_n: h * 1024,
443            o_n: 4096 * h,
444            block_size: bs,
445        };
446
447        let mb = layer.memory_bytes() as f64 / 1024.0 / 1024.0;
448        eprintln!("Qwen3-4B NF4 layer: {mb:.1} MB");
449        assert!(mb < 100.0, "NF4 layer should be < 100MB, got {mb:.1}");
450    }
451
452    /// Load Qwen3-4B layer 0 from safetensors and quantize to NF4
453    ///
454    /// # Contract (FALSIFY-WGPU-003): NF4 round-trip preserves relative accuracy
455    #[test]
456    fn test_load_qwen3_4b_layer0_nf4() {
457        let model_path = std::path::Path::new("/home/noah/src/models/qwen3-4b");
458        if !model_path.exists() {
459            eprintln!("Skipping: Qwen3-4B model not found at {}", model_path.display());
460            return;
461        }
462
463        // Load first shard
464        let shard_path = model_path.join("model-00001-of-00003.safetensors");
465        let data = std::fs::read(&shard_path).expect("read shard");
466        let tensors = safetensors::SafeTensors::deserialize(&data).expect("parse safetensors");
467
468        let layer = Nf4LayerWeights::from_safetensors(
469            &tensors, 0,    // layer 0
470            2560, // hidden_size
471            9728, // intermediate_size
472            32,   // num_heads
473            8,    // num_kv_heads
474            128,  // head_dim
475            64,   // block_size
476        )
477        .expect("from_safetensors");
478
479        let mb = layer.memory_bytes() as f64 / 1024.0 / 1024.0;
480        eprintln!("Layer 0 NF4: {mb:.1} MB (gate_n={}, q_n={})", layer.gate_n, layer.q_n);
481
482        assert_eq!(layer.gate_n, 2560 * 9728);
483        assert_eq!(layer.q_n, 2560 * 4096);
484        assert_eq!(layer.k_n, 2560 * 1024);
485        assert!(mb < 60.0, "Layer 0 should be < 60MB NF4, got {mb:.1}");
486
487        // Verify dequant round-trip on GPU
488        let device = GpuDevice::new().expect("GPU");
489        let gate_fp32 = layer.dequant_gate(&device).expect("dequant_gate");
490        assert_eq!(gate_fp32.len(), (2560 * 9728) as usize);
491        assert!(gate_fp32.iter().all(|v| v.is_finite()), "All dequanted values must be finite");
492
493        // Check non-trivial values (not all zero)
494        let nonzero = gate_fp32.iter().filter(|&&v| v.abs() > 1e-6).count();
495        let pct = nonzero as f64 / gate_fp32.len() as f64 * 100.0;
496        eprintln!("Gate dequant: {nonzero}/{} non-zero ({pct:.1}%)", gate_fp32.len());
497        assert!(pct > 50.0, "Most dequanted values should be non-zero, got {pct:.1}%");
498    }
499}