1use anyhow::{Context, Result};
17use serde::Deserialize;
18use std::path::Path;
19
20#[derive(Debug, Clone, Deserialize)]
21pub struct Flux2VaeConfig {
22 pub in_channels: usize,
23 pub out_channels: usize,
24 pub latent_channels: usize,
25 pub layers_per_block: usize,
26 pub norm_num_groups: usize,
27 pub block_out_channels: Vec<usize>,
28 #[serde(default = "default_act_fn")]
29 pub act_fn: String,
30 #[serde(default = "default_batch_norm_eps")]
31 pub batch_norm_eps: f32,
32 #[serde(default = "default_mid_block_add_attention")]
33 pub mid_block_add_attention: bool,
34 #[serde(default = "default_use_post_quant_conv")]
35 pub use_post_quant_conv: bool,
36 #[serde(default)]
37 pub scaling_factor: f32,
38 #[serde(default)]
39 pub shift_factor: f32,
40}
41
42fn default_act_fn() -> String {
43 "silu".into()
44}
45fn default_batch_norm_eps() -> f32 {
46 1e-4
47}
48fn default_mid_block_add_attention() -> bool {
49 true
50}
51fn default_use_post_quant_conv() -> bool {
52 true
53}
54
55impl Flux2VaeConfig {
56 pub fn from_file(path: &Path) -> Result<Self> {
57 let data = std::fs::read_to_string(path).with_context(|| format!("reading {path:?}"))?;
58 Ok(serde_json::from_str(&data)?)
59 }
60
61 pub fn flux2_klein() -> Self {
62 Self {
63 in_channels: 3,
64 out_channels: 3,
65 latent_channels: 32,
66 layers_per_block: 2,
67 norm_num_groups: 32,
68 block_out_channels: vec![128, 256, 512, 512],
69 act_fn: "silu".into(),
70 batch_norm_eps: 1e-4,
71 mid_block_add_attention: true,
72 use_post_quant_conv: true,
73 scaling_factor: 1.0,
74 shift_factor: 0.0,
75 }
76 }
77
78 pub fn tiny() -> Self {
80 Self {
81 in_channels: 3,
82 out_channels: 3,
83 latent_channels: 4,
84 layers_per_block: 1,
85 norm_num_groups: 2,
86 block_out_channels: vec![8, 16],
87 act_fn: "silu".into(),
88 batch_norm_eps: 1e-4,
89 mid_block_add_attention: false,
90 use_post_quant_conv: true,
91 scaling_factor: 1.0,
92 shift_factor: 0.0,
93 }
94 }
95
96 pub fn bn_channels(&self) -> usize {
97 4 * self.latent_channels
98 }
99
100 pub fn encode_spatial_stride(&self) -> usize {
102 1 << self.block_out_channels.len().saturating_sub(1)
103 }
104}