use crate::t5::{T5Config, T5ForConditionalGeneration, T5Model};
use trustformers_core::{
tensor::Tensor,
traits::{Model, TokenizedInput},
};
#[test]
fn test_t5_model_creation() {
let config = T5Config::default();
let model = T5Model::new(config.clone()).expect("operation failed");
assert_eq!(model.get_config().num_layers, 6);
assert_eq!(model.get_config().num_heads, 8);
}
#[test]
fn test_t5_forward_pass() {
let config = T5Config {
vocab_size: 100,
d_model: 64,
d_kv: 16,
d_ff: 256,
num_layers: 2,
num_heads: 4,
..Default::default()
};
let model = T5Model::new(config).expect("operation failed");
let input = super::T5Input {
input_ids: TokenizedInput {
input_ids: vec![1, 2, 3, 4, 5],
attention_mask: vec![1u8; 5],
token_type_ids: None,
offset_mapping: None,
special_tokens_mask: None,
overflowing_tokens: None,
},
decoder_input_ids: Some(TokenizedInput {
input_ids: vec![0, 6, 7, 8],
attention_mask: vec![1u8; 4],
token_type_ids: None,
offset_mapping: None,
special_tokens_mask: None,
overflowing_tokens: None,
}),
encoder_outputs: None,
};
let output = model.forward(input).expect("operation failed");
match &output.last_hidden_state {
Tensor::F32(arr) => {
println!("Output shape: {:?}", arr.shape());
if arr.ndim() == 2 {
assert_eq!(arr.shape()[0], 4); assert_eq!(arr.shape()[1], 64); } else {
assert_eq!(arr.shape()[0], 1); assert_eq!(arr.shape()[1], 4); assert_eq!(arr.shape()[2], 64); }
},
_ => panic!("Expected F32 tensor"),
}
}
#[test]
fn test_t5_lm_forward_pass() {
let config = T5Config {
vocab_size: 100,
d_model: 64,
d_kv: 16,
d_ff: 256,
num_layers: 2,
num_heads: 4,
..Default::default()
};
let model = T5ForConditionalGeneration::new(config).expect("operation failed");
let input = super::T5Input {
input_ids: TokenizedInput {
input_ids: vec![1, 2, 3, 4, 5],
attention_mask: vec![1u8; 5],
token_type_ids: None,
offset_mapping: None,
special_tokens_mask: None,
overflowing_tokens: None,
},
decoder_input_ids: Some(TokenizedInput {
input_ids: vec![0, 6, 7, 8],
attention_mask: vec![1u8; 4],
token_type_ids: None,
offset_mapping: None,
special_tokens_mask: None,
overflowing_tokens: None,
}),
encoder_outputs: None,
};
let output = model.forward(input).expect("operation failed");
match &output.logits {
Tensor::F32(arr) => {
if arr.ndim() == 2 {
assert_eq!(arr.shape()[0], 4); assert_eq!(arr.shape()[1], 100); } else {
assert_eq!(arr.shape()[0], 1); assert_eq!(arr.shape()[1], 4); assert_eq!(arr.shape()[2], 100); }
},
_ => panic!("Expected F32 tensor"),
}
}
#[test]
fn test_relative_position_bias() {
let config = T5Config {
vocab_size: 100,
d_model: 64,
d_kv: 16,
d_ff: 256,
num_layers: 1,
num_heads: 4,
relative_attention_num_buckets: 32,
relative_attention_max_distance: 128,
..Default::default()
};
let model = T5Model::new(config).expect("operation failed");
let input = super::T5Input {
input_ids: TokenizedInput {
input_ids: vec![1, 2, 3, 4],
attention_mask: vec![1u8; 4],
token_type_ids: None,
offset_mapping: None,
special_tokens_mask: None,
overflowing_tokens: None,
},
decoder_input_ids: Some(TokenizedInput {
input_ids: vec![0, 5, 6],
attention_mask: vec![1u8; 3],
token_type_ids: None,
offset_mapping: None,
special_tokens_mask: None,
overflowing_tokens: None,
}),
encoder_outputs: None,
};
let output = model.forward(input).expect("operation failed");
match &output.last_hidden_state {
Tensor::F32(arr) => {
if arr.ndim() == 2 {
assert_eq!(arr.shape()[0], 3); assert_eq!(arr.shape()[1], 64); } else {
assert_eq!(arr.shape()[0], 1); assert_eq!(arr.shape()[1], 3); assert_eq!(arr.shape()[2], 64); }
},
_ => panic!("Expected F32 tensor"),
}
}
#[test]
fn test_t5_encoder_decoder_separation() {
let config = T5Config {
vocab_size: 50,
d_model: 32,
d_kv: 8,
d_ff: 128,
num_layers: 1,
num_decoder_layers: Some(2), num_heads: 4,
..Default::default()
};
let model = T5Model::new(config).expect("operation failed");
let encoder_input = super::T5Input {
input_ids: TokenizedInput {
input_ids: vec![1, 2, 3],
attention_mask: vec![1u8; 3],
token_type_ids: None,
offset_mapping: None,
special_tokens_mask: None,
overflowing_tokens: None,
},
decoder_input_ids: None,
encoder_outputs: None,
};
let encoder_output = model.forward(encoder_input).expect("operation failed");
assert!(encoder_output.encoder_last_hidden_state.is_some());
}