use super::*;
#[test]
fn falsify_em_001_output_shape() {
let vocab_size = 100;
let hidden_size = 64;
let embed = Embedding::new(vocab_size, hidden_size);
let seq_lengths = [1, 5, 10, 50, 100];
for &seq_len in &seq_lengths {
let ids: Vec<u32> = (0..seq_len).map(|i| (i % vocab_size) as u32).collect();
let output = embed.forward(&ids);
assert_eq!(
output.shape(),
&[1, seq_len, hidden_size],
"FALSIFIED EM-001: shape for seq_len={seq_len} is {:?}, expected [1, {seq_len}, {hidden_size}]",
output.shape()
);
}
}
#[test]
fn falsify_em_002_oob_panic_freedom() {
let vocab_size = 50;
let hidden_size = 16;
let embed = Embedding::new(vocab_size, hidden_size);
let oob_ids = vec![50u32, 100, 999, u32::MAX];
let output = embed.forward(&oob_ids);
assert_eq!(
output.shape(),
&[1, 4, hidden_size],
"FALSIFIED EM-002: OOB shape {:?} != [1, 4, {hidden_size}]",
output.shape()
);
for (i, &val) in output.data().iter().enumerate() {
assert!(
val.abs() < 1e-10,
"FALSIFIED EM-002: OOB output[{i}] = {val}, expected 0.0"
);
}
}
#[test]
fn falsify_em_002b_mixed_valid_oob() {
let vocab_size = 10;
let hidden_size = 4;
let embed = Embedding::new(vocab_size, hidden_size);
let ids = vec![0u32, 100, 5];
let output = embed.forward(&ids);
let out_data = output.data();
let weight_data = embed.weight().data();
for d in 0..hidden_size {
let expected = weight_data[0 * hidden_size + d];
let actual = out_data[0 * hidden_size + d];
assert!(
(actual - expected).abs() < 1e-10,
"FALSIFIED EM-002b: valid token at pos 0, dim {d}: {actual} != {expected}"
);
}
for d in 0..hidden_size {
let actual = out_data[1 * hidden_size + d];
assert!(
actual.abs() < 1e-10,
"FALSIFIED EM-002b: OOB token at pos 1, dim {d}: {actual} != 0.0"
);
}
for d in 0..hidden_size {
let expected = weight_data[5 * hidden_size + d];
let actual = out_data[2 * hidden_size + d];
assert!(
(actual - expected).abs() < 1e-10,
"FALSIFIED EM-002b: valid token at pos 2, dim {d}: {actual} != {expected}"
);
}
}
#[test]
fn falsify_em_003_deterministic() {
let embed = Embedding::new(50, 32);
let ids = vec![0u32, 5, 10, 49, 1, 23];
let out1 = embed.forward(&ids);
let out2 = embed.forward(&ids);
assert_eq!(
out1.data(),
out2.data(),
"FALSIFIED EM-003: two calls with identical inputs differ"
);
}
#[test]
fn falsify_em_004_finite_output() {
let embed = Embedding::new(200, 128);
let ids: Vec<u32> = (0..200).collect();
let output = embed.forward(&ids);
for (i, &val) in output.data().iter().enumerate() {
assert!(
val.is_finite(),
"FALSIFIED EM-004: output[{i}] = {val} (not finite)"
);
}
}
#[test]
fn falsify_em_005_row_lookup_correctness() {
let vocab_size = 10;
let hidden_size = 4;
let embed = Embedding::new(vocab_size, hidden_size);
let ids = vec![0u32, 3, 7, 9, 1];
let output = embed.forward(&ids);
let out_data = output.data();
let weight_data = embed.weight().data();
for (seq_pos, &token_id) in ids.iter().enumerate() {
let token_idx = token_id as usize;
for d in 0..hidden_size {
let expected = weight_data[token_idx * hidden_size + d];
let actual = out_data[seq_pos * hidden_size + d];
assert!(
(actual - expected).abs() < 1e-10,
"FALSIFIED EM-005: output[{seq_pos}][{d}] = {actual}, expected W[{token_idx}][{d}] = {expected}"
);
}
}
}
mod qwen2_em_proptest_falsify {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(15))]
#[test]
fn falsify_em_001_prop_output_shape(
seq_len in 1..=30usize,
) {
let vocab_size = 50;
let hidden_size = 16;
let embed = Embedding::new(vocab_size, hidden_size);
let ids: Vec<u32> = (0..seq_len).map(|i| (i % vocab_size) as u32).collect();
let output = embed.forward(&ids);
prop_assert_eq!(
output.shape(),
&[1, seq_len, hidden_size],
"FALSIFIED EM-001-prop: shape {:?} != [1, {}, {}]",
output.shape(), seq_len, hidden_size
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(15))]
#[test]
fn falsify_em_004_prop_finite_output(
seed in 0..200u32,
) {
let vocab_size = 40;
let hidden_size = 8;
let embed = Embedding::new(vocab_size, hidden_size);
let ids: Vec<u32> = (0..10).map(|i| ((i + seed as usize) % vocab_size) as u32).collect();
let output = embed.forward(&ids);
for (i, &val) in output.data().iter().enumerate() {
prop_assert!(
val.is_finite(),
"FALSIFIED EM-004-prop: output[{}]={} not finite (seed={})",
i, val, seed
);
}
}
}
}