1use std::collections::HashMap;
38
39use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
40use ferrotorch_nn::module::{Module, StateDict};
41use ferrotorch_nn::parameter::Parameter;
42use ferrotorch_nn::{Conv2d, GroupNorm, SiLU};
43
44use crate::blocks::{DownEncoderBlock2D, UNetMidBlock2D};
45use crate::config::VaeDecoderConfig;
46
47pub type VaeEncoderConfig = VaeDecoderConfig;
54
55const LOGVAR_CLAMP_MIN: f64 = -30.0;
60const LOGVAR_CLAMP_MAX: f64 = 20.0;
61
62#[derive(Debug)]
64pub struct Encoder<T: Float> {
65 pub conv_in: Conv2d<T>,
71 pub down_blocks: Vec<DownEncoderBlock2D<T>>,
76 pub mid_block: UNetMidBlock2D<T>,
79 pub conv_norm_out: GroupNorm<T>,
82 pub conv_act: SiLU,
84 pub conv_out: Conv2d<T>,
88 pub config: VaeEncoderConfig,
90 training: bool,
91}
92
93impl<T: Float> Encoder<T> {
94 pub fn new(cfg: VaeEncoderConfig) -> FerrotorchResult<Self> {
104 cfg.validate()?;
105 let groups = cfg.norm_num_groups;
106 let resnet_eps = 1e-6_f64;
107 let bottom_channels = cfg.block_out_channels[0];
108 let top_channels =
109 *cfg.block_out_channels
110 .last()
111 .ok_or_else(|| FerrotorchError::InvalidArgument {
112 message: "Encoder::new: block_out_channels is empty (should be \
113 unreachable after validate)"
114 .into(),
115 })?;
116
117 let conv_in = Conv2d::<T>::new(
119 cfg.out_channels,
120 bottom_channels,
121 (3, 3),
122 (1, 1),
123 (1, 1),
124 true,
125 )?;
126
127 let num_blocks = cfg.block_out_channels.len();
130 let resnets = cfg.layers_per_block;
131 let mut down_blocks = Vec::with_capacity(num_blocks);
132 let mut prev_out = bottom_channels;
133 for (i, &c) in cfg.block_out_channels.iter().enumerate() {
134 let is_final = i == num_blocks - 1;
135 down_blocks.push(DownEncoderBlock2D::<T>::new(
136 prev_out, c, resnets, groups, resnet_eps, !is_final,
137 )?);
138 prev_out = c;
139 }
140
141 let mid_block = UNetMidBlock2D::<T>::new(top_channels, groups, resnet_eps)?;
143
144 let conv_norm_out = GroupNorm::<T>::new(groups, top_channels, resnet_eps, true)?;
145 let conv_out = Conv2d::<T>::new(
147 top_channels,
148 2 * cfg.latent_channels,
149 (3, 3),
150 (1, 1),
151 (1, 1),
152 true,
153 )?;
154
155 Ok(Self {
156 conv_in,
157 down_blocks,
158 mid_block,
159 conv_norm_out,
160 conv_act: SiLU::new(),
161 conv_out,
162 config: cfg,
163 training: false,
164 })
165 }
166}
167
168impl<T: Float> Module<T> for Encoder<T> {
169 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
170 let cfg = &self.config;
173 if input.ndim() != 4 || input.shape()[1] != cfg.out_channels {
174 return Err(FerrotorchError::ShapeMismatch {
175 message: format!(
176 "Encoder::forward: expected [B, {}, H, W], got {:?}",
177 cfg.out_channels,
178 input.shape()
179 ),
180 });
181 }
182 let mut h = self.conv_in.forward(input)?;
183 for d in &self.down_blocks {
184 h = d.forward(&h)?;
185 }
186 h = self.mid_block.forward(&h)?;
187 h = self.conv_norm_out.forward(&h)?;
188 h = self.conv_act.forward(&h)?;
189 self.conv_out.forward(&h)
190 }
191
192 fn parameters(&self) -> Vec<&Parameter<T>> {
193 let mut out = Vec::new();
194 out.extend(self.conv_in.parameters());
195 for b in &self.down_blocks {
196 out.extend(b.parameters());
197 }
198 out.extend(self.mid_block.parameters());
199 out.extend(self.conv_norm_out.parameters());
200 out.extend(self.conv_out.parameters());
201 out
202 }
203
204 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
205 let mut out = Vec::new();
206 out.extend(self.conv_in.parameters_mut());
207 for b in &mut self.down_blocks {
208 out.extend(b.parameters_mut());
209 }
210 out.extend(self.mid_block.parameters_mut());
211 out.extend(self.conv_norm_out.parameters_mut());
212 out.extend(self.conv_out.parameters_mut());
213 out
214 }
215
216 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
217 let mut out = Vec::new();
218 for (n, p) in self.conv_in.named_parameters() {
219 out.push((format!("conv_in.{n}"), p));
220 }
221 for (i, b) in self.down_blocks.iter().enumerate() {
222 for (n, p) in b.named_parameters() {
223 out.push((format!("down_blocks.{i}.{n}"), p));
224 }
225 }
226 for (n, p) in self.mid_block.named_parameters() {
227 out.push((format!("mid_block.{n}"), p));
228 }
229 for (n, p) in self.conv_norm_out.named_parameters() {
230 out.push((format!("conv_norm_out.{n}"), p));
231 }
232 for (n, p) in self.conv_out.named_parameters() {
233 out.push((format!("conv_out.{n}"), p));
234 }
235 out
236 }
237
238 fn train(&mut self) {
239 self.training = true;
240 for b in &mut self.down_blocks {
241 b.train();
242 }
243 self.mid_block.train();
244 }
245 fn eval(&mut self) {
246 self.training = false;
247 for b in &mut self.down_blocks {
248 b.eval();
249 }
250 self.mid_block.eval();
251 }
252 fn is_training(&self) -> bool {
253 self.training
254 }
255
256 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
257 let extract = |prefix: &str| -> StateDict<T> {
258 let p = format!("{prefix}.");
259 state
260 .iter()
261 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
262 .collect()
263 };
264
265 if strict {
266 for k in state.keys() {
267 let ok = k.starts_with("conv_in.")
268 || k.starts_with("down_blocks.")
269 || k.starts_with("mid_block.")
270 || k.starts_with("conv_norm_out.")
271 || k.starts_with("conv_out.");
272 if !ok {
273 return Err(FerrotorchError::InvalidArgument {
274 message: format!("unexpected key in Encoder state_dict: \"{k}\""),
275 });
276 }
277 }
278 }
279
280 self.conv_in.load_state_dict(&extract("conv_in"), strict)?;
281 for (i, b) in self.down_blocks.iter_mut().enumerate() {
282 b.load_state_dict(&extract(&format!("down_blocks.{i}")), strict)?;
283 }
284 self.mid_block
285 .load_state_dict(&extract("mid_block"), strict)?;
286 self.conv_norm_out
287 .load_state_dict(&extract("conv_norm_out"), strict)?;
288 self.conv_out
289 .load_state_dict(&extract("conv_out"), strict)?;
290 Ok(())
291 }
292}
293
294#[derive(Debug)]
301pub struct VaeEncoder<T: Float> {
302 pub encoder: Encoder<T>,
304 pub quant_conv: Conv2d<T>,
306 pub config: VaeEncoderConfig,
308 training: bool,
309}
310
311impl<T: Float> VaeEncoder<T> {
312 pub fn new(cfg: VaeEncoderConfig) -> FerrotorchResult<Self> {
318 cfg.validate()?;
319 let encoder = Encoder::<T>::new(cfg.clone())?;
320 let two_l = 2 * cfg.latent_channels;
321 let quant_conv = Conv2d::<T>::new(two_l, two_l, (1, 1), (1, 1), (0, 0), true)?;
322 Ok(Self {
323 encoder,
324 quant_conv,
325 config: cfg,
326 training: false,
327 })
328 }
329
330 pub fn encode(&self, image: &Tensor<T>) -> FerrotorchResult<DiagonalGaussianDistribution<T>> {
338 let params = self.forward(image)?;
339 DiagonalGaussianDistribution::from_parameters(¶ms, self.config.latent_channels)
340 }
341
342 pub fn encode_with_scaling(&self, image: &Tensor<T>, seed: u64) -> FerrotorchResult<Tensor<T>> {
361 let dist = self.encode(image)?;
362 let sample = dist.sample_with_seed(seed)?;
363 let sf = T::from(self.config.scaling_factor).ok_or_else(|| {
364 FerrotorchError::InvalidArgument {
365 message: format!(
366 "VaeEncoder::encode_with_scaling: cannot cast scaling_factor={} into Float",
367 self.config.scaling_factor
368 ),
369 }
370 })?;
371 let sf_t = ferrotorch_core::scalar::<T>(sf)?;
372 ferrotorch_core::grad_fns::arithmetic::mul(&sample, &sf_t)
373 }
374}
375
376impl<T: Float> Module<T> for VaeEncoder<T> {
377 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
382 let cfg = &self.config;
383 if input.ndim() != 4 || input.shape()[1] != cfg.out_channels {
384 return Err(FerrotorchError::ShapeMismatch {
385 message: format!(
386 "VaeEncoder::forward: expected [B, {}, H, W], got {:?}",
387 cfg.out_channels,
388 input.shape()
389 ),
390 });
391 }
392 let h = self.encoder.forward(input)?;
393 self.quant_conv.forward(&h)
394 }
395
396 fn parameters(&self) -> Vec<&Parameter<T>> {
397 let mut out = Vec::new();
398 out.extend(self.encoder.parameters());
399 out.extend(self.quant_conv.parameters());
400 out
401 }
402
403 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
404 let mut out = Vec::new();
405 out.extend(self.encoder.parameters_mut());
406 out.extend(self.quant_conv.parameters_mut());
407 out
408 }
409
410 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
411 let mut out = Vec::new();
412 for (n, p) in self.encoder.named_parameters() {
413 out.push((format!("encoder.{n}"), p));
414 }
415 for (n, p) in self.quant_conv.named_parameters() {
416 out.push((format!("quant_conv.{n}"), p));
417 }
418 out
419 }
420
421 fn train(&mut self) {
422 self.training = true;
423 self.encoder.train();
424 }
425 fn eval(&mut self) {
426 self.training = false;
427 self.encoder.eval();
428 }
429 fn is_training(&self) -> bool {
430 self.training
431 }
432
433 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
434 let extract = |prefix: &str| -> StateDict<T> {
435 let p = format!("{prefix}.");
436 state
437 .iter()
438 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
439 .collect()
440 };
441 if strict {
442 for k in state.keys() {
443 let ok = k.starts_with("encoder.") || k.starts_with("quant_conv.");
444 if !ok {
445 return Err(FerrotorchError::InvalidArgument {
446 message: format!("unexpected key in VaeEncoder state_dict: \"{k}\""),
447 });
448 }
449 }
450 }
451 self.encoder.load_state_dict(&extract("encoder"), strict)?;
452 self.quant_conv
453 .load_state_dict(&extract("quant_conv"), strict)?;
454 let _: HashMap<String, Tensor<T>> = HashMap::new(); Ok(())
456 }
457}
458
459#[derive(Debug)]
467pub struct DiagonalGaussianDistribution<T: Float> {
468 pub mean: Tensor<T>,
470 pub logvar: Tensor<T>,
472}
473
474impl<T: Float> DiagonalGaussianDistribution<T> {
475 pub fn from_parameters(params: &Tensor<T>, latent_channels: usize) -> FerrotorchResult<Self> {
484 if params.ndim() != 4 || params.shape()[1] != 2 * latent_channels {
485 return Err(FerrotorchError::ShapeMismatch {
486 message: format!(
487 "DiagonalGaussianDistribution::from_parameters: expected [B, {}, H, W], \
488 got {:?}",
489 2 * latent_channels,
490 params.shape()
491 ),
492 });
493 }
494 let parts = params.chunk(2, 1)?;
496 let mean = parts[0].clone();
497 let logvar_raw = parts[1].clone();
498 let lo = T::from(LOGVAR_CLAMP_MIN).ok_or_else(|| FerrotorchError::InvalidArgument {
500 message: format!(
501 "DiagonalGaussianDistribution: cannot cast logvar clamp min {LOGVAR_CLAMP_MIN} \
502 into Float"
503 ),
504 })?;
505 let hi = T::from(LOGVAR_CLAMP_MAX).ok_or_else(|| FerrotorchError::InvalidArgument {
506 message: format!(
507 "DiagonalGaussianDistribution: cannot cast logvar clamp max {LOGVAR_CLAMP_MAX} \
508 into Float"
509 ),
510 })?;
511 let logvar = logvar_raw.clamp_t(lo, hi)?;
512 Ok(Self { mean, logvar })
513 }
514
515 pub fn mode(&self) -> &Tensor<T> {
519 &self.mean
520 }
521
522 pub fn sample_with_seed(&self, seed: u64) -> FerrotorchResult<Tensor<T>> {
540 let eps = randn_with_seed::<T>(self.mean.shape(), seed)?;
541 let half = T::from(0.5f64).ok_or_else(|| FerrotorchError::InvalidArgument {
543 message: "DiagonalGaussianDistribution::sample_with_seed: cannot cast 0.5 into Float"
544 .into(),
545 })?;
546 let half_t = ferrotorch_core::scalar::<T>(half)?;
547 let scaled_logvar = ferrotorch_core::grad_fns::arithmetic::mul(&self.logvar, &half_t)?;
548 let std = ferrotorch_core::grad_fns::transcendental::exp(&scaled_logvar)?;
549 let noise = ferrotorch_core::grad_fns::arithmetic::mul(&std, &eps)?;
550 ferrotorch_core::grad_fns::arithmetic::add(&self.mean, &noise)
551 }
552}
553
554fn randn_with_seed<T: Float>(shape: &[usize], seed: u64) -> FerrotorchResult<Tensor<T>> {
561 let numel: usize = shape.iter().product();
562 let mut data = Vec::with_capacity(numel);
563 let mut state = if seed == 0 {
564 0x0000_dead_beef_cafe
565 } else {
566 seed
567 };
568
569 let mut next_uniform = || -> f64 {
570 state ^= state << 13;
571 state ^= state >> 7;
572 state ^= state << 17;
573 ((state as f64) / (u64::MAX as f64)).max(1e-300)
574 };
575
576 let mut i = 0;
577 while i < numel {
578 let u1 = next_uniform();
579 let u2 = next_uniform();
580 let r = (-2.0 * u1.ln()).sqrt();
581 let theta = 2.0 * std::f64::consts::PI * u2;
582 data.push(
583 T::from(r * theta.cos()).ok_or_else(|| FerrotorchError::InvalidArgument {
584 message: "randn_with_seed: cannot cast Box-Muller output into Float".into(),
585 })?,
586 );
587 if i + 1 < numel {
588 data.push(T::from(r * theta.sin()).ok_or_else(|| {
589 FerrotorchError::InvalidArgument {
590 message: "randn_with_seed: cannot cast Box-Muller output into Float".into(),
591 }
592 })?);
593 }
594 i += 2;
595 }
596
597 data.truncate(numel);
598 Tensor::from_storage(TensorStorage::cpu(data), shape.to_vec(), false)
599}
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604
605 fn tiny_cfg() -> VaeEncoderConfig {
609 VaeDecoderConfig {
610 out_channels: 3,
611 latent_channels: 4,
612 block_out_channels: vec![4, 8, 16, 16],
613 layers_per_block: 1,
614 norm_num_groups: 4,
615 sample_size: 8,
616 scaling_factor: 0.18215,
617 }
618 }
619
620 #[test]
621 fn encoder_forward_shape() {
622 let cfg = tiny_cfg();
623 let e = Encoder::<f32>::new(cfg.clone()).unwrap();
624 let x = Tensor::from_storage(
626 TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
627 vec![1, 3, 8, 8],
628 false,
629 )
630 .unwrap();
631 let y = e.forward(&x).unwrap();
632 assert_eq!(y.shape(), &[1, 2 * cfg.latent_channels, 1, 1]);
634 for &v in y.data().unwrap() {
635 assert!(v.is_finite(), "encoder output non-finite: {v}");
636 }
637 }
638
639 #[test]
640 fn vae_encoder_forward_shape() {
641 let cfg = tiny_cfg();
642 let v = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
643 let x = Tensor::from_storage(
644 TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
645 vec![1, 3, 8, 8],
646 false,
647 )
648 .unwrap();
649 let y = v.forward(&x).unwrap();
650 assert_eq!(y.shape(), &[1, 2 * cfg.latent_channels, 1, 1]);
651 }
652
653 #[test]
654 fn vae_encoder_named_parameters_include_quant_conv() {
655 let cfg = tiny_cfg();
656 let v = VaeEncoder::<f32>::new(cfg).unwrap();
657 let names: Vec<String> = v.named_parameters().into_iter().map(|(n, _)| n).collect();
658 for k in [
659 "quant_conv.weight",
660 "quant_conv.bias",
661 "encoder.conv_in.weight",
662 "encoder.down_blocks.0.resnets.0.norm1.weight",
663 "encoder.mid_block.attentions.0.to_q.weight",
664 "encoder.conv_norm_out.weight",
665 "encoder.conv_out.bias",
666 ] {
667 assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
668 }
669 }
670
671 #[test]
672 fn diag_gauss_split_and_mode_shapes() {
673 let cfg = tiny_cfg();
674 let v = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
675 let x = Tensor::from_storage(
676 TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
677 vec![1, 3, 8, 8],
678 false,
679 )
680 .unwrap();
681 let dist = v.encode(&x).unwrap();
682 assert_eq!(dist.mean.shape(), &[1, cfg.latent_channels, 1, 1]);
683 assert_eq!(dist.logvar.shape(), &[1, cfg.latent_channels, 1, 1]);
684 let mode = dist.mode();
686 assert_eq!(mode.shape(), dist.mean.shape());
687 for (a, b) in mode
688 .data()
689 .unwrap()
690 .iter()
691 .zip(dist.mean.data().unwrap().iter())
692 {
693 assert!((a - b).abs() < 1e-7);
694 }
695 }
696
697 #[test]
698 fn diag_gauss_logvar_is_clamped() {
699 let l = 2;
702 let mut data = vec![0.5f32; l * 2]; for d in data.iter_mut().skip(l) {
704 *d = 1e6; }
706 let params =
707 Tensor::from_storage(TensorStorage::cpu(data), vec![1, 2 * l, 1, 1], false).unwrap();
708 let dist = DiagonalGaussianDistribution::<f32>::from_parameters(¶ms, l).unwrap();
709 for &v in dist.logvar.data().unwrap() {
710 assert!(
711 v <= LOGVAR_CLAMP_MAX as f32 + 1e-4,
712 "logvar not clamped to <= {LOGVAR_CLAMP_MAX}: got {v}"
713 );
714 }
715 }
716
717 #[test]
718 fn diag_gauss_sample_with_seed_is_deterministic() {
719 let l = 4;
723 let zeros = vec![0.0f32; 2 * l * 3 * 3];
724 let params =
725 Tensor::from_storage(TensorStorage::cpu(zeros), vec![1, 2 * l, 3, 3], false).unwrap();
726 let dist = DiagonalGaussianDistribution::<f32>::from_parameters(¶ms, l).unwrap();
727 let s1 = dist.sample_with_seed(42).unwrap();
728 let s2 = dist.sample_with_seed(42).unwrap();
729 let s3 = dist.sample_with_seed(43).unwrap();
730 assert_eq!(s1.shape(), &[1, l, 3, 3]);
731 for (a, b) in s1.data().unwrap().iter().zip(s2.data().unwrap().iter()) {
732 assert!(
733 (a - b).abs() < 1e-7,
734 "sample_with_seed(42) not deterministic: {a} vs {b}"
735 );
736 }
737 let mut differ = false;
740 for (a, c) in s1.data().unwrap().iter().zip(s3.data().unwrap().iter()) {
741 if (a - c).abs() > 1e-6 {
742 differ = true;
743 break;
744 }
745 }
746 assert!(
747 differ,
748 "sample_with_seed(42) and sample_with_seed(43) produced identical output"
749 );
750 }
751
752 #[test]
753 fn vae_encoder_round_trip_state_dict() {
754 let cfg = tiny_cfg();
755 let src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
756 let sd = src.state_dict();
757 let mut dst = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
758 dst.load_state_dict(&sd, true).unwrap();
759 let x = Tensor::from_storage(
760 TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
761 vec![1, 3, 8, 8],
762 false,
763 )
764 .unwrap();
765 let a = src.forward(&x).unwrap();
766 let b = dst.forward(&x).unwrap();
767 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
768 assert!((x - y).abs() < 1e-5, "round-trip differs: {x} vs {y}");
769 }
770 }
771
772 #[test]
773 fn encode_with_scaling_applies_scaling_factor() {
774 let cfg = tiny_cfg();
775 let v = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
776 let x = Tensor::from_storage(
777 TensorStorage::cpu(vec![0.05f32; 3 * 8 * 8]),
778 vec![1, 3, 8, 8],
779 false,
780 )
781 .unwrap();
782 let scaled = v.encode_with_scaling(&x, 99).unwrap();
785 let dist = v.encode(&x).unwrap();
786 let raw = dist.sample_with_seed(99).unwrap();
787 for (s, r) in scaled
788 .data()
789 .unwrap()
790 .iter()
791 .zip(raw.data().unwrap().iter())
792 {
793 let expected = r * cfg.scaling_factor as f32;
794 assert!(
795 (s - expected).abs() < 1e-5,
796 "encode_with_scaling didn't apply scaling_factor: got {s}, expected {expected}"
797 );
798 }
799 }
800}