ferrotorch_diffusion/
config.rs1use ferrotorch_core::{FerrotorchError, FerrotorchResult};
9
10#[derive(Debug, Clone)]
15pub struct VaeDecoderConfig {
16 pub out_channels: usize,
19 pub latent_channels: usize,
21 pub block_out_channels: Vec<usize>,
26 pub layers_per_block: usize,
31 pub norm_num_groups: usize,
34 pub sample_size: usize,
37 pub scaling_factor: f64,
40}
41
42impl Default for VaeDecoderConfig {
43 fn default() -> Self {
44 Self {
46 out_channels: 3,
47 latent_channels: 4,
48 block_out_channels: vec![128, 256, 512, 512],
49 layers_per_block: 2,
50 norm_num_groups: 32,
51 sample_size: 512,
52 scaling_factor: 0.18215,
53 }
54 }
55}
56
57impl VaeDecoderConfig {
58 pub fn sd_v1_5() -> Self {
60 Self::default()
61 }
62
63 pub fn validate(&self) -> FerrotorchResult<()> {
71 if self.block_out_channels.is_empty() {
72 return Err(FerrotorchError::InvalidArgument {
73 message: "VaeDecoderConfig: block_out_channels must be non-empty".into(),
74 });
75 }
76 if self.norm_num_groups == 0 {
77 return Err(FerrotorchError::InvalidArgument {
78 message: "VaeDecoderConfig: norm_num_groups must be > 0".into(),
79 });
80 }
81 for &c in &self.block_out_channels {
82 if c == 0 || c % self.norm_num_groups != 0 {
83 return Err(FerrotorchError::InvalidArgument {
84 message: format!(
85 "VaeDecoderConfig: block_out_channels entry {c} must be > 0 and divisible \
86 by norm_num_groups={}",
87 self.norm_num_groups
88 ),
89 });
90 }
91 }
92 if self.latent_channels == 0 {
93 return Err(FerrotorchError::InvalidArgument {
94 message: "VaeDecoderConfig: latent_channels must be > 0".into(),
95 });
96 }
97 if self.out_channels == 0 {
98 return Err(FerrotorchError::InvalidArgument {
99 message: "VaeDecoderConfig: out_channels must be > 0".into(),
100 });
101 }
102 if self.layers_per_block == 0 {
103 return Err(FerrotorchError::InvalidArgument {
104 message: "VaeDecoderConfig: layers_per_block must be > 0".into(),
105 });
106 }
107 if self.sample_size == 0 {
108 return Err(FerrotorchError::InvalidArgument {
109 message: "VaeDecoderConfig: sample_size must be > 0".into(),
110 });
111 }
112 if !self.scaling_factor.is_finite() || self.scaling_factor == 0.0 {
113 return Err(FerrotorchError::InvalidArgument {
114 message: format!(
115 "VaeDecoderConfig: scaling_factor must be finite and non-zero, got {}",
116 self.scaling_factor
117 ),
118 });
119 }
120 Ok(())
121 }
122
123 pub fn resnets_per_up_block(&self) -> usize {
126 self.layers_per_block + 1
127 }
128
129 pub fn num_up_blocks(&self) -> usize {
132 self.block_out_channels.len()
133 }
134
135 pub fn from_json_str(s: &str) -> FerrotorchResult<Self> {
149 let v: serde_json::Value =
150 serde_json::from_str(s).map_err(|e| FerrotorchError::InvalidArgument {
151 message: format!("VaeDecoderConfig::from_json_str: bad JSON: {e}"),
152 })?;
153 let mut cfg = Self::default();
154 if let Some(x) = v.get("out_channels").and_then(serde_json::Value::as_u64) {
155 cfg.out_channels = x as usize;
156 }
157 if let Some(x) = v.get("latent_channels").and_then(serde_json::Value::as_u64) {
158 cfg.latent_channels = x as usize;
159 }
160 if let Some(arr) = v
161 .get("block_out_channels")
162 .and_then(serde_json::Value::as_array)
163 {
164 let mut out = Vec::with_capacity(arr.len());
165 for e in arr {
166 let n = e.as_u64().ok_or_else(|| FerrotorchError::InvalidArgument {
167 message: format!(
168 "VaeDecoderConfig::from_json_str: block_out_channels entry \
169 must be a non-negative integer, got {e}"
170 ),
171 })?;
172 out.push(n as usize);
173 }
174 cfg.block_out_channels = out;
175 }
176 if let Some(x) = v
177 .get("layers_per_block")
178 .and_then(serde_json::Value::as_u64)
179 {
180 cfg.layers_per_block = x as usize;
181 }
182 if let Some(x) = v.get("norm_num_groups").and_then(serde_json::Value::as_u64) {
183 cfg.norm_num_groups = x as usize;
184 }
185 if let Some(x) = v.get("sample_size").and_then(serde_json::Value::as_u64) {
186 cfg.sample_size = x as usize;
187 }
188 if let Some(x) = v.get("scaling_factor").and_then(serde_json::Value::as_f64) {
189 cfg.scaling_factor = x;
190 }
191 cfg.validate()?;
192 Ok(cfg)
193 }
194
195 pub fn from_file(path: &std::path::Path) -> FerrotorchResult<Self> {
202 let s = std::fs::read_to_string(path).map_err(|e| FerrotorchError::InvalidArgument {
203 message: format!(
204 "VaeDecoderConfig::from_file: failed to read {}: {e}",
205 path.display()
206 ),
207 })?;
208 Self::from_json_str(&s)
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 #[test]
217 fn default_is_sd_v1_5() {
218 let c = VaeDecoderConfig::default();
219 assert_eq!(c.block_out_channels, vec![128, 256, 512, 512]);
220 assert_eq!(c.layers_per_block, 2);
221 assert_eq!(c.latent_channels, 4);
222 assert_eq!(c.norm_num_groups, 32);
223 assert_eq!(c.sample_size, 512);
224 assert_eq!(c.resnets_per_up_block(), 3);
225 assert_eq!(c.num_up_blocks(), 4);
226 assert!((c.scaling_factor - 0.18215).abs() < 1e-9);
228 c.validate().unwrap();
229 }
230
231 #[test]
232 fn validate_catches_bad_groups() {
233 let c = VaeDecoderConfig {
235 norm_num_groups: 33,
236 ..VaeDecoderConfig::default()
237 };
238 assert!(c.validate().is_err());
239 }
240
241 #[test]
242 fn from_json_str_round_trip() {
243 let json = r#"{
244 "in_channels": 3,
245 "out_channels": 3,
246 "down_block_types": ["DownEncoderBlock2D"],
247 "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D",
248 "UpDecoderBlock2D", "UpDecoderBlock2D"],
249 "block_out_channels": [128, 256, 512, 512],
250 "layers_per_block": 2,
251 "act_fn": "silu",
252 "latent_channels": 4,
253 "norm_num_groups": 32,
254 "sample_size": 512,
255 "scaling_factor": 0.18215
256 }"#;
257 let c = VaeDecoderConfig::from_json_str(json).unwrap();
258 assert_eq!(c.block_out_channels, vec![128, 256, 512, 512]);
259 assert_eq!(c.layers_per_block, 2);
260 assert_eq!(c.sample_size, 512);
261 }
262}