Skip to main content

ferrotorch_diffusion/
config.rs

1//! Configuration for the Stable-Diffusion VAE decoder.
2//!
3//! Matches the public surface of `diffusers.AutoencoderKL.config` for the
4//! fields the decoder actually consumes. Encoder-side fields (e.g.
5//! `down_block_types`) are not stored — the decoder mirror is
6//! decoder-only.
7
8use ferrotorch_core::{FerrotorchError, FerrotorchResult};
9
10/// Frozen config for the Stable-Diffusion VAE decoder.
11///
12/// Mirrors the decoder-relevant subset of `AutoencoderKL.config` for
13/// `runwayml/stable-diffusion-v1-5`. The defaults match SD 1.5 exactly.
14#[derive(Debug, Clone)]
15pub struct VaeDecoderConfig {
16    /// Number of input channels of the image the encoder consumes (and
17    /// therefore of the image the decoder produces). For SD 1.5: 3.
18    pub out_channels: usize,
19    /// Number of latent channels. For SD 1.5: 4.
20    pub latent_channels: usize,
21    /// Per-block-level output channel counts (in encoder order: from
22    /// the highest-resolution block out). For SD 1.5: `[128, 256, 512,
23    /// 512]`. The decoder walks these in reverse, so the first block
24    /// after `conv_in` has `block_out_channels[-1]` channels (= 512).
25    pub block_out_channels: Vec<usize>,
26    /// Number of resnet layers in each Encoder / Decoder up- or
27    /// down-block. The decoder's `UpDecoderBlock2D` uses
28    /// `layers_per_block + 1` resnets (the diffusers convention). For
29    /// SD 1.5: 2 (so each up-block has 3 resnets).
30    pub layers_per_block: usize,
31    /// Number of GroupNorm groups (decoder-internal `norm1` / `norm2` /
32    /// `conv_norm_out`). For SD 1.5: 32.
33    pub norm_num_groups: usize,
34    /// Spatial size the encoder accepts (and the decoder produces).
35    /// For SD 1.5: 512.
36    pub sample_size: usize,
37    /// VAE latent scaling factor. The decoder pre-divides the latent by
38    /// this value (matching `AutoencoderKL.decode`). For SD 1.5: 0.18215.
39    pub scaling_factor: f64,
40}
41
42impl Default for VaeDecoderConfig {
43    fn default() -> Self {
44        // SD 1.5 VAE config.
45        Self {
46            out_channels: 3,
47            latent_channels: 4,
48            block_out_channels: vec![128, 256, 512, 512],
49            layers_per_block: 2,
50            norm_num_groups: 32,
51            sample_size: 512,
52            scaling_factor: 0.18215,
53        }
54    }
55}
56
57impl VaeDecoderConfig {
58    /// SD 1.5 VAE decoder config (alias for `Default::default()`).
59    pub fn sd_v1_5() -> Self {
60        Self::default()
61    }
62
63    /// Validate field bounds (positive sizes, channels divisible by
64    /// `norm_num_groups`, at least one resolution).
65    ///
66    /// # Errors
67    ///
68    /// Returns [`FerrotorchError::InvalidArgument`] for any out-of-bounds
69    /// or arithmetic-incompatible field.
70    pub fn validate(&self) -> FerrotorchResult<()> {
71        if self.block_out_channels.is_empty() {
72            return Err(FerrotorchError::InvalidArgument {
73                message: "VaeDecoderConfig: block_out_channels must be non-empty".into(),
74            });
75        }
76        if self.norm_num_groups == 0 {
77            return Err(FerrotorchError::InvalidArgument {
78                message: "VaeDecoderConfig: norm_num_groups must be > 0".into(),
79            });
80        }
81        for &c in &self.block_out_channels {
82            if c == 0 || c % self.norm_num_groups != 0 {
83                return Err(FerrotorchError::InvalidArgument {
84                    message: format!(
85                        "VaeDecoderConfig: block_out_channels entry {c} must be > 0 and divisible \
86                         by norm_num_groups={}",
87                        self.norm_num_groups
88                    ),
89                });
90            }
91        }
92        if self.latent_channels == 0 {
93            return Err(FerrotorchError::InvalidArgument {
94                message: "VaeDecoderConfig: latent_channels must be > 0".into(),
95            });
96        }
97        if self.out_channels == 0 {
98            return Err(FerrotorchError::InvalidArgument {
99                message: "VaeDecoderConfig: out_channels must be > 0".into(),
100            });
101        }
102        if self.layers_per_block == 0 {
103            return Err(FerrotorchError::InvalidArgument {
104                message: "VaeDecoderConfig: layers_per_block must be > 0".into(),
105            });
106        }
107        if self.sample_size == 0 {
108            return Err(FerrotorchError::InvalidArgument {
109                message: "VaeDecoderConfig: sample_size must be > 0".into(),
110            });
111        }
112        if !self.scaling_factor.is_finite() || self.scaling_factor == 0.0 {
113            return Err(FerrotorchError::InvalidArgument {
114                message: format!(
115                    "VaeDecoderConfig: scaling_factor must be finite and non-zero, got {}",
116                    self.scaling_factor
117                ),
118            });
119        }
120        Ok(())
121    }
122
123    /// Number of resnets in each `UpDecoderBlock2D` (the diffusers
124    /// convention is `layers_per_block + 1`).
125    pub fn resnets_per_up_block(&self) -> usize {
126        self.layers_per_block + 1
127    }
128
129    /// Number of up-blocks (== number of down-blocks the encoder used,
130    /// == `block_out_channels.len()`).
131    pub fn num_up_blocks(&self) -> usize {
132        self.block_out_channels.len()
133    }
134
135    /// Parse a `vae/config.json` document into a [`VaeDecoderConfig`].
136    ///
137    /// Recognised keys (all optional — anything missing falls back to
138    /// the SD-1.5 defaults):
139    ///   - `out_channels`, `latent_channels`, `block_out_channels`,
140    ///     `layers_per_block`, `norm_num_groups`, `sample_size`,
141    ///     `scaling_factor`.
142    ///
143    /// # Errors
144    ///
145    /// Returns [`FerrotorchError::InvalidArgument`] on malformed JSON or
146    /// a wrong-type field (e.g. `block_out_channels` not an array of
147    /// integers).
148    pub fn from_json_str(s: &str) -> FerrotorchResult<Self> {
149        let v: serde_json::Value =
150            serde_json::from_str(s).map_err(|e| FerrotorchError::InvalidArgument {
151                message: format!("VaeDecoderConfig::from_json_str: bad JSON: {e}"),
152            })?;
153        let mut cfg = Self::default();
154        if let Some(x) = v.get("out_channels").and_then(serde_json::Value::as_u64) {
155            cfg.out_channels = x as usize;
156        }
157        if let Some(x) = v.get("latent_channels").and_then(serde_json::Value::as_u64) {
158            cfg.latent_channels = x as usize;
159        }
160        if let Some(arr) = v
161            .get("block_out_channels")
162            .and_then(serde_json::Value::as_array)
163        {
164            let mut out = Vec::with_capacity(arr.len());
165            for e in arr {
166                let n = e.as_u64().ok_or_else(|| FerrotorchError::InvalidArgument {
167                    message: format!(
168                        "VaeDecoderConfig::from_json_str: block_out_channels entry \
169                         must be a non-negative integer, got {e}"
170                    ),
171                })?;
172                out.push(n as usize);
173            }
174            cfg.block_out_channels = out;
175        }
176        if let Some(x) = v
177            .get("layers_per_block")
178            .and_then(serde_json::Value::as_u64)
179        {
180            cfg.layers_per_block = x as usize;
181        }
182        if let Some(x) = v.get("norm_num_groups").and_then(serde_json::Value::as_u64) {
183            cfg.norm_num_groups = x as usize;
184        }
185        if let Some(x) = v.get("sample_size").and_then(serde_json::Value::as_u64) {
186            cfg.sample_size = x as usize;
187        }
188        if let Some(x) = v.get("scaling_factor").and_then(serde_json::Value::as_f64) {
189            cfg.scaling_factor = x;
190        }
191        cfg.validate()?;
192        Ok(cfg)
193    }
194
195    /// Parse a `vae/config.json` file from disk.
196    ///
197    /// # Errors
198    ///
199    /// Returns [`FerrotorchError::InvalidArgument`] for I/O or parse
200    /// failures (file missing, malformed JSON, wrong-type field).
201    pub fn from_file(path: &std::path::Path) -> FerrotorchResult<Self> {
202        let s = std::fs::read_to_string(path).map_err(|e| FerrotorchError::InvalidArgument {
203            message: format!(
204                "VaeDecoderConfig::from_file: failed to read {}: {e}",
205                path.display()
206            ),
207        })?;
208        Self::from_json_str(&s)
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    #[test]
217    fn default_is_sd_v1_5() {
218        let c = VaeDecoderConfig::default();
219        assert_eq!(c.block_out_channels, vec![128, 256, 512, 512]);
220        assert_eq!(c.layers_per_block, 2);
221        assert_eq!(c.latent_channels, 4);
222        assert_eq!(c.norm_num_groups, 32);
223        assert_eq!(c.sample_size, 512);
224        assert_eq!(c.resnets_per_up_block(), 3);
225        assert_eq!(c.num_up_blocks(), 4);
226        // Match the published `scaling_factor` exactly.
227        assert!((c.scaling_factor - 0.18215).abs() < 1e-9);
228        c.validate().unwrap();
229    }
230
231    #[test]
232    fn validate_catches_bad_groups() {
233        // 128 not divisible by 33
234        let c = VaeDecoderConfig {
235            norm_num_groups: 33,
236            ..VaeDecoderConfig::default()
237        };
238        assert!(c.validate().is_err());
239    }
240
241    #[test]
242    fn from_json_str_round_trip() {
243        let json = r#"{
244            "in_channels": 3,
245            "out_channels": 3,
246            "down_block_types": ["DownEncoderBlock2D"],
247            "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D",
248                               "UpDecoderBlock2D", "UpDecoderBlock2D"],
249            "block_out_channels": [128, 256, 512, 512],
250            "layers_per_block": 2,
251            "act_fn": "silu",
252            "latent_channels": 4,
253            "norm_num_groups": 32,
254            "sample_size": 512,
255            "scaling_factor": 0.18215
256        }"#;
257        let c = VaeDecoderConfig::from_json_str(json).unwrap();
258        assert_eq!(c.block_out_channels, vec![128, 256, 512, 512]);
259        assert_eq!(c.layers_per_block, 2);
260        assert_eq!(c.sample_size, 512);
261    }
262}