#[cfg(test)]
#[allow(clippy::module_inception)]
mod tests {
use crate::vit::config::ViTConfig;
use crate::vit::model::{PatchEmbedding, ViTForImageClassification, ViTModel};
use scirs2_core::ndarray::Array4; use trustformers_core::traits::Config;
#[test]
fn test_vit_config() {
let config = ViTConfig::base();
assert_eq!(config.image_size, 224);
assert_eq!(config.patch_size, 16);
assert_eq!(config.num_patches(), 196); assert_eq!(config.seq_length(), 197);
config.validate().expect("operation failed");
}
#[test]
fn test_vit_config_variants() {
let tiny = ViTConfig::tiny();
assert_eq!(tiny.hidden_size, 192);
assert_eq!(tiny.num_attention_heads, 3);
assert_eq!(tiny.num_hidden_layers, 12);
let small = ViTConfig::small();
assert_eq!(small.hidden_size, 384);
assert_eq!(small.num_attention_heads, 6);
assert_eq!(small.num_hidden_layers, 12);
let large = ViTConfig::large();
assert_eq!(large.hidden_size, 1024);
assert_eq!(large.num_attention_heads, 16);
assert_eq!(large.num_hidden_layers, 24);
let huge = ViTConfig::huge();
assert_eq!(huge.hidden_size, 1280);
assert_eq!(huge.num_attention_heads, 16);
assert_eq!(huge.num_hidden_layers, 32);
assert_eq!(huge.patch_size, 14); }
#[test]
fn test_patch_embedding() {
let config = ViTConfig {
image_size: 32,
patch_size: 16,
hidden_size: 64,
num_attention_heads: 4,
intermediate_size: 256,
num_hidden_layers: 2,
..ViTConfig::default()
};
let patch_embedding = PatchEmbedding::new(&config);
let image = Array4::zeros((1, 32, 32, 3));
let result = patch_embedding.forward(&image);
assert!(result.is_ok());
let patches = result.expect("operation failed");
assert_eq!(patches.shape(), &[1, 4, 64]);
drop(patches);
drop(patch_embedding);
std::hint::black_box(());
}
#[test]
fn test_patch_embedding_different_sizes() {
let mut config = ViTConfig {
image_size: 32,
patch_size: 16,
hidden_size: 64,
num_attention_heads: 4,
intermediate_size: 256,
num_hidden_layers: 2,
..ViTConfig::default()
};
config.patch_size = 16;
config.image_size = 32;
let patch_embedding = PatchEmbedding::new(&config);
let image = Array4::zeros((1, 32, 32, 3));
let result = patch_embedding.forward(&image).expect("operation failed");
let expected_patches = (32 / 16) * (32 / 16); assert_eq!(result.shape(), &[1, expected_patches, 64]);
drop(result);
drop(patch_embedding);
std::hint::black_box(());
}
#[test]
fn test_vit_model() {
let config = ViTConfig {
image_size: 32,
patch_size: 16,
hidden_size: 64,
num_attention_heads: 4,
intermediate_size: 256,
num_hidden_layers: 2,
..ViTConfig::default()
};
let model = ViTModel::new(config).expect("operation failed");
let images = Array4::zeros((1, 32, 32, 3));
let result = model.forward(&images);
assert!(result.is_ok());
let output = result.expect("operation failed");
assert_eq!(output.shape(), &[1, 5, 64]);
drop(output);
drop(model);
std::hint::black_box(());
}
#[test]
fn test_vit_classification() {
let config = ViTConfig {
image_size: 32,
patch_size: 16,
hidden_size: 64,
num_attention_heads: 4,
intermediate_size: 256,
num_hidden_layers: 2,
num_labels: 10, ..ViTConfig::default()
};
let model = ViTForImageClassification::new(config).expect("operation failed");
let images = Array4::zeros((1, 32, 32, 3));
let result = model.forward(&images);
assert!(result.is_ok());
let logits = result.expect("operation failed");
assert_eq!(logits.shape(), &[1, 10]);
drop(logits);
drop(model);
std::hint::black_box(());
}
#[test]
fn test_vit_class_token_output() {
let config = ViTConfig {
image_size: 32,
patch_size: 16,
hidden_size: 64,
num_attention_heads: 4,
intermediate_size: 256,
num_hidden_layers: 2,
..ViTConfig::default()
};
let model = ViTModel::new(config).expect("operation failed");
let images = Array4::zeros((1, 32, 32, 3));
let result = model.get_class_token_output(&images);
assert!(result.is_ok());
let class_output = result.expect("operation failed");
assert_eq!(class_output.shape(), &[1, 64]);
drop(class_output);
drop(model);
std::hint::black_box(());
}
#[test]
fn test_vit_without_class_token() {
let mut config = ViTConfig {
image_size: 32,
patch_size: 16,
hidden_size: 64,
num_attention_heads: 4,
intermediate_size: 256,
num_hidden_layers: 2,
..ViTConfig::default()
};
config.use_class_token = false;
let model = ViTModel::new(config).expect("operation failed");
let images = Array4::zeros((1, 32, 32, 3));
let output = model.forward(&images).expect("operation failed");
assert_eq!(output.shape(), &[1, 4, 64]);
let class_output = model.get_class_token_output(&images).expect("operation failed");
assert_eq!(class_output.shape(), &[1, 64]);
drop(output);
drop(class_output);
drop(model);
std::hint::black_box(());
}
#[test]
fn test_from_pretrained_name() {
let base = ViTConfig::from_pretrained_name("vit-base-patch16-224");
assert_eq!(base.hidden_size, 768);
let large = ViTConfig::from_pretrained_name("vit-large-patch16-224");
assert_eq!(large.hidden_size, 1024);
let tiny = ViTConfig::from_pretrained_name("vit-tiny-patch16-224");
assert_eq!(tiny.hidden_size, 192);
}
#[test]
fn test_config_validation_errors() {
let mut config = ViTConfig::base();
config.hidden_size = 100;
config.num_attention_heads = 12;
assert!(config.validate().is_err());
config = ViTConfig::base();
config.image_size = 225; assert!(config.validate().is_err());
config = ViTConfig::base();
config.patch_size = 0;
assert!(config.validate().is_err());
}
#[test]
fn test_config_with_different_patch_sizes() {
let base = ViTConfig::base();
let patch32 = base.with_patch_size(32);
assert_eq!(patch32.patch_size, 32);
assert_eq!(patch32.encoder_stride, 32);
assert_eq!(patch32.num_patches(), 49); }
#[test]
fn test_config_with_different_image_sizes() {
let base = ViTConfig::base();
let img384 = base.with_image_size(384);
assert_eq!(img384.image_size, 384);
assert_eq!(img384.num_patches(), 576); }
}