ferrotorch_diffusion/
config.rs1use ferrotorch_core::{FerrotorchError, FerrotorchResult};
18
19#[derive(Debug, Clone)]
24pub struct VaeDecoderConfig {
25 pub out_channels: usize,
28 pub latent_channels: usize,
30 pub block_out_channels: Vec<usize>,
35 pub layers_per_block: usize,
40 pub norm_num_groups: usize,
43 pub sample_size: usize,
46 pub scaling_factor: f64,
49}
50
51impl Default for VaeDecoderConfig {
52 fn default() -> Self {
53 Self {
55 out_channels: 3,
56 latent_channels: 4,
57 block_out_channels: vec![128, 256, 512, 512],
58 layers_per_block: 2,
59 norm_num_groups: 32,
60 sample_size: 512,
61 scaling_factor: 0.18215,
62 }
63 }
64}
65
66impl VaeDecoderConfig {
67 pub fn sd_v1_5() -> Self {
69 Self::default()
70 }
71
72 pub fn validate(&self) -> FerrotorchResult<()> {
80 if self.block_out_channels.is_empty() {
81 return Err(FerrotorchError::InvalidArgument {
82 message: "VaeDecoderConfig: block_out_channels must be non-empty".into(),
83 });
84 }
85 if self.norm_num_groups == 0 {
86 return Err(FerrotorchError::InvalidArgument {
87 message: "VaeDecoderConfig: norm_num_groups must be > 0".into(),
88 });
89 }
90 for &c in &self.block_out_channels {
91 if c == 0 || c % self.norm_num_groups != 0 {
92 return Err(FerrotorchError::InvalidArgument {
93 message: format!(
94 "VaeDecoderConfig: block_out_channels entry {c} must be > 0 and divisible \
95 by norm_num_groups={}",
96 self.norm_num_groups
97 ),
98 });
99 }
100 }
101 if self.latent_channels == 0 {
102 return Err(FerrotorchError::InvalidArgument {
103 message: "VaeDecoderConfig: latent_channels must be > 0".into(),
104 });
105 }
106 if self.out_channels == 0 {
107 return Err(FerrotorchError::InvalidArgument {
108 message: "VaeDecoderConfig: out_channels must be > 0".into(),
109 });
110 }
111 if self.layers_per_block == 0 {
112 return Err(FerrotorchError::InvalidArgument {
113 message: "VaeDecoderConfig: layers_per_block must be > 0".into(),
114 });
115 }
116 if self.sample_size == 0 {
117 return Err(FerrotorchError::InvalidArgument {
118 message: "VaeDecoderConfig: sample_size must be > 0".into(),
119 });
120 }
121 if !self.scaling_factor.is_finite() || self.scaling_factor == 0.0 {
122 return Err(FerrotorchError::InvalidArgument {
123 message: format!(
124 "VaeDecoderConfig: scaling_factor must be finite and non-zero, got {}",
125 self.scaling_factor
126 ),
127 });
128 }
129 Ok(())
130 }
131
132 pub fn resnets_per_up_block(&self) -> usize {
135 self.layers_per_block + 1
136 }
137
138 pub fn num_up_blocks(&self) -> usize {
141 self.block_out_channels.len()
142 }
143
144 pub fn from_json_str(s: &str) -> FerrotorchResult<Self> {
158 let v: serde_json::Value =
159 serde_json::from_str(s).map_err(|e| FerrotorchError::InvalidArgument {
160 message: format!("VaeDecoderConfig::from_json_str: bad JSON: {e}"),
161 })?;
162 let mut cfg = Self::default();
163 if let Some(x) = v.get("out_channels").and_then(serde_json::Value::as_u64) {
164 cfg.out_channels = x as usize;
165 }
166 if let Some(x) = v.get("latent_channels").and_then(serde_json::Value::as_u64) {
167 cfg.latent_channels = x as usize;
168 }
169 if let Some(arr) = v
170 .get("block_out_channels")
171 .and_then(serde_json::Value::as_array)
172 {
173 let mut out = Vec::with_capacity(arr.len());
174 for e in arr {
175 let n = e.as_u64().ok_or_else(|| FerrotorchError::InvalidArgument {
176 message: format!(
177 "VaeDecoderConfig::from_json_str: block_out_channels entry \
178 must be a non-negative integer, got {e}"
179 ),
180 })?;
181 out.push(n as usize);
182 }
183 cfg.block_out_channels = out;
184 }
185 if let Some(x) = v
186 .get("layers_per_block")
187 .and_then(serde_json::Value::as_u64)
188 {
189 cfg.layers_per_block = x as usize;
190 }
191 if let Some(x) = v.get("norm_num_groups").and_then(serde_json::Value::as_u64) {
192 cfg.norm_num_groups = x as usize;
193 }
194 if let Some(x) = v.get("sample_size").and_then(serde_json::Value::as_u64) {
195 cfg.sample_size = x as usize;
196 }
197 if let Some(x) = v.get("scaling_factor").and_then(serde_json::Value::as_f64) {
198 cfg.scaling_factor = x;
199 }
200 cfg.validate()?;
201 Ok(cfg)
202 }
203
204 pub fn from_file(path: &std::path::Path) -> FerrotorchResult<Self> {
211 let s = std::fs::read_to_string(path).map_err(|e| FerrotorchError::InvalidArgument {
212 message: format!(
213 "VaeDecoderConfig::from_file: failed to read {}: {e}",
214 path.display()
215 ),
216 })?;
217 Self::from_json_str(&s)
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn default_is_sd_v1_5() {
227 let c = VaeDecoderConfig::default();
228 assert_eq!(c.block_out_channels, vec![128, 256, 512, 512]);
229 assert_eq!(c.layers_per_block, 2);
230 assert_eq!(c.latent_channels, 4);
231 assert_eq!(c.norm_num_groups, 32);
232 assert_eq!(c.sample_size, 512);
233 assert_eq!(c.resnets_per_up_block(), 3);
234 assert_eq!(c.num_up_blocks(), 4);
235 assert!((c.scaling_factor - 0.18215).abs() < 1e-9);
237 c.validate().unwrap();
238 }
239
240 #[test]
241 fn validate_catches_bad_groups() {
242 let c = VaeDecoderConfig {
244 norm_num_groups: 33,
245 ..VaeDecoderConfig::default()
246 };
247 assert!(c.validate().is_err());
248 }
249
250 #[test]
251 fn from_json_str_round_trip() {
252 let json = r#"{
253 "in_channels": 3,
254 "out_channels": 3,
255 "down_block_types": ["DownEncoderBlock2D"],
256 "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D",
257 "UpDecoderBlock2D", "UpDecoderBlock2D"],
258 "block_out_channels": [128, 256, 512, 512],
259 "layers_per_block": 2,
260 "act_fn": "silu",
261 "latent_channels": 4,
262 "norm_num_groups": 32,
263 "sample_size": 512,
264 "scaling_factor": 0.18215
265 }"#;
266 let c = VaeDecoderConfig::from_json_str(json).unwrap();
267 assert_eq!(c.block_out_channels, vec![128, 256, 512, 512]);
268 assert_eq!(c.layers_per_block, 2);
269 assert_eq!(c.sample_size, 512);
270 }
271}