Skip to main content

trueno/inference/
model.rs

1//! Llama-family transformer model for inference.
2//!
3//! Composes trueno primitives (rms_norm, Q4K matmul, fused attention)
4//! into a complete transformer that loads GGUF weights and generates text.
5
6use crate::backends::q4k::matmul_q4k_f32_dispatch;
7use crate::blis::attention::fused_attention_decode;
8use crate::blis::norms::rms_norm;
9use crate::error::TruenoError;
10use crate::inference::gguf::{GgmlType, GgufFile};
11
12/// Model hyperparameters extracted from GGUF metadata.
13#[derive(Debug, Clone)]
14pub struct ModelConfig {
15    pub hidden_size: usize,
16    pub intermediate_size: usize,
17    pub num_layers: usize,
18    pub num_heads: usize,
19    pub num_kv_heads: usize,
20    pub head_dim: usize,
21    pub vocab_size: usize,
22    pub rms_norm_eps: f32,
23    pub rope_theta: f32,
24    pub max_seq_len: usize,
25    pub arch: String,
26}
27
28impl ModelConfig {
29    /// Extract config from GGUF metadata.
30    pub fn from_gguf(gguf: &GgufFile) -> Result<Self, TruenoError> {
31        let arch = gguf.meta_str("general.architecture").unwrap_or("llama").to_string();
32        let prefix = &arch; // e.g., "llama" or "qwen2"
33
34        let hidden_size = gguf
35            .meta_u32(&format!("{prefix}.embedding_length"))
36            .ok_or_else(|| TruenoError::InvalidInput("Missing embedding_length in GGUF".into()))?
37            as usize;
38
39        let num_heads = gguf
40            .meta_u32(&format!("{prefix}.attention.head_count"))
41            .ok_or_else(|| TruenoError::InvalidInput("Missing head_count in GGUF".into()))?
42            as usize;
43
44        let num_kv_heads = gguf
45            .meta_u32(&format!("{prefix}.attention.head_count_kv"))
46            .unwrap_or(num_heads as u32) as usize;
47
48        let num_layers = gguf
49            .meta_u32(&format!("{prefix}.block_count"))
50            .ok_or_else(|| TruenoError::InvalidInput("Missing block_count in GGUF".into()))?
51            as usize;
52
53        let intermediate_size =
54            gguf.meta_u32(&format!("{prefix}.feed_forward_length")).ok_or_else(|| {
55                TruenoError::InvalidInput("Missing feed_forward_length in GGUF".into())
56            })? as usize;
57
58        let head_dim = hidden_size / num_heads;
59
60        let vocab_size = gguf
61            .meta_u32("tokenizer.ggml.vocab_size")
62            .or_else(|| {
63                // Fallback: count tokens array
64                gguf.metadata.get("tokenizer.ggml.tokens").and_then(|v| {
65                    if let crate::inference::gguf::MetadataValue::Array(arr) = v {
66                        Some(arr.len() as u32)
67                    } else {
68                        None
69                    }
70                })
71            })
72            .unwrap_or(32000) as usize;
73
74        let rms_norm_eps =
75            gguf.meta_f32(&format!("{prefix}.attention.layer_norm_rms_epsilon")).unwrap_or(1e-5);
76
77        let rope_theta = gguf.meta_f32(&format!("{prefix}.rope.freq_base")).unwrap_or(10000.0);
78
79        let max_seq_len =
80            gguf.meta_u32(&format!("{prefix}.context_length")).unwrap_or(2048) as usize;
81
82        Ok(Self {
83            hidden_size,
84            intermediate_size,
85            num_layers,
86            num_heads,
87            num_kv_heads,
88            head_dim,
89            vocab_size,
90            rms_norm_eps,
91            rope_theta,
92            max_seq_len,
93            arch,
94        })
95    }
96}
97
98/// A weight matrix that may be Q4K (bytes) or any-other-quant dequantized to F32.
99pub enum WeightMatrix {
100    /// Raw Q4K bytes — use matmul_q4k_f32_dispatch
101    Q4K { data: Vec<u8>, rows: usize },
102    /// Dequantized F32 — use scalar dot-product
103    F32 { data: Vec<f32>, rows: usize },
104}
105
106impl WeightMatrix {
107    pub fn rows(&self) -> usize {
108        match self {
109            WeightMatrix::Q4K { rows, .. } => *rows,
110            WeightMatrix::F32 { rows, .. } => *rows,
111        }
112    }
113}
114
115/// Weights for a single transformer layer.
116pub struct LayerWeights {
117    // Attention
118    pub attn_norm: Vec<f32>,
119    pub q_weight: WeightMatrix,
120    pub k_weight: WeightMatrix,
121    pub v_weight: WeightMatrix,
122    pub o_weight: WeightMatrix,
123    // Qwen2/Qwen3 biases (None for LLaMA)
124    pub q_bias: Option<Vec<f32>>,
125    pub k_bias: Option<Vec<f32>>,
126    pub v_bias: Option<Vec<f32>>,
127
128    // FFN
129    pub ffn_norm: Vec<f32>,
130    pub gate_weight: WeightMatrix,
131    pub up_weight: WeightMatrix,
132    pub down_weight: WeightMatrix,
133}
134
135/// Full model weights.
136pub struct ModelWeights {
137    pub token_embd: Vec<f32>,  // [vocab_size, hidden_size]
138    pub output_norm: Vec<f32>, // [hidden_size]
139    pub output_weight: WeightMatrix,
140    pub layers: Vec<LayerWeights>,
141}
142
143/// Pre-allocated scratch buffers for the forward pass.
144/// Eliminates all per-token heap allocations (FALSIFY-ARENA-001).
145/// Contract: contracts/cgp/cgp-inference-arena-v1.yaml
146pub struct ForwardArena {
147    pub attn_input: Vec<f32>,
148    pub q: Vec<f32>,
149    pub k: Vec<f32>,
150    pub v: Vec<f32>,
151    pub attn_out: Vec<f32>,
152    pub attn_proj: Vec<f32>,
153    pub ffn_input: Vec<f32>,
154    pub gate: Vec<f32>,
155    pub up: Vec<f32>,
156    pub swiglu: Vec<f32>,
157    pub ffn_out: Vec<f32>,
158    pub hidden: Vec<f32>,
159    pub residual: Vec<f32>,
160    pub normed: Vec<f32>,
161    pub k_cache_head: Vec<f32>,
162    pub v_cache_head: Vec<f32>,
163}
164
165impl ForwardArena {
166    pub fn new(config: &ModelConfig) -> Self {
167        let kv_dim = config.num_kv_heads * config.head_dim;
168        let head_cache_size = config.max_seq_len * config.head_dim;
169        Self {
170            attn_input: vec![0.0f32; config.hidden_size],
171            q: vec![0.0f32; config.hidden_size],
172            k: vec![0.0f32; kv_dim],
173            v: vec![0.0f32; kv_dim],
174            attn_out: vec![0.0f32; config.hidden_size],
175            attn_proj: vec![0.0f32; config.hidden_size],
176            ffn_input: vec![0.0f32; config.hidden_size],
177            gate: vec![0.0f32; config.intermediate_size],
178            up: vec![0.0f32; config.intermediate_size],
179            swiglu: vec![0.0f32; config.intermediate_size],
180            ffn_out: vec![0.0f32; config.hidden_size],
181            hidden: vec![0.0f32; config.hidden_size],
182            residual: vec![0.0f32; config.hidden_size],
183            normed: vec![0.0f32; config.hidden_size],
184            k_cache_head: vec![0.0f32; head_cache_size],
185            v_cache_head: vec![0.0f32; head_cache_size],
186        }
187    }
188}
189
190/// KV cache for incremental decoding.
191pub struct KvCache {
192    /// k_cache[layer][pos * head_dim * num_kv_heads .. ] — flat per-layer
193    pub k: Vec<Vec<f32>>,
194    /// v_cache[layer][pos * head_dim * num_kv_heads .. ]
195    pub v: Vec<Vec<f32>>,
196    pub seq_len: usize,
197}
198
199impl KvCache {
200    pub fn new(config: &ModelConfig) -> Self {
201        let kv_dim = config.num_kv_heads * config.head_dim;
202        let layer_size = config.max_seq_len * kv_dim;
203        Self {
204            k: (0..config.num_layers).map(|_| vec![0.0f32; layer_size]).collect(),
205            v: (0..config.num_layers).map(|_| vec![0.0f32; layer_size]).collect(),
206            seq_len: 0,
207        }
208    }
209}
210
211/// Complete transformer model ready for inference.
212pub struct LlamaModel {
213    pub config: ModelConfig,
214    pub weights: ModelWeights,
215}
216
217impl LlamaModel {
218    /// Load model from a GGUF file.
219    pub fn from_gguf(gguf: &GgufFile) -> Result<Self, TruenoError> {
220        let config = ModelConfig::from_gguf(gguf)?;
221
222        eprintln!(
223            "Loading {} model: {}L × {}H ({}h {}kv) × {}I, vocab={}",
224            config.arch,
225            config.num_layers,
226            config.hidden_size,
227            config.num_heads,
228            config.num_kv_heads,
229            config.intermediate_size,
230            config.vocab_size,
231        );
232
233        let weights = load_weights(gguf, &config)?;
234
235        Ok(Self { config, weights })
236    }
237
238    /// Run one forward pass for a single token at the given position.
239    /// Returns logits [vocab_size].
240    /// Uses ForwardArena for zero per-token allocations (FALSIFY-ARENA-001).
241    pub fn forward(
242        &self,
243        token_id: u32,
244        pos: usize,
245        kv_cache: &mut KvCache,
246        arena: &mut ForwardArena,
247    ) -> Result<Vec<f32>, TruenoError> {
248        let cfg = &self.config;
249        let w = &self.weights;
250
251        // Token embedding lookup (copy into arena.hidden)
252        let embd_start = token_id as usize * cfg.hidden_size;
253        let embd_end = embd_start + cfg.hidden_size;
254        if embd_end > w.token_embd.len() {
255            return Err(TruenoError::InvalidInput(format!(
256                "Token ID {token_id} out of range (vocab={})",
257                cfg.vocab_size
258            )));
259        }
260        arena.hidden[..cfg.hidden_size].copy_from_slice(&w.token_embd[embd_start..embd_end]);
261
262        // Transformer layers
263        for (layer_idx, lw) in w.layers.iter().enumerate() {
264            self.forward_layer(layer_idx, lw, pos, kv_cache, arena)?;
265        }
266
267        // Final RMS norm
268        rms_norm(
269            &arena.hidden[..cfg.hidden_size],
270            &w.output_norm,
271            cfg.rms_norm_eps,
272            &mut arena.normed[..cfg.hidden_size],
273        )?;
274
275        // Output projection → logits (this one allocation is unavoidable — vocab_size is large)
276        let logits =
277            matmul_weight(&w.output_weight, &arena.normed[..cfg.hidden_size], cfg.hidden_size);
278
279        Ok(logits)
280    }
281
282    /// Forward one layer, reading/writing arena.hidden in-place.
283    fn forward_layer(
284        &self,
285        layer_idx: usize,
286        lw: &LayerWeights,
287        pos: usize,
288        kv_cache: &mut KvCache,
289        arena: &mut ForwardArena,
290    ) -> Result<(), TruenoError> {
291        let cfg = &self.config;
292        let kv_dim = cfg.num_kv_heads * cfg.head_dim;
293        let h_sz = cfg.hidden_size;
294
295        // === Attention block ===
296        rms_norm(
297            &arena.hidden[..h_sz],
298            &lw.attn_norm,
299            cfg.rms_norm_eps,
300            &mut arena.attn_input[..h_sz],
301        )?;
302
303        // QKV projections into arena buffers
304        matmul_weight_into(&lw.q_weight, &arena.attn_input[..h_sz], h_sz, &mut arena.q[..h_sz]);
305        matmul_weight_into(&lw.k_weight, &arena.attn_input[..h_sz], h_sz, &mut arena.k[..kv_dim]);
306        matmul_weight_into(&lw.v_weight, &arena.attn_input[..h_sz], h_sz, &mut arena.v[..kv_dim]);
307
308        // Optional bias (Qwen2/Qwen3)
309        if let Some(bias) = &lw.q_bias {
310            for (v, b) in arena.q[..h_sz].iter_mut().zip(bias.iter()) {
311                *v += b;
312            }
313        }
314        if let Some(bias) = &lw.k_bias {
315            for (v, b) in arena.k[..kv_dim].iter_mut().zip(bias.iter()) {
316                *v += b;
317            }
318        }
319        if let Some(bias) = &lw.v_bias {
320            for (v, b) in arena.v[..kv_dim].iter_mut().zip(bias.iter()) {
321                *v += b;
322            }
323        }
324
325        // RoPE in-place
326        apply_rope(&mut arena.q[..h_sz], cfg.num_heads, cfg.head_dim, pos, cfg.rope_theta);
327        apply_rope(&mut arena.k[..kv_dim], cfg.num_kv_heads, cfg.head_dim, pos, cfg.rope_theta);
328
329        // Store K,V in cache
330        let kv_off = pos * kv_dim;
331        kv_cache.k[layer_idx][kv_off..kv_off + kv_dim].copy_from_slice(&arena.k[..kv_dim]);
332        kv_cache.v[layer_idx][kv_off..kv_off + kv_dim].copy_from_slice(&arena.v[..kv_dim]);
333
334        let seq_len = pos + 1;
335
336        // Multi-head attention
337        arena.attn_out[..h_sz].fill(0.0);
338        let heads_per_kv = cfg.num_heads / cfg.num_kv_heads;
339
340        for h in 0..cfg.num_heads {
341            let kv_h = h / heads_per_kv;
342            let q_head = &arena.q[h * cfg.head_dim..(h + 1) * cfg.head_dim];
343
344            // Build contiguous K/V view for this head (reuse arena buffers)
345            let view_len = seq_len * cfg.head_dim;
346            for s in 0..seq_len {
347                let src_off = s * kv_dim + kv_h * cfg.head_dim;
348                let dst_off = s * cfg.head_dim;
349                arena.k_cache_head[dst_off..dst_off + cfg.head_dim]
350                    .copy_from_slice(&kv_cache.k[layer_idx][src_off..src_off + cfg.head_dim]);
351                arena.v_cache_head[dst_off..dst_off + cfg.head_dim]
352                    .copy_from_slice(&kv_cache.v[layer_idx][src_off..src_off + cfg.head_dim]);
353            }
354
355            let out_head = &mut arena.attn_out[h * cfg.head_dim..(h + 1) * cfg.head_dim];
356            fused_attention_decode(
357                q_head,
358                &arena.k_cache_head[..view_len],
359                &arena.v_cache_head[..view_len],
360                cfg.head_dim,
361                seq_len,
362                out_head,
363            );
364        }
365
366        // Output projection
367        matmul_weight_into(
368            &lw.o_weight,
369            &arena.attn_out[..h_sz],
370            h_sz,
371            &mut arena.attn_proj[..h_sz],
372        );
373
374        // Residual: residual = hidden + attn_proj
375        for i in 0..h_sz {
376            arena.residual[i] = arena.hidden[i] + arena.attn_proj[i];
377        }
378
379        // === FFN block ===
380        rms_norm(
381            &arena.residual[..h_sz],
382            &lw.ffn_norm,
383            cfg.rms_norm_eps,
384            &mut arena.ffn_input[..h_sz],
385        )?;
386
387        let i_sz = cfg.intermediate_size;
388        matmul_weight_into(
389            &lw.gate_weight,
390            &arena.ffn_input[..h_sz],
391            h_sz,
392            &mut arena.gate[..i_sz],
393        );
394        matmul_weight_into(&lw.up_weight, &arena.ffn_input[..h_sz], h_sz, &mut arena.up[..i_sz]);
395
396        // SiLU(gate) * up → swiglu
397        for i in 0..i_sz {
398            let g = arena.gate[i];
399            let silu_g = g / (1.0 + (-g).exp());
400            arena.swiglu[i] = silu_g * arena.up[i];
401        }
402
403        // Down projection
404        matmul_weight_into(
405            &lw.down_weight,
406            &arena.swiglu[..i_sz],
407            i_sz,
408            &mut arena.ffn_out[..h_sz],
409        );
410
411        // Residual → hidden (for next layer)
412        for i in 0..h_sz {
413            arena.hidden[i] = arena.residual[i] + arena.ffn_out[i];
414        }
415
416        Ok(())
417    }
418}
419
420/// Apply Rotary Position Embedding (RoPE) in-place.
421fn apply_rope(x: &mut [f32], num_heads: usize, head_dim: usize, pos: usize, theta: f32) {
422    for h in 0..num_heads {
423        let head = &mut x[h * head_dim..(h + 1) * head_dim];
424        for i in (0..head_dim).step_by(2) {
425            let freq = 1.0 / theta.powf(i as f32 / head_dim as f32);
426            let angle = pos as f32 * freq;
427            let (sin_a, cos_a) = angle.sin_cos();
428            let x0 = head[i];
429            let x1 = head[i + 1];
430            head[i] = x0 * cos_a - x1 * sin_a;
431            head[i + 1] = x0 * sin_a + x1 * cos_a;
432        }
433    }
434}
435
436/// Load all weights from GGUF into model weight structs.
437fn load_weights(gguf: &GgufFile, config: &ModelConfig) -> Result<ModelWeights, TruenoError> {
438    // Token embeddings — may be F32, F16, or quantized (Q4K/Q6K).
439    // For quantized embeddings, dequantize the full table at load time
440    // since we need random-access per-token lookup.
441    let token_embd = load_f32_or_dequant_tensor(
442        gguf,
443        "token_embd.weight",
444        config.vocab_size * config.hidden_size,
445    )?;
446
447    // Output norm
448    let output_norm = load_f32_tensor(gguf, "output_norm.weight", config.hidden_size)?;
449
450    // Output projection — Q4K kept as bytes; everything else dequantized to F32.
451    // Falls back to tied embeddings if output.weight not present.
452    let output_weight = if gguf.tensor_info("output.weight").is_some() {
453        load_weight_matrix(gguf, "output.weight", config.hidden_size)?
454    } else {
455        // Tied embeddings
456        WeightMatrix::F32 { data: token_embd.clone(), rows: config.vocab_size }
457    };
458
459    // Layers
460    let mut layers = Vec::with_capacity(config.num_layers);
461    for i in 0..config.num_layers {
462        let prefix = format!("blk.{i}");
463
464        let attn_norm =
465            load_f32_tensor(gguf, &format!("{prefix}.attn_norm.weight"), config.hidden_size)?;
466        let ffn_norm =
467            load_f32_tensor(gguf, &format!("{prefix}.ffn_norm.weight"), config.hidden_size)?;
468
469        let q_weight =
470            load_weight_matrix(gguf, &format!("{prefix}.attn_q.weight"), config.hidden_size)?;
471        let k_weight =
472            load_weight_matrix(gguf, &format!("{prefix}.attn_k.weight"), config.hidden_size)?;
473        let v_weight =
474            load_weight_matrix(gguf, &format!("{prefix}.attn_v.weight"), config.hidden_size)?;
475        let o_weight =
476            load_weight_matrix(gguf, &format!("{prefix}.attn_output.weight"), config.hidden_size)?;
477
478        // Qwen2/Qwen3 attention biases (optional — LLaMA has none)
479        let kv_dim = config.num_kv_heads * config.head_dim;
480        let q_bias = load_optional_f32(gguf, &format!("{prefix}.attn_q.bias"), config.hidden_size);
481        let k_bias = load_optional_f32(gguf, &format!("{prefix}.attn_k.bias"), kv_dim);
482        let v_bias = load_optional_f32(gguf, &format!("{prefix}.attn_v.bias"), kv_dim);
483
484        let gate_weight =
485            load_weight_matrix(gguf, &format!("{prefix}.ffn_gate.weight"), config.hidden_size)?;
486        let up_weight =
487            load_weight_matrix(gguf, &format!("{prefix}.ffn_up.weight"), config.hidden_size)?;
488        let down_weight = load_weight_matrix(
489            gguf,
490            &format!("{prefix}.ffn_down.weight"),
491            config.intermediate_size,
492        )?;
493
494        if i == 0 {
495            eprintln!(
496                "  Layer 0: Q[{}×{}] K[{}×{}] V[{}×{}] Gate[{}×{}]",
497                q_weight.rows(),
498                config.hidden_size,
499                k_weight.rows(),
500                config.hidden_size,
501                v_weight.rows(),
502                config.hidden_size,
503                gate_weight.rows(),
504                config.hidden_size,
505            );
506        }
507
508        layers.push(LayerWeights {
509            attn_norm,
510            q_weight,
511            k_weight,
512            v_weight,
513            o_weight,
514            q_bias,
515            k_bias,
516            v_bias,
517            ffn_norm,
518            gate_weight,
519            up_weight,
520            down_weight,
521        });
522    }
523
524    eprintln!("  Loaded {} layers", layers.len());
525
526    Ok(ModelWeights { token_embd, output_norm, output_weight, layers })
527}
528
529/// Load a tensor as F32, dequantizing if quantized.
530/// For Q4K weights, uses trueno's dequantize_q4k_to_f32.
531fn load_f32_or_dequant_tensor(
532    gguf: &GgufFile,
533    name: &str,
534    expected_elements: usize,
535) -> Result<Vec<f32>, TruenoError> {
536    let info = gguf
537        .tensor_info(name)
538        .ok_or_else(|| TruenoError::InvalidInput(format!("Missing tensor: {name}")))?;
539    let data = gguf
540        .tensor_data(name)
541        .ok_or_else(|| TruenoError::InvalidInput(format!("Missing tensor data: {name}")))?;
542
543    match info.dtype {
544        GgmlType::F32 | GgmlType::F16 | GgmlType::Bf16 => {
545            Ok(to_f32_from_any(data, info.dtype, expected_elements))
546        }
547        GgmlType::Q4K => {
548            let n_elements = info.n_elements() as usize;
549            Ok(crate::backends::q4k::dequantize_q4k_to_f32(data, n_elements))
550        }
551        GgmlType::Q6K => Ok(dequantize_q6k_to_f32(data, info.n_elements() as usize)),
552        GgmlType::Q5K => Ok(dequantize_q5k_to_f32(data, info.n_elements() as usize)),
553        GgmlType::Q8_0 => Ok(dequantize_q8_0_to_f32(data, info.n_elements() as usize)),
554        GgmlType::Q4_0 => Ok(dequantize_q4_0_to_f32(data, info.n_elements() as usize)),
555        GgmlType::Q4_1 => Ok(dequantize_q4_1_to_f32(data, info.n_elements() as usize)),
556        _ => {
557            eprintln!(
558                "  WARNING: tensor '{name}' has unsupported dtype {:?}, using zeros",
559                info.dtype
560            );
561            Ok(vec![0.0f32; expected_elements])
562        }
563    }
564}
565
566/// Load an optional F32 tensor (returns None if tensor doesn't exist in GGUF).
567fn load_optional_f32(gguf: &GgufFile, name: &str, expected_elements: usize) -> Option<Vec<f32>> {
568    let info = gguf.tensor_info(name)?;
569    let data = gguf.tensor_data(name)?;
570    Some(to_f32_from_any(data, info.dtype, expected_elements))
571}
572
573/// Load a tensor as F32 (dequantizing F16 if needed).
574fn load_f32_tensor(
575    gguf: &GgufFile,
576    name: &str,
577    expected_elements: usize,
578) -> Result<Vec<f32>, TruenoError> {
579    let info = gguf
580        .tensor_info(name)
581        .ok_or_else(|| TruenoError::InvalidInput(format!("Missing tensor: {name}")))?;
582    let data = gguf
583        .tensor_data(name)
584        .ok_or_else(|| TruenoError::InvalidInput(format!("Missing tensor data: {name}")))?;
585
586    Ok(to_f32_from_any(data, info.dtype, expected_elements))
587}
588
589/// Load a weight tensor as a `WeightMatrix`.
590/// Q4K weights are kept as raw bytes for the fused matmul kernel.
591/// All other quantization types are dequantized to F32 at load time.
592fn load_weight_matrix(
593    gguf: &GgufFile,
594    name: &str,
595    in_dim: usize,
596) -> Result<WeightMatrix, TruenoError> {
597    let info = gguf
598        .tensor_info(name)
599        .ok_or_else(|| TruenoError::InvalidInput(format!("Missing tensor: {name}")))?;
600    let data = gguf
601        .tensor_data(name)
602        .ok_or_else(|| TruenoError::InvalidInput(format!("Missing tensor data: {name}")))?;
603
604    let n_elements = info.n_elements() as usize;
605    let out_dim = n_elements / in_dim;
606
607    match info.dtype {
608        GgmlType::Q4K => Ok(WeightMatrix::Q4K { data: data.to_vec(), rows: out_dim }),
609        GgmlType::F32 | GgmlType::F16 | GgmlType::Bf16 => {
610            let f32_data = to_f32_from_any(data, info.dtype, n_elements);
611            Ok(WeightMatrix::F32 { data: f32_data, rows: out_dim })
612        }
613        GgmlType::Q6K => {
614            let f32_data = dequantize_q6k_to_f32(data, n_elements);
615            Ok(WeightMatrix::F32 { data: f32_data, rows: out_dim })
616        }
617        GgmlType::Q5K => {
618            let f32_data = dequantize_q5k_to_f32(data, n_elements);
619            Ok(WeightMatrix::F32 { data: f32_data, rows: out_dim })
620        }
621        GgmlType::Q8_0 => {
622            let f32_data = dequantize_q8_0_to_f32(data, n_elements);
623            Ok(WeightMatrix::F32 { data: f32_data, rows: out_dim })
624        }
625        GgmlType::Q4_0 => {
626            let f32_data = dequantize_q4_0_to_f32(data, n_elements);
627            Ok(WeightMatrix::F32 { data: f32_data, rows: out_dim })
628        }
629        GgmlType::Q4_1 => {
630            let f32_data = dequantize_q4_1_to_f32(data, n_elements);
631            Ok(WeightMatrix::F32 { data: f32_data, rows: out_dim })
632        }
633        _ => {
634            eprintln!("  WARNING: tensor '{name}' dtype {:?} unsupported, using zeros", info.dtype);
635            Ok(WeightMatrix::F32 { data: vec![0.0f32; n_elements], rows: out_dim })
636        }
637    }
638}
639
640/// Dispatch matrix-vector multiply based on weight type (allocating).
641/// Used only for output projection (vocab_size too large for arena).
642fn matmul_weight(weight: &WeightMatrix, input: &[f32], in_dim: usize) -> Vec<f32> {
643    match weight {
644        WeightMatrix::Q4K { data, rows } => matmul_q4k_f32_dispatch(data, input, *rows, in_dim),
645        WeightMatrix::F32 { data, rows } => {
646            let mut out = vec![0.0f32; *rows];
647            for i in 0..*rows {
648                let row = &data[i * in_dim..(i + 1) * in_dim];
649                out[i] = row.iter().zip(input.iter()).map(|(a, b)| a * b).sum();
650            }
651            out
652        }
653    }
654}
655
656/// Dispatch matrix-vector multiply into a pre-allocated output buffer (zero-alloc).
657fn matmul_weight_into(weight: &WeightMatrix, input: &[f32], in_dim: usize, out: &mut [f32]) {
658    match weight {
659        WeightMatrix::Q4K { data, rows } => {
660            let result = matmul_q4k_f32_dispatch(data, input, *rows, in_dim);
661            out[..*rows].copy_from_slice(&result);
662        }
663        WeightMatrix::F32 { data, rows } => {
664            for i in 0..*rows {
665                let row = &data[i * in_dim..(i + 1) * in_dim];
666                out[i] = row.iter().zip(input.iter()).map(|(a, b)| a * b).sum();
667            }
668        }
669    }
670}
671
672/// Convert IEEE 754 half-precision (FP16) bits to f32.
673fn f16_to_f32(bits: u16) -> f32 {
674    let sign = ((bits >> 15) as u32) << 31;
675    let exp = ((bits >> 10) & 0x1F) as u32;
676    let mant = (bits & 0x3FF) as u32;
677
678    if exp == 0 {
679        if mant == 0 {
680            return f32::from_bits(sign); // ±0
681        }
682        // Denormalized: convert to normalized f32
683        let mut m = mant;
684        let mut e: i32 = -14;
685        while m & 0x400 == 0 {
686            m <<= 1;
687            e -= 1;
688        }
689        m &= 0x3FF;
690        let f32_exp = ((e + 127) as u32) << 23;
691        return f32::from_bits(sign | f32_exp | (m << 13));
692    }
693    if exp == 31 {
694        // Inf/NaN
695        return f32::from_bits(sign | 0x7F80_0000 | (mant << 13));
696    }
697    let f32_exp = (exp + 112) << 23; // rebias: -15 + 127 = 112
698    f32::from_bits(sign | f32_exp | (mant << 13))
699}
700
701/// Convert tensor bytes to f32, handling F32, F16, BF16.
702fn to_f32_from_any(data: &[u8], dtype: GgmlType, n_elements: usize) -> Vec<f32> {
703    match dtype {
704        GgmlType::F32 => {
705            // Safe: read f32 values from aligned-or-unaligned bytes
706            let count = n_elements.min(data.len() / 4);
707            (0..count)
708                .map(|i| {
709                    let off = i * 4;
710                    f32::from_le_bytes([data[off], data[off + 1], data[off + 2], data[off + 3]])
711                })
712                .collect()
713        }
714        GgmlType::F16 => {
715            let count = n_elements.min(data.len() / 2);
716            (0..count)
717                .map(|i| {
718                    let off = i * 2;
719                    let bits = u16::from_le_bytes([data[off], data[off + 1]]);
720                    f16_to_f32(bits)
721                })
722                .collect()
723        }
724        GgmlType::Bf16 => {
725            let count = n_elements.min(data.len() / 2);
726            (0..count)
727                .map(|i| {
728                    let off = i * 2;
729                    let bits = u16::from_le_bytes([data[off], data[off + 1]]);
730                    f32::from_bits((bits as u32) << 16)
731                })
732                .collect()
733        }
734        _ => {
735            // For quantized norms (shouldn't happen), return zeros
736            vec![0.0f32; n_elements]
737        }
738    }
739}
740
741/// Dequantize Q6_K to F32.
742///
743/// Q6_K layout per 256-element super-block (210 bytes):
744/// - ql[128]: lower 4 bits of each 6-bit value (2 values per byte)
745/// - qh[64]:  upper 2 bits (4 values per byte, 2 bits each)
746/// - scales[16]: signed 8-bit scales for 16 groups of 16
747/// - d[2]: f16 global scale
748///
749/// Value = d * scale[group] * q6  where q6 = (low4 | high2<<4) as i8 - 32
750fn dequantize_q6k_to_f32(data: &[u8], num_elements: usize) -> Vec<f32> {
751    const BLOCK_SIZE: usize = 256;
752    const BLOCK_BYTES: usize = 210;
753
754    let num_blocks = (num_elements + BLOCK_SIZE - 1) / BLOCK_SIZE;
755    let mut result = vec![0.0f32; num_blocks * BLOCK_SIZE];
756
757    for sb in 0..num_blocks {
758        let sb_start = sb * BLOCK_BYTES;
759        if sb_start + BLOCK_BYTES > data.len() {
760            break;
761        }
762        let block = &data[sb_start..sb_start + BLOCK_BYTES];
763        let ql = &block[0..128];
764        let qh = &block[128..192];
765        let scales = &block[192..208];
766        let d = f16_to_f32(u16::from_le_bytes([block[208], block[209]]));
767
768        let out_base = sb * BLOCK_SIZE;
769        for group in 0..16usize {
770            let scale = (scales[group] as i8) as f32;
771            let group_off = group * 16;
772            for j in 0..16usize {
773                let idx = group_off + j;
774                let ql_byte = ql[idx / 2];
775                let low4 = if idx % 2 == 0 { ql_byte & 0x0F } else { ql_byte >> 4 };
776                let qh_byte = qh[idx / 4];
777                let high2 = (qh_byte >> ((idx % 4) * 2)) & 0x03;
778                let q6 = ((low4 | (high2 << 4)) as i8).wrapping_sub(32) as f32;
779                result[out_base + idx] = d * scale * q6;
780            }
781        }
782    }
783
784    result.truncate(num_elements);
785    result
786}
787
788/// Dequantize Q5_K to F32.
789///
790/// Q5_K layout per 256-element super-block (176 bytes):
791/// - d[2]:       f16 super-block scale
792/// - dmin[2]:    f16 super-block min scale
793/// - scales[12]: packed 6-bit scales and mins (8 sub-blocks × 2 values)
794/// - qh[32]:     high bit of each 5-bit value (1 bit per element = 32 bytes)
795/// - qs[128]:    lower 4 bits per element (2 per byte)
796///
797/// Value = d * scale * q5 - dmin * min
798fn dequantize_q5k_to_f32(data: &[u8], num_elements: usize) -> Vec<f32> {
799    const BLOCK_SIZE: usize = 256;
800    const BLOCK_BYTES: usize = 176;
801
802    let num_blocks = (num_elements + BLOCK_SIZE - 1) / BLOCK_SIZE;
803    let mut result = vec![0.0f32; num_blocks * BLOCK_SIZE];
804
805    for sb in 0..num_blocks {
806        let sb_start = sb * BLOCK_BYTES;
807        if sb_start + BLOCK_BYTES > data.len() {
808            break;
809        }
810        let block = &data[sb_start..sb_start + BLOCK_BYTES];
811        let d = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
812        let dmin = f16_to_f32(u16::from_le_bytes([block[2], block[3]]));
813
814        // Unpack 6-bit scales and mins (same layout as Q4K)
815        let sc = &block[4..16];
816        let mut scales = [0u8; 8];
817        let mut mins = [0u8; 8];
818        for i in 0..4 {
819            scales[i] = sc[i] & 0x3F;
820            mins[i] = sc[i + 4] & 0x3F;
821            scales[i + 4] = (sc[i + 8] & 0x0F) | ((sc[i] >> 6) << 4);
822            mins[i + 4] = (sc[i + 8] >> 4) | ((sc[i + 4] >> 6) << 4);
823        }
824
825        let qh = &block[16..48];
826        let qs = &block[48..176];
827
828        let out_base = sb * BLOCK_SIZE;
829        for sub in 0..8usize {
830            let scale = d * scales[sub] as f32;
831            let min = dmin * mins[sub] as f32;
832            let sub_off = sub * 32;
833            for j in 0..32usize {
834                let idx = sub_off + j;
835                let low4 = (qs[idx / 2] >> ((idx % 2) * 4)) & 0x0F;
836                let high1 = (qh[idx / 8] >> (idx % 8)) & 0x01;
837                let q5 = (low4 | (high1 << 4)) as f32;
838                result[out_base + idx] = scale * q5 - min;
839            }
840        }
841    }
842
843    result.truncate(num_elements);
844    result
845}
846
847/// Dequantize Q8_0 to F32.
848///
849/// Q8_0 layout per 32-element block (34 bytes):
850/// - d[2]:   f16 block scale
851/// - qs[32]: signed 8-bit quantized values
852///
853/// Value = d * qs[i]
854fn dequantize_q8_0_to_f32(data: &[u8], num_elements: usize) -> Vec<f32> {
855    const BLOCK_SIZE: usize = 32;
856    const BLOCK_BYTES: usize = 34;
857
858    let num_blocks = (num_elements + BLOCK_SIZE - 1) / BLOCK_SIZE;
859    let mut result = vec![0.0f32; num_blocks * BLOCK_SIZE];
860
861    for b in 0..num_blocks {
862        let b_start = b * BLOCK_BYTES;
863        if b_start + BLOCK_BYTES > data.len() {
864            break;
865        }
866        let block = &data[b_start..b_start + BLOCK_BYTES];
867        let d = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
868        let out_base = b * BLOCK_SIZE;
869        for j in 0..BLOCK_SIZE {
870            result[out_base + j] = d * (block[2 + j] as i8) as f32;
871        }
872    }
873
874    result.truncate(num_elements);
875    result
876}
877
878/// Dequantize Q4_0 to F32.
879///
880/// Q4_0 layout per 32-element block (18 bytes):
881/// - d[2]:   f16 block scale
882/// - qs[16]: 4-bit quantized values, 2 per byte
883///
884/// Value = d * (q4 - 8)  where q4 ∈ 0..15 (centered: subtract 8)
885fn dequantize_q4_0_to_f32(data: &[u8], num_elements: usize) -> Vec<f32> {
886    const BLOCK_SIZE: usize = 32;
887    const BLOCK_BYTES: usize = 18;
888
889    let num_blocks = (num_elements + BLOCK_SIZE - 1) / BLOCK_SIZE;
890    let mut result = vec![0.0f32; num_blocks * BLOCK_SIZE];
891
892    for b in 0..num_blocks {
893        let b_start = b * BLOCK_BYTES;
894        if b_start + BLOCK_BYTES > data.len() {
895            break;
896        }
897        let block = &data[b_start..b_start + BLOCK_BYTES];
898        let d = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
899        let out_base = b * BLOCK_SIZE;
900        for j in 0..16 {
901            let byte = block[2 + j];
902            let lo = (byte & 0x0F) as i32 - 8;
903            let hi = ((byte >> 4) & 0x0F) as i32 - 8;
904            result[out_base + j * 2] = d * lo as f32;
905            result[out_base + j * 2 + 1] = d * hi as f32;
906        }
907    }
908
909    result.truncate(num_elements);
910    result
911}
912
913/// Dequantize Q4_1 to F32.
914///
915/// Q4_1 layout per 32-element block (20 bytes):
916/// - d[2]:   f16 scale
917/// - m[2]:   f16 min (additive offset)
918/// - qs[16]: 4-bit quantized values, 2 per byte
919///
920/// Value = d * q4 + m  where q4 ∈ 0..15
921fn dequantize_q4_1_to_f32(data: &[u8], num_elements: usize) -> Vec<f32> {
922    const BLOCK_SIZE: usize = 32;
923    const BLOCK_BYTES: usize = 20;
924
925    let num_blocks = (num_elements + BLOCK_SIZE - 1) / BLOCK_SIZE;
926    let mut result = vec![0.0f32; num_blocks * BLOCK_SIZE];
927
928    for b in 0..num_blocks {
929        let b_start = b * BLOCK_BYTES;
930        if b_start + BLOCK_BYTES > data.len() {
931            break;
932        }
933        let block = &data[b_start..b_start + BLOCK_BYTES];
934        let d = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
935        let m = f16_to_f32(u16::from_le_bytes([block[2], block[3]]));
936        let out_base = b * BLOCK_SIZE;
937        for j in 0..16 {
938            let byte = block[4 + j];
939            let lo = (byte & 0x0F) as f32;
940            let hi = ((byte >> 4) & 0x0F) as f32;
941            result[out_base + j * 2] = d * lo + m;
942            result[out_base + j * 2 + 1] = d * hi + m;
943        }
944    }
945
946    result.truncate(num_elements);
947    result
948}