Skip to main content

ferrotorch_diffusion/
vae_encoder.rs

1//! Stable-Diffusion VAE encoder composition.
2//!
3//! Mirrors `diffusers.AutoencoderKL.encode(image).latent_dist` for
4//! `runwayml/stable-diffusion-v1-5`:
5//!
6//! ```text
7//! image (pixel-space, [B, 3, H, W])
8//!   -> Encoder.conv_in
9//!   -> Encoder.down_blocks[0..N]   (last block has no Downsample2D)
10//!   -> Encoder.mid_block
11//!   -> Encoder.conv_norm_out -> SiLU -> Encoder.conv_out
12//!     (output: [B, 2 * latent_channels, H/8, W/8] — mean/logvar concat)
13//!   -> quant_conv                  ([2*L -> 2*L], 1x1)
14//!   -> DiagonalGaussianDistribution::from_parameters
15//! ```
16//!
17//! The encoder-side mirror of [`crate::vae::VaeDecoder`]. The
18//! `encode_with_scaling` helper composes
19//! `latent_dist.sample(seed) * scaling_factor`, matching
20//! `AutoencoderKL.encode(x).latent_dist.sample() * vae.config.scaling_factor`.
21//! The bare [`Module::forward`] returns the raw `[B, 2*L, h, w]` parameters
22//! tensor (no split, no sample, no scaling) so callers can swap in their
23//! own sampling strategy (e.g. `.mode()` for deterministic decoding).
24
25use std::collections::HashMap;
26
27use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
28use ferrotorch_nn::module::{Module, StateDict};
29use ferrotorch_nn::parameter::Parameter;
30use ferrotorch_nn::{Conv2d, GroupNorm, SiLU};
31
32use crate::blocks::{DownEncoderBlock2D, UNetMidBlock2D};
33use crate::config::VaeDecoderConfig;
34
35/// Type alias — the SD VAE encoder and decoder share their config shape
36/// (mirrors `diffusers.AutoencoderKL.config`, which spans both halves).
37///
38/// Using a single config type also lets a downstream caller load
39/// `vae/config.json` once via [`VaeDecoderConfig::from_file`] and pass it
40/// to both `VaeEncoder::new` and `VaeDecoder::new`.
41pub type VaeEncoderConfig = VaeDecoderConfig;
42
43/// Diffusers logvar clamp range (`DiagonalGaussianDistribution.__init__`).
44///
45/// Matches `torch.clamp(self.logvar, -30.0, 20.0)` in
46/// `diffusers.models.autoencoders.vae.DiagonalGaussianDistribution`.
47const LOGVAR_CLAMP_MIN: f64 = -30.0;
48const LOGVAR_CLAMP_MAX: f64 = 20.0;
49
50/// The bare `Encoder` half — matches `diffusers.models.autoencoders.vae.Encoder`.
51#[derive(Debug)]
52pub struct Encoder<T: Float> {
53    /// First conv: `out_channels -> block_out_channels[0]` (k=3, pad=1).
54    ///
55    /// `out_channels` in [`VaeDecoderConfig`] is the image-side channel
56    /// count (the decoder *output* and the encoder *input* — they're the
57    /// same value, the field name is decoder-centric).
58    pub conv_in: Conv2d<T>,
59    /// Down-blocks in *encoder order* — block 0 operates at the lowest
60    /// channel count and highest spatial resolution. The deepest block
61    /// (`down_blocks[N-1]`) has no downsample (it preserves spatial
62    /// resolution into the mid-block).
63    pub down_blocks: Vec<DownEncoderBlock2D<T>>,
64    /// VAE mid-block at `block_out_channels[-1]` channels (same module
65    /// as in the decoder).
66    pub mid_block: UNetMidBlock2D<T>,
67    /// Final GroupNorm before the output conv (operates on
68    /// `block_out_channels[-1]` channels).
69    pub conv_norm_out: GroupNorm<T>,
70    /// Output activation (SiLU).
71    pub conv_act: SiLU,
72    /// Output conv: `block_out_channels[-1] -> 2 * latent_channels`
73    /// (k=3, pad=1). The factor of 2 holds the concatenated mean/logvar
74    /// produced by the diagonal Gaussian head.
75    pub conv_out: Conv2d<T>,
76    /// Frozen copy of the config.
77    pub config: VaeEncoderConfig,
78    training: bool,
79}
80
81impl<T: Float> Encoder<T> {
82    /// Build a randomly-initialized `Encoder`.
83    ///
84    /// # Errors
85    ///
86    /// Returns [`FerrotorchError::InvalidArgument`] for any invalid
87    /// config field (forwarded from [`VaeDecoderConfig::validate`]). In
88    /// particular `block_out_channels` must be non-empty — the index
89    /// into `[0]` / `[N-1]` below is preceded by `cfg.validate()?`
90    /// which checks exactly that.
91    pub fn new(cfg: VaeEncoderConfig) -> FerrotorchResult<Self> {
92        cfg.validate()?;
93        let groups = cfg.norm_num_groups;
94        let resnet_eps = 1e-6_f64;
95        let bottom_channels = cfg.block_out_channels[0];
96        let top_channels =
97            *cfg.block_out_channels
98                .last()
99                .ok_or_else(|| FerrotorchError::InvalidArgument {
100                    message: "Encoder::new: block_out_channels is empty (should be \
101                              unreachable after validate)"
102                        .into(),
103                })?;
104
105        // conv_in: image (out_channels) -> bottom_channels
106        let conv_in = Conv2d::<T>::new(
107            cfg.out_channels,
108            bottom_channels,
109            (3, 3),
110            (1, 1),
111            (1, 1),
112            true,
113        )?;
114
115        // Down-blocks. Encoder uses `layers_per_block` resnets per block
116        // (vs `layers_per_block + 1` on the decoder side).
117        let num_blocks = cfg.block_out_channels.len();
118        let resnets = cfg.layers_per_block;
119        let mut down_blocks = Vec::with_capacity(num_blocks);
120        let mut prev_out = bottom_channels;
121        for (i, &c) in cfg.block_out_channels.iter().enumerate() {
122            let is_final = i == num_blocks - 1;
123            down_blocks.push(DownEncoderBlock2D::<T>::new(
124                prev_out, c, resnets, groups, resnet_eps, !is_final,
125            )?);
126            prev_out = c;
127        }
128
129        // Mid-block at the deepest channel count.
130        let mid_block = UNetMidBlock2D::<T>::new(top_channels, groups, resnet_eps)?;
131
132        let conv_norm_out = GroupNorm::<T>::new(groups, top_channels, resnet_eps, true)?;
133        // The `2 *` factor produces the concatenated mean/logvar tensor.
134        let conv_out = Conv2d::<T>::new(
135            top_channels,
136            2 * cfg.latent_channels,
137            (3, 3),
138            (1, 1),
139            (1, 1),
140            true,
141        )?;
142
143        Ok(Self {
144            conv_in,
145            down_blocks,
146            mid_block,
147            conv_norm_out,
148            conv_act: SiLU::new(),
149            conv_out,
150            config: cfg,
151            training: false,
152        })
153    }
154}
155
156impl<T: Float> Module<T> for Encoder<T> {
157    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
158        // Sanity check: [B, out_channels, H, W] (out_channels is the
159        // image channel count — naming is decoder-centric).
160        let cfg = &self.config;
161        if input.ndim() != 4 || input.shape()[1] != cfg.out_channels {
162            return Err(FerrotorchError::ShapeMismatch {
163                message: format!(
164                    "Encoder::forward: expected [B, {}, H, W], got {:?}",
165                    cfg.out_channels,
166                    input.shape()
167                ),
168            });
169        }
170        let mut h = self.conv_in.forward(input)?;
171        for d in &self.down_blocks {
172            h = d.forward(&h)?;
173        }
174        h = self.mid_block.forward(&h)?;
175        h = self.conv_norm_out.forward(&h)?;
176        h = self.conv_act.forward(&h)?;
177        self.conv_out.forward(&h)
178    }
179
180    fn parameters(&self) -> Vec<&Parameter<T>> {
181        let mut out = Vec::new();
182        out.extend(self.conv_in.parameters());
183        for b in &self.down_blocks {
184            out.extend(b.parameters());
185        }
186        out.extend(self.mid_block.parameters());
187        out.extend(self.conv_norm_out.parameters());
188        out.extend(self.conv_out.parameters());
189        out
190    }
191
192    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
193        let mut out = Vec::new();
194        out.extend(self.conv_in.parameters_mut());
195        for b in &mut self.down_blocks {
196            out.extend(b.parameters_mut());
197        }
198        out.extend(self.mid_block.parameters_mut());
199        out.extend(self.conv_norm_out.parameters_mut());
200        out.extend(self.conv_out.parameters_mut());
201        out
202    }
203
204    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
205        let mut out = Vec::new();
206        for (n, p) in self.conv_in.named_parameters() {
207            out.push((format!("conv_in.{n}"), p));
208        }
209        for (i, b) in self.down_blocks.iter().enumerate() {
210            for (n, p) in b.named_parameters() {
211                out.push((format!("down_blocks.{i}.{n}"), p));
212            }
213        }
214        for (n, p) in self.mid_block.named_parameters() {
215            out.push((format!("mid_block.{n}"), p));
216        }
217        for (n, p) in self.conv_norm_out.named_parameters() {
218            out.push((format!("conv_norm_out.{n}"), p));
219        }
220        for (n, p) in self.conv_out.named_parameters() {
221            out.push((format!("conv_out.{n}"), p));
222        }
223        out
224    }
225
226    fn train(&mut self) {
227        self.training = true;
228        for b in &mut self.down_blocks {
229            b.train();
230        }
231        self.mid_block.train();
232    }
233    fn eval(&mut self) {
234        self.training = false;
235        for b in &mut self.down_blocks {
236            b.eval();
237        }
238        self.mid_block.eval();
239    }
240    fn is_training(&self) -> bool {
241        self.training
242    }
243
244    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
245        let extract = |prefix: &str| -> StateDict<T> {
246            let p = format!("{prefix}.");
247            state
248                .iter()
249                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
250                .collect()
251        };
252
253        if strict {
254            for k in state.keys() {
255                let ok = k.starts_with("conv_in.")
256                    || k.starts_with("down_blocks.")
257                    || k.starts_with("mid_block.")
258                    || k.starts_with("conv_norm_out.")
259                    || k.starts_with("conv_out.");
260                if !ok {
261                    return Err(FerrotorchError::InvalidArgument {
262                        message: format!("unexpected key in Encoder state_dict: \"{k}\""),
263                    });
264                }
265            }
266        }
267
268        self.conv_in.load_state_dict(&extract("conv_in"), strict)?;
269        for (i, b) in self.down_blocks.iter_mut().enumerate() {
270            b.load_state_dict(&extract(&format!("down_blocks.{i}")), strict)?;
271        }
272        self.mid_block
273            .load_state_dict(&extract("mid_block"), strict)?;
274        self.conv_norm_out
275            .load_state_dict(&extract("conv_norm_out"), strict)?;
276        self.conv_out
277            .load_state_dict(&extract("conv_out"), strict)?;
278        Ok(())
279    }
280}
281
282/// `AutoencoderKL`-style VAE encoder = [`Encoder`] + `quant_conv`.
283///
284/// The encoder produces `[B, 2*latent_channels, H/8, W/8]` after the
285/// `Encoder` stack; the 1x1 `quant_conv` then projects this through a
286/// learned linear map. The split into (mean, logvar) and any sampling
287/// happens in [`DiagonalGaussianDistribution`].
288#[derive(Debug)]
289pub struct VaeEncoder<T: Float> {
290    /// The actual `Encoder` stack.
291    pub encoder: Encoder<T>,
292    /// 1x1 quant projection over `2 * latent_channels` channels.
293    pub quant_conv: Conv2d<T>,
294    /// Frozen config copy.
295    pub config: VaeEncoderConfig,
296    training: bool,
297}
298
299impl<T: Float> VaeEncoder<T> {
300    /// Build a randomly-initialized `VaeEncoder`.
301    ///
302    /// # Errors
303    ///
304    /// Returns the underlying [`FerrotorchError`] on bad config dims.
305    pub fn new(cfg: VaeEncoderConfig) -> FerrotorchResult<Self> {
306        cfg.validate()?;
307        let encoder = Encoder::<T>::new(cfg.clone())?;
308        let two_l = 2 * cfg.latent_channels;
309        let quant_conv = Conv2d::<T>::new(two_l, two_l, (1, 1), (1, 1), (0, 0), true)?;
310        Ok(Self {
311            encoder,
312            quant_conv,
313            config: cfg,
314            training: false,
315        })
316    }
317
318    /// Encode an image into a diagonal Gaussian distribution over
319    /// latent space. Matches `AutoencoderKL.encode(image).latent_dist`.
320    ///
321    /// # Errors
322    ///
323    /// Returns [`FerrotorchError::ShapeMismatch`] when the input is not
324    /// `[B, out_channels, H, W]`. Propagates downstream op errors.
325    pub fn encode(&self, image: &Tensor<T>) -> FerrotorchResult<DiagonalGaussianDistribution<T>> {
326        let params = self.forward(image)?;
327        DiagonalGaussianDistribution::from_parameters(&params, self.config.latent_channels)
328    }
329
330    /// Encode an image, sample from the latent distribution with a
331    /// deterministic seed, then multiply by `scaling_factor`. This
332    /// matches the canonical SD pipeline call:
333    ///
334    /// ```text
335    /// latent = vae.encode(image).latent_dist.sample() * vae.config.scaling_factor
336    /// ```
337    ///
338    /// The Box-Muller / xorshift noise generator runs on CPU and is
339    /// fully deterministic for a given `seed` — different from
340    /// PyTorch's CUDA RNG, so the produced latent will NOT be
341    /// numerically identical to a Python reference run; only
342    /// statistically equivalent.
343    ///
344    /// # Errors
345    ///
346    /// Returns the underlying [`FerrotorchError`] for shape or
347    /// arithmetic failures.
348    pub fn encode_with_scaling(&self, image: &Tensor<T>, seed: u64) -> FerrotorchResult<Tensor<T>> {
349        let dist = self.encode(image)?;
350        let sample = dist.sample_with_seed(seed)?;
351        let sf = T::from(self.config.scaling_factor).ok_or_else(|| {
352            FerrotorchError::InvalidArgument {
353                message: format!(
354                    "VaeEncoder::encode_with_scaling: cannot cast scaling_factor={} into Float",
355                    self.config.scaling_factor
356                ),
357            }
358        })?;
359        let sf_t = ferrotorch_core::scalar::<T>(sf)?;
360        ferrotorch_core::grad_fns::arithmetic::mul(&sample, &sf_t)
361    }
362}
363
364impl<T: Float> Module<T> for VaeEncoder<T> {
365    /// Forward returns the raw `[B, 2*latent_channels, h, w]` parameters
366    /// tensor (concatenated mean/logvar, post `quant_conv`). Callers
367    /// who want a latent should use [`Self::encode`] or
368    /// [`Self::encode_with_scaling`].
369    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
370        let cfg = &self.config;
371        if input.ndim() != 4 || input.shape()[1] != cfg.out_channels {
372            return Err(FerrotorchError::ShapeMismatch {
373                message: format!(
374                    "VaeEncoder::forward: expected [B, {}, H, W], got {:?}",
375                    cfg.out_channels,
376                    input.shape()
377                ),
378            });
379        }
380        let h = self.encoder.forward(input)?;
381        self.quant_conv.forward(&h)
382    }
383
384    fn parameters(&self) -> Vec<&Parameter<T>> {
385        let mut out = Vec::new();
386        out.extend(self.encoder.parameters());
387        out.extend(self.quant_conv.parameters());
388        out
389    }
390
391    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
392        let mut out = Vec::new();
393        out.extend(self.encoder.parameters_mut());
394        out.extend(self.quant_conv.parameters_mut());
395        out
396    }
397
398    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
399        let mut out = Vec::new();
400        for (n, p) in self.encoder.named_parameters() {
401            out.push((format!("encoder.{n}"), p));
402        }
403        for (n, p) in self.quant_conv.named_parameters() {
404            out.push((format!("quant_conv.{n}"), p));
405        }
406        out
407    }
408
409    fn train(&mut self) {
410        self.training = true;
411        self.encoder.train();
412    }
413    fn eval(&mut self) {
414        self.training = false;
415        self.encoder.eval();
416    }
417    fn is_training(&self) -> bool {
418        self.training
419    }
420
421    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
422        let extract = |prefix: &str| -> StateDict<T> {
423            let p = format!("{prefix}.");
424            state
425                .iter()
426                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
427                .collect()
428        };
429        if strict {
430            for k in state.keys() {
431                let ok = k.starts_with("encoder.") || k.starts_with("quant_conv.");
432                if !ok {
433                    return Err(FerrotorchError::InvalidArgument {
434                        message: format!("unexpected key in VaeEncoder state_dict: \"{k}\""),
435                    });
436                }
437            }
438        }
439        self.encoder.load_state_dict(&extract("encoder"), strict)?;
440        self.quant_conv
441            .load_state_dict(&extract("quant_conv"), strict)?;
442        let _: HashMap<String, Tensor<T>> = HashMap::new(); // keep HashMap import alive
443        Ok(())
444    }
445}
446
447/// Diagonal Gaussian over latent space — the same parameterization
448/// `diffusers.models.autoencoders.vae.DiagonalGaussianDistribution`
449/// uses. Holds `mean` and `logvar` tensors (both `[B, L, h, w]`) split
450/// from the encoder's concatenated parameters output.
451///
452/// `logvar` is clamped to `[-30, 20]` on construction, matching the
453/// diffusers reference exactly.
454#[derive(Debug)]
455pub struct DiagonalGaussianDistribution<T: Float> {
456    /// Mean of the diagonal Gaussian, `[B, L, h, w]`.
457    pub mean: Tensor<T>,
458    /// Log-variance of the diagonal Gaussian (clamped), `[B, L, h, w]`.
459    pub logvar: Tensor<T>,
460}
461
462impl<T: Float> DiagonalGaussianDistribution<T> {
463    /// Split the encoder's `[B, 2*L, h, w]` parameters tensor into
464    /// mean / logvar halves along the channel axis.
465    ///
466    /// # Errors
467    ///
468    /// Returns [`FerrotorchError::ShapeMismatch`] when the channel
469    /// dimension is not `2 * latent_channels`. Propagates downstream
470    /// errors from the chunk / clamp ops.
471    pub fn from_parameters(params: &Tensor<T>, latent_channels: usize) -> FerrotorchResult<Self> {
472        if params.ndim() != 4 || params.shape()[1] != 2 * latent_channels {
473            return Err(FerrotorchError::ShapeMismatch {
474                message: format!(
475                    "DiagonalGaussianDistribution::from_parameters: expected [B, {}, H, W], \
476                     got {:?}",
477                    2 * latent_channels,
478                    params.shape()
479                ),
480            });
481        }
482        // chunk(2, dim=1): two [B, L, h, w] tensors — [mean, logvar].
483        let parts = params.chunk(2, 1)?;
484        let mean = parts[0].clone();
485        let logvar_raw = parts[1].clone();
486        // Match diffusers' `torch.clamp(logvar, -30.0, 20.0)`.
487        let lo = T::from(LOGVAR_CLAMP_MIN).ok_or_else(|| FerrotorchError::InvalidArgument {
488            message: format!(
489                "DiagonalGaussianDistribution: cannot cast logvar clamp min {LOGVAR_CLAMP_MIN} \
490                 into Float"
491            ),
492        })?;
493        let hi = T::from(LOGVAR_CLAMP_MAX).ok_or_else(|| FerrotorchError::InvalidArgument {
494            message: format!(
495                "DiagonalGaussianDistribution: cannot cast logvar clamp max {LOGVAR_CLAMP_MAX} \
496                 into Float"
497            ),
498        })?;
499        let logvar = logvar_raw.clamp_t(lo, hi)?;
500        Ok(Self { mean, logvar })
501    }
502
503    /// Return the distribution mode — the deterministic latent if the
504    /// caller does not want to sample. Matches
505    /// `DiagonalGaussianDistribution.mode()` in diffusers.
506    pub fn mode(&self) -> &Tensor<T> {
507        &self.mean
508    }
509
510    /// Sample from the distribution with a deterministic seed.
511    ///
512    /// Computes `mean + std * eps` where `std = exp(0.5 * logvar)`
513    /// and `eps` is a fresh `N(0, 1)` tensor produced by Box-Muller
514    /// over a seeded xorshift64 PRNG (mirroring the CPU branch of
515    /// [`ferrotorch_core::randn`] but using `seed` as the state).
516    ///
517    /// This is fully deterministic for a given `seed` on a given host
518    /// — but the bitwise output will NOT match a CUDA PyTorch reference
519    /// run, because PyTorch's CUDA RNG uses Philox, not xorshift +
520    /// Box-Muller. Tests should compare statistical properties (mean,
521    /// std) or use `.mode()` for bit-exact reference matches.
522    ///
523    /// # Errors
524    ///
525    /// Returns the underlying [`FerrotorchError`] for shape or
526    /// arithmetic failures.
527    pub fn sample_with_seed(&self, seed: u64) -> FerrotorchResult<Tensor<T>> {
528        let eps = randn_with_seed::<T>(self.mean.shape(), seed)?;
529        // std = exp(0.5 * logvar)
530        let half = T::from(0.5f64).ok_or_else(|| FerrotorchError::InvalidArgument {
531            message: "DiagonalGaussianDistribution::sample_with_seed: cannot cast 0.5 into Float"
532                .into(),
533        })?;
534        let half_t = ferrotorch_core::scalar::<T>(half)?;
535        let scaled_logvar = ferrotorch_core::grad_fns::arithmetic::mul(&self.logvar, &half_t)?;
536        let std = ferrotorch_core::grad_fns::transcendental::exp(&scaled_logvar)?;
537        let noise = ferrotorch_core::grad_fns::arithmetic::mul(&std, &eps)?;
538        ferrotorch_core::grad_fns::arithmetic::add(&self.mean, &noise)
539    }
540}
541
542/// Box-Muller + xorshift64 N(0, 1) tensor generator seeded by `seed`.
543///
544/// Local to this module so the encoder doesn't pull in a fresh `rand`
545/// dependency. The algorithm mirrors the CPU branch of
546/// [`ferrotorch_core::randn`] exactly, but takes the seed as an input
547/// rather than deriving it from system time + thread id.
548fn randn_with_seed<T: Float>(shape: &[usize], seed: u64) -> FerrotorchResult<Tensor<T>> {
549    let numel: usize = shape.iter().product();
550    let mut data = Vec::with_capacity(numel);
551    let mut state = if seed == 0 {
552        0x0000_dead_beef_cafe
553    } else {
554        seed
555    };
556
557    let mut next_uniform = || -> f64 {
558        state ^= state << 13;
559        state ^= state >> 7;
560        state ^= state << 17;
561        ((state as f64) / (u64::MAX as f64)).max(1e-300)
562    };
563
564    let mut i = 0;
565    while i < numel {
566        let u1 = next_uniform();
567        let u2 = next_uniform();
568        let r = (-2.0 * u1.ln()).sqrt();
569        let theta = 2.0 * std::f64::consts::PI * u2;
570        data.push(
571            T::from(r * theta.cos()).ok_or_else(|| FerrotorchError::InvalidArgument {
572                message: "randn_with_seed: cannot cast Box-Muller output into Float".into(),
573            })?,
574        );
575        if i + 1 < numel {
576            data.push(T::from(r * theta.sin()).ok_or_else(|| {
577                FerrotorchError::InvalidArgument {
578                    message: "randn_with_seed: cannot cast Box-Muller output into Float".into(),
579                }
580            })?);
581        }
582        i += 2;
583    }
584
585    data.truncate(numel);
586    Tensor::from_storage(TensorStorage::cpu(data), shape.to_vec(), false)
587}
588
589#[cfg(test)]
590mod tests {
591    use super::*;
592
593    /// Tiny config that exercises every architectural feature (mid-block
594    /// attn, 4 down-blocks, channel-changing resnet shortcut) without
595    /// making tests slow. Same shape as the decoder's `tiny_cfg`.
596    fn tiny_cfg() -> VaeEncoderConfig {
597        VaeDecoderConfig {
598            out_channels: 3,
599            latent_channels: 4,
600            block_out_channels: vec![4, 8, 16, 16],
601            layers_per_block: 1,
602            norm_num_groups: 4,
603            sample_size: 8,
604            scaling_factor: 0.18215,
605        }
606    }
607
608    #[test]
609    fn encoder_forward_shape() {
610        let cfg = tiny_cfg();
611        let e = Encoder::<f32>::new(cfg.clone()).unwrap();
612        // image: [1, 3, 8, 8] -> after 3 downsamples => [1, 2*4=8, 1, 1].
613        let x = Tensor::from_storage(
614            TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
615            vec![1, 3, 8, 8],
616            false,
617        )
618        .unwrap();
619        let y = e.forward(&x).unwrap();
620        // 8 -> 4 -> 2 -> 1 (3 downsamples, last block has no downsample).
621        assert_eq!(y.shape(), &[1, 2 * cfg.latent_channels, 1, 1]);
622        for &v in y.data().unwrap() {
623            assert!(v.is_finite(), "encoder output non-finite: {v}");
624        }
625    }
626
627    #[test]
628    fn vae_encoder_forward_shape() {
629        let cfg = tiny_cfg();
630        let v = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
631        let x = Tensor::from_storage(
632            TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
633            vec![1, 3, 8, 8],
634            false,
635        )
636        .unwrap();
637        let y = v.forward(&x).unwrap();
638        assert_eq!(y.shape(), &[1, 2 * cfg.latent_channels, 1, 1]);
639    }
640
641    #[test]
642    fn vae_encoder_named_parameters_include_quant_conv() {
643        let cfg = tiny_cfg();
644        let v = VaeEncoder::<f32>::new(cfg).unwrap();
645        let names: Vec<String> = v.named_parameters().into_iter().map(|(n, _)| n).collect();
646        for k in [
647            "quant_conv.weight",
648            "quant_conv.bias",
649            "encoder.conv_in.weight",
650            "encoder.down_blocks.0.resnets.0.norm1.weight",
651            "encoder.mid_block.attentions.0.to_q.weight",
652            "encoder.conv_norm_out.weight",
653            "encoder.conv_out.bias",
654        ] {
655            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
656        }
657    }
658
659    #[test]
660    fn diag_gauss_split_and_mode_shapes() {
661        let cfg = tiny_cfg();
662        let v = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
663        let x = Tensor::from_storage(
664            TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
665            vec![1, 3, 8, 8],
666            false,
667        )
668        .unwrap();
669        let dist = v.encode(&x).unwrap();
670        assert_eq!(dist.mean.shape(), &[1, cfg.latent_channels, 1, 1]);
671        assert_eq!(dist.logvar.shape(), &[1, cfg.latent_channels, 1, 1]);
672        // .mode() must hand back exactly the mean tensor.
673        let mode = dist.mode();
674        assert_eq!(mode.shape(), dist.mean.shape());
675        for (a, b) in mode
676            .data()
677            .unwrap()
678            .iter()
679            .zip(dist.mean.data().unwrap().iter())
680        {
681            assert!((a - b).abs() < 1e-7);
682        }
683    }
684
685    #[test]
686    fn diag_gauss_logvar_is_clamped() {
687        // Build a synthetic params tensor where the logvar half is
688        // wildly out of range [-30, 20] and verify the clamp lands.
689        let l = 2;
690        let mut data = vec![0.5f32; l * 2]; // mean half: 0.5 everywhere
691        for d in data.iter_mut().skip(l) {
692            *d = 1e6; // logvar half: way above the +20 ceiling
693        }
694        let params =
695            Tensor::from_storage(TensorStorage::cpu(data), vec![1, 2 * l, 1, 1], false).unwrap();
696        let dist = DiagonalGaussianDistribution::<f32>::from_parameters(&params, l).unwrap();
697        for &v in dist.logvar.data().unwrap() {
698            assert!(
699                v <= LOGVAR_CLAMP_MAX as f32 + 1e-4,
700                "logvar not clamped to <= {LOGVAR_CLAMP_MAX}: got {v}"
701            );
702        }
703    }
704
705    #[test]
706    fn diag_gauss_sample_with_seed_is_deterministic() {
707        // Synthetic, controlled distribution: mean=0, logvar=0 ⇒ std=1
708        // ⇒ sample == eps, which is deterministic in our PRNG given a
709        // fixed seed.
710        let l = 4;
711        let zeros = vec![0.0f32; 2 * l * 3 * 3];
712        let params =
713            Tensor::from_storage(TensorStorage::cpu(zeros), vec![1, 2 * l, 3, 3], false).unwrap();
714        let dist = DiagonalGaussianDistribution::<f32>::from_parameters(&params, l).unwrap();
715        let s1 = dist.sample_with_seed(42).unwrap();
716        let s2 = dist.sample_with_seed(42).unwrap();
717        let s3 = dist.sample_with_seed(43).unwrap();
718        assert_eq!(s1.shape(), &[1, l, 3, 3]);
719        for (a, b) in s1.data().unwrap().iter().zip(s2.data().unwrap().iter()) {
720            assert!(
721                (a - b).abs() < 1e-7,
722                "sample_with_seed(42) not deterministic: {a} vs {b}"
723            );
724        }
725        // Different seed should produce different output (with extreme
726        // probability — Box-Muller from a different xorshift state).
727        let mut differ = false;
728        for (a, c) in s1.data().unwrap().iter().zip(s3.data().unwrap().iter()) {
729            if (a - c).abs() > 1e-6 {
730                differ = true;
731                break;
732            }
733        }
734        assert!(
735            differ,
736            "sample_with_seed(42) and sample_with_seed(43) produced identical output"
737        );
738    }
739
740    #[test]
741    fn vae_encoder_round_trip_state_dict() {
742        let cfg = tiny_cfg();
743        let src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
744        let sd = src.state_dict();
745        let mut dst = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
746        dst.load_state_dict(&sd, true).unwrap();
747        let x = Tensor::from_storage(
748            TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
749            vec![1, 3, 8, 8],
750            false,
751        )
752        .unwrap();
753        let a = src.forward(&x).unwrap();
754        let b = dst.forward(&x).unwrap();
755        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
756            assert!((x - y).abs() < 1e-5, "round-trip differs: {x} vs {y}");
757        }
758    }
759
760    #[test]
761    fn encode_with_scaling_applies_scaling_factor() {
762        let cfg = tiny_cfg();
763        let v = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
764        let x = Tensor::from_storage(
765            TensorStorage::cpu(vec![0.05f32; 3 * 8 * 8]),
766            vec![1, 3, 8, 8],
767            false,
768        )
769        .unwrap();
770        // Sample via the high-level API, then independently compute
771        // sample * scaling_factor and compare.
772        let scaled = v.encode_with_scaling(&x, 99).unwrap();
773        let dist = v.encode(&x).unwrap();
774        let raw = dist.sample_with_seed(99).unwrap();
775        for (s, r) in scaled
776            .data()
777            .unwrap()
778            .iter()
779            .zip(raw.data().unwrap().iter())
780        {
781            let expected = r * cfg.scaling_factor as f32;
782            assert!(
783                (s - expected).abs() < 1e-5,
784                "encode_with_scaling didn't apply scaling_factor: got {s}, expected {expected}"
785            );
786        }
787    }
788}