#[cfg(test)]
#[allow(clippy::module_inception)]
mod tests {
use crate::retnet::config::RetNetConfig;
use crate::retnet::model::{
AdvancedChunkProcessor, MultiScaleRetention, RetNetDecoderLayer, RetNetEmbeddings,
RetNetFFN, RetNetForLanguageModeling, RetNetForSequenceClassification, RetNetLongSequence,
RetNetModel, RetNetStateCache, RotaryPositionEmbedding,
};
use trustformers_core::tensor::Tensor;
use trustformers_core::traits::{Config, Layer, Model};
fn tiny_config() -> RetNetConfig {
RetNetConfig {
vocab_size: 128,
hidden_size: 32,
num_hidden_layers: 2,
num_heads: 4,
intermediate_size: 64,
retention_heads: 4,
max_position_embeddings: 64,
chunk_size: 16,
deepnorm: false,
..RetNetConfig::default()
}
}
#[test]
fn test_retnet_config_default_validates() {
let config = RetNetConfig::default();
assert!(config.validate().is_ok());
}
#[test]
fn test_retnet_config_invalid_hidden_size_heads() {
let config = RetNetConfig {
hidden_size: 33,
num_heads: 4,
retention_heads: 4,
..RetNetConfig::default()
};
assert!(config.validate().is_err());
}
#[test]
fn test_retnet_config_invalid_retention_heads() {
let config = RetNetConfig {
hidden_size: 32,
retention_heads: 5,
num_heads: 4,
..RetNetConfig::default()
};
assert!(config.validate().is_err());
}
#[test]
fn test_retnet_config_invalid_chunk_size() {
let config = RetNetConfig {
chunk_size: 10000,
max_position_embeddings: 2048,
..RetNetConfig::default()
};
assert!(config.validate().is_err());
}
#[test]
fn test_retnet_config_presets() {
let small = RetNetConfig::retnet_small();
assert!(small.validate().is_ok());
assert_eq!(small.hidden_size, 2048);
let medium = RetNetConfig::retnet_medium();
assert!(medium.validate().is_ok());
assert_eq!(medium.hidden_size, 2560);
let large = RetNetConfig::retnet_large();
assert!(large.validate().is_ok());
assert_eq!(large.hidden_size, 4096);
let xl = RetNetConfig::retnet_xl();
assert!(xl.validate().is_ok());
assert_eq!(xl.hidden_size, 5120);
assert!(xl.deepnorm);
}
#[test]
fn test_retnet_config_long_preset() {
let long = RetNetConfig::retnet_long();
assert!(long.validate().is_ok());
assert!(long.chunking);
assert!(long.sequence_parallel);
assert_eq!(long.max_position_embeddings, 8192);
}
#[test]
fn test_retnet_config_head_dim() {
let config = tiny_config();
assert_eq!(config.head_dim(), 8); assert_eq!(config.retention_head_dim(), 8); }
#[test]
fn test_retnet_config_retention_dim() {
let config = tiny_config();
assert_eq!(config.retention_dim(), 16);
}
#[test]
fn test_retnet_config_uses_chunking() {
let mut config = tiny_config();
config.chunking = false;
assert!(!config.uses_chunking());
config.chunking = true;
config.chunk_size = 16;
assert!(config.uses_chunking());
config.chunk_size = 0;
assert!(!config.uses_chunking());
}
#[test]
fn test_retnet_config_memory_advantage() {
let config = tiny_config();
let advantage = config.memory_advantage();
assert!((advantage - 64.0).abs() < 0.01);
}
#[test]
fn test_retnet_config_supports_long_sequences() {
let mut config = tiny_config();
config.max_position_embeddings = 1024;
config.chunking = false;
assert!(!config.supports_long_sequences());
config.max_position_embeddings = 4096;
assert!(config.supports_long_sequences());
config.max_position_embeddings = 64;
config.chunking = true;
config.chunk_size = 16;
assert!(config.supports_long_sequences());
}
#[test]
fn test_retnet_config_deepnorm_factors() {
let config = tiny_config();
let alpha = config.deepnorm_alpha();
let beta = config.deepnorm_beta();
assert!((alpha - (4.0_f32).powf(0.25)).abs() < 0.001);
assert!((beta - (16.0_f32).powf(-0.25)).abs() < 0.001);
}
#[test]
fn test_retnet_config_architecture() {
let config = RetNetConfig::default();
assert_eq!(config.architecture(), "RetNet");
}
#[test]
fn test_rotary_position_embedding_creation() {
let rope = RotaryPositionEmbedding::new(8, 64, 10000.0);
assert!(rope.is_ok());
}
#[test]
fn test_rotary_position_embedding_device() {
let rope = RotaryPositionEmbedding::new(8, 64, 10000.0)
.expect("Failed to create RotaryPositionEmbedding");
assert_eq!(rope.device(), trustformers_core::device::Device::CPU);
}
#[test]
fn test_rotary_position_embedding_inv_freq() {
let dim = 8;
let rope = RotaryPositionEmbedding::new(dim, 64, 10000.0)
.expect("Failed to create RotaryPositionEmbedding");
assert_eq!(rope.device(), trustformers_core::device::Device::CPU);
}
#[test]
fn test_state_cache_creation() {
let cache = RetNetStateCache::new(10);
assert_eq!(cache.size(), 0);
}
#[test]
fn test_state_cache_set_and_get() {
let mut cache = RetNetStateCache::new(10);
let state = Tensor::zeros(&[2, 2]).expect("Failed to create tensor");
cache.set_state(0, state).expect("Failed to set state");
assert_eq!(cache.size(), 1);
assert!(cache.get_state(0).is_some());
assert!(cache.get_state(1).is_none());
}
#[test]
fn test_state_cache_clear() {
let mut cache = RetNetStateCache::new(10);
let s1 = Tensor::zeros(&[2, 2]).expect("Failed to create tensor");
let s2 = Tensor::zeros(&[2, 2]).expect("Failed to create tensor");
cache.set_state(0, s1).expect("Failed to set state");
cache.set_state(1, s2).expect("Failed to set state");
assert_eq!(cache.size(), 2);
cache.clear();
assert_eq!(cache.size(), 0);
assert!(cache.get_state(0).is_none());
}
#[test]
fn test_state_cache_eviction() {
let mut cache = RetNetStateCache::new(2);
let s1 = Tensor::zeros(&[2, 2]).expect("Failed to create tensor");
let s2 = Tensor::zeros(&[2, 2]).expect("Failed to create tensor");
let s3 = Tensor::zeros(&[2, 2]).expect("Failed to create tensor");
cache.set_state(0, s1).expect("Failed to set state");
cache.set_state(1, s2).expect("Failed to set state");
cache.set_state(2, s3).expect("Failed to set state");
assert!(cache.size() <= 2);
}
#[test]
fn test_chunk_processor_creation() {
let processor = AdvancedChunkProcessor::new(512, 128, false);
let _ = processor;
}
#[test]
fn test_chunk_processor_short_sequence() {
let processor = AdvancedChunkProcessor::new(512, 128, false);
let input = Tensor::zeros(&[1, 10, 32]).expect("Failed to create tensor");
let result = processor.process_chunks(&input, |chunk, _state| {
let state = Tensor::zeros(&[1]).expect("Failed to create state");
Ok((chunk.clone(), state))
});
assert!(result.is_ok());
let output = result.expect("Failed to process chunks");
assert_eq!(output.shape()[0], 1);
assert_eq!(output.shape()[1], 10);
assert_eq!(output.shape()[2], 32);
}
#[test]
fn test_retnet_embeddings_creation() {
let config = tiny_config();
let emb = RetNetEmbeddings::new(&config);
assert!(emb.is_ok());
}
#[test]
fn test_retnet_embeddings_forward() {
let config = tiny_config();
let emb = RetNetEmbeddings::new(&config).expect("Failed to create embeddings");
let input = vec![1, 2, 3, 4];
let output = emb.forward(input);
assert!(output.is_ok());
let out = output.expect("Forward pass failed");
assert_eq!(out.shape()[out.shape().len() - 1], 32);
}
#[test]
fn test_retnet_embeddings_with_layernorm() {
let mut config = tiny_config();
config.layernorm_embedding = true;
let emb = RetNetEmbeddings::new(&config);
assert!(emb.is_ok());
let emb = emb.expect("Failed to create embeddings");
assert!(emb.parameter_count() > 0);
}
#[test]
fn test_retnet_embeddings_parameter_count() {
let config = tiny_config();
let emb = RetNetEmbeddings::new(&config).expect("Failed to create embeddings");
assert!(emb.parameter_count() >= 4096);
}
#[test]
fn test_retnet_ffn_creation() {
let config = tiny_config();
let ffn = RetNetFFN::new(&config);
assert!(ffn.is_ok());
}
#[test]
fn test_retnet_ffn_forward_glu() {
let mut config = tiny_config();
config.use_glu = true;
let ffn = RetNetFFN::new(&config).expect("Failed to create FFN");
let input = Tensor::zeros(&[1, 4, 32]).expect("Failed to create tensor");
let output = ffn.forward(input);
assert!(output.is_ok());
let out = output.expect("Forward pass failed");
assert_eq!(out.shape(), &[1, 4, 32]);
}
#[test]
fn test_retnet_ffn_forward_no_glu() {
let mut config = tiny_config();
config.use_glu = false;
let ffn = RetNetFFN::new(&config).expect("Failed to create FFN");
let input = Tensor::zeros(&[1, 4, 32]).expect("Failed to create tensor");
let output = ffn.forward(input);
assert!(output.is_ok());
}
#[test]
fn test_retnet_ffn_parameter_count() {
let config = tiny_config();
let ffn = RetNetFFN::new(&config).expect("Failed to create FFN");
assert!(ffn.parameter_count() > 0);
}
#[test]
fn test_retnet_model_creation() {
let config = tiny_config();
let model = RetNetModel::new(config);
assert!(model.is_ok());
}
#[test]
fn test_retnet_model_creation_and_config() {
let config = forward_config();
let model = RetNetModel::new(config).expect("Failed to create model");
assert!(model.num_parameters() > 0);
assert_eq!(model.get_config().hidden_size, 64);
}
#[test]
fn test_retnet_model_num_parameters() {
let config = tiny_config();
let model = RetNetModel::new(config).expect("Failed to create model");
assert!(model.num_parameters() > 0);
}
#[test]
fn test_retnet_model_get_config() {
let config = tiny_config();
let model = RetNetModel::new(config).expect("Failed to create model");
let cfg = model.get_config();
assert_eq!(cfg.hidden_size, 32);
assert_eq!(cfg.num_hidden_layers, 2);
}
#[test]
fn test_retnet_lm_creation() {
let config = tiny_config();
let model = RetNetForLanguageModeling::new(config);
assert!(model.is_ok());
}
fn forward_config() -> RetNetConfig {
RetNetConfig {
vocab_size: 128,
hidden_size: 64,
num_hidden_layers: 1,
num_heads: 4,
intermediate_size: 128,
retention_heads: 4,
max_position_embeddings: 512,
chunk_size: 64,
value_factor: 1.0,
deepnorm: false,
chunking: false,
..RetNetConfig::default()
}
}
#[test]
fn test_retnet_lm_creation_and_params() {
let config = forward_config();
let model = RetNetForLanguageModeling::new(config).expect("Failed to create model");
assert!(model.num_parameters() > 0);
let cfg = model.get_config();
assert_eq!(cfg.hidden_size, 64);
}
#[test]
fn test_retnet_lm_no_output_layer() {
let mut config = forward_config();
config.no_output_layer = true;
let model = RetNetForLanguageModeling::new(config);
assert!(model.is_ok());
}
#[test]
fn test_retnet_lm_num_parameters() {
let config = tiny_config();
let model = RetNetForLanguageModeling::new(config).expect("Failed to create model");
assert!(model.num_parameters() > 0);
}
#[test]
fn test_retnet_classification_creation() {
let config = tiny_config();
let model = RetNetForSequenceClassification::new(config, 3);
assert!(model.is_ok());
}
#[test]
fn test_retnet_classification_forward_creates_ok() {
let config = tiny_config();
let model =
RetNetForSequenceClassification::new(config, 3).expect("Failed to create model");
assert!(model.num_parameters() > 0);
let cfg = model.get_config();
assert_eq!(cfg.hidden_size, 32);
}
#[test]
fn test_retnet_classification_num_parameters() {
let config = tiny_config();
let model =
RetNetForSequenceClassification::new(config, 3).expect("Failed to create model");
assert!(model.num_parameters() > 0);
}
#[test]
fn test_retnet_long_sequence_creation() {
let config = tiny_config();
let model = RetNetLongSequence::new(config, 16);
assert!(model.is_ok());
}
#[test]
fn test_retnet_long_sequence_params() {
let config = forward_config();
let model = RetNetLongSequence::new(config, 32).expect("Failed to create model");
let stats = model.get_memory_stats();
assert_eq!(stats.chunk_size, 32);
}
#[test]
fn test_retnet_long_sequence_memory_stats() {
let config = tiny_config();
let model = RetNetLongSequence::new(config, 16).expect("Failed to create model");
let stats = model.get_memory_stats();
assert_eq!(stats.chunk_size, 16);
assert_eq!(stats.overlap_size, 4); assert!(stats.estimated_memory_mb > 0.0);
}
#[test]
fn test_multi_scale_retention_creation() {
let config = tiny_config();
let msr = MultiScaleRetention::new(&config);
assert!(msr.is_ok());
}
#[test]
fn test_multi_scale_retention_parameter_count() {
let config = tiny_config();
let msr = MultiScaleRetention::new(&config).expect("Failed to create MSR");
assert!(msr.parameter_count() > 0);
}
#[test]
fn test_multi_scale_retention_set_inference_mode() {
let config = tiny_config();
let mut msr = MultiScaleRetention::new(&config).expect("Failed to create MSR");
msr.set_inference_mode(Some(10));
}
#[test]
fn test_multi_scale_retention_clear_cache() {
let config = tiny_config();
let mut msr = MultiScaleRetention::new(&config).expect("Failed to create MSR");
msr.clear_cache();
}
#[test]
fn test_retnet_decoder_layer_creation() {
let config = tiny_config();
let layer = RetNetDecoderLayer::new(&config);
assert!(layer.is_ok());
}
#[test]
fn test_retnet_decoder_layer_parameter_count() {
let config = tiny_config();
let layer = RetNetDecoderLayer::new(&config).expect("Failed to create layer");
assert!(layer.parameter_count() > 0);
}
#[test]
fn test_retnet_decoder_layer_with_deepnorm() {
let mut config = tiny_config();
config.deepnorm = true;
let layer = RetNetDecoderLayer::new(&config);
assert!(layer.is_ok());
}
}