1use std::collections::HashMap;
20
21use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
22use ferrotorch_nn::module::{Module, StateDict};
23use ferrotorch_nn::parameter::Parameter;
24use ferrotorch_nn::{Conv2d, GroupNorm, SiLU};
25
26use crate::blocks::{UNetMidBlock2D, UpDecoderBlock2D};
27use crate::config::VaeDecoderConfig;
28
29#[derive(Debug)]
31pub struct Decoder<T: Float> {
32 pub conv_in: Conv2d<T>,
34 pub mid_block: UNetMidBlock2D<T>,
36 pub up_blocks: Vec<UpDecoderBlock2D<T>>,
39 pub conv_norm_out: GroupNorm<T>,
42 pub conv_act: SiLU,
44 pub conv_out: Conv2d<T>,
46 pub config: VaeDecoderConfig,
48 training: bool,
49}
50
51impl<T: Float> Decoder<T> {
52 pub fn new(cfg: VaeDecoderConfig) -> FerrotorchResult<Self> {
62 cfg.validate()?;
63 let groups = cfg.norm_num_groups;
64 let resnet_eps = 1e-6_f64;
65 let top_channels =
66 *cfg.block_out_channels
67 .last()
68 .ok_or_else(|| FerrotorchError::InvalidArgument {
69 message: "Decoder::new: block_out_channels is empty (should be unreachable \
70 after validate)"
71 .into(),
72 })?;
73
74 let conv_in = Conv2d::<T>::new(
75 cfg.latent_channels,
76 top_channels,
77 (3, 3),
78 (1, 1),
79 (1, 1),
80 true,
81 )?;
82
83 let mid_block = UNetMidBlock2D::<T>::new(top_channels, groups, resnet_eps)?;
84
85 let reversed: Vec<usize> = cfg.block_out_channels.iter().rev().copied().collect();
86 let mut up_blocks = Vec::with_capacity(reversed.len());
87 let mut prev_out = reversed[0];
88 let num_blocks = reversed.len();
89 let resnets = cfg.resnets_per_up_block();
90 for (i, &c) in reversed.iter().enumerate() {
91 let is_final = i == num_blocks - 1;
92 up_blocks.push(UpDecoderBlock2D::<T>::new(
93 prev_out, c, resnets, groups, resnet_eps, !is_final,
94 )?);
95 prev_out = c;
96 }
97
98 let bottom_channels = cfg.block_out_channels[0];
99 let conv_norm_out = GroupNorm::<T>::new(groups, bottom_channels, resnet_eps, true)?;
100 let conv_out = Conv2d::<T>::new(
101 bottom_channels,
102 cfg.out_channels,
103 (3, 3),
104 (1, 1),
105 (1, 1),
106 true,
107 )?;
108
109 Ok(Self {
110 conv_in,
111 mid_block,
112 up_blocks,
113 conv_norm_out,
114 conv_act: SiLU::new(),
115 conv_out,
116 config: cfg,
117 training: false,
118 })
119 }
120}
121
122impl<T: Float> Module<T> for Decoder<T> {
123 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
124 let cfg = &self.config;
126 if input.ndim() != 4 || input.shape()[1] != cfg.latent_channels {
127 return Err(FerrotorchError::ShapeMismatch {
128 message: format!(
129 "Decoder::forward: expected [B, {}, H, W], got {:?}",
130 cfg.latent_channels,
131 input.shape()
132 ),
133 });
134 }
135 let mut h = self.conv_in.forward(input)?;
136 h = self.mid_block.forward(&h)?;
137 for up in &self.up_blocks {
138 h = up.forward(&h)?;
139 }
140 h = self.conv_norm_out.forward(&h)?;
141 h = self.conv_act.forward(&h)?;
142 self.conv_out.forward(&h)
143 }
144
145 fn parameters(&self) -> Vec<&Parameter<T>> {
146 let mut out = Vec::new();
147 out.extend(self.conv_in.parameters());
148 out.extend(self.mid_block.parameters());
149 for b in &self.up_blocks {
150 out.extend(b.parameters());
151 }
152 out.extend(self.conv_norm_out.parameters());
153 out.extend(self.conv_out.parameters());
154 out
155 }
156
157 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
158 let mut out = Vec::new();
159 out.extend(self.conv_in.parameters_mut());
160 out.extend(self.mid_block.parameters_mut());
161 for b in &mut self.up_blocks {
162 out.extend(b.parameters_mut());
163 }
164 out.extend(self.conv_norm_out.parameters_mut());
165 out.extend(self.conv_out.parameters_mut());
166 out
167 }
168
169 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
170 let mut out = Vec::new();
171 for (n, p) in self.conv_in.named_parameters() {
172 out.push((format!("conv_in.{n}"), p));
173 }
174 for (n, p) in self.mid_block.named_parameters() {
175 out.push((format!("mid_block.{n}"), p));
176 }
177 for (i, b) in self.up_blocks.iter().enumerate() {
178 for (n, p) in b.named_parameters() {
179 out.push((format!("up_blocks.{i}.{n}"), p));
180 }
181 }
182 for (n, p) in self.conv_norm_out.named_parameters() {
183 out.push((format!("conv_norm_out.{n}"), p));
184 }
185 for (n, p) in self.conv_out.named_parameters() {
186 out.push((format!("conv_out.{n}"), p));
187 }
188 out
189 }
190
191 fn train(&mut self) {
192 self.training = true;
193 for b in &mut self.up_blocks {
194 b.train();
195 }
196 self.mid_block.train();
197 }
198 fn eval(&mut self) {
199 self.training = false;
200 for b in &mut self.up_blocks {
201 b.eval();
202 }
203 self.mid_block.eval();
204 }
205 fn is_training(&self) -> bool {
206 self.training
207 }
208
209 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
210 let extract = |prefix: &str| -> StateDict<T> {
211 let p = format!("{prefix}.");
212 state
213 .iter()
214 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
215 .collect()
216 };
217
218 if strict {
219 for k in state.keys() {
220 let ok = k.starts_with("conv_in.")
221 || k.starts_with("mid_block.")
222 || k.starts_with("up_blocks.")
223 || k.starts_with("conv_norm_out.")
224 || k.starts_with("conv_out.");
225 if !ok {
226 return Err(FerrotorchError::InvalidArgument {
227 message: format!("unexpected key in Decoder state_dict: \"{k}\""),
228 });
229 }
230 }
231 }
232
233 self.conv_in.load_state_dict(&extract("conv_in"), strict)?;
234 self.mid_block
235 .load_state_dict(&extract("mid_block"), strict)?;
236 for (i, b) in self.up_blocks.iter_mut().enumerate() {
237 b.load_state_dict(&extract(&format!("up_blocks.{i}")), strict)?;
238 }
239 self.conv_norm_out
240 .load_state_dict(&extract("conv_norm_out"), strict)?;
241 self.conv_out
242 .load_state_dict(&extract("conv_out"), strict)?;
243 Ok(())
244 }
245}
246
247#[derive(Debug)]
255pub struct VaeDecoder<T: Float> {
256 pub post_quant_conv: Conv2d<T>,
258 pub decoder: Decoder<T>,
260 pub config: VaeDecoderConfig,
262 training: bool,
263}
264
265impl<T: Float> VaeDecoder<T> {
266 pub fn new(cfg: VaeDecoderConfig) -> FerrotorchResult<Self> {
272 cfg.validate()?;
273 let post_quant_conv = Conv2d::<T>::new(
274 cfg.latent_channels,
275 cfg.latent_channels,
276 (1, 1),
277 (1, 1),
278 (0, 0),
279 true,
280 )?;
281 let decoder = Decoder::<T>::new(cfg.clone())?;
282 Ok(Self {
283 post_quant_conv,
284 decoder,
285 config: cfg,
286 training: false,
287 })
288 }
289
290 pub fn decode_with_scaling(&self, latent: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
298 let inv = self.config.scaling_factor.recip();
299 let inv_t = T::from(inv).ok_or_else(|| FerrotorchError::InvalidArgument {
300 message: format!(
301 "VaeDecoder::decode_with_scaling: cannot cast 1/{} into Float",
302 self.config.scaling_factor
303 ),
304 })?;
305 let inv_tensor = ferrotorch_core::scalar::<T>(inv_t)?;
306 let scaled = ferrotorch_core::grad_fns::arithmetic::mul(latent, &inv_tensor)?;
307 self.forward(&scaled)
308 }
309}
310
311impl<T: Float> Module<T> for VaeDecoder<T> {
312 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
315 let cfg = &self.config;
316 if input.ndim() != 4 || input.shape()[1] != cfg.latent_channels {
317 return Err(FerrotorchError::ShapeMismatch {
318 message: format!(
319 "VaeDecoder::forward: expected [B, {}, H, W], got {:?}",
320 cfg.latent_channels,
321 input.shape()
322 ),
323 });
324 }
325 let post = self.post_quant_conv.forward(input)?;
326 self.decoder.forward(&post)
327 }
328
329 fn parameters(&self) -> Vec<&Parameter<T>> {
330 let mut out = Vec::new();
331 out.extend(self.post_quant_conv.parameters());
332 out.extend(self.decoder.parameters());
333 out
334 }
335
336 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
337 let mut out = Vec::new();
338 out.extend(self.post_quant_conv.parameters_mut());
339 out.extend(self.decoder.parameters_mut());
340 out
341 }
342
343 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
344 let mut out = Vec::new();
345 for (n, p) in self.post_quant_conv.named_parameters() {
346 out.push((format!("post_quant_conv.{n}"), p));
347 }
348 for (n, p) in self.decoder.named_parameters() {
349 out.push((format!("decoder.{n}"), p));
350 }
351 out
352 }
353
354 fn train(&mut self) {
355 self.training = true;
356 self.decoder.train();
357 }
358 fn eval(&mut self) {
359 self.training = false;
360 self.decoder.eval();
361 }
362 fn is_training(&self) -> bool {
363 self.training
364 }
365
366 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
367 let extract = |prefix: &str| -> StateDict<T> {
368 let p = format!("{prefix}.");
369 state
370 .iter()
371 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
372 .collect()
373 };
374 if strict {
375 for k in state.keys() {
376 let ok = k.starts_with("post_quant_conv.") || k.starts_with("decoder.");
377 if !ok {
378 return Err(FerrotorchError::InvalidArgument {
379 message: format!("unexpected key in VaeDecoder state_dict: \"{k}\""),
380 });
381 }
382 }
383 }
384 self.post_quant_conv
385 .load_state_dict(&extract("post_quant_conv"), strict)?;
386 self.decoder.load_state_dict(&extract("decoder"), strict)?;
387 let _: HashMap<String, Tensor<T>> = HashMap::new(); Ok(())
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395 use ferrotorch_core::TensorStorage;
396
397 fn tiny_cfg() -> VaeDecoderConfig {
401 VaeDecoderConfig {
402 out_channels: 3,
403 latent_channels: 4,
404 block_out_channels: vec![4, 8, 16, 16],
408 layers_per_block: 1, norm_num_groups: 4,
410 sample_size: 8,
411 scaling_factor: 0.18215,
412 }
413 }
414
415 #[test]
416 fn decoder_forward_shape() {
417 let cfg = tiny_cfg();
418 let d = Decoder::<f32>::new(cfg.clone()).unwrap();
419 let x = Tensor::from_storage(
421 TensorStorage::cpu(vec![0.01f32; 4]),
422 vec![1, 4, 1, 1],
423 false,
424 )
425 .unwrap();
426 let y = d.forward(&x).unwrap();
427 assert_eq!(y.shape(), &[1, 3, 8, 8]);
429 for &v in y.data().unwrap() {
430 assert!(v.is_finite(), "decoder output non-finite: {v}");
431 }
432 }
433
434 #[test]
435 fn vae_decoder_named_parameters_include_post_quant_conv() {
436 let cfg = tiny_cfg();
437 let v = VaeDecoder::<f32>::new(cfg).unwrap();
438 let names: Vec<String> = v.named_parameters().into_iter().map(|(n, _)| n).collect();
439 for k in [
440 "post_quant_conv.weight",
441 "post_quant_conv.bias",
442 "decoder.conv_in.weight",
443 "decoder.mid_block.attentions.0.to_q.weight",
444 "decoder.up_blocks.0.resnets.0.norm1.weight",
445 "decoder.conv_norm_out.weight",
446 "decoder.conv_out.bias",
447 ] {
448 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
449 }
450 }
451
452 #[test]
453 fn vae_decoder_forward_shape() {
454 let cfg = tiny_cfg();
455 let v = VaeDecoder::<f32>::new(cfg).unwrap();
456 let x = Tensor::from_storage(
457 TensorStorage::cpu(vec![0.01f32; 4]),
458 vec![1, 4, 1, 1],
459 false,
460 )
461 .unwrap();
462 let y = v.forward(&x).unwrap();
463 assert_eq!(y.shape(), &[1, 3, 8, 8]);
464 }
465
466 #[test]
467 fn vae_decoder_decode_with_scaling_matches_manual_div() {
468 let cfg = tiny_cfg();
469 let v = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
470 let x = Tensor::from_storage(
471 TensorStorage::cpu(vec![0.05f32; 4]),
472 vec![1, 4, 1, 1],
473 false,
474 )
475 .unwrap();
476 let inv = (1.0 / cfg.scaling_factor) as f32;
477 let scaled_data: Vec<f32> = x.data().unwrap().iter().map(|&v| v * inv).collect();
478 let scaled =
479 Tensor::from_storage(TensorStorage::cpu(scaled_data), vec![1, 4, 1, 1], false).unwrap();
480 let a = v.decode_with_scaling(&x).unwrap();
481 let b = v.forward(&scaled).unwrap();
482 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
483 assert!(
484 (x - y).abs() < 1e-4,
485 "decode_with_scaling vs manual div differ: {x} vs {y}"
486 );
487 }
488 }
489
490 #[test]
491 fn round_trip_state_dict() {
492 let cfg = tiny_cfg();
493 let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
494 let sd = src.state_dict();
495 let mut dst = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
496 dst.load_state_dict(&sd, true).unwrap();
497 let x = Tensor::from_storage(
498 TensorStorage::cpu(vec![0.01f32; 4]),
499 vec![1, 4, 1, 1],
500 false,
501 )
502 .unwrap();
503 let a = src.forward(&x).unwrap();
504 let b = dst.forward(&x).unwrap();
505 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
506 assert!((x - y).abs() < 1e-5, "round-trip differs: {x} vs {y}");
507 }
508 }
509}