1use super::config::Flux2Config;
22use super::weights::{Flux2Weights, LinearWeights};
23use anyhow::Result;
24use rlx_ir::{DType, FusionPolicy, Graph, HirModule, Shape};
25use std::collections::HashMap;
26
27pub type Flux2GraphParams = HashMap<String, Vec<f32>>;
29
30pub fn build_flux2_minimal_hir(
35 cfg: &Flux2Config,
36 weights: &Flux2Weights,
37 batch: usize,
38 img_seq: usize,
39) -> Result<(HirModule, Flux2GraphParams)> {
40 let mut hir = HirModule::new("flux2_minimal").with_fusion_policy(FusionPolicy::Direct);
41 let mut params = Flux2GraphParams::new();
42 let f = DType::F32;
43
44 let hidden_in = hir.input("hidden", Shape::new(&[batch, img_seq, cfg.in_channels], f));
45 let embedded = linear_hir(
46 &mut hir,
47 &mut params,
48 hidden_in,
49 &weights.x_embedder,
50 "x_embedder",
51 Shape::new(&[batch, img_seq, weights.x_embedder.out_dim], f),
52 )?;
53 let out = linear_hir(
54 &mut hir,
55 &mut params,
56 embedded,
57 &weights.proj_out,
58 "proj_out",
59 Shape::new(&[batch, img_seq, cfg.proj_out_dim()], f),
60 )?;
61 hir.outputs = vec![out];
62 Ok((hir, params))
63}
64
65pub fn build_flux2_minimal_graph(
67 cfg: &Flux2Config,
68 weights: &Flux2Weights,
69 batch: usize,
70 img_seq: usize,
71) -> Result<(Graph, Flux2GraphParams)> {
72 rlx_core::flow_util::graph_from_built(crate::flow::build_flux2_minimal_built(
73 cfg, weights, batch, img_seq,
74 )?)
75}
76
77pub fn compile_flux2_minimal(
79 cfg: &Flux2Config,
80 weights: &Flux2Weights,
81 batch: usize,
82 img_seq: usize,
83) -> Result<(rlx_runtime::CompiledGraph, Flux2GraphParams)> {
84 use rlx_runtime::Device;
85
86 let (hir, params) = build_flux2_minimal_hir(cfg, weights, batch, img_seq)?;
87 let profile = crate::compile_util::flux2_compile_profile();
88 let mut compiled = rlx_core::flow_bridge::compile_hir_with_profile(Device::Cpu, hir, &profile)?;
89 for (name, data) in ¶ms {
90 compiled.set_param(name, data);
91 }
92 Ok((compiled, params))
93}
94
95pub(crate) fn linear_hir(
96 hir: &mut HirModule,
97 params: &mut Flux2GraphParams,
98 x: rlx_ir::HirNodeId,
99 lw: &LinearWeights,
100 name: &str,
101 out_shape: Shape,
102) -> Result<rlx_ir::HirNodeId> {
103 let w = hir.param(
104 format!("{name}.weight"),
105 Shape::new(&[lw.in_dim, lw.out_dim], DType::F32),
106 );
107 params.insert(format!("{name}.weight"), lw.w_t.clone());
108 let bias = if lw.bias.iter().all(|&v| v == 0.0) {
109 None
110 } else {
111 let b = hir.param(
112 format!("{name}.bias"),
113 Shape::new(&[lw.out_dim], DType::F32),
114 );
115 params.insert(format!("{name}.bias"), lw.bias.clone());
116 Some(b)
117 };
118 Ok(hir.linear(x, w, bias, None, out_shape))
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use crate::{Flux2Config, extract_flux2_weights, prepare_weight_map, synthetic_weights};
125
126 #[test]
127 fn minimal_hir_lowers_to_mir() {
128 let cfg = Flux2Config::tiny();
129 let wm = synthetic_weights(&cfg);
130 let w = extract_flux2_weights(prepare_weight_map(wm), &cfg).unwrap();
131 let (hir, _) = build_flux2_minimal_hir(&cfg, &w, 1, 4).unwrap();
132 assert_eq!(hir.outputs.len(), 1);
133 let mir = hir.lower_to_mir().expect("lower");
134 assert_eq!(mir.outputs().len(), 1);
135 }
136
137 #[test]
138 fn minimal_compiles_on_cpu() {
139 let cfg = Flux2Config::tiny();
140 let wm = synthetic_weights(&cfg);
141 let w = extract_flux2_weights(prepare_weight_map(wm), &cfg).unwrap();
142 let (mut compiled, _) = compile_flux2_minimal(&cfg, &w, 1, 4).unwrap();
143 let hidden = vec![0.0f32; cfg.in_channels * 4];
144 let out = compiled.run(&[("hidden", hidden.as_slice())]);
145 assert_eq!(out[0].len(), 4 * cfg.proj_out_dim());
146 }
147}