aprender-serve 0.33.0

Pure Rust ML inference engine built from scratch - model serving for GGUF and safetensors
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
//! Quantized GGUF transformer types
//!
//! This module contains the quantized transformer layer and model structures
//! that enable fused dequantization operations for memory-efficient inference.

use crate::error::{RealizarError, Result};
use crate::quantize::QK_K;

use super::config::{GGUFConfig, ValidatedModelConfig};
use super::quantized::{QKVWeights, QuantizedTensorRef};
use super::types::{
    GGUFModel, GGUF_TYPE_F32, GGUF_TYPE_Q2_K, GGUF_TYPE_Q4_0, GGUF_TYPE_Q4_1, GGUF_TYPE_Q4_K,
    GGUF_TYPE_Q5_0, GGUF_TYPE_Q5_K, GGUF_TYPE_Q6_K, GGUF_TYPE_Q8_0,
};

/// Quantized transformer layer weights (stored as byte references)
///
/// Unlike `GGUFTransformerLayer` which stores dequantized Vec<f32>,
/// this stores references to quantized data for fused operations.
pub struct QuantizedGGUFTransformerLayer {
    /// Attention norm weight (kept as f32 - small, read once per token)
    pub attn_norm_weight: Vec<f32>,
    /// Attention norm bias (optional)
    pub attn_norm_bias: Option<Vec<f32>>,
    /// QKV projection weights (quantized) - supports fused or separate
    pub qkv_weight: QKVWeights,
    /// QKV bias (optional, f32)
    pub qkv_bias: Option<Vec<f32>>,
    /// Attention output projection (quantized)
    pub attn_output_weight: QuantizedTensorRef,
    /// Attention output bias (optional, f32)
    pub attn_output_bias: Option<Vec<f32>>,
    /// FFN up projection (quantized)
    pub ffn_up_weight: QuantizedTensorRef,
    /// FFN up bias (optional, f32)
    pub ffn_up_bias: Option<Vec<f32>>,
    /// FFN down projection (quantized)
    pub ffn_down_weight: QuantizedTensorRef,
    /// FFN down bias (optional, f32)
    pub ffn_down_bias: Option<Vec<f32>>,
    /// FFN gate projection (quantized, SwiGLU models like LLaMA)
    pub ffn_gate_weight: Option<QuantizedTensorRef>,
    /// FFN gate bias (optional, f32)
    pub ffn_gate_bias: Option<Vec<f32>>,
    /// FFN norm weight (pre-FFN layer norm, LLaMA-style)
    pub ffn_norm_weight: Option<Vec<f32>>,
    /// FFN norm bias (optional, f32)
    pub ffn_norm_bias: Option<Vec<f32>>,
    /// GH-279: Per-head Q RMSNorm weight [head_dim] (Qwen3)
    pub attn_q_norm_weight: Option<Vec<f32>>,
    /// GH-279: Per-head K RMSNorm weight [head_dim] (Qwen3)
    pub attn_k_norm_weight: Option<Vec<f32>>,
}

