1use anyhow::{Result, anyhow};
31use rlx_flow::blocks::{
32 LmHeadStage, Qwen3DecodeLayerSpec, Qwen3DecoderSpec, RopeTablesStage, qwen3_decode_layer_fused,
33 qwen3_prefill_layer_fused, qwen3_prefill_layer_fused_kv,
34};
35use rlx_flow::{BuiltModel, CompileProfile, FlowStage, ModelFlow, SideOutputs};
36use rlx_ir::dynamic::sym;
37use rlx_ir::shape::Dim;
38use rlx_ir::{DType, Shape};
39
40use super::config::Qwen3Config;
41use rlx_core::flow_bridge::WeightLoaderSource;
42use rlx_core::weight_loader::WeightLoader;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum Qwen3Mode {
46 Prefill,
47 Decode,
48}
49
50#[derive(Debug, Clone)]
51pub struct Qwen3PrefillOpts {
52 pub batch: usize,
53 pub seq: usize,
54 pub with_lm_head: bool,
55 pub with_kv_outputs: bool,
56 pub last_logits_only: bool,
57 pub profile: Option<CompileProfile>,
58}
59
60impl Qwen3PrefillOpts {
61 pub fn static_prefill(batch: usize, seq: usize) -> Self {
62 Self {
63 batch,
64 seq,
65 with_lm_head: false,
66 with_kv_outputs: false,
67 last_logits_only: false,
68 profile: None,
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
74pub struct Qwen3DecodeOpts {
75 pub batch: usize,
76 pub past_seq: usize,
77 pub dynamic_past: bool,
78 pub use_custom_mask: bool,
79 pub profile: Option<CompileProfile>,
80}
81
82#[derive(Debug, Clone)]
83pub struct Qwen3Flow<'a> {
84 cfg: &'a Qwen3Config,
85 mode: Qwen3Mode,
86 batch: usize,
87 seq: usize,
88 past_seq: usize,
89 dynamic_past: bool,
90 with_lm_head: bool,
91 with_kv_outputs: bool,
92 last_logits_only: bool,
93 use_custom_mask: bool,
94 profile: Option<CompileProfile>,
95}
96
97impl<'a> Qwen3Flow<'a> {
98 pub fn new(cfg: &'a Qwen3Config) -> Self {
99 Self {
100 cfg,
101 mode: Qwen3Mode::Prefill,
102 batch: 1,
103 seq: 128,
104 past_seq: 0,
105 dynamic_past: false,
106 with_lm_head: false,
107 with_kv_outputs: false,
108 last_logits_only: false,
109 use_custom_mask: false,
110 profile: None,
111 }
112 }
113
114 pub fn for_prefill(cfg: &'a Qwen3Config, batch: usize, seq: usize) -> Self {
115 Self::new(cfg).prefill().batch(batch).seq(seq)
116 }
117
118 pub fn for_decode(cfg: &'a Qwen3Config, batch: usize, past_seq: usize) -> Self {
119 Self::new(cfg)
120 .decode()
121 .batch(batch)
122 .past(past_seq)
123 .lm_head()
124 }
125
126 pub fn prefill(mut self) -> Self {
127 self.mode = Qwen3Mode::Prefill;
128 self
129 }
130
131 pub fn decode(mut self) -> Self {
132 self.mode = Qwen3Mode::Decode;
133 self
134 }
135
136 pub fn batch(mut self, batch: usize) -> Self {
137 self.batch = batch;
138 self
139 }
140
141 pub fn seq(mut self, seq: usize) -> Self {
142 self.seq = seq;
143 self
144 }
145
146 pub fn past(mut self, past_seq: usize) -> Self {
147 self.past_seq = past_seq;
148 self
149 }
150
151 pub fn dynamic_past(mut self) -> Self {
152 self.dynamic_past = true;
153 self
154 }
155
156 pub fn lm_head(mut self) -> Self {
157 self.with_lm_head = true;
158 self
159 }
160
161 pub fn last_token_logits(mut self) -> Self {
162 self.with_lm_head = true;
163 self.last_logits_only = true;
164 self
165 }
166
167 pub fn export_kv(mut self) -> Self {
168 self.with_kv_outputs = true;
169 self
170 }
171
172 pub fn custom_mask(mut self) -> Self {
173 self.use_custom_mask = true;
174 self
175 }
176
177 pub fn profile(mut self, profile: CompileProfile) -> Self {
178 self.profile = Some(profile);
179 self
180 }
181
182 pub fn build(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
183 match self.mode {
184 Qwen3Mode::Prefill => {
185 build_qwen3_prefill_built(self.cfg, weights, &self.into_prefill_opts())
186 }
187 Qwen3Mode::Decode => {
188 build_qwen3_decode_built(self.cfg, weights, &self.into_decode_opts())
189 }
190 }
191 }
192}
193
194impl Qwen3Flow<'_> {
195 fn into_prefill_opts(self) -> Qwen3PrefillOpts {
196 Qwen3PrefillOpts {
197 batch: self.batch,
198 seq: self.seq,
199 with_lm_head: self.with_lm_head,
200 with_kv_outputs: self.with_kv_outputs,
201 last_logits_only: self.last_logits_only,
202 profile: self.profile,
203 }
204 }
205
206 fn into_decode_opts(self) -> Qwen3DecodeOpts {
207 Qwen3DecodeOpts {
208 batch: self.batch,
209 past_seq: self.past_seq,
210 dynamic_past: self.dynamic_past,
211 use_custom_mask: self.use_custom_mask,
212 profile: self.profile,
213 }
214 }
215}
216
217pub fn build_qwen3_prefill_built(
218 cfg: &Qwen3Config,
219 weights: &mut dyn WeightLoader,
220 opts: &Qwen3PrefillOpts,
221) -> Result<BuiltModel> {
222 validate_cfg(cfg)?;
223
224 let profile = opts
225 .profile
226 .clone()
227 .unwrap_or_else(CompileProfile::llama32_prefill);
228 let f = DType::F32;
229 let h = cfg.hidden_size;
230 let nh = cfg.num_attention_heads;
231 let nkv = cfg.num_key_value_heads;
232 let dh = cfg.head_dim;
233 let eps = cfg.rms_norm_eps as f32;
234 let batch = opts.batch;
235 let seq = opts.seq;
236
237 let hidden_shape = Shape::new(&[batch, seq, h], f);
238 let (cos_data, sin_data) = rope_tables(cfg);
239 let decoder_spec = Qwen3DecoderSpec {
240 num_heads: nh,
241 num_kv_heads: nkv,
242 head_dim: dh,
243 eps,
244 hidden_shape: hidden_shape.clone(),
245 batch,
246 seq,
247 qk_norm: cfg.qk_norm,
248 attention_bias: cfg.attention_bias,
249 };
250
251 let kv_sink = SideOutputs::new();
252
253 let mut flow = ModelFlow::new("qwen3")
254 .with_profile(profile)
255 .input("input_ids", Shape::new(&[batch, seq], DType::F32))
256 .rope_tables(RopeTablesStage::param(
257 cfg.max_position_embeddings,
258 dh / 2,
259 cos_data,
260 sin_data,
261 ))
262 .zero_beta_named("zero_beta", h)
263 .zero_beta_named("zero_beta.head", dh)
264 .token_embed();
265
266 flow = flow.repeat_layers(cfg.num_hidden_layers, {
267 let spec = decoder_spec.clone();
268 let sink = kv_sink.clone();
269 let export = opts.with_kv_outputs;
270 move |i| {
271 if export {
272 qwen3_prefill_layer_fused_kv(i, spec.clone(), sink.inner())
273 } else {
274 qwen3_prefill_layer_fused(i, spec.clone())
275 }
276 }
277 });
278
279 if opts.with_lm_head && opts.last_logits_only {
280 flow = flow.gather_last_token_at(batch, seq);
281 }
282
283 flow = flow.final_norm(eps);
284
285 let mut built = if opts.with_lm_head {
286 flow.raw_stage(qwen3_lm_head_stage(cfg))
287 .output("logits")
288 .build(&mut WeightLoaderSource(weights))?
289 } else {
290 flow.output("hidden_states")
291 .build(&mut WeightLoaderSource(weights))?
292 };
293
294 if opts.with_kv_outputs {
295 built = built.with_extra_hir_outputs(kv_sink.drain());
296 }
297 Ok(built)
298}
299
300pub fn build_qwen3_decode_built(
301 cfg: &Qwen3Config,
302 weights: &mut dyn WeightLoader,
303 opts: &Qwen3DecodeOpts,
304) -> Result<BuiltModel> {
305 validate_cfg(cfg)?;
306
307 let profile = opts
308 .profile
309 .clone()
310 .unwrap_or_else(CompileProfile::llama32_decode);
311 let f = DType::F32;
312 let h = cfg.hidden_size;
313 let nh = cfg.num_attention_heads;
314 let nkv = cfg.num_key_value_heads;
315 let dh = cfg.head_dim;
316 let eps = cfg.rms_norm_eps as f32;
317 let batch = opts.batch;
318 let half = dh / 2;
319 let kv_dim = cfg.kv_proj_dim();
320
321 let hidden_shape = Shape::new(&[batch, 1, h], f);
322 let past_kv_shape = if opts.dynamic_past {
323 Shape::from_dims(
324 &[
325 Dim::Static(batch),
326 Dim::Dynamic(sym::PAST_SEQ),
327 Dim::Static(kv_dim),
328 ],
329 f,
330 )
331 } else {
332 Shape::new(&[batch, opts.past_seq, kv_dim], f)
333 };
334
335 let decode_spec = Qwen3DecodeLayerSpec {
336 num_heads: nh,
337 num_kv_heads: nkv,
338 head_dim: dh,
339 kv_group_size: cfg.kv_group_size(),
340 eps,
341 use_custom_mask: opts.use_custom_mask,
342 hidden_shape: hidden_shape.clone(),
343 batch,
344 qk_norm: cfg.qk_norm,
345 attention_bias: cfg.attention_bias,
346 };
347
348 let kv_out = SideOutputs::new();
349
350 let mut flow = ModelFlow::new("qwen3_decode")
351 .with_profile(profile)
352 .input("input_ids", Shape::new(&[batch, 1], DType::F32))
353 .input("rope_cos", Shape::new(&[1, half], f))
354 .input("rope_sin", Shape::new(&[1, half], f));
355
356 if opts.use_custom_mask {
357 flow = flow.input("mask", Shape::new(&[batch, opts.past_seq + 1], f));
358 }
359
360 for layer_idx in 0..cfg.num_hidden_layers {
361 flow = flow
362 .input(format!("past_k_{layer_idx}"), past_kv_shape.clone())
363 .input(format!("past_v_{layer_idx}"), past_kv_shape.clone());
364 }
365
366 let built = flow
367 .bind_decode_inputs(cfg.num_hidden_layers, opts.use_custom_mask)
368 .zero_beta_named("zero_beta", h)
369 .zero_beta_named("zero_beta.head", dh)
370 .token_embed()
371 .repeat_layers(cfg.num_hidden_layers, {
372 let spec = decode_spec.clone();
373 let sink = kv_out.clone();
374 move |i| qwen3_decode_layer_fused(i, spec.clone(), sink.inner())
375 })
376 .final_norm(eps)
377 .raw_stage(qwen3_lm_head_stage(cfg))
378 .output("logits")
379 .build(&mut WeightLoaderSource(weights))?
380 .with_extra_hir_outputs(kv_out.drain());
381
382 Ok(built)
383}
384
385pub fn build_qwen3_prefill_flow(
386 cfg: &Qwen3Config,
387 weights: &mut dyn WeightLoader,
388 opts: &Qwen3PrefillOpts,
389) -> Result<(
390 rlx_ir::hir::HirModule,
391 std::collections::HashMap<String, Vec<f32>>,
392)> {
393 build_qwen3_prefill_built(cfg, weights, opts)?.into_parts()
394}
395
396pub fn build_qwen3_decode_flow(
397 cfg: &Qwen3Config,
398 weights: &mut dyn WeightLoader,
399 opts: &Qwen3DecodeOpts,
400) -> Result<(
401 rlx_ir::hir::HirModule,
402 std::collections::HashMap<String, Vec<f32>>,
403)> {
404 build_qwen3_decode_built(cfg, weights, opts)?.into_parts()
405}
406
407pub fn build_qwen3_prefill_graph(
408 cfg: &Qwen3Config,
409 weights: &mut dyn WeightLoader,
410 opts: &Qwen3PrefillOpts,
411) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
412 rlx_core::flow_util::graph_from_built(build_qwen3_prefill_built(cfg, weights, opts)?)
413}
414
415pub fn build_qwen3_decode_graph(
416 cfg: &Qwen3Config,
417 weights: &mut dyn WeightLoader,
418 opts: &Qwen3DecodeOpts,
419) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
420 rlx_core::flow_util::graph_from_built(build_qwen3_decode_built(cfg, weights, opts)?)
421}
422
423pub fn build_qwen3_prefill_hir(
424 cfg: &Qwen3Config,
425 weights: &mut dyn WeightLoader,
426 opts: &Qwen3PrefillOpts,
427) -> Result<(
428 rlx_ir::hir::HirModule,
429 std::collections::HashMap<String, Vec<f32>>,
430)> {
431 build_qwen3_prefill_flow(cfg, weights, opts)
432}
433
434pub fn build_qwen3_decode_hir(
435 cfg: &Qwen3Config,
436 weights: &mut dyn WeightLoader,
437 opts: &Qwen3DecodeOpts,
438) -> Result<(
439 rlx_ir::hir::HirModule,
440 std::collections::HashMap<String, Vec<f32>>,
441)> {
442 build_qwen3_decode_flow(cfg, weights, opts)
443}
444
445fn qwen3_lm_head_stage(cfg: &Qwen3Config) -> FlowStage {
446 if cfg.tie_word_embeddings {
447 FlowStage::LmHead(LmHeadStage {
448 weight_key: None,
449 tie_word_embeddings: true,
450 vocab_size: cfg.vocab_size,
451 hidden_size: cfg.hidden_size,
452 tied_param_name: "qwen3.lm_head.tied_t".into(),
453 })
454 } else {
455 FlowStage::LmHead(LmHeadStage::separate(
456 "lm_head.weight",
457 cfg.vocab_size,
458 cfg.hidden_size,
459 ))
460 }
461}
462
463fn validate_cfg(cfg: &Qwen3Config) -> Result<()> {
464 if !cfg
465 .num_attention_heads
466 .is_multiple_of(cfg.num_key_value_heads)
467 {
468 return Err(anyhow!(
469 "num_attention_heads ({}) must be divisible by num_key_value_heads ({})",
470 cfg.num_attention_heads,
471 cfg.num_key_value_heads
472 ));
473 }
474 Ok(())
477}
478
479fn rope_tables(cfg: &Qwen3Config) -> (Vec<f32>, Vec<f32>) {
480 let dh = cfg.head_dim;
481 let half = dh / 2;
482 let mut cos_data = vec![0f32; cfg.max_position_embeddings * half];
483 let mut sin_data = vec![0f32; cfg.max_position_embeddings * half];
484 for pos in 0..cfg.max_position_embeddings {
485 for i in 0..half {
486 let freq = 1.0 / cfg.rope_theta.powf((2 * i) as f64 / dh as f64);
487 let angle = pos as f64 * freq;
488 let (s, c) = angle.sin_cos();
489 cos_data[pos * half + i] = c as f32;
490 sin_data[pos * half + i] = s as f32;
491 }
492 }
493 (cos_data, sin_data)
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499 use rlx_core::weight_map::WeightMap;
500 use std::collections::HashMap;
501
502 fn tiny_cfg() -> Qwen3Config {
503 Qwen3Config {
504 vocab_size: 32,
505 hidden_size: 16,
506 intermediate_size: 32,
507 num_hidden_layers: 1,
508 num_attention_heads: 4,
509 num_key_value_heads: 2,
510 head_dim: 8,
511 max_position_embeddings: 16,
512 rms_norm_eps: 1e-6,
513 rope_theta: 1_000_000.0,
514 hidden_act: "silu".into(),
515 tie_word_embeddings: false,
516 attention_bias: false,
517 qk_norm: true,
518 sliding_window: None,
519 max_window_layers: usize::MAX,
520 use_sliding_window: false,
521 num_experts: 0,
522 num_experts_used: 0,
523 expert_ffn_size: 0,
524 shared_expert_ffn_size: 0,
525 expert_weights_scale: 1.0,
526 }
527 }
528
529 fn synthetic_weights(cfg: &Qwen3Config) -> WeightMap {
530 let h = cfg.hidden_size;
531 let q_dim = cfg.q_proj_dim();
532 let kv_dim = cfg.kv_proj_dim();
533 let int_dim = cfg.intermediate_size;
534 let dh = cfg.head_dim;
535 let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
536 let z = |n: usize| vec![0.0f32; n];
537 t.insert(
538 "model.embed_tokens.weight".into(),
539 (z(cfg.vocab_size * h), vec![cfg.vocab_size, h]),
540 );
541 let lp = "model.layers.0";
542 t.insert(format!("{lp}.input_layernorm.weight"), (z(h), vec![h]));
543 t.insert(
544 format!("{lp}.post_attention_layernorm.weight"),
545 (z(h), vec![h]),
546 );
547 t.insert(
548 format!("{lp}.self_attn.q_proj.weight"),
549 (z(q_dim * h), vec![q_dim, h]),
550 );
551 t.insert(
552 format!("{lp}.self_attn.k_proj.weight"),
553 (z(kv_dim * h), vec![kv_dim, h]),
554 );
555 t.insert(
556 format!("{lp}.self_attn.v_proj.weight"),
557 (z(kv_dim * h), vec![kv_dim, h]),
558 );
559 t.insert(
560 format!("{lp}.self_attn.o_proj.weight"),
561 (z(h * q_dim), vec![h, q_dim]),
562 );
563 t.insert(format!("{lp}.self_attn.q_norm.weight"), (z(dh), vec![dh]));
564 t.insert(format!("{lp}.self_attn.k_norm.weight"), (z(dh), vec![dh]));
565 t.insert(
566 format!("{lp}.mlp.gate_proj.weight"),
567 (z(int_dim * h), vec![int_dim, h]),
568 );
569 t.insert(
570 format!("{lp}.mlp.up_proj.weight"),
571 (z(int_dim * h), vec![int_dim, h]),
572 );
573 t.insert(
574 format!("{lp}.mlp.down_proj.weight"),
575 (z(h * int_dim), vec![h, int_dim]),
576 );
577 t.insert("model.norm.weight".into(), (z(h), vec![h]));
578 t.insert(
579 "lm_head.weight".into(),
580 (z(cfg.vocab_size * h), vec![cfg.vocab_size, h]),
581 );
582 WeightMap::from_tensors(t)
583 }
584
585 #[test]
586 fn prefill_flow_builds() {
587 let cfg = tiny_cfg();
588 let mut wm = synthetic_weights(&cfg);
589 let built = Qwen3Flow::for_prefill(&cfg, 1, 4).build(&mut wm).unwrap();
590 assert_eq!(built.primary_shape().rank(), 3);
591 }
592
593 #[test]
594 fn prefill_flow_export_kv() {
595 let cfg = tiny_cfg();
596 let mut wm = synthetic_weights(&cfg);
597 let built = Qwen3Flow::for_prefill(&cfg, 1, 4)
598 .export_kv()
599 .build(&mut wm)
600 .unwrap();
601 let hir = built.into_hir().unwrap();
602 assert!(hir.outputs.len() >= 3);
603 }
604
605 #[test]
606 fn decode_flow_builds() {
607 let cfg = tiny_cfg();
608 let mut wm = synthetic_weights(&cfg);
609 let built = Qwen3Flow::for_decode(&cfg, 1, 4).build(&mut wm).unwrap();
610 let hir = built.into_hir().unwrap();
611 assert!(hir.outputs.len() >= 3, "logits + new K/V");
612 }
613}