#[cfg(test)]
#[allow(clippy::module_inception)]
mod tests {
use crate::linformer::config::LinformerConfig;
use crate::linformer::model::{
LinformerEmbeddings, LinformerEncoder, LinformerFeedForward, LinformerForMaskedLM,
LinformerForSequenceClassification, LinformerLayer, LinformerModel,
};
use trustformers_core::tensor::Tensor;
use trustformers_core::traits::{Config, Layer, Model};
fn tiny_config() -> LinformerConfig {
LinformerConfig {
vocab_size: 128,
hidden_size: 32,
num_hidden_layers: 2,
num_attention_heads: 4,
intermediate_size: 64,
max_position_embeddings: 64,
projected_attention_size: 16,
type_vocab_size: 2,
share_projection: true,
share_layers: false,
use_efficient_attention: true,
..LinformerConfig::default()
}
}
#[test]
fn test_linformer_config_default_validates() {
let config = LinformerConfig::default();
assert!(config.validate().is_ok());
}
#[test]
fn test_linformer_config_invalid_hidden_heads() {
let config = LinformerConfig {
hidden_size: 33,
num_attention_heads: 4,
..LinformerConfig::default()
};
assert!(config.validate().is_err());
}
#[test]
fn test_linformer_config_invalid_projected_size() {
let config = LinformerConfig {
projected_attention_size: 1024,
max_position_embeddings: 512,
..LinformerConfig::default()
};
assert!(config.validate().is_err());
}
#[test]
fn test_linformer_config_presets() {
let base = LinformerConfig::linformer_base();
assert!(base.validate().is_ok());
assert_eq!(base.hidden_size, 768);
let large = LinformerConfig::linformer_large();
assert!(large.validate().is_ok());
assert_eq!(large.hidden_size, 1024);
assert_eq!(large.num_hidden_layers, 24);
let long = LinformerConfig::linformer_long();
assert!(long.validate().is_ok());
assert_eq!(long.max_position_embeddings, 8192);
}
#[test]
fn test_linformer_config_head_dim() {
let config = tiny_config();
assert_eq!(config.head_dim(), 8); }
#[test]
fn test_linformer_config_compression_ratio() {
let config = tiny_config();
let ratio = config.compression_ratio();
assert!((ratio - 0.25).abs() < 0.001);
}
#[test]
fn test_linformer_config_architecture() {
let config = LinformerConfig::default();
assert_eq!(config.architecture(), "Linformer");
}
#[test]
fn test_linformer_embeddings_creation() {
let config = tiny_config();
let emb = LinformerEmbeddings::new(&config);
assert!(emb.is_ok());
}
#[test]
fn test_linformer_embeddings_forward() {
let config = tiny_config();
let emb = LinformerEmbeddings::new(&config).expect("Failed to create embeddings");
let input = (vec![1_u32, 2, 3, 4], None, None);
let output = emb.forward(input);
assert!(output.is_ok());
}
#[test]
fn test_linformer_embeddings_with_token_types() {
let config = tiny_config();
let emb = LinformerEmbeddings::new(&config).expect("Failed to create embeddings");
let input = (vec![1_u32, 2, 3], Some(vec![0_u32, 0, 1]), None);
let output = emb.forward(input);
assert!(output.is_ok());
}
#[test]
fn test_linformer_embeddings_with_position_ids() {
let config = tiny_config();
let emb = LinformerEmbeddings::new(&config).expect("Failed to create embeddings");
let input = (vec![1_u32, 2, 3], None, Some(vec![0_u32, 1, 2]));
let output = emb.forward(input);
assert!(output.is_ok());
}
#[test]
fn test_linformer_embeddings_parameter_count() {
let config = tiny_config();
let emb = LinformerEmbeddings::new(&config).expect("Failed to create embeddings");
assert!(emb.parameter_count() > 0);
assert!(emb.parameter_count() >= 6000);
}
#[test]
fn test_linformer_ffn_creation() {
let config = tiny_config();
let ffn = LinformerFeedForward::new(&config);
assert!(ffn.is_ok());
}
#[test]
fn test_linformer_ffn_forward() {
let config = tiny_config();
let ffn = LinformerFeedForward::new(&config).expect("Failed to create FFN");
let input = Tensor::zeros(&[1, 4, 32]).expect("Failed to create tensor");
let output = ffn.forward(input);
assert!(output.is_ok());
let out = output.expect("Forward pass failed");
assert_eq!(out.shape(), &[1, 4, 32]);
}
#[test]
fn test_linformer_ffn_parameter_count() {
let config = tiny_config();
let ffn = LinformerFeedForward::new(&config).expect("Failed to create FFN");
assert!(ffn.parameter_count() > 0);
}
#[test]
fn test_linformer_layer_creation() {
let config = tiny_config();
let layer = LinformerLayer::new(&config);
assert!(layer.is_ok());
}
#[test]
fn test_linformer_layer_parameter_count() {
let config = tiny_config();
let layer = LinformerLayer::new(&config).expect("Failed to create layer");
assert!(layer.parameter_count() > 0);
}
#[test]
fn test_linformer_encoder_creation() {
let config = tiny_config();
let encoder = LinformerEncoder::new(&config);
assert!(encoder.is_ok());
}
#[test]
fn test_linformer_encoder_shared_projections() {
let mut config = tiny_config();
config.share_layers = true;
let encoder = LinformerEncoder::new(&config);
assert!(encoder.is_ok());
}
#[test]
fn test_linformer_encoder_parameter_count() {
let config = tiny_config();
let encoder = LinformerEncoder::new(&config).expect("Failed to create encoder");
assert!(encoder.parameter_count() > 0);
}
#[test]
fn test_linformer_model_creation() {
let config = tiny_config();
let model = LinformerModel::new(config);
assert!(model.is_ok());
}
#[test]
fn test_linformer_model_forward() {
let config = tiny_config();
let model = LinformerModel::new(config).expect("Failed to create model");
let input = (vec![1_u32, 2, 3, 4], None, None);
let _output = model.forward(input);
}
#[test]
fn test_linformer_model_get_config() {
let config = tiny_config();
let model = LinformerModel::new(config).expect("Failed to create model");
let cfg = model.get_config();
assert_eq!(cfg.hidden_size, 32);
}
#[test]
fn test_linformer_model_num_parameters() {
let config = tiny_config();
let model = LinformerModel::new(config).expect("Failed to create model");
assert!(model.num_parameters() > 0);
}
#[test]
fn test_linformer_seq_cls_creation() {
let config = tiny_config();
let model = LinformerForSequenceClassification::new(config, 3);
assert!(model.is_ok());
}
#[test]
fn test_linformer_seq_cls_forward() {
let config = tiny_config();
let model =
LinformerForSequenceClassification::new(config, 3).expect("Failed to create model");
let input = (vec![1_u32, 2, 3, 4], None, None);
let _output = model.forward(input);
}
#[test]
fn test_linformer_seq_cls_num_parameters() {
let config = tiny_config();
let model =
LinformerForSequenceClassification::new(config, 5).expect("Failed to create model");
assert!(model.num_parameters() > 0);
}
#[test]
fn test_linformer_mlm_creation() {
let config = tiny_config();
let model = LinformerForMaskedLM::new(config);
assert!(model.is_ok());
}
#[test]
fn test_linformer_mlm_forward() {
let config = tiny_config();
let model = LinformerForMaskedLM::new(config).expect("Failed to create model");
let input = (vec![1_u32, 2, 3, 4], None, None);
let _output = model.forward(input);
}
#[test]
fn test_linformer_mlm_num_parameters() {
let config = tiny_config();
let model = LinformerForMaskedLM::new(config).expect("Failed to create model");
assert!(model.num_parameters() > 0);
}
#[test]
fn test_linformer_model_no_efficient_attention() {
let mut config = tiny_config();
config.use_efficient_attention = false;
let model = LinformerModel::new(config);
assert!(model.is_ok());
let model = model.expect("Failed to create model");
let input = (vec![1_u32, 2, 3], None, None);
let _output = model.forward(input);
}
}