impl ModelTestCase {
pub fn new(
desc: impl Into<String>,
config: ModelConfig,
format: ModelFormat,
device: Device,
) -> Self {
Self {
desc: desc.into(),
constructor: ConstructorInput::new(config),
forward: ForwardInput::new(vec![1, 2, 3, 4]),
expected_output_norm: None,
source_format: format,
target_format: None,
device,
}
}
pub fn conversion(
desc: impl Into<String>,
config: ModelConfig,
source: ModelFormat,
target: ModelFormat,
device: Device,
) -> Self {
Self {
desc: desc.into(),
constructor: ConstructorInput::new(config),
forward: ForwardInput::new(vec![1, 2, 3, 4]),
expected_output_norm: None,
source_format: source,
target_format: Some(target),
device,
}
}
#[must_use]
pub fn with_quant(mut self, quant: QuantType) -> Self {
self.constructor.quantization = Some(quant);
self
}
#[must_use]
pub fn with_tokens(mut self, tokens: Vec<u32>) -> Self {
self.forward.tokens = tokens;
self
}
#[must_use]
pub fn with_expected_norm(mut self, norm: f32) -> Self {
self.expected_output_norm = Some(norm);
self
}
pub fn is_conversion_test(&self) -> bool {
self.target_format.is_some()
}
}
#[derive(Debug)]
pub struct TestResult {
pub test_case: String,
pub passed: bool,
pub output: Option<Vec<f32>>,
pub error: Option<String>,
pub duration_us: u64,
pub memory_bytes: Option<usize>,
}
impl TestResult {
pub fn success(test_case: &str, output: Vec<f32>, duration_us: u64) -> Self {
Self {
test_case: test_case.to_string(),
passed: true,
output: Some(output),
error: None,
duration_us,
memory_bytes: None,
}
}
pub fn failure(test_case: &str, error: impl Into<String>, duration_us: u64) -> Self {
Self {
test_case: test_case.to_string(),
passed: false,
output: None,
error: Some(error.into()),
duration_us,
memory_bytes: None,
}
}
pub fn output_l2_norm(&self) -> Option<f32> {
self.output
.as_ref()
.map(|o| o.iter().map(|x| x * x).sum::<f32>().sqrt())
}
}
#[derive(Debug, Clone, Copy)]
pub struct Tolerances {
pub f32_abs: f32,
pub f32_rel: f32,
pub quant_l2_pct: f32,
pub device_parity: f32,
}
impl Default for Tolerances {
fn default() -> Self {
Self {
f32_abs: 1e-5,
f32_rel: 1e-4,
quant_l2_pct: 5.0,
device_parity: 1e-3,
}
}
}
impl Tolerances {
pub fn strict() -> Self {
Self {
f32_abs: 1e-6,
f32_rel: 1e-5,
quant_l2_pct: 1.0,
device_parity: 1e-4,
}
}
pub fn quantized() -> Self {
Self {
f32_abs: 1e-3,
f32_rel: 1e-2,
quant_l2_pct: 10.0,
device_parity: 5e-2,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_config_tiny() {
let config = ModelConfig::tiny();
assert_eq!(config.hidden_dim, 64);
assert_eq!(config.num_heads, 4);
assert_eq!(config.num_kv_heads, 2);
assert_eq!(config.head_dim(), 16);
assert_eq!(config.gqa_group_size(), 2);
assert!(config.is_gqa());
assert!(!config.is_mqa());
}
#[test]
fn test_model_config_qwen() {
let config = ModelConfig::qwen_1_5b();
assert_eq!(config.hidden_dim, 1536);
assert_eq!(config.num_heads, 12);
assert_eq!(config.num_kv_heads, 2);
assert_eq!(config.head_dim(), 128);
assert_eq!(config.gqa_group_size(), 6);
assert!(config.is_gqa());
}
#[test]
fn test_model_config_dimensions() {
let config = ModelConfig::small();
assert_eq!(config.q_dim(), 256); assert_eq!(config.k_dim(), 64); assert_eq!(config.v_dim(), 64);
}
#[test]
fn test_quant_type_bits() {
assert_eq!(QuantType::F32.bits_per_weight(), 32.0);
assert_eq!(QuantType::Q4_K.bits_per_weight(), 4.5);
}
#[test]
fn test_quant_type_format_support() {
assert!(QuantType::F32.supported_by(ModelFormat::Safetensors));
assert!(!QuantType::Q4_K.supported_by(ModelFormat::Safetensors));
assert!(QuantType::Q4_K.supported_by(ModelFormat::GGUF));
assert!(QuantType::Q4_K.supported_by(ModelFormat::APR));
}
#[test]
fn test_model_test_case_creation() {
let tc = ModelTestCase::new(
"tiny CPU test",
ModelConfig::tiny(),
ModelFormat::GGUF,
Device::Cpu,
);
assert_eq!(tc.desc, "tiny CPU test");
assert!(!tc.is_conversion_test());
}
#[test]
fn test_model_test_case_conversion() {
let tc = ModelTestCase::conversion(
"GGUF to APR",
ModelConfig::tiny(),
ModelFormat::GGUF,
ModelFormat::APR,
Device::Cpu,
);
assert!(tc.is_conversion_test());
assert_eq!(tc.source_format, ModelFormat::GGUF);
assert_eq!(tc.target_format, Some(ModelFormat::APR));
}
#[test]
fn test_device_display() {
assert_eq!(format!("{}", Device::Cpu), "CPU");
assert_eq!(format!("{}", Device::Cuda(0)), "CUDA:0");
}
#[test]
fn test_format_display() {
assert_eq!(format!("{}", ModelFormat::GGUF), "GGUF");
assert_eq!(format!("{}", ModelFormat::APR), "APR");
}
#[test]
fn test_param_count() {
let config = ModelConfig::tiny();
let params = config.param_count();
assert!(params > 10_000, "params={}", params);
assert!(params < 1_000_000, "params={}", params);
}
#[test]
fn test_test_result_l2_norm() {
let result = TestResult::success("test", vec![3.0, 4.0], 100);
let norm = result.output_l2_norm().expect("norm");
assert!((norm - 5.0).abs() < 1e-6);
}
#[test]
fn test_device_is_cuda() {
assert!(!Device::Cpu.is_cuda());
assert!(Device::Cuda(0).is_cuda());
assert_eq!(Device::Cpu.cuda_id(), None);
assert_eq!(Device::Cuda(7).cuda_id(), Some(7));
}
#[test]
fn test_model_config_tinyllama() {
let config = ModelConfig::tinyllama();
assert_eq!(config.hidden_dim, 2048);
assert_eq!(config.num_layers, 22);
}
#[test]
fn test_model_config_default() {
let default_config = ModelConfig::default();
let tiny_config = ModelConfig::tiny();
assert_eq!(default_config.hidden_dim, tiny_config.hidden_dim);
}
#[test]
fn test_quant_type_all_bits() {
assert_eq!(QuantType::F16.bits_per_weight(), 16.0);
assert_eq!(QuantType::BF16.bits_per_weight(), 16.0);
assert_eq!(QuantType::Q8_0.bits_per_weight(), 8.5);
assert_eq!(QuantType::Q4_0.bits_per_weight(), 4.5);
assert_eq!(QuantType::Q5_K.bits_per_weight(), 5.5);
assert_eq!(QuantType::Q6_K.bits_per_weight(), 6.5);
}
#[test]
fn test_quant_type_unsupported() {
assert!(!QuantType::Q8_0.supported_by(ModelFormat::PyTorch));
assert!(!QuantType::Q8_0.supported_by(ModelFormat::Safetensors));
}
#[test]
fn test_forward_input_seq_len() {
let input = ForwardInput::new(vec![1, 2, 3]);
assert_eq!(input.seq_len(), 3);
assert_eq!(input.position, 0);
let input_pos = ForwardInput::at_position(vec![4, 5], 10);
assert_eq!(input_pos.seq_len(), 2);
assert_eq!(input_pos.position, 10);
}
#[test]
fn test_model_test_case_builder() {
let tc = ModelTestCase::new("test", ModelConfig::tiny(), ModelFormat::APR, Device::Cpu)
.with_quant(QuantType::Q4_K)
.with_tokens(vec![1, 2])
.with_expected_norm(10.0);
assert_eq!(tc.constructor.quantization, Some(QuantType::Q4_K));
assert_eq!(tc.forward.tokens, vec![1, 2]);
assert_eq!(tc.expected_output_norm, Some(10.0));
}
#[test]
fn test_test_result_failure() {
let result = TestResult::failure("fail test", "error message", 50);
assert!(!result.passed);
assert_eq!(result.test_case, "fail test");
assert_eq!(result.error.expect("error"), "error message");
assert!(result.output.is_none());
}
#[test]
fn test_tolerances() {
let default = Tolerances::default();
let strict = Tolerances::strict();
let quantized = Tolerances::quantized();
assert!(strict.f32_abs < default.f32_abs);
assert!(quantized.f32_abs > default.f32_abs);
}
#[test]
fn test_format_display_pytorch_safetensors() {
assert_eq!(format!("{}", ModelFormat::PyTorch), "PyTorch");
assert_eq!(format!("{}", ModelFormat::Safetensors), "Safetensors");
}
#[test]
fn test_constructor_input_with_quant() {
let config = ModelConfig::tiny();
let ci = ConstructorInput::with_quant(config, QuantType::Q8_0, 123);
assert_eq!(ci.quantization, Some(QuantType::Q8_0));
assert_eq!(ci.weights_seed, 123);
}
}