Skip to main content

ferrotorch_diffusion/
vae_encoder.rs

1//! Stable-Diffusion VAE encoder composition.
2//!
3//! Mirrors `diffusers.AutoencoderKL.encode(image).latent_dist` for
4//! `runwayml/stable-diffusion-v1-5`:
5//!
6//! ```text
7//! image (pixel-space, [B, 3, H, W])
8//!   -> Encoder.conv_in
9//!   -> Encoder.down_blocks[0..N]   (last block has no Downsample2D)
10//!   -> Encoder.mid_block
11//!   -> Encoder.conv_norm_out -> SiLU -> Encoder.conv_out
12//!     (output: [B, 2 * latent_channels, H/8, W/8] — mean/logvar concat)
13//!   -> quant_conv                  ([2*L -> 2*L], 1x1)
14//!   -> DiagonalGaussianDistribution::from_parameters
15//! ```
16//!
17//! The encoder-side mirror of [`crate::vae::VaeDecoder`]. The
18//! `encode_with_scaling` helper composes
19//! `latent_dist.sample(seed) * scaling_factor`, matching
20//! `AutoencoderKL.encode(x).latent_dist.sample() * vae.config.scaling_factor`.
21//! The bare [`Module::forward`] returns the raw `[B, 2*L, h, w]` parameters
22//! tensor (no split, no sample, no scaling) so callers can swap in their
23//! own sampling strategy (e.g. `.mode()` for deterministic decoding).
24//!
25//! ## REQ status (per `.design/ferrotorch-diffusion/vae_encoder.md`)
26//!
27//! | REQ | Status | Evidence |
28//! |---|---|---|
29//! | REQ-1 | SHIPPED | `Encoder<T>` at `vae_encoder.rs:51..79` and `Encoder::new` at `vae_encoder.rs:81..153`; consumer: `VaeEncoder::new` at `vae_encoder.rs:307` builds it; itself consumed by `safetensors_loader.rs:425` `load_vae_encoder` |
30//! | REQ-2 | SHIPPED | `VaeEncoder<T>` at `vae_encoder.rs:288..297` and `VaeEncoder::new` at `vae_encoder.rs:299..316`; consumer: `safetensors_loader.rs:425` `load_vae_encoder`; `gpu/vae_encoder.rs:317` `GpuVaeEncoder::from_module` consumes its `state_dict()` |
31//! | REQ-3 | SHIPPED | `VaeEncoder::encode` at `vae_encoder.rs:325..328` and `DiagonalGaussianDistribution::from_parameters` at `vae_encoder.rs:471..501`; consumer: `vae_encoder.rs:349` `encode_with_scaling` invokes it |
32//! | REQ-4 | SHIPPED | `DiagonalGaussianDistribution::sample_with_seed` at `vae_encoder.rs:527..539`, `mode` at `vae_encoder.rs:506..508`, `randn_with_seed` at `vae_encoder.rs:548..587`; consumer: `vae_encoder.rs:350` `encode_with_scaling` calls `dist.sample_with_seed(seed)` |
33//! | REQ-5 | SHIPPED | `encode_with_scaling` at `vae_encoder.rs:348..361`; consumer: re-exported via `lib.rs:148` `pub use vae_encoder::VaeEncoder` (boundary method IS the public API per goal.md S5 grandfathering) |
34//! | REQ-6 | SHIPPED | `Module<T>::forward` at `vae_encoder.rs:369..382`; consumer: `vae_encoder.rs:326` `encode` calls `self.forward(image)?` to produce the `[B, 2L, h, w]` parameters |
35//! | REQ-7 | SHIPPED | `Module<T>::load_state_dict` at `vae_encoder.rs:421..444`; consumer: `safetensors_loader.rs:394` `VaeEncoder::load_hf_state_dict` calls `self.load_state_dict(&remapped, strict)` after stripping the `vae.` prefix |
36
37use std::collections::HashMap;
38
39use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
40use ferrotorch_nn::module::{Module, StateDict};
41use ferrotorch_nn::parameter::Parameter;
42use ferrotorch_nn::{Conv2d, GroupNorm, SiLU};
43
44use crate::blocks::{DownEncoderBlock2D, UNetMidBlock2D};
45use crate::config::VaeDecoderConfig;
46
47/// Type alias — the SD VAE encoder and decoder share their config shape
48/// (mirrors `diffusers.AutoencoderKL.config`, which spans both halves).
49///
50/// Using a single config type also lets a downstream caller load
51/// `vae/config.json` once via [`VaeDecoderConfig::from_file`] and pass it
52/// to both `VaeEncoder::new` and `VaeDecoder::new`.
53pub type VaeEncoderConfig = VaeDecoderConfig;
54
55/// Diffusers logvar clamp range (`DiagonalGaussianDistribution.__init__`).
56///
57/// Matches `torch.clamp(self.logvar, -30.0, 20.0)` in
58/// `diffusers.models.autoencoders.vae.DiagonalGaussianDistribution`.
59const LOGVAR_CLAMP_MIN: f64 = -30.0;
60const LOGVAR_CLAMP_MAX: f64 = 20.0;
61
62/// The bare `Encoder` half — matches `diffusers.models.autoencoders.vae.Encoder`.
63#[derive(Debug)]
64pub struct Encoder<T: Float> {
65    /// First conv: `out_channels -> block_out_channels[0]` (k=3, pad=1).
66    ///
67    /// `out_channels` in [`VaeDecoderConfig`] is the image-side channel
68    /// count (the decoder *output* and the encoder *input* — they're the
69    /// same value, the field name is decoder-centric).
70    pub conv_in: Conv2d<T>,
71    /// Down-blocks in *encoder order* — block 0 operates at the lowest
72    /// channel count and highest spatial resolution. The deepest block
73    /// (`down_blocks[N-1]`) has no downsample (it preserves spatial
74    /// resolution into the mid-block).
75    pub down_blocks: Vec<DownEncoderBlock2D<T>>,
76    /// VAE mid-block at `block_out_channels[-1]` channels (same module
77    /// as in the decoder).
78    pub mid_block: UNetMidBlock2D<T>,
79    /// Final GroupNorm before the output conv (operates on
80    /// `block_out_channels[-1]` channels).
81    pub conv_norm_out: GroupNorm<T>,
82    /// Output activation (SiLU).
83    pub conv_act: SiLU,
84    /// Output conv: `block_out_channels[-1] -> 2 * latent_channels`
85    /// (k=3, pad=1). The factor of 2 holds the concatenated mean/logvar
86    /// produced by the diagonal Gaussian head.
87    pub conv_out: Conv2d<T>,
88    /// Frozen copy of the config.
89    pub config: VaeEncoderConfig,
90    training: bool,
91}
92
93impl<T: Float> Encoder<T> {
94    /// Build a randomly-initialized `Encoder`.
95    ///
96    /// # Errors
97    ///
98    /// Returns [`FerrotorchError::InvalidArgument`] for any invalid
99    /// config field (forwarded from [`VaeDecoderConfig::validate`]). In
100    /// particular `block_out_channels` must be non-empty — the index
101    /// into `[0]` / `[N-1]` below is preceded by `cfg.validate()?`
102    /// which checks exactly that.
103    pub fn new(cfg: VaeEncoderConfig) -> FerrotorchResult<Self> {
104        cfg.validate()?;
105        let groups = cfg.norm_num_groups;
106        let resnet_eps = 1e-6_f64;
107        let bottom_channels = cfg.block_out_channels[0];
108        let top_channels =
109            *cfg.block_out_channels
110                .last()
111                .ok_or_else(|| FerrotorchError::InvalidArgument {
112                    message: "Encoder::new: block_out_channels is empty (should be \
113                              unreachable after validate)"
114                        .into(),
115                })?;
116
117        // conv_in: image (out_channels) -> bottom_channels
118        let conv_in = Conv2d::<T>::new(
119            cfg.out_channels,
120            bottom_channels,
121            (3, 3),
122            (1, 1),
123            (1, 1),
124            true,
125        )?;
126
127        // Down-blocks. Encoder uses `layers_per_block` resnets per block
128        // (vs `layers_per_block + 1` on the decoder side).
129        let num_blocks = cfg.block_out_channels.len();
130        let resnets = cfg.layers_per_block;
131        let mut down_blocks = Vec::with_capacity(num_blocks);
132        let mut prev_out = bottom_channels;
133        for (i, &c) in cfg.block_out_channels.iter().enumerate() {
134            let is_final = i == num_blocks - 1;
135            down_blocks.push(DownEncoderBlock2D::<T>::new(
136                prev_out, c, resnets, groups, resnet_eps, !is_final,
137            )?);
138            prev_out = c;
139        }
140
141        // Mid-block at the deepest channel count.
142        let mid_block = UNetMidBlock2D::<T>::new(top_channels, groups, resnet_eps)?;
143
144        let conv_norm_out = GroupNorm::<T>::new(groups, top_channels, resnet_eps, true)?;
145        // The `2 *` factor produces the concatenated mean/logvar tensor.
146        let conv_out = Conv2d::<T>::new(
147            top_channels,
148            2 * cfg.latent_channels,
149            (3, 3),
150            (1, 1),
151            (1, 1),
152            true,
153        )?;
154
155        Ok(Self {
156            conv_in,
157            down_blocks,
158            mid_block,
159            conv_norm_out,
160            conv_act: SiLU::new(),
161            conv_out,
162            config: cfg,
163            training: false,
164        })
165    }
166}
167
168impl<T: Float> Module<T> for Encoder<T> {
169    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
170        // Sanity check: [B, out_channels, H, W] (out_channels is the
171        // image channel count — naming is decoder-centric).
172        let cfg = &self.config;
173        if input.ndim() != 4 || input.shape()[1] != cfg.out_channels {
174            return Err(FerrotorchError::ShapeMismatch {
175                message: format!(
176                    "Encoder::forward: expected [B, {}, H, W], got {:?}",
177                    cfg.out_channels,
178                    input.shape()
179                ),
180            });
181        }
182        let mut h = self.conv_in.forward(input)?;
183        for d in &self.down_blocks {
184            h = d.forward(&h)?;
185        }
186        h = self.mid_block.forward(&h)?;
187        h = self.conv_norm_out.forward(&h)?;
188        h = self.conv_act.forward(&h)?;
189        self.conv_out.forward(&h)
190    }
191
192    fn parameters(&self) -> Vec<&Parameter<T>> {
193        let mut out = Vec::new();
194        out.extend(self.conv_in.parameters());
195        for b in &self.down_blocks {
196            out.extend(b.parameters());
197        }
198        out.extend(self.mid_block.parameters());
199        out.extend(self.conv_norm_out.parameters());
200        out.extend(self.conv_out.parameters());
201        out
202    }
203
204    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
205        let mut out = Vec::new();
206        out.extend(self.conv_in.parameters_mut());
207        for b in &mut self.down_blocks {
208            out.extend(b.parameters_mut());
209        }
210        out.extend(self.mid_block.parameters_mut());
211        out.extend(self.conv_norm_out.parameters_mut());
212        out.extend(self.conv_out.parameters_mut());
213        out
214    }
215
216    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
217        let mut out = Vec::new();
218        for (n, p) in self.conv_in.named_parameters() {
219            out.push((format!("conv_in.{n}"), p));
220        }
221        for (i, b) in self.down_blocks.iter().enumerate() {
222            for (n, p) in b.named_parameters() {
223                out.push((format!("down_blocks.{i}.{n}"), p));
224            }
225        }
226        for (n, p) in self.mid_block.named_parameters() {
227            out.push((format!("mid_block.{n}"), p));
228        }
229        for (n, p) in self.conv_norm_out.named_parameters() {
230            out.push((format!("conv_norm_out.{n}"), p));
231        }
232        for (n, p) in self.conv_out.named_parameters() {
233            out.push((format!("conv_out.{n}"), p));
234        }
235        out
236    }
237
238    fn train(&mut self) {
239        self.training = true;
240        for b in &mut self.down_blocks {
241            b.train();
242        }
243        self.mid_block.train();
244    }
245    fn eval(&mut self) {
246        self.training = false;
247        for b in &mut self.down_blocks {
248            b.eval();
249        }
250        self.mid_block.eval();
251    }
252    fn is_training(&self) -> bool {
253        self.training
254    }
255
256    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
257        let extract = |prefix: &str| -> StateDict<T> {
258            let p = format!("{prefix}.");
259            state
260                .iter()
261                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
262                .collect()
263        };
264
265        if strict {
266            for k in state.keys() {
267                let ok = k.starts_with("conv_in.")
268                    || k.starts_with("down_blocks.")
269                    || k.starts_with("mid_block.")
270                    || k.starts_with("conv_norm_out.")
271                    || k.starts_with("conv_out.");
272                if !ok {
273                    return Err(FerrotorchError::InvalidArgument {
274                        message: format!("unexpected key in Encoder state_dict: \"{k}\""),
275                    });
276                }
277            }
278        }
279
280        self.conv_in.load_state_dict(&extract("conv_in"), strict)?;
281        for (i, b) in self.down_blocks.iter_mut().enumerate() {
282            b.load_state_dict(&extract(&format!("down_blocks.{i}")), strict)?;
283        }
284        self.mid_block
285            .load_state_dict(&extract("mid_block"), strict)?;
286        self.conv_norm_out
287            .load_state_dict(&extract("conv_norm_out"), strict)?;
288        self.conv_out
289            .load_state_dict(&extract("conv_out"), strict)?;
290        Ok(())
291    }
292}
293
294/// `AutoencoderKL`-style VAE encoder = [`Encoder`] + `quant_conv`.
295///
296/// The encoder produces `[B, 2*latent_channels, H/8, W/8]` after the
297/// `Encoder` stack; the 1x1 `quant_conv` then projects this through a
298/// learned linear map. The split into (mean, logvar) and any sampling
299/// happens in [`DiagonalGaussianDistribution`].
300#[derive(Debug)]
301pub struct VaeEncoder<T: Float> {
302    /// The actual `Encoder` stack.
303    pub encoder: Encoder<T>,
304    /// 1x1 quant projection over `2 * latent_channels` channels.
305    pub quant_conv: Conv2d<T>,
306    /// Frozen config copy.
307    pub config: VaeEncoderConfig,
308    training: bool,
309}
310
311impl<T: Float> VaeEncoder<T> {
312    /// Build a randomly-initialized `VaeEncoder`.
313    ///
314    /// # Errors
315    ///
316    /// Returns the underlying [`FerrotorchError`] on bad config dims.
317    pub fn new(cfg: VaeEncoderConfig) -> FerrotorchResult<Self> {
318        cfg.validate()?;
319        let encoder = Encoder::<T>::new(cfg.clone())?;
320        let two_l = 2 * cfg.latent_channels;
321        let quant_conv = Conv2d::<T>::new(two_l, two_l, (1, 1), (1, 1), (0, 0), true)?;
322        Ok(Self {
323            encoder,
324            quant_conv,
325            config: cfg,
326            training: false,
327        })
328    }
329
330    /// Encode an image into a diagonal Gaussian distribution over
331    /// latent space. Matches `AutoencoderKL.encode(image).latent_dist`.
332    ///
333    /// # Errors
334    ///
335    /// Returns [`FerrotorchError::ShapeMismatch`] when the input is not
336    /// `[B, out_channels, H, W]`. Propagates downstream op errors.
337    pub fn encode(&self, image: &Tensor<T>) -> FerrotorchResult<DiagonalGaussianDistribution<T>> {
338        let params = self.forward(image)?;
339        DiagonalGaussianDistribution::from_parameters(&params, self.config.latent_channels)
340    }
341
342    /// Encode an image, sample from the latent distribution with a
343    /// deterministic seed, then multiply by `scaling_factor`. This
344    /// matches the canonical SD pipeline call:
345    ///
346    /// ```text
347    /// latent = vae.encode(image).latent_dist.sample() * vae.config.scaling_factor
348    /// ```
349    ///
350    /// The Box-Muller / xorshift noise generator runs on CPU and is
351    /// fully deterministic for a given `seed` — different from
352    /// PyTorch's CUDA RNG, so the produced latent will NOT be
353    /// numerically identical to a Python reference run; only
354    /// statistically equivalent.
355    ///
356    /// # Errors
357    ///
358    /// Returns the underlying [`FerrotorchError`] for shape or
359    /// arithmetic failures.
360    pub fn encode_with_scaling(&self, image: &Tensor<T>, seed: u64) -> FerrotorchResult<Tensor<T>> {
361        let dist = self.encode(image)?;
362        let sample = dist.sample_with_seed(seed)?;
363        let sf = T::from(self.config.scaling_factor).ok_or_else(|| {
364            FerrotorchError::InvalidArgument {
365                message: format!(
366                    "VaeEncoder::encode_with_scaling: cannot cast scaling_factor={} into Float",
367                    self.config.scaling_factor
368                ),
369            }
370        })?;
371        let sf_t = ferrotorch_core::scalar::<T>(sf)?;
372        ferrotorch_core::grad_fns::arithmetic::mul(&sample, &sf_t)
373    }
374}
375
376impl<T: Float> Module<T> for VaeEncoder<T> {
377    /// Forward returns the raw `[B, 2*latent_channels, h, w]` parameters
378    /// tensor (concatenated mean/logvar, post `quant_conv`). Callers
379    /// who want a latent should use [`Self::encode`] or
380    /// [`Self::encode_with_scaling`].
381    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
382        let cfg = &self.config;
383        if input.ndim() != 4 || input.shape()[1] != cfg.out_channels {
384            return Err(FerrotorchError::ShapeMismatch {
385                message: format!(
386                    "VaeEncoder::forward: expected [B, {}, H, W], got {:?}",
387                    cfg.out_channels,
388                    input.shape()
389                ),
390            });
391        }
392        let h = self.encoder.forward(input)?;
393        self.quant_conv.forward(&h)
394    }
395
396    fn parameters(&self) -> Vec<&Parameter<T>> {
397        let mut out = Vec::new();
398        out.extend(self.encoder.parameters());
399        out.extend(self.quant_conv.parameters());
400        out
401    }
402
403    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
404        let mut out = Vec::new();
405        out.extend(self.encoder.parameters_mut());
406        out.extend(self.quant_conv.parameters_mut());
407        out
408    }
409
410    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
411        let mut out = Vec::new();
412        for (n, p) in self.encoder.named_parameters() {
413            out.push((format!("encoder.{n}"), p));
414        }
415        for (n, p) in self.quant_conv.named_parameters() {
416            out.push((format!("quant_conv.{n}"), p));
417        }
418        out
419    }
420
421    fn train(&mut self) {
422        self.training = true;
423        self.encoder.train();
424    }
425    fn eval(&mut self) {
426        self.training = false;
427        self.encoder.eval();
428    }
429    fn is_training(&self) -> bool {
430        self.training
431    }
432
433    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
434        let extract = |prefix: &str| -> StateDict<T> {
435            let p = format!("{prefix}.");
436            state
437                .iter()
438                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
439                .collect()
440        };
441        if strict {
442            for k in state.keys() {
443                let ok = k.starts_with("encoder.") || k.starts_with("quant_conv.");
444                if !ok {
445                    return Err(FerrotorchError::InvalidArgument {
446                        message: format!("unexpected key in VaeEncoder state_dict: \"{k}\""),
447                    });
448                }
449            }
450        }
451        self.encoder.load_state_dict(&extract("encoder"), strict)?;
452        self.quant_conv
453            .load_state_dict(&extract("quant_conv"), strict)?;
454        let _: HashMap<String, Tensor<T>> = HashMap::new(); // keep HashMap import alive
455        Ok(())
456    }
457}
458
459/// Diagonal Gaussian over latent space — the same parameterization
460/// `diffusers.models.autoencoders.vae.DiagonalGaussianDistribution`
461/// uses. Holds `mean` and `logvar` tensors (both `[B, L, h, w]`) split
462/// from the encoder's concatenated parameters output.
463///
464/// `logvar` is clamped to `[-30, 20]` on construction, matching the
465/// diffusers reference exactly.
466#[derive(Debug)]
467pub struct DiagonalGaussianDistribution<T: Float> {
468    /// Mean of the diagonal Gaussian, `[B, L, h, w]`.
469    pub mean: Tensor<T>,
470    /// Log-variance of the diagonal Gaussian (clamped), `[B, L, h, w]`.
471    pub logvar: Tensor<T>,
472}
473
474impl<T: Float> DiagonalGaussianDistribution<T> {
475    /// Split the encoder's `[B, 2*L, h, w]` parameters tensor into
476    /// mean / logvar halves along the channel axis.
477    ///
478    /// # Errors
479    ///
480    /// Returns [`FerrotorchError::ShapeMismatch`] when the channel
481    /// dimension is not `2 * latent_channels`. Propagates downstream
482    /// errors from the chunk / clamp ops.
483    pub fn from_parameters(params: &Tensor<T>, latent_channels: usize) -> FerrotorchResult<Self> {
484        if params.ndim() != 4 || params.shape()[1] != 2 * latent_channels {
485            return Err(FerrotorchError::ShapeMismatch {
486                message: format!(
487                    "DiagonalGaussianDistribution::from_parameters: expected [B, {}, H, W], \
488                     got {:?}",
489                    2 * latent_channels,
490                    params.shape()
491                ),
492            });
493        }
494        // chunk(2, dim=1): two [B, L, h, w] tensors — [mean, logvar].
495        let parts = params.chunk(2, 1)?;
496        let mean = parts[0].clone();
497        let logvar_raw = parts[1].clone();
498        // Match diffusers' `torch.clamp(logvar, -30.0, 20.0)`.
499        let lo = T::from(LOGVAR_CLAMP_MIN).ok_or_else(|| FerrotorchError::InvalidArgument {
500            message: format!(
501                "DiagonalGaussianDistribution: cannot cast logvar clamp min {LOGVAR_CLAMP_MIN} \
502                 into Float"
503            ),
504        })?;
505        let hi = T::from(LOGVAR_CLAMP_MAX).ok_or_else(|| FerrotorchError::InvalidArgument {
506            message: format!(
507                "DiagonalGaussianDistribution: cannot cast logvar clamp max {LOGVAR_CLAMP_MAX} \
508                 into Float"
509            ),
510        })?;
511        let logvar = logvar_raw.clamp_t(lo, hi)?;
512        Ok(Self { mean, logvar })
513    }
514
515    /// Return the distribution mode — the deterministic latent if the
516    /// caller does not want to sample. Matches
517    /// `DiagonalGaussianDistribution.mode()` in diffusers.
518    pub fn mode(&self) -> &Tensor<T> {
519        &self.mean
520    }
521
522    /// Sample from the distribution with a deterministic seed.
523    ///
524    /// Computes `mean + std * eps` where `std = exp(0.5 * logvar)`
525    /// and `eps` is a fresh `N(0, 1)` tensor produced by Box-Muller
526    /// over a seeded xorshift64 PRNG (mirroring the CPU branch of
527    /// [`ferrotorch_core::randn`] but using `seed` as the state).
528    ///
529    /// This is fully deterministic for a given `seed` on a given host
530    /// — but the bitwise output will NOT match a CUDA PyTorch reference
531    /// run, because PyTorch's CUDA RNG uses Philox, not xorshift +
532    /// Box-Muller. Tests should compare statistical properties (mean,
533    /// std) or use `.mode()` for bit-exact reference matches.
534    ///
535    /// # Errors
536    ///
537    /// Returns the underlying [`FerrotorchError`] for shape or
538    /// arithmetic failures.
539    pub fn sample_with_seed(&self, seed: u64) -> FerrotorchResult<Tensor<T>> {
540        let eps = randn_with_seed::<T>(self.mean.shape(), seed)?;
541        // std = exp(0.5 * logvar)
542        let half = T::from(0.5f64).ok_or_else(|| FerrotorchError::InvalidArgument {
543            message: "DiagonalGaussianDistribution::sample_with_seed: cannot cast 0.5 into Float"
544                .into(),
545        })?;
546        let half_t = ferrotorch_core::scalar::<T>(half)?;
547        let scaled_logvar = ferrotorch_core::grad_fns::arithmetic::mul(&self.logvar, &half_t)?;
548        let std = ferrotorch_core::grad_fns::transcendental::exp(&scaled_logvar)?;
549        let noise = ferrotorch_core::grad_fns::arithmetic::mul(&std, &eps)?;
550        ferrotorch_core::grad_fns::arithmetic::add(&self.mean, &noise)
551    }
552}
553
554/// Box-Muller + xorshift64 N(0, 1) tensor generator seeded by `seed`.
555///
556/// Local to this module so the encoder doesn't pull in a fresh `rand`
557/// dependency. The algorithm mirrors the CPU branch of
558/// [`ferrotorch_core::randn`] exactly, but takes the seed as an input
559/// rather than deriving it from system time + thread id.
560fn randn_with_seed<T: Float>(shape: &[usize], seed: u64) -> FerrotorchResult<Tensor<T>> {
561    let numel: usize = shape.iter().product();
562    let mut data = Vec::with_capacity(numel);
563    let mut state = if seed == 0 {
564        0x0000_dead_beef_cafe
565    } else {
566        seed
567    };
568
569    let mut next_uniform = || -> f64 {
570        state ^= state << 13;
571        state ^= state >> 7;
572        state ^= state << 17;
573        ((state as f64) / (u64::MAX as f64)).max(1e-300)
574    };
575
576    let mut i = 0;
577    while i < numel {
578        let u1 = next_uniform();
579        let u2 = next_uniform();
580        let r = (-2.0 * u1.ln()).sqrt();
581        let theta = 2.0 * std::f64::consts::PI * u2;
582        data.push(
583            T::from(r * theta.cos()).ok_or_else(|| FerrotorchError::InvalidArgument {
584                message: "randn_with_seed: cannot cast Box-Muller output into Float".into(),
585            })?,
586        );
587        if i + 1 < numel {
588            data.push(T::from(r * theta.sin()).ok_or_else(|| {
589                FerrotorchError::InvalidArgument {
590                    message: "randn_with_seed: cannot cast Box-Muller output into Float".into(),
591                }
592            })?);
593        }
594        i += 2;
595    }
596
597    data.truncate(numel);
598    Tensor::from_storage(TensorStorage::cpu(data), shape.to_vec(), false)
599}
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604
605    /// Tiny config that exercises every architectural feature (mid-block
606    /// attn, 4 down-blocks, channel-changing resnet shortcut) without
607    /// making tests slow. Same shape as the decoder's `tiny_cfg`.
608    fn tiny_cfg() -> VaeEncoderConfig {
609        VaeDecoderConfig {
610            out_channels: 3,
611            latent_channels: 4,
612            block_out_channels: vec![4, 8, 16, 16],
613            layers_per_block: 1,
614            norm_num_groups: 4,
615            sample_size: 8,
616            scaling_factor: 0.18215,
617        }
618    }
619
620    #[test]
621    fn encoder_forward_shape() {
622        let cfg = tiny_cfg();
623        let e = Encoder::<f32>::new(cfg.clone()).unwrap();
624        // image: [1, 3, 8, 8] -> after 3 downsamples => [1, 2*4=8, 1, 1].
625        let x = Tensor::from_storage(
626            TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
627            vec![1, 3, 8, 8],
628            false,
629        )
630        .unwrap();
631        let y = e.forward(&x).unwrap();
632        // 8 -> 4 -> 2 -> 1 (3 downsamples, last block has no downsample).
633        assert_eq!(y.shape(), &[1, 2 * cfg.latent_channels, 1, 1]);
634        for &v in y.data().unwrap() {
635            assert!(v.is_finite(), "encoder output non-finite: {v}");
636        }
637    }
638
639    #[test]
640    fn vae_encoder_forward_shape() {
641        let cfg = tiny_cfg();
642        let v = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
643        let x = Tensor::from_storage(
644            TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
645            vec![1, 3, 8, 8],
646            false,
647        )
648        .unwrap();
649        let y = v.forward(&x).unwrap();
650        assert_eq!(y.shape(), &[1, 2 * cfg.latent_channels, 1, 1]);
651    }
652
653    #[test]
654    fn vae_encoder_named_parameters_include_quant_conv() {
655        let cfg = tiny_cfg();
656        let v = VaeEncoder::<f32>::new(cfg).unwrap();
657        let names: Vec<String> = v.named_parameters().into_iter().map(|(n, _)| n).collect();
658        for k in [
659            "quant_conv.weight",
660            "quant_conv.bias",
661            "encoder.conv_in.weight",
662            "encoder.down_blocks.0.resnets.0.norm1.weight",
663            "encoder.mid_block.attentions.0.to_q.weight",
664            "encoder.conv_norm_out.weight",
665            "encoder.conv_out.bias",
666        ] {
667            assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
668        }
669    }
670
671    #[test]
672    fn diag_gauss_split_and_mode_shapes() {
673        let cfg = tiny_cfg();
674        let v = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
675        let x = Tensor::from_storage(
676            TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
677            vec![1, 3, 8, 8],
678            false,
679        )
680        .unwrap();
681        let dist = v.encode(&x).unwrap();
682        assert_eq!(dist.mean.shape(), &[1, cfg.latent_channels, 1, 1]);
683        assert_eq!(dist.logvar.shape(), &[1, cfg.latent_channels, 1, 1]);
684        // .mode() must hand back exactly the mean tensor.
685        let mode = dist.mode();
686        assert_eq!(mode.shape(), dist.mean.shape());
687        for (a, b) in mode
688            .data()
689            .unwrap()
690            .iter()
691            .zip(dist.mean.data().unwrap().iter())
692        {
693            assert!((a - b).abs() < 1e-7);
694        }
695    }
696
697    #[test]
698    fn diag_gauss_logvar_is_clamped() {
699        // Build a synthetic params tensor where the logvar half is
700        // wildly out of range [-30, 20] and verify the clamp lands.
701        let l = 2;
702        let mut data = vec![0.5f32; l * 2]; // mean half: 0.5 everywhere
703        for d in data.iter_mut().skip(l) {
704            *d = 1e6; // logvar half: way above the +20 ceiling
705        }
706        let params =
707            Tensor::from_storage(TensorStorage::cpu(data), vec![1, 2 * l, 1, 1], false).unwrap();
708        let dist = DiagonalGaussianDistribution::<f32>::from_parameters(&params, l).unwrap();
709        for &v in dist.logvar.data().unwrap() {
710            assert!(
711                v <= LOGVAR_CLAMP_MAX as f32 + 1e-4,
712                "logvar not clamped to <= {LOGVAR_CLAMP_MAX}: got {v}"
713            );
714        }
715    }
716
717    #[test]
718    fn diag_gauss_sample_with_seed_is_deterministic() {
719        // Synthetic, controlled distribution: mean=0, logvar=0 ⇒ std=1
720        // ⇒ sample == eps, which is deterministic in our PRNG given a
721        // fixed seed.
722        let l = 4;
723        let zeros = vec![0.0f32; 2 * l * 3 * 3];
724        let params =
725            Tensor::from_storage(TensorStorage::cpu(zeros), vec![1, 2 * l, 3, 3], false).unwrap();
726        let dist = DiagonalGaussianDistribution::<f32>::from_parameters(&params, l).unwrap();
727        let s1 = dist.sample_with_seed(42).unwrap();
728        let s2 = dist.sample_with_seed(42).unwrap();
729        let s3 = dist.sample_with_seed(43).unwrap();
730        assert_eq!(s1.shape(), &[1, l, 3, 3]);
731        for (a, b) in s1.data().unwrap().iter().zip(s2.data().unwrap().iter()) {
732            assert!(
733                (a - b).abs() < 1e-7,
734                "sample_with_seed(42) not deterministic: {a} vs {b}"
735            );
736        }
737        // Different seed should produce different output (with extreme
738        // probability — Box-Muller from a different xorshift state).
739        let mut differ = false;
740        for (a, c) in s1.data().unwrap().iter().zip(s3.data().unwrap().iter()) {
741            if (a - c).abs() > 1e-6 {
742                differ = true;
743                break;
744            }
745        }
746        assert!(
747            differ,
748            "sample_with_seed(42) and sample_with_seed(43) produced identical output"
749        );
750    }
751
752    #[test]
753    fn vae_encoder_round_trip_state_dict() {
754        let cfg = tiny_cfg();
755        let src = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
756        let sd = src.state_dict();
757        let mut dst = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
758        dst.load_state_dict(&sd, true).unwrap();
759        let x = Tensor::from_storage(
760            TensorStorage::cpu(vec![0.01f32; 3 * 8 * 8]),
761            vec![1, 3, 8, 8],
762            false,
763        )
764        .unwrap();
765        let a = src.forward(&x).unwrap();
766        let b = dst.forward(&x).unwrap();
767        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
768            assert!((x - y).abs() < 1e-5, "round-trip differs: {x} vs {y}");
769        }
770    }
771
772    #[test]
773    fn encode_with_scaling_applies_scaling_factor() {
774        let cfg = tiny_cfg();
775        let v = VaeEncoder::<f32>::new(cfg.clone()).unwrap();
776        let x = Tensor::from_storage(
777            TensorStorage::cpu(vec![0.05f32; 3 * 8 * 8]),
778            vec![1, 3, 8, 8],
779            false,
780        )
781        .unwrap();
782        // Sample via the high-level API, then independently compute
783        // sample * scaling_factor and compare.
784        let scaled = v.encode_with_scaling(&x, 99).unwrap();
785        let dist = v.encode(&x).unwrap();
786        let raw = dist.sample_with_seed(99).unwrap();
787        for (s, r) in scaled
788            .data()
789            .unwrap()
790            .iter()
791            .zip(raw.data().unwrap().iter())
792        {
793            let expected = r * cfg.scaling_factor as f32;
794            assert!(
795                (s - expected).abs() < 1e-5,
796                "encode_with_scaling didn't apply scaling_factor: got {s}, expected {expected}"
797            );
798        }
799    }
800}