Skip to main content

rlx_flux2/vae/
hir_builder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! FLUX.2 VAE decoder HIR (`flux2_vae_decode` trunk on GPU backends).
17
18use super::config::Flux2VaeConfig;
19use super::weights::{
20    AttnBlockWeights, Conv2dWeight, DownEncoderBlockWeights, Flux2VaeWeights, GroupNormWeight,
21    ResnetBlockWeights, UpDecoderBlockWeights,
22};
23use crate::builder::Flux2GraphParams;
24use crate::compile_util::{
25    compile_hir_cached, flux2_vae_decoder_aot_key, flux2_vae_encoder_aot_key,
26};
27use anyhow::Result;
28use rlx_ir::hir::{HirGraphExt, HirModule, HirMut, HirNodeId};
29use rlx_ir::op::{Activation, MaskKind};
30use rlx_ir::{DType, Op, Shape};
31use rlx_runtime::Device;
32
33pub struct Flux2VaeGraph {
34    pub hir: HirModule,
35    pub params: Flux2GraphParams,
36}
37
38pub fn build_flux2_vae_hir(
39    cfg: &Flux2VaeConfig,
40    weights: &Flux2VaeWeights,
41    batch: usize,
42    h: usize,
43    w: usize,
44) -> Result<Flux2VaeGraph> {
45    let lc = cfg.latent_channels;
46    let f = DType::F32;
47    let mut hir =
48        HirModule::new("flux2_vae_decoder").with_fusion_policy(rlx_ir::hir::FusionPolicy::Direct);
49    let mut params = Flux2GraphParams::new();
50    let latents = hir.input("latents", Shape::new(&[batch, lc, h, w], f));
51    let mut b = VaeHirBuilder::from_emit_parts(&mut hir, &mut params, cfg, weights, batch, h, w);
52    let (out, _, _, _) = b.emit_decoder(latents)?;
53    hir.outputs = vec![out];
54    Ok(Flux2VaeGraph { hir, params })
55}
56
57pub fn build_flux2_vae_encoder_hir(
58    cfg: &Flux2VaeConfig,
59    weights: &Flux2VaeWeights,
60    batch: usize,
61    h: usize,
62    w: usize,
63) -> Result<Flux2VaeGraph> {
64    let in_c = cfg.in_channels;
65    let f = DType::F32;
66    let mut hir =
67        HirModule::new("flux2_vae_encoder").with_fusion_policy(rlx_ir::hir::FusionPolicy::Direct);
68    let mut params = Flux2GraphParams::new();
69    let rgb = hir.input("rgb", Shape::new(&[batch, in_c, h, w], f));
70    let mut b = VaeHirBuilder::from_emit_parts(&mut hir, &mut params, cfg, weights, batch, h, w);
71    let out = b.emit_encoder(rgb)?;
72    hir.outputs = vec![out];
73    Ok(Flux2VaeGraph { hir, params })
74}
75
76pub fn compile_flux2_vae_hir(
77    cfg: &Flux2VaeConfig,
78    weights: &Flux2VaeWeights,
79    batch: usize,
80    h: usize,
81    w: usize,
82    device: Device,
83    aot: Option<&rlx_runtime::AotCache>,
84) -> Result<(rlx_runtime::CompiledGraph, Flux2GraphParams)> {
85    crate::device::assert_flux2_device_available(device)?;
86    let g = build_flux2_vae_hir(cfg, weights, batch, h, w)?;
87    let key = flux2_vae_decoder_aot_key(device, batch, h, w);
88    let mut compiled = compile_hir_cached(
89        device,
90        aot,
91        &key,
92        g.hir,
93        &crate::compile_util::flux2_compile_profile(),
94    )?;
95    for (name, data) in &g.params {
96        compiled.set_param(name, data);
97    }
98    Ok((compiled, g.params))
99}
100
101pub fn compile_flux2_vae_encoder_hir(
102    cfg: &Flux2VaeConfig,
103    weights: &Flux2VaeWeights,
104    batch: usize,
105    h: usize,
106    w: usize,
107    device: Device,
108    aot: Option<&rlx_runtime::AotCache>,
109) -> Result<(rlx_runtime::CompiledGraph, Flux2GraphParams)> {
110    crate::device::assert_flux2_device_available(device)?;
111    let g = build_flux2_vae_encoder_hir(cfg, weights, batch, h, w)?;
112    let key = flux2_vae_encoder_aot_key(device, batch, h, w);
113    let mut compiled = compile_hir_cached(
114        device,
115        aot,
116        &key,
117        g.hir,
118        &crate::compile_util::flux2_compile_profile(),
119    )?;
120    for (name, data) in &g.params {
121        compiled.set_param(name, data);
122    }
123    Ok((compiled, g.params))
124}
125
126pub(crate) struct VaeHirBuilder<'a> {
127    hir: &'a mut HirModule,
128    params: &'a mut Flux2GraphParams,
129    cfg: &'a Flux2VaeConfig,
130    weights: &'a Flux2VaeWeights,
131    batch: usize,
132    h: usize,
133    w: usize,
134    f: DType,
135    eps: f32,
136    groups: usize,
137}
138
139impl<'a> VaeHirBuilder<'a> {
140    pub(crate) fn from_emit_parts(
141        hir: &'a mut HirModule,
142        params: &'a mut Flux2GraphParams,
143        cfg: &'a Flux2VaeConfig,
144        weights: &'a Flux2VaeWeights,
145        batch: usize,
146        h: usize,
147        w: usize,
148    ) -> Self {
149        Self {
150            hir,
151            params,
152            cfg,
153            weights,
154            batch,
155            h,
156            w,
157            f: DType::F32,
158            eps: 1e-6,
159            groups: cfg.norm_num_groups,
160        }
161    }
162
163    pub(crate) fn emit_decoder(
164        &mut self,
165        mut x: HirNodeId,
166    ) -> Result<(HirNodeId, usize, usize, usize)> {
167        let lc = self.cfg.latent_channels;
168        let mut channels = lc;
169        let mut h = self.h;
170        let mut w = self.w;
171
172        if let Some(pqc) = &self.weights.post_quant_conv {
173            x = self.conv2d_bias(x, pqc, "post_quant_conv", channels, h, w)?;
174            channels = pqc.out_c;
175        }
176        x = self.conv2d_bias(x, &self.weights.conv_in, "conv_in", channels, h, w)?;
177        channels = self.weights.conv_in.out_c;
178
179        for (i, resnet) in self.weights.mid_resnets.iter().enumerate() {
180            x = self.resnet_block(x, resnet, &format!("mid.{i}"), channels, h, w)?;
181            channels = resnet.conv2.out_c;
182        }
183        if let Some(attn) = &self.weights.mid_attn {
184            x = self.spatial_attention(x, attn, "mid.attn", channels, h, w)?;
185        }
186
187        for (i, block) in self.weights.up_blocks.iter().enumerate() {
188            let (cur, c, hh, ww) = self.up_block(x, block, &format!("up.{i}"), channels, h, w)?;
189            x = cur;
190            channels = c;
191            h = hh;
192            w = ww;
193        }
194
195        let shape = self.nchw(channels, h, w);
196        x = self.group_norm(
197            x,
198            &self.weights.conv_norm_out,
199            "conv_norm_out",
200            shape.clone(),
201        )?;
202        x = self.g().activation(Activation::Silu, x, shape.clone());
203        x = self.conv2d_bias(x, &self.weights.conv_out, "conv_out", channels, h, w)?;
204        let out_c = self.weights.conv_out.out_c;
205        Ok((x, out_c, h, w))
206    }
207
208    fn nchw(&self, c: usize, h: usize, w: usize) -> Shape {
209        Shape::new(&[self.batch, c, h, w], self.f)
210    }
211
212    fn register_param(&mut self, name: &str, data: Vec<f32>, shape: Shape) -> HirNodeId {
213        let id = self.hir.param(name, shape);
214        self.params.insert(name.to_string(), data);
215        id
216    }
217
218    fn g(&mut self) -> HirMut<'_> {
219        HirMut::new(self.hir)
220    }
221
222    pub(crate) fn emit_encoder(&mut self, mut x: HirNodeId) -> Result<HirNodeId> {
223        let in_c = self.cfg.in_channels;
224        let mut channels = in_c;
225        let mut h = self.h;
226        let mut w = self.w;
227
228        x = self.conv2d_bias(
229            x,
230            &self.weights.encoder_conv_in,
231            "encoder.conv_in",
232            channels,
233            h,
234            w,
235        )?;
236        channels = self.weights.encoder_conv_in.out_c;
237
238        for (i, block) in self.weights.encoder_down_blocks.iter().enumerate() {
239            let (cur, c, hh, ww) =
240                self.down_block(x, block, &format!("encoder.down.{i}"), channels, h, w)?;
241            x = cur;
242            channels = c;
243            h = hh;
244            w = ww;
245        }
246
247        for (i, resnet) in self.weights.encoder_mid_resnets.iter().enumerate() {
248            x = self.resnet_block(x, resnet, &format!("encoder.mid.{i}"), channels, h, w)?;
249            channels = resnet.conv2.out_c;
250        }
251        if let Some(attn) = &self.weights.encoder_mid_attn {
252            x = self.spatial_attention(x, attn, "encoder.mid.attn", channels, h, w)?;
253        }
254
255        let shape = self.nchw(channels, h, w);
256        x = self.group_norm(
257            x,
258            &self.weights.encoder_conv_norm_out,
259            "encoder.conv_norm_out",
260            shape.clone(),
261        )?;
262        x = self.g().activation(Activation::Silu, x, shape.clone());
263        x = self.conv2d_bias(
264            x,
265            &self.weights.encoder_conv_out,
266            "encoder.conv_out",
267            channels,
268            h,
269            w,
270        )?;
271        channels = self.weights.encoder_conv_out.out_c;
272
273        x = self.conv2d_bias(x, &self.weights.quant_conv, "quant_conv", channels, h, w)?;
274        let mean_c = self.weights.quant_conv.out_c / 2;
275        Ok(self.g().narrow_(x, 1, 0, mean_c))
276    }
277
278    fn group_norm(
279        &mut self,
280        x: HirNodeId,
281        gn: &GroupNormWeight,
282        name: &str,
283        shape: Shape,
284    ) -> Result<HirNodeId> {
285        let c = shape.dim(1).unwrap_static();
286        let g = self.register_param(
287            &format!("{name}.weight"),
288            gn.gamma.clone(),
289            Shape::new(&[c], self.f),
290        );
291        let b = self.register_param(
292            &format!("{name}.bias"),
293            gn.beta.clone(),
294            Shape::new(&[c], self.f),
295        );
296        let groups = self.groups;
297        let eps = self.eps;
298        Ok(self.g().group_norm(x, g, b, groups, eps))
299    }
300
301    fn conv2d_bias(
302        &mut self,
303        x: HirNodeId,
304        conv: &Conv2dWeight,
305        name: &str,
306        _in_c: usize,
307        h: usize,
308        w: usize,
309    ) -> Result<HirNodeId> {
310        let is_1x1 = conv.weight.len() == conv.out_c * conv.in_c;
311        let (kh, kw) = if is_1x1 { (1, 1) } else { (3, 3) };
312        let (pad, stride) = if is_1x1 {
313            ([0, 0], [1, 1])
314        } else {
315            ([1, 1], [1, 1])
316        };
317        let w_shape = if is_1x1 {
318            Shape::new(&[conv.out_c, conv.in_c, 1, 1], self.f)
319        } else {
320            Shape::new(&[conv.out_c, conv.in_c, 3, 3], self.f)
321        };
322        let weight = self.register_param(&format!("{name}.weight"), conv.weight.clone(), w_shape);
323        let out_shape = self.nchw(conv.out_c, h, w);
324        let y = self
325            .g()
326            .conv2d(x, weight, [kh, kw], stride, pad, 1, out_shape.clone());
327        let bias = self.register_param(
328            &format!("{name}.bias"),
329            conv.bias.clone(),
330            Shape::new(&[conv.out_c], self.f),
331        );
332        let bias4 = self.g().reshape_(bias, vec![1, conv.out_c as i64, 1, 1]);
333        let batch = self.batch;
334        let expanded = self.g().add_node(
335            Op::Expand {
336                target_shape: vec![batch as i64, conv.out_c as i64, h as i64, w as i64],
337            },
338            vec![bias4],
339            out_shape.clone(),
340        );
341        Ok(self.g().add(y, expanded))
342    }
343
344    fn resnet_block(
345        &mut self,
346        x: HirNodeId,
347        b: &ResnetBlockWeights,
348        name: &str,
349        in_c: usize,
350        h: usize,
351        w: usize,
352    ) -> Result<HirNodeId> {
353        let shape = self.nchw(in_c, h, w);
354        let mut residual = x;
355        let mut h1 = self.group_norm(x, &b.norm1, &format!("{name}.norm1"), shape.clone())?;
356        h1 = self.g().activation(Activation::Silu, h1, shape.clone());
357        h1 = self.conv2d_bias(h1, &b.conv1, &format!("{name}.conv1"), in_c, h, w)?;
358        let c1 = b.conv1.out_c;
359        let s1 = self.nchw(c1, h, w);
360        h1 = self.group_norm(h1, &b.norm2, &format!("{name}.norm2"), s1.clone())?;
361        h1 = self.g().activation(Activation::Silu, h1, s1.clone());
362        h1 = self.conv2d_bias(h1, &b.conv2, &format!("{name}.conv2"), c1, h, w)?;
363        let out_c = b.conv2.out_c;
364        if let Some(sc) = &b.shortcut {
365            residual = self.conv2d_bias(residual, sc, &format!("{name}.shortcut"), in_c, h, w)?;
366        }
367        let _out_shape = self.nchw(out_c, h, w);
368        Ok(self.g().add(h1, residual))
369    }
370
371    fn spatial_attention(
372        &mut self,
373        x: HirNodeId,
374        attn: &AttnBlockWeights,
375        name: &str,
376        channels: usize,
377        h: usize,
378        w: usize,
379    ) -> Result<HirNodeId> {
380        let shape = self.nchw(channels, h, w);
381        let normed = self.group_norm(x, &attn.norm, &format!("{name}.norm"), shape.clone())?;
382        let q = self.conv2d_bias(normed, &attn.to_q, &format!("{name}.to_q"), channels, h, w)?;
383        let k = self.conv2d_bias(normed, &attn.to_k, &format!("{name}.to_k"), channels, h, w)?;
384        let v = self.conv2d_bias(normed, &attn.to_v, &format!("{name}.to_v"), channels, h, w)?;
385        let seq = h * w;
386        let batch = self.batch;
387        let bsh = Shape::new(&[batch, seq, channels], self.f);
388        let q2 = self
389            .g()
390            .reshape_(q, vec![batch as i64, seq as i64, channels as i64]);
391        let k2 = self
392            .g()
393            .reshape_(k, vec![batch as i64, seq as i64, channels as i64]);
394        let v2 = self
395            .g()
396            .reshape_(v, vec![batch as i64, seq as i64, channels as i64]);
397        let fixed = self
398            .g()
399            .attention_kind(q2, k2, v2, 1, channels, MaskKind::None, bsh.clone());
400        let fixed4 = self.g().reshape_(
401            fixed,
402            vec![batch as i64, channels as i64, h as i64, w as i64],
403        );
404        let proj = self.conv2d_bias(
405            fixed4,
406            &attn.to_out,
407            &format!("{name}.to_out"),
408            channels,
409            h,
410            w,
411        )?;
412        Ok(self.g().add(x, proj))
413    }
414
415    fn up_block(
416        &mut self,
417        x: HirNodeId,
418        block: &UpDecoderBlockWeights,
419        name: &str,
420        mut in_c: usize,
421        h: usize,
422        w: usize,
423    ) -> Result<(HirNodeId, usize, usize, usize)> {
424        let mut cur = x;
425        for (j, resnet) in block.resnets.iter().enumerate() {
426            let out_c = resnet.conv2.out_c;
427            cur = self.resnet_block(cur, resnet, &format!("{name}.resnet.{j}"), in_c, h, w)?;
428            in_c = out_c;
429        }
430        let mut out_h = h;
431        let mut out_w = w;
432        if let Some(up) = &block.upsample {
433            let uped = self.g().resize_nearest_2x(cur);
434            out_h = h * 2;
435            out_w = w * 2;
436            cur = self.conv2d_bias(uped, up, &format!("{name}.upsample"), in_c, out_h, out_w)?;
437            in_c = up.out_c;
438        }
439        Ok((cur, in_c, out_h, out_w))
440    }
441
442    fn down_block(
443        &mut self,
444        x: HirNodeId,
445        block: &DownEncoderBlockWeights,
446        name: &str,
447        mut in_c: usize,
448        h: usize,
449        w: usize,
450    ) -> Result<(HirNodeId, usize, usize, usize)> {
451        let mut cur = x;
452        for (j, resnet) in block.resnets.iter().enumerate() {
453            let out_c = resnet.conv2.out_c;
454            cur = self.resnet_block(cur, resnet, &format!("{name}.resnet.{j}"), in_c, h, w)?;
455            in_c = out_c;
456        }
457        let mut out_h = h;
458        let mut out_w = w;
459        if let Some(down) = &block.downsample {
460            out_h = (h + 1 - 3) / 2 + 1;
461            out_w = (w + 1 - 3) / 2 + 1;
462            cur = self.conv2d_downsample(
463                cur,
464                down,
465                &format!("{name}.downsample"),
466                in_c,
467                h,
468                w,
469                out_h,
470                out_w,
471            )?;
472            in_c = down.out_c;
473        }
474        Ok((cur, in_c, out_h, out_w))
475    }
476
477    fn conv2d_downsample(
478        &mut self,
479        x: HirNodeId,
480        conv: &Conv2dWeight,
481        name: &str,
482        _in_c: usize,
483        _h: usize,
484        _w: usize,
485        out_h: usize,
486        out_w: usize,
487    ) -> Result<HirNodeId> {
488        let w_shape = Shape::new(&[conv.out_c, conv.in_c, 3, 3], self.f);
489        let weight = self.register_param(&format!("{name}.weight"), conv.weight.clone(), w_shape);
490        let out_shape = self.nchw(conv.out_c, out_h, out_w);
491        let y = self
492            .g()
493            .conv2d(x, weight, [3, 3], [2, 2], [1, 1], 1, out_shape.clone());
494        let bias = self.register_param(
495            &format!("{name}.bias"),
496            conv.bias.clone(),
497            Shape::new(&[conv.out_c], self.f),
498        );
499        let bias4 = self.g().reshape_(bias, vec![1, conv.out_c as i64, 1, 1]);
500        let batch = self.batch;
501        let expanded = self.g().add_node(
502            Op::Expand {
503                target_shape: vec![batch as i64, conv.out_c as i64, out_h as i64, out_w as i64],
504            },
505            vec![bias4],
506            out_shape.clone(),
507        );
508        Ok(self.g().add(y, expanded))
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515    use crate::vae::{Flux2VaeConfig, flux2_vae_decode, synthetic_vae_weights};
516    use rlx_runtime::Device;
517
518    #[test]
519    fn vae_hir_lowers() {
520        let cfg = Flux2VaeConfig::tiny();
521        let w = synthetic_vae_weights(&cfg);
522        let g = build_flux2_vae_hir(&cfg, &w, 1, 4, 4).unwrap();
523        g.hir.lower_to_mir().expect("lower");
524    }
525
526    #[test]
527    fn compiled_vae_encoder_matches_native() {
528        let cfg = Flux2VaeConfig::tiny();
529        let w = synthetic_vae_weights(&cfg);
530        let batch = 1usize;
531        let h = 32usize;
532        let w_px = 32usize;
533        let rgb: Vec<f32> = (0..batch * 3 * h * w_px)
534            .map(|i| (i as f32 * 0.001).sin())
535            .collect();
536
537        let native =
538            super::super::encoder::flux2_vae_encode(&w, &cfg, &rgb, batch, h, w_px).unwrap();
539
540        let (mut compiled, _) =
541            compile_flux2_vae_encoder_hir(&cfg, &w, batch, h, w_px, Device::Cpu, None).unwrap();
542        let mut out = compiled.run(&[("rgb", rgb.as_slice())]).remove(0);
543        if cfg.scaling_factor != 1.0 || cfg.shift_factor != 0.0 {
544            for v in &mut out {
545                *v = (*v - cfg.shift_factor) * cfg.scaling_factor;
546            }
547        }
548
549        assert_eq!(out.len(), native.len());
550        let max = out
551            .iter()
552            .zip(&native)
553            .map(|(a, b)| (a - b).abs())
554            .fold(0.0f32, f32::max);
555        assert!(max < 5e-2, "HIR encoder vs native max_abs_diff={max}");
556    }
557
558    #[test]
559    fn compiled_vae_matches_native() {
560        let cfg = Flux2VaeConfig::tiny();
561        let w = synthetic_vae_weights(&cfg);
562        let batch = 1usize;
563        let h = 4usize;
564        let w_px = 4usize;
565        let latents = vec![0.1f32; batch * cfg.latent_channels * h * w_px];
566
567        let native = flux2_vae_decode(&w, &cfg, &latents, batch, h, w_px).unwrap();
568
569        let (mut compiled, _) =
570            compile_flux2_vae_hir(&cfg, &w, batch, h, w_px, Device::Cpu, None).unwrap();
571        let out = compiled.run(&[("latents", latents.as_slice())]).remove(0);
572
573        assert_eq!(out.len(), native.len());
574        let up = 2usize.pow(cfg.block_out_channels.len().saturating_sub(1) as u32);
575        assert_eq!(out.len(), batch * cfg.out_channels * h * up * w_px * up);
576        let max = out
577            .iter()
578            .zip(&native)
579            .map(|(a, b)| (a - b).abs())
580            .fold(0.0f32, f32::max);
581        assert!(max < 2e-2, "HIR vs native VAE max_abs_diff={max}");
582    }
583}