Skip to main content

ferrotorch_diffusion/
config.rs

1//! Configuration for the Stable-Diffusion VAE decoder.
2//!
3//! Matches the public surface of `diffusers.AutoencoderKL.config` for the
4//! fields the decoder actually consumes. Encoder-side fields (e.g.
5//! `down_block_types`) are not stored — the decoder mirror is
6//! decoder-only.
7//!
8//! ## REQ status (per `.design/ferrotorch-diffusion/config.md`)
9//!
10//! | REQ | Status | Evidence |
11//! |---|---|---|
12//! | REQ-1 | SHIPPED | `Default::default` at `ferrotorch-diffusion/src/config.rs:42..55`; consumer: `ferrotorch-diffusion/src/vae.rs:271` calls `VaeDecoder::<T>::new(cfg)` |
13//! | REQ-2 | SHIPPED | `validate` at `ferrotorch-diffusion/src/config.rs:70..121`; consumer: `from_json_str` at `ferrotorch-diffusion/src/config.rs:191` and `VaeDecoder::new` at `vae.rs:62` invoke it |
14//! | REQ-3 | SHIPPED | `from_json_str` at `ferrotorch-diffusion/src/config.rs:148..193`; consumer: `from_file` at `ferrotorch-diffusion/src/config.rs:201..209` and the dump examples |
15//! | REQ-4 | SHIPPED | `resnets_per_up_block` at `ferrotorch-diffusion/src/config.rs:125..127`; consumer: `ferrotorch-diffusion/src/vae.rs:88` uses it to size `UpDecoderBlock2D` |
16
17use ferrotorch_core::{FerrotorchError, FerrotorchResult};
18
19/// Frozen config for the Stable-Diffusion VAE decoder.
20///
21/// Mirrors the decoder-relevant subset of `AutoencoderKL.config` for
22/// `runwayml/stable-diffusion-v1-5`. The defaults match SD 1.5 exactly.
23#[derive(Debug, Clone)]
24pub struct VaeDecoderConfig {
25    /// Number of input channels of the image the encoder consumes (and
26    /// therefore of the image the decoder produces). For SD 1.5: 3.
27    pub out_channels: usize,
28    /// Number of latent channels. For SD 1.5: 4.
29    pub latent_channels: usize,
30    /// Per-block-level output channel counts (in encoder order: from
31    /// the highest-resolution block out). For SD 1.5: `[128, 256, 512,
32    /// 512]`. The decoder walks these in reverse, so the first block
33    /// after `conv_in` has `block_out_channels[-1]` channels (= 512).
34    pub block_out_channels: Vec<usize>,
35    /// Number of resnet layers in each Encoder / Decoder up- or
36    /// down-block. The decoder's `UpDecoderBlock2D` uses
37    /// `layers_per_block + 1` resnets (the diffusers convention). For
38    /// SD 1.5: 2 (so each up-block has 3 resnets).
39    pub layers_per_block: usize,
40    /// Number of GroupNorm groups (decoder-internal `norm1` / `norm2` /
41    /// `conv_norm_out`). For SD 1.5: 32.
42    pub norm_num_groups: usize,
43    /// Spatial size the encoder accepts (and the decoder produces).
44    /// For SD 1.5: 512.
45    pub sample_size: usize,
46    /// VAE latent scaling factor. The decoder pre-divides the latent by
47    /// this value (matching `AutoencoderKL.decode`). For SD 1.5: 0.18215.
48    pub scaling_factor: f64,
49}
50
51impl Default for VaeDecoderConfig {
52    fn default() -> Self {
53        // SD 1.5 VAE config.
54        Self {
55            out_channels: 3,
56            latent_channels: 4,
57            block_out_channels: vec![128, 256, 512, 512],
58            layers_per_block: 2,
59            norm_num_groups: 32,
60            sample_size: 512,
61            scaling_factor: 0.18215,
62        }
63    }
64}
65
66impl VaeDecoderConfig {
67    /// SD 1.5 VAE decoder config (alias for `Default::default()`).
68    pub fn sd_v1_5() -> Self {
69        Self::default()
70    }
71
72    /// Validate field bounds (positive sizes, channels divisible by
73    /// `norm_num_groups`, at least one resolution).
74    ///
75    /// # Errors
76    ///
77    /// Returns [`FerrotorchError::InvalidArgument`] for any out-of-bounds
78    /// or arithmetic-incompatible field.
79    pub fn validate(&self) -> FerrotorchResult<()> {
80        if self.block_out_channels.is_empty() {
81            return Err(FerrotorchError::InvalidArgument {
82                message: "VaeDecoderConfig: block_out_channels must be non-empty".into(),
83            });
84        }
85        if self.norm_num_groups == 0 {
86            return Err(FerrotorchError::InvalidArgument {
87                message: "VaeDecoderConfig: norm_num_groups must be > 0".into(),
88            });
89        }
90        for &c in &self.block_out_channels {
91            if c == 0 || c % self.norm_num_groups != 0 {
92                return Err(FerrotorchError::InvalidArgument {
93                    message: format!(
94                        "VaeDecoderConfig: block_out_channels entry {c} must be > 0 and divisible \
95                         by norm_num_groups={}",
96                        self.norm_num_groups
97                    ),
98                });
99            }
100        }
101        if self.latent_channels == 0 {
102            return Err(FerrotorchError::InvalidArgument {
103                message: "VaeDecoderConfig: latent_channels must be > 0".into(),
104            });
105        }
106        if self.out_channels == 0 {
107            return Err(FerrotorchError::InvalidArgument {
108                message: "VaeDecoderConfig: out_channels must be > 0".into(),
109            });
110        }
111        if self.layers_per_block == 0 {
112            return Err(FerrotorchError::InvalidArgument {
113                message: "VaeDecoderConfig: layers_per_block must be > 0".into(),
114            });
115        }
116        if self.sample_size == 0 {
117            return Err(FerrotorchError::InvalidArgument {
118                message: "VaeDecoderConfig: sample_size must be > 0".into(),
119            });
120        }
121        if !self.scaling_factor.is_finite() || self.scaling_factor == 0.0 {
122            return Err(FerrotorchError::InvalidArgument {
123                message: format!(
124                    "VaeDecoderConfig: scaling_factor must be finite and non-zero, got {}",
125                    self.scaling_factor
126                ),
127            });
128        }
129        Ok(())
130    }
131
132    /// Number of resnets in each `UpDecoderBlock2D` (the diffusers
133    /// convention is `layers_per_block + 1`).
134    pub fn resnets_per_up_block(&self) -> usize {
135        self.layers_per_block + 1
136    }
137
138    /// Number of up-blocks (== number of down-blocks the encoder used,
139    /// == `block_out_channels.len()`).
140    pub fn num_up_blocks(&self) -> usize {
141        self.block_out_channels.len()
142    }
143
144    /// Parse a `vae/config.json` document into a [`VaeDecoderConfig`].
145    ///
146    /// Recognised keys (all optional — anything missing falls back to
147    /// the SD-1.5 defaults):
148    ///   - `out_channels`, `latent_channels`, `block_out_channels`,
149    ///     `layers_per_block`, `norm_num_groups`, `sample_size`,
150    ///     `scaling_factor`.
151    ///
152    /// # Errors
153    ///
154    /// Returns [`FerrotorchError::InvalidArgument`] on malformed JSON or
155    /// a wrong-type field (e.g. `block_out_channels` not an array of
156    /// integers).
157    pub fn from_json_str(s: &str) -> FerrotorchResult<Self> {
158        let v: serde_json::Value =
159            serde_json::from_str(s).map_err(|e| FerrotorchError::InvalidArgument {
160                message: format!("VaeDecoderConfig::from_json_str: bad JSON: {e}"),
161            })?;
162        let mut cfg = Self::default();
163        if let Some(x) = v.get("out_channels").and_then(serde_json::Value::as_u64) {
164            cfg.out_channels = x as usize;
165        }
166        if let Some(x) = v.get("latent_channels").and_then(serde_json::Value::as_u64) {
167            cfg.latent_channels = x as usize;
168        }
169        if let Some(arr) = v
170            .get("block_out_channels")
171            .and_then(serde_json::Value::as_array)
172        {
173            let mut out = Vec::with_capacity(arr.len());
174            for e in arr {
175                let n = e.as_u64().ok_or_else(|| FerrotorchError::InvalidArgument {
176                    message: format!(
177                        "VaeDecoderConfig::from_json_str: block_out_channels entry \
178                         must be a non-negative integer, got {e}"
179                    ),
180                })?;
181                out.push(n as usize);
182            }
183            cfg.block_out_channels = out;
184        }
185        if let Some(x) = v
186            .get("layers_per_block")
187            .and_then(serde_json::Value::as_u64)
188        {
189            cfg.layers_per_block = x as usize;
190        }
191        if let Some(x) = v.get("norm_num_groups").and_then(serde_json::Value::as_u64) {
192            cfg.norm_num_groups = x as usize;
193        }
194        if let Some(x) = v.get("sample_size").and_then(serde_json::Value::as_u64) {
195            cfg.sample_size = x as usize;
196        }
197        if let Some(x) = v.get("scaling_factor").and_then(serde_json::Value::as_f64) {
198            cfg.scaling_factor = x;
199        }
200        cfg.validate()?;
201        Ok(cfg)
202    }
203
204    /// Parse a `vae/config.json` file from disk.
205    ///
206    /// # Errors
207    ///
208    /// Returns [`FerrotorchError::InvalidArgument`] for I/O or parse
209    /// failures (file missing, malformed JSON, wrong-type field).
210    pub fn from_file(path: &std::path::Path) -> FerrotorchResult<Self> {
211        let s = std::fs::read_to_string(path).map_err(|e| FerrotorchError::InvalidArgument {
212            message: format!(
213                "VaeDecoderConfig::from_file: failed to read {}: {e}",
214                path.display()
215            ),
216        })?;
217        Self::from_json_str(&s)
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn default_is_sd_v1_5() {
227        let c = VaeDecoderConfig::default();
228        assert_eq!(c.block_out_channels, vec![128, 256, 512, 512]);
229        assert_eq!(c.layers_per_block, 2);
230        assert_eq!(c.latent_channels, 4);
231        assert_eq!(c.norm_num_groups, 32);
232        assert_eq!(c.sample_size, 512);
233        assert_eq!(c.resnets_per_up_block(), 3);
234        assert_eq!(c.num_up_blocks(), 4);
235        // Match the published `scaling_factor` exactly.
236        assert!((c.scaling_factor - 0.18215).abs() < 1e-9);
237        c.validate().unwrap();
238    }
239
240    #[test]
241    fn validate_catches_bad_groups() {
242        // 128 not divisible by 33
243        let c = VaeDecoderConfig {
244            norm_num_groups: 33,
245            ..VaeDecoderConfig::default()
246        };
247        assert!(c.validate().is_err());
248    }
249
250    #[test]
251    fn from_json_str_round_trip() {
252        let json = r#"{
253            "in_channels": 3,
254            "out_channels": 3,
255            "down_block_types": ["DownEncoderBlock2D"],
256            "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D",
257                               "UpDecoderBlock2D", "UpDecoderBlock2D"],
258            "block_out_channels": [128, 256, 512, 512],
259            "layers_per_block": 2,
260            "act_fn": "silu",
261            "latent_channels": 4,
262            "norm_num_groups": 32,
263            "sample_size": 512,
264            "scaling_factor": 0.18215
265        }"#;
266        let c = VaeDecoderConfig::from_json_str(json).unwrap();
267        assert_eq!(c.block_out_channels, vec![128, 256, 512, 512]);
268        assert_eq!(c.layers_per_block, 2);
269        assert_eq!(c.sample_size, 512);
270    }
271}