#![cfg(feature = "embeddings")]
use mlxrs::{
Array, Dtype, Error,
embeddings::{
DEFAULT_NORMALIZE_EPS, PoolingStrategy, SWIFT_L2_EPS, cls_pooling, cosine_similarity,
cosine_similarity_matrix, first_token_pooling, l2_normalize, l2_normalize_eps,
last_token_pooling, layer_norm, max_pooling, mean_pooling, normalize, pool, pool_post,
pooling_from_st_config_bytes, pooling_from_st_config_path, pooling_from_st_config_str,
rms_norm, truncate_last_dim,
},
};
const TOL: f32 = 1e-5;
fn close(a: f32, b: f32) -> bool {
(a - b).abs() <= TOL
}
fn vclose(a: &[f32], b: &[f32]) -> bool {
a.len() == b.len() && a.iter().zip(b).all(|(x, y)| close(*x, *y))
}
fn fixture() -> (Array, Array) {
let emb = Array::from_slice(
&[
1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 99.0, 99.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, ],
&(2, 4, 2),
)
.unwrap();
let mask = Array::from_slice(&[1.0_f32, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0], &(2, 4)).unwrap();
(emb, mask)
}
#[test]
fn mean_pooling_of_ones_with_full_mask_is_ones() {
let emb = Array::ones::<f32>(&(1, 3, 2)).unwrap();
let mask = Array::ones::<f32>(&(1, 3)).unwrap();
let mut pooled = mean_pooling(&emb, &mask).unwrap();
assert_eq!(pooled.shape(), vec![1, 2]);
assert_eq!(pooled.to_vec::<f32>().unwrap(), vec![1.0, 1.0]);
}
#[test]
fn mean_pooling_ignores_padding() {
let emb = Array::from_slice(&[1.0_f32, 5.0, 99.0], &(1, 3, 1)).unwrap();
let mask = Array::from_slice(&[1.0_f32, 1.0, 0.0], &(1, 3)).unwrap();
let mut pooled = mean_pooling(&emb, &mask).unwrap();
assert_eq!(pooled.shape(), vec![1, 1]);
assert!(close(pooled.to_vec::<f32>().unwrap()[0], 3.0));
}
#[test]
fn cls_pooling_selects_first_real_token() {
let emb = Array::from_slice(&[1.0_f32, 1.0, 2.0, 2.0, 3.0, 3.0], &(1, 3, 2)).unwrap();
let mask = Array::from_slice(&[0.0_f32, 1.0, 1.0], &(1, 3)).unwrap();
let mut pooled = cls_pooling(&emb, &mask).unwrap();
assert_eq!(pooled.shape(), vec![1, 2]);
assert_eq!(pooled.to_vec::<f32>().unwrap(), vec![2.0, 2.0]);
}
#[test]
fn last_token_pooling_selects_last_real_token() {
let emb = Array::from_slice(&[1.0_f32, 1.0, 2.0, 2.0, 9.0, 9.0], &(1, 3, 2)).unwrap();
let mask = Array::from_slice(&[1.0_f32, 1.0, 0.0], &(1, 3)).unwrap();
let mut pooled = last_token_pooling(&emb, &mask).unwrap();
assert_eq!(pooled.shape(), vec![1, 2]);
assert_eq!(pooled.to_vec::<f32>().unwrap(), vec![2.0, 2.0]);
}
#[test]
fn last_token_pooling_left_padded_selects_last_real_token() {
let emb = Array::from_slice(
&[
9.0_f32, 9.0, 8.0, 8.0, 6.0, 6.0, 7.0, 7.0, ],
&(1, 4, 2),
)
.unwrap();
let mask = Array::from_slice(&[0.0_f32, 0.0, 1.0, 1.0], &(1, 4)).unwrap();
let mut pooled = last_token_pooling(&emb, &mask).unwrap();
assert_eq!(pooled.shape(), vec![1, 2]);
assert!(vclose(&pooled.to_vec::<f32>().unwrap(), &[7.0, 7.0]));
}
#[test]
fn last_token_pooling_mixed_left_and_right_pad_batch() {
let emb = Array::from_slice(
&[
90.0_f32, 90.0, 80.0, 80.0, 60.0, 60.0, 70.0, 70.0, 1.0, 1.0, 2.0, 2.0, 99.0, 99.0, 99.0, 99.0, ],
&(2, 4, 2),
)
.unwrap();
let mask = Array::from_slice(
&[
0.0_f32, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, ],
&(2, 4),
)
.unwrap();
let mut pooled = last_token_pooling(&emb, &mask).unwrap();
assert_eq!(pooled.shape(), vec![2, 2]);
assert!(vclose(
&pooled.to_vec::<f32>().unwrap(),
&[70.0, 70.0, 2.0, 2.0]
));
}
#[test]
fn last_token_pooling_all_pad_row_falls_back_to_zeros() {
let emb = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0], &(1, 3, 2)).unwrap();
let mask = Array::from_slice(&[0.0_f32, 0.0, 0.0], &(1, 3)).unwrap();
let mut pooled = last_token_pooling(&emb, &mask).unwrap();
assert_eq!(pooled.shape(), vec![1, 2]);
assert!(vclose(&pooled.to_vec::<f32>().unwrap(), &[0.0, 0.0]));
}
#[test]
fn last_token_pooling_left_padded_via_dispatcher() {
let emb = Array::from_slice(
&[
9.0_f32, 9.0, 8.0, 8.0, 6.0, 6.0, 7.0, 7.0, ],
&(1, 4, 2),
)
.unwrap();
let mask = Array::from_slice(&[0.0_f32, 0.0, 1.0, 1.0], &(1, 4)).unwrap();
let mut p = pool(
&emb,
&mask,
PoolingStrategy::Last,
false,
None,
false,
false,
)
.unwrap();
assert_eq!(p.shape(), vec![1, 2]);
assert!(vclose(&p.to_vec::<f32>().unwrap(), &[7.0, 7.0]));
}
#[test]
fn last_token_pooling_right_padded_unchanged_regression() {
let (emb, mask) = fixture();
let mut lt = last_token_pooling(&emb, &mask).unwrap();
assert!(vclose(
<.to_vec::<f32>().unwrap(),
&[5.0, 6.0, 70.0, 80.0]
));
}
#[test]
fn l2_normalize_yields_unit_norm() {
let v = Array::from_slice(&[3.0_f32, 4.0], &(1, 2)).unwrap();
let n = l2_normalize(&v).unwrap();
let mut nn = mlxrs::ops::linalg_full::norm(&n, 2.0, &[-1], false).unwrap();
assert!(close(nn.item::<f32>().unwrap(), 1.0));
}
#[test]
fn cosine_similarity_identical_is_one() {
let a = Array::from_slice(&[1.0_f32, 2.0, 3.0], &(3,)).unwrap();
let b = Array::from_slice(&[1.0_f32, 2.0, 3.0], &(3,)).unwrap();
assert!(close(cosine_similarity(&a, &b).unwrap(), 1.0));
}
#[test]
fn cosine_similarity_orthogonal_is_zero() {
let a = Array::from_slice(&[1.0_f32, 0.0], &(2,)).unwrap();
let b = Array::from_slice(&[0.0_f32, 1.0], &(2,)).unwrap();
assert!(close(cosine_similarity(&a, &b).unwrap(), 0.0));
}
#[test]
fn cosine_similarity_matrix_diagonal_is_one() {
let m = Array::from_slice(&[1.0_f32, 0.0, 0.0, 2.0], &(2, 2)).unwrap();
let mut sim = cosine_similarity_matrix(&m).unwrap();
assert_eq!(sim.shape(), vec![2, 2]);
let v = sim.to_vec::<f32>().unwrap();
assert!(close(v[0], 1.0));
assert!(close(v[3], 1.0));
assert!(close(v[1], 0.0));
}
#[test]
fn max_pooling_respects_attention_mask() {
let emb = Array::from_slice(&[1.0_f32, 3.0, 5.0, 10.0], &(1, 4, 1)).unwrap();
let mask = Array::from_slice(&[1.0_f32, 1.0, 1.0, 0.0], &(1, 4)).unwrap();
let mut pooled = max_pooling(&emb, &mask).unwrap();
assert_eq!(pooled.shape(), vec![1, 1]);
assert!(close(pooled.to_vec::<f32>().unwrap()[0], 5.0));
}
#[test]
fn pooling_exact_values_fixture() {
let (emb, mask) = fixture();
let mut m = mean_pooling(&emb, &mask).unwrap();
assert!(vclose(&m.to_vec::<f32>().unwrap(), &[3.0, 4.0, 40.0, 50.0]));
let mut mx = max_pooling(&emb, &mask).unwrap();
assert!(vclose(
&mx.to_vec::<f32>().unwrap(),
&[5.0, 6.0, 70.0, 80.0]
));
let mut lt = last_token_pooling(&emb, &mask).unwrap();
assert!(vclose(
<.to_vec::<f32>().unwrap(),
&[5.0, 6.0, 70.0, 80.0]
));
let mut ft = first_token_pooling(&emb).unwrap();
assert!(vclose(
&ft.to_vec::<f32>().unwrap(),
&[1.0, 2.0, 10.0, 20.0]
));
}
#[test]
fn dispatcher_every_strategy_shapes_and_values() {
let (emb, mask) = fixture();
for (strat, expected) in [
(PoolingStrategy::Mean, vec![3.0, 4.0, 40.0, 50.0]),
(PoolingStrategy::Max, vec![5.0, 6.0, 70.0, 80.0]),
(PoolingStrategy::Last, vec![5.0, 6.0, 70.0, 80.0]),
(PoolingStrategy::First, vec![1.0, 2.0, 10.0, 20.0]),
(PoolingStrategy::Cls, vec![1.0, 2.0, 10.0, 20.0]),
] {
let mut p = pool(&emb, &mask, strat, false, None, false, false).unwrap();
assert_eq!(p.shape(), vec![2, 2], "shape for {strat:?}");
assert!(
vclose(&p.to_vec::<f32>().unwrap(), &expected),
"value for {strat:?}"
);
}
}
#[test]
fn dispatcher_none_is_passthrough() {
let (emb, mask) = fixture();
let mut p = pool(
&emb,
&mask,
PoolingStrategy::None,
false,
None,
false,
false,
)
.unwrap();
assert_eq!(p.shape(), vec![2, 4, 2]);
let mut emb2 = emb;
assert_eq!(p.to_vec::<f32>().unwrap(), emb2.to_vec::<f32>().unwrap());
}
#[test]
fn dispatcher_normalize_flag_yields_unit_rows() {
let (emb, mask) = fixture();
let p = pool(&emb, &mask, PoolingStrategy::Mean, true, None, false, false).unwrap();
let mut n = mlxrs::ops::linalg_full::norm(&p, 2.0, &[-1], false).unwrap();
let norms = n.to_vec::<f32>().unwrap();
assert!(norms.iter().all(|&x| close(x, 1.0)), "rows must be unit");
}
#[test]
fn truncate_last_dim_basic() {
let x = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0], &(2, 3)).unwrap();
let mut t = truncate_last_dim(&x, 2).unwrap();
assert_eq!(t.shape(), vec![2, 2]);
assert!(vclose(&t.to_vec::<f32>().unwrap(), &[1.0, 2.0, 4.0, 5.0]));
}
#[test]
fn truncate_last_dim_noop_when_ge_size() {
let x = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let mut t = truncate_last_dim(&x, 5).unwrap();
assert_eq!(t.shape(), vec![2, 2]);
assert!(vclose(&t.to_vec::<f32>().unwrap(), &[1.0, 2.0, 3.0, 4.0]));
}
#[test]
fn dispatcher_matryoshka_truncation() {
let emb = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 6.0], &(1, 2, 4)).unwrap();
let mask = Array::ones::<f32>(&(1, 2)).unwrap();
let mut p = pool(
&emb,
&mask,
PoolingStrategy::Mean,
false,
Some(2),
false,
false,
)
.unwrap();
assert_eq!(p.shape(), vec![1, 2]);
assert!(vclose(&p.to_vec::<f32>().unwrap(), &[2.0, 3.0]));
}
#[test]
fn normalize_l2_default() {
let v = Array::from_slice(&[3.0_f32, 4.0], &(1, 2)).unwrap();
let mut n = normalize(&v, 2.0, -1, true, DEFAULT_NORMALIZE_EPS).unwrap();
assert!(vclose(&n.to_vec::<f32>().unwrap(), &[0.6, 0.8]));
}
#[test]
fn normalize_l1_p_ne_2() {
let v = Array::from_slice(&[3.0_f32, 4.0], &(1, 2)).unwrap();
let mut n = normalize(&v, 1.0, -1, true, DEFAULT_NORMALIZE_EPS).unwrap();
assert!(vclose(&n.to_vec::<f32>().unwrap(), &[3.0 / 7.0, 4.0 / 7.0]));
}
#[test]
fn normalize_inf_norm() {
let v = Array::from_slice(&[3.0_f32, -4.0], &(1, 2)).unwrap();
let mut n = normalize(&v, f64::INFINITY, -1, true, DEFAULT_NORMALIZE_EPS).unwrap();
assert!(vclose(&n.to_vec::<f32>().unwrap(), &[0.75, -1.0]));
}
#[test]
fn normalize_axis_0_keepdims() {
let v = Array::from_slice(&[3.0_f32, 0.0, 4.0, 0.0], &(2, 2)).unwrap();
let mut n = normalize(&v, 2.0, 0, true, DEFAULT_NORMALIZE_EPS).unwrap();
assert!(vclose(&n.to_vec::<f32>().unwrap(), &[0.6, 0.0, 0.8, 0.0]));
}
#[test]
fn normalize_zero_vector_eps_floor_python_vs_swift() {
let z = Array::from_slice(&[0.0_f32, 0.0], &(1, 2)).unwrap();
let mut py = l2_normalize_eps(&z, DEFAULT_NORMALIZE_EPS).unwrap();
let mut sw = l2_normalize_eps(&z, SWIFT_L2_EPS).unwrap();
assert!(vclose(&py.to_vec::<f32>().unwrap(), &[0.0, 0.0]));
assert!(vclose(&sw.to_vec::<f32>().unwrap(), &[0.0, 0.0]));
const { assert!(DEFAULT_NORMALIZE_EPS > SWIFT_L2_EPS) }; }
#[test]
fn layer_norm_zero_mean_unit_var() {
let x = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(1, 4)).unwrap();
let mut ln = layer_norm(&x, None, None, 1e-5).unwrap();
let v = ln.to_vec::<f32>().unwrap();
let mean: f32 = v.iter().sum::<f32>() / 4.0;
assert!(mean.abs() < 1e-3, "mean ~0, got {mean}");
let var: f32 = v.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / 4.0;
assert!((var - 1.0).abs() < 1e-2, "var ~1, got {var}");
}
#[test]
fn rms_norm_scales_by_rms() {
let x = Array::from_slice(&[3.0_f32, 4.0], &(1, 2)).unwrap();
let mut rn = rms_norm(&x, None, 1e-6).unwrap();
let rms = (12.5_f32).sqrt();
assert!(vclose(
&rn.to_vec::<f32>().unwrap(),
&[3.0 / rms, 4.0 / rms]
));
}
#[test]
fn dispatcher_apply_layer_norm_then_normalize() {
let (emb, mask) = fixture();
let p = pool(&emb, &mask, PoolingStrategy::Mean, true, None, true, false).unwrap();
let mut n = mlxrs::ops::linalg_full::norm(&p, 2.0, &[-1], false).unwrap();
assert!(n.to_vec::<f32>().unwrap().iter().all(|&x| close(x, 1.0)));
}
#[test]
fn dispatcher_apply_rms_norm_path() {
let (emb, mask) = fixture();
let p = pool(&emb, &mask, PoolingStrategy::Mean, false, None, false, true).unwrap();
assert_eq!(p.shape(), vec![2, 2]);
}
#[test]
fn dispatcher_layer_norm_wins_over_rms_when_both_set() {
let (emb, mask) = fixture();
let mut both = pool(&emb, &mask, PoolingStrategy::Mean, false, None, true, true).unwrap();
let mut just_ln = pool(&emb, &mask, PoolingStrategy::Mean, false, None, true, false).unwrap();
assert!(vclose(
&both.to_vec::<f32>().unwrap(),
&just_ln.to_vec::<f32>().unwrap()
));
}
#[test]
fn st_config_modern_pooling_mode_key() {
let cfg = pooling_from_st_config_str(r#"{"pooling_mode": "mean"}"#).unwrap();
assert_eq!(cfg.strategy(), PoolingStrategy::Mean);
assert!(cfg.normalize());
assert_eq!(cfg.dimension(), None);
}
#[test]
fn st_config_word_embedding_dimension_is_matryoshka_dim() {
let cfg = pooling_from_st_config_str(
r#"{"word_embedding_dimension": 384, "pooling_mode_cls_token": true}"#,
)
.unwrap();
assert_eq!(cfg.strategy(), PoolingStrategy::Cls);
assert_eq!(cfg.dimension(), Some(384));
}
#[test]
fn st_config_legacy_mean_only() {
let json = r#"{
"embedding_dimension": 384,
"pooling_mode_cls_token": false,
"pooling_mode_mean_tokens": true,
"pooling_mode_max_tokens": false,
"pooling_mode_mean_sqrt_len_tokens": false,
"pooling_mode_weightedmean_tokens": false,
"pooling_mode_lasttoken": false,
"include_prompt": true
}"#;
let cfg = pooling_from_st_config_bytes(json.as_bytes()).unwrap();
assert_eq!(cfg.strategy(), PoolingStrategy::Mean);
assert_eq!(cfg.dimension(), Some(384));
}
#[test]
fn st_config_legacy_priority_cls_over_mean_over_max_over_last() {
let all_true = r#"{
"pooling_mode_cls_token": true,
"pooling_mode_mean_tokens": true,
"pooling_mode_max_tokens": true,
"pooling_mode_lasttoken": true
}"#;
assert_eq!(
pooling_from_st_config_str(all_true).unwrap().strategy(),
PoolingStrategy::Cls
);
let mean_max_last = r#"{
"pooling_mode_cls_token": false,
"pooling_mode_mean_tokens": true,
"pooling_mode_max_tokens": true,
"pooling_mode_lasttoken": true
}"#;
assert_eq!(
pooling_from_st_config_str(mean_max_last)
.unwrap()
.strategy(),
PoolingStrategy::Mean
);
let max_last = r#"{
"pooling_mode_max_tokens": true,
"pooling_mode_lasttoken": true
}"#;
assert_eq!(
pooling_from_st_config_str(max_last).unwrap().strategy(),
PoolingStrategy::Max
);
let last_only = r#"{"pooling_mode_lasttoken": true}"#;
assert_eq!(
pooling_from_st_config_str(last_only).unwrap().strategy(),
PoolingStrategy::Last
);
}
#[test]
fn st_config_legacy_all_false_defaults_to_mean() {
let json = r#"{
"pooling_mode_cls_token": false,
"pooling_mode_mean_tokens": false,
"pooling_mode_max_tokens": false,
"pooling_mode_lasttoken": false
}"#;
assert_eq!(
pooling_from_st_config_str(json).unwrap().strategy(),
PoolingStrategy::Mean
);
}
#[test]
fn st_config_unsupported_mode_rejected() {
assert!(pooling_from_st_config_str(r#"{"pooling_mode": "weightedmean"}"#).is_err());
assert!(pooling_from_st_config_str(r#"{"pooling_mode_weightedmean_tokens": true}"#).is_err());
assert!(pooling_from_st_config_str(r#"{"pooling_mode": "bogus"}"#).is_err());
}
#[test]
fn st_config_include_prompt_false_rejected() {
assert!(
pooling_from_st_config_str(r#"{"pooling_mode": "mean", "include_prompt": false}"#).is_err()
);
}
#[test]
fn st_config_concatenated_list_mode_rejected() {
assert!(pooling_from_st_config_str(r#"{"pooling_mode": ["cls", "mean"]}"#).is_err());
}
#[test]
fn st_config_present_malformed_pooling_mode_rejected() {
for (json, what) in [
(r#"{"pooling_mode": null}"#, "null"),
(r#"{"pooling_mode": false}"#, "bool false"),
(r#"{"pooling_mode": true}"#, "bool true"),
(r#"{"pooling_mode": 2}"#, "number"),
(r#"{"pooling_mode": 1.5}"#, "fractional number"),
(r#"{"pooling_mode": {"a": 1}}"#, "object"),
] {
let r = pooling_from_st_config_str(json);
assert!(
matches!(r, Err(Error::OutOfRange(_))),
"present malformed pooling_mode ({what}) must be Err(OutOfRange), got {r:?}"
);
assert!(
r.is_err(),
"must not silently fall back to a strategy for {what}: {r:?}"
);
}
let r = pooling_from_st_config_str(r#"{"pooling_mode": null, "pooling_mode_mean_tokens": true}"#);
assert!(
matches!(r, Err(Error::OutOfRange(_))),
"malformed pooling_mode alongside legacy flags must still be Err, got {r:?}"
);
}
#[test]
fn st_config_present_invalid_dimension_rejected() {
for (json, what) in [
(
r#"{"pooling_mode": "mean", "word_embedding_dimension": -1}"#,
"negative",
),
(
r#"{"pooling_mode": "mean", "word_embedding_dimension": 1.5}"#,
"fractional",
),
(
r#"{"pooling_mode": "mean", "word_embedding_dimension": "384"}"#,
"string",
),
(
r#"{"pooling_mode": "mean", "word_embedding_dimension": null}"#,
"null",
),
(
r#"{"pooling_mode": "mean", "word_embedding_dimension": false}"#,
"bool",
),
(
r#"{"pooling_mode": "mean", "word_embedding_dimension": 0}"#,
"zero (empty embedding)",
),
(
r#"{"pooling_mode": "mean", "word_embedding_dimension": 99999999999999999999999999}"#,
"overflow > usize",
),
(
r#"{"pooling_mode": "mean", "embedding_dimension": -5}"#,
"negative (embedding_dimension alias)",
),
] {
let r = pooling_from_st_config_str(json);
assert!(
matches!(r, Err(Error::Parse(_)) | Err(Error::OutOfRange(_))),
"present invalid dimension ({what}) must be Err(Parse) or Err(OutOfRange), got {r:?}"
);
}
let r = pooling_from_st_config_str(
r#"{"pooling_mode": "mean", "word_embedding_dimension": -1, "embedding_dimension": 384}"#,
);
assert!(
matches!(r, Err(Error::Parse(_)) | Err(Error::OutOfRange(_))),
"invalid primary key must reject, not fall back to the alias, got {r:?}"
);
assert_eq!(
pooling_from_st_config_str(r#"{"pooling_mode": "mean"}"#)
.unwrap()
.dimension(),
None
);
assert_eq!(
pooling_from_st_config_str(r#"{"pooling_mode": "mean", "word_embedding_dimension": 256}"#)
.unwrap()
.dimension(),
Some(256)
);
}
#[test]
fn st_config_end_to_end_drives_dispatcher() {
let (emb, mask) = fixture();
let cfg = pooling_from_st_config_str(
r#"{"pooling_mode_max_tokens": true, "word_embedding_dimension": 1}"#,
)
.unwrap();
assert_eq!(cfg.strategy(), PoolingStrategy::Max);
let mut p = pool(
&emb,
&mask,
cfg.strategy(),
cfg.normalize(),
cfg.dimension(),
false,
false,
)
.unwrap();
assert_eq!(p.shape(), vec![2, 1]);
assert!(vclose(&p.to_vec::<f32>().unwrap(), &[1.0, 1.0]));
}
#[test]
fn pooling_strategy_from_mode() {
assert_eq!(
PoolingStrategy::from_mode("cls").unwrap(),
PoolingStrategy::Cls
);
assert_eq!(
PoolingStrategy::from_mode("lasttoken").unwrap(),
PoolingStrategy::Last
);
assert_eq!(
PoolingStrategy::from_mode("max").unwrap(),
PoolingStrategy::Max
);
assert_eq!(
PoolingStrategy::from_mode("mean").unwrap(),
PoolingStrategy::Mean
);
assert_eq!(
PoolingStrategy::from_mode("first").unwrap(),
PoolingStrategy::First
);
assert_eq!(
PoolingStrategy::from_mode("none").unwrap(),
PoolingStrategy::None
);
assert!(PoolingStrategy::from_mode("weightedmean").is_err());
assert!(PoolingStrategy::from_mode("xyzzy").is_err());
}
fn left_padded_fixture() -> (Array, Array) {
let emb = Array::from_slice(
&[
0.0_f32, 0.0, 9.0, 9.0, 3.0, 3.0, 4.0, 4.0, 0.0, 0.0, 200.0, 200.0, 300.0, 300.0, 400.0, 400.0, ],
&(2, 4, 2),
)
.unwrap();
let mask = Array::from_slice(&[0.0_f32, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0], &(2, 4)).unwrap();
(emb, mask)
}
#[test]
fn cls_dispatcher_is_mask_aware_on_left_padded_batch() {
let (emb, mask) = left_padded_fixture();
let mut cls = pool(&emb, &mask, PoolingStrategy::Cls, false, None, false, false).unwrap();
assert_eq!(cls.shape(), vec![2, 2]);
assert!(
vclose(&cls.to_vec::<f32>().unwrap(), &[3.0, 3.0, 200.0, 200.0]),
"Cls must select first real token (py cls_pooling), not pos-0"
);
let mut first = pool(
&emb,
&mask,
PoolingStrategy::First,
false,
None,
false,
false,
)
.unwrap();
assert!(
vclose(&first.to_vec::<f32>().unwrap(), &[0.0, 0.0, 0.0, 0.0]),
"First must stay strict token-0 (swift .first)"
);
let mut direct = cls_pooling(&emb, &mask).unwrap();
assert_eq!(
direct.to_vec::<f32>().unwrap(),
cls.to_vec::<f32>().unwrap()
);
}
#[test]
fn st_config_resolved_cls_drives_mask_aware_dispatcher() {
let (emb, mask) = left_padded_fixture();
for json in [
r#"{"pooling_mode": "cls"}"#,
r#"{"pooling_mode_cls_token": true}"#,
] {
let cfg = pooling_from_st_config_str(json).unwrap();
assert_eq!(
cfg.strategy(),
PoolingStrategy::Cls,
"ST CLS key must map to Cls (mask-aware), not First: {json}"
);
let mut p = pool(&emb, &mask, cfg.strategy(), false, None, false, false).unwrap();
assert!(
vclose(&p.to_vec::<f32>().unwrap(), &[3.0, 3.0, 200.0, 200.0]),
"ST-config CLS must select first real token (py cls_pooling): {json}"
);
}
assert_eq!(
PoolingStrategy::from_mode("cls").unwrap(),
PoolingStrategy::Cls
);
}
#[test]
fn cls_and_first_coincide_when_no_left_padding() {
let (emb, mask) = fixture(); let mut cls = pool(&emb, &mask, PoolingStrategy::Cls, false, None, false, false).unwrap();
let mut first = pool(
&emb,
&mask,
PoolingStrategy::First,
false,
None,
false,
false,
)
.unwrap();
assert_eq!(cls.to_vec::<f32>().unwrap(), first.to_vec::<f32>().unwrap());
assert!(vclose(
&cls.to_vec::<f32>().unwrap(),
&[1.0, 2.0, 10.0, 20.0]
));
}
#[test]
fn st_config_path_rejects_oversize_file_without_oom() {
let dir = std::env::temp_dir().join(format!(
"mlxrs-q20-oversize-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos()
));
let pooling_dir = dir.join("1_Pooling");
std::fs::create_dir_all(&pooling_dir).unwrap();
let path = pooling_dir.join("config.json");
let mut blob = String::from(r#"{"pooling_mode": "mean", "_pad": ""#);
blob.push_str(&"A".repeat(2 * 1024 * 1024));
blob.push_str(r#""}"#);
std::fs::write(&path, &blob).unwrap();
let r = pooling_from_st_config_path(&dir);
assert!(
matches!(r, Err(Error::CapExceeded(_))),
"oversize config must yield Err(CapExceeded), got {r:?}"
);
std::fs::write(&path, r#"{"pooling_mode": "cls"}"#).unwrap();
let cfg = pooling_from_st_config_path(&dir).unwrap();
assert_eq!(cfg.strategy(), PoolingStrategy::Cls);
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn st_config_path_rejects_non_regular_file_without_hang() {
let dir = std::env::temp_dir().join(format!(
"mlxrs-q20-nonreg-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos()
));
let cfg_as_dir = dir.join("1_Pooling").join("config.json");
std::fs::create_dir_all(&cfg_as_dir).unwrap();
let r = pooling_from_st_config_path(&dir);
assert!(
matches!(r, Err(Error::FileIo(_))),
"non-regular (directory) config path must yield a recoverable \
Err(FileIo) without hang/panic, got {r:?}"
);
std::fs::remove_dir_all(&cfg_as_dir).unwrap();
std::fs::write(&cfg_as_dir, r#"{"pooling_mode": "max"}"#).unwrap();
let cfg = pooling_from_st_config_path(&dir).unwrap();
assert_eq!(cfg.strategy(), PoolingStrategy::Max);
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn st_config_path_accepts_file_at_exact_cap() {
let dir = std::env::temp_dir().join(format!(
"mlxrs-q20-atcap-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos()
));
let pooling_dir = dir.join("1_Pooling");
std::fs::create_dir_all(&pooling_dir).unwrap();
let path = pooling_dir.join("config.json");
let prefix = r#"{"pooling_mode": "mean", "_pad": ""#;
let suffix = r#""}"#;
let cap = 1usize << 20;
let pad = cap - prefix.len() - suffix.len();
let mut blob = String::with_capacity(cap);
blob.push_str(prefix);
blob.push_str(&"A".repeat(pad));
blob.push_str(suffix);
assert_eq!(blob.len(), cap, "blob must be exactly the cap");
std::fs::write(&path, &blob).unwrap();
let cfg = pooling_from_st_config_path(&dir).unwrap();
assert_eq!(cfg.strategy(), PoolingStrategy::Mean);
std::fs::remove_dir_all(&dir).ok();
}
#[cfg(unix)]
#[test]
fn st_config_path_fifo_returns_err_without_hang() {
use std::sync::mpsc;
let dir = std::env::temp_dir().join(format!(
"mlxrs-q20-fifo-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos()
));
let pooling_dir = dir.join("1_Pooling");
std::fs::create_dir_all(&pooling_dir).unwrap();
let path = pooling_dir.join("config.json");
use std::os::unix::ffi::OsStrExt;
let c_path = std::ffi::CString::new(path.as_os_str().as_bytes()).unwrap();
let rc = unsafe { libc::mkfifo(c_path.as_ptr(), 0o600) };
assert_eq!(rc, 0, "mkfifo failed (errno-based rc {rc})");
let probe_dir = dir.clone();
let (tx, rx) = mpsc::channel();
let handle = std::thread::spawn(move || {
let r = pooling_from_st_config_path(&probe_dir);
let _ = tx.send(matches!(r, Err(Error::FileIo(_))));
});
match rx.recv_timeout(std::time::Duration::from_secs(30)) {
Ok(is_recoverable_err) => {
handle.join().unwrap();
assert!(
is_recoverable_err,
"FIFO at config.json must yield a recoverable Err(FileIo) \
(rejected by is_file()==false), not Ok"
);
}
Err(_) => {
std::fs::remove_dir_all(&dir).ok();
panic!(
"pooling_from_st_config_path HUNG on a writer-less FIFO at \
config.json — the O_NONBLOCK open regressed"
);
}
}
std::fs::remove_file(&path).unwrap();
std::fs::write(&path, r#"{"pooling_mode": "last"}"#).unwrap();
let cfg = pooling_from_st_config_path(&dir).unwrap();
assert_eq!(cfg.strategy(), PoolingStrategy::Last);
std::fs::remove_dir_all(&dir).ok();
}
#[cfg(unix)]
#[test]
fn st_config_path_follows_symlink_to_regular_file() {
let dir = std::env::temp_dir().join(format!(
"mlxrs-q20-symlink-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos()
));
let blobs_dir = dir.join("blobs");
std::fs::create_dir_all(&blobs_dir).unwrap();
let blob = blobs_dir.join("deadbeefcafef00d");
std::fs::write(
&blob,
r#"{"pooling_mode": "cls", "word_embedding_dimension": 384}"#,
)
.unwrap();
let pooling_dir = dir.join("1_Pooling");
std::fs::create_dir_all(&pooling_dir).unwrap();
let cfg_path = pooling_dir.join("config.json");
std::os::unix::fs::symlink(&blob, &cfg_path).unwrap();
let cfg = pooling_from_st_config_path(&dir).expect(
"HF-cache symlink → regular config.json must be followed and parsed, \
not rejected (O_NOFOLLOW regressed)",
);
assert_eq!(cfg.strategy(), PoolingStrategy::Cls);
assert!(cfg.normalize());
assert_eq!(cfg.dimension(), Some(384));
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn pooling_helpers_reject_non_rank3_token_embeddings_without_panic() {
let mask = Array::from_slice(&[1.0_f32, 1.0], &(1, 2)).unwrap();
let emb_1d = Array::from_slice(&[1.0_f32, 2.0], &(2,)).unwrap();
let emb_2d = Array::from_slice(&[1.0_f32, 2.0], &(1, 2)).unwrap();
for emb in [&emb_1d, &emb_2d] {
assert!(matches!(
mean_pooling(emb, &mask),
Err(Error::RankMismatch(_))
));
assert!(matches!(
max_pooling(emb, &mask),
Err(Error::RankMismatch(_))
));
assert!(matches!(
cls_pooling(emb, &mask),
Err(Error::RankMismatch(_))
));
assert!(matches!(
last_token_pooling(emb, &mask),
Err(Error::RankMismatch(_))
));
assert!(matches!(
first_token_pooling(emb),
Err(Error::RankMismatch(_))
));
assert!(matches!(
pool(emb, &mask, PoolingStrategy::Mean, false, None, false, false),
Err(Error::RankMismatch(_))
));
assert!(matches!(
pool(emb, &mask, PoolingStrategy::Cls, false, None, false, false),
Err(Error::RankMismatch(_))
));
}
}
#[test]
fn pooling_helpers_reject_wrong_rank_mask_without_panic() {
let emb = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(1, 2, 2)).unwrap();
let mask_1d = Array::from_slice(&[1.0_f32, 1.0], &(2,)).unwrap();
let mask_3d = Array::from_slice(&[1.0_f32, 1.0], &(1, 2, 1)).unwrap();
for mask in [&mask_1d, &mask_3d] {
assert!(matches!(
mean_pooling(&emb, mask),
Err(Error::RankMismatch(_))
));
assert!(matches!(
max_pooling(&emb, mask),
Err(Error::RankMismatch(_))
));
assert!(matches!(
cls_pooling(&emb, mask),
Err(Error::RankMismatch(_))
));
assert!(matches!(
last_token_pooling(&emb, mask),
Err(Error::RankMismatch(_))
));
}
}
#[test]
fn pooling_helpers_reject_mismatched_batch_or_seq_dims() {
let emb = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(1, 2, 2)).unwrap();
let bad_mask = Array::from_slice(&[1.0_f32, 1.0, 1.0], &(1, 3)).unwrap();
assert!(matches!(
mean_pooling(&emb, &bad_mask),
Err(Error::ShapePairMismatch(_))
));
assert!(matches!(
cls_pooling(&emb, &bad_mask),
Err(Error::ShapePairMismatch(_))
));
}
fn to_f16_f32(v: f32) -> f32 {
half::f16::from_f32(v).to_f32()
}
fn to_bf16_f32(v: f32) -> f32 {
half::bf16::from_f32(v).to_f32()
}
fn half_close(dt: Dtype, got: f32, want: f32) -> bool {
let rel = match dt {
Dtype::F16 => 1.0 / 1024.0, Dtype::BF16 => 1.0 / 128.0, _ => TOL,
};
let tol = rel * want.abs().max(1.0) * 4.0; (got - want).abs() <= tol
}
fn fixture_dt(dt: Dtype) -> (Array, Array) {
let (emb_f32, mask) = fixture();
(emb_f32.astype(dt).unwrap(), mask)
}
fn assert_dtype(a: &Array, want: Dtype, ctx: &str) {
assert_eq!(a.dtype().unwrap(), want, "output dtype for {ctx}");
}
#[test]
fn max_pooling_f16_bf16_preserve_dtype_and_values() {
for dt in [Dtype::F16, Dtype::BF16] {
let (emb, mask) = fixture_dt(dt);
let mut p = max_pooling(&emb, &mask).unwrap();
assert_dtype(&p, dt, "max_pooling");
let v = match dt {
Dtype::F16 => p
.to_vec::<half::f16>()
.unwrap()
.iter()
.map(|x| x.to_f32())
.collect::<Vec<_>>(),
_ => p
.to_vec::<half::bf16>()
.unwrap()
.iter()
.map(|x| x.to_f32())
.collect::<Vec<_>>(),
};
assert_eq!(v, vec![5.0, 6.0, 70.0, 80.0], "max_pooling {dt:?}");
}
}
#[test]
fn cls_pooling_f16_bf16_preserve_dtype_and_values() {
for dt in [Dtype::F16, Dtype::BF16] {
let (emb, mask) = fixture_dt(dt);
let mut p = cls_pooling(&emb, &mask).unwrap();
assert_dtype(&p, dt, "cls_pooling");
let v = match dt {
Dtype::F16 => p
.to_vec::<half::f16>()
.unwrap()
.iter()
.map(|x| x.to_f32())
.collect::<Vec<_>>(),
_ => p
.to_vec::<half::bf16>()
.unwrap()
.iter()
.map(|x| x.to_f32())
.collect::<Vec<_>>(),
};
assert_eq!(v, vec![1.0, 2.0, 10.0, 20.0], "cls_pooling {dt:?}");
}
}
#[test]
fn first_token_pooling_f16_bf16_preserve_dtype_and_values() {
for dt in [Dtype::F16, Dtype::BF16] {
let (emb, _mask) = fixture_dt(dt);
let mut p = first_token_pooling(&emb).unwrap();
assert_dtype(&p, dt, "first_token_pooling");
let v = match dt {
Dtype::F16 => p
.to_vec::<half::f16>()
.unwrap()
.iter()
.map(|x| x.to_f32())
.collect::<Vec<_>>(),
_ => p
.to_vec::<half::bf16>()
.unwrap()
.iter()
.map(|x| x.to_f32())
.collect::<Vec<_>>(),
};
assert_eq!(v, vec![1.0, 2.0, 10.0, 20.0], "first_token_pooling {dt:?}");
}
}
#[test]
fn last_token_pooling_f16_bf16_preserve_dtype_and_values() {
for dt in [Dtype::F16, Dtype::BF16] {
let (emb, mask) = fixture_dt(dt);
let mut p = last_token_pooling(&emb, &mask).unwrap();
assert_dtype(&p, dt, "last_token_pooling");
let v = match dt {
Dtype::F16 => p
.to_vec::<half::f16>()
.unwrap()
.iter()
.map(|x| x.to_f32())
.collect::<Vec<_>>(),
_ => p
.to_vec::<half::bf16>()
.unwrap()
.iter()
.map(|x| x.to_f32())
.collect::<Vec<_>>(),
};
assert_eq!(v, vec![5.0, 6.0, 70.0, 80.0], "last_token_pooling {dt:?}");
}
}
#[test]
fn mean_pooling_f16_bf16_matches_python_f32_upcast() {
for dt in [Dtype::F16, Dtype::BF16] {
let (emb, mask) = fixture_dt(dt);
let mut p = mean_pooling(&emb, &mask).unwrap();
assert_dtype(&p, Dtype::F32, "mean_pooling (python forces f32)");
assert!(
vclose(&p.to_vec::<f32>().unwrap(), &[3.0, 4.0, 40.0, 50.0]),
"mean_pooling value {dt:?}"
);
}
}
#[test]
fn normalize_l2_f16_bf16_preserve_dtype_and_value() {
let base = [3.0_f32, 4.0, 0.0, 12.0]; let x_f32 = Array::from_slice(&base, &(2, 2)).unwrap();
let mut ref_f32 = l2_normalize(&x_f32).unwrap();
let exp_f32 = ref_f32.to_vec::<f32>().unwrap();
for dt in [Dtype::F16, Dtype::BF16] {
let x = Array::from_slice(&base, &(2, 2))
.unwrap()
.astype(dt)
.unwrap();
let mut p = l2_normalize(&x).unwrap();
assert_dtype(&p, dt, "l2_normalize");
let got = match dt {
Dtype::F16 => p
.to_vec::<half::f16>()
.unwrap()
.iter()
.map(|x| x.to_f32())
.collect::<Vec<_>>(),
_ => p
.to_vec::<half::bf16>()
.unwrap()
.iter()
.map(|x| x.to_f32())
.collect::<Vec<_>>(),
};
for (g, w) in got.iter().zip(&exp_f32) {
let want = if dt == Dtype::F16 {
to_f16_f32(*w)
} else {
to_bf16_f32(*w)
};
assert!(
half_close(dt, *g, want),
"l2_normalize {dt:?}: got {g} want ~{want}"
);
}
}
}
#[test]
fn normalize_param_p_f16_bf16_preserve_dtype() {
let base = [1.0_f32, 1.0, 2.0, 2.0];
for dt in [Dtype::F16, Dtype::BF16] {
let x = Array::from_slice(&base, &(2, 2))
.unwrap()
.astype(dt)
.unwrap();
let p = normalize(&x, 1.0, -1, true, DEFAULT_NORMALIZE_EPS).unwrap();
assert_dtype(&p, dt, "normalize p=1");
let p2 = normalize(&x, 2.0, -1, true, SWIFT_L2_EPS).unwrap();
assert_dtype(&p2, dt, "normalize p=2 swift-eps");
}
}
#[test]
fn normalize_zero_vector_f16_bf16_eps_floor_in_dtype() {
for dt in [Dtype::F16, Dtype::BF16] {
let x = Array::from_slice(&[0.0_f32, 0.0, 0.0], &(1, 3))
.unwrap()
.astype(dt)
.unwrap();
let mut p = l2_normalize_eps(&x, 1e-2).unwrap();
assert_dtype(&p, dt, "l2_normalize zero-vector (eps 1e-2)");
let got = match dt {
Dtype::F16 => p
.to_vec::<half::f16>()
.unwrap()
.iter()
.map(|x| x.to_f32())
.collect::<Vec<_>>(),
_ => p
.to_vec::<half::bf16>()
.unwrap()
.iter()
.map(|x| x.to_f32())
.collect::<Vec<_>>(),
};
assert_eq!(got, vec![0.0, 0.0, 0.0], "zero-vector eps 1e-2 {dt:?}");
let mut q = l2_normalize(&x).unwrap();
assert_dtype(&q, dt, "l2_normalize zero-vector (default eps 1e-9)");
let qv = match dt {
Dtype::F16 => q
.to_vec::<half::f16>()
.unwrap()
.iter()
.map(|x| x.to_f32())
.collect::<Vec<_>>(),
_ => q
.to_vec::<half::bf16>()
.unwrap()
.iter()
.map(|x| x.to_f32())
.collect::<Vec<_>>(),
};
match dt {
Dtype::F16 => assert!(
qv.iter().all(|v| v.is_nan()),
"default-eps zero-vector F16 is python-faithful NaN (1e-9 underflows in f16), got {qv:?}"
),
_ => assert_eq!(
qv,
vec![0.0, 0.0, 0.0],
"default-eps zero-vector BF16 is 0.0 (1e-9 representable in bf16)"
),
}
}
}
#[test]
fn dispatcher_normalize_f16_bf16_preserve_dtype() {
for dt in [Dtype::F16, Dtype::BF16] {
let (emb, mask) = fixture_dt(dt);
for strat in [
PoolingStrategy::Max,
PoolingStrategy::Cls,
PoolingStrategy::Last,
PoolingStrategy::First,
] {
let mut p = pool(&emb, &mask, strat, true, None, false, false).unwrap();
assert_dtype(&p, dt, &format!("pool {strat:?} normalize=true"));
let shape = p.shape();
let cols = shape[1];
let got = match dt {
Dtype::F16 => p
.to_vec::<half::f16>()
.unwrap()
.iter()
.map(|x| x.to_f32())
.collect::<Vec<_>>(),
_ => p
.to_vec::<half::bf16>()
.unwrap()
.iter()
.map(|x| x.to_f32())
.collect::<Vec<_>>(),
};
for row in got.chunks(cols) {
let n: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
half_close(dt, n, 1.0),
"unit norm {strat:?} {dt:?}: |row|={n}"
);
}
}
}
}
#[test]
fn dispatcher_mean_normalize_f16_is_f32_python_parity() {
let (emb, mask) = fixture_dt(Dtype::F16);
let mut p = pool(&emb, &mask, PoolingStrategy::Mean, true, None, false, false).unwrap();
assert_dtype(&p, Dtype::F32, "pool Mean normalize=true (python f32)");
let v = p.to_vec::<f32>().unwrap();
for row in v.chunks(2) {
let n = (row[0] * row[0] + row[1] * row[1]).sqrt();
assert!(close(n, 1.0), "unit norm mean f32: {n}");
}
}
#[test]
fn cosine_similarity_matrix_f16_bf16_preserve_dtype() {
let base = [1.0_f32, 0.0, 0.0, 1.0, 1.0, 1.0]; for dt in [Dtype::F16, Dtype::BF16] {
let x = Array::from_slice(&base, &(3, 2))
.unwrap()
.astype(dt)
.unwrap();
let mut m = cosine_similarity_matrix(&x).unwrap();
assert_dtype(&m, dt, "cosine_similarity_matrix");
assert_eq!(m.shape(), vec![3, 3]);
let got = match dt {
Dtype::F16 => m
.to_vec::<half::f16>()
.unwrap()
.iter()
.map(|x| x.to_f32())
.collect::<Vec<_>>(),
_ => m
.to_vec::<half::bf16>()
.unwrap()
.iter()
.map(|x| x.to_f32())
.collect::<Vec<_>>(),
};
for i in 0..3 {
assert!(
half_close(dt, got[i * 3 + i], 1.0),
"diag[{i}] {dt:?} = {}",
got[i * 3 + i]
);
}
}
}
#[test]
fn cosine_similarity_scalar_f16_bf16_returns_similarity() {
let av = [3.0_f32, 4.0];
let bv = [4.0_f32, 3.0];
let a_f32 = Array::from_slice(&av, &(2,)).unwrap();
let b_f32 = Array::from_slice(&bv, &(2,)).unwrap();
let ref_f32 = cosine_similarity(&a_f32, &b_f32).unwrap();
for dt in [Dtype::F16, Dtype::BF16] {
let a = Array::from_slice(&av, &(2,)).unwrap().astype(dt).unwrap();
let b = Array::from_slice(&bv, &(2,)).unwrap().astype(dt).unwrap();
let got = cosine_similarity(&a, &b)
.unwrap_or_else(|e| panic!("cosine_similarity {dt:?} errored: {e:?}"));
let want = if dt == Dtype::F16 {
to_f16_f32(ref_f32)
} else {
to_bf16_f32(ref_f32)
};
assert!(
half_close(dt, got, want),
"cosine_similarity scalar {dt:?}: got {got} want ~{want}"
);
}
}
#[test]
fn cosine_similarity_scalar_f16_bf16_identical_is_one() {
let v = [1.0_f32, 2.0, 3.0];
for dt in [Dtype::F16, Dtype::BF16] {
let a = Array::from_slice(&v, &(3,)).unwrap().astype(dt).unwrap();
let b = Array::from_slice(&v, &(3,)).unwrap().astype(dt).unwrap();
let got = cosine_similarity(&a, &b)
.unwrap_or_else(|e| panic!("cosine_similarity {dt:?} errored: {e:?}"));
assert!(
half_close(dt, got, 1.0),
"cosine_similarity identical {dt:?} = {got}"
);
}
}
#[test]
fn cosine_similarity_scalar_f32_unchanged_after_final_cast() {
let a = Array::from_slice(&[3.0_f32, 4.0], &(2,)).unwrap();
let b = Array::from_slice(&[4.0_f32, 3.0], &(2,)).unwrap();
assert_eq!(cosine_similarity(&a, &b).unwrap(), 0.96_f32);
let i = Array::from_slice(&[1.0_f32, 2.0, 3.0], &(3,)).unwrap();
assert!(close(cosine_similarity(&i, &i).unwrap(), 1.0));
let e1 = Array::from_slice(&[1.0_f32, 0.0], &(2,)).unwrap();
let e2 = Array::from_slice(&[0.0_f32, 1.0], &(2,)).unwrap();
assert_eq!(cosine_similarity(&e1, &e2).unwrap(), 0.0_f32);
}
#[test]
fn cosine_similarity_rejects_broadcastable_length_mismatch() {
let a = Array::from_slice(&[1.0_f32, 2.0, 3.0], &(3,)).unwrap();
let b = Array::from_slice(&[1.0_f32], &(1,)).unwrap();
let r = cosine_similarity(&a, &b);
assert!(
matches!(r, Err(Error::LengthMismatch(_))),
"expected Err(LengthMismatch), got {r:?}"
);
assert!(r.is_err(), "must not return a (possibly > 1) value: {r:?}");
}
#[test]
fn cosine_similarity_rejects_unequal_lengths() {
let a = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(4,)).unwrap();
let b = Array::from_slice(&[1.0_f32, 2.0, 3.0], &(3,)).unwrap();
assert!(matches!(
cosine_similarity(&a, &b),
Err(Error::LengthMismatch(_))
));
}
#[test]
fn cosine_similarity_rejects_non_rank1() {
let m = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let s = Array::from_slice(&[1.0_f32], &(1, 1)).unwrap();
assert!(matches!(
cosine_similarity(&m, &s),
Err(Error::RankMismatch(_))
));
let v = Array::from_slice(&[1.0_f32], &(1,)).unwrap();
assert!(matches!(
cosine_similarity(&v, &s),
Err(Error::RankMismatch(_))
));
assert!(matches!(
cosine_similarity(&s, &v),
Err(Error::RankMismatch(_))
));
}
#[test]
fn cosine_similarity_zero_vector_is_finite_zero() {
let a = Array::from_slice(&[0.0_f32, 0.0, 0.0], &(3,)).unwrap();
let b = Array::from_slice(&[1.0_f32, 2.0, 3.0], &(3,)).unwrap();
let got = cosine_similarity(&a, &b).unwrap();
assert!(
got.is_finite(),
"zero-vector cosine must be finite, got {got}"
);
assert_eq!(got, 0.0_f32, "zero-vector cosine must be exactly 0.0");
let got_sym = cosine_similarity(&b, &a).unwrap();
assert!(got_sym.is_finite() && got_sym == 0.0_f32);
let z = Array::from_slice(&[0.0_f32, 0.0, 0.0], &(3,)).unwrap();
let got_both = cosine_similarity(&a, &z).unwrap();
assert!(
got_both.is_finite() && got_both == 0.0_f32,
"both-zero cosine must be finite 0.0, got {got_both}"
);
}
#[test]
fn cosine_similarity_length_zero_is_finite() {
let a = Array::from_slice::<f32>(&[], &(0,)).unwrap();
let b = Array::from_slice::<f32>(&[], &(0,)).unwrap();
let r = cosine_similarity(&a, &b);
let got = r.unwrap_or_else(|e| {
panic!("length-0 vs length-0 must pass the rank/length validator, got {e:?}")
});
assert!(
got.is_finite(),
"length-0 cosine must be finite (no NaN/Inf), got {got}"
);
assert_eq!(got, 0.0_f32, "length-0 cosine must be exactly 0.0");
}
#[test]
fn cosine_similarity_zero_vector_f16_bf16_is_finite_zero() {
for dt in [Dtype::F16, Dtype::BF16] {
let zero = Array::from_slice(&[0.0_f32, 0.0, 0.0], &(3,))
.unwrap()
.astype(dt)
.unwrap();
let nonzero = Array::from_slice(&[1.0_f32, 2.0, 3.0], &(3,))
.unwrap()
.astype(dt)
.unwrap();
let got = cosine_similarity(&zero, &nonzero)
.unwrap_or_else(|e| panic!("cosine_similarity {dt:?} zero/nonzero errored: {e:?}"));
assert!(
got.is_finite(),
"{dt:?} zero-vector cosine must be finite (not NaN/Inf), got {got}"
);
assert_eq!(
got, 0.0_f32,
"{dt:?} zero-vector cosine must be exactly 0.0"
);
let got_sym = cosine_similarity(&nonzero, &zero)
.unwrap_or_else(|e| panic!("cosine_similarity {dt:?} nonzero/zero errored: {e:?}"));
assert!(
got_sym.is_finite() && got_sym == 0.0_f32,
"{dt:?} symmetric zero-vector cosine must be finite 0.0, got {got_sym}"
);
let zero2 = Array::from_slice(&[0.0_f32, 0.0, 0.0], &(3,))
.unwrap()
.astype(dt)
.unwrap();
let got_both = cosine_similarity(&zero, &zero2)
.unwrap_or_else(|e| panic!("cosine_similarity {dt:?} both-zero errored: {e:?}"));
assert!(
got_both.is_finite() && got_both == 0.0_f32,
"{dt:?} both-zero cosine must be finite 0.0, got {got_both}"
);
}
}
#[test]
fn cosine_similarity_scale_invariant_tiny_norm_f32_bf16() {
let a = Array::from_slice(&[1e-12_f32], &(1,)).unwrap();
let b = Array::from_slice(&[1.0_f32], &(1,)).unwrap();
let got = cosine_similarity(&a, &b).unwrap();
assert!(
got.is_finite() && got == 1.0_f32,
"tiny-norm colinear cosine must be exactly 1.0 (scale-invariant), got {got}"
);
let a2 = Array::from_slice(&[1e-7_f32, 2e-7, 3e-7], &(3,)).unwrap();
let b2 = Array::from_slice(&[1e1_f32, 2e1, 3e1], &(3,)).unwrap(); let got2 = cosine_similarity(&a2, &b2).unwrap();
assert!(
got2.is_finite() && (got2 - 1.0_f32).abs() <= 4.0 * f32::EPSILON,
"scaled-colinear (b = 1e8*a) cosine must be ~1.0, got {got2}"
);
let an = Array::from_slice(&[-1e-12_f32], &(1,)).unwrap();
let bn = Array::from_slice(&[1.0_f32], &(1,)).unwrap();
let gotn = cosine_similarity(&an, &bn).unwrap();
assert!(
gotn.is_finite() && gotn == -1.0_f32,
"tiny-norm anti-colinear cosine must be exactly -1.0, got {gotn}"
);
let small = 6.1035e-5_f32; let abf = Array::from_slice(&[small], &(1,))
.unwrap()
.astype(Dtype::BF16)
.unwrap();
let bbf = Array::from_slice(&[1.0_f32], &(1,))
.unwrap()
.astype(Dtype::BF16)
.unwrap();
let gotbf = cosine_similarity(&abf, &bbf).unwrap();
assert!(
gotbf.is_finite() && half_close(Dtype::BF16, gotbf, 1.0),
"bf16 small scaled-colinear cosine must be ~1.0 (scale-invariant), got {gotbf}"
);
}
#[test]
fn cosine_similarity_zero_vs_overflowed_norm_is_finite_zero() {
let zero = Array::from_slice(&[0.0_f32, 0.0], &(2,)).unwrap();
let overflowed = Array::from_slice(&[f32::MAX, f32::MAX], &(2,)).unwrap();
let got = cosine_similarity(&zero, &overflowed).unwrap();
assert!(
got.is_finite(),
"zero vs overflowed-norm cosine must be finite (NOT NaN/Inf), got {got}"
);
assert_eq!(
got, 0.0_f32,
"zero vs overflowed-norm cosine must be exactly 0.0, got {got}"
);
let got_sym = cosine_similarity(&overflowed, &zero).unwrap();
assert!(
got_sym.is_finite() && got_sym == 0.0_f32,
"symmetric (overflowed-norm vs zero) cosine must be finite 0.0, got {got_sym}"
);
let f16_max = half::f16::MAX.to_f32(); let zero_h = Array::from_slice(&[0.0_f32, 0.0], &(2,))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let overflowed_h = Array::from_slice(&[f16_max, f16_max], &(2,))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let got_h = cosine_similarity(&zero_h, &overflowed_h)
.unwrap_or_else(|e| panic!("f16 zero vs overflowed-norm errored: {e:?}"));
assert!(
got_h.is_finite(),
"f16 zero vs overflowed-norm cosine must be finite (NOT NaN/Inf), got {got_h}"
);
assert_eq!(
got_h, 0.0_f32,
"f16 zero vs overflowed-norm cosine must be exactly 0.0, got {got_h}"
);
let got_h_sym = cosine_similarity(&overflowed_h, &zero_h)
.unwrap_or_else(|e| panic!("f16 symmetric overflowed-norm vs zero errored: {e:?}"));
assert!(
got_h_sym.is_finite() && got_h_sym == 0.0_f32,
"f16 symmetric (overflowed-norm vs zero) cosine must be finite 0.0, got {got_h_sym}"
);
}
#[test]
fn cosine_similarity_tiny_and_huge_nonzero_are_scale_invariant() {
let a = Array::from_slice(&[1e-23_f32], &(1,)).unwrap();
let b = Array::from_slice(&[1.0_f32], &(1,)).unwrap();
let got = cosine_similarity(&a, &b).unwrap();
assert!(
got.is_finite() && got == 1.0_f32,
"f32 tiny [1e-23] vs [1.0] colinear must be exactly 1.0 (scale-invariant, NOT the underflow-misclassified 0.0), got {got}"
);
let an = Array::from_slice(&[1e-30_f32], &(1,)).unwrap();
let bn = Array::from_slice(&[-1e-30_f32], &(1,)).unwrap();
let gotn = cosine_similarity(&an, &bn).unwrap();
assert!(
gotn.is_finite() && gotn == -1.0_f32,
"f32 tiny [1e-30] vs [-1e-30] anti-colinear must be exactly -1.0, got {gotn}"
);
let huge = Array::from_slice(&[f32::MAX, f32::MAX], &(2,)).unwrap();
let ones = Array::from_slice(&[1.0_f32, 1.0], &(2,)).unwrap();
let goth = cosine_similarity(&huge, &ones).unwrap();
assert!(
goth.is_finite() && (goth - 1.0_f32).abs() <= 4.0 * f32::EPSILON,
"f32 huge [f32::MAX,f32::MAX] vs [1,1] colinear must be ~1.0 (overflow class terminal), got {goth}"
);
let small = 6.1035e-5_f32; let af16 = Array::from_slice(&[small], &(1,))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let bf16 = Array::from_slice(&[1.0_f32], &(1,))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let gotf = cosine_similarity(&af16, &bf16)
.unwrap_or_else(|e| panic!("f16 tiny [6.1e-5] vs [1.0] errored: {e:?}"));
assert!(
gotf.is_finite() && gotf == 1.0_f32,
"f16 tiny [6.1035e-5] vs [1.0] colinear must be exactly 1.0 (scale-invariant), got {gotf}"
);
let af16n = Array::from_slice(&[small], &(1,))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let bf16n = Array::from_slice(&[-small], &(1,))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let gotfn = cosine_similarity(&af16n, &bf16n)
.unwrap_or_else(|e| panic!("f16 tiny anti-colinear errored: {e:?}"));
assert!(
gotfn.is_finite() && gotfn == -1.0_f32,
"f16 tiny [6.1035e-5] vs [-6.1035e-5] anti-colinear must be exactly -1.0, got {gotfn}"
);
}
#[test]
fn f32_paths_bit_identical_after_dtype_fix() {
let (emb, mask) = fixture();
assert_dtype(&emb, Dtype::F32, "fixture emb is f32");
let mut mx = max_pooling(&emb, &mask).unwrap();
assert_dtype(&mx, Dtype::F32, "max_pooling f32");
assert_eq!(mx.to_vec::<f32>().unwrap(), vec![5.0, 6.0, 70.0, 80.0]);
let x = Array::from_slice(&[3.0_f32, 4.0], &(1, 2)).unwrap();
let mut n = l2_normalize(&x).unwrap();
assert_dtype(&n, Dtype::F32, "l2_normalize f32");
assert!(vclose(&n.to_vec::<f32>().unwrap(), &[0.6, 0.8]));
let mut np = normalize(&x, 2.0, -1, true, DEFAULT_NORMALIZE_EPS).unwrap();
assert!(vclose(&np.to_vec::<f32>().unwrap(), &[0.6, 0.8]));
}
#[test]
fn pool_post_no_transform_is_passthrough() {
let x = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let mut p = pool_post(x, false, None, false, false).unwrap();
assert_eq!(p.shape(), vec![2, 2]);
assert_eq!(p.to_vec::<f32>().unwrap(), vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn pool_post_truncate_only() {
let x = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0], &(2, 3)).unwrap();
let mut p = pool_post(x, false, Some(2), false, false).unwrap();
assert_eq!(p.shape(), vec![2, 2]);
assert!(vclose(&p.to_vec::<f32>().unwrap(), &[1.0, 2.0, 4.0, 5.0]));
}
#[test]
fn pool_post_normalize_only_yields_unit_rows() {
let x = Array::from_slice(&[3.0_f32, 4.0], &(1, 2)).unwrap();
let mut p = pool_post(x, true, None, false, false).unwrap();
assert_eq!(p.shape(), vec![1, 2]);
assert!(vclose(&p.to_vec::<f32>().unwrap(), &[0.6, 0.8]));
}
#[test]
fn pool_post_truncate_then_normalize_order() {
let x = Array::from_slice(&[3.0_f32, 4.0, 99.0, 0.0, 5.0, 12.0], &(2, 3)).unwrap();
let mut p = pool_post(x, true, Some(2), false, false).unwrap();
assert_eq!(p.shape(), vec![2, 2]);
assert!(vclose(&p.to_vec::<f32>().unwrap(), &[0.6, 0.8, 0.0, 1.0]));
}
#[test]
fn pool_post_equivalent_to_pool_dispatcher_tail() {
let (emb, mask) = fixture();
for (normalize, dim, ln, rms) in [
(false, None, false, false),
(true, None, false, false),
(false, Some(1), false, false),
(true, Some(1), false, false),
(false, None, true, false), (false, None, false, true), (true, None, true, false), (false, None, true, true), ] {
let pooled = mean_pooling(&emb, &mask).unwrap();
let mut via_post = pool_post(pooled, normalize, dim, ln, rms).unwrap();
let mut via_pool = pool(&emb, &mask, PoolingStrategy::Mean, normalize, dim, ln, rms).unwrap();
assert_eq!(
via_post.shape(),
via_pool.shape(),
"shape mismatch for (norm={normalize}, dim={dim:?}, ln={ln}, rms={rms})"
);
assert!(
vclose(
&via_post.to_vec::<f32>().unwrap(),
&via_pool.to_vec::<f32>().unwrap()
),
"pool_post must equal pool tail for (norm={normalize}, dim={dim:?}, ln={ln}, rms={rms})"
);
}
}
#[test]
fn pool_post_layer_norm_closed_form() {
let x = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(1, 4)).unwrap();
let mut p = pool_post(x, false, None, true, false).unwrap();
assert_eq!(p.shape(), vec![1, 4]);
assert!(vclose(
&p.to_vec::<f32>().unwrap(),
&[-1.3416354, -0.4472118, 0.4472118, 1.3416354],
));
}
#[test]
fn pool_post_rms_norm_closed_form_eps_load_bearing() {
let x = Array::from_slice(&[0.001_f32, 0.001], &(1, 2)).unwrap();
let mut p = pool_post(x, false, None, false, true).unwrap();
assert_eq!(p.shape(), vec![1, 2]);
assert!(vclose(
&p.to_vec::<f32>().unwrap(),
&[0.30151135, 0.30151135]
));
}
#[test]
fn pool_post_layer_norm_wins_over_rms_closed_form() {
let layer_norm_expected = [-1.3416354_f32, -0.4472118, 0.4472118, 1.3416354];
let rms_expected = [0.36514813_f32, 0.73029626, 1.0954444, 1.4605925];
let x = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(1, 4)).unwrap();
let mut p = pool_post(x, false, None, true, true).unwrap();
let got = p.to_vec::<f32>().unwrap();
assert!(
vclose(&got, &layer_norm_expected),
"both flags set must yield the LayerNorm result, got {got:?}"
);
assert!(
!vclose(&got, &rms_expected),
"both flags set must NOT yield the RMSNorm result, got {got:?}"
);
}
#[test]
fn pool_post_layer_norm_then_truncate_then_normalize_order_closed_form() {
let x = Array::from_slice(&[-3.0_f32, -1.0, 1.0, 3.0], &(1, 4)).unwrap();
let mut p = pool_post(x, true, Some(2), true, false).unwrap();
assert_eq!(p.shape(), vec![1, 2]);
assert!(vclose(
&p.to_vec::<f32>().unwrap(),
&[-0.9486833, -0.31622776],
));
}
#[test]
fn pooling_strategy_as_str_canonical_names() {
assert_eq!(PoolingStrategy::Mean.as_str(), "mean");
assert_eq!(PoolingStrategy::Cls.as_str(), "cls");
assert_eq!(PoolingStrategy::First.as_str(), "first");
assert_eq!(PoolingStrategy::Last.as_str(), "last");
assert_eq!(PoolingStrategy::Max.as_str(), "max");
assert_eq!(PoolingStrategy::None.as_str(), "none");
}
#[test]
fn pooling_strategy_display_matches_as_str() {
for s in [
PoolingStrategy::Mean,
PoolingStrategy::Cls,
PoolingStrategy::First,
PoolingStrategy::Last,
PoolingStrategy::Max,
PoolingStrategy::None,
] {
assert_eq!(format!("{s}"), s.as_str());
}
}
#[test]
fn pooling_strategy_is_variant_predicates() {
assert!(PoolingStrategy::Mean.is_mean());
assert!(PoolingStrategy::Cls.is_cls());
assert!(PoolingStrategy::First.is_first());
assert!(PoolingStrategy::Last.is_last());
assert!(PoolingStrategy::Max.is_max());
assert!(PoolingStrategy::None.is_none());
assert!(!PoolingStrategy::Mean.is_cls());
assert!(!PoolingStrategy::Cls.is_mean());
assert!(!PoolingStrategy::First.is_last());
assert!(!PoolingStrategy::Last.is_first());
assert!(!PoolingStrategy::None.is_max());
}
#[test]
fn pooling_strategy_from_mode_last_alias() {
assert_eq!(
PoolingStrategy::from_mode("last").unwrap(),
PoolingStrategy::Last
);
assert_eq!(
PoolingStrategy::from_mode(PoolingStrategy::Last.as_str()).unwrap(),
PoolingStrategy::Last
);
}
#[test]
fn truncate_last_dim_rank1() {
let x = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(4,)).unwrap();
let mut t = truncate_last_dim(&x, 2).unwrap();
assert_eq!(t.shape(), vec![2]);
assert!(vclose(&t.to_vec::<f32>().unwrap(), &[1.0, 2.0]));
}
#[test]
fn truncate_last_dim_rank3_keeps_first_of_last_axis() {
let x = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &(2, 2, 2)).unwrap();
let mut t = truncate_last_dim(&x, 1).unwrap();
assert_eq!(t.shape(), vec![2, 2, 1]);
assert!(vclose(&t.to_vec::<f32>().unwrap(), &[1.0, 3.0, 5.0, 7.0]));
}
#[test]
fn dispatcher_none_passthrough_with_matryoshka_truncation() {
let emb = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0], &(1, 2, 3)).unwrap();
let mask = Array::ones::<f32>(&(1, 2)).unwrap();
let mut p = pool(
&emb,
&mask,
PoolingStrategy::None,
false,
Some(2),
false,
false,
)
.unwrap();
assert_eq!(p.shape(), vec![1, 2, 2]);
assert!(vclose(&p.to_vec::<f32>().unwrap(), &[1.0, 2.0, 4.0, 5.0]));
}