1use anyhow::{Result, anyhow};
31use rlx_flow::blocks::{
32 LmHeadStage, Qwen3DecodeLayerSpec, Qwen3DecoderSpec, RopeTablesStage, qwen3_decode_layer_fused,
33 qwen3_prefill_layer_fused, qwen3_prefill_layer_fused_kv,
34};
35use rlx_flow::{BuiltModel, CompileProfile, FlowStage, ModelFlow, SideOutputs};
36use rlx_ir::dynamic::sym;
37use rlx_ir::shape::Dim;
38use rlx_ir::{DType, Shape};
39
40use super::config::Qwen3Config;
41use rlx_core::flow_bridge::WeightLoaderSource;
42use rlx_core::weight_loader::WeightLoader;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum Qwen3Mode {
46 Prefill,
47 Decode,
48}
49
50#[derive(Debug, Clone)]
51pub struct Qwen3PrefillOpts {
52 pub batch: usize,
53 pub seq: usize,
54 pub with_lm_head: bool,
55 pub with_kv_outputs: bool,
56 pub last_logits_only: bool,
57 pub profile: Option<CompileProfile>,
58 pub rope_cos: Option<Vec<f32>>,
60 pub rope_sin: Option<Vec<f32>>,
61}
62
63impl Qwen3PrefillOpts {
64 pub fn static_prefill(batch: usize, seq: usize) -> Self {
65 Self {
66 batch,
67 seq,
68 with_lm_head: false,
69 with_kv_outputs: false,
70 last_logits_only: false,
71 profile: None,
72 rope_cos: None,
73 rope_sin: None,
74 }
75 }
76}
77
78#[derive(Debug, Clone)]
79pub struct Qwen3DecodeOpts {
80 pub batch: usize,
81 pub past_seq: usize,
82 pub dynamic_past: bool,
83 pub use_custom_mask: bool,
84 pub profile: Option<CompileProfile>,
85}
86
87#[derive(Debug, Clone)]
88pub struct Qwen3Flow<'a> {
89 cfg: &'a Qwen3Config,
90 mode: Qwen3Mode,
91 batch: usize,
92 seq: usize,
93 past_seq: usize,
94 dynamic_past: bool,
95 with_lm_head: bool,
96 with_kv_outputs: bool,
97 last_logits_only: bool,
98 use_custom_mask: bool,
99 profile: Option<CompileProfile>,
100}
101
102impl<'a> Qwen3Flow<'a> {
103 pub fn new(cfg: &'a Qwen3Config) -> Self {
104 Self {
105 cfg,
106 mode: Qwen3Mode::Prefill,
107 batch: 1,
108 seq: 128,
109 past_seq: 0,
110 dynamic_past: false,
111 with_lm_head: false,
112 with_kv_outputs: false,
113 last_logits_only: false,
114 use_custom_mask: false,
115 profile: None,
116 }
117 }
118
119 pub fn for_prefill(cfg: &'a Qwen3Config, batch: usize, seq: usize) -> Self {
120 Self::new(cfg).prefill().batch(batch).seq(seq)
121 }
122
123 pub fn for_decode(cfg: &'a Qwen3Config, batch: usize, past_seq: usize) -> Self {
124 Self::new(cfg)
125 .decode()
126 .batch(batch)
127 .past(past_seq)
128 .lm_head()
129 }
130
131 pub fn prefill(mut self) -> Self {
132 self.mode = Qwen3Mode::Prefill;
133 self
134 }
135
136 pub fn decode(mut self) -> Self {
137 self.mode = Qwen3Mode::Decode;
138 self
139 }
140
141 pub fn batch(mut self, batch: usize) -> Self {
142 self.batch = batch;
143 self
144 }
145
146 pub fn seq(mut self, seq: usize) -> Self {
147 self.seq = seq;
148 self
149 }
150
151 pub fn past(mut self, past_seq: usize) -> Self {
152 self.past_seq = past_seq;
153 self
154 }
155
156 pub fn dynamic_past(mut self) -> Self {
157 self.dynamic_past = true;
158 self
159 }
160
161 pub fn lm_head(mut self) -> Self {
162 self.with_lm_head = true;
163 self
164 }
165
166 pub fn last_token_logits(mut self) -> Self {
167 self.with_lm_head = true;
168 self.last_logits_only = true;
169 self
170 }
171
172 pub fn export_kv(mut self) -> Self {
173 self.with_kv_outputs = true;
174 self
175 }
176
177 pub fn custom_mask(mut self) -> Self {
178 self.use_custom_mask = true;
179 self
180 }
181
182 pub fn profile(mut self, profile: CompileProfile) -> Self {
183 self.profile = Some(profile);
184 self
185 }
186
187 pub fn build(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
188 match self.mode {
189 Qwen3Mode::Prefill => {
190 build_qwen3_prefill_built(self.cfg, weights, &self.into_prefill_opts())
191 }
192 Qwen3Mode::Decode => {
193 build_qwen3_decode_built(self.cfg, weights, &self.into_decode_opts())
194 }
195 }
196 }
197}
198
199impl Qwen3Flow<'_> {
200 fn into_prefill_opts(self) -> Qwen3PrefillOpts {
201 Qwen3PrefillOpts {
202 batch: self.batch,
203 seq: self.seq,
204 with_lm_head: self.with_lm_head,
205 with_kv_outputs: self.with_kv_outputs,
206 last_logits_only: self.last_logits_only,
207 profile: self.profile,
208 rope_cos: None,
209 rope_sin: None,
210 }
211 }
212
213 fn into_decode_opts(self) -> Qwen3DecodeOpts {
214 Qwen3DecodeOpts {
215 batch: self.batch,
216 past_seq: self.past_seq,
217 dynamic_past: self.dynamic_past,
218 use_custom_mask: self.use_custom_mask,
219 profile: self.profile,
220 }
221 }
222}
223
224pub fn build_qwen3_prefill_built(
225 cfg: &Qwen3Config,
226 weights: &mut dyn WeightLoader,
227 opts: &Qwen3PrefillOpts,
228) -> Result<BuiltModel> {
229 validate_cfg(cfg)?;
230
231 let profile = opts
232 .profile
233 .clone()
234 .unwrap_or_else(CompileProfile::llama32_prefill);
235 let f = DType::F32;
236 let h = cfg.hidden_size;
237 let nh = cfg.num_attention_heads;
238 let nkv = cfg.num_key_value_heads;
239 let dh = cfg.head_dim;
240 let eps = cfg.rms_norm_eps as f32;
241 let batch = opts.batch;
242 let seq = opts.seq;
243
244 let hidden_shape = Shape::new(&[batch, seq, h], f);
245 let (cos_data, sin_data) = rope_tables(cfg);
246 let decoder_spec = Qwen3DecoderSpec {
247 num_heads: nh,
248 num_kv_heads: nkv,
249 head_dim: dh,
250 eps,
251 hidden_shape: hidden_shape.clone(),
252 batch,
253 seq,
254 qk_norm: cfg.qk_norm,
255 attention_bias: cfg.attention_bias,
256 };
257
258 let kv_sink = SideOutputs::new();
259
260 let mut flow = ModelFlow::new("qwen3")
261 .with_profile(profile)
262 .input("input_ids", Shape::new(&[batch, seq], DType::I32))
267 .rope_tables(RopeTablesStage::param(
268 cfg.max_position_embeddings,
269 dh / 2,
270 cos_data,
271 sin_data,
272 ))
273 .zero_beta_named("zero_beta", h)
274 .zero_beta_named("zero_beta.head", dh)
275 .token_embed();
276
277 flow = flow.repeat_layers(cfg.num_hidden_layers, {
278 let spec = decoder_spec.clone();
279 let sink = kv_sink.clone();
280 let export = opts.with_kv_outputs;
281 move |i| {
282 if export {
283 qwen3_prefill_layer_fused_kv(i, spec.clone(), sink.inner())
284 } else {
285 qwen3_prefill_layer_fused(i, spec.clone())
286 }
287 }
288 });
289
290 if opts.with_lm_head && opts.last_logits_only {
291 flow = flow.gather_last_token_at(batch, seq);
292 }
293
294 flow = flow.final_norm(eps);
295
296 let mut built = if opts.with_lm_head {
297 flow.raw_stage(qwen3_lm_head_stage(cfg))
298 .output("logits")
299 .build(&mut WeightLoaderSource(weights))?
300 } else {
301 flow.output("hidden_states")
302 .build(&mut WeightLoaderSource(weights))?
303 };
304
305 if opts.with_kv_outputs {
306 built = built.with_extra_hir_outputs(kv_sink.drain());
307 }
308 Ok(built)
309}
310
311pub fn build_qwen3_decode_built(
312 cfg: &Qwen3Config,
313 weights: &mut dyn WeightLoader,
314 opts: &Qwen3DecodeOpts,
315) -> Result<BuiltModel> {
316 validate_cfg(cfg)?;
317
318 let profile = opts
319 .profile
320 .clone()
321 .unwrap_or_else(CompileProfile::llama32_decode);
322 let f = DType::F32;
323 let h = cfg.hidden_size;
324 let nh = cfg.num_attention_heads;
325 let nkv = cfg.num_key_value_heads;
326 let dh = cfg.head_dim;
327 let eps = cfg.rms_norm_eps as f32;
328 let batch = opts.batch;
329 let half = dh / 2;
330 let kv_dim = cfg.kv_proj_dim();
331
332 let hidden_shape = Shape::new(&[batch, 1, h], f);
333 let past_kv_shape = if opts.dynamic_past {
334 Shape::from_dims(
335 &[
336 Dim::Static(batch),
337 Dim::Dynamic(sym::PAST_SEQ),
338 Dim::Static(kv_dim),
339 ],
340 f,
341 )
342 } else {
343 Shape::new(&[batch, opts.past_seq, kv_dim], f)
344 };
345
346 let decode_spec = Qwen3DecodeLayerSpec {
347 num_heads: nh,
348 num_kv_heads: nkv,
349 head_dim: dh,
350 kv_group_size: cfg.kv_group_size(),
351 eps,
352 use_custom_mask: opts.use_custom_mask,
353 hidden_shape: hidden_shape.clone(),
354 batch,
355 qk_norm: cfg.qk_norm,
356 attention_bias: cfg.attention_bias,
357 };
358
359 let kv_out = SideOutputs::new();
360
361 let mut flow = ModelFlow::new("qwen3_decode")
362 .with_profile(profile)
363 .input("input_ids", Shape::new(&[batch, 1], DType::I32))
364 .input("rope_cos", Shape::new(&[1, half], f))
365 .input("rope_sin", Shape::new(&[1, half], f));
366
367 if opts.use_custom_mask {
368 flow = flow.input("mask", Shape::new(&[batch, opts.past_seq + 1], f));
369 }
370
371 for layer_idx in 0..cfg.num_hidden_layers {
372 flow = flow
373 .input(format!("past_k_{layer_idx}"), past_kv_shape.clone())
374 .input(format!("past_v_{layer_idx}"), past_kv_shape.clone());
375 }
376
377 let built = flow
378 .bind_decode_inputs(cfg.num_hidden_layers, opts.use_custom_mask)
379 .zero_beta_named("zero_beta", h)
380 .zero_beta_named("zero_beta.head", dh)
381 .token_embed()
382 .repeat_layers(cfg.num_hidden_layers, {
383 let spec = decode_spec.clone();
384 let sink = kv_out.clone();
385 move |i| qwen3_decode_layer_fused(i, spec.clone(), sink.inner())
386 })
387 .final_norm(eps)
388 .raw_stage(qwen3_lm_head_stage(cfg))
389 .output("logits")
390 .build(&mut WeightLoaderSource(weights))?
391 .with_extra_hir_outputs(kv_out.drain());
392
393 Ok(built)
394}
395
396pub fn build_qwen3_decode_embeds_built(
399 cfg: &Qwen3Config,
400 weights: &mut dyn WeightLoader,
401 opts: &Qwen3DecodeOpts,
402) -> Result<BuiltModel> {
403 validate_cfg(cfg)?;
404
405 let profile = opts
406 .profile
407 .clone()
408 .unwrap_or_else(CompileProfile::llama32_decode);
409 let f = DType::F32;
410 let h = cfg.hidden_size;
411 let nh = cfg.num_attention_heads;
412 let nkv = cfg.num_key_value_heads;
413 let dh = cfg.head_dim;
414 let eps = cfg.rms_norm_eps as f32;
415 let batch = opts.batch;
416 let half = dh / 2;
417 let kv_dim = cfg.kv_proj_dim();
418
419 let hidden_shape = Shape::new(&[batch, 1, h], f);
420 let past_kv_shape = if opts.dynamic_past {
421 Shape::from_dims(
422 &[
423 Dim::Static(batch),
424 Dim::Dynamic(sym::PAST_SEQ),
425 Dim::Static(kv_dim),
426 ],
427 f,
428 )
429 } else {
430 Shape::new(&[batch, opts.past_seq, kv_dim], f)
431 };
432
433 let decode_spec = Qwen3DecodeLayerSpec {
434 num_heads: nh,
435 num_kv_heads: nkv,
436 head_dim: dh,
437 kv_group_size: cfg.kv_group_size(),
438 eps,
439 use_custom_mask: opts.use_custom_mask,
440 hidden_shape: hidden_shape.clone(),
441 batch,
442 qk_norm: cfg.qk_norm,
443 attention_bias: cfg.attention_bias,
444 };
445
446 let kv_out = SideOutputs::new();
447
448 let mut flow = ModelFlow::new("qwen3_decode_embeds")
449 .with_profile(profile)
450 .input("inputs_embeds", hidden_shape)
451 .input("rope_cos", Shape::new(&[1, half], f))
452 .input("rope_sin", Shape::new(&[1, half], f));
453
454 if opts.use_custom_mask {
455 flow = flow.input("mask", Shape::new(&[batch, opts.past_seq + 1], f));
456 }
457
458 for layer_idx in 0..cfg.num_hidden_layers {
459 flow = flow
460 .input(format!("past_k_{layer_idx}"), past_kv_shape.clone())
461 .input(format!("past_v_{layer_idx}"), past_kv_shape.clone());
462 }
463
464 let built = flow
465 .bind_decode_inputs(cfg.num_hidden_layers, opts.use_custom_mask)
466 .zero_beta_named("zero_beta", h)
467 .zero_beta_named("zero_beta.head", dh)
468 .repeat_layers(cfg.num_hidden_layers, {
469 let spec = decode_spec.clone();
470 let sink = kv_out.clone();
471 move |i| qwen3_decode_layer_fused(i, spec.clone(), sink.inner())
472 })
473 .final_norm(eps)
474 .output("hidden_states")
475 .build(&mut WeightLoaderSource(weights))?
476 .with_extra_hir_outputs(kv_out.drain());
477
478 Ok(built)
479}
480
481pub fn build_qwen3_prefill_embeds_built(
483 cfg: &Qwen3Config,
484 weights: &mut dyn WeightLoader,
485 opts: &Qwen3PrefillOpts,
486) -> Result<BuiltModel> {
487 validate_cfg(cfg)?;
488
489 let profile = opts
490 .profile
491 .clone()
492 .unwrap_or_else(CompileProfile::llama32_prefill);
493 let f = DType::F32;
494 let h = cfg.hidden_size;
495 let nh = cfg.num_attention_heads;
496 let nkv = cfg.num_key_value_heads;
497 let dh = cfg.head_dim;
498 let eps = cfg.rms_norm_eps as f32;
499 let batch = opts.batch;
500 let seq = opts.seq;
501
502 let hidden_shape = Shape::new(&[batch, seq, h], f);
503 let half = dh / 2;
504 let (cos_data, sin_data) = match (&opts.rope_cos, &opts.rope_sin) {
505 (Some(c), Some(s)) => (c.clone(), s.clone()),
506 _ => rope_tables(cfg),
507 };
508 let rope_max_pos = cfg.max_position_embeddings;
509 let decoder_spec = Qwen3DecoderSpec {
510 num_heads: nh,
511 num_kv_heads: nkv,
512 head_dim: dh,
513 eps,
514 hidden_shape: hidden_shape.clone(),
515 batch,
516 seq,
517 qk_norm: cfg.qk_norm,
518 attention_bias: cfg.attention_bias,
519 };
520
521 let kv_sink = SideOutputs::new();
522
523 let mut flow = ModelFlow::new("qwen3_prefill_embeds")
524 .with_profile(profile)
525 .input("inputs_embeds", hidden_shape)
526 .rope_tables(RopeTablesStage::param(
527 rope_max_pos,
528 half,
529 cos_data,
530 sin_data,
531 ))
532 .zero_beta_named("zero_beta", h)
533 .zero_beta_named("zero_beta.head", dh);
534
535 flow = flow.repeat_layers(cfg.num_hidden_layers, {
536 let spec = decoder_spec.clone();
537 let sink = kv_sink.clone();
538 let export = opts.with_kv_outputs;
539 move |i| {
540 if export {
541 qwen3_prefill_layer_fused_kv(i, spec.clone(), sink.inner())
542 } else {
543 qwen3_prefill_layer_fused(i, spec.clone())
544 }
545 }
546 });
547
548 if opts.last_logits_only {
549 flow = flow.gather_last_token_at(batch, seq);
550 }
551
552 flow = flow.final_norm(eps);
553
554 let mut built = flow
555 .output("hidden_states")
556 .build(&mut WeightLoaderSource(weights))?;
557
558 if opts.with_kv_outputs {
559 built = built.with_extra_hir_outputs(kv_sink.drain());
560 }
561 Ok(built)
562}
563
564pub fn build_qwen3_prefill_flow(
565 cfg: &Qwen3Config,
566 weights: &mut dyn WeightLoader,
567 opts: &Qwen3PrefillOpts,
568) -> Result<(
569 rlx_ir::hir::HirModule,
570 std::collections::HashMap<String, Vec<f32>>,
571)> {
572 build_qwen3_prefill_built(cfg, weights, opts)?.into_parts()
573}
574
575pub fn build_qwen3_decode_flow(
576 cfg: &Qwen3Config,
577 weights: &mut dyn WeightLoader,
578 opts: &Qwen3DecodeOpts,
579) -> Result<(
580 rlx_ir::hir::HirModule,
581 std::collections::HashMap<String, Vec<f32>>,
582)> {
583 build_qwen3_decode_built(cfg, weights, opts)?.into_parts()
584}
585
586pub fn build_qwen3_prefill_graph(
587 cfg: &Qwen3Config,
588 weights: &mut dyn WeightLoader,
589 opts: &Qwen3PrefillOpts,
590) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
591 rlx_core::flow_util::graph_from_built(build_qwen3_prefill_built(cfg, weights, opts)?)
592}
593
594pub fn build_qwen3_decode_graph(
595 cfg: &Qwen3Config,
596 weights: &mut dyn WeightLoader,
597 opts: &Qwen3DecodeOpts,
598) -> Result<(rlx_ir::Graph, std::collections::HashMap<String, Vec<f32>>)> {
599 rlx_core::flow_util::graph_from_built(build_qwen3_decode_built(cfg, weights, opts)?)
600}
601
602pub fn build_qwen3_prefill_hir(
603 cfg: &Qwen3Config,
604 weights: &mut dyn WeightLoader,
605 opts: &Qwen3PrefillOpts,
606) -> Result<(
607 rlx_ir::hir::HirModule,
608 std::collections::HashMap<String, Vec<f32>>,
609)> {
610 build_qwen3_prefill_flow(cfg, weights, opts)
611}
612
613pub fn build_qwen3_decode_hir(
614 cfg: &Qwen3Config,
615 weights: &mut dyn WeightLoader,
616 opts: &Qwen3DecodeOpts,
617) -> Result<(
618 rlx_ir::hir::HirModule,
619 std::collections::HashMap<String, Vec<f32>>,
620)> {
621 build_qwen3_decode_flow(cfg, weights, opts)
622}
623
624fn qwen3_lm_head_stage(cfg: &Qwen3Config) -> FlowStage {
625 if cfg.tie_word_embeddings {
626 FlowStage::LmHead(LmHeadStage {
627 weight_key: None,
628 tie_word_embeddings: true,
629 vocab_size: cfg.vocab_size,
630 hidden_size: cfg.hidden_size,
631 tied_param_name: "qwen3.lm_head.tied_t".into(),
632 })
633 } else {
634 FlowStage::LmHead(LmHeadStage::separate(
635 "lm_head.weight",
636 cfg.vocab_size,
637 cfg.hidden_size,
638 ))
639 }
640}
641
642fn validate_cfg(cfg: &Qwen3Config) -> Result<()> {
643 if !cfg
644 .num_attention_heads
645 .is_multiple_of(cfg.num_key_value_heads)
646 {
647 return Err(anyhow!(
648 "num_attention_heads ({}) must be divisible by num_key_value_heads ({})",
649 cfg.num_attention_heads,
650 cfg.num_key_value_heads
651 ));
652 }
653 Ok(())
656}
657
658fn rope_tables(cfg: &Qwen3Config) -> (Vec<f32>, Vec<f32>) {
659 let dh = cfg.head_dim;
660 let half = dh / 2;
661 let mut cos_data = vec![0f32; cfg.max_position_embeddings * half];
662 let mut sin_data = vec![0f32; cfg.max_position_embeddings * half];
663 for pos in 0..cfg.max_position_embeddings {
664 for i in 0..half {
665 let freq = 1.0 / cfg.rope_theta.powf((2 * i) as f64 / dh as f64);
666 let angle = pos as f64 * freq;
667 let (s, c) = angle.sin_cos();
668 cos_data[pos * half + i] = c as f32;
669 sin_data[pos * half + i] = s as f32;
670 }
671 }
672 (cos_data, sin_data)
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678 use rlx_core::weight_map::WeightMap;
679 use std::collections::HashMap;
680
681 fn tiny_cfg() -> Qwen3Config {
682 Qwen3Config {
683 vocab_size: 32,
684 hidden_size: 16,
685 intermediate_size: 32,
686 num_hidden_layers: 1,
687 num_attention_heads: 4,
688 num_key_value_heads: 2,
689 head_dim: 8,
690 max_position_embeddings: 16,
691 rms_norm_eps: 1e-6,
692 rope_theta: 1_000_000.0,
693 hidden_act: "silu".into(),
694 tie_word_embeddings: false,
695 attention_bias: false,
696 qk_norm: true,
697 sliding_window: None,
698 max_window_layers: usize::MAX,
699 use_sliding_window: false,
700 num_experts: 0,
701 num_experts_used: 0,
702 expert_ffn_size: 0,
703 shared_expert_ffn_size: 0,
704 expert_weights_scale: 1.0,
705 }
706 }
707
708 fn synthetic_weights(cfg: &Qwen3Config) -> WeightMap {
709 let h = cfg.hidden_size;
710 let q_dim = cfg.q_proj_dim();
711 let kv_dim = cfg.kv_proj_dim();
712 let int_dim = cfg.intermediate_size;
713 let dh = cfg.head_dim;
714 let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
715 let z = |n: usize| vec![0.0f32; n];
716 t.insert(
717 "model.embed_tokens.weight".into(),
718 (z(cfg.vocab_size * h), vec![cfg.vocab_size, h]),
719 );
720 let lp = "model.layers.0";
721 t.insert(format!("{lp}.input_layernorm.weight"), (z(h), vec![h]));
722 t.insert(
723 format!("{lp}.post_attention_layernorm.weight"),
724 (z(h), vec![h]),
725 );
726 t.insert(
727 format!("{lp}.self_attn.q_proj.weight"),
728 (z(q_dim * h), vec![q_dim, h]),
729 );
730 t.insert(
731 format!("{lp}.self_attn.k_proj.weight"),
732 (z(kv_dim * h), vec![kv_dim, h]),
733 );
734 t.insert(
735 format!("{lp}.self_attn.v_proj.weight"),
736 (z(kv_dim * h), vec![kv_dim, h]),
737 );
738 t.insert(
739 format!("{lp}.self_attn.o_proj.weight"),
740 (z(h * q_dim), vec![h, q_dim]),
741 );
742 t.insert(format!("{lp}.self_attn.q_norm.weight"), (z(dh), vec![dh]));
743 t.insert(format!("{lp}.self_attn.k_norm.weight"), (z(dh), vec![dh]));
744 t.insert(
745 format!("{lp}.mlp.gate_proj.weight"),
746 (z(int_dim * h), vec![int_dim, h]),
747 );
748 t.insert(
749 format!("{lp}.mlp.up_proj.weight"),
750 (z(int_dim * h), vec![int_dim, h]),
751 );
752 t.insert(
753 format!("{lp}.mlp.down_proj.weight"),
754 (z(h * int_dim), vec![h, int_dim]),
755 );
756 t.insert("model.norm.weight".into(), (z(h), vec![h]));
757 t.insert(
758 "lm_head.weight".into(),
759 (z(cfg.vocab_size * h), vec![cfg.vocab_size, h]),
760 );
761 WeightMap::from_tensors(t)
762 }
763
764 #[test]
765 fn prefill_flow_builds() {
766 let cfg = tiny_cfg();
767 let mut wm = synthetic_weights(&cfg);
768 let built = Qwen3Flow::for_prefill(&cfg, 1, 4).build(&mut wm).unwrap();
769 assert_eq!(built.primary_shape().rank(), 3);
770 }
771
772 #[test]
773 fn prefill_flow_export_kv() {
774 let cfg = tiny_cfg();
775 let mut wm = synthetic_weights(&cfg);
776 let built = Qwen3Flow::for_prefill(&cfg, 1, 4)
777 .export_kv()
778 .build(&mut wm)
779 .unwrap();
780 let hir = built.into_hir().unwrap();
781 assert!(hir.outputs.len() >= 3);
782 }
783
784 #[test]
785 fn decode_flow_builds() {
786 let cfg = tiny_cfg();
787 let mut wm = synthetic_weights(&cfg);
788 let built = Qwen3Flow::for_decode(&cfg, 1, 4).build(&mut wm).unwrap();
789 let hir = built.into_hir().unwrap();
790 assert!(hir.outputs.len() >= 3, "logits + new K/V");
791 }
792}