Skip to main content

Module patch_embed

Module patch_embed 

Source
Expand description

Patch embedding for Vision Transformers.

Converts a CHW image to a sequence of patch tokens by applying a strided Conv2D with kernel_size == stride == patch_size. Also provides 2-D sinusoidal and learnable positional encodings.

Re-exports§

pub use conv2d_patch::PatchEmbed;
pub use conv2d_patch::PatchEmbedConfig;
pub use conv2d_patch::PatchEmbedWeights;
pub use conv2d_patch::prepend_cls;
pub use pos_embed::LearnablePosEmbed;
pub use pos_embed::add_pos_embed;
pub use pos_embed::pos_2d_sincos;

Modules§

conv2d_patch
Patch embedder: strided Conv2D producing [N_patches, embed_dim].
pos_embed
2-D positional encodings for Vision Transformers.