use burn::backend::NdArray;
use burn::prelude::*;
type B = NdArray;
fn device() -> burn::backend::ndarray::NdArrayDevice {
burn::backend::ndarray::NdArrayDevice::Cpu
}
#[test]
fn classification_head_new_stores_num_classes() {
let head = brainharmony::ClassificationHead::<B>::new(768, 2, &device());
assert_eq!(head.num_classes, 2);
}
#[test]
fn classification_head_forward_shape() {
let head = brainharmony::ClassificationHead::<B>::new(768, 2, &device());
let input = Tensor::<B, 3>::zeros([2, 100, 768], &device());
let logits = head.forward(input);
assert_eq!(logits.dims(), [2, 2]);
}
#[test]
fn classification_head_forward_batch() {
let head = brainharmony::ClassificationHead::<B>::new(768, 3, &device());
let input = Tensor::<B, 3>::zeros([4, 50, 768], &device());
let logits = head.forward(input);
assert_eq!(logits.dims(), [4, 3]);
}
#[test]
fn predict_classes_returns_valid_indices() {
let head = brainharmony::ClassificationHead::<B>::new(768, 2, &device());
let input = Tensor::<B, 3>::random([2, 100, 768], burn::tensor::Distribution::Default, &device());
let logits = head.forward(input);
let classes = brainharmony::predict_classes(logits);
assert_eq!(classes.dims(), [2]);
let vals: Vec<i64> = classes.into_data().to_vec::<i64>().unwrap();
for &v in &vals {
assert!(v == 0 || v == 1, "expected 0 or 1, got {v}");
}
}
#[test]
fn predict_classes_deterministic_for_known_logits() {
let logits = Tensor::<B, 2>::from_data(
TensorData::new(vec![-10.0f32, 10.0, 5.0, -5.0], vec![2, 2]),
&device(),
);
let classes = brainharmony::predict_classes(logits);
let vals: Vec<i64> = classes.into_data().to_vec::<i64>().unwrap();
assert_eq!(vals, vec![1, 0]);
}
#[test]
fn mlp_head_forward_shape() {
let head = brainharmony::MLPHead::<B>::new(768, 384, 4, &device());
let input = Tensor::<B, 2>::zeros([3, 768], &device());
let output = head.forward(input);
assert_eq!(output.dims(), [3, 4]);
}