Skip to main content

VaeEncoder

Struct VaeEncoder 

Source
pub struct VaeEncoder<T: Float> {
    pub encoder: Encoder<T>,
    pub quant_conv: Conv2d<T>,
    pub config: VaeEncoderConfig,
    /* private fields */
}
Expand description

AutoencoderKL-style VAE encoder = Encoder + quant_conv.

The encoder produces [B, 2*latent_channels, H/8, W/8] after the Encoder stack; the 1x1 quant_conv then projects this through a learned linear map. The split into (mean, logvar) and any sampling happens in DiagonalGaussianDistribution.

Fields§

§encoder: Encoder<T>

The actual Encoder stack.

§quant_conv: Conv2d<T>

1x1 quant projection over 2 * latent_channels channels.

§config: VaeEncoderConfig

Frozen config copy.

Implementations§

Source§

impl<T: Float> VaeEncoder<T>

Source

pub fn load_hf_state_dict( &mut self, hf_state: &StateDict<T>, strict: bool, ) -> FerrotorchResult<DropReport>

Load a HuggingFace AutoencoderKL state dict into this module.

Accepts both:

  • encoder.* / quant_conv.* (bare-VAE layout, the normalised form produced by the pin script for an encoder-only artifact)
  • vae.encoder.* / vae.quant_conv.* (full SD pipeline checkpoint, e.g. an upstream runwayml/stable-diffusion-v1-5 model.safetensors)

Any other key (decoder, post_quant_conv, UNet, text encoder, …) is recorded in the returned DropReport (or, in strict mode, surfaces as FerrotorchError::InvalidArgument). This mirrors VaeDecoder::load_hf_state_dict so the same full-checkpoint file can be loaded twice — once for the encoder and once for the decoder — each time dropping the other half.

§Errors

Forwards whatever each sub-module’s load_state_dict returns (ShapeMismatch on a wrong-shape tensor, InvalidArgument in strict mode when a required tensor is missing). Strict mode will surface decoder.* / post_quant_conv.* / etc. as errors; callers with a full VAE checkpoint must pass strict=false.

Source§

impl<T: Float> VaeEncoder<T>

Source

pub fn new(cfg: VaeEncoderConfig) -> FerrotorchResult<Self>

Build a randomly-initialized VaeEncoder.

§Errors

Returns the underlying FerrotorchError on bad config dims.

Source

pub fn encode( &self, image: &Tensor<T>, ) -> FerrotorchResult<DiagonalGaussianDistribution<T>>

Encode an image into a diagonal Gaussian distribution over latent space. Matches AutoencoderKL.encode(image).latent_dist.

§Errors

Returns FerrotorchError::ShapeMismatch when the input is not [B, out_channels, H, W]. Propagates downstream op errors.

Source

pub fn encode_with_scaling( &self, image: &Tensor<T>, seed: u64, ) -> FerrotorchResult<Tensor<T>>

Encode an image, sample from the latent distribution with a deterministic seed, then multiply by scaling_factor. This matches the canonical SD pipeline call:

latent = vae.encode(image).latent_dist.sample() * vae.config.scaling_factor

The Box-Muller / xorshift noise generator runs on CPU and is fully deterministic for a given seed — different from PyTorch’s CUDA RNG, so the produced latent will NOT be numerically identical to a Python reference run; only statistically equivalent.

§Errors

Returns the underlying FerrotorchError for shape or arithmetic failures.

Trait Implementations§

Source§

impl<T: Debug + Float> Debug for VaeEncoder<T>

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl<T: Float> Module<T> for VaeEncoder<T>

Source§

fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>

Forward returns the raw [B, 2*latent_channels, h, w] parameters tensor (concatenated mean/logvar, post quant_conv). Callers who want a latent should use Self::encode or Self::encode_with_scaling.

Source§

fn parameters(&self) -> Vec<&Parameter<T>>

Iterate over all learnable parameters.
Source§

fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>>

Iterate over all learnable parameters mutably.
Source§

fn named_parameters(&self) -> Vec<(String, &Parameter<T>)>

Named parameters for state dict serialization. Read more
Source§

fn train(&mut self)

Set training mode. Affects dropout, batchnorm, etc.
Source§

fn eval(&mut self)

Set evaluation mode.
Source§

fn is_training(&self) -> bool

Whether the module is in training mode.
Source§

fn load_state_dict( &mut self, state: &StateDict<T>, strict: bool, ) -> FerrotorchResult<()>

Load parameters from a state dict. Read more
Source§

fn to_device(&mut self, device: Device) -> Result<(), FerrotorchError>

