1use std::collections::HashMap;
20
21use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
22use ferrotorch_nn::module::{Module, StateDict};
23use ferrotorch_nn::parameter::Parameter;
24use ferrotorch_nn::{Conv2d, GroupNorm, SiLU};
25
26use crate::blocks::{UNetMidBlock2D, UpDecoderBlock2D};
27use crate::config::VaeDecoderConfig;
28
29#[derive(Debug)]
31pub struct Decoder<T: Float> {
32 pub conv_in: Conv2d<T>,
34 pub mid_block: UNetMidBlock2D<T>,
36 pub up_blocks: Vec<UpDecoderBlock2D<T>>,
39 pub conv_norm_out: GroupNorm<T>,
42 pub conv_act: SiLU,
44 pub conv_out: Conv2d<T>,
46 pub config: VaeDecoderConfig,
48 training: bool,
49}
50
51impl<T: Float> Decoder<T> {
52 pub fn new(cfg: VaeDecoderConfig) -> FerrotorchResult<Self> {
62 cfg.validate()?;
63 let groups = cfg.norm_num_groups;
64 let resnet_eps = 1e-6_f64;
65 let top_channels =
66 *cfg.block_out_channels
67 .last()
68 .ok_or_else(|| FerrotorchError::InvalidArgument {
69 message: "Decoder::new: block_out_channels is empty (should be unreachable \
70 after validate)"
71 .into(),
72 })?;
73
74 let conv_in = Conv2d::<T>::new(
75 cfg.latent_channels,
76 top_channels,
77 (3, 3),
78 (1, 1),
79 (1, 1),
80 true,
81 )?;
82
83 let mid_block = UNetMidBlock2D::<T>::new(top_channels, groups, resnet_eps)?;
84
85 let reversed: Vec<usize> = cfg.block_out_channels.iter().rev().copied().collect();
86 let mut up_blocks = Vec::with_capacity(reversed.len());
87 let mut prev_out = reversed[0];
88 let num_blocks = reversed.len();
89 let resnets = cfg.resnets_per_up_block();
90 for (i, &c) in reversed.iter().enumerate() {
91 let is_final = i == num_blocks - 1;
92 up_blocks.push(UpDecoderBlock2D::<T>::new(
93 prev_out,
94 c,
95 resnets,
96 groups,
97 resnet_eps,
98 !is_final,
99 )?);
100 prev_out = c;
101 }
102
103 let bottom_channels = cfg.block_out_channels[0];
104 let conv_norm_out =
105 GroupNorm::<T>::new(groups, bottom_channels, resnet_eps, true)?;
106 let conv_out = Conv2d::<T>::new(
107 bottom_channels,
108 cfg.out_channels,
109 (3, 3),
110 (1, 1),
111 (1, 1),
112 true,
113 )?;
114
115 Ok(Self {
116 conv_in,
117 mid_block,
118 up_blocks,
119 conv_norm_out,
120 conv_act: SiLU::new(),
121 conv_out,
122 config: cfg,
123 training: false,
124 })
125 }
126}
127
128impl<T: Float> Module<T> for Decoder<T> {
129 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
130 let cfg = &self.config;
132 if input.ndim() != 4 || input.shape()[1] != cfg.latent_channels {
133 return Err(FerrotorchError::ShapeMismatch {
134 message: format!(
135 "Decoder::forward: expected [B, {}, H, W], got {:?}",
136 cfg.latent_channels,
137 input.shape()
138 ),
139 });
140 }
141 let mut h = self.conv_in.forward(input)?;
142 h = self.mid_block.forward(&h)?;
143 for up in &self.up_blocks {
144 h = up.forward(&h)?;
145 }
146 h = self.conv_norm_out.forward(&h)?;
147 h = self.conv_act.forward(&h)?;
148 self.conv_out.forward(&h)
149 }
150
151 fn parameters(&self) -> Vec<&Parameter<T>> {
152 let mut out = Vec::new();
153 out.extend(self.conv_in.parameters());
154 out.extend(self.mid_block.parameters());
155 for b in &self.up_blocks {
156 out.extend(b.parameters());
157 }
158 out.extend(self.conv_norm_out.parameters());
159 out.extend(self.conv_out.parameters());
160 out
161 }
162
163 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
164 let mut out = Vec::new();
165 out.extend(self.conv_in.parameters_mut());
166 out.extend(self.mid_block.parameters_mut());
167 for b in &mut self.up_blocks {
168 out.extend(b.parameters_mut());
169 }
170 out.extend(self.conv_norm_out.parameters_mut());
171 out.extend(self.conv_out.parameters_mut());
172 out
173 }
174
175 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
176 let mut out = Vec::new();
177 for (n, p) in self.conv_in.named_parameters() {
178 out.push((format!("conv_in.{n}"), p));
179 }
180 for (n, p) in self.mid_block.named_parameters() {
181 out.push((format!("mid_block.{n}"), p));
182 }
183 for (i, b) in self.up_blocks.iter().enumerate() {
184 for (n, p) in b.named_parameters() {
185 out.push((format!("up_blocks.{i}.{n}"), p));
186 }
187 }
188 for (n, p) in self.conv_norm_out.named_parameters() {
189 out.push((format!("conv_norm_out.{n}"), p));
190 }
191 for (n, p) in self.conv_out.named_parameters() {
192 out.push((format!("conv_out.{n}"), p));
193 }
194 out
195 }
196
197 fn train(&mut self) {
198 self.training = true;
199 for b in &mut self.up_blocks {
200 b.train();
201 }
202 self.mid_block.train();
203 }
204 fn eval(&mut self) {
205 self.training = false;
206 for b in &mut self.up_blocks {
207 b.eval();
208 }
209 self.mid_block.eval();
210 }
211 fn is_training(&self) -> bool {
212 self.training
213 }
214
215 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
216 let extract = |prefix: &str| -> StateDict<T> {
217 let p = format!("{prefix}.");
218 state
219 .iter()
220 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
221 .collect()
222 };
223
224 if strict {
225 for k in state.keys() {
226 let ok = k.starts_with("conv_in.")
227 || k.starts_with("mid_block.")
228 || k.starts_with("up_blocks.")
229 || k.starts_with("conv_norm_out.")
230 || k.starts_with("conv_out.");
231 if !ok {
232 return Err(FerrotorchError::InvalidArgument {
233 message: format!("unexpected key in Decoder state_dict: \"{k}\""),
234 });
235 }
236 }
237 }
238
239 self.conv_in.load_state_dict(&extract("conv_in"), strict)?;
240 self.mid_block
241 .load_state_dict(&extract("mid_block"), strict)?;
242 for (i, b) in self.up_blocks.iter_mut().enumerate() {
243 b.load_state_dict(&extract(&format!("up_blocks.{i}")), strict)?;
244 }
245 self.conv_norm_out
246 .load_state_dict(&extract("conv_norm_out"), strict)?;
247 self.conv_out
248 .load_state_dict(&extract("conv_out"), strict)?;
249 Ok(())
250 }
251}
252
253#[derive(Debug)]
261pub struct VaeDecoder<T: Float> {
262 pub post_quant_conv: Conv2d<T>,
264 pub decoder: Decoder<T>,
266 pub config: VaeDecoderConfig,
268 training: bool,
269}
270
271impl<T: Float> VaeDecoder<T> {
272 pub fn new(cfg: VaeDecoderConfig) -> FerrotorchResult<Self> {
278 cfg.validate()?;
279 let post_quant_conv = Conv2d::<T>::new(
280 cfg.latent_channels,
281 cfg.latent_channels,
282 (1, 1),
283 (1, 1),
284 (0, 0),
285 true,
286 )?;
287 let decoder = Decoder::<T>::new(cfg.clone())?;
288 Ok(Self {
289 post_quant_conv,
290 decoder,
291 config: cfg,
292 training: false,
293 })
294 }
295
296 pub fn decode_with_scaling(&self, latent: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
304 let inv = self.config.scaling_factor.recip();
305 let inv_t = T::from(inv).ok_or_else(|| FerrotorchError::InvalidArgument {
306 message: format!(
307 "VaeDecoder::decode_with_scaling: cannot cast 1/{} into Float",
308 self.config.scaling_factor
309 ),
310 })?;
311 let inv_tensor = ferrotorch_core::scalar::<T>(inv_t)?;
312 let scaled = ferrotorch_core::grad_fns::arithmetic::mul(latent, &inv_tensor)?;
313 self.forward(&scaled)
314 }
315}
316
317impl<T: Float> Module<T> for VaeDecoder<T> {
318 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
321 let cfg = &self.config;
322 if input.ndim() != 4 || input.shape()[1] != cfg.latent_channels {
323 return Err(FerrotorchError::ShapeMismatch {
324 message: format!(
325 "VaeDecoder::forward: expected [B, {}, H, W], got {:?}",
326 cfg.latent_channels,
327 input.shape()
328 ),
329 });
330 }
331 let post = self.post_quant_conv.forward(input)?;
332 self.decoder.forward(&post)
333 }
334
335 fn parameters(&self) -> Vec<&Parameter<T>> {
336 let mut out = Vec::new();
337 out.extend(self.post_quant_conv.parameters());
338 out.extend(self.decoder.parameters());
339 out
340 }
341
342 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
343 let mut out = Vec::new();
344 out.extend(self.post_quant_conv.parameters_mut());
345 out.extend(self.decoder.parameters_mut());
346 out
347 }
348
349 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
350 let mut out = Vec::new();
351 for (n, p) in self.post_quant_conv.named_parameters() {
352 out.push((format!("post_quant_conv.{n}"), p));
353 }
354 for (n, p) in self.decoder.named_parameters() {
355 out.push((format!("decoder.{n}"), p));
356 }
357 out
358 }
359
360 fn train(&mut self) {
361 self.training = true;
362 self.decoder.train();
363 }
364 fn eval(&mut self) {
365 self.training = false;
366 self.decoder.eval();
367 }
368 fn is_training(&self) -> bool {
369 self.training
370 }
371
372 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
373 let extract = |prefix: &str| -> StateDict<T> {
374 let p = format!("{prefix}.");
375 state
376 .iter()
377 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
378 .collect()
379 };
380 if strict {
381 for k in state.keys() {
382 let ok = k.starts_with("post_quant_conv.") || k.starts_with("decoder.");
383 if !ok {
384 return Err(FerrotorchError::InvalidArgument {
385 message: format!("unexpected key in VaeDecoder state_dict: \"{k}\""),
386 });
387 }
388 }
389 }
390 self.post_quant_conv
391 .load_state_dict(&extract("post_quant_conv"), strict)?;
392 self.decoder.load_state_dict(&extract("decoder"), strict)?;
393 let _: HashMap<String, Tensor<T>> = HashMap::new(); Ok(())
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use ferrotorch_core::TensorStorage;
402
403 fn tiny_cfg() -> VaeDecoderConfig {
407 VaeDecoderConfig {
408 out_channels: 3,
409 latent_channels: 4,
410 block_out_channels: vec![4, 8, 16, 16],
414 layers_per_block: 1, norm_num_groups: 4,
416 sample_size: 8,
417 scaling_factor: 0.18215,
418 }
419 }
420
421 #[test]
422 fn decoder_forward_shape() {
423 let cfg = tiny_cfg();
424 let d = Decoder::<f32>::new(cfg.clone()).unwrap();
425 let x = Tensor::from_storage(
427 TensorStorage::cpu(vec![0.01f32; 4]),
428 vec![1, 4, 1, 1],
429 false,
430 )
431 .unwrap();
432 let y = d.forward(&x).unwrap();
433 assert_eq!(y.shape(), &[1, 3, 8, 8]);
435 for &v in y.data().unwrap() {
436 assert!(v.is_finite(), "decoder output non-finite: {v}");
437 }
438 }
439
440 #[test]
441 fn vae_decoder_named_parameters_include_post_quant_conv() {
442 let cfg = tiny_cfg();
443 let v = VaeDecoder::<f32>::new(cfg).unwrap();
444 let names: Vec<String> = v.named_parameters().into_iter().map(|(n, _)| n).collect();
445 for k in [
446 "post_quant_conv.weight",
447 "post_quant_conv.bias",
448 "decoder.conv_in.weight",
449 "decoder.mid_block.attentions.0.to_q.weight",
450 "decoder.up_blocks.0.resnets.0.norm1.weight",
451 "decoder.conv_norm_out.weight",
452 "decoder.conv_out.bias",
453 ] {
454 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
455 }
456 }
457
458 #[test]
459 fn vae_decoder_forward_shape() {
460 let cfg = tiny_cfg();
461 let v = VaeDecoder::<f32>::new(cfg).unwrap();
462 let x = Tensor::from_storage(
463 TensorStorage::cpu(vec![0.01f32; 4]),
464 vec![1, 4, 1, 1],
465 false,
466 )
467 .unwrap();
468 let y = v.forward(&x).unwrap();
469 assert_eq!(y.shape(), &[1, 3, 8, 8]);
470 }
471
472 #[test]
473 fn vae_decoder_decode_with_scaling_matches_manual_div() {
474 let cfg = tiny_cfg();
475 let v = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
476 let x = Tensor::from_storage(
477 TensorStorage::cpu(vec![0.05f32; 4]),
478 vec![1, 4, 1, 1],
479 false,
480 )
481 .unwrap();
482 let inv = (1.0 / cfg.scaling_factor) as f32;
483 let scaled_data: Vec<f32> =
484 x.data().unwrap().iter().map(|&v| v * inv).collect();
485 let scaled = Tensor::from_storage(
486 TensorStorage::cpu(scaled_data),
487 vec![1, 4, 1, 1],
488 false,
489 )
490 .unwrap();
491 let a = v.decode_with_scaling(&x).unwrap();
492 let b = v.forward(&scaled).unwrap();
493 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
494 assert!(
495 (x - y).abs() < 1e-4,
496 "decode_with_scaling vs manual div differ: {x} vs {y}"
497 );
498 }
499 }
500
501 #[test]
502 fn round_trip_state_dict() {
503 let cfg = tiny_cfg();
504 let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
505 let sd = src.state_dict();
506 let mut dst = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
507 dst.load_state_dict(&sd, true).unwrap();
508 let x = Tensor::from_storage(
509 TensorStorage::cpu(vec![0.01f32; 4]),
510 vec![1, 4, 1, 1],
511 false,
512 )
513 .unwrap();
514 let a = src.forward(&x).unwrap();
515 let b = dst.forward(&x).unwrap();
516 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
517 assert!((x - y).abs() < 1e-5, "round-trip differs: {x} vs {y}");
518 }
519 }
520}