1use anyhow::{Context, Result, bail};
19use rlx_core::gguf_support::gguf_architecture_str;
20use rlx_gguf::{GgufFile, MetaValue};
21use serde::Deserialize;
22use std::path::Path;
23
24#[derive(Debug, Clone, Deserialize)]
26pub struct Flux2Config {
27 #[serde(default = "default_patch_size")]
28 pub patch_size: usize,
29 #[serde(default = "default_in_channels")]
30 pub in_channels: usize,
31 pub out_channels: Option<usize>,
32 #[serde(default = "default_num_layers")]
33 pub num_layers: usize,
34 #[serde(default = "default_num_single_layers")]
35 pub num_single_layers: usize,
36 #[serde(default = "default_attention_head_dim")]
37 pub attention_head_dim: usize,
38 #[serde(default = "default_num_attention_heads")]
39 pub num_attention_heads: usize,
40 #[serde(default = "default_joint_attention_dim")]
41 pub joint_attention_dim: usize,
42 #[serde(default = "default_timestep_guidance_channels")]
43 pub timestep_guidance_channels: usize,
44 #[serde(default = "default_mlp_ratio")]
45 pub mlp_ratio: f64,
46 #[serde(default = "default_axes_dims_rope")]
47 pub axes_dims_rope: Vec<usize>,
48 #[serde(default = "default_rope_theta")]
49 pub rope_theta: usize,
50 #[serde(default = "default_eps")]
51 pub eps: f64,
52 #[serde(default = "default_guidance_embeds")]
53 pub guidance_embeds: bool,
54}
55
56fn default_patch_size() -> usize {
57 1
58}
59fn default_in_channels() -> usize {
60 128
61}
62fn default_num_layers() -> usize {
63 8
64}
65fn default_num_single_layers() -> usize {
66 48
67}
68fn default_attention_head_dim() -> usize {
69 128
70}
71fn default_num_attention_heads() -> usize {
72 48
73}
74fn default_joint_attention_dim() -> usize {
75 15360
76}
77fn default_timestep_guidance_channels() -> usize {
78 256
79}
80fn default_mlp_ratio() -> f64 {
81 3.0
82}
83fn default_axes_dims_rope() -> Vec<usize> {
84 vec![32, 32, 32, 32]
85}
86fn default_rope_theta() -> usize {
87 2000
88}
89fn default_eps() -> f64 {
90 1e-6
91}
92fn default_guidance_embeds() -> bool {
93 true
94}
95
96impl Flux2Config {
97 pub fn from_file(path: &Path) -> Result<Self> {
98 let data = std::fs::read_to_string(path).with_context(|| format!("reading {path:?}"))?;
99 Ok(serde_json::from_str(&data)?)
100 }
101
102 pub fn inner_dim(&self) -> usize {
104 self.num_attention_heads * self.attention_head_dim
105 }
106
107 pub fn ff_inner_dim(&self) -> usize {
108 (self.inner_dim() as f64 * self.mlp_ratio) as usize
109 }
110
111 pub fn out_ch(&self) -> usize {
112 self.out_channels.unwrap_or(self.in_channels)
113 }
114
115 pub fn proj_out_dim(&self) -> usize {
116 self.patch_size * self.patch_size * self.out_ch()
117 }
118
119 pub fn flux2_dev() -> Self {
121 Self {
122 patch_size: 1,
123 in_channels: 128,
124 out_channels: None,
125 num_layers: 8,
126 num_single_layers: 48,
127 attention_head_dim: 128,
128 num_attention_heads: 48,
129 joint_attention_dim: 15360,
130 timestep_guidance_channels: 256,
131 mlp_ratio: 3.0,
132 axes_dims_rope: vec![32, 32, 32, 32],
133 rope_theta: 2000,
134 eps: 1e-6,
135 guidance_embeds: true,
136 }
137 }
138
139 pub fn flux2_klein_4b() -> Self {
141 Self {
142 num_layers: 4,
143 num_single_layers: 16,
144 num_attention_heads: 24,
145 attention_head_dim: 128,
146 joint_attention_dim: 7680,
147 guidance_embeds: false,
148 ..Self::flux2_dev()
149 }
150 }
151
152 pub fn flux2_klein_9b() -> Self {
154 Self {
155 num_layers: 8,
156 num_single_layers: 24,
157 num_attention_heads: 32,
158 attention_head_dim: 128,
159 joint_attention_dim: 12288,
160 guidance_embeds: false,
161 ..Self::flux2_dev()
162 }
163 }
164
165 pub fn infer_from_weight_keys<'a>(keys: impl IntoIterator<Item = &'a str>) -> Self {
167 let keys: Vec<&str> = keys.into_iter().collect();
168 let double = max_block_layers(&keys, &["double_blocks.", "transformer_blocks."]);
169 let single = max_block_layers(&keys, &["single_blocks.", "single_transformer_blocks."]);
170 let guidance = keys.iter().any(|k| {
171 k.contains("guidance_in.") || k.contains("time_guidance_embed.guidance_embedder.")
172 });
173 match (double, single) {
174 (8, 24) => Self::flux2_klein_9b(),
175 (4, 16) => Self::flux2_klein_4b(),
176 (8, 48) => Self::flux2_dev(),
177 (d, s) if d > 0 && s > 0 => {
178 let mut cfg = Self::flux2_klein_9b();
179 cfg.num_layers = d;
180 cfg.num_single_layers = s;
181 cfg.guidance_embeds = guidance;
182 cfg
183 }
184 _ => Self::flux2_klein_9b(),
185 }
186 }
187
188 pub fn from_gguf(raw: &GgufFile) -> Result<Self> {
190 let arch = gguf_architecture_str(raw).unwrap_or("flux");
191 if arch != "flux" {
192 bail!("Flux2Config::from_gguf expected architecture `flux`, got {arch}");
193 }
194 if let Some(name) = raw
195 .metadata
196 .get("general.basename")
197 .and_then(MetaValue::as_str)
198 {
199 let lower = name.to_lowercase();
200 if lower.contains("klein") && (lower.contains("9b") || lower.contains("9-b")) {
201 return Ok(Self::flux2_klein_9b());
202 }
203 if lower.contains("klein") {
204 return Ok(Self::flux2_klein_4b());
205 }
206 if lower.contains("dev") {
207 return Ok(Self::flux2_dev());
208 }
209 }
210 Ok(Self::infer_from_weight_keys(
211 raw.tensors.keys().map(|s| s.as_str()),
212 ))
213 }
214
215 pub fn tiny() -> Self {
217 Self {
218 patch_size: 1,
219 in_channels: 8,
220 out_channels: None,
221 num_layers: 1,
222 num_single_layers: 1,
223 attention_head_dim: 16,
224 num_attention_heads: 2,
225 joint_attention_dim: 16,
226 timestep_guidance_channels: 32,
227 mlp_ratio: 2.0,
228 axes_dims_rope: vec![4, 4, 4, 4],
229 rope_theta: 2000,
230 eps: 1e-6,
231 guidance_embeds: true,
232 }
233 }
234}
235
236fn max_block_layers(keys: &[&str], prefixes: &[&str]) -> usize {
237 let mut max_idx = 0usize;
238 for key in keys {
239 for pfx in prefixes {
240 if let Some(rest) = key.strip_prefix(pfx) {
241 if let Ok(i) = rest.split('.').next().unwrap_or("").parse::<usize>() {
242 max_idx = max_idx.max(i + 1);
243 }
244 }
245 }
246 }
247 max_idx
248}
249
250#[cfg(test)]
251mod gguf_config_tests {
252 use super::*;
253
254 #[test]
255 fn infer_klein_9b_from_bfl_keys() {
256 let keys = [
257 "double_blocks.0.img_attn.qkv.weight",
258 "double_blocks.7.img_attn.qkv.weight",
259 "single_blocks.23.linear1.weight",
260 ];
261 let cfg = Flux2Config::infer_from_weight_keys(keys);
262 assert_eq!(cfg.num_layers, 8);
263 assert_eq!(cfg.num_single_layers, 24);
264 assert!(!cfg.guidance_embeds);
265 }
266}