use super::*;
use trustformers_core::{
tensor::Tensor,
traits::{Config, Model},
};
#[test]
fn test_recursive_config_creation() {
let config = RecursiveConfig::default();
assert_eq!(config.hidden_size, 768);
assert_eq!(config.num_attention_heads, 12);
assert_eq!(config.recursion_depth, 3);
assert_eq!(config.chunk_size, 512);
assert!(config.validate().is_ok());
}
#[test]
fn test_long_document_config() {
let config = RecursiveConfig::long_document();
assert_eq!(config.model_type, "recursive-long-document");
assert_eq!(config.hidden_size, 1024);
assert_eq!(config.max_position_embeddings, 32768);
assert_eq!(config.recursion_depth, 4);
assert_eq!(config.chunk_size, 1024);
assert!(config.validate().is_ok());
}
#[test]
fn test_universal_config() {
let config = RecursiveConfig::universal();
assert_eq!(config.model_type, "recursive-universal");
assert!(config.use_universal_transformer);
assert_eq!(config.max_steps, 8);
assert!(config.adaptive_computation_time);
assert!(config.validate().is_ok());
}
#[test]
fn test_memory_efficient_config() {
let config = RecursiveConfig::memory_efficient();
assert_eq!(config.model_type, "recursive-memory-efficient");
assert_eq!(config.hidden_size, 512);
assert!(config.use_gradient_checkpointing);
assert!(config.use_memory_compression);
assert_eq!(config.compression_ratio, 0.25);
assert!(config.validate().is_ok());
}
#[test]
fn test_hierarchical_config() {
let config = RecursiveConfig::hierarchical();
assert_eq!(config.model_type, "recursive-hierarchical");
assert!(config.use_hierarchical_attention);
assert_eq!(config.hierarchy_levels, 3);
assert_eq!(config.level_compression_ratios.len(), 3);
assert!(config.cross_level_attention);
assert!(config.validate().is_ok());
}
#[test]
fn test_code_understanding_config() {
let config = RecursiveConfig::code_understanding();
assert_eq!(config.model_type, "recursive-code");
assert_eq!(config.vocab_size, 50000);
assert_eq!(config.max_position_embeddings, 8192);
assert!(config.use_hierarchical_attention);
assert!(config.validate().is_ok());
}
#[test]
#[allow(clippy::field_reassign_with_default)]
fn test_config_validation() {
let mut config = RecursiveConfig::default();
config.hidden_size = 100; assert!(config.validate().is_err());
config.hidden_size = 768;
assert!(config.validate().is_ok());
config.recursion_depth = 0;
assert!(config.validate().is_err());
config.recursion_depth = 3;
assert!(config.validate().is_ok());
config.chunk_size = 0;
assert!(config.validate().is_err());
config.chunk_size = 512;
assert!(config.validate().is_ok());
config.overlap_size = 512; assert!(config.validate().is_err());
config.overlap_size = 64;
assert!(config.validate().is_ok());
}
#[test]
fn test_adaptive_depth_validation() {
let mut config = RecursiveConfig {
use_adaptive_depth: true,
min_depth: 5,
max_depth: 3, ..RecursiveConfig::default()
};
assert!(config.validate().is_err());
config.max_depth = 7;
assert!(config.validate().is_ok());
}
#[test]
fn test_hierarchical_validation() {
let mut config = RecursiveConfig {
use_hierarchical_attention: true,
hierarchy_levels: 0,
..RecursiveConfig::default()
};
assert!(config.validate().is_err());
config.hierarchy_levels = 3;
config.level_compression_ratios = vec![1.0, 0.5]; assert!(config.validate().is_err());
config.level_compression_ratios = vec![1.0, 0.5, 0.25];
assert!(config.validate().is_ok());
config.level_compression_ratios = vec![1.0, 1.5, 0.25]; assert!(config.validate().is_err());
}
#[test]
fn test_universal_transformer_validation() {
let mut config = RecursiveConfig {
use_universal_transformer: true,
max_steps: 0,
..RecursiveConfig::default()
};
assert!(config.validate().is_err());
config.max_steps = 10;
assert!(config.validate().is_ok());
}
#[test]
fn test_from_pretrained_name() {
assert!(RecursiveConfig::from_pretrained_name("recursive-long-document").is_some());
assert!(RecursiveConfig::from_pretrained_name("recursive-universal").is_some());
assert!(RecursiveConfig::from_pretrained_name("recursive-memory-efficient").is_some());
assert!(RecursiveConfig::from_pretrained_name("recursive-hierarchical").is_some());
assert!(RecursiveConfig::from_pretrained_name("recursive-code").is_some());
assert!(RecursiveConfig::from_pretrained_name("invalid-model").is_none());
}
#[test]
fn test_config_helper_methods() {
let config = RecursiveConfig::default();
assert_eq!(config.head_dim(), 64); assert_eq!(config.num_kv_heads(), 12); assert_eq!(config.effective_chunk_size(), 448); assert_eq!(config.total_memory_capacity(), 1792);
}
#[test]
fn test_config_with_methods() {
let mut config = RecursiveConfig::default();
config.with_memory(2048, true, 0.5);
assert_eq!(config.memory_size, 2048);
assert!(config.use_memory_compression);
assert_eq!(config.compression_ratio, 0.5);
config.with_chunks(1024, 128);
assert_eq!(config.chunk_size, 1024);
assert_eq!(config.overlap_size, 128);
config.with_depth(4, true);
assert_eq!(config.recursion_depth, 4);
assert!(config.use_adaptive_depth);
assert_eq!(config.max_depth, 8);
config.with_hierarchy(4, vec![1.0, 0.75, 0.5, 0.25]);
assert!(config.use_hierarchical_attention);
assert_eq!(config.hierarchy_levels, 4);
assert_eq!(config.level_compression_ratios.len(), 4);
config.with_universal(12, true);
assert!(config.use_universal_transformer);
assert_eq!(config.max_steps, 12);
assert!(config.adaptive_computation_time);
}
#[test]
fn test_memory_state_creation() -> Result<()> {
let memory = MemoryState::new(2, 1024, 768)?;
let content = memory.get_content()?;
assert_eq!(content.shape(), &[2, 1024, 768]);
Ok(())
}
#[test]
fn test_memory_state_update() -> Result<()> {
let mut memory = MemoryState::new(1, 1024, 768)?;
let new_content = Tensor::ones(&[1, 256, 768])?;
assert!(memory.update(new_content).is_ok());
let large_content = Tensor::ones(&[1, 1000, 768])?;
assert!(memory.update(large_content).is_ok());
Ok(())
}
#[test]
fn test_memory_state_read() -> Result<()> {
let mut memory = MemoryState::new(1, 1024, 768)?;
let content = memory.read(256)?;
assert_eq!(content.shape(), &[1, 256, 768]);
let _ = memory.read(1000)?;
Ok(())
}
#[test]
#[ignore] fn test_recursive_transformer_creation() {
let config = RecursiveConfig::default();
let result = RecursiveTransformer::new(config);
assert!(
result.is_ok(),
"Failed to create RecursiveTransformer: {:?}",
result.err()
);
}
#[test]
#[ignore] fn test_recursive_for_causal_lm_creation() {
let config = RecursiveConfig::default();
let result = RecursiveForCausalLM::new(config);
assert!(
result.is_ok(),
"Failed to create RecursiveForCausalLM: {:?}",
result.err()
);
}
#[test]
#[ignore] fn test_recursive_for_sequence_classification_creation() {
let config = RecursiveConfig::default();
let result = RecursiveForSequenceClassification::new(config, 10);
assert!(
result.is_ok(),
"Failed to create RecursiveForSequenceClassification: {:?}",
result.err()
);
}
#[test]
fn test_memory_manager_creation() {
let config = RecursiveConfig::default();
let result = MemoryManager::new(config);
assert!(
result.is_ok(),
"Failed to create MemoryManager: {:?}",
result.err()
);
}
#[test]
fn test_depth_predictor_creation() {
let config = RecursiveConfig::default();
let result = DepthPredictor::new(config);
assert!(
result.is_ok(),
"Failed to create DepthPredictor: {:?}",
result.err()
);
}
#[test]
fn test_hierarchy_manager_creation() {
let config = RecursiveConfig::hierarchical();
let result = HierarchyManager::new(config);
assert!(
result.is_ok(),
"Failed to create HierarchyManager: {:?}",
result.err()
);
}
#[test]
fn test_universal_controller_creation() {
let config = RecursiveConfig::universal();
let result = UniversalController::new(config);
assert!(
result.is_ok(),
"Failed to create UniversalController: {:?}",
result.err()
);
}
#[test]
fn test_recursive_transformer_forward() {
let config = RecursiveConfig::default();
let model = RecursiveTransformer::new(config.clone()).expect("operation failed");
let input_ids = Tensor::zeros(&[1, 100]).expect("operation failed"); let input = RecursiveInput {
input_ids,
attention_mask: None,
position_ids: None,
memory_state: None,
};
let result = model.forward(input);
assert!(result.is_ok(), "Forward pass failed: {:?}", result.err());
let output = result.expect("operation failed");
assert_eq!(output.last_hidden_state.shape()[0], 1); assert_eq!(output.last_hidden_state.shape()[1], 100); assert_eq!(output.last_hidden_state.shape()[2], config.hidden_size); assert_eq!(output.logits.shape()[2], config.vocab_size); }
#[test]
fn test_recursive_transformer_long_sequence() {
let mut config = RecursiveConfig::long_document();
config.hidden_size = 256; config.num_attention_heads = 4; config.intermediate_size = 512; config.num_recursive_layers = 2; config.chunk_size = 128; config.overlap_size = 0; config.memory_size = 256;
let model = RecursiveTransformer::new(config.clone()).expect("operation failed");
let input_ids = Tensor::zeros(&[1, 512]).expect("operation failed"); let input = RecursiveInput {
input_ids,
attention_mask: None,
position_ids: None,
memory_state: None,
};
let result = model.forward(input);
assert!(
result.is_ok(),
"Long sequence forward pass failed: {:?}",
result.err()
);
let output = result.expect("operation failed");
assert_eq!(output.last_hidden_state.shape()[0], 1);
assert_eq!(output.last_hidden_state.shape()[1], 512);
assert!(output.recursion_depth > 0);
drop(output);
drop(model);
std::hint::black_box(());
}
#[test]
fn test_recursive_transformer_with_memory() -> Result<()> {
let config = RecursiveConfig::default();
let model = RecursiveTransformer::new(config.clone())?;
let input_ids = Tensor::zeros(&[1, 100])?;
let memory_state = MemoryState::new(1, config.memory_size, config.hidden_size)?;
let input = RecursiveInput {
input_ids,
attention_mask: None,
position_ids: None,
memory_state: Some(memory_state),
};
let result = model.forward(input);
assert!(
result.is_ok(),
"Forward pass with memory failed: {:?}",
result.err()
);
let _output = result?;
Ok(())
}
#[test]
#[ignore] fn test_causal_lm_forward() {
let config = RecursiveConfig::default();
let model = RecursiveForCausalLM::new(config.clone()).expect("operation failed");
let input_ids = Tensor::zeros(&[1, 64]).expect("operation failed"); let input = RecursiveInput {
input_ids,
attention_mask: None,
position_ids: None,
memory_state: None,
};
let result = model.forward(input);
assert!(
result.is_ok(),
"CausalLM forward pass failed: {:?}",
result.err()
);
let output = result.expect("operation failed");
assert_eq!(output.last_hidden_state.shape()[0], 1);
assert_eq!(output.last_hidden_state.shape()[1], 64);
assert_eq!(output.logits.shape()[2], config.vocab_size);
}
#[test]
fn test_sequence_classification_forward() {
let config = RecursiveConfig::default();
let num_labels = 5;
let model = RecursiveForSequenceClassification::new(config.clone(), num_labels)
.expect("operation failed");
let input_ids = Tensor::zeros(&[2, 100]).expect("operation failed");
let input = RecursiveInput {
input_ids,
attention_mask: None,
position_ids: None,
memory_state: None,
};
let result = model.forward(input);
assert!(
result.is_ok(),
"Classification forward pass failed: {:?}",
result.err()
);
let output = result.expect("operation failed");
assert_eq!(output.logits.shape()[0], 2); assert_eq!(output.logits.shape()[1], num_labels); }
#[test]
fn test_depth_predictor_predict() -> Result<()> {
let config = RecursiveConfig::default();
let predictor = DepthPredictor::new(config)?;
let memory = MemoryState::new(1, 1024, 768)?;
let short_seq = Tensor::zeros(&[1, 100])?;
let depth = predictor.predict_depth(&short_seq, &memory)?;
assert!((1..=5).contains(&depth));
let long_seq = Tensor::zeros(&[1, 5000])?;
let depth = predictor.predict_depth(&long_seq, &memory)?;
assert!(depth >= 3); Ok(())
}
#[test]
fn test_model_info() {
let info = model_info("recursive-long-document").expect("operation failed");
assert_eq!(info.name, "Recursive Long Document");
assert_eq!(info.max_sequence_length, 32768);
assert!(info.memory_efficient);
assert!(info.adaptive_depth);
let universal_info = model_info("recursive-universal").expect("operation failed");
assert!(universal_info.adaptive_depth);
assert!(!universal_info.memory_efficient);
}
#[test]
fn test_available_models() {
let models = available_models();
assert!(models.contains(&"recursive-long-document"));
assert!(models.contains(&"recursive-universal"));
assert!(models.contains(&"recursive-memory-efficient"));
assert!(models.contains(&"recursive-hierarchical"));
assert!(models.contains(&"recursive-code"));
assert_eq!(models.len(), 5);
}
#[test]
fn test_convenience_functions() {
assert!(long_document().is_ok());
assert!(universal().is_ok());
assert!(memory_efficient().is_ok());
assert!(hierarchical().is_ok());
assert!(code_understanding().is_ok());
assert!(from_pretrained("recursive-long-document").is_ok());
assert!(from_pretrained("invalid-model").is_err());
let config = RecursiveConfig::default();
assert!(for_causal_lm(config.clone()).is_ok());
assert!(for_sequence_classification(config, 10).is_ok());
}
#[test]
fn test_utility_functions() -> Result<()> {
let config = RecursiveConfig::default();
let _memory = create_memory_state(2, &config)?;
let chunk_size = optimal_chunk_size(10000, 1024, 768);
assert!(chunk_size > 0 && chunk_size <= 2500);
let memory_usage = estimate_memory_usage(&config, 1000);
assert!(memory_usage > 0);
Ok(())
}
#[test]
fn test_config_presets() {
let book_config = ConfigPresets::book_processing();
assert_eq!(book_config.chunk_size, 2048);
assert!(book_config.use_hierarchical_attention);
let code_config = ConfigPresets::code_analysis();
assert_eq!(code_config.hierarchy_levels, 4);
assert!(code_config.use_adaptive_depth);
let legal_config = ConfigPresets::legal_documents();
assert!(legal_config.use_memory_compression);
assert_eq!(legal_config.compression_ratio, 0.3);
let paper_config = ConfigPresets::research_papers();
assert_eq!(paper_config.hierarchy_levels, 3);
assert!(paper_config.cross_level_attention);
let mobile_config = ConfigPresets::mobile_deployment();
assert_eq!(mobile_config.hidden_size, 384);
assert_eq!(mobile_config.chunk_size, 256);
}
#[test]
fn test_performance_tips() {
let tips = performance_tips();
assert!(!tips.is_empty());
assert!(tips.len() >= 8);
assert!(tips.iter().any(|tip| tip.contains("memory")));
assert!(tips.iter().any(|tip| tip.contains("chunk")));
}
#[test]
fn test_empty_sequence() {
let config = RecursiveConfig::default();
let model = RecursiveTransformer::new(config).expect("operation failed");
let input_ids = Tensor::zeros(&[1, 1]).expect("operation failed"); let input = RecursiveInput {
input_ids,
attention_mask: None,
position_ids: None,
memory_state: None,
};
let result = model.forward(input);
assert!(result.is_ok());
}
#[test]
#[ignore = "scirs2-core SIMD non-contiguous array panic"]
fn test_batch_processing() {
let config = RecursiveConfig {
hidden_size: 128, num_attention_heads: 4, intermediate_size: 256, num_recursive_layers: 1, chunk_size: 64, overlap_size: 0, memory_size: 128, ..RecursiveConfig::default()
};
let model = RecursiveTransformer::new(config).expect("operation failed");
let batch_size = 2; let seq_len = 64;
let input_ids = Tensor::zeros(&[batch_size, seq_len]).expect("operation failed");
let input = RecursiveInput {
input_ids,
attention_mask: None,
position_ids: None,
memory_state: None,
};
let result = model.forward(input);
assert!(result.is_ok());
let output = result.expect("operation failed");
assert_eq!(output.last_hidden_state.shape()[0], batch_size);
assert_eq!(output.last_hidden_state.shape()[1], seq_len);
drop(output);
drop(model);
std::hint::black_box(());
}
#[test]
fn test_config_architecture_name() {
let config = RecursiveConfig::default();
assert_eq!(config.architecture(), "RecursiveTransformer");
}
#[test]
fn test_very_long_sequence() {
let mut config = RecursiveConfig::long_document();
config.hidden_size = 128; config.num_attention_heads = 2; config.intermediate_size = 256; config.num_recursive_layers = 1; config.chunk_size = 128; config.overlap_size = 0; config.memory_size = 128; config.recursion_depth = 2;
let model = RecursiveTransformer::new(config).expect("operation failed");
let input_ids = Tensor::zeros(&[1, 256]).expect("operation failed"); let input = RecursiveInput {
input_ids,
attention_mask: None,
position_ids: None,
memory_state: None,
};
let result = model.forward(input);
assert!(
result.is_ok(),
"Very long sequence processing failed: {:?}",
result.err()
);
let output = result.expect("operation failed");
assert_eq!(output.last_hidden_state.shape()[1], 256);
assert!(output.recursion_depth > 0);
drop(output);
drop(model);
std::hint::black_box(());
}