Skip to main content

rlx_flux2/
cfg.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Classifier-free guidance: `neg + scale * (pos - neg)` (matches Modular MAX).
17
18use anyhow::Result;
19use rlx_ir::hir::{FusionPolicy, HirModule};
20use rlx_ir::op::BinaryOp;
21use rlx_ir::{DType, Op, Shape};
22use rlx_runtime::Device;
23
24/// Native CFG blend in float32.
25pub fn cfg_combine(pos: &[f32], neg: &[f32], scale: f32) -> Vec<f32> {
26    pos.iter()
27        .zip(neg)
28        .map(|(&p, &n)| n + scale * (p - n))
29        .collect()
30}
31
32pub struct Flux2CfgCombineGraph {
33    pub hir: HirModule,
34}
35
36/// Emit `neg + guidance_scale * (pos - neg)` into an existing HIR module (inputs must exist).
37pub fn emit_flux2_cfg_combine(
38    hir: &mut HirModule,
39    pos: rlx_ir::HirNodeId,
40    neg: rlx_ir::HirNodeId,
41    scale: rlx_ir::HirNodeId,
42    shape: Shape,
43) -> rlx_ir::HirNodeId {
44    let diff = hir.mir(Op::Binary(BinaryOp::Sub), vec![pos, neg], shape.clone());
45    let scaled = hir.mir(Op::Binary(BinaryOp::Mul), vec![diff, scale], shape.clone());
46    hir.mir(Op::Binary(BinaryOp::Add), vec![neg, scaled], shape)
47}
48
49/// HIR graph: `neg + guidance_scale * (pos - neg)` in f32.
50///
51/// Inputs: `pos`, `neg` `[batch, seq, channels]`; `guidance_scale` scalar f32.
52pub fn build_flux2_cfg_combine_hir(
53    batch: usize,
54    seq: usize,
55    channels: usize,
56) -> Flux2CfgCombineGraph {
57    let mut hir = HirModule::new("flux2_cfg_combine").with_fusion_policy(FusionPolicy::Direct);
58    let f = DType::F32;
59    let shape = Shape::new(&[batch, seq, channels], f);
60    let pos = hir.input("pos", shape.clone());
61    let neg = hir.input("neg", shape.clone());
62    let scale = hir.input("guidance_scale", Shape::scalar(f));
63    let out = emit_flux2_cfg_combine(&mut hir, pos, neg, scale, shape);
64    hir.outputs = vec![out];
65    Flux2CfgCombineGraph { hir }
66}
67
68pub fn compile_flux2_cfg_combine(
69    batch: usize,
70    seq: usize,
71    channels: usize,
72    device: Device,
73    aot: Option<&rlx_runtime::AotCache>,
74) -> Result<rlx_runtime::CompiledGraph> {
75    use crate::compile_util::{compile_hir_cached, flux2_cfg_aot_key};
76
77    crate::device::assert_flux2_device_available(device)?;
78    let g = build_flux2_cfg_combine_hir(batch, seq, channels);
79    let key = flux2_cfg_aot_key(device, batch, seq, channels);
80    compile_hir_cached(
81        device,
82        aot,
83        &key,
84        g.hir,
85        &super::compile_util::flux2_compile_profile(),
86    )
87}
88
89/// Tier-0 CFG combine via [`ModelFlow`](rlx_flow::ModelFlow).
90pub fn build_flux2_cfg_combine_built(
91    batch: usize,
92    seq: usize,
93    channels: usize,
94) -> Result<rlx_flow::BuiltModel> {
95    use rlx_flow::{MapWeights, ModelFlow};
96
97    let f = DType::F32;
98    let shape = Shape::new(&[batch, seq, channels], f);
99    let out_shape = shape.clone();
100    ModelFlow::new("flux2_cfg_combine")
101        .input("pos", shape.clone())
102        .input("neg", shape.clone())
103        .input("guidance_scale", Shape::scalar(f))
104        .plugin_named("flux2.cfg.combine", move |emit, _| {
105            let pos = emit.flow_input("pos")?.hir_id();
106            let neg = emit.flow_input("neg")?.hir_id();
107            let scale = emit.flow_input("guidance_scale")?.hir_id();
108            let (hir, _) = emit.hir_and_params();
109            let out = emit_flux2_cfg_combine(hir, pos, neg, scale, out_shape.clone());
110            Ok(Some(emit.wrap(out, out_shape.clone())))
111        })
112        .output("output")
113        .build(&mut MapWeights::default())
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119    use rlx_runtime::Device;
120
121    #[test]
122    fn cfg_native_matches_hir_cpu() {
123        let pos = vec![1.0f32, 2.0, 3.0, 4.0];
124        let neg = vec![0.0f32, 0.5, 1.0, 1.5];
125        let scale = 2.5f32;
126        let native = cfg_combine(&pos, &neg, scale);
127
128        let mut compiled = compile_flux2_cfg_combine(1, 2, 2, Device::Cpu, None).unwrap();
129        let out = compiled
130            .run(&[
131                ("pos", pos.as_slice()),
132                ("neg", neg.as_slice()),
133                ("guidance_scale", &[scale]),
134            ])
135            .remove(0);
136
137        assert_eq!(out.len(), native.len());
138        let max = out
139            .iter()
140            .zip(&native)
141            .map(|(a, b)| (a - b).abs())
142            .fold(0.0f32, f32::max);
143        assert!(max < 1e-5, "max_diff={max}");
144    }
145}