Skip to main content

BasicTransformerBlock

Struct BasicTransformerBlock 

Source
pub struct BasicTransformerBlock<T: Float> {
    pub norm1: LayerNorm<T>,
    pub attn1: Attention<T>,
    pub norm2: LayerNorm<T>,
    pub attn2: Attention<T>,
    pub norm3: LayerNorm<T>,
    pub ff: FeedForward<T>,
    /* private fields */
}
Expand description

Diffusers’ BasicTransformerBlock configured the way SD-1.5’s UNet uses it: pre-LayerNorm on every sub-layer, self-attn followed by cross-attn followed by GEGLU FeedForward, all with residuals.

State-dict layout:

norm1.{weight,bias}   [dim], [dim]
attn1.<keys>          # self-attn (Attention with cross_attention_dim=None)
norm2.{weight,bias}
attn2.<keys>          # cross-attn
norm3.{weight,bias}
ff.<keys>             # FeedForward (GEGLU)

Fields§

§norm1: LayerNorm<T>

LayerNorm before self-attn.

§attn1: Attention<T>

Self-attention.

§norm2: LayerNorm<T>

LayerNorm before cross-attn.

§attn2: Attention<T>

Cross-attention.

§norm3: LayerNorm<T>

LayerNorm before FF.

§ff: FeedForward<T>

GEGLU FeedForward.

Implementations§

Source§

impl<T: Float> BasicTransformerBlock<T>

Source

pub fn new( dim: usize, heads: usize, dim_head: usize, cross_attention_dim: usize, ) -> FerrotorchResult<Self>

Build a randomly-initialized BasicTransformerBlock.

§Errors

Returns the underlying FerrotorchError for invalid dims.

Source

pub fn forward_xattn( &self, x: &Tensor<T>, encoder_hidden_states: &Tensor<T>, ) -> FerrotorchResult<Tensor<T>>

Forward with optional encoder hidden states. Self-attn ignores encoder_hidden_states, cross-attn uses it.

§Errors

Returns FerrotorchError::ShapeMismatch on rank disagreement, underlying op errors otherwise.

Trait Implementations§

Source§

impl<T: Debug + Float> Debug for BasicTransformerBlock<T>

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl<T: Float> Module<T> for BasicTransformerBlock<T>

Source§

fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>

Forward pass. Takes input tensor, returns output tensor.
Source§

fn parameters(&self) -> Vec<&Parameter<T>>

Iterate over all learnable parameters.
Source§

fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>>

Iterate over all learnable parameters mutably.
Source§

fn named_parameters(&self) -> Vec<(String, &Parameter<T>)>

Named parameters for state dict serialization. Read more
Source§

fn train(&mut self)

Set training mode. Affects dropout, batchnorm, etc.
Source§

fn eval(&mut self)

Set evaluation mode.
Source§

fn is_training(&self) -> bool

Whether the module is in training mode.
Source§

fn load_state_dict( &mut self, state: &StateDict<T>, strict: bool, ) -> FerrotorchResult<()>

Load parameters from a state dict. Read more
Source§

fn to_device(&mut self, device: Device) -> Result<(), FerrotorchError>

Move all parameters and buffers to a device. Read more
Source§

fn state_dict(&self) -> HashMap<String, Tensor<T>>

Export parameters and buffers as a state dict (torch parity). Read more
Source§

fn buffers(&self) -> Vec<&Buffer<T>>

Iterate over all non-trainable buffers (e.g. running mean / variance in BatchNorm). Default returns empty — concrete modules with buffers override.
Source§

fn buffers_mut(&mut self) -> Vec<&mut Buffer<T>>

Mutable iteration over all buffers. Default returns empty.
Source§

fn named_buffers(&self) -> Vec<(String, &Buffer<T>)>

Named buffers (dot-separated paths for nested modules). Default returns empty.
Source§

fn as_any(&self) -> Option<&(dyn Any + 'static)>

Downcast hook for type-erased buffer-loader dispatch. (#984) Read more
Source§

fn children(&self) -> Vec<&dyn Module<T>>

Direct child modules. Default returns empty (leaf module).
Source§

fn named_children(&self) -> Vec<(String, &dyn Module<T>)>

Direct child modules with their attribute names. Default returns empty.
Source§

fn modules(&self) -> Vec<&dyn Module<T>>
where Self: Sized,

All modules in this subtree, depth-first (self first, then each child’s descendants in order). Read more
Source§

fn descendants_dyn(&self) -> Vec<&dyn Module<T>>

All strict descendants of self in depth-first order. Object-safe.
Source§

fn named_modules(&self) -> Vec<(String, &dyn Module<T>)>
where Self: Sized,

All modules in this subtree with dot-separated path names. The root is named ""; children paths are joined with ..
Source§

fn named_descendants_dyn(&self) -> Vec<(String, &dyn Module<T>)>

Strict descendants with dot-paths. Object-safe.
Source§

fn with_forward_hook( self, hook: Box<dyn Fn(&Tensor<T>, &Tensor<T>) + Send + Sync>, ) -> (HookedModule<Self, T>, HookHandle)
where Self: Sized,

Wrap this module in a HookedModule and register a forward hook. Returns the wrapper paired with a HookHandle that can be used to remove the hook later. The wrapper implements Module<T> itself, so it slots into any place the original module did. Mirrors torch.nn.Module.register_forward_hook.
Source§

fn with_forward_pre_hook( self, hook: Box<dyn Fn(&Tensor<T>) -> Result<Tensor<T>, FerrotorchError> + Send + Sync>, ) -> (HookedModule<Self, T>, HookHandle)
where Self: Sized,

Wrap this module in a HookedModule and register a forward pre-hook. See Self::with_forward_hook. Mirrors torch.nn.Module.register_forward_pre_hook.
Source§

fn with_backward_hook( self, hook: Box<dyn Fn(&Tensor<T>, &Tensor<T>) + Send + Sync>, ) -> (HookedModule<Self, T>, HookHandle)
where Self: Sized,

Wrap this module in a HookedModule and register a backward hook. See Self::with_forward_hook. Mirrors torch.nn.Module.register_backward_hook.
Source§

fn zero_grad(&self) -> Result<(), FerrotorchError>

Set the gradient of every parameter to None. Read more
Source§

fn requires_grad_(&mut self, requires_grad: bool)

Toggle requires_grad on every parameter (freeze / unfreeze the module). Mirrors torch.nn.Module.requires_grad_.
Source§

fn apply_to_parameters(&mut self, f: &mut dyn FnMut(&mut Parameter<T>))

Apply a function to every parameter in this module. Mirrors torch.nn.Module.apply for the parameter case (true apply recurses over all submodules; the recursive form requires &mut dyn Module which conflicts with this trait’s &mut self borrow). Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> ByRef<T> for T

Source§

fn by_ref(&self) -> &T

Source§

impl<T> DistributionExt for T
where T: ?Sized,

Source§

fn rand<T>(&self, rng: &mut (impl Rng + ?Sized)) -> T
where Self: Distribution<T>,

Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Imply<T> for U
where T: ?Sized, U: ?Sized,

Source§

impl<T> Instrument for T

Source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T> WithSubscriber for T

Source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more