Skip to main content

rlx_flux2/vae/
flow.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Native FLUX.2 VAE flows — decoder (`latents` → RGB) and encoder (`rgb` → latents).
17
18use anyhow::Result;
19use rlx_flow::{BuiltModel, MapWeights, ModelFlow};
20use rlx_ir::{DType, Shape};
21
22use super::config::Flux2VaeConfig;
23use super::hir_builder::VaeHirBuilder;
24use super::weights::Flux2VaeWeights;
25
26fn decoder_output_hw(weights: &Flux2VaeWeights, h: usize, w: usize) -> (usize, usize) {
27    let mut hh = h;
28    let mut ww = w;
29    for block in &weights.up_blocks {
30        if block.upsample.is_some() {
31            hh *= 2;
32            ww *= 2;
33        }
34    }
35    (hh, ww)
36}
37
38/// Tier-0 FLUX.2 VAE decoder flow.
39#[derive(Clone)]
40pub struct Flux2VaeDecoderFlow<'a> {
41    cfg: &'a Flux2VaeConfig,
42    weights: &'a Flux2VaeWeights,
43    batch: usize,
44    h: usize,
45    w: usize,
46}
47
48impl<'a> Flux2VaeDecoderFlow<'a> {
49    pub fn new(
50        cfg: &'a Flux2VaeConfig,
51        weights: &'a Flux2VaeWeights,
52        batch: usize,
53        h: usize,
54        w: usize,
55    ) -> Self {
56        Self {
57            cfg,
58            weights,
59            batch,
60            h,
61            w,
62        }
63    }
64
65    pub fn build(self) -> Result<BuiltModel> {
66        build_flux2_vae_decoder_built(self.cfg, self.weights, self.batch, self.h, self.w)
67    }
68}
69
70/// Tier-0 FLUX.2 VAE encoder flow.
71#[derive(Clone)]
72pub struct Flux2VaeEncoderFlow<'a> {
73    cfg: &'a Flux2VaeConfig,
74    weights: &'a Flux2VaeWeights,
75    batch: usize,
76    h: usize,
77    w: usize,
78}
79
80impl<'a> Flux2VaeEncoderFlow<'a> {
81    pub fn new(
82        cfg: &'a Flux2VaeConfig,
83        weights: &'a Flux2VaeWeights,
84        batch: usize,
85        h: usize,
86        w: usize,
87    ) -> Self {
88        Self {
89            cfg,
90            weights,
91            batch,
92            h,
93            w,
94        }
95    }
96
97    pub fn build(self) -> Result<BuiltModel> {
98        build_flux2_vae_encoder_built(self.cfg, self.weights, self.batch, self.h, self.w)
99    }
100}
101
102pub fn build_flux2_vae_decoder_built(
103    cfg: &Flux2VaeConfig,
104    weights: &Flux2VaeWeights,
105    batch: usize,
106    h: usize,
107    w: usize,
108) -> Result<BuiltModel> {
109    let f = DType::F32;
110    let lc = cfg.latent_channels;
111    let in_shape = Shape::new(&[batch, lc, h, w], f);
112    let (out_h, out_w) = decoder_output_hw(weights, h, w);
113    let out_shape = Shape::new(&[batch, cfg.out_channels, out_h, out_w], f);
114
115    let cfg = cfg.clone();
116    let weights = weights.clone();
117    ModelFlow::new("flux2_vae_decoder")
118        .input("latents", in_shape)
119        .plugin_named("flux2_vae.decoder", move |emit, input| {
120            let latents = input
121                .ok_or_else(|| anyhow::anyhow!("VAE decoder requires latents input"))?
122                .hir_id();
123            let (hir, params) = emit.hir_and_params();
124            let mut b = VaeHirBuilder::from_emit_parts(hir, params, &cfg, &weights, batch, h, w);
125            let (out, _, _, _) = b.emit_decoder(latents)?;
126            Ok(Some(emit.wrap(out, out_shape.clone())))
127        })
128        .output("rgb")
129        .build(&mut MapWeights::default())
130}
131
132pub fn build_flux2_vae_encoder_built(
133    cfg: &Flux2VaeConfig,
134    weights: &Flux2VaeWeights,
135    batch: usize,
136    h: usize,
137    w: usize,
138) -> Result<BuiltModel> {
139    let f = DType::F32;
140    let in_c = cfg.in_channels;
141    let in_shape = Shape::new(&[batch, in_c, h, w], f);
142    let mean_c = weights.quant_conv.out_c / 2;
143    let out_shape = Shape::new(&[batch, mean_c, h, w], f);
144
145    let cfg = cfg.clone();
146    let weights = weights.clone();
147    ModelFlow::new("flux2_vae_encoder")
148        .input("rgb", in_shape)
149        .plugin_named("flux2_vae.encoder", move |emit, input| {
150            let rgb = input
151                .ok_or_else(|| anyhow::anyhow!("VAE encoder requires rgb input"))?
152                .hir_id();
153            let (hir, params) = emit.hir_and_params();
154            let mut b = VaeHirBuilder::from_emit_parts(hir, params, &cfg, &weights, batch, h, w);
155            let out = b.emit_encoder(rgb)?;
156            Ok(Some(emit.wrap(out, out_shape.clone())))
157        })
158        .output("latents")
159        .build(&mut MapWeights::default())
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use crate::vae::{
166        Flux2VaeConfig, build_flux2_vae_encoder_hir, build_flux2_vae_hir, synthetic_vae_weights,
167    };
168
169    #[test]
170    fn vae_decoder_flow_matches_hir_node_count() {
171        let cfg = Flux2VaeConfig::tiny();
172        let w = synthetic_vae_weights(&cfg);
173        let batch = 1;
174        let h = 4;
175        let w_px = 4;
176
177        let ref_hir = build_flux2_vae_hir(&cfg, &w, batch, h, w_px).unwrap().hir;
178        let built = Flux2VaeDecoderFlow::new(&cfg, &w, batch, h, w_px)
179            .build()
180            .unwrap();
181        let flow_hir = built.into_hir().unwrap();
182
183        assert_eq!(
184            flow_hir.len(),
185            ref_hir.len(),
186            "VAE decoder flow should match hir_builder node count"
187        );
188    }
189
190    #[test]
191    fn vae_encoder_flow_matches_hir_node_count() {
192        let cfg = Flux2VaeConfig::tiny();
193        let w = synthetic_vae_weights(&cfg);
194        let batch = 1;
195        let h = 32;
196        let w_px = 32;
197
198        let ref_hir = build_flux2_vae_encoder_hir(&cfg, &w, batch, h, w_px)
199            .unwrap()
200            .hir;
201        let built = Flux2VaeEncoderFlow::new(&cfg, &w, batch, h, w_px)
202            .build()
203            .unwrap();
204        let flow_hir = built.into_hir().unwrap();
205
206        assert_eq!(
207            flow_hir.len(),
208            ref_hir.len(),
209            "VAE encoder flow should match hir_builder node count"
210        );
211    }
212}