Skip to main content

rlx_flux2/
config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! FLUX.2 transformer configuration (matches HuggingFace / diffusers / BFL).
17
18use anyhow::{Context, Result, bail};
19use rlx_core::gguf_support::gguf_architecture_str;
20use rlx_gguf::{GgufFile, MetaValue};
21use serde::Deserialize;
22use std::path::Path;
23
24/// FLUX.2 rectified-flow transformer (denoiser) configuration.
25#[derive(Debug, Clone, Deserialize)]
26pub struct Flux2Config {
27    #[serde(default = "default_patch_size")]
28    pub patch_size: usize,
29    #[serde(default = "default_in_channels")]
30    pub in_channels: usize,
31    pub out_channels: Option<usize>,
32    #[serde(default = "default_num_layers")]
33    pub num_layers: usize,
34    #[serde(default = "default_num_single_layers")]
35    pub num_single_layers: usize,
36    #[serde(default = "default_attention_head_dim")]
37    pub attention_head_dim: usize,
38    #[serde(default = "default_num_attention_heads")]
39    pub num_attention_heads: usize,
40    #[serde(default = "default_joint_attention_dim")]
41    pub joint_attention_dim: usize,
42    #[serde(default = "default_timestep_guidance_channels")]
43    pub timestep_guidance_channels: usize,
44    #[serde(default = "default_mlp_ratio")]
45    pub mlp_ratio: f64,
46    #[serde(default = "default_axes_dims_rope")]
47    pub axes_dims_rope: Vec<usize>,
48    #[serde(default = "default_rope_theta")]
49    pub rope_theta: usize,
50    #[serde(default = "default_eps")]
51    pub eps: f64,
52    #[serde(default = "default_guidance_embeds")]
53    pub guidance_embeds: bool,
54}
55
56fn default_patch_size() -> usize {
57    1
58}
59fn default_in_channels() -> usize {
60    128
61}
62fn default_num_layers() -> usize {
63    8
64}
65fn default_num_single_layers() -> usize {
66    48
67}
68fn default_attention_head_dim() -> usize {
69    128
70}
71fn default_num_attention_heads() -> usize {
72    48
73}
74fn default_joint_attention_dim() -> usize {
75    15360
76}
77fn default_timestep_guidance_channels() -> usize {
78    256
79}
80fn default_mlp_ratio() -> f64 {
81    3.0
82}
83fn default_axes_dims_rope() -> Vec<usize> {
84    vec![32, 32, 32, 32]
85}
86fn default_rope_theta() -> usize {
87    2000
88}
89fn default_eps() -> f64 {
90    1e-6
91}
92fn default_guidance_embeds() -> bool {
93    true
94}
95
96impl Flux2Config {
97    pub fn from_file(path: &Path) -> Result<Self> {
98        let data = std::fs::read_to_string(path).with_context(|| format!("reading {path:?}"))?;
99        Ok(serde_json::from_str(&data)?)
100    }
101
102    /// Hidden width (`num_attention_heads * attention_head_dim`).
103    pub fn inner_dim(&self) -> usize {
104        self.num_attention_heads * self.attention_head_dim
105    }
106
107    pub fn ff_inner_dim(&self) -> usize {
108        (self.inner_dim() as f64 * self.mlp_ratio) as usize
109    }
110
111    pub fn out_ch(&self) -> usize {
112        self.out_channels.unwrap_or(self.in_channels)
113    }
114
115    pub fn proj_out_dim(&self) -> usize {
116        self.patch_size * self.patch_size * self.out_ch()
117    }
118
119    /// FLUX.2 [dev] defaults (32B-class; not runnable on commodity RAM at F32).
120    pub fn flux2_dev() -> Self {
121        Self {
122            patch_size: 1,
123            in_channels: 128,
124            out_channels: None,
125            num_layers: 8,
126            num_single_layers: 48,
127            attention_head_dim: 128,
128            num_attention_heads: 48,
129            joint_attention_dim: 15360,
130            timestep_guidance_channels: 256,
131            mlp_ratio: 3.0,
132            axes_dims_rope: vec![32, 32, 32, 32],
133            rope_theta: 2000,
134            eps: 1e-6,
135            guidance_embeds: true,
136        }
137    }
138
139    /// FLUX.2 [klein] 4B-style defaults (guidance embedder optional).
140    pub fn flux2_klein_4b() -> Self {
141        Self {
142            num_layers: 4,
143            num_single_layers: 16,
144            num_attention_heads: 24,
145            attention_head_dim: 128,
146            joint_attention_dim: 7680,
147            guidance_embeds: false,
148            ..Self::flux2_dev()
149        }
150    }
151
152    /// FLUX.2 [klein] 9B defaults (BFL `Klein9BParams`: 8 double + 24 single, 32 heads).
153    pub fn flux2_klein_9b() -> Self {
154        Self {
155            num_layers: 8,
156            num_single_layers: 24,
157            num_attention_heads: 32,
158            attention_head_dim: 128,
159            joint_attention_dim: 12288,
160            guidance_embeds: false,
161            ..Self::flux2_dev()
162        }
163    }
164
165    /// Infer variant from checkpoint tensor names (BFL `double_blocks.*` or diffusers `transformer_blocks.*`).
166    pub fn infer_from_weight_keys<'a>(keys: impl IntoIterator<Item = &'a str>) -> Self {
167        let keys: Vec<&str> = keys.into_iter().collect();
168        let double = max_block_layers(&keys, &["double_blocks.", "transformer_blocks."]);
169        let single = max_block_layers(&keys, &["single_blocks.", "single_transformer_blocks."]);
170        let guidance = keys.iter().any(|k| {
171            k.contains("guidance_in.") || k.contains("time_guidance_embed.guidance_embedder.")
172        });
173        match (double, single) {
174            (8, 24) => Self::flux2_klein_9b(),
175            (4, 16) => Self::flux2_klein_4b(),
176            (8, 48) => Self::flux2_dev(),
177            (d, s) if d > 0 && s > 0 => {
178                let mut cfg = Self::flux2_klein_9b();
179                cfg.num_layers = d;
180                cfg.num_single_layers = s;
181                cfg.guidance_embeds = guidance;
182                cfg
183            }
184            _ => Self::flux2_klein_9b(),
185        }
186    }
187
188    /// Read `flux.*` metadata when present; otherwise infer from `general.basename`.
189    pub fn from_gguf(raw: &GgufFile) -> Result<Self> {
190        let arch = gguf_architecture_str(raw).unwrap_or("flux");
191        if arch != "flux" {
192            bail!("Flux2Config::from_gguf expected architecture `flux`, got {arch}");
193        }
194        if let Some(name) = raw
195            .metadata
196            .get("general.basename")
197            .and_then(MetaValue::as_str)
198        {
199            let lower = name.to_lowercase();
200            if lower.contains("klein") && (lower.contains("9b") || lower.contains("9-b")) {
201                return Ok(Self::flux2_klein_9b());
202            }
203            if lower.contains("klein") {
204                return Ok(Self::flux2_klein_4b());
205            }
206            if lower.contains("dev") {
207                return Ok(Self::flux2_dev());
208            }
209        }
210        Ok(Self::infer_from_weight_keys(
211            raw.tensors.keys().map(|s| s.as_str()),
212        ))
213    }
214
215    /// Tiny config for unit tests and graph minimal builds.
216    pub fn tiny() -> Self {
217        Self {
218            patch_size: 1,
219            in_channels: 8,
220            out_channels: None,
221            num_layers: 1,
222            num_single_layers: 1,
223            attention_head_dim: 16,
224            num_attention_heads: 2,
225            joint_attention_dim: 16,
226            timestep_guidance_channels: 32,
227            mlp_ratio: 2.0,
228            axes_dims_rope: vec![4, 4, 4, 4],
229            rope_theta: 2000,
230            eps: 1e-6,
231            guidance_embeds: true,
232        }
233    }
234}
235
236fn max_block_layers(keys: &[&str], prefixes: &[&str]) -> usize {
237    let mut max_idx = 0usize;
238    for key in keys {
239        for pfx in prefixes {
240            if let Some(rest) = key.strip_prefix(pfx) {
241                if let Ok(i) = rest.split('.').next().unwrap_or("").parse::<usize>() {
242                    max_idx = max_idx.max(i + 1);
243                }
244            }
245        }
246    }
247    max_idx
248}
249
250#[cfg(test)]
251mod gguf_config_tests {
252    use super::*;
253
254    #[test]
255    fn infer_klein_9b_from_bfl_keys() {
256        let keys = [
257            "double_blocks.0.img_attn.qkv.weight",
258            "double_blocks.7.img_attn.qkv.weight",
259            "single_blocks.23.linear1.weight",
260        ];
261        let cfg = Flux2Config::infer_from_weight_keys(keys);
262        assert_eq!(cfg.num_layers, 8);
263        assert_eq!(cfg.num_single_layers, 24);
264        assert!(!cfg.guidance_embeds);
265    }
266}