1#[cfg(feature = "gpu")]
12use trueno::backends::gpu::GpuDevice;
13
14#[cfg(feature = "gpu")]
19pub struct Nf4LayerWeights {
20 pub gate_packed: Vec<u32>,
22 pub gate_scales: Vec<f32>,
23 pub up_packed: Vec<u32>,
25 pub up_scales: Vec<f32>,
26 pub down_packed: Vec<u32>,
28 pub down_scales: Vec<f32>,
29 pub q_packed: Vec<u32>,
31 pub q_scales: Vec<f32>,
32 pub k_packed: Vec<u32>,
34 pub k_scales: Vec<f32>,
35 pub v_packed: Vec<u32>,
37 pub v_scales: Vec<f32>,
38 pub o_packed: Vec<u32>,
40 pub o_scales: Vec<f32>,
41 pub gate_n: u32,
43 pub up_n: u32,
44 pub down_n: u32,
45 pub q_n: u32,
46 pub k_n: u32,
47 pub v_n: u32,
48 pub o_n: u32,
49 pub block_size: u32,
51}
52
53#[cfg(feature = "gpu")]
54impl Nf4LayerWeights {
55 pub fn dequant_gate(&self, device: &GpuDevice) -> Result<Vec<f32>, String> {
57 self.dequant_any(&self.gate_packed, &self.gate_scales, self.gate_n, device)
58 }
59 pub fn dequant_up(&self, device: &GpuDevice) -> Result<Vec<f32>, String> {
61 self.dequant_any(&self.up_packed, &self.up_scales, self.up_n, device)
62 }
63 pub fn dequant_down(&self, device: &GpuDevice) -> Result<Vec<f32>, String> {
65 self.dequant_any(&self.down_packed, &self.down_scales, self.down_n, device)
66 }
67
68 pub fn dequant_q(&self, device: &GpuDevice) -> Result<Vec<f32>, String> {
70 self.dequant_any(&self.q_packed, &self.q_scales, self.q_n, device)
71 }
72 pub fn dequant_k(&self, device: &GpuDevice) -> Result<Vec<f32>, String> {
74 self.dequant_any(&self.k_packed, &self.k_scales, self.k_n, device)
75 }
76 pub fn dequant_v(&self, device: &GpuDevice) -> Result<Vec<f32>, String> {
78 self.dequant_any(&self.v_packed, &self.v_scales, self.v_n, device)
79 }
80 pub fn dequant_o(&self, device: &GpuDevice) -> Result<Vec<f32>, String> {
82 self.dequant_any(&self.o_packed, &self.o_scales, self.o_n, device)
83 }
84
85 fn dequant_any(
86 &self,
87 packed: &[u32],
88 scales: &[f32],
89 n: u32,
90 device: &GpuDevice,
91 ) -> Result<Vec<f32>, String> {
92 let mut output = vec![0.0f32; n as usize];
93 device.nf4_dequant(packed, scales, &mut output, n, self.block_size)?;
94 Ok(output)
95 }
96
97 pub fn memory_bytes(&self) -> usize {
99 let packed_bytes = (self.gate_packed.len()
100 + self.up_packed.len()
101 + self.down_packed.len()
102 + self.q_packed.len()
103 + self.k_packed.len()
104 + self.v_packed.len()
105 + self.o_packed.len())
106 * 4;
107 let scale_bytes = (self.gate_scales.len()
108 + self.up_scales.len()
109 + self.down_scales.len()
110 + self.q_scales.len()
111 + self.k_scales.len()
112 + self.v_scales.len()
113 + self.o_scales.len())
114 * 4;
115 packed_bytes + scale_bytes
116 }
117
118 pub fn quantize_projection_from_tensors(
122 tensors: &safetensors::SafeTensors<'_>,
123 name: &str,
124 rows: usize,
125 cols: usize,
126 ) -> Result<(Vec<u32>, Vec<f32>, u32), String> {
127 quantize_projection(tensors, name, rows, cols)
128 }
129}
130
131#[cfg(feature = "gpu")]
133const NF4_LUT: [f32; 16] = [
134 -1.0,
135 -0.696_192_8,
136 -0.525_073_05,
137 -0.394_917_5,
138 -0.284_441_38,
139 -0.184_773_43,
140 -0.091_050_036,
141 0.0,
142 0.079_580_3,
143 0.160_930_2,
144 0.246_112_3,
145 0.337_915_24,
146 0.440_709_83,
147 0.562_617,
148 0.722_956_84,
149 1.0,
150];
151
152const NF4_BLOCK_SIZE: usize = 64;
153
154#[cfg(feature = "gpu")]
158fn quantize_to_nf4(values: &[f32]) -> (Vec<u32>, Vec<f32>) {
159 let n = values.len();
160 assert!(n.is_multiple_of(NF4_BLOCK_SIZE), "Length must be divisible by {NF4_BLOCK_SIZE}");
161
162 let num_blocks = n / NF4_BLOCK_SIZE;
163 let mut scales = Vec::with_capacity(num_blocks);
164 let mut packed_bytes = vec![0u8; n / 2]; for block_idx in 0..num_blocks {
167 let start = block_idx * NF4_BLOCK_SIZE;
168 let block = &values[start..start + NF4_BLOCK_SIZE];
169
170 let absmax = block.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
172 let scale = if absmax < 1e-10 { 1.0 } else { absmax };
173 scales.push(scale);
174
175 for (i, &val) in block.iter().enumerate() {
177 let normalized = val / scale;
178 let mut best_idx = 0u8;
179 let mut best_dist = f32::MAX;
180 for (j, &lut_val) in NF4_LUT.iter().enumerate() {
181 let dist = (normalized - lut_val).abs();
182 if dist < best_dist {
183 best_dist = dist;
184 best_idx = j as u8;
185 }
186 }
187 let elem_idx = start + i;
188 let byte_idx = elem_idx / 2;
189 if elem_idx.is_multiple_of(2) {
190 packed_bytes[byte_idx] |= best_idx; } else {
192 packed_bytes[byte_idx] |= best_idx << 4; }
194 }
195 }
196
197 let mut packed = vec![0u32; packed_bytes.len().div_ceil(4)];
199 for (i, &byte) in packed_bytes.iter().enumerate() {
200 packed[i / 4] |= u32::from(byte) << ((i % 4) * 8);
201 }
202
203 (packed, scales)
204}
205
206#[cfg(feature = "gpu")]
210fn quantize_projection(
211 tensors: &safetensors::SafeTensors<'_>,
212 name: &str,
213 rows: usize,
214 cols: usize,
215) -> Result<(Vec<u32>, Vec<f32>, u32), String> {
216 let view = tensors.tensor(name).map_err(|e| format!("Missing tensor {name}: {e}"))?;
217
218 let fp32: Vec<f32> = match view.dtype() {
219 safetensors::Dtype::F16 => view
220 .data()
221 .chunks_exact(2)
222 .map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
223 .collect(),
224 safetensors::Dtype::F32 => bytemuck::cast_slice(view.data()).to_vec(),
225 safetensors::Dtype::BF16 => view
226 .data()
227 .chunks_exact(2)
228 .map(|b| half::bf16::from_le_bytes([b[0], b[1]]).to_f32())
229 .collect(),
230 dt => return Err(format!("Unsupported dtype {dt:?} for {name}")),
231 };
232
233 let expected = rows * cols;
234 let mut padded = fp32;
236 if padded.len() != expected {
237 return Err(format!("{name}: expected {expected} elements, got {}", padded.len()));
238 }
239 let remainder = expected % NF4_BLOCK_SIZE;
240 if remainder != 0 {
241 padded.resize(expected + NF4_BLOCK_SIZE - remainder, 0.0);
242 }
243
244 let (packed, scales) = quantize_to_nf4(&padded);
245 Ok((packed, scales, expected as u32))
246}
247
248#[cfg(feature = "gpu")]
249impl Nf4LayerWeights {
250 pub fn from_safetensors(
256 tensors: &safetensors::SafeTensors<'_>,
257 layer_idx: usize,
258 hidden_size: usize,
259 intermediate_size: usize,
260 num_heads: usize,
261 num_kv_heads: usize,
262 head_dim: usize,
263 block_size: u32,
264 ) -> Result<Self, String> {
265 let prefix = format!("model.layers.{layer_idx}");
266 let q_dim = num_heads * head_dim;
267 let kv_dim = num_kv_heads * head_dim;
268
269 let (gate_packed, gate_scales, gate_n) = quantize_projection(
270 tensors,
271 &format!("{prefix}.mlp.gate_proj.weight"),
272 intermediate_size,
273 hidden_size,
274 )?;
275 let (up_packed, up_scales, up_n) = quantize_projection(
276 tensors,
277 &format!("{prefix}.mlp.up_proj.weight"),
278 intermediate_size,
279 hidden_size,
280 )?;
281 let (down_packed, down_scales, down_n) = quantize_projection(
282 tensors,
283 &format!("{prefix}.mlp.down_proj.weight"),
284 hidden_size,
285 intermediate_size,
286 )?;
287 let (q_packed, q_scales, q_n) = quantize_projection(
288 tensors,
289 &format!("{prefix}.self_attn.q_proj.weight"),
290 q_dim,
291 hidden_size,
292 )?;
293 let (k_packed, k_scales, k_n) = quantize_projection(
294 tensors,
295 &format!("{prefix}.self_attn.k_proj.weight"),
296 kv_dim,
297 hidden_size,
298 )?;
299 let (v_packed, v_scales, v_n) = quantize_projection(
300 tensors,
301 &format!("{prefix}.self_attn.v_proj.weight"),
302 kv_dim,
303 hidden_size,
304 )?;
305 let (o_packed, o_scales, o_n) = quantize_projection(
306 tensors,
307 &format!("{prefix}.self_attn.o_proj.weight"),
308 hidden_size,
309 q_dim,
310 )?;
311
312 Ok(Self {
313 gate_packed,
314 gate_scales,
315 up_packed,
316 up_scales,
317 down_packed,
318 down_scales,
319 q_packed,
320 q_scales,
321 k_packed,
322 k_scales,
323 v_packed,
324 v_scales,
325 o_packed,
326 o_scales,
327 gate_n,
328 up_n,
329 down_n,
330 q_n,
331 k_n,
332 v_n,
333 o_n,
334 block_size,
335 })
336 }
337}
338
339#[cfg(feature = "gpu")]
344#[derive(Clone, serde::Serialize, serde::Deserialize)]
345pub struct LoraAdapter {
346 pub a: Vec<f32>,
348 pub b: Vec<f32>,
350 pub m_a: Vec<f32>,
352 pub v_a: Vec<f32>,
354 pub m_b: Vec<f32>,
356 pub v_b: Vec<f32>,
358 pub rank: u32,
360 pub in_dim: u32,
361 pub out_dim: u32,
362}
363
364#[cfg(feature = "gpu")]
365impl LoraAdapter {
366 pub fn new(rank: u32, in_dim: u32, out_dim: u32) -> Self {
368 let a_len = (rank * in_dim) as usize;
369 let b_len = (out_dim * rank) as usize;
370
371 let scale = (2.0 / f64::from(in_dim)).sqrt() as f32;
373 let mut a = vec![0.0f32; a_len];
374 for (i, val) in a.iter_mut().enumerate() {
376 let hash = ((i as u64)
377 .wrapping_mul(6364136223846793005)
378 .wrapping_add(1442695040888963407)) as f32;
379 *val = (hash / u64::MAX as f32 * 2.0 - 1.0) * scale;
380 }
381
382 Self {
383 a,
384 b: vec![0.0f32; b_len], m_a: vec![0.0f32; a_len],
386 v_a: vec![0.0f32; a_len],
387 m_b: vec![0.0f32; b_len],
388 v_b: vec![0.0f32; b_len],
389 rank,
390 in_dim,
391 out_dim,
392 }
393 }
394
395 pub fn num_params(&self) -> usize {
397 self.a.len() + self.b.len()
398 }
399}
400
401#[cfg(all(test, feature = "gpu"))]
402mod tests {
403 use super::*;
404
405 #[test]
406 fn test_lora_adapter_creation() {
407 let adapter = LoraAdapter::new(16, 2560, 4096);
408 assert_eq!(adapter.a.len(), 16 * 2560);
409 assert_eq!(adapter.b.len(), 4096 * 16);
410 assert_eq!(adapter.num_params(), 16 * 2560 + 4096 * 16);
411 assert!(adapter.b.iter().all(|&v| v == 0.0));
413 }
414
415 #[test]
416 fn test_nf4_layer_memory() {
417 let h: u32 = 2560;
419 let i: u32 = 9728;
420 let bs: u32 = 64;
421
422 let layer = Nf4LayerWeights {
423 gate_packed: vec![0u32; (h * i / 8) as usize], gate_scales: vec![0.0f32; (h * i / bs) as usize],
425 up_packed: vec![0u32; (h * i / 8) as usize],
426 up_scales: vec![0.0f32; (h * i / bs) as usize],
427 down_packed: vec![0u32; (i * h / 8) as usize],
428 down_scales: vec![0.0f32; (i * h / bs) as usize],
429 q_packed: vec![0u32; (h * 4096 / 8) as usize],
430 q_scales: vec![0.0f32; (h * 4096 / bs) as usize],
431 k_packed: vec![0u32; (h * 1024 / 8) as usize],
432 k_scales: vec![0.0f32; (h * 1024 / bs) as usize],
433 v_packed: vec![0u32; (h * 1024 / 8) as usize],
434 v_scales: vec![0.0f32; (h * 1024 / bs) as usize],
435 o_packed: vec![0u32; (4096 * h / 8) as usize],
436 o_scales: vec![0.0f32; (4096 * h / bs) as usize],
437 gate_n: h * i,
438 up_n: h * i,
439 down_n: i * h,
440 q_n: h * 4096,
441 k_n: h * 1024,
442 v_n: h * 1024,
443 o_n: 4096 * h,
444 block_size: bs,
445 };
446
447 let mb = layer.memory_bytes() as f64 / 1024.0 / 1024.0;
448 eprintln!("Qwen3-4B NF4 layer: {mb:.1} MB");
449 assert!(mb < 100.0, "NF4 layer should be < 100MB, got {mb:.1}");
450 }
451
452 #[test]
456 fn test_load_qwen3_4b_layer0_nf4() {
457 let model_path = std::path::Path::new("/home/noah/src/models/qwen3-4b");
458 if !model_path.exists() {
459 eprintln!("Skipping: Qwen3-4B model not found at {}", model_path.display());
460 return;
461 }
462
463 let shard_path = model_path.join("model-00001-of-00003.safetensors");
465 let data = std::fs::read(&shard_path).expect("read shard");
466 let tensors = safetensors::SafeTensors::deserialize(&data).expect("parse safetensors");
467
468 let layer = Nf4LayerWeights::from_safetensors(
469 &tensors, 0, 2560, 9728, 32, 8, 128, 64, )
477 .expect("from_safetensors");
478
479 let mb = layer.memory_bytes() as f64 / 1024.0 / 1024.0;
480 eprintln!("Layer 0 NF4: {mb:.1} MB (gate_n={}, q_n={})", layer.gate_n, layer.q_n);
481
482 assert_eq!(layer.gate_n, 2560 * 9728);
483 assert_eq!(layer.q_n, 2560 * 4096);
484 assert_eq!(layer.k_n, 2560 * 1024);
485 assert!(mb < 60.0, "Layer 0 should be < 60MB NF4, got {mb:.1}");
486
487 let device = GpuDevice::new().expect("GPU");
489 let gate_fp32 = layer.dequant_gate(&device).expect("dequant_gate");
490 assert_eq!(gate_fp32.len(), (2560 * 9728) as usize);
491 assert!(gate_fp32.iter().all(|v| v.is_finite()), "All dequanted values must be finite");
492
493 let nonzero = gate_fp32.iter().filter(|&&v| v.abs() > 1e-6).count();
495 let pct = nonzero as f64 / gate_fp32.len() as f64 * 100.0;
496 eprintln!("Gate dequant: {nonzero}/{} non-zero ({pct:.1}%)", gate_fp32.len());
497 assert!(pct > 50.0, "Most dequanted values should be non-zero, got {pct:.1}%");
498 }
499}