use super::classification::{cross_entropy_loss, ClassificationHead};
use crate::autograd::Tensor;
#[test]
fn falsify_class_ent_001_forward_output_shape() {
let head = ClassificationHead::new(64, 5);
let hidden = Tensor::from_vec(vec![0.1f32; 3 * 64], false);
let logits = head.forward(&hidden, 3);
assert_eq!(
logits.len(),
5,
"F-CLASS-001: forward must produce exactly num_classes={} logits, got {}",
head.num_classes(),
logits.len()
);
}
#[test]
fn falsify_class_ent_001b_correct_dims_correct_shape() {
for (hidden_size, num_classes, seq_len) in [(32, 3, 1), (128, 10, 5), (256, 2, 8)] {
let head = ClassificationHead::new(hidden_size, num_classes);
let hidden = Tensor::from_vec(vec![0.05f32; seq_len * hidden_size], false);
let logits = head.forward(&hidden, seq_len);
assert_eq!(
logits.len(),
num_classes,
"F-CLASS-001: hidden_size={hidden_size}, num_classes={num_classes}, seq_len={seq_len} \
must produce {num_classes} logits, got {}",
logits.len()
);
}
}
#[test]
#[should_panic(expected = "F-CLASS-002")]
fn falsify_class_ent_002_label_out_of_range() {
let logits = Tensor::from_vec(vec![1.0, 2.0, -1.0, 0.5, 3.0], false);
let _ = cross_entropy_loss(&logits, 5, 5);
}
#[test]
fn falsify_class_ent_002b_label_boundary_valid() {
let logits = Tensor::from_vec(vec![1.0, 2.0, -1.0, 0.5, 3.0], false);
let loss = cross_entropy_loss(&logits, 4, 5);
let loss_val = loss.data()[0];
assert!(
loss_val.is_finite(),
"F-CLASS-005: boundary label=4 must produce finite loss, got {loss_val}"
);
}
#[test]
#[should_panic(expected = "F-CLASS-004")]
fn falsify_class_ent_003_hidden_size_zero() {
let _ = ClassificationHead::new(0, 5);
}
#[test]
#[should_panic(expected = "F-CLASS-004")]
fn falsify_class_ent_003b_num_classes_one() {
let _ = ClassificationHead::new(64, 1);
}
#[test]
fn falsify_class_ent_004_cross_entropy_finite() {
let logits = Tensor::from_vec(vec![1.0, 2.0, -1.0, 0.5, 3.0], false);
let loss = cross_entropy_loss(&logits, 2, 5);
let loss_val = loss.data()[0];
assert!(loss_val.is_finite(), "F-CLASS-005: cross_entropy_loss must be finite, got {loss_val}");
assert!(
loss_val > 0.0,
"Cross-entropy loss must be positive for non-dominant class, got {loss_val}"
);
}
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn falsify_class_ent_001_prop(
hidden_size in 8usize..=512,
num_classes in 2usize..=20,
seq_len in 1usize..=10,
) {
let head = ClassificationHead::new(hidden_size, num_classes);
let hidden = Tensor::from_vec(vec![0.01f32; seq_len * hidden_size], false);
let logits = head.forward(&hidden, seq_len);
prop_assert_eq!(
logits.len(),
num_classes,
"F-CLASS-001-prop: hidden_size={}, num_classes={}, seq_len={} produced {} logits",
hidden_size,
num_classes,
seq_len,
logits.len()
);
}
#[test]
fn falsify_class_ent_003_prop(
hidden_size in 1usize..=512,
num_classes in 2usize..=20,
) {
let head = ClassificationHead::new(hidden_size, num_classes);
prop_assert_eq!(head.hidden_size(), hidden_size);
prop_assert_eq!(head.num_classes(), num_classes);
prop_assert_eq!(head.num_parameters(), hidden_size * num_classes + num_classes);
}
#[test]
fn falsify_class_ent_004_prop(
num_classes in 2usize..=20,
label_offset in 0usize..20,
) {
let label = label_offset % num_classes; let logits_data: Vec<f32> = (0..num_classes)
.map(|i| (i as f32 - num_classes as f32 / 2.0) * 0.5)
.collect();
let logits = Tensor::from_vec(logits_data, false);
let loss = cross_entropy_loss(&logits, label, num_classes);
let loss_val = loss.data()[0];
prop_assert!(
loss_val.is_finite(),
"F-CLASS-005-prop: loss must be finite for num_classes={}, label={}, got {}",
num_classes,
label,
loss_val
);
}
}
}