1use std::collections::HashMap;
30
31use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
32use ferrotorch_nn::module::{Module, StateDict};
33use ferrotorch_nn::parameter::Parameter;
34use ferrotorch_nn::{Conv2d, GroupNorm, SiLU};
35
36use crate::blocks::{UNetMidBlock2D, UpDecoderBlock2D};
37use crate::config::VaeDecoderConfig;
38
39#[derive(Debug)]
41pub struct Decoder<T: Float> {
42 pub conv_in: Conv2d<T>,
44 pub mid_block: UNetMidBlock2D<T>,
46 pub up_blocks: Vec<UpDecoderBlock2D<T>>,
49 pub conv_norm_out: GroupNorm<T>,
52 pub conv_act: SiLU,
54 pub conv_out: Conv2d<T>,
56 pub config: VaeDecoderConfig,
58 training: bool,
59}
60
61impl<T: Float> Decoder<T> {
62 pub fn new(cfg: VaeDecoderConfig) -> FerrotorchResult<Self> {
72 cfg.validate()?;
73 let groups = cfg.norm_num_groups;
74 let resnet_eps = 1e-6_f64;
75 let top_channels =
76 *cfg.block_out_channels
77 .last()
78 .ok_or_else(|| FerrotorchError::InvalidArgument {
79 message: "Decoder::new: block_out_channels is empty (should be unreachable \
80 after validate)"
81 .into(),
82 })?;
83
84 let conv_in = Conv2d::<T>::new(
85 cfg.latent_channels,
86 top_channels,
87 (3, 3),
88 (1, 1),
89 (1, 1),
90 true,
91 )?;
92
93 let mid_block = UNetMidBlock2D::<T>::new(top_channels, groups, resnet_eps)?;
94
95 let reversed: Vec<usize> = cfg.block_out_channels.iter().rev().copied().collect();
96 let mut up_blocks = Vec::with_capacity(reversed.len());
97 let mut prev_out = reversed[0];
98 let num_blocks = reversed.len();
99 let resnets = cfg.resnets_per_up_block();
100 for (i, &c) in reversed.iter().enumerate() {
101 let is_final = i == num_blocks - 1;
102 up_blocks.push(UpDecoderBlock2D::<T>::new(
103 prev_out, c, resnets, groups, resnet_eps, !is_final,
104 )?);
105 prev_out = c;
106 }
107
108 let bottom_channels = cfg.block_out_channels[0];
109 let conv_norm_out = GroupNorm::<T>::new(groups, bottom_channels, resnet_eps, true)?;
110 let conv_out = Conv2d::<T>::new(
111 bottom_channels,
112 cfg.out_channels,
113 (3, 3),
114 (1, 1),
115 (1, 1),
116 true,
117 )?;
118
119 Ok(Self {
120 conv_in,
121 mid_block,
122 up_blocks,
123 conv_norm_out,
124 conv_act: SiLU::new(),
125 conv_out,
126 config: cfg,
127 training: false,
128 })
129 }
130}
131
132impl<T: Float> Module<T> for Decoder<T> {
133 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
134 let cfg = &self.config;
136 if input.ndim() != 4 || input.shape()[1] != cfg.latent_channels {
137 return Err(FerrotorchError::ShapeMismatch {
138 message: format!(
139 "Decoder::forward: expected [B, {}, H, W], got {:?}",
140 cfg.latent_channels,
141 input.shape()
142 ),
143 });
144 }
145 let mut h = self.conv_in.forward(input)?;
146 h = self.mid_block.forward(&h)?;
147 for up in &self.up_blocks {
148 h = up.forward(&h)?;
149 }
150 h = self.conv_norm_out.forward(&h)?;
151 h = self.conv_act.forward(&h)?;
152 self.conv_out.forward(&h)
153 }
154
155 fn parameters(&self) -> Vec<&Parameter<T>> {
156 let mut out = Vec::new();
157 out.extend(self.conv_in.parameters());
158 out.extend(self.mid_block.parameters());
159 for b in &self.up_blocks {
160 out.extend(b.parameters());
161 }
162 out.extend(self.conv_norm_out.parameters());
163 out.extend(self.conv_out.parameters());
164 out
165 }
166
167 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
168 let mut out = Vec::new();
169 out.extend(self.conv_in.parameters_mut());
170 out.extend(self.mid_block.parameters_mut());
171 for b in &mut self.up_blocks {
172 out.extend(b.parameters_mut());
173 }
174 out.extend(self.conv_norm_out.parameters_mut());
175 out.extend(self.conv_out.parameters_mut());
176 out
177 }
178
179 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
180 let mut out = Vec::new();
181 for (n, p) in self.conv_in.named_parameters() {
182 out.push((format!("conv_in.{n}"), p));
183 }
184 for (n, p) in self.mid_block.named_parameters() {
185 out.push((format!("mid_block.{n}"), p));
186 }
187 for (i, b) in self.up_blocks.iter().enumerate() {
188 for (n, p) in b.named_parameters() {
189 out.push((format!("up_blocks.{i}.{n}"), p));
190 }
191 }
192 for (n, p) in self.conv_norm_out.named_parameters() {
193 out.push((format!("conv_norm_out.{n}"), p));
194 }
195 for (n, p) in self.conv_out.named_parameters() {
196 out.push((format!("conv_out.{n}"), p));
197 }
198 out
199 }
200
201 fn train(&mut self) {
202 self.training = true;
203 for b in &mut self.up_blocks {
204 b.train();
205 }
206 self.mid_block.train();
207 }
208 fn eval(&mut self) {
209 self.training = false;
210 for b in &mut self.up_blocks {
211 b.eval();
212 }
213 self.mid_block.eval();
214 }
215 fn is_training(&self) -> bool {
216 self.training
217 }
218
219 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
220 let extract = |prefix: &str| -> StateDict<T> {
221 let p = format!("{prefix}.");
222 state
223 .iter()
224 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
225 .collect()
226 };
227
228 if strict {
229 for k in state.keys() {
230 let ok = k.starts_with("conv_in.")
231 || k.starts_with("mid_block.")
232 || k.starts_with("up_blocks.")
233 || k.starts_with("conv_norm_out.")
234 || k.starts_with("conv_out.");
235 if !ok {
236 return Err(FerrotorchError::InvalidArgument {
237 message: format!("unexpected key in Decoder state_dict: \"{k}\""),
238 });
239 }
240 }
241 }
242
243 self.conv_in.load_state_dict(&extract("conv_in"), strict)?;
244 self.mid_block
245 .load_state_dict(&extract("mid_block"), strict)?;
246 for (i, b) in self.up_blocks.iter_mut().enumerate() {
247 b.load_state_dict(&extract(&format!("up_blocks.{i}")), strict)?;
248 }
249 self.conv_norm_out
250 .load_state_dict(&extract("conv_norm_out"), strict)?;
251 self.conv_out
252 .load_state_dict(&extract("conv_out"), strict)?;
253 Ok(())
254 }
255}
256
257#[derive(Debug)]
265pub struct VaeDecoder<T: Float> {
266 pub post_quant_conv: Conv2d<T>,
268 pub decoder: Decoder<T>,
270 pub config: VaeDecoderConfig,
272 training: bool,
273}
274
275impl<T: Float> VaeDecoder<T> {
276 pub fn new(cfg: VaeDecoderConfig) -> FerrotorchResult<Self> {
282 cfg.validate()?;
283 let post_quant_conv = Conv2d::<T>::new(
284 cfg.latent_channels,
285 cfg.latent_channels,
286 (1, 1),
287 (1, 1),
288 (0, 0),
289 true,
290 )?;
291 let decoder = Decoder::<T>::new(cfg.clone())?;
292 Ok(Self {
293 post_quant_conv,
294 decoder,
295 config: cfg,
296 training: false,
297 })
298 }
299
300 pub fn decode_with_scaling(&self, latent: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
308 let inv = self.config.scaling_factor.recip();
309 let inv_t = T::from(inv).ok_or_else(|| FerrotorchError::InvalidArgument {
310 message: format!(
311 "VaeDecoder::decode_with_scaling: cannot cast 1/{} into Float",
312 self.config.scaling_factor
313 ),
314 })?;
315 let inv_tensor = ferrotorch_core::scalar::<T>(inv_t)?;
316 let scaled = ferrotorch_core::grad_fns::arithmetic::mul(latent, &inv_tensor)?;
317 self.forward(&scaled)
318 }
319}
320
321impl<T: Float> Module<T> for VaeDecoder<T> {
322 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
325 let cfg = &self.config;
326 if input.ndim() != 4 || input.shape()[1] != cfg.latent_channels {
327 return Err(FerrotorchError::ShapeMismatch {
328 message: format!(
329 "VaeDecoder::forward: expected [B, {}, H, W], got {:?}",
330 cfg.latent_channels,
331 input.shape()
332 ),
333 });
334 }
335 let post = self.post_quant_conv.forward(input)?;
336 self.decoder.forward(&post)
337 }
338
339 fn parameters(&self) -> Vec<&Parameter<T>> {
340 let mut out = Vec::new();
341 out.extend(self.post_quant_conv.parameters());
342 out.extend(self.decoder.parameters());
343 out
344 }
345
346 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
347 let mut out = Vec::new();
348 out.extend(self.post_quant_conv.parameters_mut());
349 out.extend(self.decoder.parameters_mut());
350 out
351 }
352
353 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
354 let mut out = Vec::new();
355 for (n, p) in self.post_quant_conv.named_parameters() {
356 out.push((format!("post_quant_conv.{n}"), p));
357 }
358 for (n, p) in self.decoder.named_parameters() {
359 out.push((format!("decoder.{n}"), p));
360 }
361 out
362 }
363
364 fn train(&mut self) {
365 self.training = true;
366 self.decoder.train();
367 }
368 fn eval(&mut self) {
369 self.training = false;
370 self.decoder.eval();
371 }
372 fn is_training(&self) -> bool {
373 self.training
374 }
375
376 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
377 let extract = |prefix: &str| -> StateDict<T> {
378 let p = format!("{prefix}.");
379 state
380 .iter()
381 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
382 .collect()
383 };
384 if strict {
385 for k in state.keys() {
386 let ok = k.starts_with("post_quant_conv.") || k.starts_with("decoder.");
387 if !ok {
388 return Err(FerrotorchError::InvalidArgument {
389 message: format!("unexpected key in VaeDecoder state_dict: \"{k}\""),
390 });
391 }
392 }
393 }
394 self.post_quant_conv
395 .load_state_dict(&extract("post_quant_conv"), strict)?;
396 self.decoder.load_state_dict(&extract("decoder"), strict)?;
397 let _: HashMap<String, Tensor<T>> = HashMap::new(); Ok(())
399 }
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405 use ferrotorch_core::TensorStorage;
406
407 fn tiny_cfg() -> VaeDecoderConfig {
411 VaeDecoderConfig {
412 out_channels: 3,
413 latent_channels: 4,
414 block_out_channels: vec![4, 8, 16, 16],
418 layers_per_block: 1, norm_num_groups: 4,
420 sample_size: 8,
421 scaling_factor: 0.18215,
422 }
423 }
424
425 #[test]
426 fn decoder_forward_shape() {
427 let cfg = tiny_cfg();
428 let d = Decoder::<f32>::new(cfg.clone()).unwrap();
429 let x = Tensor::from_storage(
431 TensorStorage::cpu(vec![0.01f32; 4]),
432 vec![1, 4, 1, 1],
433 false,
434 )
435 .unwrap();
436 let y = d.forward(&x).unwrap();
437 assert_eq!(y.shape(), &[1, 3, 8, 8]);
439 for &v in y.data().unwrap() {
440 assert!(v.is_finite(), "decoder output non-finite: {v}");
441 }
442 }
443
444 #[test]
445 fn vae_decoder_named_parameters_include_post_quant_conv() {
446 let cfg = tiny_cfg();
447 let v = VaeDecoder::<f32>::new(cfg).unwrap();
448 let names: Vec<String> = v.named_parameters().into_iter().map(|(n, _)| n).collect();
449 for k in [
450 "post_quant_conv.weight",
451 "post_quant_conv.bias",
452 "decoder.conv_in.weight",
453 "decoder.mid_block.attentions.0.to_q.weight",
454 "decoder.up_blocks.0.resnets.0.norm1.weight",
455 "decoder.conv_norm_out.weight",
456 "decoder.conv_out.bias",
457 ] {
458 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
459 }
460 }
461
462 #[test]
463 fn vae_decoder_forward_shape() {
464 let cfg = tiny_cfg();
465 let v = VaeDecoder::<f32>::new(cfg).unwrap();
466 let x = Tensor::from_storage(
467 TensorStorage::cpu(vec![0.01f32; 4]),
468 vec![1, 4, 1, 1],
469 false,
470 )
471 .unwrap();
472 let y = v.forward(&x).unwrap();
473 assert_eq!(y.shape(), &[1, 3, 8, 8]);
474 }
475
476 #[test]
477 fn vae_decoder_decode_with_scaling_matches_manual_div() {
478 let cfg = tiny_cfg();
479 let v = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
480 let x = Tensor::from_storage(
481 TensorStorage::cpu(vec![0.05f32; 4]),
482 vec![1, 4, 1, 1],
483 false,
484 )
485 .unwrap();
486 let inv = (1.0 / cfg.scaling_factor) as f32;
487 let scaled_data: Vec<f32> = x.data().unwrap().iter().map(|&v| v * inv).collect();
488 let scaled =
489 Tensor::from_storage(TensorStorage::cpu(scaled_data), vec![1, 4, 1, 1], false).unwrap();
490 let a = v.decode_with_scaling(&x).unwrap();
491 let b = v.forward(&scaled).unwrap();
492 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
493 assert!(
494 (x - y).abs() < 1e-4,
495 "decode_with_scaling vs manual div differ: {x} vs {y}"
496 );
497 }
498 }
499
500 #[test]
501 fn round_trip_state_dict() {
502 let cfg = tiny_cfg();
503 let src = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
504 let sd = src.state_dict();
505 let mut dst = VaeDecoder::<f32>::new(cfg.clone()).unwrap();
506 dst.load_state_dict(&sd, true).unwrap();
507 let x = Tensor::from_storage(
508 TensorStorage::cpu(vec![0.01f32; 4]),
509 vec![1, 4, 1, 1],
510 false,
511 )
512 .unwrap();
513 let a = src.forward(&x).unwrap();
514 let b = dst.forward(&x).unwrap();
515 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
516 assert!((x - y).abs() < 1e-5, "round-trip differs: {x} vs {y}");
517 }
518 }
519}