Skip to main content

rlx_flux2/
hir_builder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Full FLUX.2 transformer HIR builder (dual-stream + single-stream blocks).
17
18use super::config::Flux2Config;
19use super::layers::time_guidance_embed;
20use super::packed::{Flux2GgufLinearPacked, Flux2PackedParams, Nvfp4LinearPacked};
21use super::rope::flux2_pos_embed;
22use super::typed_linear::{TypedLinear, TypedLinearStore};
23use super::weights::{
24    Flux2DualAttnWeights, Flux2FeedForwardWeights, Flux2ModulationWeights, Flux2NormOutWeights,
25    Flux2ParallelAttnWeights, Flux2Weights, LinearWeights, RmsNormWeight,
26};
27use crate::builder::Flux2GraphParams;
28use anyhow::Result;
29use rlx_ir::hir::{FusionPolicy, HirModule, HirNodeId};
30use rlx_ir::op::{Activation, BinaryOp, MaskKind};
31use rlx_ir::{DType, Dim, Graph, Op, Shape};
32
33/// Non-f32 parameter blobs (`set_param_typed` at compile time).
34pub type Flux2TypedParams = Vec<(String, Vec<u8>, DType)>;
35
36pub struct Flux2ForwardGraph {
37    pub hir: HirModule,
38    pub params: Flux2GraphParams,
39    pub typed_params: Flux2TypedParams,
40}
41
42/// Build the full denoiser forward graph in HIR.
43///
44/// Inputs:
45///   - `hidden` `[batch, img_seq, in_channels]`
46///   - `encoder` `[batch, txt_seq, joint_attention_dim]`
47///   - `temb` `[batch, inner_dim]` — host-side timestep+guidance embedding
48///     (see [`super::layers::time_guidance_embed`])
49pub fn build_flux2_forward_hir(
50    cfg: &Flux2Config,
51    weights: &Flux2Weights,
52    batch: usize,
53    img_seq: usize,
54    txt_seq: usize,
55    img_ids: &[f32],
56    txt_ids: &[f32],
57    packed: Option<&Flux2PackedParams>,
58    typed_linears: Option<&TypedLinearStore>,
59) -> Result<Flux2ForwardGraph> {
60    let mut hir = HirModule::new("flux2_forward").with_fusion_policy(FusionPolicy::Direct);
61    let mut params = Flux2GraphParams::new();
62    let mut typed_params = Flux2TypedParams::new();
63    let mut b = Flux2HirBuilder::new(
64        &mut hir,
65        &mut params,
66        &mut typed_params,
67        cfg,
68        weights,
69        batch,
70        img_seq,
71        txt_seq,
72        packed,
73        typed_linears,
74    );
75    b.build_forward(img_ids, txt_ids)?;
76    Ok(Flux2ForwardGraph {
77        hir,
78        params,
79        typed_params,
80    })
81}
82
83pub fn build_flux2_forward_graph(
84    cfg: &Flux2Config,
85    weights: &Flux2Weights,
86    batch: usize,
87    img_seq: usize,
88    txt_seq: usize,
89    img_ids: &[f32],
90    txt_ids: &[f32],
91) -> Result<(Graph, Flux2GraphParams)> {
92    let built = crate::flow::Flux2Flow::new(cfg, weights)
93        .batch(batch)
94        .img_seq(img_seq)
95        .txt_seq(txt_seq)
96        .position_ids(img_ids.to_vec(), txt_ids.to_vec())
97        .build_forward(img_ids, txt_ids)?;
98    let (graph, params) = rlx_core::flow_util::graph_from_built(built.model)?;
99    Ok((graph, params))
100}
101
102pub fn compile_flux2_forward(
103    cfg: &Flux2Config,
104    weights: &Flux2Weights,
105    batch: usize,
106    img_seq: usize,
107    txt_seq: usize,
108    img_ids: &[f32],
109    txt_ids: &[f32],
110    device: rlx_runtime::Device,
111    packed: Option<&Flux2PackedParams>,
112    typed_linears: Option<&TypedLinearStore>,
113    aot: Option<&rlx_runtime::AotCache>,
114) -> Result<(rlx_runtime::CompiledGraph, Flux2GraphParams)> {
115    use crate::compile_util::{compile_hir_cached, flux2_denoiser_aot_key};
116
117    super::device::assert_flux2_device_available(device)?;
118    let g = build_flux2_forward_hir(
119        cfg,
120        weights,
121        batch,
122        img_seq,
123        txt_seq,
124        img_ids,
125        txt_ids,
126        packed,
127        typed_linears,
128    )?;
129    let key = flux2_denoiser_aot_key(
130        device,
131        batch,
132        img_seq,
133        txt_seq,
134        img_ids,
135        txt_ids,
136        packed.is_some(),
137    );
138    let mut compiled = compile_hir_cached(
139        device,
140        aot,
141        &key,
142        g.hir,
143        &super::compile_util::flux2_compile_profile(),
144    )?;
145    for (name, data) in &g.params {
146        compiled.set_param(name, data);
147    }
148    for (name, data, dtype) in &g.typed_params {
149        compiled.set_param_typed(name, data, *dtype);
150    }
151    Ok((compiled, g.params))
152}
153
154/// Dual-stream section only: embed → mod → dual blocks → img hidden output.
155pub fn build_flux2_dual_section_hir(
156    cfg: &Flux2Config,
157    weights: &Flux2Weights,
158    batch: usize,
159    img_seq: usize,
160    txt_seq: usize,
161    img_ids: &[f32],
162    txt_ids: &[f32],
163) -> Result<Flux2ForwardGraph> {
164    let mut hir = HirModule::new("flux2_dual").with_fusion_policy(FusionPolicy::Direct);
165    let mut params = Flux2GraphParams::new();
166    let mut typed_params = Flux2TypedParams::new();
167    let mut b = Flux2HirBuilder::new(
168        &mut hir,
169        &mut params,
170        &mut typed_params,
171        cfg,
172        weights,
173        batch,
174        img_seq,
175        txt_seq,
176        None,
177        None,
178    );
179    let (hidden, _encoder, _cos, _sin, _temb) = b.build_dual_section(img_ids, txt_ids)?;
180    hir.outputs = vec![hidden];
181    Ok(Flux2ForwardGraph {
182        hir,
183        params,
184        typed_params,
185    })
186}
187
188pub(crate) struct Flux2HirBuilder<'a> {
189    hir: &'a mut HirModule,
190    params: &'a mut Flux2GraphParams,
191    typed_params: &'a mut Flux2TypedParams,
192    weights: &'a Flux2Weights,
193    packed: Option<&'a Flux2PackedParams>,
194    typed_linears: Option<&'a TypedLinearStore>,
195    cfg: &'a Flux2Config,
196    batch: usize,
197    img_seq: usize,
198    txt_seq: usize,
199    dim: usize,
200    heads: usize,
201    head_dim: usize,
202    eps: f32,
203    rope_dim: usize,
204    mlp_hidden: usize,
205    f: DType,
206}
207
208/// MSA + MLP modulation triples from [`Flux2HirBuilder::modulation_params`].
209pub(crate) type Flux2DoubleMod = (
210    (HirNodeId, HirNodeId, HirNodeId),
211    (HirNodeId, HirNodeId, HirNodeId),
212);
213
214impl<'a> Flux2HirBuilder<'a> {
215    fn new(
216        hir: &'a mut HirModule,
217        params: &'a mut Flux2GraphParams,
218        typed_params: &'a mut Flux2TypedParams,
219        cfg: &'a Flux2Config,
220        weights: &'a Flux2Weights,
221        batch: usize,
222        img_seq: usize,
223        txt_seq: usize,
224        packed: Option<&'a Flux2PackedParams>,
225        typed_linears: Option<&'a TypedLinearStore>,
226    ) -> Self {
227        let dim = cfg.inner_dim();
228        Self {
229            hir,
230            params,
231            typed_params,
232            weights,
233            packed,
234            typed_linears,
235            cfg,
236            batch,
237            img_seq,
238            txt_seq,
239            dim,
240            heads: cfg.num_attention_heads,
241            head_dim: cfg.attention_head_dim,
242            eps: cfg.eps as f32,
243            rope_dim: cfg.axes_dims_rope.iter().sum(),
244            mlp_hidden: cfg.ff_inner_dim(),
245            f: DType::F32,
246        }
247    }
248
249    pub(crate) fn from_emit_parts(
250        hir: &'a mut HirModule,
251        params: &'a mut Flux2GraphParams,
252        typed_params: &'a mut Flux2TypedParams,
253        cfg: &'a Flux2Config,
254        weights: &'a Flux2Weights,
255        batch: usize,
256        img_seq: usize,
257        txt_seq: usize,
258    ) -> Self {
259        Self::new(
260            hir,
261            params,
262            typed_params,
263            cfg,
264            weights,
265            batch,
266            img_seq,
267            txt_seq,
268            None,
269            None,
270        )
271    }
272
273    fn build_dual_section(
274        &mut self,
275        img_ids: &[f32],
276        txt_ids: &[f32],
277    ) -> Result<(HirNodeId, HirNodeId, HirNodeId, HirNodeId, HirNodeId)> {
278        let hidden_in = self.hir.input(
279            "hidden",
280            Shape::new(&[self.batch, self.img_seq, self.cfg.in_channels], self.f),
281        );
282        let enc_in = self.hir.input(
283            "encoder",
284            Shape::new(
285                &[self.batch, self.txt_seq, self.cfg.joint_attention_dim],
286                self.f,
287            ),
288        );
289        let temb_in = self.hir.input("temb", self.b1());
290
291        let mod_img = self.modulation_params(&self.weights.double_mod_img, "mod_img", temb_in)?;
292        let mod_txt = self.modulation_params(&self.weights.double_mod_txt, "mod_txt", temb_in)?;
293
294        let mut hidden = self.linear(
295            hidden_in,
296            &self.weights.x_embedder,
297            "x_embedder",
298            self.b3i(),
299        )?;
300        let mut encoder = self.linear(
301            enc_in,
302            &self.weights.context_embedder,
303            "context_embedder",
304            self.b3t(),
305        )?;
306
307        let (cos_id, sin_id) = self.rope_params(img_ids, txt_ids)?;
308
309        for (li, block) in self.weights.transformer_blocks.iter().enumerate() {
310            (hidden, encoder) = self.emit_dual_stream_block(
311                li, block, hidden, encoder, &mod_img, &mod_txt, cos_id, sin_id,
312            )?;
313        }
314
315        Ok((hidden, encoder, cos_id, sin_id, temb_in))
316    }
317
318    fn build_forward(&mut self, img_ids: &[f32], txt_ids: &[f32]) -> Result<()> {
319        let (hidden, encoder, cos_id, sin_id, temb_in) =
320            self.build_dual_section(img_ids, txt_ids)?;
321        let out = self.emit_single_stream_tail(hidden, encoder, cos_id, sin_id, temb_in)?;
322        self.hir.outputs = vec![out];
323        Ok(())
324    }
325
326    /// Concat img/txt streams → single-stream blocks → ada-norm → proj_out.
327    pub(crate) fn emit_single_stream_tail(
328        &mut self,
329        hidden: HirNodeId,
330        encoder: HirNodeId,
331        cos_id: HirNodeId,
332        sin_id: HirNodeId,
333        temb_in: HirNodeId,
334    ) -> Result<HirNodeId> {
335        let single_mod =
336            self.single_modulation_params(&self.weights.single_mod, "mod_single", temb_in)?;
337
338        let stream = self.concat(
339            vec![encoder, hidden],
340            1,
341            self.b3(self.txt_seq + self.img_seq),
342        );
343        let mut stream = stream;
344        for (li, block) in self.weights.single_transformer_blocks.iter().enumerate() {
345            let lp = format!("sblk{li}");
346            let n = self.layer_norm_no_affine(
347                stream,
348                self.b3(self.txt_seq + self.img_seq),
349                &format!("{lp}.n"),
350            )?;
351            let n = self.modulate(n, single_mod.0, single_mod.1, self.txt_seq + self.img_seq);
352            let attn =
353                self.parallel_attention(&block.attn, &format!("{lp}.attn"), n, cos_id, sin_id)?;
354            let attn_g = self.gate(attn, single_mod.2, self.txt_seq + self.img_seq);
355            stream = self.add(stream, attn_g, self.b3(self.txt_seq + self.img_seq));
356        }
357
358        let hidden_out = self.narrow(stream, 1, self.txt_seq, self.img_seq, self.b3i());
359        let normed = self.ada_norm_out(hidden_out, temb_in, &self.weights.norm_out)?;
360        self.linear(normed, &self.weights.proj_out, "proj_out", self.b3o())
361    }
362
363    /// One FLUX dual-stream transformer block (img + txt).
364    pub(crate) fn emit_dual_stream_block(
365        &mut self,
366        li: usize,
367        block: &super::weights::Flux2DoubleBlockWeights,
368        hidden: HirNodeId,
369        encoder: HirNodeId,
370        mod_img: &Flux2DoubleMod,
371        mod_txt: &Flux2DoubleMod,
372        cos_id: HirNodeId,
373        sin_id: HirNodeId,
374    ) -> Result<(HirNodeId, HirNodeId)> {
375        let lp = format!("blk{li}");
376        let (img_msa, img_mlp) = mod_img;
377        let (txt_msa, txt_mlp) = mod_txt;
378
379        let n1 = self.layer_norm_no_affine(hidden, self.b3i(), &format!("{lp}.n1"))?;
380        let n1 = self.modulate(n1, img_msa.0, img_msa.1, self.img_seq);
381        let nc = self.layer_norm_no_affine(encoder, self.b3t(), &format!("{lp}.nc"))?;
382        let nc = self.modulate(nc, txt_msa.0, txt_msa.1, self.txt_seq);
383
384        let (enc_a, img_a) =
385            self.dual_attention(&block.attn, &format!("{lp}.attn"), n1, nc, cos_id, sin_id)?;
386        let img_g = self.gate(img_a, img_msa.2, self.img_seq);
387        let hidden = self.add(hidden, img_g, self.b3i());
388        let txt_g = self.gate(enc_a, txt_msa.2, self.txt_seq);
389        let encoder = self.add(encoder, txt_g, self.b3t());
390
391        let n2 = self.layer_norm_no_affine(hidden, self.b3i(), &format!("{lp}.n2"))?;
392        let n2 = self.modulate_scale_shift(n2, img_mlp.1, img_mlp.0, self.img_seq);
393        let ff = self.feed_forward(&block.ff, &format!("{lp}.ff"), n2, self.img_seq)?;
394        let ff_g = self.gate(ff, img_mlp.2, self.img_seq);
395        let hidden = self.add(hidden, ff_g, self.b3i());
396
397        let nc2 = self.layer_norm_no_affine(encoder, self.b3t(), &format!("{lp}.nc2"))?;
398        let nc2 = self.modulate_scale_shift(nc2, txt_mlp.1, txt_mlp.0, self.txt_seq);
399        let ffc = self.feed_forward(&block.ff_context, &format!("{lp}.ffc"), nc2, self.txt_seq)?;
400        let ffc_g = self.gate(ffc, txt_mlp.2, self.txt_seq);
401        let encoder = self.add(encoder, ffc_g, self.b3t());
402        Ok((hidden, encoder))
403    }
404
405    fn b1(&self) -> Shape {
406        Shape::new(&[self.batch, self.dim], self.f)
407    }
408    fn b3i(&self) -> Shape {
409        self.b3(self.img_seq)
410    }
411    fn b3t(&self) -> Shape {
412        self.b3(self.txt_seq)
413    }
414    fn b3o(&self) -> Shape {
415        Shape::new(&[self.batch, self.img_seq, self.cfg.proj_out_dim()], self.f)
416    }
417    fn b3(&self, seq: usize) -> Shape {
418        Shape::new(&[self.batch, seq, self.dim], self.f)
419    }
420
421    fn register_param(&mut self, name: &str, data: Vec<f32>, shape: Shape) -> HirNodeId {
422        let id = self.hir.param(name, shape);
423        self.params.insert(name.to_string(), data);
424        id
425    }
426
427    pub(crate) fn linear(
428        &mut self,
429        x: HirNodeId,
430        lw: &LinearWeights,
431        name: &str,
432        out_shape: Shape,
433    ) -> Result<HirNodeId> {
434        if let Some(p) = self.packed.and_then(|m| m.get_nvfp4(name)) {
435            return self.linear_nvfp4(x, p, name, out_shape);
436        }
437        if let Some(p) = self.packed.and_then(|m| m.get_gguf(name)) {
438            return self.linear_gguf(x, p, name, out_shape);
439        }
440        if let Some(tl) = self.typed_linears.and_then(|t| t.get(name)) {
441            return self.linear_typed(x, tl, name, out_shape);
442        }
443        let w = self.register_param(
444            &format!("{name}.weight"),
445            lw.w_t.clone(),
446            Shape::new(&[lw.in_dim, lw.out_dim], self.f),
447        );
448        let bias = if lw.bias.iter().all(|&v| v == 0.0) {
449            None
450        } else {
451            let b = self.register_param(
452                &format!("{name}.bias"),
453                lw.bias.clone(),
454                Shape::new(&[lw.out_dim], self.f),
455            );
456            Some(b)
457        };
458        Ok(self.hir.linear(x, w, bias, None, out_shape))
459    }
460
461    fn linear_typed(
462        &mut self,
463        x: HirNodeId,
464        tl: &TypedLinear,
465        name: &str,
466        out_shape: Shape,
467    ) -> Result<HirNodeId> {
468        let w = self.register_typed_param_shaped(
469            &format!("{name}.weight"),
470            tl.weight_bytes.clone(),
471            tl.dtype,
472            Shape::new(&[tl.in_dim, tl.out_dim], tl.dtype),
473        );
474        let bias = if tl.bias.iter().all(|&v| v == 0.0) {
475            None
476        } else {
477            let b = self.register_param(
478                &format!("{name}.bias"),
479                tl.bias.clone(),
480                Shape::new(&[tl.out_dim], self.f),
481            );
482            Some(b)
483        };
484        Ok(self.hir.linear(x, w, bias, None, out_shape))
485    }
486
487    fn linear_nvfp4(
488        &mut self,
489        x: HirNodeId,
490        p: &Nvfp4LinearPacked,
491        name: &str,
492        out_shape: Shape,
493    ) -> Result<HirNodeId> {
494        use rlx_ir::QuantScheme;
495
496        let w_name = format!("{name}.weight");
497        let s_name = format!("{name}.scale");
498        let gs_name = format!("{name}.global_scale");
499        let w = self.register_typed_param(&w_name, p.w_q.clone(), DType::U8);
500        let scale = self.register_typed_param(&s_name, p.scale.clone(), DType::U8);
501        let gs = self.register_param(&gs_name, vec![p.global_scale], Shape::scalar(self.f));
502        let mut y = self.hir.dequant_matmul(
503            x,
504            w,
505            Some(scale),
506            Some(gs),
507            QuantScheme::Nvfp4Block,
508            out_shape.clone(),
509        );
510        if p.bias.iter().any(|&v| v != 0.0) {
511            let b = self.register_param(
512                &format!("{name}.bias"),
513                p.bias.clone(),
514                Shape::new(&[p.out_dim], self.f),
515            );
516            y = self
517                .hir
518                .mir(Op::Binary(BinaryOp::Add), vec![y, b], out_shape);
519        }
520        Ok(y)
521    }
522
523    fn linear_gguf(
524        &mut self,
525        x: HirNodeId,
526        p: &Flux2GgufLinearPacked,
527        name: &str,
528        out_shape: Shape,
529    ) -> Result<HirNodeId> {
530        let w_name = format!("{name}.weight");
531        let w = self.register_typed_param(&w_name, p.w_q.clone(), DType::U8);
532        let mut y = self
533            .hir
534            .dequant_matmul(x, w, None, None, p.scheme, out_shape.clone());
535        if p.bias.iter().any(|&v| v != 0.0) {
536            let b = self.register_param(
537                &format!("{name}.bias"),
538                p.bias.clone(),
539                Shape::new(&[p.out_dim], self.f),
540            );
541            y = self
542                .hir
543                .mir(Op::Binary(BinaryOp::Add), vec![y, b], out_shape);
544        }
545        Ok(y)
546    }
547
548    fn register_typed_param(&mut self, name: &str, data: Vec<u8>, dtype: DType) -> HirNodeId {
549        let shape = Shape::new(&[data.len()], dtype);
550        let id = self.hir.param(name, shape);
551        self.typed_params.push((name.to_string(), data, dtype));
552        id
553    }
554
555    fn register_typed_param_shaped(
556        &mut self,
557        name: &str,
558        data: Vec<u8>,
559        dtype: DType,
560        shape: Shape,
561    ) -> HirNodeId {
562        let id = self.hir.param(name, shape);
563        self.typed_params.push((name.to_string(), data, dtype));
564        id
565    }
566
567    fn layer_norm_no_affine(&mut self, x: HirNodeId, shape: Shape, tag: &str) -> Result<HirNodeId> {
568        let d = self.dim;
569        let g = self.register_param(
570            &format!("{tag}.ln1"),
571            vec![1.0f32; d],
572            Shape::new(&[d], self.f),
573        );
574        let b = self.register_param(
575            &format!("{tag}.ln0"),
576            vec![0.0f32; d],
577            Shape::new(&[d], self.f),
578        );
579        Ok(self.hir.mir(
580            Op::LayerNorm {
581                axis: -1,
582                eps: self.eps,
583            },
584            vec![x, g, b],
585            shape,
586        ))
587    }
588
589    pub(crate) fn modulation_params(
590        &mut self,
591        m: &Flux2ModulationWeights,
592        tag: &str,
593        temb: HirNodeId,
594    ) -> Result<Flux2DoubleMod> {
595        let h = self
596            .hir
597            .mir(Op::Activation(Activation::Silu), vec![temb], self.b1());
598        let mod_shape = Shape::new(&[self.batch, 6 * self.dim], self.f);
599        let mod_out = self.linear(h, &m.linear, tag, mod_shape)?;
600        let last = self.hir.node(mod_out).shape.rank() - 1;
601        let d = self.dim;
602        let b1 = self.b1();
603        let s0 = self.narrow(mod_out, last, 0, d, b1.clone());
604        let s1 = self.narrow(mod_out, last, d, d, b1.clone());
605        let s2 = self.narrow(mod_out, last, 2 * d, d, b1.clone());
606        let s3 = self.narrow(mod_out, last, 3 * d, d, b1.clone());
607        let s4 = self.narrow(mod_out, last, 4 * d, d, b1.clone());
608        let s5 = self.narrow(mod_out, last, 5 * d, d, b1);
609        Ok(((s0, s1, s2), (s3, s4, s5)))
610    }
611
612    fn single_modulation_params(
613        &mut self,
614        m: &Flux2ModulationWeights,
615        tag: &str,
616        temb: HirNodeId,
617    ) -> Result<(HirNodeId, HirNodeId, HirNodeId)> {
618        let h = self
619            .hir
620            .mir(Op::Activation(Activation::Silu), vec![temb], self.b1());
621        let mod_shape = Shape::new(&[self.batch, 3 * self.dim], self.f);
622        let mod_out = self.linear(h, &m.linear, tag, mod_shape)?;
623        let last = self.hir.node(mod_out).shape.rank() - 1;
624        let d = self.dim;
625        let b1 = self.b1();
626        let s0 = self.narrow(mod_out, last, 0, d, b1.clone());
627        let s1 = self.narrow(mod_out, last, d, d, b1.clone());
628        let s2 = self.narrow(mod_out, last, 2 * d, d, b1);
629        Ok((s0, s1, s2))
630    }
631
632    fn broadcast_bd(&mut self, v: HirNodeId, seq: usize) -> HirNodeId {
633        let b1d = self.reshape(v, vec![self.batch as i64, 1, self.dim as i64]);
634        self.mir_expand(b1d, vec![self.batch as i64, seq as i64, self.dim as i64])
635    }
636
637    fn modulate(
638        &mut self,
639        x: HirNodeId,
640        shift: HirNodeId,
641        scale: HirNodeId,
642        seq: usize,
643    ) -> HirNodeId {
644        let shape = self.b3(seq);
645        let shift_b = self.broadcast_bd(shift, seq);
646        let scale_b = self.broadcast_bd(scale, seq);
647        let ones = self.ones3(seq);
648        let scaled_base = self.add(ones, scale_b, shape.clone());
649        let scaled = self.mul(x, scaled_base, shape.clone());
650        self.add(scaled, shift_b, shape)
651    }
652
653    fn modulate_scale_shift(
654        &mut self,
655        x: HirNodeId,
656        scale: HirNodeId,
657        shift: HirNodeId,
658        seq: usize,
659    ) -> HirNodeId {
660        let shape = self.b3(seq);
661        let shift_b = self.broadcast_bd(shift, seq);
662        let scale_b = self.broadcast_bd(scale, seq);
663        let ones = self.ones3(seq);
664        let scaled_base = self.add(ones, scale_b, shape.clone());
665        let scaled = self.mul(x, scaled_base, shape.clone());
666        self.add(scaled, shift_b, shape)
667    }
668
669    fn gate(&mut self, x: HirNodeId, gate: HirNodeId, seq: usize) -> HirNodeId {
670        let g = self.broadcast_bd(gate, seq);
671        self.mul(x, g, self.b3(seq))
672    }
673
674    fn feed_forward(
675        &mut self,
676        ff: &Flux2FeedForwardWeights,
677        tag: &str,
678        x: HirNodeId,
679        seq: usize,
680    ) -> Result<HirNodeId> {
681        let rows = self.batch * seq;
682        let inner = ff.linear_in.out_dim / 2;
683        let flat = self.reshape(x, vec![rows as i64, self.dim as i64]);
684        let h = self.linear(
685            flat,
686            &ff.linear_in,
687            &format!("{tag}.in"),
688            Shape::new(&[rows, ff.linear_in.out_dim], self.f),
689        )?;
690        let h3 = self.reshape(
691            h,
692            vec![self.batch as i64, seq as i64, ff.linear_in.out_dim as i64],
693        );
694        let act = self.hir.mir(
695            Op::FusedSwiGLU {
696                cast_to: None,
697                gate_first: true,
698            },
699            vec![h3],
700            self.b3(seq).with_dim(2, Dim::Static(inner)),
701        );
702        let act_flat = self.reshape(act, vec![rows as i64, inner as i64]);
703        self.linear(
704            act_flat,
705            &ff.linear_out,
706            &format!("{tag}.out"),
707            Shape::new(&[rows, self.dim], self.f),
708        )
709        .map(|o| self.reshape(o, vec![self.batch as i64, seq as i64, self.dim as i64]))
710    }
711
712    fn rms_gamma(&mut self, rms: &RmsNormWeight, name: &str) -> HirNodeId {
713        let mut g = vec![0.0f32; self.dim];
714        for h in 0..self.heads {
715            g[h * self.head_dim..(h + 1) * self.head_dim].copy_from_slice(&rms.scale);
716        }
717        self.register_param(name, g, Shape::new(&[self.dim], self.f))
718    }
719
720    fn rms_norm(&mut self, x: HirNodeId, gamma: HirNodeId, shape: Shape) -> HirNodeId {
721        let beta = self.register_param(
722            &format!("rmsb_{}", self.params.len()),
723            vec![0.0f32; self.dim],
724            Shape::new(&[self.dim], self.f),
725        );
726        self.hir.mir(
727            Op::RmsNorm {
728                axis: -1,
729                eps: 1e-6,
730            },
731            vec![x, gamma, beta],
732            shape,
733        )
734    }
735
736    fn linear_rms(
737        &mut self,
738        x: HirNodeId,
739        lw: &LinearWeights,
740        rms: &RmsNormWeight,
741        name: &str,
742        shape: Shape,
743    ) -> Result<HirNodeId> {
744        let h = self.linear(x, lw, name, shape.clone())?;
745        let g = self.rms_gamma(rms, &format!("{name}.rms"));
746        Ok(self.rms_norm(h, g, shape))
747    }
748
749    fn dual_attention(
750        &mut self,
751        attn: &Flux2DualAttnWeights,
752        tag: &str,
753        hidden: HirNodeId,
754        encoder: HirNodeId,
755        cos: HirNodeId,
756        sin: HirNodeId,
757    ) -> Result<(HirNodeId, HirNodeId)> {
758        let total = self.txt_seq + self.img_seq;
759        let b3i = self.b3i();
760        let b3t = self.b3t();
761        let q_i = self.linear_rms(
762            hidden,
763            &attn.to_q,
764            &attn.norm_q,
765            &format!("{tag}.q"),
766            b3i.clone(),
767        )?;
768        let k_i = self.linear_rms(
769            hidden,
770            &attn.to_k,
771            &attn.norm_k,
772            &format!("{tag}.k"),
773            b3i.clone(),
774        )?;
775        let v_i = self.linear(hidden, &attn.to_v, &format!("{tag}.v"), b3i)?;
776        let q_t = self.linear_rms(
777            encoder,
778            &attn.add_q,
779            &attn.norm_added_q,
780            &format!("{tag}.aq"),
781            b3t.clone(),
782        )?;
783        let k_t = self.linear_rms(
784            encoder,
785            &attn.add_k,
786            &attn.norm_added_k,
787            &format!("{tag}.ak"),
788            b3t.clone(),
789        )?;
790        let v_t = self.linear(encoder, &attn.add_v, &format!("{tag}.av"), b3t)?;
791
792        let q = self.concat(vec![q_t, q_i], 1, self.b3(total));
793        let k = self.concat(vec![k_t, k_i], 1, self.b3(total));
794        let v = self.concat(vec![v_t, v_i], 1, self.b3(total));
795
796        let q = self.rope(q, cos, sin, self.b3(total));
797        let k = self.rope(k, cos, sin, self.b3(total));
798
799        let out = self.hir.attention(
800            q,
801            k,
802            v,
803            None,
804            self.heads,
805            self.head_dim,
806            MaskKind::None,
807            self.b3(total),
808        );
809
810        let txt_out = self.narrow(out, 1, 0, self.txt_seq, self.b3t());
811        let img_out = self.narrow(out, 1, self.txt_seq, self.img_seq, self.b3i());
812        let enc_proj = self.linear(txt_out, &attn.to_add_out, &format!("{tag}.ao"), self.b3t())?;
813        let img_proj = self.linear(img_out, &attn.to_out, &format!("{tag}.o"), self.b3i())?;
814        Ok((enc_proj, img_proj))
815    }
816
817    fn parallel_attention(
818        &mut self,
819        attn: &Flux2ParallelAttnWeights,
820        tag: &str,
821        x: HirNodeId,
822        cos: HirNodeId,
823        sin: HirNodeId,
824    ) -> Result<HirNodeId> {
825        let seq = self.txt_seq + self.img_seq;
826        let rows = self.batch * seq;
827        let flat = self.reshape(x, vec![rows as i64, self.dim as i64]);
828        let fused = self.linear(
829            flat,
830            &attn.to_qkv_mlp,
831            &format!("{tag}.fused"),
832            Shape::new(&[rows, attn.to_qkv_mlp.out_dim], self.f),
833        )?;
834        let fused3 = self.reshape(
835            fused,
836            vec![
837                self.batch as i64,
838                seq as i64,
839                attn.to_qkv_mlp.out_dim as i64,
840            ],
841        );
842        let last = 2;
843        let b3s = self.b3(seq);
844        let q = self.narrow(fused3, last, 0, self.dim, b3s.clone());
845        let k = self.narrow(fused3, last, self.dim, self.dim, b3s.clone());
846        let v = self.narrow(fused3, last, 2 * self.dim, self.dim, b3s.clone());
847        let mlp = self.narrow(
848            fused3,
849            last,
850            3 * self.dim,
851            2 * self.mlp_hidden,
852            Shape::new(&[self.batch, seq, 2 * self.mlp_hidden], self.f),
853        );
854
855        let nq = self.rms_gamma(&attn.norm_q, &format!("{tag}.nq"));
856        let nk = self.rms_gamma(&attn.norm_k, &format!("{tag}.nk"));
857        let q = self.rms_norm(q, nq, b3s.clone());
858        let k = self.rms_norm(k, nk, b3s.clone());
859        let q = self.rope(q, cos, sin, self.b3(seq));
860        let k = self.rope(k, cos, sin, self.b3(seq));
861        let attn_out = self.hir.attention(
862            q,
863            k,
864            v,
865            None,
866            self.heads,
867            self.head_dim,
868            MaskKind::None,
869            self.b3(seq),
870        );
871
872        let mlp_act = self.hir.mir(
873            Op::FusedSwiGLU {
874                cast_to: None,
875                gate_first: true,
876            },
877            vec![mlp],
878            self.b3(seq).with_dim(2, Dim::Static(self.mlp_hidden)),
879        );
880        let cat = self.concat(
881            vec![attn_out, mlp_act],
882            2,
883            Shape::new(&[self.batch, seq, self.dim + self.mlp_hidden], self.f),
884        );
885        let cat_flat = self.reshape(cat, vec![rows as i64, (self.dim + self.mlp_hidden) as i64]);
886        let out = self.linear(
887            cat_flat,
888            &attn.to_out,
889            &format!("{tag}.out"),
890            Shape::new(&[rows, self.dim], self.f),
891        )?;
892        Ok(self.reshape(out, vec![self.batch as i64, seq as i64, self.dim as i64]))
893    }
894
895    fn ada_norm_out(
896        &mut self,
897        x: HirNodeId,
898        temb: HirNodeId,
899        norm: &Flux2NormOutWeights,
900    ) -> Result<HirNodeId> {
901        let h = self
902            .hir
903            .mir(Op::Activation(Activation::Silu), vec![temb], self.b1());
904        let emb = self.linear(
905            h,
906            &norm.linear,
907            "norm_out",
908            Shape::new(&[self.batch, 2 * self.dim], self.f),
909        )?;
910        let last = self.hir.node(emb).shape.rank() - 1;
911        let b1 = self.b1();
912        let scale = self.narrow(emb, last, 0, self.dim, b1.clone());
913        let shift = self.narrow(emb, last, self.dim, self.dim, b1);
914        let n = self.layer_norm_no_affine(x, self.b3i(), "norm_out_ln")?;
915        let b3i = self.b3i();
916        let scale_b = self.broadcast_bd(scale, self.img_seq);
917        let shift_b = self.broadcast_bd(shift, self.img_seq);
918        let ones = self.ones3(self.img_seq);
919        let scaled_base = self.add(ones, scale_b, b3i.clone());
920        let scaled = self.mul(n, scaled_base, b3i.clone());
921        Ok(self.add(scaled, shift_b, b3i))
922    }
923
924    pub(crate) fn rope_params(
925        &mut self,
926        img_ids: &[f32],
927        txt_ids: &[f32],
928    ) -> Result<(HirNodeId, HirNodeId)> {
929        let n_axes = 4usize;
930        let total = self.txt_seq + self.img_seq;
931        let mut ids = vec![0.0f32; total * n_axes];
932        for t in 0..self.txt_seq {
933            for a in 0..n_axes {
934                ids[t * n_axes + a] = txt_ids[t * n_axes + a];
935            }
936        }
937        for t in 0..self.img_seq {
938            for a in 0..n_axes {
939                ids[(self.txt_seq + t) * n_axes + a] = img_ids[t * n_axes + a];
940            }
941        }
942        let (cos, sin) = flux2_pos_embed(self.cfg, &ids, total, n_axes);
943        let cos_id =
944            self.register_param("rope_cos", cos, Shape::new(&[total, self.rope_dim], self.f));
945        let sin_id =
946            self.register_param("rope_sin", sin, Shape::new(&[total, self.rope_dim], self.f));
947        Ok((cos_id, sin_id))
948    }
949
950    fn rope(&mut self, x: HirNodeId, cos: HirNodeId, sin: HirNodeId, shape: Shape) -> HirNodeId {
951        self.hir.mir(
952            Op::Rope {
953                head_dim: self.head_dim,
954                n_rot: self.rope_dim.min(self.head_dim),
955            },
956            vec![x, cos, sin],
957            shape,
958        )
959    }
960
961    #[allow(dead_code)]
962    fn ones1(&mut self) -> HirNodeId {
963        self.register_param(
964            &format!("ones1_{}", self.params.len()),
965            vec![1.0f32; self.dim],
966            Shape::new(&[self.dim], self.f),
967        )
968    }
969
970    fn ones3(&mut self, seq: usize) -> HirNodeId {
971        let id = self.register_param(
972            &format!("ones3_{}", self.params.len()),
973            vec![1.0f32; self.dim],
974            Shape::new(&[1, 1, self.dim], self.f),
975        );
976        self.mir_expand(id, vec![self.batch as i64, seq as i64, self.dim as i64])
977    }
978
979    fn reshape(&mut self, x: HirNodeId, new_shape: Vec<i64>) -> HirNodeId {
980        let shape = self.infer_reshape(&self.hir.node(x).shape, &new_shape);
981        self.hir.mir(Op::Reshape { new_shape }, vec![x], shape)
982    }
983
984    fn narrow(
985        &mut self,
986        x: HirNodeId,
987        axis: usize,
988        start: usize,
989        len: usize,
990        shape: Shape,
991    ) -> HirNodeId {
992        self.hir
993            .mir(Op::Narrow { axis, start, len }, vec![x], shape)
994    }
995
996    fn concat(&mut self, inputs: Vec<HirNodeId>, axis: usize, shape: Shape) -> HirNodeId {
997        self.hir.mir(Op::Concat { axis }, inputs, shape)
998    }
999
1000    fn add(&mut self, a: HirNodeId, b: HirNodeId, shape: Shape) -> HirNodeId {
1001        self.hir.mir(Op::Binary(BinaryOp::Add), vec![a, b], shape)
1002    }
1003
1004    fn mul(&mut self, a: HirNodeId, b: HirNodeId, shape: Shape) -> HirNodeId {
1005        self.hir.mir(Op::Binary(BinaryOp::Mul), vec![a, b], shape)
1006    }
1007
1008    fn mir_expand(&mut self, x: HirNodeId, target: Vec<i64>) -> HirNodeId {
1009        let shape = self.infer_reshape(&self.hir.node(x).shape, &target);
1010        self.hir.mir(
1011            Op::Expand {
1012                target_shape: target,
1013            },
1014            vec![x],
1015            shape,
1016        )
1017    }
1018
1019    fn infer_reshape(&self, input: &Shape, new_shape: &[i64]) -> Shape {
1020        let static_dims: Vec<usize> = new_shape.iter().map(|&d| d as usize).collect();
1021        Shape::new(&static_dims, input.dtype())
1022    }
1023}
1024
1025/// Host-side temb for compiled forward (timestep × 1000, optional guidance × 1000).
1026pub fn host_temb(
1027    weights: &Flux2Weights,
1028    cfg: &Flux2Config,
1029    timestep: &[f32],
1030    guidance: Option<&[f32]>,
1031) -> Result<Vec<f32>> {
1032    let t_scaled: Vec<f32> = timestep.iter().map(|t| t * 1000.0).collect();
1033    let g_scaled = guidance.map(|g| g.iter().map(|x| x * 1000.0).collect::<Vec<_>>());
1034    time_guidance_embed(
1035        &t_scaled,
1036        g_scaled.as_deref(),
1037        &weights.time_guidance,
1038        cfg.inner_dim(),
1039    )
1040}
1041
1042/// Dual-time temb for flow-map forwards: mean(embed(t), embed(t′)).
1043pub fn host_temb_dual(
1044    weights: &Flux2Weights,
1045    cfg: &Flux2Config,
1046    timestep: &[f32],
1047    timestep_target: &[f32],
1048    guidance: Option<&[f32]>,
1049) -> Result<Vec<f32>> {
1050    let t_scaled: Vec<f32> = timestep.iter().map(|t| t * 1000.0).collect();
1051    let t2_scaled: Vec<f32> = timestep_target.iter().map(|t| t * 1000.0).collect();
1052    let g_scaled = guidance.map(|g| g.iter().map(|x| x * 1000.0).collect::<Vec<_>>());
1053    let tg_tgt = weights
1054        .time_guidance_target
1055        .as_ref()
1056        .unwrap_or(&weights.time_guidance);
1057    crate::layers::time_guidance_embed_dual(
1058        &t_scaled,
1059        &t2_scaled,
1060        g_scaled.as_deref(),
1061        &weights.time_guidance,
1062        tg_tgt,
1063        cfg.inner_dim(),
1064    )
1065}
1066
1067#[cfg(test)]
1068mod tests {
1069    use super::*;
1070    use crate::{
1071        Flux2Config, Flux2ForwardInput, extract_flux2_weights, flux2_transformer_forward,
1072        prepare_weight_map, synthetic_weights,
1073    };
1074
1075    #[test]
1076    fn nvfp4_x_embedder_lowers() {
1077        use crate::synthetic_flux2_packed_tiny;
1078
1079        let cfg = Flux2Config::tiny();
1080        let wm = synthetic_weights(&cfg);
1081        let w = extract_flux2_weights(prepare_weight_map(wm), &cfg).unwrap();
1082        let packed = synthetic_flux2_packed_tiny(&cfg);
1083        let g = build_flux2_forward_hir(
1084            &cfg,
1085            &w,
1086            1,
1087            4,
1088            3,
1089            &[0.0; 16],
1090            &[0.0; 12],
1091            Some(&packed),
1092            None,
1093        )
1094        .unwrap();
1095        assert!(!g.typed_params.is_empty());
1096        g.hir.lower_to_mir().expect("lower nvfp4");
1097    }
1098
1099    #[test]
1100    fn forward_hir_lowers() {
1101        let cfg = Flux2Config::tiny();
1102        let wm = synthetic_weights(&cfg);
1103        let w = extract_flux2_weights(prepare_weight_map(wm), &cfg).unwrap();
1104        let g =
1105            build_flux2_forward_hir(&cfg, &w, 1, 4, 3, &[0.0; 16], &[0.0; 12], None, None).unwrap();
1106        assert_eq!(g.hir.outputs.len(), 1);
1107        g.hir.lower_to_mir().expect("lower");
1108    }
1109
1110    #[test]
1111    fn compiled_forward_matches_native() {
1112        let cfg = Flux2Config::tiny();
1113        let wm = synthetic_weights(&cfg);
1114        let w = extract_flux2_weights(prepare_weight_map(wm), &cfg).unwrap();
1115        let b = 1usize;
1116        let img_seq = 4usize;
1117        let txt_seq = 3usize;
1118        let hidden = (0..b * img_seq * cfg.in_channels)
1119            .map(|i| (i as f32 * 0.01).sin())
1120            .collect::<Vec<_>>();
1121        let encoder = (0..b * txt_seq * cfg.joint_attention_dim)
1122            .map(|i| (i as f32 * 0.02).cos())
1123            .collect::<Vec<_>>();
1124        let timestep = vec![0.5f32];
1125        let guidance = vec![3.5f32];
1126        let img_ids = vec![0.0f32; img_seq * 4];
1127        let txt_ids = vec![0.0f32; txt_seq * 4];
1128
1129        let native = flux2_transformer_forward(
1130            &w,
1131            &cfg,
1132            Flux2ForwardInput {
1133                hidden_states: &hidden,
1134                encoder_hidden_states: &encoder,
1135                timestep: &timestep,
1136                timestep_target: None,
1137                guidance: Some(&guidance),
1138                img_ids: &img_ids,
1139                txt_ids: &txt_ids,
1140                batch: b,
1141                img_seq,
1142                txt_seq,
1143            },
1144        )
1145        .unwrap();
1146
1147        let temb = host_temb(&w, &cfg, &timestep, Some(&guidance)).unwrap();
1148        let (mut compiled, _) = compile_flux2_forward(
1149            &cfg,
1150            &w,
1151            b,
1152            img_seq,
1153            txt_seq,
1154            &img_ids,
1155            &txt_ids,
1156            rlx_runtime::Device::Cpu,
1157            None,
1158            None,
1159            None,
1160        )
1161        .unwrap();
1162        let out = compiled
1163            .run(&[
1164                ("hidden", hidden.as_slice()),
1165                ("encoder", encoder.as_slice()),
1166                ("temb", temb.as_slice()),
1167            ])
1168            .remove(0);
1169
1170        assert_eq!(out.len(), native.len());
1171        let max_diff = native
1172            .iter()
1173            .zip(&out)
1174            .map(|(a, b)| (a - b).abs())
1175            .fold(0.0f32, f32::max);
1176        assert!(max_diff < 2e-2, "HIR vs native max_abs_diff={max_diff}");
1177    }
1178
1179    #[cfg(feature = "cuda")]
1180    #[test]
1181    fn compiled_forward_matches_native_on_cuda() {
1182        use rlx_runtime::Device;
1183
1184        if !rlx_runtime::is_available(Device::Cuda) {
1185            eprintln!("skip: CUDA not available");
1186            return;
1187        }
1188        let cfg = Flux2Config::tiny();
1189        let wm = synthetic_weights(&cfg);
1190        let w = extract_flux2_weights(prepare_weight_map(wm), &cfg).unwrap();
1191        let b = 1usize;
1192        let img_seq = 4usize;
1193        let txt_seq = 3usize;
1194        let hidden = (0..b * img_seq * cfg.in_channels)
1195            .map(|i| (i as f32 * 0.01).sin())
1196            .collect::<Vec<_>>();
1197        let encoder = (0..b * txt_seq * cfg.joint_attention_dim)
1198            .map(|i| (i as f32 * 0.02).cos())
1199            .collect::<Vec<_>>();
1200        let timestep = vec![0.5f32];
1201        let guidance = vec![3.5f32];
1202        let img_ids = vec![0.0f32; img_seq * 4];
1203        let txt_ids = vec![0.0f32; txt_seq * 4];
1204
1205        let native = flux2_transformer_forward(
1206            &w,
1207            &cfg,
1208            Flux2ForwardInput {
1209                hidden_states: &hidden,
1210                encoder_hidden_states: &encoder,
1211                timestep: &timestep,
1212                timestep_target: None,
1213                guidance: Some(&guidance),
1214                img_ids: &img_ids,
1215                txt_ids: &txt_ids,
1216                batch: b,
1217                img_seq,
1218                txt_seq,
1219            },
1220        )
1221        .unwrap();
1222
1223        let temb = host_temb(&w, &cfg, &timestep, Some(&guidance)).unwrap();
1224        let (mut compiled, _) = compile_flux2_forward(
1225            &cfg,
1226            &w,
1227            b,
1228            img_seq,
1229            txt_seq,
1230            &img_ids,
1231            &txt_ids,
1232            Device::Cuda,
1233            None,
1234            None,
1235            None,
1236        )
1237        .unwrap();
1238        let out = compiled
1239            .run(&[
1240                ("hidden", hidden.as_slice()),
1241                ("encoder", encoder.as_slice()),
1242                ("temb", temb.as_slice()),
1243            ])
1244            .remove(0);
1245
1246        assert_eq!(out.len(), native.len());
1247        let max_diff = native
1248            .iter()
1249            .zip(&out)
1250            .map(|(a, b)| (a - b).abs())
1251            .fold(0.0f32, f32::max);
1252        assert!(
1253            max_diff < 2e-2,
1254            "CUDA HIR vs native max_abs_diff={max_diff}"
1255        );
1256    }
1257}