Skip to main content

rlx_flux2/
flow.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Fluent FLUX.2 assembly — dual-stream blocks via `rlx-flow` streams + plugins.
17
18use std::fmt;
19use std::sync::Arc;
20
21use anyhow::Result;
22use rlx_flow::stream::id as stream_id;
23use rlx_flow::{BuiltModel, CompileProfile, MapWeights, ModelFlow};
24use rlx_ir::{DType, Shape};
25
26use super::config::Flux2Config;
27use super::hir_builder::{Flux2DoubleMod, Flux2HirBuilder, Flux2TypedParams};
28use super::packed::Flux2PackedParams;
29use super::typed_linear::TypedLinearStore;
30use super::weights::Flux2Weights;
31
32/// Named handles for FLUX dual-stream conditioning (stored in [`rlx_flow::FlowState::named`]).
33const MOD_IMG_KEY: &str = "flux2.mod_img";
34const MOD_TXT_KEY: &str = "flux2.mod_txt";
35const ROPE_COS_KEY: &str = "flux2.rope_cos";
36const ROPE_SIN_KEY: &str = "flux2.rope_sin";
37
38/// Tier-0 FLUX.2 dual-stream flow builder.
39#[derive(Clone)]
40pub struct Flux2Flow<'a> {
41    cfg: &'a Flux2Config,
42    weights: &'a Flux2Weights,
43    batch: usize,
44    img_seq: usize,
45    txt_seq: usize,
46    img_ids: Arc<Vec<f32>>,
47    txt_ids: Arc<Vec<f32>>,
48    profile: Option<CompileProfile>,
49}
50
51impl fmt::Debug for Flux2Flow<'_> {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        f.debug_struct("Flux2Flow")
54            .field("batch", &self.batch)
55            .field("img_seq", &self.img_seq)
56            .field("txt_seq", &self.txt_seq)
57            .field("profile", &self.profile)
58            .finish_non_exhaustive()
59    }
60}
61
62impl<'a> Flux2Flow<'a> {
63    pub fn new(cfg: &'a Flux2Config, weights: &'a Flux2Weights) -> Self {
64        Self {
65            cfg,
66            weights,
67            batch: 1,
68            img_seq: 64,
69            txt_seq: 128,
70            img_ids: Arc::new(Vec::new()),
71            txt_ids: Arc::new(Vec::new()),
72            profile: None,
73        }
74    }
75
76    pub fn batch(mut self, batch: usize) -> Self {
77        self.batch = batch;
78        self
79    }
80
81    pub fn img_seq(mut self, seq: usize) -> Self {
82        self.img_seq = seq;
83        self
84    }
85
86    pub fn txt_seq(mut self, seq: usize) -> Self {
87        self.txt_seq = seq;
88        self
89    }
90
91    pub fn position_ids(mut self, img_ids: Vec<f32>, txt_ids: Vec<f32>) -> Self {
92        self.img_ids = Arc::new(img_ids);
93        self.txt_ids = Arc::new(txt_ids);
94        self
95    }
96
97    pub fn profile(mut self, profile: CompileProfile) -> Self {
98        self.profile = Some(profile);
99        self
100    }
101
102    /// Build img/txt embed inputs + dual-stream transformer blocks (no single-stream tail).
103    pub fn build_dual_blocks(self) -> Result<BuiltModel> {
104        flux2_dual_flow(
105            "flux2_dual",
106            self.cfg,
107            self.weights,
108            self.batch,
109            self.img_seq,
110            self.txt_seq,
111            self.img_ids,
112            self.txt_ids,
113            self.profile.unwrap_or_default(),
114        )
115        .load_stream(stream_id::IMG)
116        .output("hidden")
117        .build(&mut MapWeights::default())
118    }
119
120    /// Full denoiser forward: native dual-stream blocks + single-stream tail.
121    pub fn build_forward(self, img_ids: &[f32], txt_ids: &[f32]) -> Result<Flux2ForwardBuilt> {
122        let cfg = self.cfg.clone();
123        let batch = self.batch;
124        let img_seq = self.img_seq;
125        let txt_seq = self.txt_seq;
126        let out_shape = Shape::new(&[batch, img_seq, cfg.proj_out_dim()], DType::F32);
127        let built = flux2_dual_flow(
128            "flux2_forward",
129            self.cfg,
130            self.weights,
131            batch,
132            img_seq,
133            txt_seq,
134            Arc::new(img_ids.to_vec()),
135            Arc::new(txt_ids.to_vec()),
136            self.profile.unwrap_or_default(),
137        )
138        .plugin_named("flux2.single_tail", {
139            let cfg = cfg.clone();
140            let weights = self.weights.clone();
141            move |emit, _| {
142                let img = emit
143                    .state
144                    .streams
145                    .get(stream_id::IMG)
146                    .cloned()
147                    .ok_or_else(|| anyhow::anyhow!("missing img stream after dual blocks"))?;
148                let txt = emit
149                    .state
150                    .streams
151                    .get(stream_id::TXT)
152                    .cloned()
153                    .ok_or_else(|| anyhow::anyhow!("missing txt stream after dual blocks"))?;
154                let cos = emit.named(ROPE_COS_KEY)?;
155                let sin = emit.named(ROPE_SIN_KEY)?;
156                let temb = emit.flow_input("temb")?.hir_id();
157                let mut typed = Flux2TypedParams::new();
158                let out = {
159                    let (hir, params) = emit.hir_and_params();
160                    let mut b = Flux2HirBuilder::from_emit_parts(
161                        hir, params, &mut typed, &cfg, &weights, batch, img_seq, txt_seq,
162                    );
163                    b.emit_single_stream_tail(img.hir_id(), txt.hir_id(), cos, sin, temb)?
164                };
165                Ok(Some(emit.wrap(out, out_shape.clone())))
166            }
167        })
168        .output("hidden")
169        .build(&mut MapWeights::default())?;
170
171        Ok(Flux2ForwardBuilt {
172            graph_params: built.params.clone(),
173            typed_params: Flux2TypedParams::new(),
174            model: built,
175        })
176    }
177
178    /// Compile-minimal path: x_embedder → proj_out.
179    pub fn build_minimal(self) -> Result<BuiltModel> {
180        build_flux2_minimal_built(self.cfg, self.weights, self.batch, self.img_seq)
181    }
182}
183
184/// Compile-minimal FLUX.2 flow: `hidden` → x_embedder → proj_out.
185pub fn build_flux2_minimal_built(
186    cfg: &Flux2Config,
187    weights: &Flux2Weights,
188    batch: usize,
189    img_seq: usize,
190) -> Result<BuiltModel> {
191    let cfg = cfg.clone();
192    let x_embedder = weights.x_embedder.clone();
193    let proj_out = weights.proj_out.clone();
194    let in_ch = cfg.in_channels;
195    let out_dim = cfg.proj_out_dim();
196    let f = DType::F32;
197    let hidden_shape = Shape::new(&[batch, img_seq, in_ch], f);
198    let embed_shape = Shape::new(&[batch, img_seq, x_embedder.out_dim], f);
199    let out_shape = Shape::new(&[batch, img_seq, out_dim], f);
200
201    ModelFlow::new("flux2_minimal")
202        .input("hidden", hidden_shape.clone())
203        .plugin_named("flux2_minimal.embed", {
204            let x_embedder = x_embedder.clone();
205            let embed_shape = embed_shape.clone();
206            move |emit, _| {
207                let hidden = emit.flow_input("hidden")?.hir_id();
208                let hir = emit
209                    .module
210                    .as_hir_mut()
211                    .expect("flux2 minimal flow requires HIR stage");
212                let embedded = super::builder::linear_hir(
213                    hir,
214                    emit.params,
215                    hidden,
216                    &x_embedder,
217                    "x_embedder",
218                    embed_shape.clone(),
219                )?;
220                Ok(Some(emit.wrap(embedded, embed_shape.clone())))
221            }
222        })
223        .plugin_named("flux2_minimal.proj", {
224            let proj_out = proj_out.clone();
225            let out_shape = out_shape.clone();
226            move |emit, primary| {
227                let embedded = primary
228                    .ok_or_else(|| anyhow::anyhow!("flux2 minimal proj requires embed output"))?
229                    .hir_id();
230                let hir = emit
231                    .module
232                    .as_hir_mut()
233                    .expect("flux2 minimal flow requires HIR stage");
234                let out = super::builder::linear_hir(
235                    hir,
236                    emit.params,
237                    embedded,
238                    &proj_out,
239                    "proj_out",
240                    out_shape.clone(),
241                )?;
242                Ok(Some(emit.wrap(out, out_shape.clone())))
243            }
244        })
245        .output("output")
246        .build(&mut MapWeights::default())
247}
248
249/// Full forward build product (includes non-f32 typed param blobs).
250pub struct Flux2ForwardBuilt {
251    pub model: BuiltModel,
252    pub typed_params: Flux2TypedParams,
253    pub graph_params: crate::builder::Flux2GraphParams,
254}
255
256/// Compile denoiser via tier-0 [`Flux2Flow`] wrapper (same numerics as [`super::hir_builder::compile_flux2_forward`]).
257pub fn compile_flux2_forward_via_flow(
258    cfg: &Flux2Config,
259    weights: &Flux2Weights,
260    batch: usize,
261    img_seq: usize,
262    txt_seq: usize,
263    img_ids: &[f32],
264    txt_ids: &[f32],
265    device: rlx_runtime::Device,
266    packed: Option<&Flux2PackedParams>,
267    typed_linears: Option<&TypedLinearStore>,
268    aot: Option<&rlx_runtime::AotCache>,
269) -> Result<(rlx_runtime::CompiledGraph, crate::builder::Flux2GraphParams)> {
270    use crate::compile_util::{compile_hir_cached, flux2_denoiser_aot_key};
271
272    super::device::assert_flux2_device_available(device)?;
273    let Flux2ForwardBuilt {
274        model,
275        typed_params,
276        graph_params,
277    } = Flux2Flow::new(cfg, weights)
278        .batch(batch)
279        .img_seq(img_seq)
280        .txt_seq(txt_seq)
281        .position_ids(img_ids.to_vec(), txt_ids.to_vec())
282        .build_forward(img_ids, txt_ids)?;
283
284    let key = format!(
285        "{}_flow",
286        flux2_denoiser_aot_key(
287            device,
288            batch,
289            img_seq,
290            txt_seq,
291            img_ids,
292            txt_ids,
293            packed.is_some()
294        )
295    );
296    let hir = model
297        .into_hir()
298        .ok_or_else(|| anyhow::anyhow!("Flux2Flow build did not produce HIR"))?;
299    let profile = CompileProfile::flux2();
300    let mut compiled = compile_hir_cached(device, aot, &key, hir, &profile)?;
301    for (name, data) in &graph_params {
302        compiled.set_param(name, data);
303    }
304    for (name, data, dtype) in &typed_params {
305        compiled.set_param_typed(name, data, *dtype);
306    }
307    let _ = (packed, typed_linears);
308    Ok((compiled, graph_params))
309}
310
311/// Tier-0 CFG combine: `neg + scale * (pos - neg)`.
312#[derive(Debug, Clone, Copy)]
313pub struct Flux2CfgCombineFlow {
314    pub batch: usize,
315    pub seq: usize,
316    pub channels: usize,
317}
318
319impl Flux2CfgCombineFlow {
320    pub fn new(batch: usize, seq: usize, channels: usize) -> Self {
321        Self {
322            batch,
323            seq,
324            channels,
325        }
326    }
327
328    pub fn build(self) -> Result<BuiltModel> {
329        super::cfg::build_flux2_cfg_combine_built(self.batch, self.seq, self.channels)
330    }
331}
332
333fn flux2_dual_flow(
334    name: &str,
335    cfg: &Flux2Config,
336    weights: &Flux2Weights,
337    batch: usize,
338    img_seq: usize,
339    txt_seq: usize,
340    img_ids: Arc<Vec<f32>>,
341    txt_ids: Arc<Vec<f32>>,
342    profile: CompileProfile,
343) -> ModelFlow {
344    let cfg = cfg.clone();
345    let weights = weights.clone();
346    let dim = cfg.inner_dim();
347    let f = DType::F32;
348    let img_shape = Shape::new(&[batch, img_seq, cfg.in_channels], f);
349    let txt_shape = Shape::new(&[batch, txt_seq, cfg.joint_attention_dim], f);
350    let temb_shape = Shape::new(&[batch, dim], f);
351
352    let mut flow = ModelFlow::new(name)
353        .with_profile(profile)
354        .input("hidden", img_shape.clone())
355        .input("encoder", txt_shape.clone())
356        .input("temb", temb_shape)
357        .bind_inputs_to_streams([("hidden", stream_id::IMG), ("encoder", stream_id::TXT)])
358        .plugin_named("flux2.embed", {
359            let cfg = cfg.clone();
360            let weights = weights.clone();
361            move |emit, _| {
362                let img = emit
363                    .state
364                    .streams
365                    .get(stream_id::IMG)
366                    .cloned()
367                    .ok_or_else(|| anyhow::anyhow!("missing img stream"))?;
368                let txt = emit
369                    .state
370                    .streams
371                    .get(stream_id::TXT)
372                    .cloned()
373                    .ok_or_else(|| anyhow::anyhow!("missing txt stream"))?;
374                let mut typed = Flux2TypedParams::new();
375                let (hir, params) = emit.hir_and_params();
376                let mut b = Flux2HirBuilder::from_emit_parts(
377                    hir, params, &mut typed, &cfg, &weights, batch, img_seq, txt_seq,
378                );
379                let img_e = b.linear(
380                    img.hir_id(),
381                    &weights.x_embedder,
382                    "x_embedder",
383                    Shape::new(&[batch, img_seq, dim], f),
384                )?;
385                let txt_e = b.linear(
386                    txt.hir_id(),
387                    &weights.context_embedder,
388                    "context_embedder",
389                    Shape::new(&[batch, txt_seq, dim], f),
390                )?;
391                let img_out = emit.wrap(img_e, Shape::new(&[batch, img_seq, dim], f));
392                let txt_out = emit.wrap(txt_e, Shape::new(&[batch, txt_seq, dim], f));
393                emit.state
394                    .streams
395                    .insert(stream_id::IMG.into(), img_out.clone());
396                emit.state.streams.insert(stream_id::TXT.into(), txt_out);
397                Ok(Some(img_out))
398            }
399        })
400        .plugin_named("flux2.cond", {
401            let cfg = cfg.clone();
402            let weights = weights.clone();
403            let img_ids = img_ids.clone();
404            let txt_ids = txt_ids.clone();
405            move |emit, primary| {
406                let temb = emit.flow_input("temb")?.hir_id();
407                let mut typed = Flux2TypedParams::new();
408                let (mod_img, mod_txt, cos, sin) = {
409                    let (hir, params) = emit.hir_and_params();
410                    let mut b = Flux2HirBuilder::from_emit_parts(
411                        hir, params, &mut typed, &cfg, &weights, batch, img_seq, txt_seq,
412                    );
413                    let mod_img = b.modulation_params(&weights.double_mod_img, "mod_img", temb)?;
414                    let mod_txt = b.modulation_params(&weights.double_mod_txt, "mod_txt", temb)?;
415                    let (cos, sin) = b.rope_params(&img_ids, &txt_ids)?;
416                    (mod_img, mod_txt, cos, sin)
417                };
418                store_double_mod(emit, MOD_IMG_KEY, &mod_img);
419                store_double_mod(emit, MOD_TXT_KEY, &mod_txt);
420                emit.set_named(ROPE_COS_KEY, cos);
421                emit.set_named(ROPE_SIN_KEY, sin);
422                Ok(primary)
423            }
424        });
425
426    let block_count = weights.transformer_blocks.len();
427    for li in 0..block_count {
428        let block = weights.transformer_blocks[li].clone();
429        let cfg = cfg.clone();
430        let weights = weights.clone();
431        flow = flow.dual_stream(
432            format!("blk{li}"),
433            stream_id::IMG,
434            stream_id::TXT,
435            move |emit, img, txt| {
436                let mod_img = load_double_mod(emit, MOD_IMG_KEY)?;
437                let mod_txt = load_double_mod(emit, MOD_TXT_KEY)?;
438                let cos = emit.named(ROPE_COS_KEY)?;
439                let sin = emit.named(ROPE_SIN_KEY)?;
440                let mut typed = Flux2TypedParams::new();
441                let (h, e) = {
442                    let (hir, params) = emit.hir_and_params();
443                    let mut b = Flux2HirBuilder::from_emit_parts(
444                        hir, params, &mut typed, &cfg, &weights, batch, img_seq, txt_seq,
445                    );
446                    b.emit_dual_stream_block(
447                        li,
448                        &block,
449                        img.hir_id(),
450                        txt.hir_id(),
451                        &mod_img,
452                        &mod_txt,
453                        cos,
454                        sin,
455                    )?
456                };
457                Ok((
458                    emit.wrap(h, img.shape.clone()),
459                    emit.wrap(e, txt.shape.clone()),
460                ))
461            },
462        );
463    }
464    flow
465}
466
467fn store_double_mod(emit: &mut rlx_flow::Emit<'_>, prefix: &str, m: &Flux2DoubleMod) {
468    emit.set_named(format!("{prefix}.msa.s"), m.0.0);
469    emit.set_named(format!("{prefix}.msa.c"), m.0.1);
470    emit.set_named(format!("{prefix}.msa.g"), m.0.2);
471    emit.set_named(format!("{prefix}.mlp.s"), m.1.0);
472    emit.set_named(format!("{prefix}.mlp.c"), m.1.1);
473    emit.set_named(format!("{prefix}.mlp.g"), m.1.2);
474}
475
476fn load_double_mod(emit: &rlx_flow::Emit<'_>, prefix: &str) -> Result<Flux2DoubleMod> {
477    Ok((
478        (
479            emit.named(&format!("{prefix}.msa.s"))?,
480            emit.named(&format!("{prefix}.msa.c"))?,
481            emit.named(&format!("{prefix}.msa.g"))?,
482        ),
483        (
484            emit.named(&format!("{prefix}.mlp.s"))?,
485            emit.named(&format!("{prefix}.mlp.c"))?,
486            emit.named(&format!("{prefix}.mlp.g"))?,
487        ),
488    ))
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494    use crate::{extract_flux2_weights, prepare_weight_map, synthetic_weights};
495
496    #[test]
497    fn cfg_flow_matches_hir_node_count() {
498        let batch = 1;
499        let seq = 2;
500        let channels = 2;
501        let ref_hir = crate::cfg::build_flux2_cfg_combine_hir(batch, seq, channels).hir;
502        let built = crate::cfg::build_flux2_cfg_combine_built(batch, seq, channels).unwrap();
503        let flow_hir = built.into_hir().unwrap();
504        assert_eq!(flow_hir.len(), ref_hir.len());
505    }
506
507    #[test]
508    fn dual_block_flow_matches_builder_node_count() {
509        let cfg = Flux2Config::tiny();
510        let wm = synthetic_weights(&cfg);
511        let weights = extract_flux2_weights(prepare_weight_map(wm), &cfg).unwrap();
512        let batch = 1;
513        let img_seq = 4;
514        let txt_seq = 3;
515        let img_ids = vec![0.0f32; img_seq * 4];
516        let txt_ids = vec![0.0f32; txt_seq * 4];
517
518        let ref_hir = super::super::hir_builder::build_flux2_dual_section_hir(
519            &cfg, &weights, batch, img_seq, txt_seq, &img_ids, &txt_ids,
520        )
521        .unwrap()
522        .hir;
523
524        let built = Flux2Flow::new(&cfg, &weights)
525            .batch(batch)
526            .img_seq(img_seq)
527            .txt_seq(txt_seq)
528            .position_ids(img_ids, txt_ids)
529            .build_dual_blocks()
530            .unwrap();
531        let flow_hir = built.into_hir().unwrap();
532
533        assert_eq!(
534            flow_hir.len(),
535            ref_hir.len(),
536            "dual-stream flow should match hir_builder node count (flow={}, builder={})",
537            flow_hir.len(),
538            ref_hir.len()
539        );
540    }
541
542    #[test]
543    fn forward_flow_compile_matches_hir_cpu() {
544        use super::super::hir_builder::compile_flux2_forward;
545
546        let cfg = Flux2Config::tiny();
547        let wm = synthetic_weights(&cfg);
548        let weights = extract_flux2_weights(prepare_weight_map(wm), &cfg).unwrap();
549        let batch = 1usize;
550        let img_seq = 4usize;
551        let txt_seq = 3usize;
552        let img_ids = vec![0.0f32; img_seq * 4];
553        let txt_ids = vec![0.0f32; txt_seq * 4];
554
555        let (mut flow_c, _) = super::compile_flux2_forward_via_flow(
556            &cfg,
557            &weights,
558            batch,
559            img_seq,
560            txt_seq,
561            &img_ids,
562            &txt_ids,
563            rlx_runtime::Device::Cpu,
564            None,
565            None,
566            None,
567        )
568        .unwrap();
569        let (mut hir_c, _) = compile_flux2_forward(
570            &cfg,
571            &weights,
572            batch,
573            img_seq,
574            txt_seq,
575            &img_ids,
576            &txt_ids,
577            rlx_runtime::Device::Cpu,
578            None,
579            None,
580            None,
581        )
582        .unwrap();
583
584        let hidden = vec![0.1f32; batch * img_seq * cfg.in_channels];
585        let encoder = vec![0.2f32; batch * txt_seq * cfg.joint_attention_dim];
586        let temb =
587            super::super::hir_builder::host_temb(&weights, &cfg, &[0.5], Some(&[3.5])).unwrap();
588        let out_flow = flow_c
589            .run(&[
590                ("hidden", hidden.as_slice()),
591                ("encoder", encoder.as_slice()),
592                ("temb", temb.as_slice()),
593            ])
594            .remove(0);
595        let out_hir = hir_c
596            .run(&[
597                ("hidden", hidden.as_slice()),
598                ("encoder", encoder.as_slice()),
599                ("temb", temb.as_slice()),
600            ])
601            .remove(0);
602        assert_eq!(out_flow.len(), out_hir.len());
603        let mae: f32 = out_flow
604            .iter()
605            .zip(out_hir.iter())
606            .map(|(a, b)| (a - b).abs())
607            .sum::<f32>()
608            / out_flow.len() as f32;
609        assert!(mae < 1e-4, "flow vs hir mae={mae}");
610    }
611}