pub struct VaeEncoder<T: Float> {
pub encoder: Encoder<T>,
pub quant_conv: Conv2d<T>,
pub config: VaeEncoderConfig,
/* private fields */
}Expand description
AutoencoderKL-style VAE encoder = Encoder + quant_conv.
The encoder produces [B, 2*latent_channels, H/8, W/8] after the
Encoder stack; the 1x1 quant_conv then projects this through a
learned linear map. The split into (mean, logvar) and any sampling
happens in DiagonalGaussianDistribution.
Fields§
§encoder: Encoder<T>The actual Encoder stack.
quant_conv: Conv2d<T>1x1 quant projection over 2 * latent_channels channels.
config: VaeEncoderConfigFrozen config copy.
Implementations§
Source§impl<T: Float> VaeEncoder<T>
impl<T: Float> VaeEncoder<T>
Sourcepub fn load_hf_state_dict(
&mut self,
hf_state: &StateDict<T>,
strict: bool,
) -> FerrotorchResult<DropReport>
pub fn load_hf_state_dict( &mut self, hf_state: &StateDict<T>, strict: bool, ) -> FerrotorchResult<DropReport>
Load a HuggingFace AutoencoderKL state dict into this module.
Accepts both:
encoder.*/quant_conv.*(bare-VAE layout, the normalised form produced by the pin script for an encoder-only artifact)vae.encoder.*/vae.quant_conv.*(full SD pipeline checkpoint, e.g. an upstreamrunwayml/stable-diffusion-v1-5model.safetensors)
Any other key (decoder, post_quant_conv, UNet, text encoder, …)
is recorded in the returned DropReport (or, in strict mode,
surfaces as FerrotorchError::InvalidArgument). This mirrors
VaeDecoder::load_hf_state_dict so the same full-checkpoint
file can be loaded twice — once for the encoder and once for the
decoder — each time dropping the other half.
§Errors
Forwards whatever each sub-module’s load_state_dict returns
(ShapeMismatch on a wrong-shape tensor, InvalidArgument in
strict mode when a required tensor is missing). Strict mode will
surface decoder.* / post_quant_conv.* / etc. as errors;
callers with a full VAE checkpoint must pass strict=false.
Source§impl<T: Float> VaeEncoder<T>
impl<T: Float> VaeEncoder<T>
Sourcepub fn new(cfg: VaeEncoderConfig) -> FerrotorchResult<Self>
pub fn new(cfg: VaeEncoderConfig) -> FerrotorchResult<Self>
Build a randomly-initialized VaeEncoder.
§Errors
Returns the underlying FerrotorchError on bad config dims.
Sourcepub fn encode(
&self,
image: &Tensor<T>,
) -> FerrotorchResult<DiagonalGaussianDistribution<T>>
pub fn encode( &self, image: &Tensor<T>, ) -> FerrotorchResult<DiagonalGaussianDistribution<T>>
Encode an image into a diagonal Gaussian distribution over
latent space. Matches AutoencoderKL.encode(image).latent_dist.
§Errors
Returns FerrotorchError::ShapeMismatch when the input is not
[B, out_channels, H, W]. Propagates downstream op errors.
Sourcepub fn encode_with_scaling(
&self,
image: &Tensor<T>,
seed: u64,
) -> FerrotorchResult<Tensor<T>>
pub fn encode_with_scaling( &self, image: &Tensor<T>, seed: u64, ) -> FerrotorchResult<Tensor<T>>
Encode an image, sample from the latent distribution with a
deterministic seed, then multiply by scaling_factor. This
matches the canonical SD pipeline call:
latent = vae.encode(image).latent_dist.sample() * vae.config.scaling_factorThe Box-Muller / xorshift noise generator runs on CPU and is
fully deterministic for a given seed — different from
PyTorch’s CUDA RNG, so the produced latent will NOT be
numerically identical to a Python reference run; only
statistically equivalent.
§Errors
Returns the underlying FerrotorchError for shape or
arithmetic failures.
Trait Implementations§
Source§impl<T: Float> Module<T> for VaeEncoder<T>
impl<T: Float> Module<T> for VaeEncoder<T>
Source§fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>
Forward returns the raw [B, 2*latent_channels, h, w] parameters
tensor (concatenated mean/logvar, post quant_conv). Callers
who want a latent should use Self::encode or
Self::encode_with_scaling.
Source§fn parameters(&self) -> Vec<&Parameter<T>>
fn parameters(&self) -> Vec<&Parameter<T>>
Source§fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>>
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>>
Source§fn named_parameters(&self) -> Vec<(String, &Parameter<T>)>
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)>
Source§fn is_training(&self) -> bool
fn is_training(&self) -> bool
Source§fn load_state_dict(
&mut self,
state: &StateDict<T>,
strict: bool,
) -> FerrotorchResult<()>
fn load_state_dict( &mut self, state: &StateDict<T>, strict: bool, ) -> FerrotorchResult<()>
Source§fn to_device(&mut self, device: Device) -> Result<(), FerrotorchError>
fn to_device(&mut self, device: Device) -> Result<(), FerrotorchError>
Source§fn state_dict(&self) -> HashMap<String, Tensor<T>>
fn state_dict(&self) -> HashMap<String, Tensor<T>>
Source§fn buffers(&self) -> Vec<&Buffer<T>>
fn buffers(&self) -> Vec<&Buffer<T>>
Source§fn buffers_mut(&mut self) -> Vec<&mut Buffer<T>>
fn buffers_mut(&mut self) -> Vec<&mut Buffer<T>>
Source§fn named_buffers(&self) -> Vec<(String, &Buffer<T>)>
fn named_buffers(&self) -> Vec<(String, &Buffer<T>)>
Source§fn as_any(&self) -> Option<&(dyn Any + 'static)>
fn as_any(&self) -> Option<&(dyn Any + 'static)>
Source§fn children(&self) -> Vec<&dyn Module<T>>
fn children(&self) -> Vec<&dyn Module<T>>
Source§fn named_children(&self) -> Vec<(String, &dyn Module<T>)>
fn named_children(&self) -> Vec<(String, &dyn Module<T>)>
Source§fn modules(&self) -> Vec<&dyn Module<T>>where
Self: Sized,
fn modules(&self) -> Vec<&dyn Module<T>>where
Self: Sized,
Source§fn descendants_dyn(&self) -> Vec<&dyn Module<T>>
fn descendants_dyn(&self) -> Vec<&dyn Module<T>>
self in depth-first order. Object-safe.Source§fn named_modules(&self) -> Vec<(String, &dyn Module<T>)>where
Self: Sized,
fn named_modules(&self) -> Vec<(String, &dyn Module<T>)>where
Self: Sized,
""; children paths are joined with ..Source§fn named_descendants_dyn(&self) -> Vec<(String, &dyn Module<T>)>
fn named_descendants_dyn(&self) -> Vec<(String, &dyn Module<T>)>
Source§fn with_forward_hook(
self,
hook: Box<dyn Fn(&Tensor<T>, &Tensor<T>) + Send + Sync>,
) -> (HookedModule<Self, T>, HookHandle)where
Self: Sized,
fn with_forward_hook(
self,
hook: Box<dyn Fn(&Tensor<T>, &Tensor<T>) + Send + Sync>,
) -> (HookedModule<Self, T>, HookHandle)where
Self: Sized,
HookedModule and register a forward hook.
Returns the wrapper paired with a HookHandle that can be used to
remove the hook later. The wrapper implements Module<T> itself, so
it slots into any place the original module did. Mirrors
torch.nn.Module.register_forward_hook.Source§fn with_forward_pre_hook(
self,
hook: Box<dyn Fn(&Tensor<T>) -> Result<Tensor<T>, FerrotorchError> + Send + Sync>,
) -> (HookedModule<Self, T>, HookHandle)where
Self: Sized,
fn with_forward_pre_hook(
self,
hook: Box<dyn Fn(&Tensor<T>) -> Result<Tensor<T>, FerrotorchError> + Send + Sync>,
) -> (HookedModule<Self, T>, HookHandle)where
Self: Sized,
HookedModule and register a forward
pre-hook. See Self::with_forward_hook. Mirrors
torch.nn.Module.register_forward_pre_hook.Source§fn with_backward_hook(
self,
hook: Box<dyn Fn(&Tensor<T>, &Tensor<T>) + Send + Sync>,
) -> (HookedModule<Self, T>, HookHandle)where
Self: Sized,
fn with_backward_hook(
self,
hook: Box<dyn Fn(&Tensor<T>, &Tensor<T>) + Send + Sync>,
) -> (HookedModule<Self, T>, HookHandle)where
Self: Sized,
HookedModule and register a backward hook.
See Self::with_forward_hook. Mirrors
torch.nn.Module.register_backward_hook.Source§fn zero_grad(&self) -> Result<(), FerrotorchError>
fn zero_grad(&self) -> Result<(), FerrotorchError>
None. Read moreSource§fn requires_grad_(&mut self, requires_grad: bool)
fn requires_grad_(&mut self, requires_grad: bool)
requires_grad on every parameter (freeze / unfreeze the
module). Mirrors torch.nn.Module.requires_grad_.Source§fn apply_to_parameters(&mut self, f: &mut dyn FnMut(&mut Parameter<T>))
fn apply_to_parameters(&mut self, f: &mut dyn FnMut(&mut Parameter<T>))
torch.nn.Module.apply for the parameter case (true apply recurses
over all submodules; the recursive form requires &mut dyn Module
which conflicts with this trait’s &mut self borrow). Read moreAuto Trait Implementations§
impl<T> !RefUnwindSafe for VaeEncoder<T>
impl<T> !UnwindSafe for VaeEncoder<T>
impl<T> Freeze for VaeEncoder<T>
impl<T> Send for VaeEncoder<T>
impl<T> Sync for VaeEncoder<T>
impl<T> Unpin for VaeEncoder<T>
impl<T> UnsafeUnpin for VaeEncoder<T>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> DistributionExt for Twhere
T: ?Sized,
impl<T> DistributionExt for Twhere
T: ?Sized,
impl<T, U> Imply<T> for U
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more