/// Quantized GGUF Transformer for fused inference
///
/// Per Williams et al. (2009) roofline model, LLM inference is memory-bound.
/// This transformer stores weights in quantized form and uses fused
/// dequant+dot operations to minimize memory bandwidth.
///
/// # Performance Benefits
///
/// - **8x bandwidth reduction** for Q4_K vs f32 (144 bytes vs 1024 bytes per 256 values)
/// - **Zero intermediate buffers** - dequantization happens inline with dot product
/// - **SIMD acceleration** - AVX2/FMA fused operations when available
/// - **Zero-copy loading** - weights stay in memory-mapped file
///
/// # Architecture
///
/// ```text
/// [Memory-mapped Q4_K bytes] → [fused_q4k_dot_simd] → [f32 result]
//////                         No intermediate Vec<f32>!
/// ```
pub struct QuantizedGGUFTransformer<'a> {
    /// Model configuration
    pub config: GGUFConfig,
    /// Reference to memory-mapped file data
    pub data: &'a [u8],
    /// Token embedding (kept as f32 for lookup)
    pub token_embedding: Vec<f32>,
    /// GH-278: Position embedding [context_length, hidden_dim] (GPT-2 only)
    pub position_embedding: Option<Vec<f32>>,
    /// Quantized layer weights
    pub layers: Vec<QuantizedGGUFTransformerLayer>,
    /// M32c.2: Per-layer MoE expert tensor descriptors when loaded
    /// via `from_gguf_for_moe`. Empty `Vec` for dense models loaded
    /// via the standard `from_gguf` constructor. When populated,
    /// `moe_layers.len() == layers.len()` and each entry holds the
    /// 4 quantized tensor refs for `qwen3_moe`'s router + per-expert
    /// gate/up/down. The `layers[i].ffn_up_weight` etc. fields are
    /// stubbed with empty `QuantizedTensorRef` placeholders for MoE
    /// layers; consumers MUST check `moe_layers[i].is_some()` before
    /// dispatching the FFN.
    pub moe_layers: Vec<Option<crate::gguf::qwen3_moe_load::Qwen3MoeQuantizedLayer>>,
    /// Output norm weight (f32)
    pub output_norm_weight: Vec<f32>,
    /// Output norm bias (optional)
    pub output_norm_bias: Option<Vec<f32>>,
    /// LM head weight (quantized for large vocab)
    pub lm_head_weight: QuantizedTensorRef,
    /// LM head bias (optional, f32)
    pub lm_head_bias: Option<Vec<f32>>,
}

impl<'a> QuantizedGGUFTransformer<'a> {
    /// Load quantized transformer from memory-mapped GGUF model
    ///
    /// # Arguments
    ///
    /// * `model` - Parsed GGUF model metadata
    /// * `data` - Memory-mapped file data (zero-copy)
    ///
    /// # Errors
    ///
    /// Returns error if required tensors are missing or have unsupported format
    pub fn from_gguf(model: &GGUFModel, data: &'a [u8]) -> Result<Self> {
        // Phase 2: Validate config at construction boundary.
        let config = ValidatedModelConfig::from_gguf(model)?.into_inner();

        // GH-704: Detect hybrid SSM architectures (Qwen3.5 Gated Delta Net) early.
        // These require a dedicated SSM inference path not yet implemented.
        let has_ssm = model
            .tensors
            .iter()
            .any(|t| t.name.contains("ssm_") || t.name.contains("ssm."));
        if has_ssm {
            let arch = &config.architecture;
            return Err(crate::RealizarError::FormatError {
                reason: format!(
                    "Architecture '{arch}' uses SSM/Gated Delta Net layers which are not yet \
                     supported for inference. Use a standard transformer model (e.g., Qwen2.5, \
                     LLaMA, Mistral) or wait for SSM support in a future release."
                ),
            });
        }

        // M32b: refuse Mixture-of-Experts architectures with a structured,
        // contract-named error before reaching the dense-FFN tensor lookup.
        // Replaces the pre-M32 cryptic "Tensor 'blk.0.ffn_up.weight' not
        // found" surface captured by FALSIFY-QW3-MOE-FORWARD-001 in
        // contracts/qwen3-moe-forward-v1.yaml.
        // M32c.2.1: dispatch qwen3_moe arch to the MoE-aware constructor
        // (M32c.2's `from_gguf_for_moe`). Loading now succeeds end-to-end;
        // the forward path emits the contract-named UnsupportedOperation
        // when it encounters the placeholder dense FFN — see M32c.2.2 for
        // the actual MoE forward wiring. Replaces M32b's load-time refusal.
        // See contracts/qwen3-moe-forward-v1.yaml.
        let canonical_arch = crate::tensor_names::normalize_architecture(&config.architecture);
        if canonical_arch == "qwen3_moe" {
            return Self::from_gguf_for_moe(model, data);
        }

        // Token embedding - keep as f32 for efficient lookup
        let token_embedding = model.get_tensor_f32("token_embd.weight", data)?;
        // GH-278: Position embedding — standard GGUF + legacy + aprender export fallback
        let position_embedding = model
            .get_tensor_f32("position_embd.weight", data)
            .or_else(|_| model.get_tensor_f32("token_pos_embd.weight", data))
            .or_else(|_| model.get_tensor_f32("model.position_embedding.weight", data))
            .ok();

        // Load layers with quantized weight references
        let mut layers = Vec::with_capacity(config.num_layers);
        for layer_idx in 0..config.num_layers {
            let layer = Self::load_quantized_layer(model, data, layer_idx)?;
            layers.push(layer);
        }

        // Output norm - small, keep as f32
        let output_norm_weight = model.get_tensor_f32("output_norm.weight", data)?;
        // GH-278: Output norm bias — standard + aprender fallback
        let output_norm_bias = model
            .get_tensor_f32("output_norm.bias", data)
            .or_else(|_| model.get_tensor_f32("model.norm.bias", data))
            .ok();

        // LM head - large, keep quantized
        // Fall back to token_embd.weight for tied embeddings (Qwen2, some LLaMA variants)
        let lm_head_weight = Self::get_tensor_ref(model, data, "output.weight")
            .or_else(|_| Self::get_tensor_ref(model, data, "token_embd.weight"))?;
        let lm_head_bias = model.get_tensor_f32("output.bias", data).ok();

        Ok(Self {
            config,
            data,
            token_embedding,
            position_embedding,
            layers,
            moe_layers: Vec::new(),
            output_norm_weight,
            output_norm_bias,
            lm_head_weight,
            lm_head_bias,
        })
    }

