1use anyhow::Result;
19use rlx_ir::hir::{FusionPolicy, HirModule};
20use rlx_ir::op::BinaryOp;
21use rlx_ir::{DType, Op, Shape};
22use rlx_runtime::Device;
23
24pub fn cfg_combine(pos: &[f32], neg: &[f32], scale: f32) -> Vec<f32> {
26 pos.iter()
27 .zip(neg)
28 .map(|(&p, &n)| n + scale * (p - n))
29 .collect()
30}
31
32pub struct Flux2CfgCombineGraph {
33 pub hir: HirModule,
34}
35
36pub fn emit_flux2_cfg_combine(
38 hir: &mut HirModule,
39 pos: rlx_ir::HirNodeId,
40 neg: rlx_ir::HirNodeId,
41 scale: rlx_ir::HirNodeId,
42 shape: Shape,
43) -> rlx_ir::HirNodeId {
44 let diff = hir.mir(Op::Binary(BinaryOp::Sub), vec![pos, neg], shape.clone());
45 let scaled = hir.mir(Op::Binary(BinaryOp::Mul), vec![diff, scale], shape.clone());
46 hir.mir(Op::Binary(BinaryOp::Add), vec![neg, scaled], shape)
47}
48
49pub fn build_flux2_cfg_combine_hir(
53 batch: usize,
54 seq: usize,
55 channels: usize,
56) -> Flux2CfgCombineGraph {
57 let mut hir = HirModule::new("flux2_cfg_combine").with_fusion_policy(FusionPolicy::Direct);
58 let f = DType::F32;
59 let shape = Shape::new(&[batch, seq, channels], f);
60 let pos = hir.input("pos", shape.clone());
61 let neg = hir.input("neg", shape.clone());
62 let scale = hir.input("guidance_scale", Shape::scalar(f));
63 let out = emit_flux2_cfg_combine(&mut hir, pos, neg, scale, shape);
64 hir.outputs = vec![out];
65 Flux2CfgCombineGraph { hir }
66}
67
68pub fn compile_flux2_cfg_combine(
69 batch: usize,
70 seq: usize,
71 channels: usize,
72 device: Device,
73 aot: Option<&rlx_runtime::AotCache>,
74) -> Result<rlx_runtime::CompiledGraph> {
75 use crate::compile_util::{compile_hir_cached, flux2_cfg_aot_key};
76
77 crate::device::assert_flux2_device_available(device)?;
78 let g = build_flux2_cfg_combine_hir(batch, seq, channels);
79 let key = flux2_cfg_aot_key(device, batch, seq, channels);
80 compile_hir_cached(
81 device,
82 aot,
83 &key,
84 g.hir,
85 &super::compile_util::flux2_compile_profile(),
86 )
87}
88
89pub fn build_flux2_cfg_combine_built(
91 batch: usize,
92 seq: usize,
93 channels: usize,
94) -> Result<rlx_flow::BuiltModel> {
95 use rlx_flow::{MapWeights, ModelFlow};
96
97 let f = DType::F32;
98 let shape = Shape::new(&[batch, seq, channels], f);
99 let out_shape = shape.clone();
100 ModelFlow::new("flux2_cfg_combine")
101 .input("pos", shape.clone())
102 .input("neg", shape.clone())
103 .input("guidance_scale", Shape::scalar(f))
104 .plugin_named("flux2.cfg.combine", move |emit, _| {
105 let pos = emit.flow_input("pos")?.hir_id();
106 let neg = emit.flow_input("neg")?.hir_id();
107 let scale = emit.flow_input("guidance_scale")?.hir_id();
108 let (hir, _) = emit.hir_and_params();
109 let out = emit_flux2_cfg_combine(hir, pos, neg, scale, out_shape.clone());
110 Ok(Some(emit.wrap(out, out_shape.clone())))
111 })
112 .output("output")
113 .build(&mut MapWeights::default())
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119 use rlx_runtime::Device;
120
121 #[test]
122 fn cfg_native_matches_hir_cpu() {
123 let pos = vec![1.0f32, 2.0, 3.0, 4.0];
124 let neg = vec![0.0f32, 0.5, 1.0, 1.5];
125 let scale = 2.5f32;
126 let native = cfg_combine(&pos, &neg, scale);
127
128 let mut compiled = compile_flux2_cfg_combine(1, 2, 2, Device::Cpu, None).unwrap();
129 let out = compiled
130 .run(&[
131 ("pos", pos.as_slice()),
132 ("neg", neg.as_slice()),
133 ("guidance_scale", &[scale]),
134 ])
135 .remove(0);
136
137 assert_eq!(out.len(), native.len());
138 let max = out
139 .iter()
140 .zip(&native)
141 .map(|(a, b)| (a - b).abs())
142 .fold(0.0f32, f32::max);
143 assert!(max < 1e-5, "max_diff={max}");
144 }
145}