1use std::fmt;
19use std::sync::Arc;
20
21use anyhow::Result;
22use rlx_flow::stream::id as stream_id;
23use rlx_flow::{BuiltModel, CompileProfile, MapWeights, ModelFlow};
24use rlx_ir::{DType, Shape};
25
26use super::config::Flux2Config;
27use super::hir_builder::{Flux2DoubleMod, Flux2HirBuilder, Flux2TypedParams};
28use super::packed::Flux2PackedParams;
29use super::typed_linear::TypedLinearStore;
30use super::weights::Flux2Weights;
31
32const MOD_IMG_KEY: &str = "flux2.mod_img";
34const MOD_TXT_KEY: &str = "flux2.mod_txt";
35const ROPE_COS_KEY: &str = "flux2.rope_cos";
36const ROPE_SIN_KEY: &str = "flux2.rope_sin";
37
38#[derive(Clone)]
40pub struct Flux2Flow<'a> {
41 cfg: &'a Flux2Config,
42 weights: &'a Flux2Weights,
43 batch: usize,
44 img_seq: usize,
45 txt_seq: usize,
46 img_ids: Arc<Vec<f32>>,
47 txt_ids: Arc<Vec<f32>>,
48 profile: Option<CompileProfile>,
49}
50
51impl fmt::Debug for Flux2Flow<'_> {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 f.debug_struct("Flux2Flow")
54 .field("batch", &self.batch)
55 .field("img_seq", &self.img_seq)
56 .field("txt_seq", &self.txt_seq)
57 .field("profile", &self.profile)
58 .finish_non_exhaustive()
59 }
60}
61
62impl<'a> Flux2Flow<'a> {
63 pub fn new(cfg: &'a Flux2Config, weights: &'a Flux2Weights) -> Self {
64 Self {
65 cfg,
66 weights,
67 batch: 1,
68 img_seq: 64,
69 txt_seq: 128,
70 img_ids: Arc::new(Vec::new()),
71 txt_ids: Arc::new(Vec::new()),
72 profile: None,
73 }
74 }
75
76 pub fn batch(mut self, batch: usize) -> Self {
77 self.batch = batch;
78 self
79 }
80
81 pub fn img_seq(mut self, seq: usize) -> Self {
82 self.img_seq = seq;
83 self
84 }
85
86 pub fn txt_seq(mut self, seq: usize) -> Self {
87 self.txt_seq = seq;
88 self
89 }
90
91 pub fn position_ids(mut self, img_ids: Vec<f32>, txt_ids: Vec<f32>) -> Self {
92 self.img_ids = Arc::new(img_ids);
93 self.txt_ids = Arc::new(txt_ids);
94 self
95 }
96
97 pub fn profile(mut self, profile: CompileProfile) -> Self {
98 self.profile = Some(profile);
99 self
100 }
101
102 pub fn build_dual_blocks(self) -> Result<BuiltModel> {
104 flux2_dual_flow(
105 "flux2_dual",
106 self.cfg,
107 self.weights,
108 self.batch,
109 self.img_seq,
110 self.txt_seq,
111 self.img_ids,
112 self.txt_ids,
113 self.profile.unwrap_or_default(),
114 )
115 .load_stream(stream_id::IMG)
116 .output("hidden")
117 .build(&mut MapWeights::default())
118 }
119
120 pub fn build_forward(self, img_ids: &[f32], txt_ids: &[f32]) -> Result<Flux2ForwardBuilt> {
122 let cfg = self.cfg.clone();
123 let batch = self.batch;
124 let img_seq = self.img_seq;
125 let txt_seq = self.txt_seq;
126 let out_shape = Shape::new(&[batch, img_seq, cfg.proj_out_dim()], DType::F32);
127 let built = flux2_dual_flow(
128 "flux2_forward",
129 self.cfg,
130 self.weights,
131 batch,
132 img_seq,
133 txt_seq,
134 Arc::new(img_ids.to_vec()),
135 Arc::new(txt_ids.to_vec()),
136 self.profile.unwrap_or_default(),
137 )
138 .plugin_named("flux2.single_tail", {
139 let cfg = cfg.clone();
140 let weights = self.weights.clone();
141 move |emit, _| {
142 let img = emit
143 .state
144 .streams
145 .get(stream_id::IMG)
146 .cloned()
147 .ok_or_else(|| anyhow::anyhow!("missing img stream after dual blocks"))?;
148 let txt = emit
149 .state
150 .streams
151 .get(stream_id::TXT)
152 .cloned()
153 .ok_or_else(|| anyhow::anyhow!("missing txt stream after dual blocks"))?;
154 let cos = emit.named(ROPE_COS_KEY)?;
155 let sin = emit.named(ROPE_SIN_KEY)?;
156 let temb = emit.flow_input("temb")?.hir_id();
157 let mut typed = Flux2TypedParams::new();
158 let out = {
159 let (hir, params) = emit.hir_and_params();
160 let mut b = Flux2HirBuilder::from_emit_parts(
161 hir, params, &mut typed, &cfg, &weights, batch, img_seq, txt_seq,
162 );
163 b.emit_single_stream_tail(img.hir_id(), txt.hir_id(), cos, sin, temb)?
164 };
165 Ok(Some(emit.wrap(out, out_shape.clone())))
166 }
167 })
168 .output("hidden")
169 .build(&mut MapWeights::default())?;
170
171 Ok(Flux2ForwardBuilt {
172 graph_params: built.params.clone(),
173 typed_params: Flux2TypedParams::new(),
174 model: built,
175 })
176 }
177
178 pub fn build_minimal(self) -> Result<BuiltModel> {
180 build_flux2_minimal_built(self.cfg, self.weights, self.batch, self.img_seq)
181 }
182}
183
184pub fn build_flux2_minimal_built(
186 cfg: &Flux2Config,
187 weights: &Flux2Weights,
188 batch: usize,
189 img_seq: usize,
190) -> Result<BuiltModel> {
191 let cfg = cfg.clone();
192 let x_embedder = weights.x_embedder.clone();
193 let proj_out = weights.proj_out.clone();
194 let in_ch = cfg.in_channels;
195 let out_dim = cfg.proj_out_dim();
196 let f = DType::F32;
197 let hidden_shape = Shape::new(&[batch, img_seq, in_ch], f);
198 let embed_shape = Shape::new(&[batch, img_seq, x_embedder.out_dim], f);
199 let out_shape = Shape::new(&[batch, img_seq, out_dim], f);
200
201 ModelFlow::new("flux2_minimal")
202 .input("hidden", hidden_shape.clone())
203 .plugin_named("flux2_minimal.embed", {
204 let x_embedder = x_embedder.clone();
205 let embed_shape = embed_shape.clone();
206 move |emit, _| {
207 let hidden = emit.flow_input("hidden")?.hir_id();
208 let hir = emit
209 .module
210 .as_hir_mut()
211 .expect("flux2 minimal flow requires HIR stage");
212 let embedded = super::builder::linear_hir(
213 hir,
214 emit.params,
215 hidden,
216 &x_embedder,
217 "x_embedder",
218 embed_shape.clone(),
219 )?;
220 Ok(Some(emit.wrap(embedded, embed_shape.clone())))
221 }
222 })
223 .plugin_named("flux2_minimal.proj", {
224 let proj_out = proj_out.clone();
225 let out_shape = out_shape.clone();
226 move |emit, primary| {
227 let embedded = primary
228 .ok_or_else(|| anyhow::anyhow!("flux2 minimal proj requires embed output"))?
229 .hir_id();
230 let hir = emit
231 .module
232 .as_hir_mut()
233 .expect("flux2 minimal flow requires HIR stage");
234 let out = super::builder::linear_hir(
235 hir,
236 emit.params,
237 embedded,
238 &proj_out,
239 "proj_out",
240 out_shape.clone(),
241 )?;
242 Ok(Some(emit.wrap(out, out_shape.clone())))
243 }
244 })
245 .output("output")
246 .build(&mut MapWeights::default())
247}
248
249pub struct Flux2ForwardBuilt {
251 pub model: BuiltModel,
252 pub typed_params: Flux2TypedParams,
253 pub graph_params: crate::builder::Flux2GraphParams,
254}
255
256pub fn compile_flux2_forward_via_flow(
258 cfg: &Flux2Config,
259 weights: &Flux2Weights,
260 batch: usize,
261 img_seq: usize,
262 txt_seq: usize,
263 img_ids: &[f32],
264 txt_ids: &[f32],
265 device: rlx_runtime::Device,
266 packed: Option<&Flux2PackedParams>,
267 typed_linears: Option<&TypedLinearStore>,
268 aot: Option<&rlx_runtime::AotCache>,
269) -> Result<(rlx_runtime::CompiledGraph, crate::builder::Flux2GraphParams)> {
270 use crate::compile_util::{compile_hir_cached, flux2_denoiser_aot_key};
271
272 super::device::assert_flux2_device_available(device)?;
273 let Flux2ForwardBuilt {
274 model,
275 typed_params,
276 graph_params,
277 } = Flux2Flow::new(cfg, weights)
278 .batch(batch)
279 .img_seq(img_seq)
280 .txt_seq(txt_seq)
281 .position_ids(img_ids.to_vec(), txt_ids.to_vec())
282 .build_forward(img_ids, txt_ids)?;
283
284 let key = format!(
285 "{}_flow",
286 flux2_denoiser_aot_key(
287 device,
288 batch,
289 img_seq,
290 txt_seq,
291 img_ids,
292 txt_ids,
293 packed.is_some()
294 )
295 );
296 let hir = model
297 .into_hir()
298 .ok_or_else(|| anyhow::anyhow!("Flux2Flow build did not produce HIR"))?;
299 let profile = CompileProfile::flux2();
300 let mut compiled = compile_hir_cached(device, aot, &key, hir, &profile)?;
301 for (name, data) in &graph_params {
302 compiled.set_param(name, data);
303 }
304 for (name, data, dtype) in &typed_params {
305 compiled.set_param_typed(name, data, *dtype);
306 }
307 let _ = (packed, typed_linears);
308 Ok((compiled, graph_params))
309}
310
311#[derive(Debug, Clone, Copy)]
313pub struct Flux2CfgCombineFlow {
314 pub batch: usize,
315 pub seq: usize,
316 pub channels: usize,
317}
318
319impl Flux2CfgCombineFlow {
320 pub fn new(batch: usize, seq: usize, channels: usize) -> Self {
321 Self {
322 batch,
323 seq,
324 channels,
325 }
326 }
327
328 pub fn build(self) -> Result<BuiltModel> {
329 super::cfg::build_flux2_cfg_combine_built(self.batch, self.seq, self.channels)
330 }
331}
332
333fn flux2_dual_flow(
334 name: &str,
335 cfg: &Flux2Config,
336 weights: &Flux2Weights,
337 batch: usize,
338 img_seq: usize,
339 txt_seq: usize,
340 img_ids: Arc<Vec<f32>>,
341 txt_ids: Arc<Vec<f32>>,
342 profile: CompileProfile,
343) -> ModelFlow {
344 let cfg = cfg.clone();
345 let weights = weights.clone();
346 let dim = cfg.inner_dim();
347 let f = DType::F32;
348 let img_shape = Shape::new(&[batch, img_seq, cfg.in_channels], f);
349 let txt_shape = Shape::new(&[batch, txt_seq, cfg.joint_attention_dim], f);
350 let temb_shape = Shape::new(&[batch, dim], f);
351
352 let mut flow = ModelFlow::new(name)
353 .with_profile(profile)
354 .input("hidden", img_shape.clone())
355 .input("encoder", txt_shape.clone())
356 .input("temb", temb_shape)
357 .bind_inputs_to_streams([("hidden", stream_id::IMG), ("encoder", stream_id::TXT)])
358 .plugin_named("flux2.embed", {
359 let cfg = cfg.clone();
360 let weights = weights.clone();
361 move |emit, _| {
362 let img = emit
363 .state
364 .streams
365 .get(stream_id::IMG)
366 .cloned()
367 .ok_or_else(|| anyhow::anyhow!("missing img stream"))?;
368 let txt = emit
369 .state
370 .streams
371 .get(stream_id::TXT)
372 .cloned()
373 .ok_or_else(|| anyhow::anyhow!("missing txt stream"))?;
374 let mut typed = Flux2TypedParams::new();
375 let (hir, params) = emit.hir_and_params();
376 let mut b = Flux2HirBuilder::from_emit_parts(
377 hir, params, &mut typed, &cfg, &weights, batch, img_seq, txt_seq,
378 );
379 let img_e = b.linear(
380 img.hir_id(),
381 &weights.x_embedder,
382 "x_embedder",
383 Shape::new(&[batch, img_seq, dim], f),
384 )?;
385 let txt_e = b.linear(
386 txt.hir_id(),
387 &weights.context_embedder,
388 "context_embedder",
389 Shape::new(&[batch, txt_seq, dim], f),
390 )?;
391 let img_out = emit.wrap(img_e, Shape::new(&[batch, img_seq, dim], f));
392 let txt_out = emit.wrap(txt_e, Shape::new(&[batch, txt_seq, dim], f));
393 emit.state
394 .streams
395 .insert(stream_id::IMG.into(), img_out.clone());
396 emit.state.streams.insert(stream_id::TXT.into(), txt_out);
397 Ok(Some(img_out))
398 }
399 })
400 .plugin_named("flux2.cond", {
401 let cfg = cfg.clone();
402 let weights = weights.clone();
403 let img_ids = img_ids.clone();
404 let txt_ids = txt_ids.clone();
405 move |emit, primary| {
406 let temb = emit.flow_input("temb")?.hir_id();
407 let mut typed = Flux2TypedParams::new();
408 let (mod_img, mod_txt, cos, sin) = {
409 let (hir, params) = emit.hir_and_params();
410 let mut b = Flux2HirBuilder::from_emit_parts(
411 hir, params, &mut typed, &cfg, &weights, batch, img_seq, txt_seq,
412 );
413 let mod_img = b.modulation_params(&weights.double_mod_img, "mod_img", temb)?;
414 let mod_txt = b.modulation_params(&weights.double_mod_txt, "mod_txt", temb)?;
415 let (cos, sin) = b.rope_params(&img_ids, &txt_ids)?;
416 (mod_img, mod_txt, cos, sin)
417 };
418 store_double_mod(emit, MOD_IMG_KEY, &mod_img);
419 store_double_mod(emit, MOD_TXT_KEY, &mod_txt);
420 emit.set_named(ROPE_COS_KEY, cos);
421 emit.set_named(ROPE_SIN_KEY, sin);
422 Ok(primary)
423 }
424 });
425
426 let block_count = weights.transformer_blocks.len();
427 for li in 0..block_count {
428 let block = weights.transformer_blocks[li].clone();
429 let cfg = cfg.clone();
430 let weights = weights.clone();
431 flow = flow.dual_stream(
432 format!("blk{li}"),
433 stream_id::IMG,
434 stream_id::TXT,
435 move |emit, img, txt| {
436 let mod_img = load_double_mod(emit, MOD_IMG_KEY)?;
437 let mod_txt = load_double_mod(emit, MOD_TXT_KEY)?;
438 let cos = emit.named(ROPE_COS_KEY)?;
439 let sin = emit.named(ROPE_SIN_KEY)?;
440 let mut typed = Flux2TypedParams::new();
441 let (h, e) = {
442 let (hir, params) = emit.hir_and_params();
443 let mut b = Flux2HirBuilder::from_emit_parts(
444 hir, params, &mut typed, &cfg, &weights, batch, img_seq, txt_seq,
445 );
446 b.emit_dual_stream_block(
447 li,
448 &block,
449 img.hir_id(),
450 txt.hir_id(),
451 &mod_img,
452 &mod_txt,
453 cos,
454 sin,
455 )?
456 };
457 Ok((
458 emit.wrap(h, img.shape.clone()),
459 emit.wrap(e, txt.shape.clone()),
460 ))
461 },
462 );
463 }
464 flow
465}
466
467fn store_double_mod(emit: &mut rlx_flow::Emit<'_>, prefix: &str, m: &Flux2DoubleMod) {
468 emit.set_named(format!("{prefix}.msa.s"), m.0.0);
469 emit.set_named(format!("{prefix}.msa.c"), m.0.1);
470 emit.set_named(format!("{prefix}.msa.g"), m.0.2);
471 emit.set_named(format!("{prefix}.mlp.s"), m.1.0);
472 emit.set_named(format!("{prefix}.mlp.c"), m.1.1);
473 emit.set_named(format!("{prefix}.mlp.g"), m.1.2);
474}
475
476fn load_double_mod(emit: &rlx_flow::Emit<'_>, prefix: &str) -> Result<Flux2DoubleMod> {
477 Ok((
478 (
479 emit.named(&format!("{prefix}.msa.s"))?,
480 emit.named(&format!("{prefix}.msa.c"))?,
481 emit.named(&format!("{prefix}.msa.g"))?,
482 ),
483 (
484 emit.named(&format!("{prefix}.mlp.s"))?,
485 emit.named(&format!("{prefix}.mlp.c"))?,
486 emit.named(&format!("{prefix}.mlp.g"))?,
487 ),
488 ))
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494 use crate::{extract_flux2_weights, prepare_weight_map, synthetic_weights};
495
496 #[test]
497 fn cfg_flow_matches_hir_node_count() {
498 let batch = 1;
499 let seq = 2;
500 let channels = 2;
501 let ref_hir = crate::cfg::build_flux2_cfg_combine_hir(batch, seq, channels).hir;
502 let built = crate::cfg::build_flux2_cfg_combine_built(batch, seq, channels).unwrap();
503 let flow_hir = built.into_hir().unwrap();
504 assert_eq!(flow_hir.len(), ref_hir.len());
505 }
506
507 #[test]
508 fn dual_block_flow_matches_builder_node_count() {
509 let cfg = Flux2Config::tiny();
510 let wm = synthetic_weights(&cfg);
511 let weights = extract_flux2_weights(prepare_weight_map(wm), &cfg).unwrap();
512 let batch = 1;
513 let img_seq = 4;
514 let txt_seq = 3;
515 let img_ids = vec![0.0f32; img_seq * 4];
516 let txt_ids = vec![0.0f32; txt_seq * 4];
517
518 let ref_hir = super::super::hir_builder::build_flux2_dual_section_hir(
519 &cfg, &weights, batch, img_seq, txt_seq, &img_ids, &txt_ids,
520 )
521 .unwrap()
522 .hir;
523
524 let built = Flux2Flow::new(&cfg, &weights)
525 .batch(batch)
526 .img_seq(img_seq)
527 .txt_seq(txt_seq)
528 .position_ids(img_ids, txt_ids)
529 .build_dual_blocks()
530 .unwrap();
531 let flow_hir = built.into_hir().unwrap();
532
533 assert_eq!(
534 flow_hir.len(),
535 ref_hir.len(),
536 "dual-stream flow should match hir_builder node count (flow={}, builder={})",
537 flow_hir.len(),
538 ref_hir.len()
539 );
540 }
541
542 #[test]
543 fn forward_flow_compile_matches_hir_cpu() {
544 use super::super::hir_builder::compile_flux2_forward;
545
546 let cfg = Flux2Config::tiny();
547 let wm = synthetic_weights(&cfg);
548 let weights = extract_flux2_weights(prepare_weight_map(wm), &cfg).unwrap();
549 let batch = 1usize;
550 let img_seq = 4usize;
551 let txt_seq = 3usize;
552 let img_ids = vec![0.0f32; img_seq * 4];
553 let txt_ids = vec![0.0f32; txt_seq * 4];
554
555 let (mut flow_c, _) = super::compile_flux2_forward_via_flow(
556 &cfg,
557 &weights,
558 batch,
559 img_seq,
560 txt_seq,
561 &img_ids,
562 &txt_ids,
563 rlx_runtime::Device::Cpu,
564 None,
565 None,
566 None,
567 )
568 .unwrap();
569 let (mut hir_c, _) = compile_flux2_forward(
570 &cfg,
571 &weights,
572 batch,
573 img_seq,
574 txt_seq,
575 &img_ids,
576 &txt_ids,
577 rlx_runtime::Device::Cpu,
578 None,
579 None,
580 None,
581 )
582 .unwrap();
583
584 let hidden = vec![0.1f32; batch * img_seq * cfg.in_channels];
585 let encoder = vec![0.2f32; batch * txt_seq * cfg.joint_attention_dim];
586 let temb =
587 super::super::hir_builder::host_temb(&weights, &cfg, &[0.5], Some(&[3.5])).unwrap();
588 let out_flow = flow_c
589 .run(&[
590 ("hidden", hidden.as_slice()),
591 ("encoder", encoder.as_slice()),
592 ("temb", temb.as_slice()),
593 ])
594 .remove(0);
595 let out_hir = hir_c
596 .run(&[
597 ("hidden", hidden.as_slice()),
598 ("encoder", encoder.as_slice()),
599 ("temb", temb.as_slice()),
600 ])
601 .remove(0);
602 assert_eq!(out_flow.len(), out_hir.len());
603 let mae: f32 = out_flow
604 .iter()
605 .zip(out_hir.iter())
606 .map(|(a, b)| (a - b).abs())
607 .sum::<f32>()
608 / out_flow.len() as f32;
609 assert!(mae < 1e-4, "flow vs hir mae={mae}");
610 }
611}