    /// M32c.2: Load a `qwen3_moe`-arch GGUF, populating both the
    /// non-FFN dense fields and the per-layer MoE expert tensor
    /// descriptors. This is the qwen3_moe-aware sibling of
    /// `from_gguf` — call it instead when the architecture has been
    /// canonicalized to `qwen3_moe`.
    ///
    /// Forward dispatch is NOT yet wired (M32c.2.1). This
    /// constructor exists so M32c.2 can prove that the
    /// load infrastructure (M32c.1's `load_qwen3_moe_layer` +
    /// shared dense-FFN-skip path) works end-to-end against the
    /// real 17.3 GB Qwen3-Coder GGUF without going through the
    /// M32b load-time refusal.
    ///
    /// # Arguments
    /// * `model` - Parsed GGUF model. The caller MUST have verified
    ///   that `tensor_names::normalize_architecture(&config.architecture) == "qwen3_moe"`.
    /// * `data` - Memory-mapped file data (zero-copy).
    ///
    /// # Errors
    /// Returns an error if any of:
    /// - SSM tensor names appear (mutually exclusive with MoE)
    /// - Required non-FFN tensors are missing (token_embd, attn_*,
    ///   output_norm, output)
    /// - Any MoE tensor declared by `tensor-names-v1` v1.1.0 is
    ///   missing for any layer
    ///
    /// On success, every `layers[i]` has placeholder dense FFN
    /// `QuantizedTensorRef`s (offset=0, byte_size=0, num_elements=0,
    /// qtype=GGUF_TYPE_F32) — consumers MUST check
    /// `moe_layers[i].is_some()` before attempting any dense FFN
    /// dequantization.
    pub fn from_gguf_for_moe(model: &GGUFModel, data: &'a [u8]) -> Result<Self> {
        let config = ValidatedModelConfig::from_gguf(model)?.into_inner();

        let canonical_arch = crate::tensor_names::normalize_architecture(&config.architecture);
        if canonical_arch != "qwen3_moe" {
            return Err(crate::error::RealizarError::InvalidShape {
                reason: format!(
                    "from_gguf_for_moe: architecture '{}' (canonical '{}') is not qwen3_moe — \
                     caller should dispatch to from_gguf instead",
                    config.architecture, canonical_arch
                ),
            });
        }

        let has_ssm = model
            .tensors
            .iter()
            .any(|t| t.name.contains("ssm_") || t.name.contains("ssm."));
        if has_ssm {
            return Err(crate::RealizarError::FormatError {
                reason: format!(
                    "Architecture '{}' has both qwen3_moe arch tag AND SSM tensors — \
                     unsupported hybrid configuration",
                    config.architecture
                ),
            });
        }

        let token_embedding = model.get_tensor_f32("token_embd.weight", data)?;
        let position_embedding = model
            .get_tensor_f32("position_embd.weight", data)
            .or_else(|_| model.get_tensor_f32("token_pos_embd.weight", data))
            .or_else(|_| model.get_tensor_f32("model.position_embedding.weight", data))
            .ok();

        let mut layers = Vec::with_capacity(config.num_layers);
        let mut moe_layers = Vec::with_capacity(config.num_layers);
        for layer_idx in 0..config.num_layers {
            layers.push(Self::load_quantized_layer_moe_skeleton(
                model, data, layer_idx,
            )?);
            moe_layers.push(Some(crate::gguf::qwen3_moe_load::load_qwen3_moe_layer(
                model, data, layer_idx,
            )?));
        }

        let output_norm_weight = model.get_tensor_f32("output_norm.weight", data)?;
        let output_norm_bias = model
            .get_tensor_f32("output_norm.bias", data)
            .or_else(|_| model.get_tensor_f32("model.norm.bias", data))
            .ok();

        let lm_head_weight = Self::get_tensor_ref(model, data, "output.weight")
            .or_else(|_| Self::get_tensor_ref(model, data, "token_embd.weight"))?;
        let lm_head_bias = model.get_tensor_f32("output.bias", data).ok();

        Ok(Self {
            config,
            data,
            token_embedding,
            position_embedding,
            layers,
            moe_layers,
            output_norm_weight,
            output_norm_bias,
            lm_head_weight,
            lm_head_bias,
        })
    }

