ferrotorch_diffusion/
config.rs1use ferrotorch_core::{FerrotorchError, FerrotorchResult};
9
10#[derive(Debug, Clone)]
15pub struct VaeDecoderConfig {
16 pub out_channels: usize,
19 pub latent_channels: usize,
21 pub block_out_channels: Vec<usize>,
26 pub layers_per_block: usize,
31 pub norm_num_groups: usize,
34 pub sample_size: usize,
37 pub scaling_factor: f64,
40}
41
42impl Default for VaeDecoderConfig {
43 fn default() -> Self {
44 Self {
46 out_channels: 3,
47 latent_channels: 4,
48 block_out_channels: vec![128, 256, 512, 512],
49 layers_per_block: 2,
50 norm_num_groups: 32,
51 sample_size: 512,
52 scaling_factor: 0.18215,
53 }
54 }
55}
56
57impl VaeDecoderConfig {
58 pub fn sd_v1_5() -> Self {
60 Self::default()
61 }
62
63 pub fn validate(&self) -> FerrotorchResult<()> {
71 if self.block_out_channels.is_empty() {
72 return Err(FerrotorchError::InvalidArgument {
73 message: "VaeDecoderConfig: block_out_channels must be non-empty".into(),
74 });
75 }
76 if self.norm_num_groups == 0 {
77 return Err(FerrotorchError::InvalidArgument {
78 message: "VaeDecoderConfig: norm_num_groups must be > 0".into(),
79 });
80 }
81 for &c in &self.block_out_channels {
82 if c == 0 || c % self.norm_num_groups != 0 {
83 return Err(FerrotorchError::InvalidArgument {
84 message: format!(
85 "VaeDecoderConfig: block_out_channels entry {c} must be > 0 and divisible \
86 by norm_num_groups={}",
87 self.norm_num_groups
88 ),
89 });
90 }
91 }
92 if self.latent_channels == 0 {
93 return Err(FerrotorchError::InvalidArgument {
94 message: "VaeDecoderConfig: latent_channels must be > 0".into(),
95 });
96 }
97 if self.out_channels == 0 {
98 return Err(FerrotorchError::InvalidArgument {
99 message: "VaeDecoderConfig: out_channels must be > 0".into(),
100 });
101 }
102 if self.layers_per_block == 0 {
103 return Err(FerrotorchError::InvalidArgument {
104 message: "VaeDecoderConfig: layers_per_block must be > 0".into(),
105 });
106 }
107 if self.sample_size == 0 {
108 return Err(FerrotorchError::InvalidArgument {
109 message: "VaeDecoderConfig: sample_size must be > 0".into(),
110 });
111 }
112 if !self.scaling_factor.is_finite() || self.scaling_factor == 0.0 {
113 return Err(FerrotorchError::InvalidArgument {
114 message: format!(
115 "VaeDecoderConfig: scaling_factor must be finite and non-zero, got {}",
116 self.scaling_factor
117 ),
118 });
119 }
120 Ok(())
121 }
122
123 pub fn resnets_per_up_block(&self) -> usize {
126 self.layers_per_block + 1
127 }
128
129 pub fn num_up_blocks(&self) -> usize {
132 self.block_out_channels.len()
133 }
134
135 pub fn from_json_str(s: &str) -> FerrotorchResult<Self> {
149 let v: serde_json::Value = serde_json::from_str(s).map_err(|e| {
150 FerrotorchError::InvalidArgument {
151 message: format!("VaeDecoderConfig::from_json_str: bad JSON: {e}"),
152 }
153 })?;
154 let mut cfg = Self::default();
155 if let Some(x) = v.get("out_channels").and_then(serde_json::Value::as_u64) {
156 cfg.out_channels = x as usize;
157 }
158 if let Some(x) = v.get("latent_channels").and_then(serde_json::Value::as_u64) {
159 cfg.latent_channels = x as usize;
160 }
161 if let Some(arr) = v.get("block_out_channels").and_then(serde_json::Value::as_array) {
162 let mut out = Vec::with_capacity(arr.len());
163 for e in arr {
164 let n = e.as_u64().ok_or_else(|| FerrotorchError::InvalidArgument {
165 message: format!(
166 "VaeDecoderConfig::from_json_str: block_out_channels entry \
167 must be a non-negative integer, got {e}"
168 ),
169 })?;
170 out.push(n as usize);
171 }
172 cfg.block_out_channels = out;
173 }
174 if let Some(x) = v.get("layers_per_block").and_then(serde_json::Value::as_u64) {
175 cfg.layers_per_block = x as usize;
176 }
177 if let Some(x) = v.get("norm_num_groups").and_then(serde_json::Value::as_u64) {
178 cfg.norm_num_groups = x as usize;
179 }
180 if let Some(x) = v.get("sample_size").and_then(serde_json::Value::as_u64) {
181 cfg.sample_size = x as usize;
182 }
183 if let Some(x) = v.get("scaling_factor").and_then(serde_json::Value::as_f64) {
184 cfg.scaling_factor = x;
185 }
186 cfg.validate()?;
187 Ok(cfg)
188 }
189
190 pub fn from_file(path: &std::path::Path) -> FerrotorchResult<Self> {
197 let s = std::fs::read_to_string(path).map_err(|e| FerrotorchError::InvalidArgument {
198 message: format!(
199 "VaeDecoderConfig::from_file: failed to read {}: {e}",
200 path.display()
201 ),
202 })?;
203 Self::from_json_str(&s)
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn default_is_sd_v1_5() {
213 let c = VaeDecoderConfig::default();
214 assert_eq!(c.block_out_channels, vec![128, 256, 512, 512]);
215 assert_eq!(c.layers_per_block, 2);
216 assert_eq!(c.latent_channels, 4);
217 assert_eq!(c.norm_num_groups, 32);
218 assert_eq!(c.sample_size, 512);
219 assert_eq!(c.resnets_per_up_block(), 3);
220 assert_eq!(c.num_up_blocks(), 4);
221 assert!((c.scaling_factor - 0.18215).abs() < 1e-9);
223 c.validate().unwrap();
224 }
225
226 #[test]
227 fn validate_catches_bad_groups() {
228 let c = VaeDecoderConfig {
230 norm_num_groups: 33,
231 ..VaeDecoderConfig::default()
232 };
233 assert!(c.validate().is_err());
234 }
235
236 #[test]
237 fn from_json_str_round_trip() {
238 let json = r#"{
239 "in_channels": 3,
240 "out_channels": 3,
241 "down_block_types": ["DownEncoderBlock2D"],
242 "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D",
243 "UpDecoderBlock2D", "UpDecoderBlock2D"],
244 "block_out_channels": [128, 256, 512, 512],
245 "layers_per_block": 2,
246 "act_fn": "silu",
247 "latent_channels": 4,
248 "norm_num_groups": 32,
249 "sample_size": 512,
250 "scaling_factor": 0.18215
251 }"#;
252 let c = VaeDecoderConfig::from_json_str(json).unwrap();
253 assert_eq!(c.block_out_channels, vec![128, 256, 512, 512]);
254 assert_eq!(c.layers_per_block, 2);
255 assert_eq!(c.sample_size, 512);
256 }
257}