Skip to main content

rlx_qwen3/
flow.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Tier-0 Qwen3 flow — native [`ModelFlow`] assembly (QK-norm + GQA + RoPE).
17//!
18//! ```rust,ignore
19//! use rlx_models::qwen3::Qwen3Flow;
20//!
21//! // Prefill hidden states
22//! let built = Qwen3Flow::for_prefill(&cfg, 1, 128).build(&mut weights)?;
23//!
24//! // Decode step with KV side outputs
25//! let built = Qwen3Flow::for_decode(&cfg, 1, 256)
26//!     .custom_mask()
27//!     .build(&mut weights)?;
28//! ```
29
30use anyhow::{Result, anyhow};
31use rlx_flow::blocks::{
32    LmHeadStage, Qwen3DecodeLayerSpec, Qwen3DecoderSpec, RopeTablesStage, qwen3_decode_layer_fused,
33    qwen3_prefill_layer_fused, qwen3_prefill_layer_fused_kv,
34};
35use rlx_flow::{BuiltModel, CompileProfile, FlowStage, ModelFlow, SideOutputs};
36use rlx_ir::dynamic::sym;
37use rlx_ir::shape::Dim;
38use rlx_ir::{DType, Shape};
39
40use super::config::Qwen3Config;
41use rlx_core::flow_bridge::WeightLoaderSource;
42use rlx_core::weight_loader::WeightLoader;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum Qwen3Mode {
46    Prefill,
47    Decode,
48}
49
50#[derive(Debug, Clone)]
51pub struct Qwen3PrefillOpts {
52    pub batch: usize,
53    pub seq: usize,
54    pub with_lm_head: bool,
55    pub with_kv_outputs: bool,
56    pub last_logits_only: bool,
57    pub profile: Option<CompileProfile>,
58    /// When set, use these tables (`[seq, head_dim/2]`) instead of config-derived RoPE.
59    pub rope_cos: Option<Vec<f32>>,
60    pub rope_sin: Option<Vec<f32>>,
61}
62
63impl Qwen3PrefillOpts {
64    pub fn static_prefill(batch: usize, seq: usize) -> Self {
65        Self {
66            batch,
67            seq,
68            with_lm_head: false,
69            with_kv_outputs: false,
70            last_logits_only: false,
71            profile: None,
72            rope_cos: None,
73            rope_sin: None,
74        }
75    }
76}
77
78#[derive(Debug, Clone)]
79pub struct Qwen3DecodeOpts {
80    pub batch: usize,
81    pub past_seq: usize,
82    pub dynamic_past: bool,
83    pub use_custom_mask: bool,
84    pub profile: Option<CompileProfile>,
85}
86
87#[derive(Debug, Clone)]
88pub struct Qwen3Flow<'a> {
89    cfg: &'a Qwen3Config,
90    mode: Qwen3Mode,
91    batch: usize,
92    seq: usize,
93    past_seq: usize,
94    dynamic_past: bool,
95    with_lm_head: bool,
96    with_kv_outputs: bool,
97    last_logits_only: bool,
98    use_custom_mask: bool,
99    profile: Option<CompileProfile>,
100}
101
102impl<'a> Qwen3Flow<'a> {
103    pub fn new(cfg: &'a Qwen3Config) -> Self {
104        Self {
105            cfg,
106            mode: Qwen3Mode::Prefill,
107            batch: 1,
108            seq: 128,
109            past_seq: 0,
110            dynamic_past: false,
111            with_lm_head: false,
112            with_kv_outputs: false,
113            last_logits_only: false,
114            use_custom_mask: false,
115            profile: None,
116        }
117    }
118
119    pub fn for_prefill(cfg: &'a Qwen3Config, batch: usize, seq: usize) -> Self {
120        Self::new(cfg).prefill().batch(batch).seq(seq)
121    }
122
123    pub fn for_decode(cfg: &'a Qwen3Config, batch: usize, past_seq: usize) -> Self {
124        Self::new(cfg)
125            .decode()
126            .batch(batch)
127            .past(past_seq)
128            .lm_head()
129    }
130
131    pub fn prefill(mut self) -> Self {
132        self.mode = Qwen3Mode::Prefill;
133        self
134    }
135
136    pub fn decode(mut self) -> Self {
137        self.mode = Qwen3Mode::Decode;
138        self
139    }
140
141    pub fn batch(mut self, batch: usize) -> Self {
142        self.batch = batch;
143        self
144    }
145
146    pub fn seq(mut self, seq: usize) -> Self {
147        self.seq = seq;
148        self
149    }
150
151    pub fn past(mut self, past_seq: usize) -> Self {
152        self.past_seq = past_seq;
153        self
154    }
155
156    pub fn dynamic_past(mut self) -> Self {
157        self.dynamic_past = true;
158        self
159    }
160
161    pub fn lm_head(mut self) -> Self {
162        self.with_lm_head = true;
163        self
164    }
165
166    pub fn last_token_logits(mut self) -> Self {
167        self.with_lm_head = true;
168        self.last_logits_only = true;
169        self
170    }
171
172    pub fn export_kv(mut self) -> Self {
173        self.with_kv_outputs = true;
174        self
175    }
176
177    pub fn custom_mask(mut self) -> Self {
178        self.use_custom_mask = true;
179        self
180    }
181
182    pub fn profile(mut self, profile: CompileProfile) -> Self {
183        self.profile = Some(profile);
184        self
185    }
186
187    pub fn build(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
188        match self.mode {
189            Qwen3Mode::Prefill => {
190                build_qwen3_prefill_built(self.cfg, weights, &self.into_prefill_opts())
191            }
192            Qwen3Mode::Decode => {
193                build_qwen3_decode_built(self.cfg, weights, &self.into_decode_opts())
194            }
195        }
196    }
197}
198
199impl Qwen3Flow<'_> {
200    fn into_prefill_opts(self) -> Qwen3PrefillOpts {
201        Qwen3PrefillOpts {
202            batch: self.batch,
203            seq: self.seq,
204            with_lm_head: self.with_lm_head,
205            with_kv_outputs: self.with_kv_outputs,
206            last_logits_only: self.last_logits_only,
207            profile: self.profile,
208            rope_cos: None,
209            rope_sin: None,
210        }
211    }
212
213    fn into_decode_opts(self) -> Qwen3DecodeOpts {
214        Qwen3DecodeOpts {
215            batch: self.batch,
216            past_seq: self.past_seq,
217            dynamic_past: self.dynamic_past,
218            use_custom_mask: self.use_custom_mask,
219            profile: self.profile,
220        }
221    }
222}
223
224pub fn build_qwen3_prefill_built(
225    cfg: &Qwen3Config,
226    weights: &mut dyn WeightLoader,
227    opts: &Qwen3PrefillOpts,
228) -> Result<BuiltModel> {
229    validate_cfg(cfg)?;
230
231    let profile = opts
232        .profile
233        .clone()
234        .unwrap_or_else(CompileProfile::llama32_prefill);
235    let f = DType::F32;
236    let h = cfg.hidden_size;
237    let nh = cfg.num_attention_heads;
238    let nkv = cfg.num_key_value_heads;
239    let dh = cfg.head_dim;
240    let eps = cfg.rms_norm_eps as f32;
241    let batch = opts.batch;
242    let seq = opts.seq;
243
244    let hidden_shape = Shape::new(&[batch, seq, h], f);
245    let (cos_data, sin_data) = rope_tables(cfg);
246    let decoder_spec = Qwen3DecoderSpec {
247        num_heads: nh,
248        num_kv_heads: nkv,
249        head_dim: dh,
250        eps,
251        hidden_shape: hidden_shape.clone(),
252        batch,
253        seq,
254        qk_norm: cfg.qk_norm,
255        attention_bias: cfg.attention_bias,
256    };
257
258    let kv_sink = SideOutputs::new();
259
260    let mut flow = ModelFlow::new("qwen3")
261        .with_profile(profile)
262        // input_ids declared I32 — backends convert from f32 host buffers at
263        // the input boundary (Metal arena.write_from_f32, MLX mc::astype).
264        // Declaring F32 here bit-reinterpreted the float bytes as ints inside
265        // the gather kernel, producing garbled-token streams.
266        .input("input_ids", Shape::new(&[batch, seq], DType::I32))
267        .rope_tables(RopeTablesStage::param(
268            cfg.max_position_embeddings,
269            dh / 2,
270            cos_data,
271            sin_data,
272        ))
273        .zero_beta_named("zero_beta", h)
274        .zero_beta_named("zero_beta.head", dh)
275        .token_embed();
276
277    flow = flow.repeat_layers(cfg.num_hidden_layers, {
278        let spec = decoder_spec.clone();
279        let sink = kv_sink.clone();
280        let export = opts.with_kv_outputs;
281        move |i| {
282            if export {
283                qwen3_prefill_layer_fused_kv(i, spec.clone(), sink.inner())
284            } else {
285                qwen3_prefill_layer_fused(i, spec.clone())
286            }
287        }
288    });
289
290    if opts.with_lm_head && opts.last_logits_only {
291        flow = flow.gather_last_token_at(batch, seq);
292    }
293
294    flow = flow.final_norm(eps);
295
296    let mut built = if opts.with_lm_head {
297        flow.raw_stage(qwen3_lm_head_stage(cfg))
298            .output("logits")
299            .build(&mut WeightLoaderSource(weights))?
300    } else {
301        flow.output("hidden_states")
302            .build(&mut WeightLoaderSource(weights))?
303    };
304
305    if opts.with_kv_outputs {
306        built = built.with_extra_hir_outputs(kv_sink.drain());
307    }
308    Ok(built)
309}
310
311pub fn build_qwen3_decode_built(
312    cfg: &Qwen3Config,
313    weights: &mut dyn WeightLoader,
314    opts: &Qwen3DecodeOpts,
315) -> Result<BuiltModel> {
316    validate_cfg(cfg)?;
317
318    let profile = opts
319        .profile
320        .clone()
321        .unwrap_or_else(CompileProfile::llama32_decode);
322    let f = DType::F32;
323    let h = cfg.hidden_size;
324    let nh = cfg.num_attention_heads;
325    let nkv = cfg.num_key_value_heads;
326    let dh = cfg.head_dim;
327    let eps = cfg.rms_norm_eps as f32;
328    let batch = opts.batch;
329    let half = dh / 2;
330    let kv_dim = cfg.kv_proj_dim();
331
332    let hidden_shape = Shape::new(&[batch, 1, h], f);
333    let past_kv_shape = if opts.dynamic_past {
334        Shape::from_dims(
335            &[
336                Dim::Static(batch),
337                Dim::Dynamic(sym::PAST_SEQ),
338                Dim::Static(kv_dim),
339            ],
340            f,
341        )
342    } else {
343        Shape::new(&[batch, opts.past_seq, kv_dim], f)
344    };
345
346    let decode_spec = Qwen3DecodeLayerSpec {
347        num_heads: nh,
348        num_kv_heads: nkv,
349        head_dim: dh,
350        kv_group_size: cfg.kv_group_size(),
351        eps,
352        use_custom_mask: opts.use_custom_mask,
353        hidden_shape: hidden_shape.clone(),
354        batch,
355        qk_norm: cfg.qk_norm,
356        attention_bias: cfg.attention_bias,
357    };
358
359    let kv_out = SideOutputs::new();
360
361    let mut flow = ModelFlow::new("qwen3_decode")
362        .with_profile(profile)
363        .input("input_ids", Shape::new(&[batch, 1], DType::I32))
364        .input("rope_cos", Shape::new(&[1, half], f))
365        .input("rope_sin", Shape::new(&[1, half], f));
366
367    if opts.use_custom_mask {
368        flow = flow.input("mask", Shape::new(&[batch, opts.past_seq + 1], f));
369    }
370
371    for layer_idx in 0..cfg.num_hidden_layers {
372        flow = flow
373            .input(format!("past_k_{layer_idx}"), past_kv_shape.clone())
374            .input(format!("past_v_{layer_idx}"), past_kv_shape.clone());
375    }
376
377    let built = flow
378        .bind_decode_inputs(cfg.num_hidden_layers, opts.use_custom_mask)
379        .zero_beta_named("zero_beta", h)
380        .zero_beta_named("zero_beta.head", dh)
381        .token_embed()
382        .repeat_layers(cfg.num_hidden_layers, {
383            let spec = decode_spec.clone();
384            let sink = kv_out.clone();
385            move |i| qwen3_decode_layer_fused(i, spec.clone(), sink.inner())
386        })
387        .final_norm(eps)
388        .raw_stage(qwen3_lm_head_stage(cfg))
389        .output("logits")
390        .build(&mut WeightLoaderSource(weights))?
391        .with_extra_hir_outputs(kv_out.drain());
392
393    Ok(built)
394}
395
396/// Decode one step from a precomputed embedding (`[batch, 1, hidden]`).
397/// Outputs `hidden_states` and per-layer K/V (no LM head — use an external codec head for TTS).
398pub fn build_qwen3_decode_embeds_built(
399    cfg: &Qwen3Config,
400    weights: &mut dyn WeightLoader,
401    opts: &Qwen3DecodeOpts,
402) -> Result<BuiltModel> {
403    validate_cfg(cfg)?;
404
405    let profile = opts
406        .profile
407        .clone()
408        .unwrap_or_else(CompileProfile::llama32_decode);
409    let f = DType::F32;
410    let h = cfg.hidden_size;
411    let nh = cfg.num_attention_heads;
412    let nkv = cfg.num_key_value_heads;
413    let dh = cfg.head_dim;
414    let eps = cfg.rms_norm_eps as f32;
415    let batch = opts.batch;
416    let half = dh / 2;
417    let kv_dim = cfg.kv_proj_dim();
418
419    let hidden_shape = Shape::new(&[batch, 1, h], f);
420    let past_kv_shape = if opts.dynamic_past {
421        Shape::from_dims(
422            &[
423                Dim::Static(batch),
424                Dim::Dynamic(sym::PAST_SEQ),
425                Dim::Static(kv_dim),
426            ],
427            f,
428        )
429    } else {
430        Shape::new(&[batch, opts.past_seq, kv_dim], f)
431    };
432
433    let decode_spec = Qwen3DecodeLayerSpec {
434        num_heads: nh,
435        num_kv_heads: nkv,
436        head_dim: dh,
437        kv_group_size: cfg.kv_group_size(),
438        eps,
439        use_custom_mask: opts.use_custom_mask,
440        hidden_shape: hidden_shape.clone(),
441        batch,
442        qk_norm: cfg.qk_norm,
443        attention_bias: cfg.attention_bias,
444    };
445
446    let kv_out = SideOutputs::new();
447
448    let mut flow = ModelFlow::new("qwen3_decode_embeds")
449        .with_profile(profile)
450        .input("inputs_embeds", hidden_shape)
451        .input("rope_cos", Shape::new(&[1, half], f))
452        .input("rope_sin", Shape::new(&[1, half], f));
453
454    if opts.use_custom_mask {
455        flow = flow.input("mask", Shape::new(&[batch, opts.past_seq + 1], f));
456    }
457
458    for layer_idx in 0..cfg.num_hidden_layers {
459        flow = flow
460            .input(format!("past_k_{layer_idx}"), past_kv_shape.clone())
461            .input(format!("past_v_{layer_idx}"), past_kv_shape.clone());
462    }
463
464    let built = flow
465        .bind_decode_inputs(cfg.num_hidden_layers, opts.use_custom_mask)
466        .zero_beta_named("zero_beta", h)
467        .zero_beta_named("zero_beta.head", dh)
468        .repeat_layers(cfg.num_hidden_layers, {
469            let spec = decode_spec.clone();
470            let sink = kv_out.clone();
471            move |i| qwen3_decode_layer_fused(i, spec.clone(), sink.inner())
472        })
473        .final_norm(eps)
474        .output("hidden_states")
475        .build(&mut WeightLoaderSource(weights))?
476        .with_extra_hir_outputs(kv_out.drain());
477
478    Ok(built)
479}
480
481/// Prefill from `inputs_embeds` (`[batch, seq, hidden]`) with optional K/V export.
482pub fn build_qwen3_prefill_embeds_built(
483    cfg: &Qwen3Config,
484    weights: &mut dyn WeightLoader,
485    opts: &Qwen3PrefillOpts,
486) -> Result<BuiltModel> {
487    validate_cfg(cfg)?;
488
489    let profile = opts
490        .profile
491        .clone()
492        .unwrap_or_else(CompileProfile::llama32_prefill);
493    let f = DType::F32;
494    let h = cfg.hidden_size;
495    let nh = cfg.num_attention_heads;
496    let nkv = cfg.num_key_value_heads;
497    let dh = cfg.head_dim;
498    let eps = cfg.rms_norm_eps as f32;
499    let batch = opts.batch;
500    let seq = opts.seq;
501
502    let hidden_shape = Shape::new(&[batch, seq, h], f);
503    let half = dh / 2;
504    let (cos_data, sin_data) = match (&opts.rope_cos, &opts.rope_sin) {
505        (Some(c), Some(s)) => (c.clone(), s.clone()),
506        _ => rope_tables(cfg),
507    };
508    let rope_max_pos = cfg.max_position_embeddings;
509    let decoder_spec = Qwen3DecoderSpec {
510        num_heads: nh,
511        num_kv_heads: nkv,
512        head_dim: dh,
513        eps,
514        hidden_shape: hidden_shape.clone(),
515        batch,
516        seq,
517        qk_norm: cfg.qk_norm,
518        attention_bias: cfg.attention_bias,
519    };
520
521    let kv_sink = SideOutputs::new();
522
523    let mut flow = ModelFlow::new("qwen3_prefill_embeds")
524        .with_profile(profile)
525        .input("inputs_embeds", hidden_shape)
526        .rope_tables(RopeTablesStage::param(
527            rope_max_pos,
528            half,
529            cos_data,
530            sin_data,
531        ))
532        .zero_beta_named("zero_beta", h)
533        .zero_beta_named("zero_beta.head", dh);
534
535    flow = flow.repeat_layers(cfg.num_hidden_layers, {
536        let spec = decoder_spec.clone();
537        let sink = kv_sink.clone();
538        let export = opts.with_kv_outputs;
539        move |i| {
540            if export {
541                qwen3_prefill_layer_fused_kv(i, spec.clone(), sink.inner())
542            } else {
543                qwen3_prefill_layer_fused(i, spec.clone())
544            }
545        }
546    });
547
548    if opts.last_logits_only {
549        flow = flow.gather_last_token_at(batch, seq);
550    }
551
552    flow = flow.final_norm(eps);
553
554    let mut built = flow
555        .output("hidden_states")
556        .build(&mut WeightLoaderSource(weights))?;
557
558    if opts.with_kv_outputs {
559        built = built.with_extra_hir_outputs(kv_sink.drain());
560    }
561    Ok(built)
562}
563
564pub fn build_qwen3_prefill_flow(
565    cfg: &Qwen3Config,
566    weights: &mut dyn WeightLoader,
567    opts: &Qwen3PrefillOpts,
568) -> Result<(
569    rlx_ir::hir::HirModule,
570    std::collections::HashMap<String, Vec<f32>>,
571)> {
572    build_qwen3_prefill_built(cfg, weights, opts)?.into_parts()
573}
574
575pub fn build_qwen3_decode_flow(
576    cfg: &Qwen3Config,
577    weights: &mut dyn WeightLoader,
578    opts: &Qwen3DecodeOpts,
579) -> Result<(
580    rlx_ir::hir::HirModule,
581    std::collections::HashMap<String, Vec<f32>>,
582)> {
583    build_qwen3_decode_built(cfg, weights, opts)?.into_parts()
584}
585
586pub fn build_qwen3_prefill_graph(
587    cfg: &Qwen3Config,
588    weights: &mut dyn WeightLoader,
589    opts: &Qwen3PrefillOpts,
590) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
591    rlx_core::flow_util::graph_from_built(build_qwen3_prefill_built(cfg, weights, opts)?)
592}
593
594pub fn build_qwen3_decode_graph(
595    cfg: &Qwen3Config,
596    weights: &mut dyn WeightLoader,
597    opts: &Qwen3DecodeOpts,
598) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
599    rlx_core::flow_util::graph_from_built(build_qwen3_decode_built(cfg, weights, opts)?)
600}
601
602pub fn build_qwen3_prefill_hir(
603    cfg: &Qwen3Config,
604    weights: &mut dyn WeightLoader,
605    opts: &Qwen3PrefillOpts,
606) -> Result<(
607    rlx_ir::hir::HirModule,
608    std::collections::HashMap<String, Vec<f32>>,
609)> {
610    build_qwen3_prefill_flow(cfg, weights, opts)
611}
612
613pub fn build_qwen3_decode_hir(
614    cfg: &Qwen3Config,
615    weights: &mut dyn WeightLoader,
616    opts: &Qwen3DecodeOpts,
617) -> Result<(
618    rlx_ir::hir::HirModule,
619    std::collections::HashMap<String, Vec<f32>>,
620)> {
621    build_qwen3_decode_flow(cfg, weights, opts)
622}
623
624fn qwen3_lm_head_stage(cfg: &Qwen3Config) -> FlowStage {
625    if cfg.tie_word_embeddings {
626        FlowStage::LmHead(LmHeadStage {
627            weight_key: None,
628            tie_word_embeddings: true,
629            vocab_size: cfg.vocab_size,
630            hidden_size: cfg.hidden_size,
631            tied_param_name: "qwen3.lm_head.tied_t".into(),
632        })
633    } else {
634        FlowStage::LmHead(LmHeadStage::separate(
635            "lm_head.weight",
636            cfg.vocab_size,
637            cfg.hidden_size,
638        ))
639    }
640}
641
642fn validate_cfg(cfg: &Qwen3Config) -> Result<()> {
643    if !cfg
644        .num_attention_heads
645        .is_multiple_of(cfg.num_key_value_heads)
646    {
647        return Err(anyhow!(
648            "num_attention_heads ({}) must be divisible by num_key_value_heads ({})",
649            cfg.num_attention_heads,
650            cfg.num_key_value_heads
651        ));
652    }
653    // Qwen 2 / 2.5 ship `attention_bias=true` (explicit bias on Q/K/V);
654    // the builder loads + adds the bias vectors. Qwen 3 sets `false`.
655    Ok(())
656}
657
658fn rope_tables(cfg: &Qwen3Config) -> (Vec<f32>, Vec<f32>) {
659    let dh = cfg.head_dim;
660    let half = dh / 2;
661    let mut cos_data = vec![0f32; cfg.max_position_embeddings * half];
662    let mut sin_data = vec![0f32; cfg.max_position_embeddings * half];
663    for pos in 0..cfg.max_position_embeddings {
664        for i in 0..half {
665            let freq = 1.0 / cfg.rope_theta.powf((2 * i) as f64 / dh as f64);
666            let angle = pos as f64 * freq;
667            let (s, c) = angle.sin_cos();
668            cos_data[pos * half + i] = c as f32;
669            sin_data[pos * half + i] = s as f32;
670        }
671    }
672    (cos_data, sin_data)
673}
674
675#[cfg(test)]
676mod tests {
677    use super::*;
678    use rlx_core::weight_map::WeightMap;
679    use std::collections::HashMap;
680
681    fn tiny_cfg() -> Qwen3Config {
682        Qwen3Config {
683            vocab_size: 32,
684            hidden_size: 16,
685            intermediate_size: 32,
686            num_hidden_layers: 1,
687            num_attention_heads: 4,
688            num_key_value_heads: 2,
689            head_dim: 8,
690            max_position_embeddings: 16,
691            rms_norm_eps: 1e-6,
692            rope_theta: 1_000_000.0,
693            hidden_act: "silu".into(),
694            tie_word_embeddings: false,
695            attention_bias: false,
696            qk_norm: true,
697            sliding_window: None,
698            max_window_layers: usize::MAX,
699            use_sliding_window: false,
700            num_experts: 0,
701            num_experts_used: 0,
702            expert_ffn_size: 0,
703            shared_expert_ffn_size: 0,
704            expert_weights_scale: 1.0,
705        }
706    }
707
708    fn synthetic_weights(cfg: &Qwen3Config) -> WeightMap {
709        let h = cfg.hidden_size;
710        let q_dim = cfg.q_proj_dim();
711        let kv_dim = cfg.kv_proj_dim();
712        let int_dim = cfg.intermediate_size;
713        let dh = cfg.head_dim;
714        let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
715        let z = |n: usize| vec![0.0f32; n];
716        t.insert(
717            "model.embed_tokens.weight".into(),
718            (z(cfg.vocab_size * h), vec![cfg.vocab_size, h]),
719        );
720        let lp = "model.layers.0";
721        t.insert(format!("{lp}.input_layernorm.weight"), (z(h), vec![h]));
722        t.insert(
723            format!("{lp}.post_attention_layernorm.weight"),
724            (z(h), vec![h]),
725        );
726        t.insert(
727            format!("{lp}.self_attn.q_proj.weight"),
728            (z(q_dim * h), vec![q_dim, h]),
729        );
730        t.insert(
731            format!("{lp}.self_attn.k_proj.weight"),
732            (z(kv_dim * h), vec![kv_dim, h]),
733        );
734        t.insert(
735            format!("{lp}.self_attn.v_proj.weight"),
736            (z(kv_dim * h), vec![kv_dim, h]),
737        );
738        t.insert(
739            format!("{lp}.self_attn.o_proj.weight"),
740            (z(h * q_dim), vec![h, q_dim]),
741        );
742        t.insert(format!("{lp}.self_attn.q_norm.weight"), (z(dh), vec![dh]));
743        t.insert(format!("{lp}.self_attn.k_norm.weight"), (z(dh), vec![dh]));
744        t.insert(
745            format!("{lp}.mlp.gate_proj.weight"),
746            (z(int_dim * h), vec![int_dim, h]),
747        );
748        t.insert(
749            format!("{lp}.mlp.up_proj.weight"),
750            (z(int_dim * h), vec![int_dim, h]),
751        );
752        t.insert(
753            format!("{lp}.mlp.down_proj.weight"),
754            (z(h * int_dim), vec![h, int_dim]),
755        );
756        t.insert("model.norm.weight".into(), (z(h), vec![h]));
757        t.insert(
758            "lm_head.weight".into(),
759            (z(cfg.vocab_size * h), vec![cfg.vocab_size, h]),
760        );
761        WeightMap::from_tensors(t)
762    }
763
764    #[test]
765    fn prefill_flow_builds() {
766        let cfg = tiny_cfg();
767        let mut wm = synthetic_weights(&cfg);
768        let built = Qwen3Flow::for_prefill(&cfg, 1, 4).build(&mut wm).unwrap();
769        assert_eq!(built.primary_shape().rank(), 3);
770    }
771
772    #[test]
773    fn prefill_flow_export_kv() {
774        let cfg = tiny_cfg();
775        let mut wm = synthetic_weights(&cfg);
776        let built = Qwen3Flow::for_prefill(&cfg, 1, 4)
777            .export_kv()
778            .build(&mut wm)
779            .unwrap();
780        let hir = built.into_hir().unwrap();
781        assert!(hir.outputs.len() >= 3);
782    }
783
784    #[test]
785    fn decode_flow_builds() {
786        let cfg = tiny_cfg();
787        let mut wm = synthetic_weights(&cfg);
788        let built = Qwen3Flow::for_decode(&cfg, 1, 4).build(&mut wm).unwrap();
789        let hir = built.into_hir().unwrap();
790        assert!(hir.outputs.len() >= 3, "logits + new K/V");
791    }
792}