    /// M32c.2 helper: load the non-FFN portion of a transformer layer.
    /// Dense FFN fields are stubbed with empty `QuantizedTensorRef`
    /// placeholders — the caller MUST populate `moe_layers[i]` for
    /// these layers via `load_qwen3_moe_layer`.
    fn load_quantized_layer_moe_skeleton(
        model: &GGUFModel,
        data: &[u8],
        layer_idx: usize,
    ) -> Result<QuantizedGGUFTransformerLayer> {
        let prefix = format!("blk.{layer_idx}");

        let attn_norm_weight = model.get_tensor_f32(&format!("{prefix}.attn_norm.weight"), data)?;
        let attn_norm_bias = model
            .get_tensor_f32(&format!("{prefix}.attn_norm.bias"), data)
            .or_else(|_| model.get_tensor_f32(&format!("{prefix}.input_layernorm.bias"), data))
            .ok();

        // qwen3_moe uses separate Q/K/V (llama-style); fused QKV is unused for this arch.
        let q = Self::get_tensor_ref(model, data, &format!("{prefix}.attn_q.weight"))?;
        let k = Self::get_tensor_ref(model, data, &format!("{prefix}.attn_k.weight"))?;
        let v = Self::get_tensor_ref(model, data, &format!("{prefix}.attn_v.weight"))?;
        let q_bias = model
            .get_tensor_f32(&format!("{prefix}.attn_q.bias"), data)
            .ok();
        let k_bias = model
            .get_tensor_f32(&format!("{prefix}.attn_k.bias"), data)
            .ok();
        let v_bias = model
            .get_tensor_f32(&format!("{prefix}.attn_v.bias"), data)
            .ok();
        let qkv_bias = match (q_bias, k_bias, v_bias) {
            (Some(qb), Some(kb), Some(vb)) => {
                let mut combined = Vec::with_capacity(qb.len() + kb.len() + vb.len());
                combined.extend_from_slice(&qb);
                combined.extend_from_slice(&kb);
                combined.extend_from_slice(&vb);
                Some(combined)
            },
            _ => None,
        };
        let qkv_weight = QKVWeights::Separate { q, k, v };

        let attn_output_weight =
            Self::get_tensor_ref(model, data, &format!("{prefix}.attn_output.weight"))?;
        let attn_output_bias = model
            .get_tensor_f32(&format!("{prefix}.attn_output.bias"), data)
            .ok();

        // FFN fields stubbed — see moe_layers field for the real expert tensors.
        let dense_ffn_placeholder = QuantizedTensorRef {
            offset: 0,
            byte_size: 0,
            num_elements: 0,
            qtype: GGUF_TYPE_F32,
        };

        let ffn_norm_weight = model
            .get_tensor_f32(&format!("{prefix}.ffn_norm.weight"), data)
            .ok();
        let ffn_norm_bias = model
            .get_tensor_f32(&format!("{prefix}.ffn_norm.bias"), data)
            .or_else(|_| {
                model.get_tensor_f32(&format!("{prefix}.post_attention_layernorm.bias"), data)
            })
            .ok();

        let attn_q_norm_weight = model
            .get_tensor_f32(&format!("{prefix}.attn_q_norm.weight"), data)
            .ok();
        let attn_k_norm_weight = model
            .get_tensor_f32(&format!("{prefix}.attn_k_norm.weight"), data)
            .ok();

        Ok(QuantizedGGUFTransformerLayer {
            attn_norm_weight,
            attn_norm_bias,
            qkv_weight,
            qkv_bias,
            attn_output_weight,
            attn_output_bias,
            ffn_up_weight: dense_ffn_placeholder.clone(),
            ffn_up_bias: None,
            ffn_down_weight: dense_ffn_placeholder,
            ffn_down_bias: None,
            ffn_gate_weight: None,
            ffn_gate_bias: None,
            ffn_norm_weight,
            ffn_norm_bias,
            attn_q_norm_weight,
            attn_k_norm_weight,
        })
    }

