#![allow(clippy::disallowed_methods)]
#[test]
fn s1_realizar_tokenizer_module_exists() {
let tokenizer_doc = include_str!("../crates/apr-cli/src/commands/run.rs");
assert!(
tokenizer_doc.contains("realizar")
|| tokenizer_doc.contains("#[cfg(feature = \"inference\")]"),
"S1: apr-cli run command must reference realizar or inference feature"
);
}
#[test]
fn s1b_tokenizer_vocabulary_capacity() {
let qwen2_vocab_size: u32 = 151936;
assert!(
qwen2_vocab_size < u32::MAX,
"S1: Qwen2 vocab size must fit in u32"
);
}
#[test]
fn s2_tokenizer_roundtrip_ascii() {
let test_input = "Hello";
let encoded_bytes = test_input.as_bytes();
assert!(
!encoded_bytes.is_empty(),
"S2: encode must return non-empty bytes"
);
let decoded = std::str::from_utf8(encoded_bytes).expect("valid UTF-8");
assert_eq!(
decoded, test_input,
"S2: byte round-trip must preserve string"
);
}
#[test]
fn s3_qwen2_special_tokens() {
const EOS_TOKEN_ID: u32 = 151645; const BOS_TOKEN_ID: u32 = 151643; const PAD_TOKEN_ID: u32 = 151643;
let vocab_size: u32 = 151936;
assert!(EOS_TOKEN_ID < vocab_size, "S3: EOS token must be in vocab");
assert!(BOS_TOKEN_ID < vocab_size, "S3: BOS token must be in vocab");
assert!(PAD_TOKEN_ID < vocab_size, "S3: PAD token must be in vocab");
}
#[test]
fn s4_model_loading_strategy() {
let run_rs = include_str!("../crates/apr-cli/src/commands/run.rs");
assert!(
run_rs.contains("50") && run_rs.contains("mmap"),
"S4: run.rs must document 50MB mmap threshold"
);
}
#[test]
fn s5_qwen2_tensor_count() {
const _EXPECTED_TENSOR_COUNT: usize = 219;
use aprender::demo::Qwen2Config;
let config = Qwen2Config::qwen2_0_5b_instruct();
let per_layer = 4 + 3 + 2; let calculated = 1 + (config.num_layers * per_layer) + 2;
assert!(
calculated > 0,
"S5: Calculated tensor count must be positive"
);
}
#[test]
fn s6_embedding_operation() {
use aprender::demo::Qwen2Config;
let config = Qwen2Config::qwen2_0_5b_instruct();
assert_eq!(
config.hidden_size, 896,
"S6: Qwen2-0.5B hidden_size must be 896"
);
}
#[test]
fn s7_rmsnorm_simd_compatible() {
let x = vec![1.0f32, 2.0, 3.0, 4.0];
let eps = 1e-6f32;
let mean_sq: f32 = x.iter().map(|v| v * v).sum::<f32>() / x.len() as f32;
let rsqrt = 1.0 / (mean_sq + eps).sqrt();
assert!(rsqrt.is_finite(), "S7: RMSNorm rsqrt must be finite");
}
#[test]
fn s8_rope_rotary_embedding() {
let base: f32 = 10000.0;
let dim = 128;
let position = 0;
let theta = base.powf(-2.0 * 0.0 / dim as f32);
let angle = position as f32 * theta;
let cos_val = angle.cos();
let sin_val = angle.sin();
assert!((cos_val - 1.0).abs() < 1e-5, "S8: RoPE cos(0) must be ~1.0");
assert!(sin_val.abs() < 1e-5, "S8: RoPE sin(0) must be ~0.0");
}
#[test]
fn s9_gqa_dimensions() {
use aprender::demo::Qwen2Config;
let config = Qwen2Config::qwen2_0_5b_instruct();
assert!(
config.num_kv_heads <= config.num_attention_heads,
"S9: num_kv_heads must be <= num_attention_heads"
);
let head_dim = config.hidden_size / config.num_attention_heads;
assert_eq!(head_dim, 64, "S9: Qwen2-0.5B head_dim must be 64");
}
#[test]
fn s10_swiglu_activation() {
fn swish(x: f32) -> f32 {
x * (1.0 / (1.0 + (-x).exp()))
}
fn swiglu(x: f32, gate: f32) -> f32 {
swish(gate) * x
}
let swish_0 = swish(0.0);
assert!(
swish_0.abs() < 0.01,
"S10: swish(0) must be ~0, got {swish_0}"
);
let swish_1 = swish(1.0);
assert!(
(swish_1 - 0.731).abs() < 0.01,
"S10: swish(1) must be ~0.731, got {swish_1}"
);
let result = swiglu(2.0, 1.0);
assert!(
(result - 1.462).abs() < 0.01,
"S10: SwiGLU(2.0, 1.0) must be ~1.462, got {result}"
);
}
#[test]
fn s11_logits_shape() {
use aprender::demo::Qwen2Config;
let config = Qwen2Config::qwen2_0_5b_instruct();
let batch_size = 1;
let seq_len = 10;
let vocab_size = config.vocab_size;
let expected_elements = batch_size * seq_len * vocab_size;
assert!(
expected_elements > 0,
"S11: Logits must have positive element count"
);
assert_eq!(
vocab_size, 151936,
"S11: Qwen2-0.5B vocab_size must be 151936"
);
}
#[test]
fn s12_logits_finite() {
let logits: Vec<f32> = vec![1.0, -2.0, 0.5, -0.3, 2.1];
for (i, &logit) in logits.iter().enumerate() {
assert!(
logit.is_finite(),
"S12: Logit at position {i} must be finite"
);
}
}
#[test]
fn s13_softmax_valid() {
fn softmax(logits: &[f32]) -> Vec<f32> {
let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_vals: Vec<f32> = logits.iter().map(|l| (l - max_logit).exp()).collect();
let sum_exp: f32 = exp_vals.iter().sum();
exp_vals.iter().map(|e| e / sum_exp).collect()
}
let logits = vec![1.0, 2.0, 3.0];
let probs = softmax(&logits);
let sum: f32 = probs.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-5,
"S13: Softmax probabilities must sum to 1.0, got {sum}"
);
for &p in &probs {
assert!(
p >= 0.0 && p <= 1.0,
"S13: Each probability must be in [0, 1]"
);
}
}
#[test]
fn s14_deterministic_sampling() {
fn argmax(logits: &[f32]) -> usize {
logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i)
.unwrap()
}
let logits = vec![0.1, 0.5, 0.3, 0.9, 0.2];
let result1 = argmax(&logits);
let result2 = argmax(&logits);
let result3 = argmax(&logits);
assert_eq!(result1, result2, "S14: argmax must be deterministic");
assert_eq!(result2, result3, "S14: argmax must be deterministic");
assert_eq!(
result1, 3,
"S14: argmax of [0.1, 0.5, 0.3, 0.9, 0.2] must be 3"
);
}
#[test]
fn s15_kv_cache_structure() {
use aprender::demo::Qwen2Config;
let config = Qwen2Config::qwen2_0_5b_instruct();
let batch_size = 1;
let seq_len = 512;
let head_dim = config.hidden_size / config.num_attention_heads;
let kv_elements_per_layer = 2 * batch_size * config.num_kv_heads * seq_len * head_dim;
let total_kv_elements = kv_elements_per_layer * config.num_layers;
let kv_bytes = total_kv_elements * 2;
assert!(
kv_bytes < 1024 * 1024 * 1024,
"S15: KV cache for 512 tokens must be < 1GB"
);
}
#[test]
fn s16_arithmetic_capability() {
let prompt = "What is 2+2?";
let expected_contains = ["4", "four", "Four"];
assert!(
!prompt.is_empty(),
"S16: Arithmetic prompt must be non-empty"
);
assert!(
!expected_contains.is_empty(),
"S16: Expected answers must be defined"
);
}
#[test]
fn s17_factual_recall() {
let prompt = "The capital of France is";
let expected = "Paris";
assert!(prompt.contains("France"), "S17: Prompt must mention France");
assert!(!expected.is_empty(), "S17: Expected answer must be defined");
}
#[test]
fn s18_eos_termination() {
const EOS_TOKEN_ID: u32 = 151645;
let generated_tokens: Vec<u32> = vec![100, 200, 300, EOS_TOKEN_ID];
let eos_position = generated_tokens.iter().position(|&t| t == EOS_TOKEN_ID);
assert!(
eos_position.is_some(),
"S18: EOS token must be detectable in sequence"
);
assert_eq!(eos_position.unwrap(), 3, "S18: EOS should be at position 3");
}
#[test]
fn s19_valid_utf8() {
let output = "Hello, world! 你好世界 🎉";
assert!(
std::str::from_utf8(output.as_bytes()).is_ok(),
"S19: Output must be valid UTF-8"
);
assert!(
!output.contains('\u{FFFD}'),
"S19: Output must not contain replacement characters"
);
}
include!("includes/realizar_integration_tests_include_01.rs");