use crate::autograd::Tensor;
use crate::finetune::linear_probe::LinearProbe;
use crate::transformer::{EncoderModel, ModelArchitecture, TransformerConfig};
fn tiny_encoder_config() -> TransformerConfig {
TransformerConfig {
hidden_size: 32,
num_hidden_layers: 2,
num_attention_heads: 4,
num_kv_heads: 4,
intermediate_size: 64,
vocab_size: 100,
max_position_embeddings: 32,
rms_norm_eps: 1e-5,
architecture: ModelArchitecture::Encoder,
..TransformerConfig::tiny()
}
}
#[test]
fn falsify_biatt_001_no_causal_mask() {
let config = tiny_encoder_config();
let encoder = EncoderModel::new(&config);
let tokens_a = vec![1, 2, 3, 4];
let tokens_b = vec![1, 2, 99, 4];
let out_a = encoder.cls_embedding(&tokens_a);
let out_b = encoder.cls_embedding(&tokens_b);
let da = out_a.data();
let db = out_b.data();
let sa = da.as_slice().expect("contiguous");
let sb = db.as_slice().expect("contiguous");
let diff: f32 = sa.iter().zip(sb.iter()).map(|(a, b)| (a - b).abs()).sum();
assert!(
diff > 1e-6,
"FALSIFY-BIATT-001: CLS embedding must change when later tokens change (diff={diff}). \
If zero, causal mask is blocking bidirectional attention."
);
}
#[test]
fn falsify_biatt_002_single_token_deterministic() {
let config = tiny_encoder_config();
let encoder = EncoderModel::new(&config);
let out1 = encoder.cls_embedding(&[42]);
let out2 = encoder.cls_embedding(&[42]);
let d1 = out1.data();
let d2 = out2.data();
let s1 = d1.as_slice().expect("contiguous");
let s2 = d2.as_slice().expect("contiguous");
assert_eq!(
s1, s2,
"FALSIFY-BIATT-002: Single-token output must be bit-identical on repeated calls"
);
}
#[test]
fn falsify_biatt_003_output_finite() {
let config = tiny_encoder_config();
let encoder = EncoderModel::new(&config);
let tokens = vec![10, 20, 30, 40, 50];
let output = encoder.forward(&tokens);
let data = output.data();
let slice = data.as_slice().expect("contiguous");
assert!(
slice.iter().all(|v| v.is_finite()),
"FALSIFY-BIATT-003: All encoder outputs must be finite (implies attention weights normalized)"
);
}
#[test]
fn falsify_probe_001_encoder_frozen() {
let config = tiny_encoder_config();
let encoder = EncoderModel::new(&config);
let tokens = vec![1, 2, 3];
let before = encoder.cls_embedding(&tokens);
let before_data = before.data();
let before_slice = before_data.as_slice().expect("contiguous").to_vec();
let embeddings: Vec<Vec<f32>> = (0..20)
.map(|i| {
let t = vec![i as u32 % 100 + 1; 3];
let cls = encoder.cls_embedding(&t);
let d = cls.data();
d.as_slice().expect("contiguous").to_vec()
})
.collect();
let labels: Vec<usize> = (0..20).map(|i| usize::from(i >= 15)).collect();
let mut probe = LinearProbe::new(config.hidden_size, 2);
probe.train(&embeddings, &labels, 50, 0.1, None);
let after = encoder.cls_embedding(&tokens);
let after_data = after.data();
let after_slice = after_data.as_slice().expect("contiguous");
assert_eq!(
before_slice.as_slice(),
after_slice,
"FALSIFY-PROBE-001: Encoder weights must be unchanged after linear probe training. \
If different, gradient leaked through frozen encoder."
);
}
#[test]
fn falsify_probe_002_probability_simplex() {
let probe = LinearProbe::new(32, 2);
let test_embeddings = vec![
vec![0.0f32; 32], vec![1.0f32; 32], vec![-1.0f32; 32], vec![0.5f32; 32], vec![-0.5f32; 32], (0..32).map(|i| (i as f32 - 16.0) * 0.1).collect(), ];
for (i, emb) in test_embeddings.iter().enumerate() {
let probs = probe.predict_probs(&Tensor::from_vec(emb.clone(), false));
let sum: f32 = probs.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-5,
"FALSIFY-PROBE-002: Softmax must sum to 1.0 for embedding {i}, got {sum}"
);
for (j, &p) in probs.iter().enumerate() {
assert!(
p > 0.0,
"FALSIFY-PROBE-002: All probabilities must be > 0 for embedding {i}, got prob[{j}]={p}"
);
}
}
}
#[test]
fn falsify_probe_003_trainable_param_count() {
let probe_codebert = LinearProbe::new(768, 2);
assert_eq!(
probe_codebert.num_parameters(),
768 * 2 + 2,
"FALSIFY-PROBE-003: CodeBERT probe must have exactly 1538 params"
);
let probe_tiny = LinearProbe::new(32, 2);
assert_eq!(
probe_tiny.num_parameters(),
32 * 2 + 2,
"FALSIFY-PROBE-003: Tiny probe must have exactly 66 params"
);
for (h, k) in [(64, 3), (128, 5), (256, 10)] {
let probe = LinearProbe::new(h, k);
assert_eq!(
probe.num_parameters(),
h * k + k,
"FALSIFY-PROBE-003: LinearProbe({h}, {k}) must have {expected} params",
expected = h * k + k
);
}
}
#[test]
fn falsify_enc_001_shape_preservation() {
let config = tiny_encoder_config();
let encoder = EncoderModel::new(&config);
for seq_len in [1, 3, 5, 10] {
let tokens: Vec<u32> = (0..seq_len).map(|i| (i as u32) + 1).collect();
let output = encoder.forward(&tokens);
assert_eq!(
output.len(),
seq_len * config.hidden_size,
"FALSIFY-ENC-001: Encoder output must have shape (seq_len={seq_len}, hidden={h}), \
got len={len}",
h = config.hidden_size,
len = output.len()
);
}
}
#[test]
fn falsify_enc_002_no_nan_inf() {
let config = tiny_encoder_config();
let encoder = EncoderModel::new(&config);
let test_cases: Vec<Vec<u32>> = vec![
vec![0, 0, 0], vec![99, 99, 99], vec![1], vec![1, 50, 99, 25, 75], ];
for tokens in &test_cases {
let output = encoder.forward(tokens);
let data = output.data();
let slice = data.as_slice().expect("contiguous");
assert!(
slice.iter().all(|v| v.is_finite()),
"FALSIFY-ENC-002: No NaN/Inf in encoder output for tokens {tokens:?}"
);
}
}
#[test]
fn falsify_pos_001_deterministic_lookup() {
let config = tiny_encoder_config();
let encoder = EncoderModel::new(&config);
let tokens = vec![1, 2, 3, 4, 5];
let out1 = encoder.forward(&tokens);
let out2 = encoder.forward(&tokens);
let d1 = out1.data();
let d2 = out2.data();
let s1 = d1.as_slice().expect("contiguous");
let s2 = d2.as_slice().expect("contiguous");
assert_eq!(
s1, s2,
"FALSIFY-POS-001: Encoder output must be bit-identical for same input (deterministic positions)"
);
}
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn falsify_probe_002_prop(
hidden_size in 4usize..=64,
num_classes in 2usize..=5,
) {
let probe = LinearProbe::new(hidden_size, num_classes);
let emb = Tensor::from_vec(vec![0.1f32; hidden_size], false);
let probs = probe.predict_probs(&emb);
let sum: f32 = probs.iter().sum();
prop_assert!(
(sum - 1.0).abs() < 1e-4,
"FALSIFY-PROBE-002-prop: Softmax sum={sum} for hidden={hidden_size}, classes={num_classes}"
);
prop_assert!(
probs.iter().all(|&p| p > 0.0),
"FALSIFY-PROBE-002-prop: All probs must be > 0"
);
}
#[test]
fn falsify_probe_003_prop(
hidden_size in 1usize..=512,
num_classes in 2usize..=20,
) {
let probe = LinearProbe::new(hidden_size, num_classes);
let expected = hidden_size * num_classes + num_classes;
prop_assert_eq!(
probe.num_parameters(),
expected,
"FALSIFY-PROBE-003-prop: params must be H*K + K"
);
}
#[test]
fn falsify_enc_001_prop(
seq_len in 1usize..=16,
) {
let config = tiny_encoder_config();
let encoder = EncoderModel::new(&config);
let tokens: Vec<u32> = (0..seq_len).map(|i| (i as u32) % 100 + 1).collect();
let output = encoder.forward(&tokens);
let expected = seq_len * config.hidden_size;
prop_assert_eq!(
output.len(),
expected,
"FALSIFY-ENC-001-prop: shape mismatch for seq_len={}",
seq_len
);
}
}
}