1use std::collections::HashMap;
26
27use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
28use ferrotorch_nn::module::{Module, StateDict};
29use ferrotorch_nn::parameter::Parameter;
30use ferrotorch_nn::{Conv2d, GroupNorm, SiLU};
31
32use crate::blocks::{DownEncoderBlock2D, UNetMidBlock2D};
33use crate::config::VaeDecoderConfig;
34
35pub type VaeEncoderConfig = VaeDecoderConfig;
42
43const LOGVAR_CLAMP_MIN: f64 = -30.0;
48const LOGVAR_CLAMP_MAX: f64 = 20.0;
49
50#[derive(Debug)]
52pub struct Encoder<T: Float> {
53 pub conv_in: Conv2d<T>,
59 pub down_blocks: Vec<DownEncoderBlock2D<T>>,
64 pub mid_block: UNetMidBlock2D<T>,
67 pub conv_norm_out: GroupNorm<T>,
70 pub conv_act: SiLU,
72 pub conv_out: Conv2d<T>,
76 pub config: VaeEncoderConfig,
78 training: bool,
79}
80
81impl<T: Float> Encoder<T> {
82 pub fn new(cfg: VaeEncoderConfig) -> FerrotorchResult<Self> {
92 cfg.validate()?;
93 let groups = cfg.norm_num_groups;
94 let resnet_eps = 1e-6_f64;
95 let bottom_channels = cfg.block_out_channels[0];
96 let top_channels =
97 *cfg.block_out_channels
98 .last()
99 .ok_or_else(|| FerrotorchError::InvalidArgument {
100 message: "Encoder::new: block_out_channels is empty (should be \
101 unreachable after validate)"
102 .into(),
103 })?;
104
105 let conv_in = Conv2d::<T>::new(
107 cfg.out_channels,
108 bottom_channels,
109 (3, 3),
110 (1, 1),
111 (1, 1),
112 true,
113 )?;
114
115 let num_blocks = cfg.block_out_channels.len();
118 let resnets = cfg.layers_per_block;
119 let mut down_blocks = Vec::with_capacity(num_blocks);
120 let mut prev_out = bottom_channels;
121 for (i, &c) in cfg.block_out_channels.iter().enumerate() {
122 let is_final = i == num_blocks - 1;
123 down_blocks.push(DownEncoderBlock2D::<T>::new(
124 prev_out, c, resnets, groups, resnet_eps, !is_final,
125 )?);
126 prev_out = c;
127 }
128
129 let mid_block = UNetMidBlock2D::<T>::new(top_channels, groups, resnet_eps)?;
131
132 let conv_norm_out = GroupNorm::<T>::new(groups, top_channels, resnet_eps, true)?;
133 let conv_out = Conv2d::<T>::new(
135 top_channels,
136 2 * cfg.latent_channels,
137 (3, 3),
138 (1, 1),
139 (1, 1),
140 true,
141 )?;
142
143 Ok(Self {
144 conv_in,
145 down_blocks,
146 mid_block,
147 conv_norm_out,
148 conv_act: SiLU::new(),
149 conv_out,
150 config: cfg,
151 training: false,
152 })
153 }
154}
155
156impl<T: Float> Module<T> for Encoder<T> {
157 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
158 let cfg = &self.config;
161 if input.ndim() != 4 || input.shape()[1] != cfg.out_channels {
162 return Err(FerrotorchError::ShapeMismatch {
163 message: format!(
164 "Encoder::forward: expected [B, {}, H, W], got {:?}",
165 cfg.out_channels,
166 input.shape()
167 ),
168 });
169 }
170 let mut h = self.conv_in.forward(input)?;
171 for d in &self.down_blocks {
172 h = d.forward(&h)?;
173 }
174 h = self.mid_block.forward(&h)?;
175 h = self.conv_norm_out.forward(&h)?;
176 h = self.conv_act.forward(&h)?;
177 self.conv_out.forward(&h)
178 }
179
180 fn parameters(&self) -> Vec<&Parameter<T>> {
181 let mut out = Vec::new();
182 out.extend(self.conv_in.parameters());
183 for b in &self.down_blocks {
184 out.extend(b.parameters());
185 }
186 out.extend(self.mid_block.parameters());
187 out.extend(self.conv_norm_out.parameters());
188 out.extend(self.conv_out.parameters());
189 out
190 }
191
192 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
193 let mut out = Vec::new();
194 out.extend(self.conv_in.parameters_mut());
195 for b in &mut self.down_blocks {
196 out.extend(b.parameters_mut());
197 }
198 out.extend(self.mid_block.parameters_mut());
199 out.extend(self.conv_norm_out.parameters_mut());
200 out.extend(self.conv_out.parameters_mut());
201 out
202 }
203
204 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
205 let mut out = Vec::new();
206 for (n, p) in self.conv_in.named_parameters() {
207 out.push((format!("conv_in.{n}"), p));
208 }
209 for (i, b) in self.down_blocks.iter().enumerate() {
210 for (n, p) in b.named_parameters() {
211 out.push((format!("down_blocks.{i}.{n}"), p));
212 }
213 }
214 for (n, p) in self.mid_block.named_parameters() {
215 out.push((format!("mid_block.{n}"), p));
216 }
217 for (n, p) in self.conv_norm_out.named_parameters() {
218 out.push((format!("conv_norm_out.{n}"), p));
219 }
220 for (n, p) in self.conv_out.named_parameters() {
221 out.push((format!("conv_out.{n}"), p));
222 }
223 out
224 }
225
226 fn train(&mut self) {
227 self.training = true;
228 for b in &mut self.down_blocks {
229 b.train();
230 }
231 self.mid_block.train();
232 }
233 fn eval(&mut self) {
234 self.training = false;
235 for b in &mut self.down_blocks {
236 b.eval();
237 }
238 self.mid_block.eval();
239 }
240 fn is_training(&self) -> bool {
241 self.training
242 }
243
244 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
245 let extract = |prefix: &str| -> StateDict<T> {
246 let p = format!("{prefix}.");
247 state
248 .iter()
249 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
250 .collect()
251 };
252
253 if strict {
254 for k in state.keys() {
255 let ok = k.starts_with("conv_in.")
256 || k.starts_with("down_blocks.")
257 || k.starts_with("mid_block.")
258 || k.starts_with("conv_norm_out.")
259 || k.starts_with("conv_out.");
260 if !ok {
261 return Err(FerrotorchError::InvalidArgument {
262 message: format!("unexpected key in Encoder state_dict: \"{k}\""),
263 });
264 }
265 }
266 }
267
268 self.conv_in.load_state_dict(&extract("conv_in"), strict)?;
269 for (i, b) in self.down_blocks.iter_mut().enumerate() {
270 b.load_state_dict(&extract(&format!("down_blocks.{i}")), strict)?;
271 }
272 self.mid_block
273 .load_state_dict(&extract("mid_block"), strict)?;
274 self.conv_norm_out
275 .load_state_dict(&extract("conv_norm_out"), strict)?;
276 self.conv_out
277 .load_state_dict(&extract("conv_out"), strict)?;
278 Ok(())
279 }
280}
281
282#[derive(Debug)]
289pub struct VaeEncoder<T: Float> {
290 pub encoder: Encoder<T>,
292 pub quant_conv: Conv2d<T>,
294 pub config: VaeEncoderConfig,
296 training: bool,
297}
298
299impl<T: Float> VaeEncoder<T> {
300 pub fn new(cfg: VaeEncoderConfig) -> FerrotorchResult<Self> {
306 cfg.validate()?;
307 let encoder = Encoder::<T>::new(cfg.clone())?;
308 let two_l = 2 * cfg.latent_channels;
309 let quant_conv = Conv2d::<T>::new(two_l, two_l, (1, 1), (1, 1), (0, 0), true)?;
310 Ok(Self {
311 encoder,
312 quant_conv,
313 config: cfg,
314 training: false,
315 })
316 }
317
318 pub fn encode(&self, image: &Tensor<T>) -> FerrotorchResult<DiagonalGaussianDistribution<T>> {
326 let params = self.forward(image)?;
327 DiagonalGaussianDistribution::from_parameters(¶ms, self.config.latent_channels)
328 }
329
330 pub fn encode_with_scaling(&self, image: &Tensor<T>, seed: u64) -> FerrotorchResult<Tensor<T>> {
349 let dist = self.encode(image)?;
350 let sample = dist.sample_with_seed(seed)?;
351 let sf = T::from(self.config.scaling_factor).ok_or_else(|| {
352 FerrotorchError::InvalidArgument {
353 message: format!(
354 "VaeEncoder::encode_with_scaling: cannot cast scaling_factor={} into Float",
355 self.config.scaling_factor
356 ),
357 }
358 })?;
359 let sf_t = ferrotorch_core::scalar::<T>(sf)?;
360 ferrotorch_core::grad_fns::arithmetic::mul(&sample, &sf_t)
361 }
362}
363
364impl<T: Float> Module<T> for VaeEncoder<T> {
365 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
370 let cfg = &self.config;
371 if input.ndim() != 4 || input.shape()[1] != cfg.out_channels {
372 return Err(FerrotorchError::ShapeMismatch {
373 message: format!(
374 "VaeEncoder::forward: expected [B, {}, H, W], got {:?}",
375 cfg.out_channels,
376 input.shape()
377 ),
378 });
379 }
380 let h = self.encoder.forward(input)?;
381 self.quant_conv.forward(&h)
382 }
383
384 fn parameters(&self) -> Vec<&Parameter<T>> {
385 let mut out = Vec::new();
386 out.extend(self.encoder.parameters());
387 out.extend(self.quant_conv.parameters());
388 out
389 }
390
391 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
392 let mut out = Vec::new();
393 out.extend(self.encoder.parameters_mut());
394 out.extend(self.quant_conv.parameters_mut());
395 out
396 }
397
398 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
399 let mut out = Vec::new();
400 for (n, p) in self.encoder.named_parameters() {
401 out.push((format!("encoder.{n}"), p));
402 }
403 for (n, p) in self.quant_conv.named_parameters() {
404 out.push((format!("quant_conv.{n}"), p));
405 }
406 out
407 }
408
409 fn train(&mut self) {
410 self.training = true;
411 self.encoder.train();
412 }
413 fn eval(&mut self) {
414 self.training = false;
415 self.encoder.eval();
416 }
417 fn is_training(&self) -> bool {
418 self.training
419 }
420
421 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
422 let extract = |prefix: &str| -> StateDict<T> {
423 let p = format!("{prefix}.");
424 state
425 .iter()
426 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
427 .collect()
428 };
429 if strict {
430 for k in state.keys() {
431 let ok = k.starts_with("encoder.") || k.starts_with("quant_conv.");
432 if !ok {
433 return Err(FerrotorchError::InvalidArgument {
434 message: format!("unexpected key in VaeEncoder state_dict: \"{k}\""),
435 });
436 }
437 }
438 }
439 self.encoder.load_state_dict(&extract("encoder"), strict)?;
440 self.quant_conv
441 .load_state_dict(&extract("quant_conv"), strict)?;
442 let _: HashMap<String, Tensor<T>> = HashMap::new(); Ok(())
444 }
445}
446
447#[derive(Debug)]
455pub struct DiagonalGaussianDistribution<T: Float> {
456 pub mean: Tensor<T>,
458 pub logvar: Tensor<T>,
460}
461
462impl<T: Float> DiagonalGaussianDistribution<T> {
463 pub fn from_parameters(params: &Tensor<T>, latent_channels: usize) -> FerrotorchResult<Self> {
472 if params.ndim() != 4 || params.shape()[1] != 2 * latent_channels {
473 return Err(FerrotorchError::ShapeMismatch {
474 message: format!(
475 "DiagonalGaussianDistribution::from_parameters: expected [B, {}, H, W], \
476 got {:?}",
477 2 * latent_channels,
478 params.shape()
479 ),
480 });
481 }
482 let parts = params.chunk(2, 1)?;
484 let mean = parts[0].clone();
485 let logvar_raw = parts[1].clone();
486 let lo = T::from(LOGVAR_CLAMP_MIN).ok_or_else(|| FerrotorchError::InvalidArgument {
488 message: format!(
489 "DiagonalGaussianDistribution: cannot cast logvar clamp min {LOGVAR_CLAMP_MIN} \
490 into Float"
491 ),
492 })?;
493 let hi = T::from(LOGVAR_CLAMP_MAX).ok_or_else(|| FerrotorchError::InvalidArgument {
494 message: format!(
495 "DiagonalGaussianDistribution: cannot cast logvar clamp max {LOGVAR_CLAMP_MAX} \
496 into Float"
497 ),
498 })?;
499 let logvar = logvar_raw.clamp_t(lo, hi)?;
500 Ok(Self { mean, logvar })
501 }
502
503 pub fn mode(&self) -> &Tensor<T> {
507 &self.mean
508 }
509
510 pub fn sample_with_seed(&self, seed: u64) -> FerrotorchResult<Tensor<T>> {
528 let eps = randn_with_seed::<T>(self.mean.shape(), seed)?;
529 let half = T::from(0.5f64).ok_or_else(|| FerrotorchError::InvalidArgument {
531 message: "DiagonalGaussianDistribution::sample_with_seed: cannot cast 0.5 into Float"
532 .into(),
533 })?;
534 let half_t = ferrotorch_core::scalar::<T>(half)?;
535 let scaled_logvar = ferrotorch_core::grad_fns::arithmetic::mul(&self.logvar, &half_t)?;
536 let std = ferrotorch_core::grad_fns::transcendental::exp(&scaled_logvar)?;
537 let noise = ferrotorch_core::grad_fns::arithmetic::mul(&std, &eps)?;
538 ferrotorch_core::grad_fns::arithmetic::add(&self.mean, &noise)
539 }
540}
541
542fn randn_with_seed<T: Float>(shape: &[usize], seed: u64) -> FerrotorchResult<Tensor<T>> {
549 let numel: usize = shape.iter().product();
550 let mut data = Vec::with_capacity(numel);
551 let mut state = if seed == 0 {
552 0x0000_dead_beef_cafe
553 } else {
554 seed
555 };
556
557 let mut next_uniform = || -> f64 {
558 state ^= state << 13;
559 state ^= state >> 7;
560 state ^= state << 17;
561 ((state as f64) / (u64::MAX as f64)).max(1e-300)
562 };
563
564 let mut i = 0;
565 while i < numel {
566 let u1 = next_uniform();
567 let u2 = next_uniform();
568 let r = (-2.0 * u1.ln()).sqrt();
569 let theta = 2.0 * std::f64::consts::PI * u2;
570 data.push(
571 T::from(r * theta.cos()).ok_or_else(|| FerrotorchError::InvalidArgument {
572 message: "randn_with_seed: cannot cast Box-Muller output into Float".into(),
573 })?,
574 );
575 if i + 1 < numel {
576 data.push(T::from(r * theta.sin()).ok_or_else(|| {
577 FerrotorchError::InvalidArgument {
578 message: "randn_with_seed: cannot cast Box-Muller output into Float".into(),
579 }
580 })?);
581 }
582 i += 2;
583 }
584
585 data.truncate(numel);
586 Tensor::from_storage(TensorStorage::cpu(data), shape.to_vec(), false)
587}
588
589#[cfg(test)]
590mod tests {
591 use super::*;
592
593 fn tiny_cfg() -> VaeEncoderConfig {
597 VaeDecoderConfig {
598 out_channels: 3,
599 latent_channels: 4,
600 block_out_channels: vec![4, 8, 16, 16],
601 layers_per_block: 1,
602 norm_num_groups: 4,
603 sample_size: 8,
604 scaling_factor: 0.18215,
605 }
606 }
607
608 #[test]
609 fn encoder_forward_shape() {
610 let cfg = tiny_cfg();
611 let e = Encoder::<f32>::new(cfg.clone()).unwrap();
612 let x = Tensor::from_storage(
614 TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
615 vec![1, 3, 8, 8],
616 false,
617 )
618 .unwrap();
619 let y = e.forward(&x).unwrap();
620 assert_eq!(y.shape(), &[1, 2 * cfg.latent_channels, 1, 1]);
622 for &v in y.data().unwrap() {
623 assert!(v.is_finite(), "encoder output non-finite: {v}");
624 }
625 }
626
627 #[test]
628 fn vae_encoder_forward_shape() {
629 let cfg = tiny_cfg();
630 let v = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
631 let x = Tensor::from_storage(
632 TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
633 vec![1, 3, 8, 8],
634 false,
635 )
636 .unwrap();
637 let y = v.forward(&x).unwrap();
638 assert_eq!(y.shape(), &[1, 2 * cfg.latent_channels, 1, 1]);
639 }
640
641 #[test]
642 fn vae_encoder_named_parameters_include_quant_conv() {
643 let cfg = tiny_cfg();
644 let v = VaeEncoder::<f32>::new(cfg).unwrap();
645 let names: Vec<String> = v.named_parameters().into_iter().map(|(n, _)| n).collect();
646 for k in [
647 "quant_conv.weight",
648 "quant_conv.bias",
649 "encoder.conv_in.weight",
650 "encoder.down_blocks.0.resnets.0.norm1.weight",
651 "encoder.mid_block.attentions.0.to_q.weight",
652 "encoder.conv_norm_out.weight",
653 "encoder.conv_out.bias",
654 ] {
655 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
656 }
657 }
658
659 #[test]
660 fn diag_gauss_split_and_mode_shapes() {
661 let cfg = tiny_cfg();
662 let v = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
663 let x = Tensor::from_storage(
664 TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
665 vec![1, 3, 8, 8],
666 false,
667 )
668 .unwrap();
669 let dist = v.encode(&x).unwrap();
670 assert_eq!(dist.mean.shape(), &[1, cfg.latent_channels, 1, 1]);
671 assert_eq!(dist.logvar.shape(), &[1, cfg.latent_channels, 1, 1]);
672 let mode = dist.mode();
674 assert_eq!(mode.shape(), dist.mean.shape());
675 for (a, b) in mode
676 .data()
677 .unwrap()
678 .iter()
679 .zip(dist.mean.data().unwrap().iter())
680 {
681 assert!((a - b).abs() < 1e-7);
682 }
683 }
684
685 #[test]
686 fn diag_gauss_logvar_is_clamped() {
687 let l = 2;
690 let mut data = vec![0.5f32; l * 2]; for d in data.iter_mut().skip(l) {
692 *d = 1e6; }
694 let params =
695 Tensor::from_storage(TensorStorage::cpu(data), vec![1, 2 * l, 1, 1], false).unwrap();
696 let dist = DiagonalGaussianDistribution::<f32>::from_parameters(¶ms, l).unwrap();
697 for &v in dist.logvar.data().unwrap() {
698 assert!(
699 v <= LOGVAR_CLAMP_MAX as f32 + 1e-4,
700 "logvar not clamped to <= {LOGVAR_CLAMP_MAX}: got {v}"
701 );
702 }
703 }
704
705 #[test]
706 fn diag_gauss_sample_with_seed_is_deterministic() {
707 let l = 4;
711 let zeros = vec![0.0f32; 2 * l * 3 * 3];
712 let params =
713 Tensor::from_storage(TensorStorage::cpu(zeros), vec![1, 2 * l, 3, 3], false).unwrap();
714 let dist = DiagonalGaussianDistribution::<f32>::from_parameters(¶ms, l).unwrap();
715 let s1 = dist.sample_with_seed(42).unwrap();
716 let s2 = dist.sample_with_seed(42).unwrap();
717 let s3 = dist.sample_with_seed(43).unwrap();
718 assert_eq!(s1.shape(), &[1, l, 3, 3]);
719 for (a, b) in s1.data().unwrap().iter().zip(s2.data().unwrap().iter()) {
720 assert!(
721 (a - b).abs() < 1e-7,
722 "sample_with_seed(42) not deterministic: {a} vs {b}"
723 );
724 }
725 let mut differ = false;
728 for (a, c) in s1.data().unwrap().iter().zip(s3.data().unwrap().iter()) {
729 if (a - c).abs() > 1e-6 {
730 differ = true;
731 break;
732 }
733 }
734 assert!(
735 differ,
736 "sample_with_seed(42) and sample_with_seed(43) produced identical output"
737 );
738 }
739
740 #[test]
741 fn vae_encoder_round_trip_state_dict() {
742 let cfg = tiny_cfg();
743 let src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
744 let sd = src.state_dict();
745 let mut dst = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
746 dst.load_state_dict(&sd, true).unwrap();
747 let x = Tensor::from_storage(
748 TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
749 vec![1, 3, 8, 8],
750 false,
751 )
752 .unwrap();
753 let a = src.forward(&x).unwrap();
754 let b = dst.forward(&x).unwrap();
755 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
756 assert!((x - y).abs() < 1e-5, "round-trip differs: {x} vs {y}");
757 }
758 }
759
760 #[test]
761 fn encode_with_scaling_applies_scaling_factor() {
762 let cfg = tiny_cfg();
763 let v = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
764 let x = Tensor::from_storage(
765 TensorStorage::cpu(vec![0.05f32; 3 * 8 * 8]),
766 vec![1, 3, 8, 8],
767 false,
768 )
769 .unwrap();
770 let scaled = v.encode_with_scaling(&x, 99).unwrap();
773 let dist = v.encode(&x).unwrap();
774 let raw = dist.sample_with_seed(99).unwrap();
775 for (s, r) in scaled
776 .data()
777 .unwrap()
778 .iter()
779 .zip(raw.data().unwrap().iter())
780 {
781 let expected = r * cfg.scaling_factor as f32;
782 assert!(
783 (s - expected).abs() < 1e-5,
784 "encode_with_scaling didn't apply scaling_factor: got {s}, expected {expected}"
785 );
786 }
787 }
788}