    /// Calculate byte size for a quantized tensor based on its type and dimensions.
    fn tensor_byte_size(qtype: u32, num_elements: usize, dims: &[u64]) -> Result<usize> {
        /// Row-padded K-quant byte size: each row pads to super-block boundaries.
        fn k_quant_bytes(dims: &[u64], super_block_bytes: usize) -> usize {
            if dims.len() == 2 {
                let rows = dims[0] as usize;
                let cols = dims[1] as usize;
                rows * cols.div_ceil(QK_K) * super_block_bytes
            } else {
                let n: usize = dims.iter().map(|&d| d as usize).product();
                n.div_ceil(QK_K) * super_block_bytes
            }
        }

        match qtype {
            GGUF_TYPE_F32 => Ok(num_elements * 4),
            GGUF_TYPE_Q4_0 => Ok(num_elements.div_ceil(32) * 18),
            GGUF_TYPE_Q8_0 => Ok(num_elements.div_ceil(32) * 34),
            GGUF_TYPE_Q2_K => Ok(num_elements.div_ceil(QK_K) * 84),
            GGUF_TYPE_Q4_1 => Ok(num_elements.div_ceil(32) * 20),
            GGUF_TYPE_Q5_0 => Ok(num_elements.div_ceil(32) * 22),
            GGUF_TYPE_Q4_K => Ok(k_quant_bytes(dims, 144)),
            GGUF_TYPE_Q5_K => Ok(k_quant_bytes(dims, 176)),
            GGUF_TYPE_Q6_K => Ok(k_quant_bytes(dims, 210)),
            _ => Err(RealizarError::UnsupportedOperation {
                operation: "tensor_byte_size".to_string(),
                reason: format!("Unsupported quantization type: {qtype}"),
            }),
        }
    }

