Skip to main content

rlx_flux2/
builder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! FLUX.2 HIR graph builder (compile minimal + incremental lowering).
17//!
18//! Model builders target [`HirModule`] and lower via [`HirModule::lower_to_mir`].
19//! Minimal graph only — see [`super::hir_builder`] for the full transformer HIR.
20
21use super::config::Flux2Config;
22use super::weights::{Flux2Weights, LinearWeights};
23use anyhow::Result;
24use rlx_ir::{DType, FusionPolicy, Graph, HirModule, Shape};
25use std::collections::HashMap;
26
27/// Param tensors keyed by name for [`rlx_runtime::CompiledGraph::set_param`].
28pub type Flux2GraphParams = HashMap<String, Vec<f32>>;
29
30/// Build a compile-minimal HIR module: `x_embedder(hidden)` → `proj_out`.
31///
32/// Inputs: `hidden` `[batch, img_seq, in_channels]`.
33/// Output: `[batch, img_seq, proj_out_dim]`.
34pub fn build_flux2_minimal_hir(
35    cfg: &Flux2Config,
36    weights: &Flux2Weights,
37    batch: usize,
38    img_seq: usize,
39) -> Result<(HirModule, Flux2GraphParams)> {
40    let mut hir = HirModule::new("flux2_minimal").with_fusion_policy(FusionPolicy::Direct);
41    let mut params = Flux2GraphParams::new();
42    let f = DType::F32;
43
44    let hidden_in = hir.input("hidden", Shape::new(&[batch, img_seq, cfg.in_channels], f));
45    let embedded = linear_hir(
46        &mut hir,
47        &mut params,
48        hidden_in,
49        &weights.x_embedder,
50        "x_embedder",
51        Shape::new(&[batch, img_seq, weights.x_embedder.out_dim], f),
52    )?;
53    let out = linear_hir(
54        &mut hir,
55        &mut params,
56        embedded,
57        &weights.proj_out,
58        "proj_out",
59        Shape::new(&[batch, img_seq, cfg.proj_out_dim()], f),
60    )?;
61    hir.outputs = vec![out];
62    Ok((hir, params))
63}
64
65/// Lower minimal HIR to legacy [`Graph`] (MIR inner) for `Session::compile`.
66pub fn build_flux2_minimal_graph(
67    cfg: &Flux2Config,
68    weights: &Flux2Weights,
69    batch: usize,
70    img_seq: usize,
71) -> Result<(Graph, Flux2GraphParams)> {
72    rlx_core::flow_util::graph_from_built(crate::flow::build_flux2_minimal_built(
73        cfg, weights, batch, img_seq,
74    )?)
75}
76
77/// Compile minimal HIR on CPU (HIR → MIR → LIR).
78pub fn compile_flux2_minimal(
79    cfg: &Flux2Config,
80    weights: &Flux2Weights,
81    batch: usize,
82    img_seq: usize,
83) -> Result<(rlx_runtime::CompiledGraph, Flux2GraphParams)> {
84    use rlx_runtime::Device;
85
86    let (hir, params) = build_flux2_minimal_hir(cfg, weights, batch, img_seq)?;
87    let profile = crate::compile_util::flux2_compile_profile();
88    let mut compiled = rlx_core::flow_bridge::compile_hir_with_profile(Device::Cpu, hir, &profile)?;
89    for (name, data) in &params {
90        compiled.set_param(name, data);
91    }
92    Ok((compiled, params))
93}
94
95pub(crate) fn linear_hir(
96    hir: &mut HirModule,
97    params: &mut Flux2GraphParams,
98    x: rlx_ir::HirNodeId,
99    lw: &LinearWeights,
100    name: &str,
101    out_shape: Shape,
102) -> Result<rlx_ir::HirNodeId> {
103    let w = hir.param(
104        format!("{name}.weight"),
105        Shape::new(&[lw.in_dim, lw.out_dim], DType::F32),
106    );
107    params.insert(format!("{name}.weight"), lw.w_t.clone());
108    let bias = if lw.bias.iter().all(|&v| v == 0.0) {
109        None
110    } else {
111        let b = hir.param(
112            format!("{name}.bias"),
113            Shape::new(&[lw.out_dim], DType::F32),
114        );
115        params.insert(format!("{name}.bias"), lw.bias.clone());
116        Some(b)
117    };
118    Ok(hir.linear(x, w, bias, None, out_shape))
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use crate::{Flux2Config, extract_flux2_weights, prepare_weight_map, synthetic_weights};
125
126    #[test]
127    fn minimal_hir_lowers_to_mir() {
128        let cfg = Flux2Config::tiny();
129        let wm = synthetic_weights(&cfg);
130        let w = extract_flux2_weights(prepare_weight_map(wm), &cfg).unwrap();
131        let (hir, _) = build_flux2_minimal_hir(&cfg, &w, 1, 4).unwrap();
132        assert_eq!(hir.outputs.len(), 1);
133        let mir = hir.lower_to_mir().expect("lower");
134        assert_eq!(mir.outputs().len(), 1);
135    }
136
137    #[test]
138    fn minimal_compiles_on_cpu() {
139        let cfg = Flux2Config::tiny();
140        let wm = synthetic_weights(&cfg);
141        let w = extract_flux2_weights(prepare_weight_map(wm), &cfg).unwrap();
142        let (mut compiled, _) = compile_flux2_minimal(&cfg, &w, 1, 4).unwrap();
143        let hidden = vec![0.0f32; cfg.in_channels * 4];
144        let out = compiled.run(&[("hidden", hidden.as_slice())]);
145        assert_eq!(out[0].len(), 4 * cfg.proj_out_dim());
146    }
147}