#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantized_tensor_q4_new() {
let data = vec![0u8; 36]; let tensor = QuantizedAprTensorQ4::new(data.clone(), 32, 2);
assert_eq!(tensor.data.len(), 36);
assert_eq!(tensor.in_dim, 32);
assert_eq!(tensor.out_dim, 2);
}
#[test]
fn test_quantized_tensor_q4_zeros_small() {
let tensor = QuantizedAprTensorQ4::zeros(32, 1); assert_eq!(tensor.in_dim, 32);
assert_eq!(tensor.out_dim, 1);
assert_eq!(tensor.data.len(), 18);
assert!(tensor.data.iter().all(|&b| b == 0));
}
#[test]
fn test_quantized_tensor_q4_zeros_multiple_blocks() {
let tensor = QuantizedAprTensorQ4::zeros(64, 2); assert_eq!(tensor.data.len(), 72);
}
#[test]
fn test_quantized_tensor_q4_zeros_partial_block() {
let tensor = QuantizedAprTensorQ4::zeros(33, 1); assert_eq!(tensor.data.len(), 36);
}
#[test]
fn test_quantized_tensor_q4_expected_bytes() {
assert_eq!(QuantizedAprTensorQ4::expected_bytes(32), 18);
assert_eq!(QuantizedAprTensorQ4::expected_bytes(33), 36);
assert_eq!(QuantizedAprTensorQ4::expected_bytes(64), 36);
assert_eq!(QuantizedAprTensorQ4::expected_bytes(256), 144);
}
#[test]
fn test_quantized_tensor_q4_expected_bytes_zero() {
assert_eq!(QuantizedAprTensorQ4::expected_bytes(0), 0);
}
fn make_test_config() -> AprTransformerConfig {
AprTransformerConfig {
architecture: "apr".to_string(),
hidden_dim: 64,
num_layers: 2,
num_heads: 4,
num_kv_heads: 2,
vocab_size: 100,
intermediate_dim: 128,
context_length: 512,
rope_theta: 10000.0,
eps: 1e-5,
eos_token_id: None,
..Default::default()
}
}
#[test]
fn test_inference_scratch_from_config() {
let config = make_test_config();
let scratch = AprInferenceScratch::from_config(&config);
assert_eq!(scratch.hidden.len(), 64); assert_eq!(scratch.normed.len(), 64); assert_eq!(scratch.qkv_out.len(), 192); assert_eq!(scratch.q.len(), 64); assert_eq!(scratch.k.len(), 64); assert_eq!(scratch.v.len(), 64); assert_eq!(scratch.attn_out.len(), 64); assert_eq!(scratch.ffn_input.len(), 64); assert_eq!(scratch.ffn_up.len(), 128); assert_eq!(scratch.ffn_gate.len(), 128); assert_eq!(scratch.ffn_out.len(), 64); }
#[test]
fn test_inference_scratch_initialized_to_zero() {
let config = make_test_config();
let scratch = AprInferenceScratch::from_config(&config);
assert!(scratch.hidden.iter().all(|&x| x == 0.0));
assert!(scratch.normed.iter().all(|&x| x == 0.0));
assert!(scratch.qkv_out.iter().all(|&x| x == 0.0));
assert!(scratch.attn_out.iter().all(|&x| x == 0.0));
assert!(scratch.ffn_up.iter().all(|&x| x == 0.0));
}
#[test]
fn test_inference_scratch_clear() {
let config = make_test_config();
let mut scratch = AprInferenceScratch::from_config(&config);
scratch.hidden[0] = 1.0;
scratch.normed[0] = 2.0;
scratch.attn_out[0] = 3.0;
scratch.ffn_up[0] = 4.0;
scratch.clear();
assert!(scratch.hidden.iter().all(|&x| x == 0.0));
assert!(scratch.normed.iter().all(|&x| x == 0.0));
assert!(scratch.attn_out.iter().all(|&x| x == 0.0));
assert!(scratch.ffn_up.iter().all(|&x| x == 0.0));
}
#[test]
fn test_inference_scratch_large_config() {
let config = AprTransformerConfig {
architecture: "apr".to_string(),
hidden_dim: 4096,
num_layers: 32,
num_heads: 32,
num_kv_heads: 8,
vocab_size: 32000,
intermediate_dim: 11008,
context_length: 4096,
rope_theta: 10000.0,
eps: 1e-5,
eos_token_id: None,
..Default::default()
};
let scratch = AprInferenceScratch::from_config(&config);
assert_eq!(scratch.hidden.len(), 4096);
assert_eq!(scratch.ffn_up.len(), 11008);
assert_eq!(scratch.ffn_gate.len(), 11008);
}
}