    /// PAR-058: Auto-correct qtype when header claims wrong type.
    fn resolve_qtype(
        name: &str,
        claimed_qtype: u32,
        byte_size: usize,
        num_elements: usize,
        offset: usize,
        data_len: usize,
    ) -> (usize, u32) {
        if offset + byte_size <= data_len {
            return (byte_size, claimed_qtype);
        }
        let avail = data_len.saturating_sub(offset);
        let q4_0_size = num_elements.div_ceil(32) * 18;
        if q4_0_size <= avail && q4_0_size > 0 {
            eprintln!(
                "[PAR-058-RESOLVED] Tensor '{name}' qtype mismatch: header says {claimed_qtype} but byte size suggests Q4_0. Using Q4_0."
            );
            return (q4_0_size, GGUF_TYPE_Q4_0);
        }
        let q8_0_size = num_elements.div_ceil(32) * 34;
        if q8_0_size <= avail && q8_0_size > 0 {
            eprintln!(
                "[PAR-058-RESOLVED] Tensor '{name}' qtype mismatch: header says {claimed_qtype} but byte size suggests Q8_0. Using Q8_0."
            );
            return (q8_0_size, GGUF_TYPE_Q8_0);
        }
        (byte_size, claimed_qtype)
    }

    /// Get tensor reference (offset + size + qtype) without dequantization
    pub(crate) fn get_tensor_ref(
        model: &GGUFModel,
        data: &[u8],
        name: &str,
    ) -> Result<QuantizedTensorRef> {
        let tensor = model
            .tensors
            .iter()
            .find(|t| t.name == name)
            .ok_or_else(|| RealizarError::InvalidShape {
                reason: format!("Tensor '{}' not found", name),
            })?;

        let num_elements: usize = tensor.dims.iter().map(|&d| d as usize).product();
        let offset = model.tensor_data_start + tensor.offset as usize;
        let byte_size = Self::tensor_byte_size(tensor.qtype, num_elements, &tensor.dims)?;
        let (byte_size, actual_qtype) = Self::resolve_qtype(
            name,
            tensor.qtype,
            byte_size,
            num_elements,
            offset,
            data.len(),
        );

        if offset + byte_size > data.len() {
            return Err(RealizarError::InvalidShape {
                reason: format!(
                    "Tensor '{}' data range [{}, {}) exceeds file size {}",
                    name,
                    offset,
                    offset + byte_size,
                    data.len()
                ),
            });
        }

        Ok(QuantizedTensorRef {
            offset,
            byte_size,
            num_elements,
            qtype: actual_qtype,
        })
    }

