use crate::error::{Result, TrustformerError};
use crate::stacks::{EncoderStack, EncoderStackConfig};
use tensorlogic_ir::{EinsumGraph, EinsumNode};
#[derive(Debug, Clone)]
pub struct PatchEmbeddingConfig {
pub image_size: usize,
pub patch_size: usize,
pub in_channels: usize,
pub d_model: usize,
}
impl PatchEmbeddingConfig {
pub fn new(
image_size: usize,
patch_size: usize,
in_channels: usize,
d_model: usize,
) -> Result<Self> {
if image_size == 0 {
return Err(TrustformerError::CompilationError(
"image_size must be > 0".into(),
));
}
if patch_size == 0 {
return Err(TrustformerError::CompilationError(
"patch_size must be > 0".into(),
));
}
if !image_size.is_multiple_of(patch_size) {
return Err(TrustformerError::CompilationError(format!(
"image_size ({}) must be divisible by patch_size ({})",
image_size, patch_size
)));
}
if in_channels == 0 {
return Err(TrustformerError::CompilationError(
"in_channels must be > 0".into(),
));
}
if d_model == 0 {
return Err(TrustformerError::CompilationError(
"d_model must be > 0".into(),
));
}
Ok(Self {
image_size,
patch_size,
in_channels,
d_model,
})
}
pub fn num_patches(&self) -> usize {
let patches_per_side = self.image_size / self.patch_size;
patches_per_side * patches_per_side
}
pub fn patch_dim(&self) -> usize {
self.patch_size * self.patch_size * self.in_channels
}
pub fn validate(&self) -> Result<()> {
if !self.image_size.is_multiple_of(self.patch_size) {
return Err(TrustformerError::CompilationError(
"image_size must be divisible by patch_size".into(),
));
}
Ok(())
}
}
pub struct PatchEmbedding {
config: PatchEmbeddingConfig,
}
impl PatchEmbedding {
pub fn new(config: PatchEmbeddingConfig) -> Result<Self> {
config.validate()?;
Ok(Self { config })
}
pub fn build_patch_embed_graph(&self, graph: &mut EinsumGraph) -> Result<usize> {
let output_tensor = graph.add_tensor("patch_embeddings");
let node = EinsumNode::new("bnp,pd->bnd", vec![0, 1], vec![output_tensor]);
graph.add_node(node)?;
Ok(output_tensor)
}
pub fn config(&self) -> &PatchEmbeddingConfig {
&self.config
}
}
#[derive(Debug, Clone)]
pub struct VisionTransformerConfig {
pub patch_embed: PatchEmbeddingConfig,
pub encoder: EncoderStackConfig,
pub num_classes: usize,
pub use_class_token: bool,
pub classifier_dropout: f64,
}
impl VisionTransformerConfig {
#[allow(clippy::too_many_arguments)]
pub fn new(
image_size: usize,
patch_size: usize,
in_channels: usize,
d_model: usize,
n_heads: usize,
d_ff: usize,
n_layers: usize,
num_classes: usize,
) -> Result<Self> {
let patch_embed = PatchEmbeddingConfig::new(image_size, patch_size, in_channels, d_model)?;
let max_seq_len = patch_embed.num_patches() + 1; let encoder = EncoderStackConfig::new(n_layers, d_model, n_heads, d_ff, max_seq_len)?
.with_learned_position_encoding();
Ok(Self {
patch_embed,
encoder,
num_classes,
use_class_token: true,
classifier_dropout: 0.0,
})
}
pub fn with_class_token(mut self, use_class_token: bool) -> Self {
self.use_class_token = use_class_token;
self
}
pub fn with_classifier_dropout(mut self, dropout: f64) -> Self {
self.classifier_dropout = dropout;
self
}
pub fn with_learned_position_encoding(mut self) -> Self {
self.encoder = self.encoder.with_learned_position_encoding();
self
}
pub fn with_pre_norm(mut self, pre_norm: bool) -> Self {
self.encoder.layer_config = self.encoder.layer_config.with_pre_norm(pre_norm);
self
}
pub fn with_dropout(mut self, dropout: f64) -> Self {
self.encoder = self.encoder.with_dropout(dropout);
self.classifier_dropout = dropout;
self
}
pub fn seq_length(&self) -> usize {
let base = self.patch_embed.num_patches();
if self.use_class_token {
base + 1
} else {
base
}
}
pub fn validate(&self) -> Result<()> {
self.patch_embed.validate()?;
self.encoder.validate()?;
if self.num_classes == 0 {
return Err(TrustformerError::CompilationError(
"num_classes must be > 0".into(),
));
}
Ok(())
}
}
pub struct VisionTransformer {
config: VisionTransformerConfig,
patch_embed: PatchEmbedding,
#[allow(dead_code)] encoder: EncoderStack,
}
impl VisionTransformer {
pub fn new(config: VisionTransformerConfig) -> Result<Self> {
config.validate()?;
let patch_embed = PatchEmbedding::new(config.patch_embed.clone())?;
let encoder = EncoderStack::new(config.encoder.clone())?;
Ok(Self {
config,
patch_embed,
encoder,
})
}
pub fn build_vit_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
let patches = self.patch_embed.build_patch_embed_graph(graph)?;
let positioned = graph.add_tensor("positioned_embeddings");
let pos_add_node = EinsumNode::elem_binary(
"add_pos_embed".to_string(),
patches,
2, positioned,
);
graph.add_node(pos_add_node)?;
Ok(vec![positioned])
}
pub fn config(&self) -> &VisionTransformerConfig {
&self.config
}
pub fn count_parameters(&self) -> usize {
let mut total = 0;
total += self.config.patch_embed.patch_dim() * self.config.patch_embed.d_model;
if self.config.use_class_token {
total += self.config.patch_embed.d_model;
}
total += self.config.seq_length() * self.config.patch_embed.d_model;
let params_per_layer =
crate::utils::count_encoder_layer_params(&self.config.encoder.layer_config);
total += params_per_layer * self.config.encoder.num_layers;
if self.config.encoder.final_layer_norm {
total +=
crate::utils::count_layernorm_params(&self.config.encoder.layer_config.layer_norm);
}
total +=
self.config.patch_embed.d_model * self.config.num_classes + self.config.num_classes;
total
}
}
pub enum ViTPreset {
Tiny16,
Small16,
Base16,
Large16,
Huge14,
}
impl ViTPreset {
pub fn config(&self, num_classes: usize) -> Result<VisionTransformerConfig> {
let (image_size, patch_size, d_model, n_heads, d_ff, n_layers) = match self {
ViTPreset::Tiny16 => (224, 16, 192, 3, 768, 12),
ViTPreset::Small16 => (224, 16, 384, 6, 1536, 12),
ViTPreset::Base16 => (224, 16, 768, 12, 3072, 12),
ViTPreset::Large16 => (224, 16, 1024, 16, 4096, 24),
ViTPreset::Huge14 => (224, 14, 1280, 16, 5120, 32),
};
VisionTransformerConfig::new(
image_size,
patch_size,
3, d_model,
n_heads,
d_ff,
n_layers,
num_classes,
)
}
pub fn name(&self) -> &'static str {
match self {
ViTPreset::Tiny16 => "ViT-Tiny/16",
ViTPreset::Small16 => "ViT-Small/16",
ViTPreset::Base16 => "ViT-Base/16",
ViTPreset::Large16 => "ViT-Large/16",
ViTPreset::Huge14 => "ViT-Huge/14",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_patch_embedding_config() {
let config = PatchEmbeddingConfig::new(224, 16, 3, 768).expect("unwrap");
assert_eq!(config.image_size, 224);
assert_eq!(config.patch_size, 16);
assert_eq!(config.in_channels, 3);
assert_eq!(config.d_model, 768);
assert_eq!(config.num_patches(), 196); assert_eq!(config.patch_dim(), 768); }
#[test]
fn test_patch_embedding_invalid_size() {
let result = PatchEmbeddingConfig::new(225, 16, 3, 768);
assert!(result.is_err()); }
#[test]
fn test_patch_embedding_graph() {
let config = PatchEmbeddingConfig::new(224, 16, 3, 768).expect("unwrap");
let patch_embed = PatchEmbedding::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("image");
graph.add_tensor("W_patch_embed");
let output = patch_embed
.build_patch_embed_graph(&mut graph)
.expect("unwrap");
assert!(output > 0);
assert!(graph.validate().is_ok());
}
#[test]
fn test_vit_config_creation() {
let config = VisionTransformerConfig::new(
224, 16, 3, 768, 12, 3072, 12, 1000, )
.expect("unwrap");
assert_eq!(config.num_classes, 1000);
assert!(config.use_class_token);
assert_eq!(config.seq_length(), 197); }
#[test]
fn test_vit_config_without_class_token() {
let config = VisionTransformerConfig::new(224, 16, 3, 768, 12, 3072, 12, 1000)
.expect("unwrap")
.with_class_token(false);
assert!(!config.use_class_token);
assert_eq!(config.seq_length(), 196); }
#[test]
fn test_vit_creation() {
let config =
VisionTransformerConfig::new(224, 16, 3, 768, 12, 3072, 12, 1000).expect("unwrap");
let vit = VisionTransformer::new(config).expect("unwrap");
assert!(vit.config().validate().is_ok());
}
#[test]
fn test_vit_graph_building() {
let config = VisionTransformerConfig::new(224, 16, 3, 384, 6, 1536, 2, 10).expect("unwrap");
let vit = VisionTransformer::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("patches"); graph.add_tensor("W_patch_embed"); graph.add_tensor("pos_embed");
let result = vit.build_vit_graph(&mut graph);
assert!(result.is_ok());
let outputs = result.expect("unwrap");
assert!(!outputs.is_empty());
}
#[test]
fn test_vit_parameter_count() {
let config =
VisionTransformerConfig::new(224, 16, 3, 768, 12, 3072, 12, 1000).expect("unwrap");
let vit = VisionTransformer::new(config).expect("unwrap");
let params = vit.count_parameters();
assert!(params > 0);
}
#[test]
fn test_vit_presets() {
for preset in [
ViTPreset::Tiny16,
ViTPreset::Small16,
ViTPreset::Base16,
ViTPreset::Large16,
ViTPreset::Huge14,
] {
let config = preset.config(1000).expect("unwrap");
assert!(config.validate().is_ok());
assert_eq!(config.num_classes, 1000);
let vit = VisionTransformer::new(config).expect("unwrap");
assert!(vit.count_parameters() > 0);
}
}
#[test]
fn test_vit_preset_names() {
assert_eq!(ViTPreset::Tiny16.name(), "ViT-Tiny/16");
assert_eq!(ViTPreset::Small16.name(), "ViT-Small/16");
assert_eq!(ViTPreset::Base16.name(), "ViT-Base/16");
assert_eq!(ViTPreset::Large16.name(), "ViT-Large/16");
assert_eq!(ViTPreset::Huge14.name(), "ViT-Huge/14");
}
#[test]
fn test_different_image_sizes() {
for (image_size, patch_size) in [(224, 16), (384, 16), (512, 32)] {
let config = PatchEmbeddingConfig::new(image_size, patch_size, 3, 768).expect("unwrap");
let expected_patches = (image_size / patch_size) * (image_size / patch_size);
assert_eq!(config.num_patches(), expected_patches);
}
}
#[test]
fn test_vit_config_builder() {
let config = VisionTransformerConfig::new(224, 16, 3, 768, 12, 3072, 12, 1000)
.expect("unwrap")
.with_class_token(true)
.with_classifier_dropout(0.1)
.with_pre_norm(true)
.with_dropout(0.1);
assert!(config.use_class_token);
assert!((config.classifier_dropout - 0.1).abs() < 1e-10);
assert!(config.encoder.layer_config.pre_norm);
assert!(config.validate().is_ok());
}
}