Skip to main content

rlx_qwen35/
weights.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Qwen3.5 / Qwen3.6 weight loader.
17//!
18//! Resolves every per-layer tensor named by `llama.cpp`'s
19//! `src/models/qwen35.cpp` (commit referenced in
20//! [`super::SOURCE_REF`]) from a `WeightLoader` (today only
21//! `GgufLoader` — the unsloth/froggeric files don't ship
22//! safetensors). Each tensor is dequantized to `Vec<f32>` and
23//! shape-checked against the config.
24//!
25//! The resulting [`Qwen35Weights`] holds three groups:
26//!
27//!   - **Embed/output**: `token_embd`, `output_norm`, optional `output`
28//!     (tied to embed if missing).
29//!   - **Trunk layers** `0..num_main_layers`: each is one of
30//!     - [`Qwen35TrunkLayer::Linear`] (gated DeltaNet block) — for
31//!       layers where `(i + 1) % full_attention_interval != 0`.
32//!     - [`Qwen35TrunkLayer::FullAttn`] (standard attention) — for
33//!       layers where `(i + 1) % full_attention_interval == 0`.
34//!   - **MTP layers** `num_main_layers..num_hidden_layers`: full
35//!     attention plus the NextN-specific `eh_proj` / `enorm` / `hnorm`
36//!     / optional `embed_tokens` / `shared_head_*` tensors.
37//!
38//! This struct is the input the future forward-graph builder consumes;
39//! it intentionally doesn't depend on `Graph` / `Session` so it can be
40//! unit-tested against a tiny synthesized GGUF.
41
42use crate::config::Qwen35Config;
43use anyhow::{Context, Result, anyhow};
44use rlx_core::weight_loader::{GgufLoader, WeightLoader};
45use rlx_ir::quant::QuantScheme;
46
47/// Storage variant for matmul weight tensors. The big projections
48/// (qkv / gate / ffn / lm_head) dominate the load footprint; the
49/// `Packed` variant keeps GGUF K-quant bytes in-place so the graph
50/// can emit `Op::DequantMatMul` instead of a full F32 dequant.
51///
52/// Norm weights, conv kernels, scalar params etc. stay as
53/// [`Vec<f32>`] in the layer structs (their footprint is negligible
54/// and the `RmsNorm` / `Conv` ops don't have a packed variant).
55#[derive(Debug, Clone)]
56pub enum MatWeight {
57    /// Already dequantized to f32, row-major `[out, in]`. The
58    /// builder transposes to `[in, out]` before issuing `MatMul`.
59    F32(Vec<f32>),
60    /// GGUF-packed K-quant metadata only. The actual bytes are
61    /// looked up in the loader at upload time via
62    /// [`rlx_core::weight_loader::GgufLoader::tensor_bytes_borrowed`]
63    /// — eliminates the per-tensor `Vec<u8>` allocation that
64    /// otherwise costs ~16 GB of memcpy on Qwen3.6-27B Q4_K_M.
65    ///
66    /// `key` is the loader-resolvable name (post-HF↔GGUF mapping);
67    /// `shape` is `[out, in]` after the safetensors-style dim
68    /// reversal.
69    Packed {
70        key: String,
71        scheme: QuantScheme,
72        shape: Vec<usize>,
73    },
74}
75
76impl MatWeight {
77    pub fn len(&self) -> usize {
78        match self {
79            MatWeight::F32(v) => v.len(),
80            MatWeight::Packed { shape, .. } => shape.iter().product(),
81        }
82    }
83    pub fn is_empty(&self) -> bool {
84        self.len() == 0
85    }
86    /// `[out, in]` on-disk shape. For the F32 variant the caller is
87    /// expected to track this externally (we return an empty Vec).
88    pub fn shape(&self) -> &[usize] {
89        match self {
90            MatWeight::F32(_) => &[],
91            MatWeight::Packed { shape, .. } => shape,
92        }
93    }
94    pub fn is_packed(&self) -> bool {
95        matches!(self, MatWeight::Packed { .. })
96    }
97    /// Loader-resolvable key for the packed variant. `None` for F32.
98    pub fn packed_key(&self) -> Option<&str> {
99        match self {
100            MatWeight::F32(_) => None,
101            MatWeight::Packed { key, .. } => Some(key.as_str()),
102        }
103    }
104}
105
106/// Per-layer feed-forward: dense SwiGLU or MoE (routed + gated shared expert).
107#[derive(Debug, Clone)]
108#[allow(clippy::large_enum_variant)]
109pub enum Qwen35LayerFfn {
110    Dense {
111        gate: MatWeight,
112        up: MatWeight,
113        down: MatWeight,
114    },
115    Moe(Qwen35MoeFfn),
116}
117
118/// MoE FFN tensors for one decoder layer (trunk or MTP).
119#[derive(Debug, Clone)]
120pub struct Qwen35MoeFfn {
121    /// Router logits: `[n_embd, n_expert]`.
122    pub router: MatWeight,
123    /// Expert gate projections: GroupedMatMul layout `[n_expert, n_embd, n_ff_exp]`.
124    pub gate_exps: MatWeight,
125    pub up_exps: MatWeight,
126    /// Expert down projections: `[n_expert, n_ff_exp, n_embd]`.
127    pub down_exps: MatWeight,
128    /// Shared-expert router weight `[n_embd]` (`ffn_gate_inp_shexp`).
129    pub shared_router: Vec<f32>,
130    pub shared_gate: MatWeight,
131    pub shared_up: MatWeight,
132    pub shared_down: MatWeight,
133}
134
135/// One trunk-layer tensor bundle. Either a gated-DeltaNet "linear
136/// attention" block or a standard full-attention block.
137#[derive(Debug, Clone)]
138pub enum Qwen35TrunkLayer {
139    Linear(Qwen35LinearLayer),
140    FullAttn(Qwen35FullAttnLayer),
141}
142
143/// Gated DeltaNet ("linear attention") trunk layer. Mirrors
144/// `qwen35.cpp::load_block_trunk` for the `is_recurrent(il)` branch.
145#[derive(Debug, Clone)]
146pub struct Qwen35LinearLayer {
147    /// `[n_embd]`
148    pub attn_norm: Vec<f32>,
149    /// `[n_embd]`
150    pub attn_post_norm: Vec<f32>,
151    /// Fused `[gate, x, k, B, C]`-style projection:
152    /// `[n_embd, 2*key_dim + value_dim]` with `key_dim =
153    /// ssm_state*group_count`, `value_dim = ssm_state*dt_rank`.
154    pub attn_qkv: MatWeight,
155    /// `[n_embd, value_dim]` — z gating projection.
156    pub attn_gate: MatWeight,
157    /// Depthwise 1-D conv weights over the fused channels:
158    /// `[ssm_conv_kernel, key_dim*2 + value_dim]`. Kept dense —
159    /// `Op::Conv` has no packed variant and the conv kernel is
160    /// tiny vs the projections.
161    pub ssm_conv1d: Vec<f32>,
162    /// `[dt_rank]` — delta-t bias.
163    pub ssm_dt_bias: Vec<f32>,
164    /// `[dt_rank]` — A (no-scan; used directly as scalar gate
165    /// multiplier per head).
166    pub ssm_a: Vec<f32>,
167    /// `[n_embd, dt_rank]` — β projection.
168    pub ssm_beta: MatWeight,
169    /// `[n_embd, dt_rank]` — α projection.
170    pub ssm_alpha: MatWeight,
171    /// `[ssm_state]` — per-state-row RMS norm gate.
172    pub ssm_norm: Vec<f32>,
173    /// `[value_dim, n_embd]` — output projection.
174    pub ssm_out: MatWeight,
175    pub ffn: Qwen35LayerFfn,
176}
177
178/// Standard full-attention trunk layer (interspersed every
179/// `full_attention_interval` blocks). Per `qwen35.cpp::load_block_trunk`
180/// non-recurrent branch.
181#[derive(Debug, Clone)]
182pub struct Qwen35FullAttnLayer {
183    pub attn_norm: Vec<f32>,
184    pub attn_post_norm: Vec<f32>,
185    /// `[n_embd, n_embd_head_k * n_head * 2]` — joint Q + gate
186    /// projection (Qwen3-Next style).
187    pub attn_q_gate: MatWeight,
188    pub attn_k: MatWeight,
189    pub attn_v: MatWeight,
190    pub attn_output: MatWeight,
191    pub attn_q_norm: Vec<f32>,
192    pub attn_k_norm: Vec<f32>,
193    pub ffn: Qwen35LayerFfn,
194}
195
196/// One MTP (NextN) layer. Per `qwen35.cpp::load_block_mtp`.
197#[derive(Debug, Clone)]
198pub struct Qwen35MtpLayer {
199    /// Base full-attention sub-block (shares the same shapes as
200    /// [`Qwen35FullAttnLayer`]).
201    pub base: Qwen35FullAttnLayer,
202    /// `[2*n_embd, n_embd]` — concatenated [e, h] → hidden projection.
203    pub eh_proj: MatWeight,
204    /// `[n_embd]`
205    pub enorm: Vec<f32>,
206    /// `[n_embd]`
207    pub hnorm: Vec<f32>,
208    /// `[n_embd, n_vocab]` — optional; if absent the MTP head reuses
209    /// the trunk's `token_embd`.
210    pub embed_tokens: Option<MatWeight>,
211    /// `[n_embd, n_vocab]` — optional; if absent the MTP head reuses
212    /// the trunk's `output` (or tied `token_embd`).
213    pub shared_head_head: Option<MatWeight>,
214    /// `[n_embd]` — optional; if absent the MTP head reuses
215    /// `output_norm`.
216    pub shared_head_norm: Option<Vec<f32>>,
217}
218
219/// Top-level Qwen3.5 / Qwen3.6 weight bundle.
220#[derive(Debug, Clone)]
221pub struct Qwen35Weights {
222    /// `[n_vocab, n_embd]`. Kept as `Vec<f32>` because the embed
223    /// table is always materialized for the `Op::Gather` lookup
224    /// (no packed-gather kernel today).
225    pub token_embd: Vec<f32>,
226    /// `[n_embd]`
227    pub output_norm: Vec<f32>,
228    /// `[n_vocab, n_embd]` — optional; tied to `token_embd` if absent.
229    /// May be packed when loaded via `from_loader_packed`.
230    pub output: Option<MatWeight>,
231    /// Packed K-quant bytes for tied LM head (`token_embd.weight`) when
232    /// the GGUF table is quantized. Gather still uses [`Self::token_embd`]
233    /// (eager F32); this is only for `DequantMatMul` on the logits path.
234    pub token_embd_lm: Option<MatWeight>,
235    pub trunk_layers: Vec<Qwen35TrunkLayer>,
236    pub mtp_layers: Vec<Qwen35MtpLayer>,
237}
238
239impl Qwen35Weights {
240    /// LM head width: tied embeddings use the full embedding table
241    /// (often wider than `cfg.vocab_size` on Qwen3.5 checkpoints).
242    pub fn lm_vocab_size(&self, cfg: &Qwen35Config) -> usize {
243        if self.token_embd.is_empty() || cfg.hidden_size == 0 {
244            return cfg.vocab_size;
245        }
246        self.token_embd.len() / cfg.hidden_size
247    }
248}
249
250impl Qwen35Weights {
251    /// Resolve every named tensor for a Qwen3.5 file. Drains the
252    /// loader's `take()` cache as it goes — the caller should not
253    /// expect to read these tensors back out afterwards. Errors on
254    /// the first missing required tensor with a precise key + reason.
255    ///
256    /// All matmul weights are loaded as `MatWeight::F32` (eager
257    /// dequant). For ≥14 B GGUFs use [`Self::from_loader_packed`]
258    /// to keep K-quant bytes packed in the arena.
259    pub fn from_loader(loader: &mut dyn WeightLoader, cfg: &Qwen35Config) -> Result<Self> {
260        Self::from_loader_inner(loader, cfg, /*pack*/ None)
261    }
262
263    /// Variant of [`Self::from_loader`] that keeps every K-quant
264    /// matmul weight packed (Q4_K / Q5_K / Q6_K / Q8_K) so the
265    /// builder can emit `Op::DequantMatMul`. Non-K-quant tensors
266    /// (F32, F16, BF16, legacy Q4_0/Q5_0/Q8_0) still fall through
267    /// to the dequant-to-F32 path.
268    ///
269    /// Memory savings on Qwen3.6-27B-Q4_K_M: ~65 GB → ~16 GB.
270    pub fn from_loader_packed(loader: &mut GgufLoader, cfg: &Qwen35Config) -> Result<Self> {
271        // Capture the raw pointer first so the &mut borrow that
272        // follows doesn't alias it (Rust's borrow checker rejects
273        // `&mut loader` and `loader as *mut` in the same call).
274        let pack_via = loader as *mut GgufLoader;
275        Self::from_loader_inner(loader, cfg, Some(pack_via))
276    }
277
278    fn from_loader_inner(
279        loader: &mut dyn WeightLoader,
280        cfg: &Qwen35Config,
281        pack_via: Option<*mut GgufLoader>,
282    ) -> Result<Self> {
283        let n_layer = cfg.num_hidden_layers;
284        let nextn = cfg.nextn_predict_layers;
285        if nextn >= n_layer {
286            return Err(anyhow!(
287                "qwen35: nextn_predict_layers={nextn} must be < num_hidden_layers={n_layer}",
288            ));
289        }
290        let n_main = n_layer - nextn;
291        let interval = cfg.full_attention_interval.max(1);
292
293        let token_embd_lm = pack_via.and_then(|p| peek_gguf_packed_mat(p, "token_embd.weight"));
294        let token_embd = take_f32(loader, "token_embd.weight")?;
295        let output_norm = take_f32(loader, "output_norm.weight")?;
296        let output = take_mat(loader, "output.weight", pack_via).ok();
297
298        let mut trunk_layers = Vec::with_capacity(n_main);
299        for il in 0..n_main {
300            let is_full_attn = ((il + 1) % interval) == 0;
301            if is_full_attn {
302                trunk_layers.push(Qwen35TrunkLayer::FullAttn(load_full_attn_layer(
303                    loader, il, cfg, pack_via,
304                )?));
305            } else {
306                trunk_layers.push(Qwen35TrunkLayer::Linear(load_linear_layer(
307                    loader, il, cfg, pack_via,
308                )?));
309            }
310        }
311
312        let mut mtp_layers = Vec::with_capacity(nextn);
313        for il in n_main..n_layer {
314            mtp_layers.push(load_mtp_layer(loader, il, cfg, pack_via)?);
315        }
316
317        Ok(Self {
318            token_embd,
319            output_norm,
320            output,
321            token_embd_lm,
322            trunk_layers,
323            mtp_layers,
324        })
325    }
326}
327
328fn peek_gguf_packed_mat(loader: *mut GgufLoader, key: &str) -> Option<MatWeight> {
329    use rlx_gguf::GgmlType;
330    use rlx_ir::quant::QuantScheme;
331    let g = unsafe { &*loader };
332    let t = g.file().get(key)?;
333    let scheme = match t.dtype {
334        GgmlType::Q4K => QuantScheme::GgufQ4K,
335        GgmlType::Q5K => QuantScheme::GgufQ5K,
336        GgmlType::Q6K => QuantScheme::GgufQ6K,
337        GgmlType::Q8K => QuantScheme::GgufQ8K,
338        _ => return None,
339    };
340    let mut shape = t.shape.clone();
341    shape.reverse();
342    Some(MatWeight::Packed {
343        key: key.to_string(),
344        scheme,
345        shape,
346    })
347}
348
349fn take_f32(loader: &mut dyn WeightLoader, key: &str) -> Result<Vec<f32>> {
350    let (data, _shape) = loader
351        .take(key)
352        .with_context(|| format!("missing tensor: {key}"))?;
353    Ok(data)
354}
355
356/// Take a matmul tensor: if `pack_via` is provided, try the packed
357/// loader first and only fall back to F32 dequant when the source
358/// tensor isn't a K-quant. SAFETY: `pack_via` must point at the
359/// same `GgufLoader` instance backing `loader`; the wrapper exists
360/// purely to thread the concrete-type method through the dyn-trait
361/// API. Constructed by [`Qwen35Weights::from_loader_packed`].
362fn take_mat(
363    loader: &mut dyn WeightLoader,
364    key: &str,
365    pack_via: Option<*mut GgufLoader>,
366) -> Result<MatWeight> {
367    if let Some(p) = pack_via {
368        // SAFETY: `p` was derived from the same `&mut GgufLoader`
369        // the caller already has exclusive access to via `loader`;
370        // we use it only to call `take_packed_metadata`, which
371        // doesn't alias with anything else inside this function.
372        let g: &mut GgufLoader = unsafe { &mut *p };
373        match g.take_packed_metadata(key) {
374            Ok(Some((scheme, shape))) => {
375                return Ok(MatWeight::Packed {
376                    key: key.to_string(),
377                    scheme,
378                    shape,
379                });
380            }
381            Ok(None) => { /* not a K-quant; fall through to F32 */ }
382            Err(_e) => { /* missing or already-taken; F32 will error */ }
383        }
384    }
385    let (data, _shape) = loader
386        .take(key)
387        .with_context(|| format!("missing tensor: {key}"))?;
388    Ok(MatWeight::F32(data))
389}
390
391/// Expert 3-D tensors: try packed K-quant first (native GGML layout,
392/// expert dimension outermost). F32 fallback permutes to `[E, K, N]`.
393fn take_expert_mat(
394    loader: &mut dyn WeightLoader,
395    key: &str,
396    pack_via: Option<*mut GgufLoader>,
397) -> Result<MatWeight> {
398    if let Some(p) = pack_via {
399        let g: &mut GgufLoader = unsafe { &mut *p };
400        if let Ok(Some((scheme, shape))) = g.take_packed_metadata(key) {
401            if shape.len() == 3 {
402                let n_expert = shape[2];
403                return Ok(MatWeight::Packed {
404                    key: key.to_string(),
405                    scheme,
406                    shape: vec![n_expert, shape[0], shape[1]],
407                });
408            }
409        }
410    }
411    let (data, shape) = loader
412        .take(key)
413        .with_context(|| format!("missing MoE tensor: {key}"))?;
414    if shape.len() != 3 {
415        return Err(anyhow!(
416            "MoE tensor {key}: expected rank-3 GGML shape, got {shape:?}"
417        ));
418    }
419    let n_expert = shape[2];
420    let permuted = permute_ggml_expert_to_grouped(&data, shape[0], shape[1], n_expert);
421    Ok(MatWeight::F32(permuted))
422}
423
424fn permute_ggml_expert_to_grouped(data: &[f32], d0: usize, d1: usize, n_expert: usize) -> Vec<f32> {
425    let mut out = vec![0f32; data.len()];
426    for e in 0..n_expert {
427        for i0 in 0..d0 {
428            for i1 in 0..d1 {
429                let src = i0 + d0 * i1 + d0 * d1 * e;
430                let dst = e * (d0 * d1) + i0 * d1 + i1;
431                out[dst] = data[src];
432            }
433        }
434    }
435    out
436}
437
438fn load_layer_ffn(
439    loader: &mut dyn WeightLoader,
440    il: usize,
441    cfg: &Qwen35Config,
442    pack_via: Option<*mut GgufLoader>,
443) -> Result<Qwen35LayerFfn> {
444    let p = |suffix: &str| format!("blk.{il}.{suffix}");
445    if cfg.is_moe() {
446        Ok(Qwen35LayerFfn::Moe(load_moe_ffn(
447            loader, il, cfg, pack_via,
448        )?))
449    } else {
450        Ok(Qwen35LayerFfn::Dense {
451            gate: take_mat(loader, &p("ffn_gate.weight"), pack_via)?,
452            up: take_mat(loader, &p("ffn_up.weight"), pack_via)?,
453            down: take_mat(loader, &p("ffn_down.weight"), pack_via)?,
454        })
455    }
456}
457
458fn load_moe_ffn(
459    loader: &mut dyn WeightLoader,
460    il: usize,
461    cfg: &Qwen35Config,
462    pack_via: Option<*mut GgufLoader>,
463) -> Result<Qwen35MoeFfn> {
464    let p = |suffix: &str| format!("blk.{il}.{suffix}");
465    let router = take_mat(loader, &p("ffn_gate_inp.weight"), pack_via)?;
466    let down_exps = take_expert_mat(loader, &p("ffn_down_exps.weight"), pack_via)?;
467    let (gate_exps, up_exps) = match (
468        take_expert_mat(loader, &p("ffn_gate_exps.weight"), pack_via),
469        take_expert_mat(loader, &p("ffn_up_exps.weight"), pack_via),
470    ) {
471        (Ok(g), Ok(u)) => (g, u),
472        _ => {
473            let fused = take_expert_mat(loader, &p("ffn_gate_up_exps.weight"), pack_via)?;
474            split_fused_gate_up_exps(fused, cfg)?
475        }
476    };
477    Ok(Qwen35MoeFfn {
478        router,
479        gate_exps,
480        up_exps,
481        down_exps,
482        shared_router: take_f32(loader, &p("ffn_gate_inp_shexp.weight"))?,
483        shared_gate: take_mat(loader, &p("ffn_gate_shexp.weight"), pack_via)?,
484        shared_up: take_mat(loader, &p("ffn_up_shexp.weight"), pack_via)?,
485        shared_down: take_mat(loader, &p("ffn_down_shexp.weight"), pack_via)?,
486    })
487}
488
489/// Split fused `ffn_gate_up_exps` after permute: `[n_expert, 2*n_ff, n_embd]`.
490fn split_fused_gate_up_exps(
491    fused: MatWeight,
492    cfg: &Qwen35Config,
493) -> Result<(MatWeight, MatWeight)> {
494    let MatWeight::F32(data) = fused else {
495        return Err(anyhow!(
496            "fused gate_up_exps must be F32 after take_expert_mat"
497        ));
498    };
499    let n_ff = cfg.expert_ffn_dim();
500    let n_embd = cfg.hidden_size;
501    let n_expert = cfg.num_experts;
502    let expected = 2 * n_ff * n_embd * n_expert;
503    if data.len() != expected {
504        return Err(anyhow!(
505            "fused gate_up_exps: len {} != 2*{n_ff}*{n_embd}*{n_expert}",
506            data.len()
507        ));
508    }
509    let expert_slab = 2 * n_ff * n_embd;
510    let half = n_ff * n_embd;
511    let mut gate = Vec::with_capacity(n_expert * half);
512    let mut up = Vec::with_capacity(n_expert * half);
513    for e in 0..n_expert {
514        let base = e * expert_slab;
515        gate.extend_from_slice(&data[base..base + half]);
516        up.extend_from_slice(&data[base + half..base + expert_slab]);
517    }
518    Ok((MatWeight::F32(gate), MatWeight::F32(up)))
519}
520
521fn load_linear_layer(
522    loader: &mut dyn WeightLoader,
523    il: usize,
524    cfg: &Qwen35Config,
525    pack_via: Option<*mut GgufLoader>,
526) -> Result<Qwen35LinearLayer> {
527    let p = |suffix: &str| format!("blk.{il}.{suffix}");
528    Ok(Qwen35LinearLayer {
529        attn_norm: take_f32(loader, &p("attn_norm.weight"))?,
530        attn_post_norm: take_f32(loader, &p("post_attention_norm.weight"))?,
531        attn_qkv: take_mat(loader, &p("attn_qkv.weight"), pack_via)?,
532        attn_gate: take_mat(loader, &p("attn_gate.weight"), pack_via)?,
533        ssm_conv1d: take_f32(loader, &p("ssm_conv1d.weight"))?,
534        ssm_dt_bias: take_f32(loader, &p("ssm_dt.bias"))?,
535        ssm_a: take_f32(loader, &p("ssm_a"))?,
536        ssm_beta: take_mat(loader, &p("ssm_beta.weight"), pack_via)?,
537        ssm_alpha: take_mat(loader, &p("ssm_alpha.weight"), pack_via)?,
538        ssm_norm: take_f32(loader, &p("ssm_norm.weight"))?,
539        ssm_out: take_mat(loader, &p("ssm_out.weight"), pack_via)?,
540        ffn: load_layer_ffn(loader, il, cfg, pack_via)?,
541    })
542}
543
544fn load_full_attn_layer(
545    loader: &mut dyn WeightLoader,
546    il: usize,
547    cfg: &Qwen35Config,
548    pack_via: Option<*mut GgufLoader>,
549) -> Result<Qwen35FullAttnLayer> {
550    let p = |suffix: &str| format!("blk.{il}.{suffix}");
551    Ok(Qwen35FullAttnLayer {
552        attn_norm: take_f32(loader, &p("attn_norm.weight"))?,
553        attn_post_norm: take_f32(loader, &p("post_attention_norm.weight"))?,
554        attn_q_gate: take_mat(loader, &p("attn_q.weight"), pack_via)?,
555        attn_k: take_mat(loader, &p("attn_k.weight"), pack_via)?,
556        attn_v: take_mat(loader, &p("attn_v.weight"), pack_via)?,
557        attn_output: take_mat(loader, &p("attn_output.weight"), pack_via)?,
558        attn_q_norm: take_f32(loader, &p("attn_q_norm.weight"))?,
559        attn_k_norm: take_f32(loader, &p("attn_k_norm.weight"))?,
560        ffn: load_layer_ffn(loader, il, cfg, pack_via)?,
561    })
562}
563
564fn load_mtp_layer(
565    loader: &mut dyn WeightLoader,
566    il: usize,
567    cfg: &Qwen35Config,
568    pack_via: Option<*mut GgufLoader>,
569) -> Result<Qwen35MtpLayer> {
570    let base = load_full_attn_layer(loader, il, cfg, pack_via)?;
571    let p = |suffix: &str| format!("blk.{il}.nextn.{suffix}");
572    let eh_proj = take_mat(loader, &p("eh_proj.weight"), pack_via)?;
573    let enorm = take_f32(loader, &p("enorm.weight"))?;
574    let hnorm = take_f32(loader, &p("hnorm.weight"))?;
575    let embed_tokens = take_mat(loader, &p("embed_tokens.weight"), pack_via).ok();
576    let shared_head_head = take_mat(loader, &p("shared_head_head.weight"), pack_via).ok();
577    let shared_head_norm = take_f32(loader, &p("shared_head_norm.weight")).ok();
578    Ok(Qwen35MtpLayer {
579        base,
580        eh_proj,
581        enorm,
582        hnorm,
583        embed_tokens,
584        shared_head_head,
585        shared_head_norm,
586    })
587}
588
589#[cfg(test)]
590mod tests {
591    use super::*;
592    use std::collections::HashMap;
593
594    /// Tiny in-memory `WeightLoader` that hands back a unique
595    /// constant-valued vector for each requested key. The shape we
596    /// return doesn't matter for this basic test — we only verify
597    /// that the right key set was requested and that the resulting
598    /// `Qwen35Weights` slots them into the right struct fields.
599    struct MockLoader {
600        store: HashMap<String, (Vec<f32>, Vec<usize>)>,
601    }
602
603    impl WeightLoader for MockLoader {
604        fn len(&self) -> usize {
605            self.store.len()
606        }
607        fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
608            self.store
609                .remove(key)
610                .ok_or_else(|| anyhow!("mock: missing key {key}"))
611        }
612        fn take_transposed(&mut self, _key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
613            unimplemented!("mock loader: not used by qwen35 loader")
614        }
615        fn remaining_keys(&self) -> Vec<String> {
616            self.store.keys().cloned().collect()
617        }
618    }
619
620    fn populate(store: &mut HashMap<String, (Vec<f32>, Vec<usize>)>, key: &str, marker: f32) {
621        store.insert(key.to_string(), (vec![marker], vec![1]));
622    }
623
624    fn build_synth_store(cfg: &Qwen35Config) -> HashMap<String, (Vec<f32>, Vec<usize>)> {
625        let mut store = HashMap::new();
626        populate(&mut store, "token_embd.weight", 1.0);
627        populate(&mut store, "output_norm.weight", 2.0);
628        // `output.weight` intentionally omitted to exercise the
629        // tied-embeddings path.
630
631        let n_main = cfg.num_hidden_layers - cfg.nextn_predict_layers;
632        let interval = cfg.full_attention_interval.max(1);
633
634        for il in 0..n_main {
635            let is_full_attn = ((il + 1) % interval) == 0;
636            let p = |suf: &str| format!("blk.{il}.{suf}");
637            if is_full_attn {
638                for k in [
639                    "attn_norm.weight",
640                    "post_attention_norm.weight",
641                    "attn_q.weight",
642                    "attn_k.weight",
643                    "attn_v.weight",
644                    "attn_output.weight",
645                    "attn_q_norm.weight",
646                    "attn_k_norm.weight",
647                    "ffn_gate.weight",
648                    "ffn_down.weight",
649                    "ffn_up.weight",
650                ] {
651                    populate(&mut store, &p(k), 10.0 + il as f32);
652                }
653            } else {
654                for k in [
655                    "attn_norm.weight",
656                    "post_attention_norm.weight",
657                    "attn_qkv.weight",
658                    "attn_gate.weight",
659                    "ssm_conv1d.weight",
660                    "ssm_dt.bias",
661                    "ssm_a",
662                    "ssm_beta.weight",
663                    "ssm_alpha.weight",
664                    "ssm_norm.weight",
665                    "ssm_out.weight",
666                    "ffn_gate.weight",
667                    "ffn_down.weight",
668                    "ffn_up.weight",
669                ] {
670                    populate(&mut store, &p(k), 100.0 + il as f32);
671                }
672            }
673        }
674
675        for il in n_main..cfg.num_hidden_layers {
676            let p = |suf: &str| format!("blk.{il}.{suf}");
677            for k in [
678                "attn_norm.weight",
679                "post_attention_norm.weight",
680                "attn_q.weight",
681                "attn_k.weight",
682                "attn_v.weight",
683                "attn_output.weight",
684                "attn_q_norm.weight",
685                "attn_k_norm.weight",
686                "ffn_gate.weight",
687                "ffn_down.weight",
688                "ffn_up.weight",
689                "nextn.eh_proj.weight",
690                "nextn.enorm.weight",
691                "nextn.hnorm.weight",
692            ] {
693                populate(&mut store, &p(k), 1000.0 + il as f32);
694            }
695        }
696        store
697    }
698
699    fn dummy_cfg() -> Qwen35Config {
700        // Mirrors Qwen3.5-0.8B: 25 layers, 1 MTP, full_attn every 4.
701        // The synthetic store ignores hidden_size etc., so the
702        // loader's shape checks fall back to whatever the GGUF
703        // reports (here single-element [1]).
704        Qwen35Config {
705            vocab_size: 0,
706            hidden_size: 1024,
707            intermediate_size: 3584,
708            num_hidden_layers: 6,
709            nextn_predict_layers: 1,
710            num_attention_heads: 16,
711            num_key_value_heads: 4,
712            key_length: 128,
713            value_length: 128,
714            max_position_embeddings: 40_960,
715            rms_norm_eps: 1e-6,
716            rope_theta: 10_000_000.0,
717            rope_dim_count: 64,
718            rope_dim_sections: vec![],
719            full_attention_interval: 4,
720            ssm_conv_kernel: 4,
721            ssm_group_count: 16,
722            ssm_inner_size: 2048,
723            ssm_state_size: 128,
724            ssm_time_step_rank: 16,
725            tie_word_embeddings: true,
726            num_experts: 0,
727            num_experts_used: 0,
728            expert_ffn_size: 0,
729            shared_expert_ffn_size: 0,
730            expert_weights_scale: 1.0,
731        }
732    }
733
734    /// 6-layer trunk (interval=4 → layer 3 is full-attn, others linear) +
735    /// 1 MTP layer. Verify each layer is classified correctly and the
736    /// MTP block exists with the NextN tensors loaded.
737    #[test]
738    fn qwen35_weights_loader_classifies_layers_and_loads_mtp() {
739        let cfg = dummy_cfg();
740        let mut loader = MockLoader {
741            store: build_synth_store(&cfg),
742        };
743        let w = Qwen35Weights::from_loader(&mut loader, &cfg).expect("load qwen35 weights");
744
745        // 5 linear + 1 full-attn trunk (6 main layers, interval=4)
746        // = layer 3 (zero-indexed: il=3 → (3+1)%4==0) full-attn,
747        // others linear.
748        assert_eq!(w.trunk_layers.len(), 5); // num_hidden_layers - nextn = 6 - 1 = 5
749        for (i, layer) in w.trunk_layers.iter().enumerate() {
750            let want_full = ((i + 1) % 4) == 0;
751            match (want_full, layer) {
752                (true, Qwen35TrunkLayer::FullAttn(_)) => {}
753                (false, Qwen35TrunkLayer::Linear(_)) => {}
754                _ => panic!(
755                    "layer {i}: want_full={want_full}, got {:?}",
756                    std::mem::discriminant(layer)
757                ),
758            }
759        }
760
761        // 1 MTP layer with required tensors loaded; optional
762        // shared-head tensors omitted in the synth store → None.
763        assert_eq!(w.mtp_layers.len(), 1);
764        let mtp = &w.mtp_layers[0];
765        // Mock loader returns F32 only (no packed bytes); verify
766        // the synth eh_proj came through as MatWeight::F32.
767        assert_eq!(mtp.eh_proj.len(), 1);
768        assert!(matches!(mtp.eh_proj, MatWeight::F32(_)));
769        assert_eq!(mtp.enorm.len(), 1);
770        assert_eq!(mtp.hnorm.len(), 1);
771        assert!(mtp.embed_tokens.is_none());
772        assert!(mtp.shared_head_head.is_none());
773        assert!(mtp.shared_head_norm.is_none());
774
775        // Tied LM head: `output.weight` was intentionally omitted
776        // from the synth store, so `output` should be None and the
777        // caller is expected to fall back to `token_embd`.
778        assert!(w.output.is_none());
779        assert_eq!(w.token_embd.len(), 1);
780        assert_eq!(w.output_norm.len(), 1);
781    }
782
783    /// Missing required tensor: error mentions the exact key.
784    #[test]
785    fn qwen35_weights_loader_reports_missing_tensor_key() {
786        let cfg = dummy_cfg();
787        let mut store = build_synth_store(&cfg);
788        store.remove("blk.2.ssm_conv1d.weight");
789        let mut loader = MockLoader { store };
790        let err = Qwen35Weights::from_loader(&mut loader, &cfg).expect_err("must error");
791        let msg = format!("{err:#}");
792        assert!(
793            msg.contains("blk.2.ssm_conv1d.weight"),
794            "error message must point at the missing key: {msg}"
795        );
796    }
797}