Move all parameters and buffers to a device. Read more
Source§

fn state_dict(&self) -> HashMap<String, Tensor<T>>

Export parameters and buffers as a state dict (torch parity). Read more
Source§

fn buffers(&self) -> Vec<&Buffer<T>>

Iterate over all non-trainable buffers (e.g. running mean / variance in BatchNorm). Default returns empty — concrete modules with buffers override.
Source§

fn buffers_mut(&mut self) -> Vec<&mut Buffer<T>>

Mutable iteration over all buffers. Default returns empty.
Source§

fn named_buffers(&self) -> Vec<(String, &Buffer<T>)>

Named buffers (dot-separated paths for nested modules). Default returns empty.
Source§

fn as_any(&self) -> Option<&(dyn Any + 'static)>

Downcast hook for type-erased buffer-loader dispatch. (#984) Read more
Source§

fn children(&self) -> Vec<&dyn Module<T>>

Direct child modules. Default returns empty (leaf module).
Source§

fn named_children(&self) -> Vec<(String, &dyn Module<T>)>

Direct child modules with their attribute names. Default returns empty.
Source§

fn modules(&self) -> Vec<&dyn Module<T>>
where Self: Sized,

All modules in this subtree, depth-first (self first, then each child’s descendants in order). Read more
Source§

fn descendants_dyn(&self) -> Vec<&dyn Module<T>>

All strict descendants of self in depth-first order. Object-safe.
Source§

fn named_modules(&self) -> Vec<(String, &dyn Module<T>)>
where Self: Sized,

All modules in this subtree with dot-separated path names. The root is named ""; children paths are joined with ..
Source§

fn named_descendants_dyn(&self) -> Vec<(String, &dyn Module<T>)>

Strict descendants with dot-paths. Object-safe.
Source§

fn with_forward_hook( self, hook: Box<dyn Fn(&Tensor<T>, &Tensor<T>) + Sync + Send>, ) -> (HookedModule<Self, T>, HookHandle)
where Self: Sized,

Wrap this module in a HookedModule and register a forward hook. Returns the wrapper paired with a HookHandle that can be used to remove the hook later. The wrapper implements Module<T> itself, so it slots into any place the original module did. Mirrors torch.nn.Module.register_forward_hook.
Source§

fn with_forward_pre_hook( self, hook: Box<dyn Fn(&Tensor<T>) -> Result<Tensor<T>, FerrotorchError> + Sync + Send>, ) -> (HookedModule<Self, T>, HookHandle)
where Self: Sized,

Wrap this module in a HookedModule and register a forward pre-hook. See Self::with_forward_hook. Mirrors torch.nn.Module.register_forward_pre_hook.
Source§

fn with_backward_hook( self, hook: Box<dyn Fn(&Tensor<T>, &Tensor<T>) + Sync + Send>, ) -> (HookedModule<Self, T>, HookHandle)
where Self: Sized,

Wrap this module in a HookedModule and register a backward hook. See Self::with_forward_hook. Mirrors torch.nn.Module.register_backward_hook.
Source§

fn zero_grad(&self) -> Result<(), FerrotorchError>

Set the gradient of every parameter to None. Read more
Source§

fn requires_grad_(&mut self, requires_grad: bool)

Toggle requires_grad on every parameter (freeze / unfreeze the module). Mirrors torch.nn.Module.requires_grad_.
Source§

fn apply_to_parameters(&mut self, f: &mut dyn FnMut(&mut Parameter<T>))

Apply a function to every parameter in this module. Mirrors torch.nn.Module.apply for the parameter case (true apply recurses over all submodules; the recursive form requires &mut dyn Module which conflicts with this trait’s &mut self borrow). Read more

Auto Trait Implementations§

§

impl<T> Freeze for VaeEncoder<T>

§

impl<T> !RefUnwindSafe for VaeEncoder<T>

§

impl<T> Send for VaeEncoder<T>

§

impl<T> Sync for VaeEncoder<T>

§

impl<T> Unpin for VaeEncoder<T>

§

impl<T> UnsafeUnpin for VaeEncoder<T>

§

impl<T> !UnwindSafe for VaeEncoder<T>

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> ByRef<T> for T

Source§

fn by_ref(&self) -> &T

Source§

impl<T> DistributionExt for T
where T: ?Sized,

Source§

fn rand<T>(&self, rng: &mut (impl Rng + ?Sized)) -> T
where Self: Distribution<T>,

Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T> Instrument for T

Source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T> WithSubscriber for T

Source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

impl<T, U> Imply<T> for U
where T: ?Sized, U: ?Sized,