1use anyhow::Result;
19use rlx_flow::{BuiltModel, MapWeights, ModelFlow};
20use rlx_ir::{DType, Shape};
21
22use super::config::Flux2VaeConfig;
23use super::hir_builder::VaeHirBuilder;
24use super::weights::Flux2VaeWeights;
25
26fn decoder_output_hw(weights: &Flux2VaeWeights, h: usize, w: usize) -> (usize, usize) {
27 let mut hh = h;
28 let mut ww = w;
29 for block in &weights.up_blocks {
30 if block.upsample.is_some() {
31 hh *= 2;
32 ww *= 2;
33 }
34 }
35 (hh, ww)
36}
37
38#[derive(Clone)]
40pub struct Flux2VaeDecoderFlow<'a> {
41 cfg: &'a Flux2VaeConfig,
42 weights: &'a Flux2VaeWeights,
43 batch: usize,
44 h: usize,
45 w: usize,
46}
47
48impl<'a> Flux2VaeDecoderFlow<'a> {
49 pub fn new(
50 cfg: &'a Flux2VaeConfig,
51 weights: &'a Flux2VaeWeights,
52 batch: usize,
53 h: usize,
54 w: usize,
55 ) -> Self {
56 Self {
57 cfg,
58 weights,
59 batch,
60 h,
61 w,
62 }
63 }
64
65 pub fn build(self) -> Result<BuiltModel> {
66 build_flux2_vae_decoder_built(self.cfg, self.weights, self.batch, self.h, self.w)
67 }
68}
69
70#[derive(Clone)]
72pub struct Flux2VaeEncoderFlow<'a> {
73 cfg: &'a Flux2VaeConfig,
74 weights: &'a Flux2VaeWeights,
75 batch: usize,
76 h: usize,
77 w: usize,
78}
79
80impl<'a> Flux2VaeEncoderFlow<'a> {
81 pub fn new(
82 cfg: &'a Flux2VaeConfig,
83 weights: &'a Flux2VaeWeights,
84 batch: usize,
85 h: usize,
86 w: usize,
87 ) -> Self {
88 Self {
89 cfg,
90 weights,
91 batch,
92 h,
93 w,
94 }
95 }
96
97 pub fn build(self) -> Result<BuiltModel> {
98 build_flux2_vae_encoder_built(self.cfg, self.weights, self.batch, self.h, self.w)
99 }
100}
101
102pub fn build_flux2_vae_decoder_built(
103 cfg: &Flux2VaeConfig,
104 weights: &Flux2VaeWeights,
105 batch: usize,
106 h: usize,
107 w: usize,
108) -> Result<BuiltModel> {
109 let f = DType::F32;
110 let lc = cfg.latent_channels;
111 let in_shape = Shape::new(&[batch, lc, h, w], f);
112 let (out_h, out_w) = decoder_output_hw(weights, h, w);
113 let out_shape = Shape::new(&[batch, cfg.out_channels, out_h, out_w], f);
114
115 let cfg = cfg.clone();
116 let weights = weights.clone();
117 ModelFlow::new("flux2_vae_decoder")
118 .input("latents", in_shape)
119 .plugin_named("flux2_vae.decoder", move |emit, input| {
120 let latents = input
121 .ok_or_else(|| anyhow::anyhow!("VAE decoder requires latents input"))?
122 .hir_id();
123 let (hir, params) = emit.hir_and_params();
124 let mut b = VaeHirBuilder::from_emit_parts(hir, params, &cfg, &weights, batch, h, w);
125 let (out, _, _, _) = b.emit_decoder(latents)?;
126 Ok(Some(emit.wrap(out, out_shape.clone())))
127 })
128 .output("rgb")
129 .build(&mut MapWeights::default())
130}
131
132pub fn build_flux2_vae_encoder_built(
133 cfg: &Flux2VaeConfig,
134 weights: &Flux2VaeWeights,
135 batch: usize,
136 h: usize,
137 w: usize,
138) -> Result<BuiltModel> {
139 let f = DType::F32;
140 let in_c = cfg.in_channels;
141 let in_shape = Shape::new(&[batch, in_c, h, w], f);
142 let mean_c = weights.quant_conv.out_c / 2;
143 let out_shape = Shape::new(&[batch, mean_c, h, w], f);
144
145 let cfg = cfg.clone();
146 let weights = weights.clone();
147 ModelFlow::new("flux2_vae_encoder")
148 .input("rgb", in_shape)
149 .plugin_named("flux2_vae.encoder", move |emit, input| {
150 let rgb = input
151 .ok_or_else(|| anyhow::anyhow!("VAE encoder requires rgb input"))?
152 .hir_id();
153 let (hir, params) = emit.hir_and_params();
154 let mut b = VaeHirBuilder::from_emit_parts(hir, params, &cfg, &weights, batch, h, w);
155 let out = b.emit_encoder(rgb)?;
156 Ok(Some(emit.wrap(out, out_shape.clone())))
157 })
158 .output("latents")
159 .build(&mut MapWeights::default())
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use crate::vae::{
166 Flux2VaeConfig, build_flux2_vae_encoder_hir, build_flux2_vae_hir, synthetic_vae_weights,
167 };
168
169 #[test]
170 fn vae_decoder_flow_matches_hir_node_count() {
171 let cfg = Flux2VaeConfig::tiny();
172 let w = synthetic_vae_weights(&cfg);
173 let batch = 1;
174 let h = 4;
175 let w_px = 4;
176
177 let ref_hir = build_flux2_vae_hir(&cfg, &w, batch, h, w_px).unwrap().hir;
178 let built = Flux2VaeDecoderFlow::new(&cfg, &w, batch, h, w_px)
179 .build()
180 .unwrap();
181 let flow_hir = built.into_hir().unwrap();
182
183 assert_eq!(
184 flow_hir.len(),
185 ref_hir.len(),
186 "VAE decoder flow should match hir_builder node count"
187 );
188 }
189
190 #[test]
191 fn vae_encoder_flow_matches_hir_node_count() {
192 let cfg = Flux2VaeConfig::tiny();
193 let w = synthetic_vae_weights(&cfg);
194 let batch = 1;
195 let h = 32;
196 let w_px = 32;
197
198 let ref_hir = build_flux2_vae_encoder_hir(&cfg, &w, batch, h, w_px)
199 .unwrap()
200 .hir;
201 let built = Flux2VaeEncoderFlow::new(&cfg, &w, batch, h, w_px)
202 .build()
203 .unwrap();
204 let flow_hir = built.into_hir().unwrap();
205
206 assert_eq!(
207 flow_hir.len(),
208 ref_hir.len(),
209 "VAE encoder flow should match hir_builder node count"
210 );
211 }
212}