1use std::collections::HashMap;
26
27use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
28use ferrotorch_nn::module::{Module, StateDict};
29use ferrotorch_nn::parameter::Parameter;
30use ferrotorch_nn::{Conv2d, GroupNorm, SiLU};
31
32use crate::blocks::{DownEncoderBlock2D, UNetMidBlock2D};
33use crate::config::VaeDecoderConfig;
34
35pub type VaeEncoderConfig = VaeDecoderConfig;
42
43const LOGVAR_CLAMP_MIN: f64 = -30.0;
48const LOGVAR_CLAMP_MAX: f64 = 20.0;
49
50#[derive(Debug)]
52pub struct Encoder<T: Float> {
53 pub conv_in: Conv2d<T>,
59 pub down_blocks: Vec<DownEncoderBlock2D<T>>,
64 pub mid_block: UNetMidBlock2D<T>,
67 pub conv_norm_out: GroupNorm<T>,
70 pub conv_act: SiLU,
72 pub conv_out: Conv2d<T>,
76 pub config: VaeEncoderConfig,
78 training: bool,
79}
80
81impl<T: Float> Encoder<T> {
82 pub fn new(cfg: VaeEncoderConfig) -> FerrotorchResult<Self> {
92 cfg.validate()?;
93 let groups = cfg.norm_num_groups;
94 let resnet_eps = 1e-6_f64;
95 let bottom_channels = cfg.block_out_channels[0];
96 let top_channels =
97 *cfg.block_out_channels
98 .last()
99 .ok_or_else(|| FerrotorchError::InvalidArgument {
100 message: "Encoder::new: block_out_channels is empty (should be \
101 unreachable after validate)"
102 .into(),
103 })?;
104
105 let conv_in = Conv2d::<T>::new(
107 cfg.out_channels,
108 bottom_channels,
109 (3, 3),
110 (1, 1),
111 (1, 1),
112 true,
113 )?;
114
115 let num_blocks = cfg.block_out_channels.len();
118 let resnets = cfg.layers_per_block;
119 let mut down_blocks = Vec::with_capacity(num_blocks);
120 let mut prev_out = bottom_channels;
121 for (i, &c) in cfg.block_out_channels.iter().enumerate() {
122 let is_final = i == num_blocks - 1;
123 down_blocks.push(DownEncoderBlock2D::<T>::new(
124 prev_out,
125 c,
126 resnets,
127 groups,
128 resnet_eps,
129 !is_final,
130 )?);
131 prev_out = c;
132 }
133
134 let mid_block = UNetMidBlock2D::<T>::new(top_channels, groups, resnet_eps)?;
136
137 let conv_norm_out =
138 GroupNorm::<T>::new(groups, top_channels, resnet_eps, true)?;
139 let conv_out = Conv2d::<T>::new(
141 top_channels,
142 2 * cfg.latent_channels,
143 (3, 3),
144 (1, 1),
145 (1, 1),
146 true,
147 )?;
148
149 Ok(Self {
150 conv_in,
151 down_blocks,
152 mid_block,
153 conv_norm_out,
154 conv_act: SiLU::new(),
155 conv_out,
156 config: cfg,
157 training: false,
158 })
159 }
160}
161
162impl<T: Float> Module<T> for Encoder<T> {
163 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
164 let cfg = &self.config;
167 if input.ndim() != 4 || input.shape()[1] != cfg.out_channels {
168 return Err(FerrotorchError::ShapeMismatch {
169 message: format!(
170 "Encoder::forward: expected [B, {}, H, W], got {:?}",
171 cfg.out_channels,
172 input.shape()
173 ),
174 });
175 }
176 let mut h = self.conv_in.forward(input)?;
177 for d in &self.down_blocks {
178 h = d.forward(&h)?;
179 }
180 h = self.mid_block.forward(&h)?;
181 h = self.conv_norm_out.forward(&h)?;
182 h = self.conv_act.forward(&h)?;
183 self.conv_out.forward(&h)
184 }
185
186 fn parameters(&self) -> Vec<&Parameter<T>> {
187 let mut out = Vec::new();
188 out.extend(self.conv_in.parameters());
189 for b in &self.down_blocks {
190 out.extend(b.parameters());
191 }
192 out.extend(self.mid_block.parameters());
193 out.extend(self.conv_norm_out.parameters());
194 out.extend(self.conv_out.parameters());
195 out
196 }
197
198 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
199 let mut out = Vec::new();
200 out.extend(self.conv_in.parameters_mut());
201 for b in &mut self.down_blocks {
202 out.extend(b.parameters_mut());
203 }
204 out.extend(self.mid_block.parameters_mut());
205 out.extend(self.conv_norm_out.parameters_mut());
206 out.extend(self.conv_out.parameters_mut());
207 out
208 }
209
210 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
211 let mut out = Vec::new();
212 for (n, p) in self.conv_in.named_parameters() {
213 out.push((format!("conv_in.{n}"), p));
214 }
215 for (i, b) in self.down_blocks.iter().enumerate() {
216 for (n, p) in b.named_parameters() {
217 out.push((format!("down_blocks.{i}.{n}"), p));
218 }
219 }
220 for (n, p) in self.mid_block.named_parameters() {
221 out.push((format!("mid_block.{n}"), p));
222 }
223 for (n, p) in self.conv_norm_out.named_parameters() {
224 out.push((format!("conv_norm_out.{n}"), p));
225 }
226 for (n, p) in self.conv_out.named_parameters() {
227 out.push((format!("conv_out.{n}"), p));
228 }
229 out
230 }
231
232 fn train(&mut self) {
233 self.training = true;
234 for b in &mut self.down_blocks {
235 b.train();
236 }
237 self.mid_block.train();
238 }
239 fn eval(&mut self) {
240 self.training = false;
241 for b in &mut self.down_blocks {
242 b.eval();
243 }
244 self.mid_block.eval();
245 }
246 fn is_training(&self) -> bool {
247 self.training
248 }
249
250 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
251 let extract = |prefix: &str| -> StateDict<T> {
252 let p = format!("{prefix}.");
253 state
254 .iter()
255 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
256 .collect()
257 };
258
259 if strict {
260 for k in state.keys() {
261 let ok = k.starts_with("conv_in.")
262 || k.starts_with("down_blocks.")
263 || k.starts_with("mid_block.")
264 || k.starts_with("conv_norm_out.")
265 || k.starts_with("conv_out.");
266 if !ok {
267 return Err(FerrotorchError::InvalidArgument {
268 message: format!("unexpected key in Encoder state_dict: \"{k}\""),
269 });
270 }
271 }
272 }
273
274 self.conv_in.load_state_dict(&extract("conv_in"), strict)?;
275 for (i, b) in self.down_blocks.iter_mut().enumerate() {
276 b.load_state_dict(&extract(&format!("down_blocks.{i}")), strict)?;
277 }
278 self.mid_block
279 .load_state_dict(&extract("mid_block"), strict)?;
280 self.conv_norm_out
281 .load_state_dict(&extract("conv_norm_out"), strict)?;
282 self.conv_out
283 .load_state_dict(&extract("conv_out"), strict)?;
284 Ok(())
285 }
286}
287
288#[derive(Debug)]
295pub struct VaeEncoder<T: Float> {
296 pub encoder: Encoder<T>,
298 pub quant_conv: Conv2d<T>,
300 pub config: VaeEncoderConfig,
302 training: bool,
303}
304
305impl<T: Float> VaeEncoder<T> {
306 pub fn new(cfg: VaeEncoderConfig) -> FerrotorchResult<Self> {
312 cfg.validate()?;
313 let encoder = Encoder::<T>::new(cfg.clone())?;
314 let two_l = 2 * cfg.latent_channels;
315 let quant_conv = Conv2d::<T>::new(two_l, two_l, (1, 1), (1, 1), (0, 0), true)?;
316 Ok(Self {
317 encoder,
318 quant_conv,
319 config: cfg,
320 training: false,
321 })
322 }
323
324 pub fn encode(&self, image: &Tensor<T>) -> FerrotorchResult<DiagonalGaussianDistribution<T>> {
332 let params = self.forward(image)?;
333 DiagonalGaussianDistribution::from_parameters(¶ms, self.config.latent_channels)
334 }
335
336 pub fn encode_with_scaling(
355 &self,
356 image: &Tensor<T>,
357 seed: u64,
358 ) -> FerrotorchResult<Tensor<T>> {
359 let dist = self.encode(image)?;
360 let sample = dist.sample_with_seed(seed)?;
361 let sf = T::from(self.config.scaling_factor)
362 .ok_or_else(|| FerrotorchError::InvalidArgument {
363 message: format!(
364 "VaeEncoder::encode_with_scaling: cannot cast scaling_factor={} into Float",
365 self.config.scaling_factor
366 ),
367 })?;
368 let sf_t = ferrotorch_core::scalar::<T>(sf)?;
369 ferrotorch_core::grad_fns::arithmetic::mul(&sample, &sf_t)
370 }
371}
372
373impl<T: Float> Module<T> for VaeEncoder<T> {
374 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
379 let cfg = &self.config;
380 if input.ndim() != 4 || input.shape()[1] != cfg.out_channels {
381 return Err(FerrotorchError::ShapeMismatch {
382 message: format!(
383 "VaeEncoder::forward: expected [B, {}, H, W], got {:?}",
384 cfg.out_channels,
385 input.shape()
386 ),
387 });
388 }
389 let h = self.encoder.forward(input)?;
390 self.quant_conv.forward(&h)
391 }
392
393 fn parameters(&self) -> Vec<&Parameter<T>> {
394 let mut out = Vec::new();
395 out.extend(self.encoder.parameters());
396 out.extend(self.quant_conv.parameters());
397 out
398 }
399
400 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
401 let mut out = Vec::new();
402 out.extend(self.encoder.parameters_mut());
403 out.extend(self.quant_conv.parameters_mut());
404 out
405 }
406
407 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
408 let mut out = Vec::new();
409 for (n, p) in self.encoder.named_parameters() {
410 out.push((format!("encoder.{n}"), p));
411 }
412 for (n, p) in self.quant_conv.named_parameters() {
413 out.push((format!("quant_conv.{n}"), p));
414 }
415 out
416 }
417
418 fn train(&mut self) {
419 self.training = true;
420 self.encoder.train();
421 }
422 fn eval(&mut self) {
423 self.training = false;
424 self.encoder.eval();
425 }
426 fn is_training(&self) -> bool {
427 self.training
428 }
429
430 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
431 let extract = |prefix: &str| -> StateDict<T> {
432 let p = format!("{prefix}.");
433 state
434 .iter()
435 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
436 .collect()
437 };
438 if strict {
439 for k in state.keys() {
440 let ok = k.starts_with("encoder.") || k.starts_with("quant_conv.");
441 if !ok {
442 return Err(FerrotorchError::InvalidArgument {
443 message: format!("unexpected key in VaeEncoder state_dict: \"{k}\""),
444 });
445 }
446 }
447 }
448 self.encoder.load_state_dict(&extract("encoder"), strict)?;
449 self.quant_conv
450 .load_state_dict(&extract("quant_conv"), strict)?;
451 let _: HashMap<String, Tensor<T>> = HashMap::new(); Ok(())
453 }
454}
455
456#[derive(Debug)]
464pub struct DiagonalGaussianDistribution<T: Float> {
465 pub mean: Tensor<T>,
467 pub logvar: Tensor<T>,
469}
470
471impl<T: Float> DiagonalGaussianDistribution<T> {
472 pub fn from_parameters(
481 params: &Tensor<T>,
482 latent_channels: usize,
483 ) -> FerrotorchResult<Self> {
484 if params.ndim() != 4 || params.shape()[1] != 2 * latent_channels {
485 return Err(FerrotorchError::ShapeMismatch {
486 message: format!(
487 "DiagonalGaussianDistribution::from_parameters: expected [B, {}, H, W], \
488 got {:?}",
489 2 * latent_channels,
490 params.shape()
491 ),
492 });
493 }
494 let parts = params.chunk(2, 1)?;
496 let mean = parts[0].clone();
497 let logvar_raw = parts[1].clone();
498 let lo = T::from(LOGVAR_CLAMP_MIN).ok_or_else(|| FerrotorchError::InvalidArgument {
500 message: format!(
501 "DiagonalGaussianDistribution: cannot cast logvar clamp min {LOGVAR_CLAMP_MIN} \
502 into Float"
503 ),
504 })?;
505 let hi = T::from(LOGVAR_CLAMP_MAX).ok_or_else(|| FerrotorchError::InvalidArgument {
506 message: format!(
507 "DiagonalGaussianDistribution: cannot cast logvar clamp max {LOGVAR_CLAMP_MAX} \
508 into Float"
509 ),
510 })?;
511 let logvar = logvar_raw.clamp_t(lo, hi)?;
512 Ok(Self { mean, logvar })
513 }
514
515 pub fn mode(&self) -> &Tensor<T> {
519 &self.mean
520 }
521
522 pub fn sample_with_seed(&self, seed: u64) -> FerrotorchResult<Tensor<T>> {
540 let eps = randn_with_seed::<T>(self.mean.shape(), seed)?;
541 let half = T::from(0.5f64).ok_or_else(|| FerrotorchError::InvalidArgument {
543 message: "DiagonalGaussianDistribution::sample_with_seed: cannot cast 0.5 into Float"
544 .into(),
545 })?;
546 let half_t = ferrotorch_core::scalar::<T>(half)?;
547 let scaled_logvar =
548 ferrotorch_core::grad_fns::arithmetic::mul(&self.logvar, &half_t)?;
549 let std = ferrotorch_core::grad_fns::transcendental::exp(&scaled_logvar)?;
550 let noise = ferrotorch_core::grad_fns::arithmetic::mul(&std, &eps)?;
551 ferrotorch_core::grad_fns::arithmetic::add(&self.mean, &noise)
552 }
553}
554
555fn randn_with_seed<T: Float>(shape: &[usize], seed: u64) -> FerrotorchResult<Tensor<T>> {
562 let numel: usize = shape.iter().product();
563 let mut data = Vec::with_capacity(numel);
564 let mut state = if seed == 0 { 0x0000_dead_beef_cafe } else { seed };
565
566 let mut next_uniform = || -> f64 {
567 state ^= state << 13;
568 state ^= state >> 7;
569 state ^= state << 17;
570 ((state as f64) / (u64::MAX as f64)).max(1e-300)
571 };
572
573 let mut i = 0;
574 while i < numel {
575 let u1 = next_uniform();
576 let u2 = next_uniform();
577 let r = (-2.0 * u1.ln()).sqrt();
578 let theta = 2.0 * std::f64::consts::PI * u2;
579 data.push(T::from(r * theta.cos()).ok_or_else(|| FerrotorchError::InvalidArgument {
580 message: "randn_with_seed: cannot cast Box-Muller output into Float".into(),
581 })?);
582 if i + 1 < numel {
583 data.push(T::from(r * theta.sin()).ok_or_else(|| {
584 FerrotorchError::InvalidArgument {
585 message: "randn_with_seed: cannot cast Box-Muller output into Float".into(),
586 }
587 })?);
588 }
589 i += 2;
590 }
591
592 data.truncate(numel);
593 Tensor::from_storage(TensorStorage::cpu(data), shape.to_vec(), false)
594}
595
596#[cfg(test)]
597mod tests {
598 use super::*;
599
600 fn tiny_cfg() -> VaeEncoderConfig {
604 VaeDecoderConfig {
605 out_channels: 3,
606 latent_channels: 4,
607 block_out_channels: vec![4, 8, 16, 16],
608 layers_per_block: 1,
609 norm_num_groups: 4,
610 sample_size: 8,
611 scaling_factor: 0.18215,
612 }
613 }
614
615 #[test]
616 fn encoder_forward_shape() {
617 let cfg = tiny_cfg();
618 let e = Encoder::<f32>::new(cfg.clone()).unwrap();
619 let x = Tensor::from_storage(
621 TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
622 vec![1, 3, 8, 8],
623 false,
624 )
625 .unwrap();
626 let y = e.forward(&x).unwrap();
627 assert_eq!(y.shape(), &[1, 2 * cfg.latent_channels, 1, 1]);
629 for &v in y.data().unwrap() {
630 assert!(v.is_finite(), "encoder output non-finite: {v}");
631 }
632 }
633
634 #[test]
635 fn vae_encoder_forward_shape() {
636 let cfg = tiny_cfg();
637 let v = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
638 let x = Tensor::from_storage(
639 TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
640 vec![1, 3, 8, 8],
641 false,
642 )
643 .unwrap();
644 let y = v.forward(&x).unwrap();
645 assert_eq!(y.shape(), &[1, 2 * cfg.latent_channels, 1, 1]);
646 }
647
648 #[test]
649 fn vae_encoder_named_parameters_include_quant_conv() {
650 let cfg = tiny_cfg();
651 let v = VaeEncoder::<f32>::new(cfg).unwrap();
652 let names: Vec<String> = v.named_parameters().into_iter().map(|(n, _)| n).collect();
653 for k in [
654 "quant_conv.weight",
655 "quant_conv.bias",
656 "encoder.conv_in.weight",
657 "encoder.down_blocks.0.resnets.0.norm1.weight",
658 "encoder.mid_block.attentions.0.to_q.weight",
659 "encoder.conv_norm_out.weight",
660 "encoder.conv_out.bias",
661 ] {
662 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
663 }
664 }
665
666 #[test]
667 fn diag_gauss_split_and_mode_shapes() {
668 let cfg = tiny_cfg();
669 let v = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
670 let x = Tensor::from_storage(
671 TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
672 vec![1, 3, 8, 8],
673 false,
674 )
675 .unwrap();
676 let dist = v.encode(&x).unwrap();
677 assert_eq!(dist.mean.shape(), &[1, cfg.latent_channels, 1, 1]);
678 assert_eq!(dist.logvar.shape(), &[1, cfg.latent_channels, 1, 1]);
679 let mode = dist.mode();
681 assert_eq!(mode.shape(), dist.mean.shape());
682 for (a, b) in mode
683 .data()
684 .unwrap()
685 .iter()
686 .zip(dist.mean.data().unwrap().iter())
687 {
688 assert!((a - b).abs() < 1e-7);
689 }
690 }
691
692 #[test]
693 fn diag_gauss_logvar_is_clamped() {
694 let l = 2;
697 let mut data = vec![0.5f32; l * 2]; for d in data.iter_mut().skip(l) {
699 *d = 1e6; }
701 let params =
702 Tensor::from_storage(TensorStorage::cpu(data), vec![1, 2 * l, 1, 1], false).unwrap();
703 let dist = DiagonalGaussianDistribution::<f32>::from_parameters(¶ms, l).unwrap();
704 for &v in dist.logvar.data().unwrap() {
705 assert!(
706 v <= LOGVAR_CLAMP_MAX as f32 + 1e-4,
707 "logvar not clamped to <= {LOGVAR_CLAMP_MAX}: got {v}"
708 );
709 }
710 }
711
712 #[test]
713 fn diag_gauss_sample_with_seed_is_deterministic() {
714 let l = 4;
718 let zeros = vec![0.0f32; 2 * l * 3 * 3];
719 let params = Tensor::from_storage(
720 TensorStorage::cpu(zeros),
721 vec![1, 2 * l, 3, 3],
722 false,
723 )
724 .unwrap();
725 let dist = DiagonalGaussianDistribution::<f32>::from_parameters(¶ms, l).unwrap();
726 let s1 = dist.sample_with_seed(42).unwrap();
727 let s2 = dist.sample_with_seed(42).unwrap();
728 let s3 = dist.sample_with_seed(43).unwrap();
729 assert_eq!(s1.shape(), &[1, l, 3, 3]);
730 for (a, b) in s1.data().unwrap().iter().zip(s2.data().unwrap().iter()) {
731 assert!(
732 (a - b).abs() < 1e-7,
733 "sample_with_seed(42) not deterministic: {a} vs {b}"
734 );
735 }
736 let mut differ = false;
739 for (a, c) in s1.data().unwrap().iter().zip(s3.data().unwrap().iter()) {
740 if (a - c).abs() > 1e-6 {
741 differ = true;
742 break;
743 }
744 }
745 assert!(differ, "sample_with_seed(42) and sample_with_seed(43) produced identical output");
746 }
747
748 #[test]
749 fn vae_encoder_round_trip_state_dict() {
750 let cfg = tiny_cfg();
751 let src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
752 let sd = src.state_dict();
753 let mut dst = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
754 dst.load_state_dict(&sd, true).unwrap();
755 let x = Tensor::from_storage(
756 TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
757 vec![1, 3, 8, 8],
758 false,
759 )
760 .unwrap();
761 let a = src.forward(&x).unwrap();
762 let b = dst.forward(&x).unwrap();
763 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
764 assert!((x - y).abs() < 1e-5, "round-trip differs: {x} vs {y}");
765 }
766 }
767
768 #[test]
769 fn encode_with_scaling_applies_scaling_factor() {
770 let cfg = tiny_cfg();
771 let v = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
772 let x = Tensor::from_storage(
773 TensorStorage::cpu(vec![0.05f32; 3 * 8 * 8]),
774 vec![1, 3, 8, 8],
775 false,
776 )
777 .unwrap();
778 let scaled = v.encode_with_scaling(&x, 99).unwrap();
781 let dist = v.encode(&x).unwrap();
782 let raw = dist.sample_with_seed(99).unwrap();
783 for (s, r) in scaled
784 .data()
785 .unwrap()
786 .iter()
787 .zip(raw.data().unwrap().iter())
788 {
789 let expected = r * cfg.scaling_factor as f32;
790 assert!(
791 (s - expected).abs() < 1e-5,
792 "encode_with_scaling didn't apply scaling_factor: got {s}, expected {expected}"
793 );
794 }
795 }
796}