use std::collections::HashMap;
use torsh_core::{device::DeviceType, error::Result};
use torsh_nn::prelude::*;
use torsh_nn::{Module, Parameter};
use torsh_tensor::{creation, Tensor};
use crate::{ModelConfig, VisionModel};
pub struct MultiHeadAttention {
num_heads: usize,
head_dim: usize,
scale: f32,
qkv: Linear,
proj: Linear,
dropout: Dropout,
}
impl MultiHeadAttention {
pub fn new(embed_dim: usize, num_heads: usize, dropout: f32) -> Self {
let head_dim = embed_dim / num_heads;
let scale = 1.0 / (head_dim as f32).sqrt();
Self {
num_heads,
head_dim,
scale,
qkv: Linear::new(embed_dim, embed_dim * 3, true),
proj: Linear::new(embed_dim, embed_dim, true),
dropout: Dropout::new(dropout),
}
}
}
impl Module for MultiHeadAttention {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let batch_size = input.shape().dims()[0];
let seq_len = input.shape().dims()[1];
let _embed_dim = input.shape().dims()[2];
let qkv = self.qkv.forward(input)?;
let qkv = qkv.view(&[
batch_size as i32,
seq_len as i32,
3,
self.num_heads as i32,
self.head_dim as i32,
])?;
let _qkv = qkv.permute(&[2, 0, 3, 1, 4])?;
let _q = input.clone(); let _k = input.clone(); let v = input.clone();
let out = self.proj.forward(&v)?;
self.dropout.forward(&out)
}
fn parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
params.extend(self.qkv.parameters());
params.extend(self.proj.parameters());
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.parameters()
}
fn training(&self) -> bool {
self.qkv.training()
}
fn train(&mut self) {
self.qkv.train();
self.proj.train();
self.dropout.train();
}
fn eval(&mut self) {
self.qkv.eval();
self.proj.eval();
self.dropout.eval();
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.qkv.to_device(device)?;
self.proj.to_device(device)?;
self.dropout.to_device(device)?;
Ok(())
}
fn set_training(&mut self, training: bool) {
if training {
self.train();
} else {
self.eval();
}
}
}
pub struct TransformerBlock {
norm1: LayerNorm,
attn: MultiHeadAttention,
norm2: LayerNorm,
mlp: Sequential,
dropout: Dropout,
}
impl TransformerBlock {
pub fn new(embed_dim: usize, num_heads: usize, mlp_ratio: usize, dropout: f32) -> Result<Self> {
let mlp_hidden_dim = embed_dim * mlp_ratio;
let mlp = Sequential::new()
.add(Linear::new(embed_dim, mlp_hidden_dim, true))
.add(GELU::new(false))
.add(Linear::new(mlp_hidden_dim, embed_dim, true))
.add(Dropout::new(dropout));
Ok(Self {
norm1: LayerNorm::new(vec![embed_dim], 1e-5, true, torsh_core::DeviceType::Cpu)?,
attn: MultiHeadAttention::new(embed_dim, num_heads, dropout),
norm2: LayerNorm::new(vec![embed_dim], 1e-5, true, torsh_core::DeviceType::Cpu)?,
mlp,
dropout: Dropout::new(dropout),
})
}
}
impl Module for TransformerBlock {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let normed = self.norm1.forward(input)?;
let attn_out = self.attn.forward(&normed)?;
let x = input.add(&attn_out)?;
let normed = self.norm2.forward(&x)?;
let mlp_out = self.mlp.forward(&normed)?;
x.add(&mlp_out)
}
fn parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
params.extend(self.norm1.parameters());
params.extend(self.attn.parameters());
params.extend(self.norm2.parameters());
params.extend(self.mlp.parameters());
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.parameters()
}
fn training(&self) -> bool {
self.attn.training()
}
fn train(&mut self) {
self.norm1.train();
self.attn.train();
self.norm2.train();
self.mlp.train();
self.dropout.train();
}
fn eval(&mut self) {
self.norm1.eval();
self.attn.eval();
self.norm2.eval();
self.mlp.eval();
self.dropout.eval();
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.norm1.to_device(device)?;
self.attn.to_device(device)?;
self.norm2.to_device(device)?;
self.mlp.to_device(device)?;
self.dropout.to_device(device)?;
Ok(())
}
fn set_training(&mut self, training: bool) {
if training {
self.train();
} else {
self.eval();
}
}
}
#[derive(Debug, Clone)]
pub struct ViTConfig {
pub image_size: usize,
pub patch_size: usize,
pub embed_dim: usize,
pub num_layers: usize,
pub num_heads: usize,
pub mlp_ratio: usize,
pub num_classes: usize,
pub dropout: f32,
}
impl ViTConfig {
pub fn vit_tiny_patch16_224() -> Self {
Self {
image_size: 224,
patch_size: 16,
embed_dim: 192,
num_layers: 12,
num_heads: 3,
mlp_ratio: 4,
num_classes: 1000,
dropout: 0.0,
}
}
pub fn vit_small_patch16_224() -> Self {
Self {
image_size: 224,
patch_size: 16,
embed_dim: 384,
num_layers: 12,
num_heads: 6,
mlp_ratio: 4,
num_classes: 1000,
dropout: 0.0,
}
}
pub fn vit_base_patch16_224() -> Self {
Self {
image_size: 224,
patch_size: 16,
embed_dim: 768,
num_layers: 12,
num_heads: 12,
mlp_ratio: 4,
num_classes: 1000,
dropout: 0.1,
}
}
pub fn vit_large_patch16_224() -> Self {
Self {
image_size: 224,
patch_size: 16,
embed_dim: 1024,
num_layers: 24,
num_heads: 16,
mlp_ratio: 4,
num_classes: 1000,
dropout: 0.1,
}
}
pub fn vit_huge_patch14_224() -> Self {
Self {
image_size: 224,
patch_size: 14,
embed_dim: 1280,
num_layers: 32,
num_heads: 16,
mlp_ratio: 4,
num_classes: 1000,
dropout: 0.1,
}
}
}
pub struct VisionTransformer {
patch_embed: Conv2d,
cls_token: Parameter,
pos_embed: Parameter,
pos_dropout: Dropout,
blocks: Vec<TransformerBlock>,
norm: LayerNorm,
head: Linear,
config: ViTConfig,
is_training: bool,
}
impl VisionTransformer {
pub fn new(vit_config: ViTConfig) -> Result<Self> {
let num_patches = (vit_config.image_size / vit_config.patch_size).pow(2);
let seq_len = num_patches + 1;
let patch_embed = Conv2d::new(
3,
vit_config.embed_dim,
(vit_config.patch_size, vit_config.patch_size),
(vit_config.patch_size, vit_config.patch_size),
(0, 0),
(1, 1),
false,
1,
);
let cls_token = Parameter::new(
torsh_tensor::creation::zeros(&[1, 1, vit_config.embed_dim])
.expect("tensor creation should succeed"),
);
let pos_embed = Parameter::new(
torsh_tensor::creation::zeros(&[1, seq_len, vit_config.embed_dim])
.expect("tensor creation should succeed"),
);
let mut blocks = Vec::new();
for _ in 0..vit_config.num_layers {
blocks.push(TransformerBlock::new(
vit_config.embed_dim,
vit_config.num_heads,
vit_config.mlp_ratio,
vit_config.dropout,
)?);
}
Ok(Self {
patch_embed,
cls_token,
pos_embed,
pos_dropout: Dropout::new(vit_config.dropout),
blocks,
norm: LayerNorm::new(
vec![vit_config.embed_dim],
1e-5,
true,
torsh_core::DeviceType::Cpu,
)?,
head: Linear::new(vit_config.embed_dim, vit_config.num_classes, true),
config: vit_config,
is_training: true,
})
}
pub fn from_config(model_config: ModelConfig, vit_config: ViTConfig) -> Result<Self> {
let mut config = vit_config;
config.num_classes = model_config.num_classes;
config.dropout = model_config.dropout;
Self::new(config)
}
pub fn vit_tiny_patch16_224(config: ModelConfig) -> Result<Self> {
Self::from_config(config, ViTConfig::vit_tiny_patch16_224())
}
pub fn vit_small_patch16_224(config: ModelConfig) -> Result<Self> {
Self::from_config(config, ViTConfig::vit_small_patch16_224())
}
pub fn vit_base_patch16_224(config: ModelConfig) -> Result<Self> {
Self::from_config(config, ViTConfig::vit_base_patch16_224())
}
pub fn vit_large_patch16_224(config: ModelConfig) -> Result<Self> {
Self::from_config(config, ViTConfig::vit_large_patch16_224())
}
pub fn vit_huge_patch14_224(config: ModelConfig) -> Result<Self> {
Self::from_config(config, ViTConfig::vit_huge_patch14_224())
}
}
impl Module for VisionTransformer {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let batch_size = input.shape().dims()[0];
let mut x = self.patch_embed.forward(input)?;
let num_patches = x.shape().dims()[2] * x.shape().dims()[3];
x = x.view(&[
batch_size as i32,
self.config.embed_dim as i32,
num_patches as i32,
])?;
x = x.transpose(1, 2)?;
let cls_tokens =
self.cls_token
.tensor()
.read()
.expand(&[batch_size, 1, self.config.embed_dim])?;
x = Tensor::cat(&[&cls_tokens, &x], 1)?;
x = x.add(&*self.pos_embed.tensor().read())?;
x = self.pos_dropout.forward(&x)?;
for block in &self.blocks {
x = block.forward(&x)?;
}
x = self.norm.forward(&x)?;
let cls_token = x.slice(1, 0, 1)?; let cls_token = cls_token
.to_tensor()?
.view(&[batch_size as i32, self.config.embed_dim as i32])?;
self.head.forward(&cls_token)
}
fn parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
params.extend(self.patch_embed.parameters());
params.insert("cls_token".to_string(), self.cls_token.clone());
params.insert("pos_embed".to_string(), self.pos_embed.clone());
for (i, block) in self.blocks.iter().enumerate() {
let block_params = block.parameters();
for (key, param) in block_params {
params.insert(format!("block{}.{}", i, key), param);
}
}
params.extend(self.norm.parameters());
params.extend(self.head.parameters());
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.parameters()
}
fn training(&self) -> bool {
self.is_training
}
fn train(&mut self) {
self.is_training = true;
self.patch_embed.train();
self.pos_dropout.train();
for block in &mut self.blocks {
block.train();
}
self.norm.train();
self.head.train();
}
fn eval(&mut self) {
self.is_training = false;
self.patch_embed.eval();
self.pos_dropout.eval();
for block in &mut self.blocks {
block.eval();
}
self.norm.eval();
self.head.eval();
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.patch_embed.to_device(device)?;
self.pos_dropout.to_device(device)?;
for block in &mut self.blocks {
block.to_device(device)?;
}
self.norm.to_device(device)?;
self.head.to_device(device)?;
Ok(())
}
fn set_training(&mut self, training: bool) {
if training {
self.train();
} else {
self.eval();
}
}
}
impl VisionModel for VisionTransformer {
fn num_classes(&self) -> usize {
self.config.num_classes
}
fn input_size(&self) -> (usize, usize) {
(self.config.image_size, self.config.image_size)
}
fn name(&self) -> &str {
"VisionTransformer"
}
}