1use super::config::Flux2Config;
19use super::layers::time_guidance_embed;
20use super::packed::{Flux2GgufLinearPacked, Flux2PackedParams, Nvfp4LinearPacked};
21use super::rope::flux2_pos_embed;
22use super::typed_linear::{TypedLinear, TypedLinearStore};
23use super::weights::{
24 Flux2DualAttnWeights, Flux2FeedForwardWeights, Flux2ModulationWeights, Flux2NormOutWeights,
25 Flux2ParallelAttnWeights, Flux2Weights, LinearWeights, RmsNormWeight,
26};
27use crate::builder::Flux2GraphParams;
28use anyhow::Result;
29use rlx_ir::hir::{FusionPolicy, HirModule, HirNodeId};
30use rlx_ir::op::{Activation, BinaryOp, MaskKind};
31use rlx_ir::{DType, Dim, Graph, Op, Shape};
32
33pub type Flux2TypedParams = Vec<(String, Vec<u8>, DType)>;
35
36pub struct Flux2ForwardGraph {
37 pub hir: HirModule,
38 pub params: Flux2GraphParams,
39 pub typed_params: Flux2TypedParams,
40}
41
42pub fn build_flux2_forward_hir(
50 cfg: &Flux2Config,
51 weights: &Flux2Weights,
52 batch: usize,
53 img_seq: usize,
54 txt_seq: usize,
55 img_ids: &[f32],
56 txt_ids: &[f32],
57 packed: Option<&Flux2PackedParams>,
58 typed_linears: Option<&TypedLinearStore>,
59) -> Result<Flux2ForwardGraph> {
60 let mut hir = HirModule::new("flux2_forward").with_fusion_policy(FusionPolicy::Direct);
61 let mut params = Flux2GraphParams::new();
62 let mut typed_params = Flux2TypedParams::new();
63 let mut b = Flux2HirBuilder::new(
64 &mut hir,
65 &mut params,
66 &mut typed_params,
67 cfg,
68 weights,
69 batch,
70 img_seq,
71 txt_seq,
72 packed,
73 typed_linears,
74 );
75 b.build_forward(img_ids, txt_ids)?;
76 Ok(Flux2ForwardGraph {
77 hir,
78 params,
79 typed_params,
80 })
81}
82
83pub fn build_flux2_forward_graph(
84 cfg: &Flux2Config,
85 weights: &Flux2Weights,
86 batch: usize,
87 img_seq: usize,
88 txt_seq: usize,
89 img_ids: &[f32],
90 txt_ids: &[f32],
91) -> Result<(Graph, Flux2GraphParams)> {
92 let built = crate::flow::Flux2Flow::new(cfg, weights)
93 .batch(batch)
94 .img_seq(img_seq)
95 .txt_seq(txt_seq)
96 .position_ids(img_ids.to_vec(), txt_ids.to_vec())
97 .build_forward(img_ids, txt_ids)?;
98 let (graph, params) = rlx_core::flow_util::graph_from_built(built.model)?;
99 Ok((graph, params))
100}
101
102pub fn compile_flux2_forward(
103 cfg: &Flux2Config,
104 weights: &Flux2Weights,
105 batch: usize,
106 img_seq: usize,
107 txt_seq: usize,
108 img_ids: &[f32],
109 txt_ids: &[f32],
110 device: rlx_runtime::Device,
111 packed: Option<&Flux2PackedParams>,
112 typed_linears: Option<&TypedLinearStore>,
113 aot: Option<&rlx_runtime::AotCache>,
114) -> Result<(rlx_runtime::CompiledGraph, Flux2GraphParams)> {
115 use crate::compile_util::{compile_hir_cached, flux2_denoiser_aot_key};
116
117 super::device::assert_flux2_device_available(device)?;
118 let g = build_flux2_forward_hir(
119 cfg,
120 weights,
121 batch,
122 img_seq,
123 txt_seq,
124 img_ids,
125 txt_ids,
126 packed,
127 typed_linears,
128 )?;
129 let key = flux2_denoiser_aot_key(
130 device,
131 batch,
132 img_seq,
133 txt_seq,
134 img_ids,
135 txt_ids,
136 packed.is_some(),
137 );
138 let mut compiled = compile_hir_cached(
139 device,
140 aot,
141 &key,
142 g.hir,
143 &super::compile_util::flux2_compile_profile(),
144 )?;
145 for (name, data) in &g.params {
146 compiled.set_param(name, data);
147 }
148 for (name, data, dtype) in &g.typed_params {
149 compiled.set_param_typed(name, data, *dtype);
150 }
151 Ok((compiled, g.params))
152}
153
154pub fn build_flux2_dual_section_hir(
156 cfg: &Flux2Config,
157 weights: &Flux2Weights,
158 batch: usize,
159 img_seq: usize,
160 txt_seq: usize,
161 img_ids: &[f32],
162 txt_ids: &[f32],
163) -> Result<Flux2ForwardGraph> {
164 let mut hir = HirModule::new("flux2_dual").with_fusion_policy(FusionPolicy::Direct);
165 let mut params = Flux2GraphParams::new();
166 let mut typed_params = Flux2TypedParams::new();
167 let mut b = Flux2HirBuilder::new(
168 &mut hir,
169 &mut params,
170 &mut typed_params,
171 cfg,
172 weights,
173 batch,
174 img_seq,
175 txt_seq,
176 None,
177 None,
178 );
179 let (hidden, _encoder, _cos, _sin, _temb) = b.build_dual_section(img_ids, txt_ids)?;
180 hir.outputs = vec![hidden];
181 Ok(Flux2ForwardGraph {
182 hir,
183 params,
184 typed_params,
185 })
186}
187
188pub(crate) struct Flux2HirBuilder<'a> {
189 hir: &'a mut HirModule,
190 params: &'a mut Flux2GraphParams,
191 typed_params: &'a mut Flux2TypedParams,
192 weights: &'a Flux2Weights,
193 packed: Option<&'a Flux2PackedParams>,
194 typed_linears: Option<&'a TypedLinearStore>,
195 cfg: &'a Flux2Config,
196 batch: usize,
197 img_seq: usize,
198 txt_seq: usize,
199 dim: usize,
200 heads: usize,
201 head_dim: usize,
202 eps: f32,
203 rope_dim: usize,
204 mlp_hidden: usize,
205 f: DType,
206}
207
208pub(crate) type Flux2DoubleMod = (
210 (HirNodeId, HirNodeId, HirNodeId),
211 (HirNodeId, HirNodeId, HirNodeId),
212);
213
214impl<'a> Flux2HirBuilder<'a> {
215 fn new(
216 hir: &'a mut HirModule,
217 params: &'a mut Flux2GraphParams,
218 typed_params: &'a mut Flux2TypedParams,
219 cfg: &'a Flux2Config,
220 weights: &'a Flux2Weights,
221 batch: usize,
222 img_seq: usize,
223 txt_seq: usize,
224 packed: Option<&'a Flux2PackedParams>,
225 typed_linears: Option<&'a TypedLinearStore>,
226 ) -> Self {
227 let dim = cfg.inner_dim();
228 Self {
229 hir,
230 params,
231 typed_params,
232 weights,
233 packed,
234 typed_linears,
235 cfg,
236 batch,
237 img_seq,
238 txt_seq,
239 dim,
240 heads: cfg.num_attention_heads,
241 head_dim: cfg.attention_head_dim,
242 eps: cfg.eps as f32,
243 rope_dim: cfg.axes_dims_rope.iter().sum(),
244 mlp_hidden: cfg.ff_inner_dim(),
245 f: DType::F32,
246 }
247 }
248
249 pub(crate) fn from_emit_parts(
250 hir: &'a mut HirModule,
251 params: &'a mut Flux2GraphParams,
252 typed_params: &'a mut Flux2TypedParams,
253 cfg: &'a Flux2Config,
254 weights: &'a Flux2Weights,
255 batch: usize,
256 img_seq: usize,
257 txt_seq: usize,
258 ) -> Self {
259 Self::new(
260 hir,
261 params,
262 typed_params,
263 cfg,
264 weights,
265 batch,
266 img_seq,
267 txt_seq,
268 None,
269 None,
270 )
271 }
272
273 fn build_dual_section(
274 &mut self,
275 img_ids: &[f32],
276 txt_ids: &[f32],
277 ) -> Result<(HirNodeId, HirNodeId, HirNodeId, HirNodeId, HirNodeId)> {
278 let hidden_in = self.hir.input(
279 "hidden",
280 Shape::new(&[self.batch, self.img_seq, self.cfg.in_channels], self.f),
281 );
282 let enc_in = self.hir.input(
283 "encoder",
284 Shape::new(
285 &[self.batch, self.txt_seq, self.cfg.joint_attention_dim],
286 self.f,
287 ),
288 );
289 let temb_in = self.hir.input("temb", self.b1());
290
291 let mod_img = self.modulation_params(&self.weights.double_mod_img, "mod_img", temb_in)?;
292 let mod_txt = self.modulation_params(&self.weights.double_mod_txt, "mod_txt", temb_in)?;
293
294 let mut hidden = self.linear(
295 hidden_in,
296 &self.weights.x_embedder,
297 "x_embedder",
298 self.b3i(),
299 )?;
300 let mut encoder = self.linear(
301 enc_in,
302 &self.weights.context_embedder,
303 "context_embedder",
304 self.b3t(),
305 )?;
306
307 let (cos_id, sin_id) = self.rope_params(img_ids, txt_ids)?;
308
309 for (li, block) in self.weights.transformer_blocks.iter().enumerate() {
310 (hidden, encoder) = self.emit_dual_stream_block(
311 li, block, hidden, encoder, &mod_img, &mod_txt, cos_id, sin_id,
312 )?;
313 }
314
315 Ok((hidden, encoder, cos_id, sin_id, temb_in))
316 }
317
318 fn build_forward(&mut self, img_ids: &[f32], txt_ids: &[f32]) -> Result<()> {
319 let (hidden, encoder, cos_id, sin_id, temb_in) =
320 self.build_dual_section(img_ids, txt_ids)?;
321 let out = self.emit_single_stream_tail(hidden, encoder, cos_id, sin_id, temb_in)?;
322 self.hir.outputs = vec![out];
323 Ok(())
324 }
325
326 pub(crate) fn emit_single_stream_tail(
328 &mut self,
329 hidden: HirNodeId,
330 encoder: HirNodeId,
331 cos_id: HirNodeId,
332 sin_id: HirNodeId,
333 temb_in: HirNodeId,
334 ) -> Result<HirNodeId> {
335 let single_mod =
336 self.single_modulation_params(&self.weights.single_mod, "mod_single", temb_in)?;
337
338 let stream = self.concat(
339 vec![encoder, hidden],
340 1,
341 self.b3(self.txt_seq + self.img_seq),
342 );
343 let mut stream = stream;
344 for (li, block) in self.weights.single_transformer_blocks.iter().enumerate() {
345 let lp = format!("sblk{li}");
346 let n = self.layer_norm_no_affine(
347 stream,
348 self.b3(self.txt_seq + self.img_seq),
349 &format!("{lp}.n"),
350 )?;
351 let n = self.modulate(n, single_mod.0, single_mod.1, self.txt_seq + self.img_seq);
352 let attn =
353 self.parallel_attention(&block.attn, &format!("{lp}.attn"), n, cos_id, sin_id)?;
354 let attn_g = self.gate(attn, single_mod.2, self.txt_seq + self.img_seq);
355 stream = self.add(stream, attn_g, self.b3(self.txt_seq + self.img_seq));
356 }
357
358 let hidden_out = self.narrow(stream, 1, self.txt_seq, self.img_seq, self.b3i());
359 let normed = self.ada_norm_out(hidden_out, temb_in, &self.weights.norm_out)?;
360 self.linear(normed, &self.weights.proj_out, "proj_out", self.b3o())
361 }
362
363 pub(crate) fn emit_dual_stream_block(
365 &mut self,
366 li: usize,
367 block: &super::weights::Flux2DoubleBlockWeights,
368 hidden: HirNodeId,
369 encoder: HirNodeId,
370 mod_img: &Flux2DoubleMod,
371 mod_txt: &Flux2DoubleMod,
372 cos_id: HirNodeId,
373 sin_id: HirNodeId,
374 ) -> Result<(HirNodeId, HirNodeId)> {
375 let lp = format!("blk{li}");
376 let (img_msa, img_mlp) = mod_img;
377 let (txt_msa, txt_mlp) = mod_txt;
378
379 let n1 = self.layer_norm_no_affine(hidden, self.b3i(), &format!("{lp}.n1"))?;
380 let n1 = self.modulate(n1, img_msa.0, img_msa.1, self.img_seq);
381 let nc = self.layer_norm_no_affine(encoder, self.b3t(), &format!("{lp}.nc"))?;
382 let nc = self.modulate(nc, txt_msa.0, txt_msa.1, self.txt_seq);
383
384 let (enc_a, img_a) =
385 self.dual_attention(&block.attn, &format!("{lp}.attn"), n1, nc, cos_id, sin_id)?;
386 let img_g = self.gate(img_a, img_msa.2, self.img_seq);
387 let hidden = self.add(hidden, img_g, self.b3i());
388 let txt_g = self.gate(enc_a, txt_msa.2, self.txt_seq);
389 let encoder = self.add(encoder, txt_g, self.b3t());
390
391 let n2 = self.layer_norm_no_affine(hidden, self.b3i(), &format!("{lp}.n2"))?;
392 let n2 = self.modulate_scale_shift(n2, img_mlp.1, img_mlp.0, self.img_seq);
393 let ff = self.feed_forward(&block.ff, &format!("{lp}.ff"), n2, self.img_seq)?;
394 let ff_g = self.gate(ff, img_mlp.2, self.img_seq);
395 let hidden = self.add(hidden, ff_g, self.b3i());
396
397 let nc2 = self.layer_norm_no_affine(encoder, self.b3t(), &format!("{lp}.nc2"))?;
398 let nc2 = self.modulate_scale_shift(nc2, txt_mlp.1, txt_mlp.0, self.txt_seq);
399 let ffc = self.feed_forward(&block.ff_context, &format!("{lp}.ffc"), nc2, self.txt_seq)?;
400 let ffc_g = self.gate(ffc, txt_mlp.2, self.txt_seq);
401 let encoder = self.add(encoder, ffc_g, self.b3t());
402 Ok((hidden, encoder))
403 }
404
405 fn b1(&self) -> Shape {
406 Shape::new(&[self.batch, self.dim], self.f)
407 }
408 fn b3i(&self) -> Shape {
409 self.b3(self.img_seq)
410 }
411 fn b3t(&self) -> Shape {
412 self.b3(self.txt_seq)
413 }
414 fn b3o(&self) -> Shape {
415 Shape::new(&[self.batch, self.img_seq, self.cfg.proj_out_dim()], self.f)
416 }
417 fn b3(&self, seq: usize) -> Shape {
418 Shape::new(&[self.batch, seq, self.dim], self.f)
419 }
420
421 fn register_param(&mut self, name: &str, data: Vec<f32>, shape: Shape) -> HirNodeId {
422 let id = self.hir.param(name, shape);
423 self.params.insert(name.to_string(), data);
424 id
425 }
426
427 pub(crate) fn linear(
428 &mut self,
429 x: HirNodeId,
430 lw: &LinearWeights,
431 name: &str,
432 out_shape: Shape,
433 ) -> Result<HirNodeId> {
434 if let Some(p) = self.packed.and_then(|m| m.get_nvfp4(name)) {
435 return self.linear_nvfp4(x, p, name, out_shape);
436 }
437 if let Some(p) = self.packed.and_then(|m| m.get_gguf(name)) {
438 return self.linear_gguf(x, p, name, out_shape);
439 }
440 if let Some(tl) = self.typed_linears.and_then(|t| t.get(name)) {
441 return self.linear_typed(x, tl, name, out_shape);
442 }
443 let w = self.register_param(
444 &format!("{name}.weight"),
445 lw.w_t.clone(),
446 Shape::new(&[lw.in_dim, lw.out_dim], self.f),
447 );
448 let bias = if lw.bias.iter().all(|&v| v == 0.0) {
449 None
450 } else {
451 let b = self.register_param(
452 &format!("{name}.bias"),
453 lw.bias.clone(),
454 Shape::new(&[lw.out_dim], self.f),
455 );
456 Some(b)
457 };
458 Ok(self.hir.linear(x, w, bias, None, out_shape))
459 }
460
461 fn linear_typed(
462 &mut self,
463 x: HirNodeId,
464 tl: &TypedLinear,
465 name: &str,
466 out_shape: Shape,
467 ) -> Result<HirNodeId> {
468 let w = self.register_typed_param_shaped(
469 &format!("{name}.weight"),
470 tl.weight_bytes.clone(),
471 tl.dtype,
472 Shape::new(&[tl.in_dim, tl.out_dim], tl.dtype),
473 );
474 let bias = if tl.bias.iter().all(|&v| v == 0.0) {
475 None
476 } else {
477 let b = self.register_param(
478 &format!("{name}.bias"),
479 tl.bias.clone(),
480 Shape::new(&[tl.out_dim], self.f),
481 );
482 Some(b)
483 };
484 Ok(self.hir.linear(x, w, bias, None, out_shape))
485 }
486
487 fn linear_nvfp4(
488 &mut self,
489 x: HirNodeId,
490 p: &Nvfp4LinearPacked,
491 name: &str,
492 out_shape: Shape,
493 ) -> Result<HirNodeId> {
494 use rlx_ir::QuantScheme;
495
496 let w_name = format!("{name}.weight");
497 let s_name = format!("{name}.scale");
498 let gs_name = format!("{name}.global_scale");
499 let w = self.register_typed_param(&w_name, p.w_q.clone(), DType::U8);
500 let scale = self.register_typed_param(&s_name, p.scale.clone(), DType::U8);
501 let gs = self.register_param(&gs_name, vec![p.global_scale], Shape::scalar(self.f));
502 let mut y = self.hir.dequant_matmul(
503 x,
504 w,
505 Some(scale),
506 Some(gs),
507 QuantScheme::Nvfp4Block,
508 out_shape.clone(),
509 );
510 if p.bias.iter().any(|&v| v != 0.0) {
511 let b = self.register_param(
512 &format!("{name}.bias"),
513 p.bias.clone(),
514 Shape::new(&[p.out_dim], self.f),
515 );
516 y = self
517 .hir
518 .mir(Op::Binary(BinaryOp::Add), vec![y, b], out_shape);
519 }
520 Ok(y)
521 }
522
523 fn linear_gguf(
524 &mut self,
525 x: HirNodeId,
526 p: &Flux2GgufLinearPacked,
527 name: &str,
528 out_shape: Shape,
529 ) -> Result<HirNodeId> {
530 let w_name = format!("{name}.weight");
531 let w = self.register_typed_param(&w_name, p.w_q.clone(), DType::U8);
532 let mut y = self
533 .hir
534 .dequant_matmul(x, w, None, None, p.scheme, out_shape.clone());
535 if p.bias.iter().any(|&v| v != 0.0) {
536 let b = self.register_param(
537 &format!("{name}.bias"),
538 p.bias.clone(),
539 Shape::new(&[p.out_dim], self.f),
540 );
541 y = self
542 .hir
543 .mir(Op::Binary(BinaryOp::Add), vec![y, b], out_shape);
544 }
545 Ok(y)
546 }
547
548 fn register_typed_param(&mut self, name: &str, data: Vec<u8>, dtype: DType) -> HirNodeId {
549 let shape = Shape::new(&[data.len()], dtype);
550 let id = self.hir.param(name, shape);
551 self.typed_params.push((name.to_string(), data, dtype));
552 id
553 }
554
555 fn register_typed_param_shaped(
556 &mut self,
557 name: &str,
558 data: Vec<u8>,
559 dtype: DType,
560 shape: Shape,
561 ) -> HirNodeId {
562 let id = self.hir.param(name, shape);
563 self.typed_params.push((name.to_string(), data, dtype));
564 id
565 }
566
567 fn layer_norm_no_affine(&mut self, x: HirNodeId, shape: Shape, tag: &str) -> Result<HirNodeId> {
568 let d = self.dim;
569 let g = self.register_param(
570 &format!("{tag}.ln1"),
571 vec![1.0f32; d],
572 Shape::new(&[d], self.f),
573 );
574 let b = self.register_param(
575 &format!("{tag}.ln0"),
576 vec![0.0f32; d],
577 Shape::new(&[d], self.f),
578 );
579 Ok(self.hir.mir(
580 Op::LayerNorm {
581 axis: -1,
582 eps: self.eps,
583 },
584 vec![x, g, b],
585 shape,
586 ))
587 }
588
589 pub(crate) fn modulation_params(
590 &mut self,
591 m: &Flux2ModulationWeights,
592 tag: &str,
593 temb: HirNodeId,
594 ) -> Result<Flux2DoubleMod> {
595 let h = self
596 .hir
597 .mir(Op::Activation(Activation::Silu), vec![temb], self.b1());
598 let mod_shape = Shape::new(&[self.batch, 6 * self.dim], self.f);
599 let mod_out = self.linear(h, &m.linear, tag, mod_shape)?;
600 let last = self.hir.node(mod_out).shape.rank() - 1;
601 let d = self.dim;
602 let b1 = self.b1();
603 let s0 = self.narrow(mod_out, last, 0, d, b1.clone());
604 let s1 = self.narrow(mod_out, last, d, d, b1.clone());
605 let s2 = self.narrow(mod_out, last, 2 * d, d, b1.clone());
606 let s3 = self.narrow(mod_out, last, 3 * d, d, b1.clone());
607 let s4 = self.narrow(mod_out, last, 4 * d, d, b1.clone());
608 let s5 = self.narrow(mod_out, last, 5 * d, d, b1);
609 Ok(((s0, s1, s2), (s3, s4, s5)))
610 }
611
612 fn single_modulation_params(
613 &mut self,
614 m: &Flux2ModulationWeights,
615 tag: &str,
616 temb: HirNodeId,
617 ) -> Result<(HirNodeId, HirNodeId, HirNodeId)> {
618 let h = self
619 .hir
620 .mir(Op::Activation(Activation::Silu), vec![temb], self.b1());
621 let mod_shape = Shape::new(&[self.batch, 3 * self.dim], self.f);
622 let mod_out = self.linear(h, &m.linear, tag, mod_shape)?;
623 let last = self.hir.node(mod_out).shape.rank() - 1;
624 let d = self.dim;
625 let b1 = self.b1();
626 let s0 = self.narrow(mod_out, last, 0, d, b1.clone());
627 let s1 = self.narrow(mod_out, last, d, d, b1.clone());
628 let s2 = self.narrow(mod_out, last, 2 * d, d, b1);
629 Ok((s0, s1, s2))
630 }
631
632 fn broadcast_bd(&mut self, v: HirNodeId, seq: usize) -> HirNodeId {
633 let b1d = self.reshape(v, vec![self.batch as i64, 1, self.dim as i64]);
634 self.mir_expand(b1d, vec![self.batch as i64, seq as i64, self.dim as i64])
635 }
636
637 fn modulate(
638 &mut self,
639 x: HirNodeId,
640 shift: HirNodeId,
641 scale: HirNodeId,
642 seq: usize,
643 ) -> HirNodeId {
644 let shape = self.b3(seq);
645 let shift_b = self.broadcast_bd(shift, seq);
646 let scale_b = self.broadcast_bd(scale, seq);
647 let ones = self.ones3(seq);
648 let scaled_base = self.add(ones, scale_b, shape.clone());
649 let scaled = self.mul(x, scaled_base, shape.clone());
650 self.add(scaled, shift_b, shape)
651 }
652
653 fn modulate_scale_shift(
654 &mut self,
655 x: HirNodeId,
656 scale: HirNodeId,
657 shift: HirNodeId,
658 seq: usize,
659 ) -> HirNodeId {
660 let shape = self.b3(seq);
661 let shift_b = self.broadcast_bd(shift, seq);
662 let scale_b = self.broadcast_bd(scale, seq);
663 let ones = self.ones3(seq);
664 let scaled_base = self.add(ones, scale_b, shape.clone());
665 let scaled = self.mul(x, scaled_base, shape.clone());
666 self.add(scaled, shift_b, shape)
667 }
668
669 fn gate(&mut self, x: HirNodeId, gate: HirNodeId, seq: usize) -> HirNodeId {
670 let g = self.broadcast_bd(gate, seq);
671 self.mul(x, g, self.b3(seq))
672 }
673
674 fn feed_forward(
675 &mut self,
676 ff: &Flux2FeedForwardWeights,
677 tag: &str,
678 x: HirNodeId,
679 seq: usize,
680 ) -> Result<HirNodeId> {
681 let rows = self.batch * seq;
682 let inner = ff.linear_in.out_dim / 2;
683 let flat = self.reshape(x, vec![rows as i64, self.dim as i64]);
684 let h = self.linear(
685 flat,
686 &ff.linear_in,
687 &format!("{tag}.in"),
688 Shape::new(&[rows, ff.linear_in.out_dim], self.f),
689 )?;
690 let h3 = self.reshape(
691 h,
692 vec![self.batch as i64, seq as i64, ff.linear_in.out_dim as i64],
693 );
694 let act = self.hir.mir(
695 Op::FusedSwiGLU {
696 cast_to: None,
697 gate_first: true,
698 },
699 vec![h3],
700 self.b3(seq).with_dim(2, Dim::Static(inner)),
701 );
702 let act_flat = self.reshape(act, vec![rows as i64, inner as i64]);
703 self.linear(
704 act_flat,
705 &ff.linear_out,
706 &format!("{tag}.out"),
707 Shape::new(&[rows, self.dim], self.f),
708 )
709 .map(|o| self.reshape(o, vec![self.batch as i64, seq as i64, self.dim as i64]))
710 }
711
712 fn rms_gamma(&mut self, rms: &RmsNormWeight, name: &str) -> HirNodeId {
713 let mut g = vec![0.0f32; self.dim];
714 for h in 0..self.heads {
715 g[h * self.head_dim..(h + 1) * self.head_dim].copy_from_slice(&rms.scale);
716 }
717 self.register_param(name, g, Shape::new(&[self.dim], self.f))
718 }
719
720 fn rms_norm(&mut self, x: HirNodeId, gamma: HirNodeId, shape: Shape) -> HirNodeId {
721 let beta = self.register_param(
722 &format!("rmsb_{}", self.params.len()),
723 vec![0.0f32; self.dim],
724 Shape::new(&[self.dim], self.f),
725 );
726 self.hir.mir(
727 Op::RmsNorm {
728 axis: -1,
729 eps: 1e-6,
730 },
731 vec![x, gamma, beta],
732 shape,
733 )
734 }
735
736 fn linear_rms(
737 &mut self,
738 x: HirNodeId,
739 lw: &LinearWeights,
740 rms: &RmsNormWeight,
741 name: &str,
742 shape: Shape,
743 ) -> Result<HirNodeId> {
744 let h = self.linear(x, lw, name, shape.clone())?;
745 let g = self.rms_gamma(rms, &format!("{name}.rms"));
746 Ok(self.rms_norm(h, g, shape))
747 }
748
749 fn dual_attention(
750 &mut self,
751 attn: &Flux2DualAttnWeights,
752 tag: &str,
753 hidden: HirNodeId,
754 encoder: HirNodeId,
755 cos: HirNodeId,
756 sin: HirNodeId,
757 ) -> Result<(HirNodeId, HirNodeId)> {
758 let total = self.txt_seq + self.img_seq;
759 let b3i = self.b3i();
760 let b3t = self.b3t();
761 let q_i = self.linear_rms(
762 hidden,
763 &attn.to_q,
764 &attn.norm_q,
765 &format!("{tag}.q"),
766 b3i.clone(),
767 )?;
768 let k_i = self.linear_rms(
769 hidden,
770 &attn.to_k,
771 &attn.norm_k,
772 &format!("{tag}.k"),
773 b3i.clone(),
774 )?;
775 let v_i = self.linear(hidden, &attn.to_v, &format!("{tag}.v"), b3i)?;
776 let q_t = self.linear_rms(
777 encoder,
778 &attn.add_q,
779 &attn.norm_added_q,
780 &format!("{tag}.aq"),
781 b3t.clone(),
782 )?;
783 let k_t = self.linear_rms(
784 encoder,
785 &attn.add_k,
786 &attn.norm_added_k,
787 &format!("{tag}.ak"),
788 b3t.clone(),
789 )?;
790 let v_t = self.linear(encoder, &attn.add_v, &format!("{tag}.av"), b3t)?;
791
792 let q = self.concat(vec![q_t, q_i], 1, self.b3(total));
793 let k = self.concat(vec![k_t, k_i], 1, self.b3(total));
794 let v = self.concat(vec![v_t, v_i], 1, self.b3(total));
795
796 let q = self.rope(q, cos, sin, self.b3(total));
797 let k = self.rope(k, cos, sin, self.b3(total));
798
799 let out = self.hir.attention(
800 q,
801 k,
802 v,
803 None,
804 self.heads,
805 self.head_dim,
806 MaskKind::None,
807 self.b3(total),
808 );
809
810 let txt_out = self.narrow(out, 1, 0, self.txt_seq, self.b3t());
811 let img_out = self.narrow(out, 1, self.txt_seq, self.img_seq, self.b3i());
812 let enc_proj = self.linear(txt_out, &attn.to_add_out, &format!("{tag}.ao"), self.b3t())?;
813 let img_proj = self.linear(img_out, &attn.to_out, &format!("{tag}.o"), self.b3i())?;
814 Ok((enc_proj, img_proj))
815 }
816
817 fn parallel_attention(
818 &mut self,
819 attn: &Flux2ParallelAttnWeights,
820 tag: &str,
821 x: HirNodeId,
822 cos: HirNodeId,
823 sin: HirNodeId,
824 ) -> Result<HirNodeId> {
825 let seq = self.txt_seq + self.img_seq;
826 let rows = self.batch * seq;
827 let flat = self.reshape(x, vec![rows as i64, self.dim as i64]);
828 let fused = self.linear(
829 flat,
830 &attn.to_qkv_mlp,
831 &format!("{tag}.fused"),
832 Shape::new(&[rows, attn.to_qkv_mlp.out_dim], self.f),
833 )?;
834 let fused3 = self.reshape(
835 fused,
836 vec![
837 self.batch as i64,
838 seq as i64,
839 attn.to_qkv_mlp.out_dim as i64,
840 ],
841 );
842 let last = 2;
843 let b3s = self.b3(seq);
844 let q = self.narrow(fused3, last, 0, self.dim, b3s.clone());
845 let k = self.narrow(fused3, last, self.dim, self.dim, b3s.clone());
846 let v = self.narrow(fused3, last, 2 * self.dim, self.dim, b3s.clone());
847 let mlp = self.narrow(
848 fused3,
849 last,
850 3 * self.dim,
851 2 * self.mlp_hidden,
852 Shape::new(&[self.batch, seq, 2 * self.mlp_hidden], self.f),
853 );
854
855 let nq = self.rms_gamma(&attn.norm_q, &format!("{tag}.nq"));
856 let nk = self.rms_gamma(&attn.norm_k, &format!("{tag}.nk"));
857 let q = self.rms_norm(q, nq, b3s.clone());
858 let k = self.rms_norm(k, nk, b3s.clone());
859 let q = self.rope(q, cos, sin, self.b3(seq));
860 let k = self.rope(k, cos, sin, self.b3(seq));
861 let attn_out = self.hir.attention(
862 q,
863 k,
864 v,
865 None,
866 self.heads,
867 self.head_dim,
868 MaskKind::None,
869 self.b3(seq),
870 );
871
872 let mlp_act = self.hir.mir(
873 Op::FusedSwiGLU {
874 cast_to: None,
875 gate_first: true,
876 },
877 vec![mlp],
878 self.b3(seq).with_dim(2, Dim::Static(self.mlp_hidden)),
879 );
880 let cat = self.concat(
881 vec![attn_out, mlp_act],
882 2,
883 Shape::new(&[self.batch, seq, self.dim + self.mlp_hidden], self.f),
884 );
885 let cat_flat = self.reshape(cat, vec![rows as i64, (self.dim + self.mlp_hidden) as i64]);
886 let out = self.linear(
887 cat_flat,
888 &attn.to_out,
889 &format!("{tag}.out"),
890 Shape::new(&[rows, self.dim], self.f),
891 )?;
892 Ok(self.reshape(out, vec![self.batch as i64, seq as i64, self.dim as i64]))
893 }
894
895 fn ada_norm_out(
896 &mut self,
897 x: HirNodeId,
898 temb: HirNodeId,
899 norm: &Flux2NormOutWeights,
900 ) -> Result<HirNodeId> {
901 let h = self
902 .hir
903 .mir(Op::Activation(Activation::Silu), vec![temb], self.b1());
904 let emb = self.linear(
905 h,
906 &norm.linear,
907 "norm_out",
908 Shape::new(&[self.batch, 2 * self.dim], self.f),
909 )?;
910 let last = self.hir.node(emb).shape.rank() - 1;
911 let b1 = self.b1();
912 let scale = self.narrow(emb, last, 0, self.dim, b1.clone());
913 let shift = self.narrow(emb, last, self.dim, self.dim, b1);
914 let n = self.layer_norm_no_affine(x, self.b3i(), "norm_out_ln")?;
915 let b3i = self.b3i();
916 let scale_b = self.broadcast_bd(scale, self.img_seq);
917 let shift_b = self.broadcast_bd(shift, self.img_seq);
918 let ones = self.ones3(self.img_seq);
919 let scaled_base = self.add(ones, scale_b, b3i.clone());
920 let scaled = self.mul(n, scaled_base, b3i.clone());
921 Ok(self.add(scaled, shift_b, b3i))
922 }
923
924 pub(crate) fn rope_params(
925 &mut self,
926 img_ids: &[f32],
927 txt_ids: &[f32],
928 ) -> Result<(HirNodeId, HirNodeId)> {
929 let n_axes = 4usize;
930 let total = self.txt_seq + self.img_seq;
931 let mut ids = vec![0.0f32; total * n_axes];
932 for t in 0..self.txt_seq {
933 for a in 0..n_axes {
934 ids[t * n_axes + a] = txt_ids[t * n_axes + a];
935 }
936 }
937 for t in 0..self.img_seq {
938 for a in 0..n_axes {
939 ids[(self.txt_seq + t) * n_axes + a] = img_ids[t * n_axes + a];
940 }
941 }
942 let (cos, sin) = flux2_pos_embed(self.cfg, &ids, total, n_axes);
943 let cos_id =
944 self.register_param("rope_cos", cos, Shape::new(&[total, self.rope_dim], self.f));
945 let sin_id =
946 self.register_param("rope_sin", sin, Shape::new(&[total, self.rope_dim], self.f));
947 Ok((cos_id, sin_id))
948 }
949
950 fn rope(&mut self, x: HirNodeId, cos: HirNodeId, sin: HirNodeId, shape: Shape) -> HirNodeId {
951 self.hir.mir(
952 Op::Rope {
953 head_dim: self.head_dim,
954 n_rot: self.rope_dim.min(self.head_dim),
955 },
956 vec![x, cos, sin],
957 shape,
958 )
959 }
960
961 #[allow(dead_code)]
962 fn ones1(&mut self) -> HirNodeId {
963 self.register_param(
964 &format!("ones1_{}", self.params.len()),
965 vec![1.0f32; self.dim],
966 Shape::new(&[self.dim], self.f),
967 )
968 }
969
970 fn ones3(&mut self, seq: usize) -> HirNodeId {
971 let id = self.register_param(
972 &format!("ones3_{}", self.params.len()),
973 vec![1.0f32; self.dim],
974 Shape::new(&[1, 1, self.dim], self.f),
975 );
976 self.mir_expand(id, vec![self.batch as i64, seq as i64, self.dim as i64])
977 }
978
979 fn reshape(&mut self, x: HirNodeId, new_shape: Vec<i64>) -> HirNodeId {
980 let shape = self.infer_reshape(&self.hir.node(x).shape, &new_shape);
981 self.hir.mir(Op::Reshape { new_shape }, vec![x], shape)
982 }
983
984 fn narrow(
985 &mut self,
986 x: HirNodeId,
987 axis: usize,
988 start: usize,
989 len: usize,
990 shape: Shape,
991 ) -> HirNodeId {
992 self.hir
993 .mir(Op::Narrow { axis, start, len }, vec![x], shape)
994 }
995
996 fn concat(&mut self, inputs: Vec<HirNodeId>, axis: usize, shape: Shape) -> HirNodeId {
997 self.hir.mir(Op::Concat { axis }, inputs, shape)
998 }
999
1000 fn add(&mut self, a: HirNodeId, b: HirNodeId, shape: Shape) -> HirNodeId {
1001 self.hir.mir(Op::Binary(BinaryOp::Add), vec![a, b], shape)
1002 }
1003
1004 fn mul(&mut self, a: HirNodeId, b: HirNodeId, shape: Shape) -> HirNodeId {
1005 self.hir.mir(Op::Binary(BinaryOp::Mul), vec![a, b], shape)
1006 }
1007
1008 fn mir_expand(&mut self, x: HirNodeId, target: Vec<i64>) -> HirNodeId {
1009 let shape = self.infer_reshape(&self.hir.node(x).shape, &target);
1010 self.hir.mir(
1011 Op::Expand {
1012 target_shape: target,
1013 },
1014 vec![x],
1015 shape,
1016 )
1017 }
1018
1019 fn infer_reshape(&self, input: &Shape, new_shape: &[i64]) -> Shape {
1020 let static_dims: Vec<usize> = new_shape.iter().map(|&d| d as usize).collect();
1021 Shape::new(&static_dims, input.dtype())
1022 }
1023}
1024
1025pub fn host_temb(
1027 weights: &Flux2Weights,
1028 cfg: &Flux2Config,
1029 timestep: &[f32],
1030 guidance: Option<&[f32]>,
1031) -> Result<Vec<f32>> {
1032 let t_scaled: Vec<f32> = timestep.iter().map(|t| t * 1000.0).collect();
1033 let g_scaled = guidance.map(|g| g.iter().map(|x| x * 1000.0).collect::<Vec<_>>());
1034 time_guidance_embed(
1035 &t_scaled,
1036 g_scaled.as_deref(),
1037 &weights.time_guidance,
1038 cfg.inner_dim(),
1039 )
1040}
1041
1042pub fn host_temb_dual(
1044 weights: &Flux2Weights,
1045 cfg: &Flux2Config,
1046 timestep: &[f32],
1047 timestep_target: &[f32],
1048 guidance: Option<&[f32]>,
1049) -> Result<Vec<f32>> {
1050 let t_scaled: Vec<f32> = timestep.iter().map(|t| t * 1000.0).collect();
1051 let t2_scaled: Vec<f32> = timestep_target.iter().map(|t| t * 1000.0).collect();
1052 let g_scaled = guidance.map(|g| g.iter().map(|x| x * 1000.0).collect::<Vec<_>>());
1053 let tg_tgt = weights
1054 .time_guidance_target
1055 .as_ref()
1056 .unwrap_or(&weights.time_guidance);
1057 crate::layers::time_guidance_embed_dual(
1058 &t_scaled,
1059 &t2_scaled,
1060 g_scaled.as_deref(),
1061 &weights.time_guidance,
1062 tg_tgt,
1063 cfg.inner_dim(),
1064 )
1065}
1066
1067#[cfg(test)]
1068mod tests {
1069 use super::*;
1070 use crate::{
1071 Flux2Config, Flux2ForwardInput, extract_flux2_weights, flux2_transformer_forward,
1072 prepare_weight_map, synthetic_weights,
1073 };
1074
1075 #[test]
1076 fn nvfp4_x_embedder_lowers() {
1077 use crate::synthetic_flux2_packed_tiny;
1078
1079 let cfg = Flux2Config::tiny();
1080 let wm = synthetic_weights(&cfg);
1081 let w = extract_flux2_weights(prepare_weight_map(wm), &cfg).unwrap();
1082 let packed = synthetic_flux2_packed_tiny(&cfg);
1083 let g = build_flux2_forward_hir(
1084 &cfg,
1085 &w,
1086 1,
1087 4,
1088 3,
1089 &[0.0; 16],
1090 &[0.0; 12],
1091 Some(&packed),
1092 None,
1093 )
1094 .unwrap();
1095 assert!(!g.typed_params.is_empty());
1096 g.hir.lower_to_mir().expect("lower nvfp4");
1097 }
1098
1099 #[test]
1100 fn forward_hir_lowers() {
1101 let cfg = Flux2Config::tiny();
1102 let wm = synthetic_weights(&cfg);
1103 let w = extract_flux2_weights(prepare_weight_map(wm), &cfg).unwrap();
1104 let g =
1105 build_flux2_forward_hir(&cfg, &w, 1, 4, 3, &[0.0; 16], &[0.0; 12], None, None).unwrap();
1106 assert_eq!(g.hir.outputs.len(), 1);
1107 g.hir.lower_to_mir().expect("lower");
1108 }
1109
1110 #[test]
1111 fn compiled_forward_matches_native() {
1112 let cfg = Flux2Config::tiny();
1113 let wm = synthetic_weights(&cfg);
1114 let w = extract_flux2_weights(prepare_weight_map(wm), &cfg).unwrap();
1115 let b = 1usize;
1116 let img_seq = 4usize;
1117 let txt_seq = 3usize;
1118 let hidden = (0..b * img_seq * cfg.in_channels)
1119 .map(|i| (i as f32 * 0.01).sin())
1120 .collect::<Vec<_>>();
1121 let encoder = (0..b * txt_seq * cfg.joint_attention_dim)
1122 .map(|i| (i as f32 * 0.02).cos())
1123 .collect::<Vec<_>>();
1124 let timestep = vec![0.5f32];
1125 let guidance = vec![3.5f32];
1126 let img_ids = vec![0.0f32; img_seq * 4];
1127 let txt_ids = vec![0.0f32; txt_seq * 4];
1128
1129 let native = flux2_transformer_forward(
1130 &w,
1131 &cfg,
1132 Flux2ForwardInput {
1133 hidden_states: &hidden,
1134 encoder_hidden_states: &encoder,
1135 timestep: ×tep,
1136 timestep_target: None,
1137 guidance: Some(&guidance),
1138 img_ids: &img_ids,
1139 txt_ids: &txt_ids,
1140 batch: b,
1141 img_seq,
1142 txt_seq,
1143 },
1144 )
1145 .unwrap();
1146
1147 let temb = host_temb(&w, &cfg, ×tep, Some(&guidance)).unwrap();
1148 let (mut compiled, _) = compile_flux2_forward(
1149 &cfg,
1150 &w,
1151 b,
1152 img_seq,
1153 txt_seq,
1154 &img_ids,
1155 &txt_ids,
1156 rlx_runtime::Device::Cpu,
1157 None,
1158 None,
1159 None,
1160 )
1161 .unwrap();
1162 let out = compiled
1163 .run(&[
1164 ("hidden", hidden.as_slice()),
1165 ("encoder", encoder.as_slice()),
1166 ("temb", temb.as_slice()),
1167 ])
1168 .remove(0);
1169
1170 assert_eq!(out.len(), native.len());
1171 let max_diff = native
1172 .iter()
1173 .zip(&out)
1174 .map(|(a, b)| (a - b).abs())
1175 .fold(0.0f32, f32::max);
1176 assert!(max_diff < 2e-2, "HIR vs native max_abs_diff={max_diff}");
1177 }
1178
1179 #[cfg(feature = "cuda")]
1180 #[test]
1181 fn compiled_forward_matches_native_on_cuda() {
1182 use rlx_runtime::Device;
1183
1184 if !rlx_runtime::is_available(Device::Cuda) {
1185 eprintln!("skip: CUDA not available");
1186 return;
1187 }
1188 let cfg = Flux2Config::tiny();
1189 let wm = synthetic_weights(&cfg);
1190 let w = extract_flux2_weights(prepare_weight_map(wm), &cfg).unwrap();
1191 let b = 1usize;
1192 let img_seq = 4usize;
1193 let txt_seq = 3usize;
1194 let hidden = (0..b * img_seq * cfg.in_channels)
1195 .map(|i| (i as f32 * 0.01).sin())
1196 .collect::<Vec<_>>();
1197 let encoder = (0..b * txt_seq * cfg.joint_attention_dim)
1198 .map(|i| (i as f32 * 0.02).cos())
1199 .collect::<Vec<_>>();
1200 let timestep = vec![0.5f32];
1201 let guidance = vec![3.5f32];
1202 let img_ids = vec![0.0f32; img_seq * 4];
1203 let txt_ids = vec![0.0f32; txt_seq * 4];
1204
1205 let native = flux2_transformer_forward(
1206 &w,
1207 &cfg,
1208 Flux2ForwardInput {
1209 hidden_states: &hidden,
1210 encoder_hidden_states: &encoder,
1211 timestep: ×tep,
1212 timestep_target: None,
1213 guidance: Some(&guidance),
1214 img_ids: &img_ids,
1215 txt_ids: &txt_ids,
1216 batch: b,
1217 img_seq,
1218 txt_seq,
1219 },
1220 )
1221 .unwrap();
1222
1223 let temb = host_temb(&w, &cfg, ×tep, Some(&guidance)).unwrap();
1224 let (mut compiled, _) = compile_flux2_forward(
1225 &cfg,
1226 &w,
1227 b,
1228 img_seq,
1229 txt_seq,
1230 &img_ids,
1231 &txt_ids,
1232 Device::Cuda,
1233 None,
1234 None,
1235 None,
1236 )
1237 .unwrap();
1238 let out = compiled
1239 .run(&[
1240 ("hidden", hidden.as_slice()),
1241 ("encoder", encoder.as_slice()),
1242 ("temb", temb.as_slice()),
1243 ])
1244 .remove(0);
1245
1246 assert_eq!(out.len(), native.len());
1247 let max_diff = native
1248 .iter()
1249 .zip(&out)
1250 .map(|(a, b)| (a - b).abs())
1251 .fold(0.0f32, f32::max);
1252 assert!(
1253 max_diff < 2e-2,
1254 "CUDA HIR vs native max_abs_diff={max_diff}"
1255 );
1256 }
1257}