Skip to main content

rlx_flux2/vae/
weights.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use super::config::Flux2VaeConfig;
17use anyhow::{Result, ensure};
18use rlx_core::weight_map::WeightMap;
19use std::path::{Path, PathBuf};
20
21#[derive(Debug, Clone)]
22pub struct Conv2dWeight {
23    pub weight: Vec<f32>,
24    pub bias: Vec<f32>,
25    pub in_c: usize,
26    pub out_c: usize,
27}
28
29#[derive(Debug, Clone)]
30pub struct GroupNormWeight {
31    pub gamma: Vec<f32>,
32    pub beta: Vec<f32>,
33}
34
35#[derive(Debug, Clone)]
36pub struct ResnetBlockWeights {
37    pub norm1: GroupNormWeight,
38    pub conv1: Conv2dWeight,
39    pub norm2: GroupNormWeight,
40    pub conv2: Conv2dWeight,
41    pub shortcut: Option<Conv2dWeight>,
42}
43
44#[derive(Debug, Clone)]
45pub struct AttnBlockWeights {
46    pub norm: GroupNormWeight,
47    pub to_q: Conv2dWeight,
48    pub to_k: Conv2dWeight,
49    pub to_v: Conv2dWeight,
50    pub to_out: Conv2dWeight,
51}
52
53#[derive(Debug, Clone)]
54pub struct UpDecoderBlockWeights {
55    pub resnets: Vec<ResnetBlockWeights>,
56    pub upsample: Option<Conv2dWeight>,
57}
58
59#[derive(Debug, Clone)]
60pub struct DownEncoderBlockWeights {
61    pub resnets: Vec<ResnetBlockWeights>,
62    pub downsample: Option<Conv2dWeight>,
63}
64
65#[derive(Debug, Clone)]
66pub struct Flux2VaeWeights {
67    pub encoder_conv_in: Conv2dWeight,
68    pub encoder_down_blocks: Vec<DownEncoderBlockWeights>,
69    pub encoder_mid_resnets: Vec<ResnetBlockWeights>,
70    pub encoder_mid_attn: Option<AttnBlockWeights>,
71    pub encoder_conv_norm_out: GroupNormWeight,
72    pub encoder_conv_out: Conv2dWeight,
73    pub quant_conv: Conv2dWeight,
74    pub post_quant_conv: Option<Conv2dWeight>,
75    pub conv_in: Conv2dWeight,
76    pub mid_resnets: Vec<ResnetBlockWeights>,
77    pub mid_attn: Option<AttnBlockWeights>,
78    pub up_blocks: Vec<UpDecoderBlockWeights>,
79    pub conv_norm_out: GroupNormWeight,
80    pub conv_out: Conv2dWeight,
81    pub bn_running_mean: Vec<f32>,
82    pub bn_running_var: Vec<f32>,
83}
84
85/// Resolve `vae/` next to a transformer weights file or model root.
86pub fn resolve_vae_dir(model_path: &Path) -> Option<PathBuf> {
87    crate::paths::find_component_dir(model_path, "vae")
88}
89
90pub fn load_flux2_vae_weights(path: &Path, cfg: &Flux2VaeConfig) -> Result<Flux2VaeWeights> {
91    let wm = if path.is_dir() {
92        WeightMap::from_safetensors_dir(path)?
93    } else {
94        WeightMap::from_file(
95            path.to_str()
96                .ok_or_else(|| anyhow::anyhow!("non-utf8 path"))?,
97        )?
98    };
99    extract_flux2_vae_weights(wm, cfg)
100}
101
102pub fn extract_flux2_vae_weights(
103    mut wm: WeightMap,
104    cfg: &Flux2VaeConfig,
105) -> Result<Flux2VaeWeights> {
106    let encoder_conv_in = load_conv(&mut wm, "encoder.conv_in.weight", "encoder.conv_in.bias")?;
107
108    let mut encoder_down_blocks = Vec::new();
109    let channels: Vec<usize> = cfg.block_out_channels.clone();
110    for (i, &out_ch) in channels.iter().enumerate() {
111        let in_ch = if i == 0 { channels[0] } else { channels[i - 1] };
112        let num_layers = cfg.layers_per_block;
113        let mut resnets = Vec::with_capacity(num_layers);
114        for j in 0..num_layers {
115            resnets.push(load_resnet(
116                &mut wm,
117                &format!("encoder.down_blocks.{i}.resnets.{j}"),
118                cfg.norm_num_groups,
119            )?);
120            let _ = if j == 0 { in_ch } else { out_ch };
121        }
122        let downsample = if i + 1 < channels.len() {
123            Some(load_conv(
124                &mut wm,
125                &format!("encoder.down_blocks.{i}.downsamplers.0.conv.weight"),
126                &format!("encoder.down_blocks.{i}.downsamplers.0.conv.bias"),
127            )?)
128        } else {
129            None
130        };
131        encoder_down_blocks.push(DownEncoderBlockWeights {
132            resnets,
133            downsample,
134        });
135    }
136
137    let mut encoder_mid_resnets = Vec::new();
138    for i in 0..2 {
139        encoder_mid_resnets.push(load_resnet(
140            &mut wm,
141            &format!("encoder.mid_block.resnets.{i}"),
142            cfg.norm_num_groups,
143        )?);
144    }
145    let encoder_mid_attn = if cfg.mid_block_add_attention {
146        let p = "encoder.mid_block.attentions.0";
147        Some(AttnBlockWeights {
148            norm: load_gn(&mut wm, &format!("{p}.group_norm"))?,
149            to_q: load_conv(
150                &mut wm,
151                &format!("{p}.to_q.weight"),
152                &format!("{p}.to_q.bias"),
153            )?,
154            to_k: load_conv(
155                &mut wm,
156                &format!("{p}.to_k.weight"),
157                &format!("{p}.to_k.bias"),
158            )?,
159            to_v: load_conv(
160                &mut wm,
161                &format!("{p}.to_v.weight"),
162                &format!("{p}.to_v.bias"),
163            )?,
164            to_out: load_conv(
165                &mut wm,
166                &format!("{p}.to_out.0.weight"),
167                &format!("{p}.to_out.0.bias"),
168            )?,
169        })
170    } else {
171        None
172    };
173    let encoder_conv_norm_out = load_gn(&mut wm, "encoder.conv_norm_out")?;
174    let encoder_conv_out = load_conv(&mut wm, "encoder.conv_out.weight", "encoder.conv_out.bias")?;
175    let quant_conv = load_conv(&mut wm, "quant_conv.weight", "quant_conv.bias")?;
176
177    let post_quant_conv = if cfg.use_post_quant_conv {
178        Some(load_conv(
179            &mut wm,
180            "post_quant_conv.weight",
181            "post_quant_conv.bias",
182        )?)
183    } else {
184        None
185    };
186    let conv_in = load_conv(&mut wm, "decoder.conv_in.weight", "decoder.conv_in.bias")?;
187
188    let mut mid_resnets = Vec::new();
189    for i in 0..2 {
190        mid_resnets.push(load_resnet(
191            &mut wm,
192            &format!("decoder.mid_block.resnets.{i}"),
193            cfg.norm_num_groups,
194        )?);
195    }
196    let mid_attn = if cfg.mid_block_add_attention {
197        let p = "decoder.mid_block.attentions.0";
198        Some(AttnBlockWeights {
199            norm: load_gn(&mut wm, &format!("{p}.group_norm"))?,
200            to_q: load_conv(
201                &mut wm,
202                &format!("{p}.to_q.weight"),
203                &format!("{p}.to_q.bias"),
204            )?,
205            to_k: load_conv(
206                &mut wm,
207                &format!("{p}.to_k.weight"),
208                &format!("{p}.to_k.bias"),
209            )?,
210            to_v: load_conv(
211                &mut wm,
212                &format!("{p}.to_v.weight"),
213                &format!("{p}.to_v.bias"),
214            )?,
215            to_out: load_conv(
216                &mut wm,
217                &format!("{p}.to_out.0.weight"),
218                &format!("{p}.to_out.0.bias"),
219            )?,
220        })
221    } else {
222        None
223    };
224
225    let channels: Vec<usize> = cfg.block_out_channels.clone();
226    let mut up_blocks = Vec::new();
227    let reversed: Vec<usize> = channels.iter().copied().rev().collect();
228    for (i, &out_ch) in reversed.iter().enumerate() {
229        let in_ch = if i == 0 {
230            *channels.last().unwrap()
231        } else {
232            reversed[i - 1]
233        };
234        let num_layers = cfg.layers_per_block + 1;
235        let mut resnets = Vec::with_capacity(num_layers);
236        for j in 0..num_layers {
237            let block_in = if j == 0 { in_ch } else { out_ch };
238            resnets.push(load_resnet(
239                &mut wm,
240                &format!("decoder.up_blocks.{i}.resnets.{j}"),
241                cfg.norm_num_groups,
242            )?);
243            let _ = block_in;
244        }
245        let upsample = if i + 1 < reversed.len() {
246            Some(load_conv(
247                &mut wm,
248                &format!("decoder.up_blocks.{i}.upsamplers.0.conv.weight"),
249                &format!("decoder.up_blocks.{i}.upsamplers.0.conv.bias"),
250            )?)
251        } else {
252            None
253        };
254        up_blocks.push(UpDecoderBlockWeights { resnets, upsample });
255    }
256
257    let conv_norm_out = load_gn(&mut wm, "decoder.conv_norm_out")?;
258    let conv_out = load_conv(&mut wm, "decoder.conv_out.weight", "decoder.conv_out.bias")?;
259    let (bn_running_mean, _) = wm.take("bn.running_mean")?;
260    let (bn_running_var, _) = wm.take("bn.running_var")?;
261    ensure!(
262        bn_running_mean.len() == cfg.bn_channels(),
263        "bn.running_mean len {} != {}",
264        bn_running_mean.len(),
265        cfg.bn_channels()
266    );
267
268    Ok(Flux2VaeWeights {
269        encoder_conv_in,
270        encoder_down_blocks,
271        encoder_mid_resnets,
272        encoder_mid_attn,
273        encoder_conv_norm_out,
274        encoder_conv_out,
275        quant_conv,
276        post_quant_conv,
277        conv_in,
278        mid_resnets,
279        mid_attn,
280        up_blocks,
281        conv_norm_out,
282        conv_out,
283        bn_running_mean,
284        bn_running_var,
285    })
286}
287
288fn load_conv(wm: &mut WeightMap, w_key: &str, b_key: &str) -> Result<Conv2dWeight> {
289    let (data, shape) = wm.take(w_key)?;
290    let (bias, _) = wm.take(b_key)?;
291    let (out_c, in_c, kh, kw) = match shape.as_slice() {
292        [o, i, 3, 3] => (*o, *i, 3, 3),
293        [o, i, 1, 1] => (*o, *i, 1, 1),
294        [o, i] => (*o, *i, 1, 1),
295        _ => anyhow::bail!("conv weight shape {shape:?}"),
296    };
297    ensure!(kh == kw && (kh == 3 || kh == 1), "expected 1x1 or 3x3 conv");
298    let weight = if kh == 3 {
299        let mut w = vec![0.0f32; out_c * in_c * 9];
300        for oc in 0..out_c {
301            for ic in 0..in_c {
302                for ky in 0..3 {
303                    for kx in 0..3 {
304                        w[(oc * in_c + ic) * 9 + ky * 3 + kx] =
305                            data[((oc * in_c + ic) * 3 + ky) * 3 + kx];
306                    }
307                }
308            }
309        }
310        w
311    } else {
312        data
313    };
314    Ok(Conv2dWeight {
315        weight,
316        bias,
317        in_c,
318        out_c,
319    })
320}
321
322fn load_gn(wm: &mut WeightMap, prefix: &str) -> Result<GroupNormWeight> {
323    let (gamma, _) = wm.take(&format!("{prefix}.weight"))?;
324    let (beta, _) = wm.take(&format!("{prefix}.bias"))?;
325    Ok(GroupNormWeight { gamma, beta })
326}
327
328fn zero_conv3(in_c: usize, out_c: usize) -> Conv2dWeight {
329    Conv2dWeight {
330        weight: vec![0.0; out_c * in_c * 9],
331        bias: vec![0.0; out_c],
332        in_c,
333        out_c,
334    }
335}
336
337fn zero_conv1(in_c: usize, out_c: usize) -> Conv2dWeight {
338    Conv2dWeight {
339        weight: vec![0.0; out_c * in_c],
340        bias: vec![0.0; out_c],
341        in_c,
342        out_c,
343    }
344}
345
346fn zero_gn(ch: usize) -> GroupNormWeight {
347    GroupNormWeight {
348        gamma: vec![1.0; ch],
349        beta: vec![0.0; ch],
350    }
351}
352
353fn zero_resnet(in_c: usize, out_c: usize) -> ResnetBlockWeights {
354    ResnetBlockWeights {
355        norm1: zero_gn(in_c),
356        conv1: zero_conv3(in_c, out_c),
357        norm2: zero_gn(out_c),
358        conv2: zero_conv3(out_c, out_c),
359        shortcut: if in_c != out_c {
360            Some(zero_conv1(in_c, out_c))
361        } else {
362            None
363        },
364    }
365}
366
367/// Zero weights for [`Flux2VaeConfig::tiny`] basic tests.
368pub fn synthetic_vae_weights(cfg: &Flux2VaeConfig) -> Flux2VaeWeights {
369    let last = *cfg.block_out_channels.last().unwrap_or(&8);
370    let channels: Vec<usize> = cfg.block_out_channels.clone();
371    let reversed: Vec<usize> = channels.iter().copied().rev().collect();
372    let mut up_blocks = Vec::new();
373    for (i, &out_ch) in reversed.iter().enumerate() {
374        let in_ch = if i == 0 { last } else { reversed[i - 1] };
375        let num_layers = cfg.layers_per_block + 1;
376        let resnets = (0..num_layers)
377            .map(|j| {
378                let cin = if j == 0 { in_ch } else { out_ch };
379                zero_resnet(cin, out_ch)
380            })
381            .collect();
382        let upsample = if i + 1 < reversed.len() {
383            Some(zero_conv3(out_ch, out_ch))
384        } else {
385            None
386        };
387        up_blocks.push(UpDecoderBlockWeights { resnets, upsample });
388    }
389    Flux2VaeWeights {
390        encoder_conv_in: zero_conv3(cfg.in_channels, channels[0]),
391        encoder_down_blocks: {
392            let mut blocks = Vec::new();
393            for (i, &out_ch) in channels.iter().enumerate() {
394                let in_ch = if i == 0 { channels[0] } else { channels[i - 1] };
395                let num_layers = cfg.layers_per_block;
396                let resnets = (0..num_layers)
397                    .map(|j| {
398                        let cin = if j == 0 { in_ch } else { out_ch };
399                        zero_resnet(cin, out_ch)
400                    })
401                    .collect();
402                let downsample = if i + 1 < channels.len() {
403                    Some(zero_conv3(out_ch, out_ch))
404                } else {
405                    None
406                };
407                blocks.push(DownEncoderBlockWeights {
408                    resnets,
409                    downsample,
410                });
411            }
412            blocks
413        },
414        encoder_mid_resnets: vec![zero_resnet(last, last), zero_resnet(last, last)],
415        encoder_mid_attn: None,
416        encoder_conv_norm_out: zero_gn(last),
417        encoder_conv_out: zero_conv3(last, cfg.latent_channels * 2),
418        quant_conv: zero_conv1(cfg.latent_channels * 2, cfg.latent_channels * 2),
419        post_quant_conv: cfg
420            .use_post_quant_conv
421            .then(|| zero_conv1(cfg.latent_channels, cfg.latent_channels)),
422        conv_in: zero_conv3(cfg.latent_channels, last),
423        mid_resnets: vec![zero_resnet(last, last), zero_resnet(last, last)],
424        mid_attn: None,
425        up_blocks,
426        conv_norm_out: zero_gn(cfg.block_out_channels[0]),
427        conv_out: zero_conv3(cfg.block_out_channels[0], cfg.out_channels),
428        bn_running_mean: vec![0.0; cfg.bn_channels()],
429        bn_running_var: vec![1.0; cfg.bn_channels()],
430    }
431}
432
433fn load_resnet(wm: &mut WeightMap, prefix: &str, groups: usize) -> Result<ResnetBlockWeights> {
434    let norm1 = load_gn(wm, &format!("{prefix}.norm1"))?;
435    let conv1 = load_conv(
436        wm,
437        &format!("{prefix}.conv1.weight"),
438        &format!("{prefix}.conv1.bias"),
439    )?;
440    let norm2 = load_gn(wm, &format!("{prefix}.norm2"))?;
441    let conv2 = load_conv(
442        wm,
443        &format!("{prefix}.conv2.weight"),
444        &format!("{prefix}.conv2.bias"),
445    )?;
446    let shortcut = if wm.has(&format!("{prefix}.conv_shortcut.weight")) {
447        Some(load_conv(
448            wm,
449            &format!("{prefix}.conv_shortcut.weight"),
450            &format!("{prefix}.conv_shortcut.bias"),
451        )?)
452    } else {
453        None
454    };
455    let _ = groups;
456    Ok(ResnetBlockWeights {
457        norm1,
458        conv1,
459        norm2,
460        conv2,
461        shortcut,
462    })
463}