Skip to main content

rlx_flux2/vae/
config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use anyhow::{Context, Result};
17use serde::Deserialize;
18use std::path::Path;
19
20#[derive(Debug, Clone, Deserialize)]
21pub struct Flux2VaeConfig {
22    pub in_channels: usize,
23    pub out_channels: usize,
24    pub latent_channels: usize,
25    pub layers_per_block: usize,
26    pub norm_num_groups: usize,
27    pub block_out_channels: Vec<usize>,
28    #[serde(default = "default_act_fn")]
29    pub act_fn: String,
30    #[serde(default = "default_batch_norm_eps")]
31    pub batch_norm_eps: f32,
32    #[serde(default = "default_mid_block_add_attention")]
33    pub mid_block_add_attention: bool,
34    #[serde(default = "default_use_post_quant_conv")]
35    pub use_post_quant_conv: bool,
36    #[serde(default)]
37    pub scaling_factor: f32,
38    #[serde(default)]
39    pub shift_factor: f32,
40}
41
42fn default_act_fn() -> String {
43    "silu".into()
44}
45fn default_batch_norm_eps() -> f32 {
46    1e-4
47}
48fn default_mid_block_add_attention() -> bool {
49    true
50}
51fn default_use_post_quant_conv() -> bool {
52    true
53}
54
55impl Flux2VaeConfig {
56    pub fn from_file(path: &Path) -> Result<Self> {
57        let data = std::fs::read_to_string(path).with_context(|| format!("reading {path:?}"))?;
58        Ok(serde_json::from_str(&data)?)
59    }
60
61    pub fn flux2_klein() -> Self {
62        Self {
63            in_channels: 3,
64            out_channels: 3,
65            latent_channels: 32,
66            layers_per_block: 2,
67            norm_num_groups: 32,
68            block_out_channels: vec![128, 256, 512, 512],
69            act_fn: "silu".into(),
70            batch_norm_eps: 1e-4,
71            mid_block_add_attention: true,
72            use_post_quant_conv: true,
73            scaling_factor: 1.0,
74            shift_factor: 0.0,
75        }
76    }
77
78    /// Tiny VAE for unit tests (no mid attention to shrink graph).
79    pub fn tiny() -> Self {
80        Self {
81            in_channels: 3,
82            out_channels: 3,
83            latent_channels: 4,
84            layers_per_block: 1,
85            norm_num_groups: 2,
86            block_out_channels: vec![8, 16],
87            act_fn: "silu".into(),
88            batch_norm_eps: 1e-4,
89            mid_block_add_attention: false,
90            use_post_quant_conv: true,
91            scaling_factor: 1.0,
92            shift_factor: 0.0,
93        }
94    }
95
96    pub fn bn_channels(&self) -> usize {
97        4 * self.latent_channels
98    }
99
100    /// Spatial downsample factor from RGB to latent (one halving per encoder down block except the last).
101    pub fn encode_spatial_stride(&self) -> usize {
102        1 << self.block_out_channels.len().saturating_sub(1)
103    }
104}