use inference::batch::{mean_pooling, normalize_embeddings};
#[test]
fn mean_pooling_output_shape_matches_batch_x_hidden() {
let lhs = vec![0.0f32; 3 * 4 * 5]; let mask = vec![1i64; 3 * 4];
let result = mean_pooling(&lhs, 3, 4, 5, &mask);
assert_eq!(result.len(), 3, "should produce one vector per batch item");
assert!(
result.iter().all(|v| v.len() == 5),
"each vector should have hidden_size=5 dims"
);
}
#[test]
fn mean_pooling_single_batch_known_value() {
let lhs = vec![4.0f32, 6.0, 4.0, 6.0, 4.0, 6.0];
let mask = vec![1i64; 3];
let result = mean_pooling(&lhs, 1, 3, 2, &mask);
assert!((result[0][0] - 4.0).abs() < 1e-6, "dim0 should be 4.0");
assert!((result[0][1] - 6.0).abs() < 1e-6, "dim1 should be 6.0");
}
#[test]
fn mean_pooling_masked_tokens_do_not_contribute() {
let lhs = vec![1.0f32, 1.0, 9.0, 9.0];
let mask = vec![1i64, 0i64];
let result = mean_pooling(&lhs, 1, 2, 2, &mask);
assert!(
(result[0][0] - 1.0).abs() < 1e-6,
"masked token [9.0] must not shift result; got {}",
result[0][0]
);
assert!((result[0][1] - 1.0).abs() < 1e-6, "got {}", result[0][1]);
}
#[test]
fn mean_pooling_batch_position_independence() {
let lhs_a_alone = vec![
1.0f32, 2.0, 3.0, 1.0, 2.0, 3.0, ];
let mask_a_alone = vec![1i64, 1i64];
let result_a = mean_pooling(&lhs_a_alone, 1, 2, 3, &mask_a_alone);
let lhs_pair = vec![
1.0f32, 2.0, 3.0, 1.0, 2.0, 3.0, 5.0, 6.0, 7.0, 0.0, 0.0, 0.0, ];
let mask_pair = vec![1i64, 1i64, 1i64, 0i64];
let result_pair = mean_pooling(&lhs_pair, 2, 2, 3, &mask_pair);
for (i, (solo, batched)) in result_a[0].iter().zip(result_pair[0].iter()).enumerate() {
assert!(
(solo - batched).abs() < 1e-6,
"dim {i}: solo={solo} vs batched={batched} — co-batch member B must not affect A"
);
}
}
#[test]
fn mean_pooling_output_ordering_preserves_input_order() {
let lhs = vec![10.0f32, 20.0, 30.0];
let mask = vec![1i64, 1i64, 1i64];
let result = mean_pooling(&lhs, 3, 1, 1, &mask);
assert!(
(result[0][0] - 10.0).abs() < 1e-6,
"item 0 should be 10.0, got {}",
result[0][0]
);
assert!(
(result[1][0] - 20.0).abs() < 1e-6,
"item 1 should be 20.0, got {}",
result[1][0]
);
assert!(
(result[2][0] - 30.0).abs() < 1e-6,
"item 2 should be 30.0, got {}",
result[2][0]
);
}
#[test]
fn mean_pooling_all_mask_zero_does_not_panic_or_nan() {
let lhs = vec![1.0f32; 2 * 2]; let mask = vec![0i64, 0i64];
let result = mean_pooling(&lhs, 1, 2, 2, &mask);
for v in &result[0] {
assert!(
v.is_finite(),
"all-zero mask should not produce NaN/Inf, got {v}"
);
}
}
#[test]
fn mean_pooling_partial_mask_computes_active_average() {
let lhs = vec![2.0f32, 4.0, 8.0, 16.0];
let mask = vec![1i64, 1i64, 0i64, 0i64];
let result = mean_pooling(&lhs, 1, 4, 1, &mask);
assert!(
(result[0][0] - 3.0).abs() < 1e-6,
"expected mean of active tokens = 3.0, got {}",
result[0][0]
);
}
#[test]
fn normalize_l2_norm_is_one() {
let mut embeddings = vec![vec![3.0f32, 4.0]];
normalize_embeddings(&mut embeddings);
let norm: f32 = embeddings[0].iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-5,
"L2 norm after normalization should be 1.0, got {norm}"
);
}
#[test]
fn normalize_known_3_4_vector_values() {
let mut embeddings = vec![vec![3.0f32, 4.0]];
normalize_embeddings(&mut embeddings);
assert!(
(embeddings[0][0] - 0.6).abs() < 1e-5,
"expected 0.6, got {}",
embeddings[0][0]
);
assert!(
(embeddings[0][1] - 0.8).abs() < 1e-5,
"expected 0.8, got {}",
embeddings[0][1]
);
}
#[test]
fn normalize_each_row_normalized_independently() {
let mut embeddings = vec![vec![1.0f32, 0.0], vec![0.0f32, 1.0]];
normalize_embeddings(&mut embeddings);
for (i, row) in embeddings.iter().enumerate() {
let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-5,
"row {i} should have L2 norm 1.0, got {norm}"
);
}
}
#[test]
fn normalize_near_zero_vector_produces_finite_output() {
let mut embeddings = vec![vec![1e-15f32, 1e-15]];
normalize_embeddings(&mut embeddings);
for v in &embeddings[0] {
assert!(
v.is_finite(),
"near-zero vector should not produce NaN/Inf, got {v}"
);
}
}
#[test]
fn normalize_negative_values_produce_unit_vector() {
let mut embeddings = vec![vec![-3.0f32, -4.0]];
normalize_embeddings(&mut embeddings);
let norm: f32 = embeddings[0].iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-5,
"negative input should still normalize to unit vector, got norm={norm}"
);
}
#[test]
fn normalize_already_unit_vector_unchanged() {
let mut embeddings = vec![vec![0.6f32, 0.8]];
normalize_embeddings(&mut embeddings);
assert!(
(embeddings[0][0] - 0.6).abs() < 1e-5,
"idempotent: expected 0.6, got {}",
embeddings[0][0]
);
assert!(
(embeddings[0][1] - 0.8).abs() < 1e-5,
"idempotent: expected 0.8, got {}",
embeddings[0][1]
);
}
#[test]
fn normalize_preserves_output_shape() {
let mut embeddings: Vec<Vec<f32>> = (0..4).map(|_| vec![1.0f32, 2.0, 3.0, 4.0, 5.0]).collect();
normalize_embeddings(&mut embeddings);
assert_eq!(embeddings.len(), 4, "row count unchanged");
assert!(
embeddings.iter().all(|v| v.len() == 5),
"column count unchanged"
);
}