Skip to main content

ferrotorch_diffusion/
vae_encoder.rs

1//! Stable-Diffusion VAE encoder composition.
2//!
3//! Mirrors `diffusers.AutoencoderKL.encode(image).latent_dist` for
4//! `runwayml/stable-diffusion-v1-5`:
5//!
6//! ```text
7//! image (pixel-space, [B, 3, H, W])
8//!   -> Encoder.conv_in
9//!   -> Encoder.down_blocks[0..N]   (last block has no Downsample2D)
10//!   -> Encoder.mid_block
11//!   -> Encoder.conv_norm_out -> SiLU -> Encoder.conv_out
12//!     (output: [B, 2 * latent_channels, H/8, W/8] — mean/logvar concat)
13//!   -> quant_conv                  ([2*L -> 2*L], 1x1)
14//!   -> DiagonalGaussianDistribution::from_parameters
15//! ```
16//!
17//! The encoder-side mirror of [`crate::vae::VaeDecoder`]. The
18//! `encode_with_scaling` helper composes
19//! `latent_dist.sample(seed) * scaling_factor`, matching
20//! `AutoencoderKL.encode(x).latent_dist.sample() * vae.config.scaling_factor`.
21//! The bare [`Module::forward`] returns the raw `[B, 2*L, h, w]` parameters
22//! tensor (no split, no sample, no scaling) so callers can swap in their
23//! own sampling strategy (e.g. `.mode()` for deterministic decoding).
24
25use std::collections::HashMap;
26
27use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
28use ferrotorch_nn::module::{Module, StateDict};
29use ferrotorch_nn::parameter::Parameter;
30use ferrotorch_nn::{Conv2d, GroupNorm, SiLU};
31
32use crate::blocks::{DownEncoderBlock2D, UNetMidBlock2D};
33use crate::config::VaeDecoderConfig;
34
35/// Type alias — the SD VAE encoder and decoder share their config shape
36/// (mirrors `diffusers.AutoencoderKL.config`, which spans both halves).
37///
38/// Using a single config type also lets a downstream caller load
39/// `vae/config.json` once via [`VaeDecoderConfig::from_file`] and pass it
40/// to both `VaeEncoder::new` and `VaeDecoder::new`.
41pub type VaeEncoderConfig = VaeDecoderConfig;
42
43/// Diffusers logvar clamp range (`DiagonalGaussianDistribution.__init__`).
44///
45/// Matches `torch.clamp(self.logvar, -30.0, 20.0)` in
46/// `diffusers.models.autoencoders.vae.DiagonalGaussianDistribution`.
47const LOGVAR_CLAMP_MIN: f64 = -30.0;
48const LOGVAR_CLAMP_MAX: f64 = 20.0;
49
50/// The bare `Encoder` half — matches `diffusers.models.autoencoders.vae.Encoder`.
51#[derive(Debug)]
52pub struct Encoder<T: Float> {
53    /// First conv: `out_channels -> block_out_channels[0]` (k=3, pad=1).
54    ///
55    /// `out_channels` in [`VaeDecoderConfig`] is the image-side channel
56    /// count (the decoder *output* and the encoder *input* — they're the
57    /// same value, the field name is decoder-centric).
58    pub conv_in: Conv2d<T>,
59    /// Down-blocks in *encoder order* — block 0 operates at the lowest
60    /// channel count and highest spatial resolution. The deepest block
61    /// (`down_blocks[N-1]`) has no downsample (it preserves spatial
62    /// resolution into the mid-block).
63    pub down_blocks: Vec<DownEncoderBlock2D<T>>,
64    /// VAE mid-block at `block_out_channels[-1]` channels (same module
65    /// as in the decoder).
66    pub mid_block: UNetMidBlock2D<T>,
67    /// Final GroupNorm before the output conv (operates on
68    /// `block_out_channels[-1]` channels).
69    pub conv_norm_out: GroupNorm<T>,
70    /// Output activation (SiLU).
71    pub conv_act: SiLU,
72    /// Output conv: `block_out_channels[-1] -> 2 * latent_channels`
73    /// (k=3, pad=1). The factor of 2 holds the concatenated mean/logvar
74    /// produced by the diagonal Gaussian head.
75    pub conv_out: Conv2d<T>,
76    /// Frozen copy of the config.
77    pub config: VaeEncoderConfig,
78    training: bool,
79}
80
81impl<T: Float> Encoder<T> {
82    /// Build a randomly-initialized `Encoder`.
83    ///
84    /// # Errors
85    ///
86    /// Returns [`FerrotorchError::InvalidArgument`] for any invalid
87    /// config field (forwarded from [`VaeDecoderConfig::validate`]). In
88    /// particular `block_out_channels` must be non-empty — the index
89    /// into `[0]` / `[N-1]` below is preceded by `cfg.validate()?`
90    /// which checks exactly that.
91    pub fn new(cfg: VaeEncoderConfig) -> FerrotorchResult<Self> {
92        cfg.validate()?;
93        let groups = cfg.norm_num_groups;
94        let resnet_eps = 1e-6_f64;
95        let bottom_channels = cfg.block_out_channels[0];
96        let top_channels =
97            *cfg.block_out_channels
98                .last()
99                .ok_or_else(|| FerrotorchError::InvalidArgument {
100                    message: "Encoder::new: block_out_channels is empty (should be \
101                              unreachable after validate)"
102                        .into(),
103                })?;
104
105        // conv_in: image (out_channels) -> bottom_channels
106        let conv_in = Conv2d::<T>::new(
107            cfg.out_channels,
108            bottom_channels,
109            (3, 3),
110            (1, 1),
111            (1, 1),
112            true,
113        )?;
114
115        // Down-blocks. Encoder uses `layers_per_block` resnets per block
116        // (vs `layers_per_block + 1` on the decoder side).
117        let num_blocks = cfg.block_out_channels.len();
118        let resnets = cfg.layers_per_block;
119        let mut down_blocks = Vec::with_capacity(num_blocks);
120        let mut prev_out = bottom_channels;
121        for (i, &c) in cfg.block_out_channels.iter().enumerate() {
122            let is_final = i == num_blocks - 1;
123            down_blocks.push(DownEncoderBlock2D::<T>::new(
124                prev_out,
125                c,
126                resnets,
127                groups,
128                resnet_eps,
129                !is_final,
130            )?);
131            prev_out = c;
132        }
133
134        // Mid-block at the deepest channel count.
135        let mid_block = UNetMidBlock2D::<T>::new(top_channels, groups, resnet_eps)?;
136
137        let conv_norm_out =
138            GroupNorm::<T>::new(groups, top_channels, resnet_eps, true)?;
139        // The `2 *` factor produces the concatenated mean/logvar tensor.
140        let conv_out = Conv2d::<T>::new(
141            top_channels,
142            2 * cfg.latent_channels,
143            (3, 3),
144            (1, 1),
145            (1, 1),
146            true,
147        )?;
148
149        Ok(Self {
150            conv_in,
151            down_blocks,
152            mid_block,
153            conv_norm_out,
154            conv_act: SiLU::new(),
155            conv_out,
156            config: cfg,
157            training: false,
158        })
159    }
160}
161
162impl<T: Float> Module<T> for Encoder<T> {
163    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
164        // Sanity check: [B, out_channels, H, W] (out_channels is the
165        // image channel count — naming is decoder-centric).
166        let cfg = &self.config;
167        if input.ndim() != 4 || input.shape()[1] != cfg.out_channels {
168            return Err(FerrotorchError::ShapeMismatch {
169                message: format!(
170                    "Encoder::forward: expected [B, {}, H, W], got {:?}",
171                    cfg.out_channels,
172                    input.shape()
173                ),
174            });
175        }
176        let mut h = self.conv_in.forward(input)?;
177        for d in &self.down_blocks {
178            h = d.forward(&h)?;
179        }
180        h = self.mid_block.forward(&h)?;
181        h = self.conv_norm_out.forward(&h)?;
182        h = self.conv_act.forward(&h)?;
183        self.conv_out.forward(&h)
184    }
185
186    fn parameters(&self) -> Vec<&Parameter<T>> {
187        let mut out = Vec::new();
188        out.extend(self.conv_in.parameters());
189        for b in &self.down_blocks {
190            out.extend(b.parameters());
191        }
192        out.extend(self.mid_block.parameters());
193        out.extend(self.conv_norm_out.parameters());
194        out.extend(self.conv_out.parameters());
195        out
196    }
197
198    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
199        let mut out = Vec::new();
200        out.extend(self.conv_in.parameters_mut());
201        for b in &mut self.down_blocks {
202            out.extend(b.parameters_mut());
203        }
204        out.extend(self.mid_block.parameters_mut());
205        out.extend(self.conv_norm_out.parameters_mut());
206        out.extend(self.conv_out.parameters_mut());
207        out
208    }
209
210    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
211        let mut out = Vec::new();
212        for (n, p) in self.conv_in.named_parameters() {
213            out.push((format!("conv_in.{n}"), p));
214        }
215        for (i, b) in self.down_blocks.iter().enumerate() {
216            for (n, p) in b.named_parameters() {
217                out.push((format!("down_blocks.{i}.{n}"), p));
218            }
219        }
220        for (n, p) in self.mid_block.named_parameters() {
221            out.push((format!("mid_block.{n}"), p));
222        }
223        for (n, p) in self.conv_norm_out.named_parameters() {
224            out.push((format!("conv_norm_out.{n}"), p));
225        }
226        for (n, p) in self.conv_out.named_parameters() {
227            out.push((format!("conv_out.{n}"), p));
228        }
229        out
230    }
231
232    fn train(&mut self) {
233        self.training = true;
234        for b in &mut self.down_blocks {
235            b.train();
236        }
237        self.mid_block.train();
238    }
239    fn eval(&mut self) {
240        self.training = false;
241        for b in &mut self.down_blocks {
242            b.eval();
243        }
244        self.mid_block.eval();
245    }
246    fn is_training(&self) -> bool {
247        self.training
248    }
249
250    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
251        let extract = |prefix: &str| -> StateDict<T> {
252            let p = format!("{prefix}.");
253            state
254                .iter()
255                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
256                .collect()
257        };
258
259        if strict {
260            for k in state.keys() {
261                let ok = k.starts_with("conv_in.")
262                    || k.starts_with("down_blocks.")
263                    || k.starts_with("mid_block.")
264                    || k.starts_with("conv_norm_out.")
265                    || k.starts_with("conv_out.");
266                if !ok {
267                    return Err(FerrotorchError::InvalidArgument {
268                        message: format!("unexpected key in Encoder state_dict: \"{k}\""),
269                    });
270                }
271            }
272        }
273
274        self.conv_in.load_state_dict(&extract("conv_in"), strict)?;
275        for (i, b) in self.down_blocks.iter_mut().enumerate() {
276            b.load_state_dict(&extract(&format!("down_blocks.{i}")), strict)?;
277        }
278        self.mid_block
279            .load_state_dict(&extract("mid_block"), strict)?;
280        self.conv_norm_out
281            .load_state_dict(&extract("conv_norm_out"), strict)?;
282        self.conv_out
283            .load_state_dict(&extract("conv_out"), strict)?;
284        Ok(())
285    }
286}
287
288/// `AutoencoderKL`-style VAE encoder = [`Encoder`] + `quant_conv`.
289///
290/// The encoder produces `[B, 2*latent_channels, H/8, W/8]` after the
291/// `Encoder` stack; the 1x1 `quant_conv` then projects this through a
292/// learned linear map. The split into (mean, logvar) and any sampling
293/// happens in [`DiagonalGaussianDistribution`].
294#[derive(Debug)]
295pub struct VaeEncoder<T: Float> {
296    /// The actual `Encoder` stack.
297    pub encoder: Encoder<T>,
298    /// 1x1 quant projection over `2 * latent_channels` channels.
299    pub quant_conv: Conv2d<T>,
300    /// Frozen config copy.
301    pub config: VaeEncoderConfig,
302    training: bool,
303}
304
305impl<T: Float> VaeEncoder<T> {
306    /// Build a randomly-initialized `VaeEncoder`.
307    ///
308    /// # Errors
309    ///
310    /// Returns the underlying [`FerrotorchError`] on bad config dims.
311    pub fn new(cfg: VaeEncoderConfig) -> FerrotorchResult<Self> {
312        cfg.validate()?;
313        let encoder = Encoder::<T>::new(cfg.clone())?;
314        let two_l = 2 * cfg.latent_channels;
315        let quant_conv = Conv2d::<T>::new(two_l, two_l, (1, 1), (1, 1), (0, 0), true)?;
316        Ok(Self {
317            encoder,
318            quant_conv,
319            config: cfg,
320            training: false,
321        })
322    }
323
324    /// Encode an image into a diagonal Gaussian distribution over
325    /// latent space. Matches `AutoencoderKL.encode(image).latent_dist`.
326    ///
327    /// # Errors
328    ///
329    /// Returns [`FerrotorchError::ShapeMismatch`] when the input is not
330    /// `[B, out_channels, H, W]`. Propagates downstream op errors.
331    pub fn encode(&self, image: &Tensor<T>) -> FerrotorchResult<DiagonalGaussianDistribution<T>> {
332        let params = self.forward(image)?;
333        DiagonalGaussianDistribution::from_parameters(&params, self.config.latent_channels)
334    }
335
336    /// Encode an image, sample from the latent distribution with a
337    /// deterministic seed, then multiply by `scaling_factor`. This
338    /// matches the canonical SD pipeline call:
339    ///
340    /// ```text
341    /// latent = vae.encode(image).latent_dist.sample() * vae.config.scaling_factor
342    /// ```
343    ///
344    /// The Box-Muller / xorshift noise generator runs on CPU and is
345    /// fully deterministic for a given `seed` — different from
346    /// PyTorch's CUDA RNG, so the produced latent will NOT be
347    /// numerically identical to a Python reference run; only
348    /// statistically equivalent.
349    ///
350    /// # Errors
351    ///
352    /// Returns the underlying [`FerrotorchError`] for shape or
353    /// arithmetic failures.
354    pub fn encode_with_scaling(
355        &self,
356        image: &Tensor<T>,
357        seed: u64,
358    ) -> FerrotorchResult<Tensor<T>> {
359        let dist = self.encode(image)?;
360        let sample = dist.sample_with_seed(seed)?;
361        let sf = T::from(self.config.scaling_factor)
362            .ok_or_else(|| FerrotorchError::InvalidArgument {
363                message: format!(
364                    "VaeEncoder::encode_with_scaling: cannot cast scaling_factor={} into Float",
365                    self.config.scaling_factor
366                ),
367            })?;
368        let sf_t = ferrotorch_core::scalar::<T>(sf)?;
369        ferrotorch_core::grad_fns::arithmetic::mul(&sample, &sf_t)
370    }
371}
372
373impl<T: Float> Module<T> for VaeEncoder<T> {
374    /// Forward returns the raw `[B, 2*latent_channels, h, w]` parameters
375    /// tensor (concatenated mean/logvar, post `quant_conv`). Callers
376    /// who want a latent should use [`Self::encode`] or
377    /// [`Self::encode_with_scaling`].
378    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
379        let cfg = &self.config;
380        if input.ndim() != 4 || input.shape()[1] != cfg.out_channels {
381            return Err(FerrotorchError::ShapeMismatch {
382                message: format!(
383                    "VaeEncoder::forward: expected [B, {}, H, W], got {:?}",
384                    cfg.out_channels,
385                    input.shape()
386                ),
387            });
388        }
389        let h = self.encoder.forward(input)?;
390        self.quant_conv.forward(&h)
391    }
392
393    fn parameters(&self) -> Vec<&Parameter<T>> {
394        let mut out = Vec::new();
395        out.extend(self.encoder.parameters());
396        out.extend(self.quant_conv.parameters());
397        out
398    }
399
400    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
401        let mut out = Vec::new();
402        out.extend(self.encoder.parameters_mut());
403        out.extend(self.quant_conv.parameters_mut());
404        out
405    }
406
407    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
408        let mut out = Vec::new();
409        for (n, p) in self.encoder.named_parameters() {
410            out.push((format!("encoder.{n}"), p));
411        }
412        for (n, p) in self.quant_conv.named_parameters() {
413            out.push((format!("quant_conv.{n}"), p));
414        }
415        out
416    }
417
418    fn train(&mut self) {
419        self.training = true;
420        self.encoder.train();
421    }
422    fn eval(&mut self) {
423        self.training = false;
424        self.encoder.eval();
425    }
426    fn is_training(&self) -> bool {
427        self.training
428    }
429
430    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
431        let extract = |prefix: &str| -> StateDict<T> {
432            let p = format!("{prefix}.");
433            state
434                .iter()
435                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
436                .collect()
437        };
438        if strict {
439            for k in state.keys() {
440                let ok = k.starts_with("encoder.") || k.starts_with("quant_conv.");
441                if !ok {
442                    return Err(FerrotorchError::InvalidArgument {
443                        message: format!("unexpected key in VaeEncoder state_dict: \"{k}\""),
444                    });
445                }
446            }
447        }
448        self.encoder.load_state_dict(&extract("encoder"), strict)?;
449        self.quant_conv
450            .load_state_dict(&extract("quant_conv"), strict)?;
451        let _: HashMap<String, Tensor<T>> = HashMap::new(); // keep HashMap import alive
452        Ok(())
453    }
454}
455
456/// Diagonal Gaussian over latent space — the same parameterization
457/// `diffusers.models.autoencoders.vae.DiagonalGaussianDistribution`
458/// uses. Holds `mean` and `logvar` tensors (both `[B, L, h, w]`) split
459/// from the encoder's concatenated parameters output.
460///
461/// `logvar` is clamped to `[-30, 20]` on construction, matching the
462/// diffusers reference exactly.
463#[derive(Debug)]
464pub struct DiagonalGaussianDistribution<T: Float> {
465    /// Mean of the diagonal Gaussian, `[B, L, h, w]`.
466    pub mean: Tensor<T>,
467    /// Log-variance of the diagonal Gaussian (clamped), `[B, L, h, w]`.
468    pub logvar: Tensor<T>,
469}
470
471impl<T: Float> DiagonalGaussianDistribution<T> {
472    /// Split the encoder's `[B, 2*L, h, w]` parameters tensor into
473    /// mean / logvar halves along the channel axis.
474    ///
475    /// # Errors
476    ///
477    /// Returns [`FerrotorchError::ShapeMismatch`] when the channel
478    /// dimension is not `2 * latent_channels`. Propagates downstream
479    /// errors from the chunk / clamp ops.
480    pub fn from_parameters(
481        params: &Tensor<T>,
482        latent_channels: usize,
483    ) -> FerrotorchResult<Self> {
484        if params.ndim() != 4 || params.shape()[1] != 2 * latent_channels {
485            return Err(FerrotorchError::ShapeMismatch {
486                message: format!(
487                    "DiagonalGaussianDistribution::from_parameters: expected [B, {}, H, W], \
488                     got {:?}",
489                    2 * latent_channels,
490                    params.shape()
491                ),
492            });
493        }
494        // chunk(2, dim=1): two [B, L, h, w] tensors — [mean, logvar].
495        let parts = params.chunk(2, 1)?;
496        let mean = parts[0].clone();
497        let logvar_raw = parts[1].clone();
498        // Match diffusers' `torch.clamp(logvar, -30.0, 20.0)`.
499        let lo = T::from(LOGVAR_CLAMP_MIN).ok_or_else(|| FerrotorchError::InvalidArgument {
500            message: format!(
501                "DiagonalGaussianDistribution: cannot cast logvar clamp min {LOGVAR_CLAMP_MIN} \
502                 into Float"
503            ),
504        })?;
505        let hi = T::from(LOGVAR_CLAMP_MAX).ok_or_else(|| FerrotorchError::InvalidArgument {
506            message: format!(
507                "DiagonalGaussianDistribution: cannot cast logvar clamp max {LOGVAR_CLAMP_MAX} \
508                 into Float"
509            ),
510        })?;
511        let logvar = logvar_raw.clamp_t(lo, hi)?;
512        Ok(Self { mean, logvar })
513    }
514
515    /// Return the distribution mode — the deterministic latent if the
516    /// caller does not want to sample. Matches
517    /// `DiagonalGaussianDistribution.mode()` in diffusers.
518    pub fn mode(&self) -> &Tensor<T> {
519        &self.mean
520    }
521
522    /// Sample from the distribution with a deterministic seed.
523    ///
524    /// Computes `mean + std * eps` where `std = exp(0.5 * logvar)`
525    /// and `eps` is a fresh `N(0, 1)` tensor produced by Box-Muller
526    /// over a seeded xorshift64 PRNG (mirroring the CPU branch of
527    /// [`ferrotorch_core::randn`] but using `seed` as the state).
528    ///
529    /// This is fully deterministic for a given `seed` on a given host
530    /// — but the bitwise output will NOT match a CUDA PyTorch reference
531    /// run, because PyTorch's CUDA RNG uses Philox, not xorshift +
532    /// Box-Muller. Tests should compare statistical properties (mean,
533    /// std) or use `.mode()` for bit-exact reference matches.
534    ///
535    /// # Errors
536    ///
537    /// Returns the underlying [`FerrotorchError`] for shape or
538    /// arithmetic failures.
539    pub fn sample_with_seed(&self, seed: u64) -> FerrotorchResult<Tensor<T>> {
540        let eps = randn_with_seed::<T>(self.mean.shape(), seed)?;
541        // std = exp(0.5 * logvar)
542        let half = T::from(0.5f64).ok_or_else(|| FerrotorchError::InvalidArgument {
543            message: "DiagonalGaussianDistribution::sample_with_seed: cannot cast 0.5 into Float"
544                .into(),
545        })?;
546        let half_t = ferrotorch_core::scalar::<T>(half)?;
547        let scaled_logvar =
548            ferrotorch_core::grad_fns::arithmetic::mul(&self.logvar, &half_t)?;
549        let std = ferrotorch_core::grad_fns::transcendental::exp(&scaled_logvar)?;
550        let noise = ferrotorch_core::grad_fns::arithmetic::mul(&std, &eps)?;
551        ferrotorch_core::grad_fns::arithmetic::add(&self.mean, &noise)
552    }
553}
554
555/// Box-Muller + xorshift64 N(0, 1) tensor generator seeded by `seed`.
556///
557/// Local to this module so the encoder doesn't pull in a fresh `rand`
558/// dependency. The algorithm mirrors the CPU branch of
559/// [`ferrotorch_core::randn`] exactly, but takes the seed as an input
560/// rather than deriving it from system time + thread id.
561fn randn_with_seed<T: Float>(shape: &[usize], seed: u64) -> FerrotorchResult<Tensor<T>> {
562    let numel: usize = shape.iter().product();
563    let mut data = Vec::with_capacity(numel);
564    let mut state = if seed == 0 { 0x0000_dead_beef_cafe } else { seed };
565
566    let mut next_uniform = || -> f64 {
567        state ^= state << 13;
568        state ^= state >> 7;
569        state ^= state << 17;
570        ((state as f64) / (u64::MAX as f64)).max(1e-300)
571    };
572
573    let mut i = 0;
574    while i < numel {
575        let u1 = next_uniform();
576        let u2 = next_uniform();
577        let r = (-2.0 * u1.ln()).sqrt();
578        let theta = 2.0 * std::f64::consts::PI * u2;
579        data.push(T::from(r * theta.cos()).ok_or_else(|| FerrotorchError::InvalidArgument {
580            message: "randn_with_seed: cannot cast Box-Muller output into Float".into(),
581        })?);
582        if i + 1 < numel {
583            data.push(T::from(r * theta.sin()).ok_or_else(|| {
584                FerrotorchError::InvalidArgument {
585                    message: "randn_with_seed: cannot cast Box-Muller output into Float".into(),
586                }
587            })?);
588        }
589        i += 2;
590    }
591
592    data.truncate(numel);
593    Tensor::from_storage(TensorStorage::cpu(data), shape.to_vec(), false)
594}
595
596#[cfg(test)]
597mod tests {
598    use super::*;
599
600    /// Tiny config that exercises every architectural feature (mid-block
601    /// attn, 4 down-blocks, channel-changing resnet shortcut) without
602    /// making tests slow. Same shape as the decoder's `tiny_cfg`.
603    fn tiny_cfg() -> VaeEncoderConfig {
604        VaeDecoderConfig {
605            out_channels: 3,
606            latent_channels: 4,
607            block_out_channels: vec![4, 8, 16, 16],
608            layers_per_block: 1,
609            norm_num_groups: 4,
610            sample_size: 8,
611            scaling_factor: 0.18215,
612        }
613    }
614
615    #[test]
616    fn encoder_forward_shape() {
617        let cfg = tiny_cfg();
618        let e = Encoder::<f32>::new(cfg.clone()).unwrap();
619        // image: [1, 3, 8, 8] -> after 3 downsamples => [1, 2*4=8, 1, 1].
620        let x = Tensor::from_storage(
621            TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
622            vec![1, 3, 8, 8],
623            false,
624        )
625        .unwrap();
626        let y = e.forward(&x).unwrap();
627        // 8 -> 4 -> 2 -> 1 (3 downsamples, last block has no downsample).
628        assert_eq!(y.shape(), &[1, 2 * cfg.latent_channels, 1, 1]);
629        for &v in y.data().unwrap() {
630            assert!(v.is_finite(), "encoder output non-finite: {v}");
631        }
632    }
633
634    #[test]
635    fn vae_encoder_forward_shape() {
636        let cfg = tiny_cfg();
637        let v = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
638        let x = Tensor::from_storage(
639            TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
640            vec![1, 3, 8, 8],
641            false,
642        )
643        .unwrap();
644        let y = v.forward(&x).unwrap();
645        assert_eq!(y.shape(), &[1, 2 * cfg.latent_channels, 1, 1]);
646    }
647
648    #[test]
649    fn vae_encoder_named_parameters_include_quant_conv() {
650        let cfg = tiny_cfg();
651        let v = VaeEncoder::<f32>::new(cfg).unwrap();
652        let names: Vec<String> = v.named_parameters().into_iter().map(|(n, _)| n).collect();
653        for k in [
654            "quant_conv.weight",
655            "quant_conv.bias",
656            "encoder.conv_in.weight",
657            "encoder.down_blocks.0.resnets.0.norm1.weight",
658            "encoder.mid_block.attentions.0.to_q.weight",
659            "encoder.conv_norm_out.weight",
660            "encoder.conv_out.bias",
661        ] {
662            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
663        }
664    }
665
666    #[test]
667    fn diag_gauss_split_and_mode_shapes() {
668        let cfg = tiny_cfg();
669        let v = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
670        let x = Tensor::from_storage(
671            TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
672            vec![1, 3, 8, 8],
673            false,
674        )
675        .unwrap();
676        let dist = v.encode(&x).unwrap();
677        assert_eq!(dist.mean.shape(), &[1, cfg.latent_channels, 1, 1]);
678        assert_eq!(dist.logvar.shape(), &[1, cfg.latent_channels, 1, 1]);
679        // .mode() must hand back exactly the mean tensor.
680        let mode = dist.mode();
681        assert_eq!(mode.shape(), dist.mean.shape());
682        for (a, b) in mode
683            .data()
684            .unwrap()
685            .iter()
686            .zip(dist.mean.data().unwrap().iter())
687        {
688            assert!((a - b).abs() < 1e-7);
689        }
690    }
691
692    #[test]
693    fn diag_gauss_logvar_is_clamped() {
694        // Build a synthetic params tensor where the logvar half is
695        // wildly out of range [-30, 20] and verify the clamp lands.
696        let l = 2;
697        let mut data = vec![0.5f32; l * 2]; // mean half: 0.5 everywhere
698        for d in data.iter_mut().skip(l) {
699            *d = 1e6; // logvar half: way above the +20 ceiling
700        }
701        let params =
702            Tensor::from_storage(TensorStorage::cpu(data), vec![1, 2 * l, 1, 1], false).unwrap();
703        let dist = DiagonalGaussianDistribution::<f32>::from_parameters(&params, l).unwrap();
704        for &v in dist.logvar.data().unwrap() {
705            assert!(
706                v <= LOGVAR_CLAMP_MAX as f32 + 1e-4,
707                "logvar not clamped to <= {LOGVAR_CLAMP_MAX}: got {v}"
708            );
709        }
710    }
711
712    #[test]
713    fn diag_gauss_sample_with_seed_is_deterministic() {
714        // Synthetic, controlled distribution: mean=0, logvar=0 ⇒ std=1
715        // ⇒ sample == eps, which is deterministic in our PRNG given a
716        // fixed seed.
717        let l = 4;
718        let zeros = vec![0.0f32; 2 * l * 3 * 3];
719        let params = Tensor::from_storage(
720            TensorStorage::cpu(zeros),
721            vec![1, 2 * l, 3, 3],
722            false,
723        )
724        .unwrap();
725        let dist = DiagonalGaussianDistribution::<f32>::from_parameters(&params, l).unwrap();
726        let s1 = dist.sample_with_seed(42).unwrap();
727        let s2 = dist.sample_with_seed(42).unwrap();
728        let s3 = dist.sample_with_seed(43).unwrap();
729        assert_eq!(s1.shape(), &[1, l, 3, 3]);
730        for (a, b) in s1.data().unwrap().iter().zip(s2.data().unwrap().iter()) {
731            assert!(
732                (a - b).abs() < 1e-7,
733                "sample_with_seed(42) not deterministic: {a} vs {b}"
734            );
735        }
736        // Different seed should produce different output (with extreme
737        // probability — Box-Muller from a different xorshift state).
738        let mut differ = false;
739        for (a, c) in s1.data().unwrap().iter().zip(s3.data().unwrap().iter()) {
740            if (a - c).abs() > 1e-6 {
741                differ = true;
742                break;
743            }
744        }
745        assert!(differ, "sample_with_seed(42) and sample_with_seed(43) produced identical output");
746    }
747
748    #[test]
749    fn vae_encoder_round_trip_state_dict() {
750        let cfg = tiny_cfg();
751        let src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
752        let sd = src.state_dict();
753        let mut dst = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
754        dst.load_state_dict(&sd, true).unwrap();
755        let x = Tensor::from_storage(
756            TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
757            vec![1, 3, 8, 8],
758            false,
759        )
760        .unwrap();
761        let a = src.forward(&x).unwrap();
762        let b = dst.forward(&x).unwrap();
763        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
764            assert!((x - y).abs() < 1e-5, "round-trip differs: {x} vs {y}");
765        }
766    }
767
768    #[test]
769    fn encode_with_scaling_applies_scaling_factor() {
770        let cfg = tiny_cfg();
771        let v = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
772        let x = Tensor::from_storage(
773            TensorStorage::cpu(vec![0.05f32; 3 * 8 * 8]),
774            vec![1, 3, 8, 8],
775            false,
776        )
777        .unwrap();
778        // Sample via the high-level API, then independently compute
779        // sample * scaling_factor and compare.
780        let scaled = v.encode_with_scaling(&x, 99).unwrap();
781        let dist = v.encode(&x).unwrap();
782        let raw = dist.sample_with_seed(99).unwrap();
783        for (s, r) in scaled
784            .data()
785            .unwrap()
786            .iter()
787            .zip(raw.data().unwrap().iter())
788        {
789            let expected = r * cfg.scaling_factor as f32;
790            assert!(
791                (s - expected).abs() < 1e-5,
792                "encode_with_scaling didn't apply scaling_factor: got {s}, expected {expected}"
793            );
794        }
795    }
796}