use super::*;
use crate::{array::Array, dtype::Dtype, error::Error};
const TOL: f32 = 1e-5;
fn close(a: f32, b: f32) -> bool {
(a - b).abs() <= TOL
}
fn vclose(a: &[f32], b: &[f32]) -> bool {
a.len() == b.len() && a.iter().zip(b).all(|(x, y)| close(*x, *y))
}
#[test]
fn validate_ok_for_well_formed_rank3_emb_and_rank2_mask() {
let emb = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(1, 2, 2)).unwrap();
let mask = Array::from_slice(&[1.0_f32, 1.0], &(1, 2)).unwrap();
assert!(validate_token_embeddings_and_mask(&emb, &mask).is_ok());
assert!(validate_token_embeddings_rank3(&emb).is_ok());
}
#[test]
fn validate_rejects_non_rank3_emb_with_observed_rank_and_shape() {
let emb_2d = Array::from_slice(&[1.0_f32, 2.0], &(1, 2)).unwrap();
let mask = Array::from_slice(&[1.0_f32, 1.0], &(1, 2)).unwrap();
match validate_token_embeddings_and_mask(&emb_2d, &mask) {
Err(Error::RankMismatch(p)) => {
assert_eq!(p.actual(), 2, "observed rank");
assert_eq!(p.actual_shape(), &[1, 2], "observed shape");
}
other => panic!("expected RankMismatch, got {other:?}"),
}
let emb_1d = Array::from_slice(&[1.0_f32, 2.0], &(2,)).unwrap();
match validate_token_embeddings_rank3(&emb_1d) {
Err(Error::RankMismatch(p)) => {
assert_eq!(p.actual(), 1);
assert_eq!(p.actual_shape(), &[2]);
}
other => panic!("expected RankMismatch, got {other:?}"),
}
}
#[test]
fn validate_rejects_non_rank2_mask() {
let emb = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(1, 2, 2)).unwrap();
let mask_3d = Array::from_slice(&[1.0_f32, 1.0], &(1, 2, 1)).unwrap();
match validate_token_embeddings_and_mask(&emb, &mask_3d) {
Err(Error::RankMismatch(p)) => {
assert_eq!(p.actual(), 3);
assert_eq!(p.actual_shape(), &[1, 2, 1]);
}
other => panic!("expected RankMismatch, got {other:?}"),
}
}
#[test]
fn validate_rejects_batch_or_seq_mismatch_with_both_shapes() {
let emb = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(1, 2, 2)).unwrap();
let bad_mask = Array::from_slice(&[1.0_f32, 1.0, 1.0], &(1, 3)).unwrap();
match validate_token_embeddings_and_mask(&emb, &bad_mask) {
Err(Error::ShapePairMismatch(p)) => {
assert_eq!(p.expected(), &[1, 2], "emb (batch, seq_len)");
assert_eq!(p.actual(), &[1, 3], "mask (batch, seq_len)");
}
other => panic!("expected ShapePairMismatch, got {other:?}"),
}
}
#[test]
fn mean_pooling_hand_average_over_unmasked() {
let emb = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0], &(1, 3, 2)).unwrap();
let mask = Array::from_slice(&[1.0_f32, 1.0, 0.0], &(1, 3)).unwrap();
let mut p = mean_pooling(&emb, &mask).unwrap();
assert_eq!(p.shape(), vec![1, 2]);
assert!(vclose(&p.to_vec::<f32>().unwrap(), &[2.0, 3.0]));
}
#[test]
fn mean_pooling_all_pad_row_uses_1e9_floor_finite_near_zero() {
let emb = Array::from_slice(&[9.0_f32, 9.0, 7.0, 7.0], &(1, 2, 2)).unwrap();
let mask = Array::from_slice(&[0.0_f32, 0.0], &(1, 2)).unwrap();
let mut p = mean_pooling(&emb, &mask).unwrap();
assert_eq!(p.shape(), vec![1, 2]);
let v = p.to_vec::<f32>().unwrap();
assert!(
v.iter().all(|x| x.is_finite()),
"floor must avoid NaN: {v:?}"
);
assert!(vclose(&v, &[0.0, 0.0]));
}
#[test]
fn mean_pooling_output_is_f32_even_when_input_is_f16() {
let emb = Array::from_slice(&[2.0_f32, 4.0], &(1, 2, 1))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let mask = Array::ones::<f32>(&(1, 2)).unwrap();
let mut p = mean_pooling(&emb, &mask).unwrap();
assert_eq!(p.dtype().unwrap(), Dtype::F32);
assert!(close(p.to_vec::<f32>().unwrap()[0], 3.0));
}
#[test]
fn max_pooling_forces_pad_to_neg_inf_then_maxes() {
let emb = Array::from_slice(&[1.0_f32, 9.0, 8.0, 2.0, 100.0, 100.0], &(1, 3, 2)).unwrap();
let mask = Array::from_slice(&[1.0_f32, 1.0, 0.0], &(1, 3)).unwrap();
let mut p = max_pooling(&emb, &mask).unwrap();
assert_eq!(p.shape(), vec![1, 2]);
assert!(vclose(&p.to_vec::<f32>().unwrap(), &[8.0, 9.0]));
}
#[test]
fn max_pooling_handles_negative_values_under_mask() {
let emb = Array::from_slice(&[-5.0_f32, -2.0], &(1, 2, 1)).unwrap();
let mask = Array::from_slice(&[1.0_f32, 0.0], &(1, 2)).unwrap();
let mut p = max_pooling(&emb, &mask).unwrap();
assert!(close(p.to_vec::<f32>().unwrap()[0], -5.0));
}
#[test]
fn max_pooling_preserves_f16_dtype() {
let emb = Array::from_slice(&[1.0_f32, 4.0], &(1, 2, 1))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let mask = Array::ones::<f32>(&(1, 2)).unwrap();
let p = max_pooling(&emb, &mask).unwrap();
assert_eq!(p.dtype().unwrap(), Dtype::F16);
}
#[test]
fn cls_pooling_picks_argmax_mask_row_under_left_padding() {
let emb = Array::from_slice(&[1.0_f32, 1.0, 2.0, 2.0, 3.0, 3.0], &(1, 3, 2)).unwrap();
let mask = Array::from_slice(&[0.0_f32, 0.0, 1.0], &(1, 3)).unwrap();
let mut p = cls_pooling(&emb, &mask).unwrap();
assert_eq!(p.shape(), vec![1, 2]);
assert!(vclose(&p.to_vec::<f32>().unwrap(), &[3.0, 3.0]));
}
#[test]
fn cls_pooling_all_pad_row_argmax_is_index0() {
let emb = Array::from_slice(&[5.0_f32, 6.0, 7.0, 8.0], &(1, 2, 2)).unwrap();
let mask = Array::from_slice(&[0.0_f32, 0.0], &(1, 2)).unwrap();
let mut p = cls_pooling(&emb, &mask).unwrap();
assert!(vclose(&p.to_vec::<f32>().unwrap(), &[5.0, 6.0]));
}
#[test]
fn last_token_pooling_left_padded_selects_last_real() {
let emb = Array::from_slice(&[1.0_f32, 1.0, 2.0, 2.0, 3.0, 3.0], &(1, 3, 2)).unwrap();
let mask = Array::from_slice(&[0.0_f32, 1.0, 1.0], &(1, 3)).unwrap();
let mut p = last_token_pooling(&emb, &mask).unwrap();
assert_eq!(p.shape(), vec![1, 2]);
assert!(vclose(&p.to_vec::<f32>().unwrap(), &[3.0, 3.0]));
}
#[test]
fn last_token_pooling_right_padded_selects_last_real() {
let emb = Array::from_slice(&[1.0_f32, 1.0, 2.0, 2.0, 9.0, 9.0], &(1, 3, 2)).unwrap();
let mask = Array::from_slice(&[1.0_f32, 1.0, 0.0], &(1, 3)).unwrap();
let mut p = last_token_pooling(&emb, &mask).unwrap();
assert!(vclose(&p.to_vec::<f32>().unwrap(), &[2.0, 2.0]));
}
#[test]
fn last_token_pooling_all_pad_falls_back_to_zeros() {
let emb = Array::from_slice(&[3.0_f32, 4.0, 5.0, 6.0], &(1, 2, 2)).unwrap();
let mask = Array::from_slice(&[0.0_f32, 0.0], &(1, 2)).unwrap();
let mut p = last_token_pooling(&emb, &mask).unwrap();
assert!(vclose(&p.to_vec::<f32>().unwrap(), &[0.0, 0.0]));
}
#[test]
fn last_token_pooling_mixed_pad_batch() {
let emb = Array::from_slice(&[1.0_f32, 1.0, 2.0, 2.0, 7.0, 7.0, 9.0, 9.0], &(2, 2, 2)).unwrap();
let mask = Array::from_slice(&[0.0_f32, 1.0, 1.0, 0.0], &(2, 2)).unwrap();
let mut p = last_token_pooling(&emb, &mask).unwrap();
assert_eq!(p.shape(), vec![2, 2]);
assert!(vclose(&p.to_vec::<f32>().unwrap(), &[2.0, 2.0, 7.0, 7.0]));
}
#[test]
fn first_token_pooling_always_takes_row0_ignoring_mask() {
let emb = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(1, 2, 2)).unwrap();
let mut p = first_token_pooling(&emb).unwrap();
assert_eq!(p.shape(), vec![1, 2]);
assert!(vclose(&p.to_vec::<f32>().unwrap(), &[1.0, 2.0]));
}
#[test]
fn single_token_sequence_all_strategies_return_that_token() {
let emb = Array::from_slice(&[4.0_f32, 5.0], &(1, 1, 2)).unwrap();
let mask = Array::from_slice(&[1.0_f32], &(1, 1)).unwrap();
let want = [4.0_f32, 5.0];
for (label, mut p) in [
("mean", mean_pooling(&emb, &mask).unwrap()),
("max", max_pooling(&emb, &mask).unwrap()),
("cls", cls_pooling(&emb, &mask).unwrap()),
("last", last_token_pooling(&emb, &mask).unwrap()),
("first", first_token_pooling(&emb).unwrap()),
] {
assert_eq!(p.shape(), vec![1, 2], "shape for {label}");
assert!(
vclose(&p.to_vec::<f32>().unwrap(), &want),
"value for {label}"
);
}
}
#[test]
fn pool_dispatches_each_strategy_to_its_reduction() {
let emb = Array::from_slice(&[99.0_f32, 1.0, 2.0, 9.0, 8.0, 3.0], &(1, 3, 2)).unwrap();
let mask = Array::from_slice(&[0.0_f32, 1.0, 1.0], &(1, 3)).unwrap();
for (strat, want) in [
(PoolingStrategy::Mean, [5.0_f32, 6.0]),
(PoolingStrategy::Max, [8.0, 9.0]),
(PoolingStrategy::Cls, [2.0, 9.0]),
(PoolingStrategy::First, [99.0, 1.0]),
(PoolingStrategy::Last, [8.0, 3.0]),
] {
let mut p = pool(&emb, &mask, strat, false, None, false, false).unwrap();
assert_eq!(p.shape(), vec![1, 2], "shape for {strat:?}");
assert!(
vclose(&p.to_vec::<f32>().unwrap(), &want),
"value for {strat:?}"
);
}
}
#[test]
fn pool_none_is_rank3_passthrough() {
let emb = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(1, 2, 2)).unwrap();
let mask = Array::ones::<f32>(&(1, 2)).unwrap();
let mut p = pool(
&emb,
&mask,
PoolingStrategy::None,
false,
None,
false,
false,
)
.unwrap();
assert_eq!(p.shape(), vec![1, 2, 2]);
assert!(vclose(&p.to_vec::<f32>().unwrap(), &[1.0, 2.0, 3.0, 4.0]));
}
#[test]
fn pool_propagates_rank_mismatch_from_validator() {
let emb_2d = Array::from_slice(&[1.0_f32, 2.0], &(1, 2)).unwrap();
let mask = Array::from_slice(&[1.0_f32, 1.0], &(1, 2)).unwrap();
assert!(matches!(
pool(
&emb_2d,
&mask,
PoolingStrategy::Mean,
false,
None,
false,
false
),
Err(Error::RankMismatch(_))
));
}
#[test]
fn pool_post_no_transform_returns_input_unchanged() {
let x = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let mut p = pool_post(x, false, None, false, false).unwrap();
assert_eq!(p.shape(), vec![2, 2]);
assert!(vclose(&p.to_vec::<f32>().unwrap(), &[1.0, 2.0, 3.0, 4.0]));
}
#[test]
fn pool_post_layer_norm_closed_form() {
let x = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(1, 4)).unwrap();
let mut p = pool_post(x, false, None, true, false).unwrap();
assert_eq!(p.shape(), vec![1, 4]);
assert!(vclose(
&p.to_vec::<f32>().unwrap(),
&[-1.3416354, -0.4472118, 0.4472118, 1.3416354],
));
}
#[test]
fn pool_post_rms_norm_closed_form_eps_load_bearing() {
let x = Array::from_slice(&[0.001_f32, 0.001], &(1, 2)).unwrap();
let mut p = pool_post(x, false, None, false, true).unwrap();
assert!(vclose(
&p.to_vec::<f32>().unwrap(),
&[0.30151135, 0.30151135]
));
}
#[test]
fn pool_post_layer_norm_wins_when_both_norm_flags_set() {
let layer_norm_expected = [-1.3416354_f32, -0.4472118, 0.4472118, 1.3416354];
let rms_expected = [0.36514813_f32, 0.73029626, 1.0954444, 1.4605925];
let x = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(1, 4)).unwrap();
let mut p = pool_post(x, false, None, true, true).unwrap();
let got = p.to_vec::<f32>().unwrap();
assert!(
vclose(&got, &layer_norm_expected),
"LayerNorm must win: {got:?}"
);
assert!(!vclose(&got, &rms_expected), "must not be RMSNorm: {got:?}");
}
#[test]
fn pool_post_normalize_only_yields_unit_row() {
let x = Array::from_slice(&[3.0_f32, 4.0], &(1, 2)).unwrap();
let mut p = pool_post(x, true, None, false, false).unwrap();
assert!(vclose(&p.to_vec::<f32>().unwrap(), &[0.6, 0.8]));
}
#[test]
fn pool_post_truncate_before_normalize_order() {
let x = Array::from_slice(&[3.0_f32, 4.0, 99.0, 0.0, 5.0, 12.0], &(2, 3)).unwrap();
let mut p = pool_post(x, true, Some(2), false, false).unwrap();
assert_eq!(p.shape(), vec![2, 2]);
assert!(vclose(&p.to_vec::<f32>().unwrap(), &[0.6, 0.8, 0.0, 1.0]));
}
#[test]
fn truncate_last_dim_rank2_keeps_first_cols() {
let x = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0], &(2, 3)).unwrap();
let mut t = truncate_last_dim(&x, 2).unwrap();
assert_eq!(t.shape(), vec![2, 2]);
assert!(vclose(&t.to_vec::<f32>().unwrap(), &[1.0, 2.0, 4.0, 5.0]));
}
#[test]
fn truncate_last_dim_rank1() {
let x = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(4,)).unwrap();
let mut t = truncate_last_dim(&x, 2).unwrap();
assert_eq!(t.shape(), vec![2]);
assert!(vclose(&t.to_vec::<f32>().unwrap(), &[1.0, 2.0]));
}
#[test]
fn truncate_last_dim_rank3_truncates_only_last_axis() {
let x = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &(2, 2, 2)).unwrap();
let mut t = truncate_last_dim(&x, 1).unwrap();
assert_eq!(t.shape(), vec![2, 2, 1]);
assert!(vclose(&t.to_vec::<f32>().unwrap(), &[1.0, 3.0, 5.0, 7.0]));
}
#[test]
fn truncate_last_dim_noop_when_dimension_ge_last() {
let x = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let mut eq = truncate_last_dim(&x, 2).unwrap();
let mut gt = truncate_last_dim(&x, 5).unwrap();
assert_eq!(eq.shape(), vec![2, 2]);
assert_eq!(gt.shape(), vec![2, 2]);
assert!(vclose(&eq.to_vec::<f32>().unwrap(), &[1.0, 2.0, 3.0, 4.0]));
assert!(vclose(>.to_vec::<f32>().unwrap(), &[1.0, 2.0, 3.0, 4.0]));
}
#[test]
fn pooling_strategy_as_str_and_display_match() {
for (s, name) in [
(PoolingStrategy::Mean, "mean"),
(PoolingStrategy::Cls, "cls"),
(PoolingStrategy::First, "first"),
(PoolingStrategy::Last, "last"),
(PoolingStrategy::Max, "max"),
(PoolingStrategy::None, "none"),
] {
assert_eq!(s.as_str(), name);
assert_eq!(format!("{s}"), name, "Display delegates to as_str");
}
}
#[test]
fn pooling_strategy_is_variant_predicates() {
assert!(PoolingStrategy::Mean.is_mean());
assert!(PoolingStrategy::Cls.is_cls());
assert!(PoolingStrategy::First.is_first());
assert!(PoolingStrategy::Last.is_last());
assert!(PoolingStrategy::Max.is_max());
assert!(PoolingStrategy::None.is_none());
assert!(!PoolingStrategy::Mean.is_max());
assert!(!PoolingStrategy::First.is_cls());
}
#[test]
fn pooling_strategy_from_mode_accepts_known_modes_and_last_alias() {
assert_eq!(
PoolingStrategy::from_mode("cls").unwrap(),
PoolingStrategy::Cls
);
assert_eq!(
PoolingStrategy::from_mode("mean").unwrap(),
PoolingStrategy::Mean
);
assert_eq!(
PoolingStrategy::from_mode("max").unwrap(),
PoolingStrategy::Max
);
assert_eq!(
PoolingStrategy::from_mode("lasttoken").unwrap(),
PoolingStrategy::Last
);
assert_eq!(
PoolingStrategy::from_mode("last").unwrap(),
PoolingStrategy::Last
);
assert_eq!(
PoolingStrategy::from_mode("first").unwrap(),
PoolingStrategy::First
);
assert_eq!(
PoolingStrategy::from_mode("none").unwrap(),
PoolingStrategy::None
);
for s in [
PoolingStrategy::Mean,
PoolingStrategy::Cls,
PoolingStrategy::First,
PoolingStrategy::Last,
PoolingStrategy::Max,
PoolingStrategy::None,
] {
assert_eq!(
PoolingStrategy::from_mode(s.as_str()).unwrap(),
s,
"round-trip {s}"
);
}
}
#[test]
fn pooling_strategy_from_mode_rejects_unsupported_with_typed_payload() {
match PoolingStrategy::from_mode("weightedmean") {
Err(Error::UnknownEnumValue(p)) => {
assert_eq!(p.type_name(), "embeddings::PoolingStrategy");
assert_eq!(p.value(), "weightedmean");
assert_eq!(p.supported(), &["cls", "lasttoken", "max", "mean"]);
}
other => panic!("expected UnknownEnumValue, got {other:?}"),
}
assert!(matches!(
PoolingStrategy::from_mode("xyzzy"),
Err(Error::UnknownEnumValue(_))
));
assert!(matches!(
PoolingStrategy::from_mode("mean_sqrt_len_tokens"),
Err(Error::UnknownEnumValue(_))
));
}
#[test]
fn eps_constants_match_documented_defaults() {
assert_eq!(LAYER_NORM_EPS, 1e-5);
assert_eq!(RMS_NORM_EPS, 1e-5);
}