use super::model::{Discretization, HiPPOMatrix};
use super::*;
use scirs2_core::ndarray::{Array1, Array2, Array3}; use trustformers_core::{
tensor::Tensor,
traits::{Config, Layer, Model},
};
#[test]
fn test_hippo_matrix_properties() {
let n = 8;
let legs = HiPPOMatrix::LEGS;
let a_legs = legs.initialize(n);
let sum = &a_legs + &a_legs.t();
let max_error = sum.iter().map(|x| x.abs()).fold(0.0_f32, f32::max);
assert!(max_error < 1e-6, "LEGS matrix should be skew-symmetric");
let legt = HiPPOMatrix::LEGT;
let a_legt = legt.initialize(n);
for i in 0..n {
let expected = -(2.0 * i as f32 + 1.0) / 2.0;
assert!((a_legt[[i, i]] - expected).abs() < 1e-6);
}
let lagt = HiPPOMatrix::LAGT;
let a_lagt = lagt.initialize(n);
for i in 0..n {
assert!((a_lagt[[i, i]] + 0.5).abs() < 1e-6);
}
let fourier = HiPPOMatrix::Fourier;
let a_fourier = fourier.initialize(n);
let sum_fourier = &a_fourier + &a_fourier.t();
let max_error_fourier = sum_fourier.iter().map(|x| x.abs()).fold(0.0_f32, f32::max);
assert!(
max_error_fourier < 1e-6,
"Fourier matrix should be skew-symmetric"
);
}
#[test]
fn test_discretization_stability() {
let n = 4;
let dt = 0.01;
let a = -Array2::<f32>::eye(n); let b = Array1::<f32>::ones(n);
let methods = vec![
Discretization::ZOH,
Discretization::Bilinear,
Discretization::Euler,
Discretization::BackwardEuler,
];
for method in methods {
let (a_bar, b_bar) = method.discretize(&a, &b, dt);
assert_eq!(a_bar.shape(), &[n, n]);
assert_eq!(b_bar.shape(), &[n]);
let trace = a_bar.diag().sum();
assert!(
trace.abs() < n as f32 * 2.0,
"Discretized system should remain bounded"
);
}
}
#[test]
fn test_s4_layer_discretization() {
let config = S4Config {
d_state: 4,
d_model: 8,
..S4Config::default()
};
let mut _layer = S4Layer::new(&config).expect("operation failed");
}
#[test]
fn test_s4_block_forward() {
let config = S4Config {
d_model: 16,
d_state: 4,
n_layer: 1,
vocab_size: 100,
max_position_embeddings: 128,
..Default::default()
};
let block = S4Block::new(&config).expect("operation failed");
let batch_size = 2;
let seq_len = 10;
let input_array = Array3::<f32>::ones((batch_size, seq_len, config.d_model));
let input = Tensor::F32(input_array.into_dyn());
let output = block.forward(input);
assert!(output.is_ok());
let output_tensor = output.expect("operation failed");
match &output_tensor {
Tensor::F32(arr) => {
assert_eq!(arr.ndim(), 3);
let shape = arr.shape();
assert_eq!(shape[0], batch_size);
assert_eq!(shape[1], seq_len);
assert_eq!(shape[2], config.d_model);
},
_ => panic!("Expected F32 tensor"),
}
}
#[test]
fn test_s4_model_shapes() {
let config = S4Config {
d_model: 32,
d_state: 8,
n_layer: 2,
vocab_size: 1000,
max_position_embeddings: 256,
..Default::default()
};
let _model = S4Model::new(config.clone()).expect("operation failed");
}
#[test]
fn test_s4_lm_forward() {
let config = S4Config {
d_model: 16,
d_state: 4,
n_layer: 1,
vocab_size: 50,
max_position_embeddings: 64,
..Default::default()
};
let model = S4ForLanguageModeling::new(config.clone()).expect("operation failed");
let batch_size = 2;
let seq_len = 8;
let input_array = Array2::<i64>::zeros((batch_size, seq_len));
let input = Tensor::I64(input_array.into_dyn());
let output = model.forward(input);
assert!(output.is_ok());
let output_tensor = output.expect("operation failed");
match &output_tensor {
Tensor::F32(arr) => {
assert_eq!(arr.ndim(), 3);
let shape = arr.shape();
assert_eq!(shape[0], batch_size);
assert_eq!(shape[1], seq_len);
assert_eq!(shape[2], config.vocab_size);
},
_ => panic!("Expected F32 tensor output"),
}
}
#[test]
fn test_config_variants() {
let configs = vec![
("s4-small", S4Config::s4_small()),
("s4-base", S4Config::s4_base()),
("s4-large", S4Config::s4_large()),
("s4-long", S4Config::s4_long()),
];
for (name, config) in configs {
assert!(config.validate().is_ok(), "Config {} should be valid", name);
let loaded = S4Config::from_pretrained_name(name);
assert!(loaded.is_some(), "Should load config for {}", name);
let loaded_config = loaded.expect("operation failed");
assert_eq!(loaded_config.d_model, config.d_model);
assert_eq!(loaded_config.d_state, config.d_state);
assert_eq!(loaded_config.n_layer, config.n_layer);
}
}
#[test]
fn test_postact_options() {
let mut config = S4Config::default();
let postacts = vec!["glu", "relu", "gelu", "silu", "identity"];
for postact in postacts {
config.postact = postact.to_string();
let block = S4Block::new(&config);
assert!(
block.is_ok(),
"Failed to create block with postact: {}",
postact
);
}
}
#[test]
fn test_bidirectional_mode() {
let config = S4Config {
bidirectional: true,
..S4Config::default()
};
let layer = S4Layer::new(&config);
assert!(layer.is_ok());
}
#[test]
fn test_different_hippo_initializations() {
let n = 6;
let hippo_types = vec![
("legs", HiPPOMatrix::LEGS),
("legt", HiPPOMatrix::LEGT),
("lagt", HiPPOMatrix::LAGT),
("fourier", HiPPOMatrix::Fourier),
("random", HiPPOMatrix::Random),
];
for (name, hippo) in hippo_types {
let matrix = hippo.initialize(n);
assert_eq!(matrix.shape(), &[n, n], "HiPPO {} has wrong shape", name);
let has_nonzero = matrix.iter().any(|&x| x.abs() > 1e-10);
assert!(
has_nonzero || name == "legs",
"HiPPO {} should have non-zero values",
name
);
}
}
#[test]
fn test_lr_mult_parameter() {
let config = S4Config {
lr_mult: 0.1,
..S4Config::default()
};
let layer = S4Layer::new(&config);
assert!(layer.is_ok());
}
#[test]
fn test_transposed_parameter() {
let config = S4Config {
transposed: false,
..S4Config::default()
};
let layer = S4Layer::new(&config);
assert!(layer.is_ok());
let config_transposed = S4Config {
transposed: true,
..S4Config::default()
};
let layer_transposed = S4Layer::new(&config_transposed);
assert!(layer_transposed.is_ok());
}
#[test]
fn test_n_ssm_configuration() {
let mut config = S4Config::default();
assert_eq!(config.get_n_ssm(), config.d_model);
config.n_ssm = Some(128);
assert_eq!(config.get_n_ssm(), 128);
let layer = S4Layer::new(&config);
assert!(layer.is_ok());
}