1use super::config::Flux2VaeConfig;
17use anyhow::{Result, ensure};
18use rlx_core::weight_map::WeightMap;
19use std::path::{Path, PathBuf};
20
21#[derive(Debug, Clone)]
22pub struct Conv2dWeight {
23 pub weight: Vec<f32>,
24 pub bias: Vec<f32>,
25 pub in_c: usize,
26 pub out_c: usize,
27}
28
29#[derive(Debug, Clone)]
30pub struct GroupNormWeight {
31 pub gamma: Vec<f32>,
32 pub beta: Vec<f32>,
33}
34
35#[derive(Debug, Clone)]
36pub struct ResnetBlockWeights {
37 pub norm1: GroupNormWeight,
38 pub conv1: Conv2dWeight,
39 pub norm2: GroupNormWeight,
40 pub conv2: Conv2dWeight,
41 pub shortcut: Option<Conv2dWeight>,
42}
43
44#[derive(Debug, Clone)]
45pub struct AttnBlockWeights {
46 pub norm: GroupNormWeight,
47 pub to_q: Conv2dWeight,
48 pub to_k: Conv2dWeight,
49 pub to_v: Conv2dWeight,
50 pub to_out: Conv2dWeight,
51}
52
53#[derive(Debug, Clone)]
54pub struct UpDecoderBlockWeights {
55 pub resnets: Vec<ResnetBlockWeights>,
56 pub upsample: Option<Conv2dWeight>,
57}
58
59#[derive(Debug, Clone)]
60pub struct DownEncoderBlockWeights {
61 pub resnets: Vec<ResnetBlockWeights>,
62 pub downsample: Option<Conv2dWeight>,
63}
64
65#[derive(Debug, Clone)]
66pub struct Flux2VaeWeights {
67 pub encoder_conv_in: Conv2dWeight,
68 pub encoder_down_blocks: Vec<DownEncoderBlockWeights>,
69 pub encoder_mid_resnets: Vec<ResnetBlockWeights>,
70 pub encoder_mid_attn: Option<AttnBlockWeights>,
71 pub encoder_conv_norm_out: GroupNormWeight,
72 pub encoder_conv_out: Conv2dWeight,
73 pub quant_conv: Conv2dWeight,
74 pub post_quant_conv: Option<Conv2dWeight>,
75 pub conv_in: Conv2dWeight,
76 pub mid_resnets: Vec<ResnetBlockWeights>,
77 pub mid_attn: Option<AttnBlockWeights>,
78 pub up_blocks: Vec<UpDecoderBlockWeights>,
79 pub conv_norm_out: GroupNormWeight,
80 pub conv_out: Conv2dWeight,
81 pub bn_running_mean: Vec<f32>,
82 pub bn_running_var: Vec<f32>,
83}
84
85pub fn resolve_vae_dir(model_path: &Path) -> Option<PathBuf> {
87 crate::paths::find_component_dir(model_path, "vae")
88}
89
90pub fn load_flux2_vae_weights(path: &Path, cfg: &Flux2VaeConfig) -> Result<Flux2VaeWeights> {
91 let wm = if path.is_dir() {
92 WeightMap::from_safetensors_dir(path)?
93 } else {
94 WeightMap::from_file(
95 path.to_str()
96 .ok_or_else(|| anyhow::anyhow!("non-utf8 path"))?,
97 )?
98 };
99 extract_flux2_vae_weights(wm, cfg)
100}
101
102pub fn extract_flux2_vae_weights(
103 mut wm: WeightMap,
104 cfg: &Flux2VaeConfig,
105) -> Result<Flux2VaeWeights> {
106 let encoder_conv_in = load_conv(&mut wm, "encoder.conv_in.weight", "encoder.conv_in.bias")?;
107
108 let mut encoder_down_blocks = Vec::new();
109 let channels: Vec<usize> = cfg.block_out_channels.clone();
110 for (i, &out_ch) in channels.iter().enumerate() {
111 let in_ch = if i == 0 { channels[0] } else { channels[i - 1] };
112 let num_layers = cfg.layers_per_block;
113 let mut resnets = Vec::with_capacity(num_layers);
114 for j in 0..num_layers {
115 resnets.push(load_resnet(
116 &mut wm,
117 &format!("encoder.down_blocks.{i}.resnets.{j}"),
118 cfg.norm_num_groups,
119 )?);
120 let _ = if j == 0 { in_ch } else { out_ch };
121 }
122 let downsample = if i + 1 < channels.len() {
123 Some(load_conv(
124 &mut wm,
125 &format!("encoder.down_blocks.{i}.downsamplers.0.conv.weight"),
126 &format!("encoder.down_blocks.{i}.downsamplers.0.conv.bias"),
127 )?)
128 } else {
129 None
130 };
131 encoder_down_blocks.push(DownEncoderBlockWeights {
132 resnets,
133 downsample,
134 });
135 }
136
137 let mut encoder_mid_resnets = Vec::new();
138 for i in 0..2 {
139 encoder_mid_resnets.push(load_resnet(
140 &mut wm,
141 &format!("encoder.mid_block.resnets.{i}"),
142 cfg.norm_num_groups,
143 )?);
144 }
145 let encoder_mid_attn = if cfg.mid_block_add_attention {
146 let p = "encoder.mid_block.attentions.0";
147 Some(AttnBlockWeights {
148 norm: load_gn(&mut wm, &format!("{p}.group_norm"))?,
149 to_q: load_conv(
150 &mut wm,
151 &format!("{p}.to_q.weight"),
152 &format!("{p}.to_q.bias"),
153 )?,
154 to_k: load_conv(
155 &mut wm,
156 &format!("{p}.to_k.weight"),
157 &format!("{p}.to_k.bias"),
158 )?,
159 to_v: load_conv(
160 &mut wm,
161 &format!("{p}.to_v.weight"),
162 &format!("{p}.to_v.bias"),
163 )?,
164 to_out: load_conv(
165 &mut wm,
166 &format!("{p}.to_out.0.weight"),
167 &format!("{p}.to_out.0.bias"),
168 )?,
169 })
170 } else {
171 None
172 };
173 let encoder_conv_norm_out = load_gn(&mut wm, "encoder.conv_norm_out")?;
174 let encoder_conv_out = load_conv(&mut wm, "encoder.conv_out.weight", "encoder.conv_out.bias")?;
175 let quant_conv = load_conv(&mut wm, "quant_conv.weight", "quant_conv.bias")?;
176
177 let post_quant_conv = if cfg.use_post_quant_conv {
178 Some(load_conv(
179 &mut wm,
180 "post_quant_conv.weight",
181 "post_quant_conv.bias",
182 )?)
183 } else {
184 None
185 };
186 let conv_in = load_conv(&mut wm, "decoder.conv_in.weight", "decoder.conv_in.bias")?;
187
188 let mut mid_resnets = Vec::new();
189 for i in 0..2 {
190 mid_resnets.push(load_resnet(
191 &mut wm,
192 &format!("decoder.mid_block.resnets.{i}"),
193 cfg.norm_num_groups,
194 )?);
195 }
196 let mid_attn = if cfg.mid_block_add_attention {
197 let p = "decoder.mid_block.attentions.0";
198 Some(AttnBlockWeights {
199 norm: load_gn(&mut wm, &format!("{p}.group_norm"))?,
200 to_q: load_conv(
201 &mut wm,
202 &format!("{p}.to_q.weight"),
203 &format!("{p}.to_q.bias"),
204 )?,
205 to_k: load_conv(
206 &mut wm,
207 &format!("{p}.to_k.weight"),
208 &format!("{p}.to_k.bias"),
209 )?,
210 to_v: load_conv(
211 &mut wm,
212 &format!("{p}.to_v.weight"),
213 &format!("{p}.to_v.bias"),
214 )?,
215 to_out: load_conv(
216 &mut wm,
217 &format!("{p}.to_out.0.weight"),
218 &format!("{p}.to_out.0.bias"),
219 )?,
220 })
221 } else {
222 None
223 };
224
225 let channels: Vec<usize> = cfg.block_out_channels.clone();
226 let mut up_blocks = Vec::new();
227 let reversed: Vec<usize> = channels.iter().copied().rev().collect();
228 for (i, &out_ch) in reversed.iter().enumerate() {
229 let in_ch = if i == 0 {
230 *channels.last().unwrap()
231 } else {
232 reversed[i - 1]
233 };
234 let num_layers = cfg.layers_per_block + 1;
235 let mut resnets = Vec::with_capacity(num_layers);
236 for j in 0..num_layers {
237 let block_in = if j == 0 { in_ch } else { out_ch };
238 resnets.push(load_resnet(
239 &mut wm,
240 &format!("decoder.up_blocks.{i}.resnets.{j}"),
241 cfg.norm_num_groups,
242 )?);
243 let _ = block_in;
244 }
245 let upsample = if i + 1 < reversed.len() {
246 Some(load_conv(
247 &mut wm,
248 &format!("decoder.up_blocks.{i}.upsamplers.0.conv.weight"),
249 &format!("decoder.up_blocks.{i}.upsamplers.0.conv.bias"),
250 )?)
251 } else {
252 None
253 };
254 up_blocks.push(UpDecoderBlockWeights { resnets, upsample });
255 }
256
257 let conv_norm_out = load_gn(&mut wm, "decoder.conv_norm_out")?;
258 let conv_out = load_conv(&mut wm, "decoder.conv_out.weight", "decoder.conv_out.bias")?;
259 let (bn_running_mean, _) = wm.take("bn.running_mean")?;
260 let (bn_running_var, _) = wm.take("bn.running_var")?;
261 ensure!(
262 bn_running_mean.len() == cfg.bn_channels(),
263 "bn.running_mean len {} != {}",
264 bn_running_mean.len(),
265 cfg.bn_channels()
266 );
267
268 Ok(Flux2VaeWeights {
269 encoder_conv_in,
270 encoder_down_blocks,
271 encoder_mid_resnets,
272 encoder_mid_attn,
273 encoder_conv_norm_out,
274 encoder_conv_out,
275 quant_conv,
276 post_quant_conv,
277 conv_in,
278 mid_resnets,
279 mid_attn,
280 up_blocks,
281 conv_norm_out,
282 conv_out,
283 bn_running_mean,
284 bn_running_var,
285 })
286}
287
288fn load_conv(wm: &mut WeightMap, w_key: &str, b_key: &str) -> Result<Conv2dWeight> {
289 let (data, shape) = wm.take(w_key)?;
290 let (bias, _) = wm.take(b_key)?;
291 let (out_c, in_c, kh, kw) = match shape.as_slice() {
292 [o, i, 3, 3] => (*o, *i, 3, 3),
293 [o, i, 1, 1] => (*o, *i, 1, 1),
294 [o, i] => (*o, *i, 1, 1),
295 _ => anyhow::bail!("conv weight shape {shape:?}"),
296 };
297 ensure!(kh == kw && (kh == 3 || kh == 1), "expected 1x1 or 3x3 conv");
298 let weight = if kh == 3 {
299 let mut w = vec![0.0f32; out_c * in_c * 9];
300 for oc in 0..out_c {
301 for ic in 0..in_c {
302 for ky in 0..3 {
303 for kx in 0..3 {
304 w[(oc * in_c + ic) * 9 + ky * 3 + kx] =
305 data[((oc * in_c + ic) * 3 + ky) * 3 + kx];
306 }
307 }
308 }
309 }
310 w
311 } else {
312 data
313 };
314 Ok(Conv2dWeight {
315 weight,
316 bias,
317 in_c,
318 out_c,
319 })
320}
321
322fn load_gn(wm: &mut WeightMap, prefix: &str) -> Result<GroupNormWeight> {
323 let (gamma, _) = wm.take(&format!("{prefix}.weight"))?;
324 let (beta, _) = wm.take(&format!("{prefix}.bias"))?;
325 Ok(GroupNormWeight { gamma, beta })
326}
327
328fn zero_conv3(in_c: usize, out_c: usize) -> Conv2dWeight {
329 Conv2dWeight {
330 weight: vec![0.0; out_c * in_c * 9],
331 bias: vec![0.0; out_c],
332 in_c,
333 out_c,
334 }
335}
336
337fn zero_conv1(in_c: usize, out_c: usize) -> Conv2dWeight {
338 Conv2dWeight {
339 weight: vec![0.0; out_c * in_c],
340 bias: vec![0.0; out_c],
341 in_c,
342 out_c,
343 }
344}
345
346fn zero_gn(ch: usize) -> GroupNormWeight {
347 GroupNormWeight {
348 gamma: vec![1.0; ch],
349 beta: vec![0.0; ch],
350 }
351}
352
353fn zero_resnet(in_c: usize, out_c: usize) -> ResnetBlockWeights {
354 ResnetBlockWeights {
355 norm1: zero_gn(in_c),
356 conv1: zero_conv3(in_c, out_c),
357 norm2: zero_gn(out_c),
358 conv2: zero_conv3(out_c, out_c),
359 shortcut: if in_c != out_c {
360 Some(zero_conv1(in_c, out_c))
361 } else {
362 None
363 },
364 }
365}
366
367pub fn synthetic_vae_weights(cfg: &Flux2VaeConfig) -> Flux2VaeWeights {
369 let last = *cfg.block_out_channels.last().unwrap_or(&8);
370 let channels: Vec<usize> = cfg.block_out_channels.clone();
371 let reversed: Vec<usize> = channels.iter().copied().rev().collect();
372 let mut up_blocks = Vec::new();
373 for (i, &out_ch) in reversed.iter().enumerate() {
374 let in_ch = if i == 0 { last } else { reversed[i - 1] };
375 let num_layers = cfg.layers_per_block + 1;
376 let resnets = (0..num_layers)
377 .map(|j| {
378 let cin = if j == 0 { in_ch } else { out_ch };
379 zero_resnet(cin, out_ch)
380 })
381 .collect();
382 let upsample = if i + 1 < reversed.len() {
383 Some(zero_conv3(out_ch, out_ch))
384 } else {
385 None
386 };
387 up_blocks.push(UpDecoderBlockWeights { resnets, upsample });
388 }
389 Flux2VaeWeights {
390 encoder_conv_in: zero_conv3(cfg.in_channels, channels[0]),
391 encoder_down_blocks: {
392 let mut blocks = Vec::new();
393 for (i, &out_ch) in channels.iter().enumerate() {
394 let in_ch = if i == 0 { channels[0] } else { channels[i - 1] };
395 let num_layers = cfg.layers_per_block;
396 let resnets = (0..num_layers)
397 .map(|j| {
398 let cin = if j == 0 { in_ch } else { out_ch };
399 zero_resnet(cin, out_ch)
400 })
401 .collect();
402 let downsample = if i + 1 < channels.len() {
403 Some(zero_conv3(out_ch, out_ch))
404 } else {
405 None
406 };
407 blocks.push(DownEncoderBlockWeights {
408 resnets,
409 downsample,
410 });
411 }
412 blocks
413 },
414 encoder_mid_resnets: vec![zero_resnet(last, last), zero_resnet(last, last)],
415 encoder_mid_attn: None,
416 encoder_conv_norm_out: zero_gn(last),
417 encoder_conv_out: zero_conv3(last, cfg.latent_channels * 2),
418 quant_conv: zero_conv1(cfg.latent_channels * 2, cfg.latent_channels * 2),
419 post_quant_conv: cfg
420 .use_post_quant_conv
421 .then(|| zero_conv1(cfg.latent_channels, cfg.latent_channels)),
422 conv_in: zero_conv3(cfg.latent_channels, last),
423 mid_resnets: vec![zero_resnet(last, last), zero_resnet(last, last)],
424 mid_attn: None,
425 up_blocks,
426 conv_norm_out: zero_gn(cfg.block_out_channels[0]),
427 conv_out: zero_conv3(cfg.block_out_channels[0], cfg.out_channels),
428 bn_running_mean: vec![0.0; cfg.bn_channels()],
429 bn_running_var: vec![1.0; cfg.bn_channels()],
430 }
431}
432
433fn load_resnet(wm: &mut WeightMap, prefix: &str, groups: usize) -> Result<ResnetBlockWeights> {
434 let norm1 = load_gn(wm, &format!("{prefix}.norm1"))?;
435 let conv1 = load_conv(
436 wm,
437 &format!("{prefix}.conv1.weight"),
438 &format!("{prefix}.conv1.bias"),
439 )?;
440 let norm2 = load_gn(wm, &format!("{prefix}.norm2"))?;
441 let conv2 = load_conv(
442 wm,
443 &format!("{prefix}.conv2.weight"),
444 &format!("{prefix}.conv2.bias"),
445 )?;
446 let shortcut = if wm.has(&format!("{prefix}.conv_shortcut.weight")) {
447 Some(load_conv(
448 wm,
449 &format!("{prefix}.conv_shortcut.weight"),
450 &format!("{prefix}.conv_shortcut.bias"),
451 )?)
452 } else {
453 None
454 };
455 let _ = groups;
456 Ok(ResnetBlockWeights {
457 norm1,
458 conv1,
459 norm2,
460 conv2,
461 shortcut,
462 })
463}