Skip to main content

rlx_qwen3/
flow.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Tier-0 Qwen3 flow — native [`ModelFlow`] assembly (QK-norm + GQA + RoPE).
17//!
18//! ```rust,ignore
19//! use rlx_models::qwen3::Qwen3Flow;
20//!
21//! // Prefill hidden states
22//! let built = Qwen3Flow::for_prefill(&cfg, 1, 128).build(&mut weights)?;
23//!
24//! // Decode step with KV side outputs
25//! let built = Qwen3Flow::for_decode(&cfg, 1, 256)
26//!     .custom_mask()
27//!     .build(&mut weights)?;
28//! ```
29
30use anyhow::{Result, anyhow};
31use rlx_flow::blocks::{
32    LmHeadStage, Qwen3DecodeLayerSpec, Qwen3DecoderSpec, RopeTablesStage, qwen3_decode_layer_fused,
33    qwen3_prefill_layer_fused, qwen3_prefill_layer_fused_kv,
34};
35use rlx_flow::{BuiltModel, CompileProfile, FlowStage, ModelFlow, SideOutputs};
36use rlx_ir::dynamic::sym;
37use rlx_ir::shape::Dim;
38use rlx_ir::{DType, Shape};
39
40use super::config::Qwen3Config;
41use rlx_core::flow_bridge::WeightLoaderSource;
42use rlx_core::weight_loader::WeightLoader;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum Qwen3Mode {
46    Prefill,
47    Decode,
48}
49
50#[derive(Debug, Clone)]
51pub struct Qwen3PrefillOpts {
52    pub batch: usize,
53    pub seq: usize,
54    pub with_lm_head: bool,
55    pub with_kv_outputs: bool,
56    pub last_logits_only: bool,
57    pub profile: Option<CompileProfile>,
58}
59
60impl Qwen3PrefillOpts {
61    pub fn static_prefill(batch: usize, seq: usize) -> Self {
62        Self {
63            batch,
64            seq,
65            with_lm_head: false,
66            with_kv_outputs: false,
67            last_logits_only: false,
68            profile: None,
69        }
70    }
71}
72
73#[derive(Debug, Clone)]
74pub struct Qwen3DecodeOpts {
75    pub batch: usize,
76    pub past_seq: usize,
77    pub dynamic_past: bool,
78    pub use_custom_mask: bool,
79    pub profile: Option<CompileProfile>,
80}
81
82#[derive(Debug, Clone)]
83pub struct Qwen3Flow<'a> {
84    cfg: &'a Qwen3Config,
85    mode: Qwen3Mode,
86    batch: usize,
87    seq: usize,
88    past_seq: usize,
89    dynamic_past: bool,
90    with_lm_head: bool,
91    with_kv_outputs: bool,
92    last_logits_only: bool,
93    use_custom_mask: bool,
94    profile: Option<CompileProfile>,
95}
96
97impl<'a> Qwen3Flow<'a> {
98    pub fn new(cfg: &'a Qwen3Config) -> Self {
99        Self {
100            cfg,
101            mode: Qwen3Mode::Prefill,
102            batch: 1,
103            seq: 128,
104            past_seq: 0,
105            dynamic_past: false,
106            with_lm_head: false,
107            with_kv_outputs: false,
108            last_logits_only: false,
109            use_custom_mask: false,
110            profile: None,
111        }
112    }
113
114    pub fn for_prefill(cfg: &'a Qwen3Config, batch: usize, seq: usize) -> Self {
115        Self::new(cfg).prefill().batch(batch).seq(seq)
116    }
117
118    pub fn for_decode(cfg: &'a Qwen3Config, batch: usize, past_seq: usize) -> Self {
119        Self::new(cfg)
120            .decode()
121            .batch(batch)
122            .past(past_seq)
123            .lm_head()
124    }
125
126    pub fn prefill(mut self) -> Self {
127        self.mode = Qwen3Mode::Prefill;
128        self
129    }
130
131    pub fn decode(mut self) -> Self {
132        self.mode = Qwen3Mode::Decode;
133        self
134    }
135
136    pub fn batch(mut self, batch: usize) -> Self {
137        self.batch = batch;
138        self
139    }
140
141    pub fn seq(mut self, seq: usize) -> Self {
142        self.seq = seq;
143        self
144    }
145
146    pub fn past(mut self, past_seq: usize) -> Self {
147        self.past_seq = past_seq;
148        self
149    }
150
151    pub fn dynamic_past(mut self) -> Self {
152        self.dynamic_past = true;
153        self
154    }
155
156    pub fn lm_head(mut self) -> Self {
157        self.with_lm_head = true;
158        self
159    }
160
161    pub fn last_token_logits(mut self) -> Self {
162        self.with_lm_head = true;
163        self.last_logits_only = true;
164        self
165    }
166
167    pub fn export_kv(mut self) -> Self {
168        self.with_kv_outputs = true;
169        self
170    }
171
172    pub fn custom_mask(mut self) -> Self {
173        self.use_custom_mask = true;
174        self
175    }
176
177    pub fn profile(mut self, profile: CompileProfile) -> Self {
178        self.profile = Some(profile);
179        self
180    }
181
182    pub fn build(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
183        match self.mode {
184            Qwen3Mode::Prefill => {
185                build_qwen3_prefill_built(self.cfg, weights, &self.into_prefill_opts())
186            }
187            Qwen3Mode::Decode => {
188                build_qwen3_decode_built(self.cfg, weights, &self.into_decode_opts())
189            }
190        }
191    }
192}
193
194impl Qwen3Flow<'_> {
195    fn into_prefill_opts(self) -> Qwen3PrefillOpts {
196        Qwen3PrefillOpts {
197            batch: self.batch,
198            seq: self.seq,
199            with_lm_head: self.with_lm_head,
200            with_kv_outputs: self.with_kv_outputs,
201            last_logits_only: self.last_logits_only,
202            profile: self.profile,
203        }
204    }
205
206    fn into_decode_opts(self) -> Qwen3DecodeOpts {
207        Qwen3DecodeOpts {
208            batch: self.batch,
209            past_seq: self.past_seq,
210            dynamic_past: self.dynamic_past,
211            use_custom_mask: self.use_custom_mask,
212            profile: self.profile,
213        }
214    }
215}
216
217pub fn build_qwen3_prefill_built(
218    cfg: &Qwen3Config,
219    weights: &mut dyn WeightLoader,
220    opts: &Qwen3PrefillOpts,
221) -> Result<BuiltModel> {
222    validate_cfg(cfg)?;
223
224    let profile = opts
225        .profile
226        .clone()
227        .unwrap_or_else(CompileProfile::llama32_prefill);
228    let f = DType::F32;
229    let h = cfg.hidden_size;
230    let nh = cfg.num_attention_heads;
231    let nkv = cfg.num_key_value_heads;
232    let dh = cfg.head_dim;
233    let eps = cfg.rms_norm_eps as f32;
234    let batch = opts.batch;
235    let seq = opts.seq;
236
237    let hidden_shape = Shape::new(&[batch, seq, h], f);
238    let (cos_data, sin_data) = rope_tables(cfg);
239    let decoder_spec = Qwen3DecoderSpec {
240        num_heads: nh,
241        num_kv_heads: nkv,
242        head_dim: dh,
243        eps,
244        hidden_shape: hidden_shape.clone(),
245        batch,
246        seq,
247        qk_norm: cfg.qk_norm,
248        attention_bias: cfg.attention_bias,
249    };
250
251    let kv_sink = SideOutputs::new();
252
253    let mut flow = ModelFlow::new("qwen3")
254        .with_profile(profile)
255        .input("input_ids", Shape::new(&[batch, seq], DType::F32))
256        .rope_tables(RopeTablesStage::param(
257            cfg.max_position_embeddings,
258            dh / 2,
259            cos_data,
260            sin_data,
261        ))
262        .zero_beta_named("zero_beta", h)
263        .zero_beta_named("zero_beta.head", dh)
264        .token_embed();
265
266    flow = flow.repeat_layers(cfg.num_hidden_layers, {
267        let spec = decoder_spec.clone();
268        let sink = kv_sink.clone();
269        let export = opts.with_kv_outputs;
270        move |i| {
271            if export {
272                qwen3_prefill_layer_fused_kv(i, spec.clone(), sink.inner())
273            } else {
274                qwen3_prefill_layer_fused(i, spec.clone())
275            }
276        }
277    });
278
279    if opts.with_lm_head && opts.last_logits_only {
280        flow = flow.gather_last_token_at(batch, seq);
281    }
282
283    flow = flow.final_norm(eps);
284
285    let mut built = if opts.with_lm_head {
286        flow.raw_stage(qwen3_lm_head_stage(cfg))
287            .output("logits")
288            .build(&mut WeightLoaderSource(weights))?
289    } else {
290        flow.output("hidden_states")
291            .build(&mut WeightLoaderSource(weights))?
292    };
293
294    if opts.with_kv_outputs {
295        built = built.with_extra_hir_outputs(kv_sink.drain());
296    }
297    Ok(built)
298}
299
300pub fn build_qwen3_decode_built(
301    cfg: &Qwen3Config,
302    weights: &mut dyn WeightLoader,
303    opts: &Qwen3DecodeOpts,
304) -> Result<BuiltModel> {
305    validate_cfg(cfg)?;
306
307    let profile = opts
308        .profile
309        .clone()
310        .unwrap_or_else(CompileProfile::llama32_decode);
311    let f = DType::F32;
312    let h = cfg.hidden_size;
313    let nh = cfg.num_attention_heads;
314    let nkv = cfg.num_key_value_heads;
315    let dh = cfg.head_dim;
316    let eps = cfg.rms_norm_eps as f32;
317    let batch = opts.batch;
318    let half = dh / 2;
319    let kv_dim = cfg.kv_proj_dim();
320
321    let hidden_shape = Shape::new(&[batch, 1, h], f);
322    let past_kv_shape = if opts.dynamic_past {
323        Shape::from_dims(
324            &[
325                Dim::Static(batch),
326                Dim::Dynamic(sym::PAST_SEQ),
327                Dim::Static(kv_dim),
328            ],
329            f,
330        )
331    } else {
332        Shape::new(&[batch, opts.past_seq, kv_dim], f)
333    };
334
335    let decode_spec = Qwen3DecodeLayerSpec {
336        num_heads: nh,
337        num_kv_heads: nkv,
338        head_dim: dh,
339        kv_group_size: cfg.kv_group_size(),
340        eps,
341        use_custom_mask: opts.use_custom_mask,
342        hidden_shape: hidden_shape.clone(),
343        batch,
344        qk_norm: cfg.qk_norm,
345        attention_bias: cfg.attention_bias,
346    };
347
348    let kv_out = SideOutputs::new();
349
350    let mut flow = ModelFlow::new("qwen3_decode")
351        .with_profile(profile)
352        .input("input_ids", Shape::new(&[batch, 1], DType::F32))
353        .input("rope_cos", Shape::new(&[1, half], f))
354        .input("rope_sin", Shape::new(&[1, half], f));
355
356    if opts.use_custom_mask {
357        flow = flow.input("mask", Shape::new(&[batch, opts.past_seq + 1], f));
358    }
359
360    for layer_idx in 0..cfg.num_hidden_layers {
361        flow = flow
362            .input(format!("past_k_{layer_idx}"), past_kv_shape.clone())
363            .input(format!("past_v_{layer_idx}"), past_kv_shape.clone());
364    }
365
366    let built = flow
367        .bind_decode_inputs(cfg.num_hidden_layers, opts.use_custom_mask)
368        .zero_beta_named("zero_beta", h)
369        .zero_beta_named("zero_beta.head", dh)
370        .token_embed()
371        .repeat_layers(cfg.num_hidden_layers, {
372            let spec = decode_spec.clone();
373            let sink = kv_out.clone();
374            move |i| qwen3_decode_layer_fused(i, spec.clone(), sink.inner())
375        })
376        .final_norm(eps)
377        .raw_stage(qwen3_lm_head_stage(cfg))
378        .output("logits")
379        .build(&mut WeightLoaderSource(weights))?
380        .with_extra_hir_outputs(kv_out.drain());
381
382    Ok(built)
383}
384
385pub fn build_qwen3_prefill_flow(
386    cfg: &Qwen3Config,
387    weights: &mut dyn WeightLoader,
388    opts: &Qwen3PrefillOpts,
389) -> Result<(
390    rlx_ir::hir::HirModule,
391    std::collections::HashMap<String, Vec<f32>>,
392)> {
393    build_qwen3_prefill_built(cfg, weights, opts)?.into_parts()
394}
395
396pub fn build_qwen3_decode_flow(
397    cfg: &Qwen3Config,
398    weights: &mut dyn WeightLoader,
399    opts: &Qwen3DecodeOpts,
400) -> Result<(
401    rlx_ir::hir::HirModule,
402    std::collections::HashMap<String, Vec<f32>>,
403)> {
404    build_qwen3_decode_built(cfg, weights, opts)?.into_parts()
405}
406
407pub fn build_qwen3_prefill_graph(
408    cfg: &Qwen3Config,
409    weights: &mut dyn WeightLoader,
410    opts: &Qwen3PrefillOpts,
411) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
412    rlx_core::flow_util::graph_from_built(build_qwen3_prefill_built(cfg, weights, opts)?)
413}
414
415pub fn build_qwen3_decode_graph(
416    cfg: &Qwen3Config,
417    weights: &mut dyn WeightLoader,
418    opts: &Qwen3DecodeOpts,
419) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
420    rlx_core::flow_util::graph_from_built(build_qwen3_decode_built(cfg, weights, opts)?)
421}
422
423pub fn build_qwen3_prefill_hir(
424    cfg: &Qwen3Config,
425    weights: &mut dyn WeightLoader,
426    opts: &Qwen3PrefillOpts,
427) -> Result<(
428    rlx_ir::hir::HirModule,
429    std::collections::HashMap<String, Vec<f32>>,
430)> {
431    build_qwen3_prefill_flow(cfg, weights, opts)
432}
433
434pub fn build_qwen3_decode_hir(
435    cfg: &Qwen3Config,
436    weights: &mut dyn WeightLoader,
437    opts: &Qwen3DecodeOpts,
438) -> Result<(
439    rlx_ir::hir::HirModule,
440    std::collections::HashMap<String, Vec<f32>>,
441)> {
442    build_qwen3_decode_flow(cfg, weights, opts)
443}
444
445fn qwen3_lm_head_stage(cfg: &Qwen3Config) -> FlowStage {
446    if cfg.tie_word_embeddings {
447        FlowStage::LmHead(LmHeadStage {
448            weight_key: None,
449            tie_word_embeddings: true,
450            vocab_size: cfg.vocab_size,
451            hidden_size: cfg.hidden_size,
452            tied_param_name: "qwen3.lm_head.tied_t".into(),
453        })
454    } else {
455        FlowStage::LmHead(LmHeadStage::separate(
456            "lm_head.weight",
457            cfg.vocab_size,
458            cfg.hidden_size,
459        ))
460    }
461}
462
463fn validate_cfg(cfg: &Qwen3Config) -> Result<()> {
464    if !cfg
465        .num_attention_heads
466        .is_multiple_of(cfg.num_key_value_heads)
467    {
468        return Err(anyhow!(
469            "num_attention_heads ({}) must be divisible by num_key_value_heads ({})",
470            cfg.num_attention_heads,
471            cfg.num_key_value_heads
472        ));
473    }
474    // Qwen 2 / 2.5 ship `attention_bias=true` (explicit bias on Q/K/V);
475    // the builder loads + adds the bias vectors. Qwen 3 sets `false`.
476    Ok(())
477}
478
479fn rope_tables(cfg: &Qwen3Config) -> (Vec<f32>, Vec<f32>) {
480    let dh = cfg.head_dim;
481    let half = dh / 2;
482    let mut cos_data = vec![0f32; cfg.max_position_embeddings * half];
483    let mut sin_data = vec![0f32; cfg.max_position_embeddings * half];
484    for pos in 0..cfg.max_position_embeddings {
485        for i in 0..half {
486            let freq = 1.0 / cfg.rope_theta.powf((2 * i) as f64 / dh as f64);
487            let angle = pos as f64 * freq;
488            let (s, c) = angle.sin_cos();
489            cos_data[pos * half + i] = c as f32;
490            sin_data[pos * half + i] = s as f32;
491        }
492    }
493    (cos_data, sin_data)
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499    use rlx_core::weight_map::WeightMap;
500    use std::collections::HashMap;
501
502    fn tiny_cfg() -> Qwen3Config {
503        Qwen3Config {
504            vocab_size: 32,
505            hidden_size: 16,
506            intermediate_size: 32,
507            num_hidden_layers: 1,
508            num_attention_heads: 4,
509            num_key_value_heads: 2,
510            head_dim: 8,
511            max_position_embeddings: 16,
512            rms_norm_eps: 1e-6,
513            rope_theta: 1_000_000.0,
514            hidden_act: "silu".into(),
515            tie_word_embeddings: false,
516            attention_bias: false,
517            qk_norm: true,
518            sliding_window: None,
519            max_window_layers: usize::MAX,
520            use_sliding_window: false,
521            num_experts: 0,
522            num_experts_used: 0,
523            expert_ffn_size: 0,
524            shared_expert_ffn_size: 0,
525            expert_weights_scale: 1.0,
526        }
527    }
528
529    fn synthetic_weights(cfg: &Qwen3Config) -> WeightMap {
530        let h = cfg.hidden_size;
531        let q_dim = cfg.q_proj_dim();
532        let kv_dim = cfg.kv_proj_dim();
533        let int_dim = cfg.intermediate_size;
534        let dh = cfg.head_dim;
535        let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
536        let z = |n: usize| vec![0.0f32; n];
537        t.insert(
538            "model.embed_tokens.weight".into(),
539            (z(cfg.vocab_size * h), vec![cfg.vocab_size, h]),
540        );
541        let lp = "model.layers.0";
542        t.insert(format!("{lp}.input_layernorm.weight"), (z(h), vec![h]));
543        t.insert(
544            format!("{lp}.post_attention_layernorm.weight"),
545            (z(h), vec![h]),
546        );
547        t.insert(
548            format!("{lp}.self_attn.q_proj.weight"),
549            (z(q_dim * h), vec![q_dim, h]),
550        );
551        t.insert(
552            format!("{lp}.self_attn.k_proj.weight"),
553            (z(kv_dim * h), vec![kv_dim, h]),
554        );
555        t.insert(
556            format!("{lp}.self_attn.v_proj.weight"),
557            (z(kv_dim * h), vec![kv_dim, h]),
558        );
559        t.insert(
560            format!("{lp}.self_attn.o_proj.weight"),
561            (z(h * q_dim), vec![h, q_dim]),
562        );
563        t.insert(format!("{lp}.self_attn.q_norm.weight"), (z(dh), vec![dh]));
564        t.insert(format!("{lp}.self_attn.k_norm.weight"), (z(dh), vec![dh]));
565        t.insert(
566            format!("{lp}.mlp.gate_proj.weight"),
567            (z(int_dim * h), vec![int_dim, h]),
568        );
569        t.insert(
570            format!("{lp}.mlp.up_proj.weight"),
571            (z(int_dim * h), vec![int_dim, h]),
572        );
573        t.insert(
574            format!("{lp}.mlp.down_proj.weight"),
575            (z(h * int_dim), vec![h, int_dim]),
576        );
577        t.insert("model.norm.weight".into(), (z(h), vec![h]));
578        t.insert(
579            "lm_head.weight".into(),
580            (z(cfg.vocab_size * h), vec![cfg.vocab_size, h]),
581        );
582        WeightMap::from_tensors(t)
583    }
584
585    #[test]
586    fn prefill_flow_builds() {
587        let cfg = tiny_cfg();
588        let mut wm = synthetic_weights(&cfg);
589        let built = Qwen3Flow::for_prefill(&cfg, 1, 4).build(&mut wm).unwrap();
590        assert_eq!(built.primary_shape().rank(), 3);
591    }
592
593    #[test]
594    fn prefill_flow_export_kv() {
595        let cfg = tiny_cfg();
596        let mut wm = synthetic_weights(&cfg);
597        let built = Qwen3Flow::for_prefill(&cfg, 1, 4)
598            .export_kv()
599            .build(&mut wm)
600            .unwrap();
601        let hir = built.into_hir().unwrap();
602        assert!(hir.outputs.len() >= 3);
603    }
604
605    #[test]
606    fn decode_flow_builds() {
607        let cfg = tiny_cfg();
608        let mut wm = synthetic_weights(&cfg);
609        let built = Qwen3Flow::for_decode(&cfg, 1, 4).build(&mut wm).unwrap();
610        let hir = built.into_hir().unwrap();
611        assert!(hir.outputs.len() >= 3, "logits + new K/V");
612    }
613}