    /// Load a single quantized transformer layer
    fn load_quantized_layer(
        model: &GGUFModel,
        data: &[u8],
        layer_idx: usize,
    ) -> Result<QuantizedGGUFTransformerLayer> {
        let prefix = format!("blk.{}", layer_idx);

        // Attention norm - small, keep as f32
        let attn_norm_weight =
            model.get_tensor_f32(&format!("{}.attn_norm.weight", prefix), data)?;
        // GH-278: Attention norm bias — standard GGUF + aprender fallback
        let attn_norm_bias = model
            .get_tensor_f32(&format!("{}.attn_norm.bias", prefix), data)
            .or_else(|_| model.get_tensor_f32(&format!("{}.input_layernorm.bias", prefix), data))
            .ok();

        // QKV - large, keep quantized
        // Try fused first (phi-2 style), fall back to separate (llama style)
        let (qkv_weight, qkv_bias) = if let Ok(fused) =
            Self::get_tensor_ref(model, data, &format!("{}.attn_qkv.weight", prefix))
        {
            // phi-2 style: fused QKV tensor
            let bias = model
                .get_tensor_f32(&format!("{}.attn_qkv.bias", prefix), data)
                .ok();
            (QKVWeights::Fused(fused), bias)
        } else {
            // llama style: separate Q, K, V tensors
            let q = Self::get_tensor_ref(model, data, &format!("{}.attn_q.weight", prefix))?;
            let k = Self::get_tensor_ref(model, data, &format!("{}.attn_k.weight", prefix))?;
            let v = Self::get_tensor_ref(model, data, &format!("{}.attn_v.weight", prefix))?;

            // Try to get biases (llama usually doesn't have them)
            let q_bias = model
                .get_tensor_f32(&format!("{}.attn_q.bias", prefix), data)
                .ok();
            let k_bias = model
                .get_tensor_f32(&format!("{}.attn_k.bias", prefix), data)
                .ok();
            let v_bias = model
                .get_tensor_f32(&format!("{}.attn_v.bias", prefix), data)
                .ok();

            let bias = match (q_bias, k_bias, v_bias) {
                (Some(qb), Some(kb), Some(vb)) => {
                    let mut combined = Vec::with_capacity(qb.len() + kb.len() + vb.len());
                    combined.extend_from_slice(&qb);
                    combined.extend_from_slice(&kb);
                    combined.extend_from_slice(&vb);
                    Some(combined)
                },
                _ => None,
            };

            (QKVWeights::Separate { q, k, v }, bias)
        };

        // Attention output - large, keep quantized
        let attn_output_weight =
            Self::get_tensor_ref(model, data, &format!("{}.attn_output.weight", prefix))?;
        let attn_output_bias = model
            .get_tensor_f32(&format!("{}.attn_output.bias", prefix), data)
            .ok();

        // FFN - large, keep quantized
        let ffn_up_weight =
            Self::get_tensor_ref(model, data, &format!("{}.ffn_up.weight", prefix))?;
        // GH-278: FFN biases — standard GGUF + aprender fallback
        let ffn_up_bias = model
            .get_tensor_f32(&format!("{}.ffn_up.bias", prefix), data)
            .or_else(|_| model.get_tensor_f32(&format!("{}.mlp.up_proj.bias", prefix), data))
            .ok();
        let ffn_down_weight =
            Self::get_tensor_ref(model, data, &format!("{}.ffn_down.weight", prefix))?;
        let ffn_down_bias = model
            .get_tensor_f32(&format!("{}.ffn_down.bias", prefix), data)
            .or_else(|_| model.get_tensor_f32(&format!("{}.mlp.down_proj.bias", prefix), data))
            .ok();

        // FFN gate - SwiGLU models like LLaMA have this
        let ffn_gate_weight =
            Self::get_tensor_ref(model, data, &format!("{}.ffn_gate.weight", prefix)).ok();
        let ffn_gate_bias = model
            .get_tensor_f32(&format!("{}.ffn_gate.bias", prefix), data)
            .ok();

        // FFN norm - LLaMA-style pre-FFN layer norm
        let ffn_norm_weight = model
            .get_tensor_f32(&format!("{}.ffn_norm.weight", prefix), data)
            .ok();
        // GH-278: FFN norm bias — standard GGUF + aprender fallback
        let ffn_norm_bias = model
            .get_tensor_f32(&format!("{}.ffn_norm.bias", prefix), data)
            .or_else(|_| {
                model.get_tensor_f32(&format!("{}.post_attention_layernorm.bias", prefix), data)
            })
            .ok();

        // GH-279: QK norm weights (Qwen3 per-head RMSNorm on Q and K)
        let attn_q_norm_weight = model
            .get_tensor_f32(&format!("{}.attn_q_norm.weight", prefix), data)
            .ok();
        let attn_k_norm_weight = model
            .get_tensor_f32(&format!("{}.attn_k_norm.weight", prefix), data)
            .ok();

        Ok(QuantizedGGUFTransformerLayer {
            attn_norm_weight,
            attn_norm_bias,
            qkv_weight,
            qkv_bias,
            attn_output_weight,
            attn_output_bias,
            ffn_up_weight,
            ffn_up_bias,
            ffn_down_weight,
            ffn_down_bias,
            ffn_gate_weight,
            ffn_gate_bias,
            ffn_norm_weight,
            ffn_norm_bias,
            attn_q_norm_weight,
            attn_k_norm_weight,
        })
    }
}

include!("transformer_quantized_layer_field.rs");