use super::*;
use crate::embeddings::model::{EmbeddingModel, EmbeddingModelOutput, MockEmbeddingModel};
const TOL: f32 = 1e-5;
struct RawPooledModel {
inner: MockEmbeddingModel,
pooled_data: Vec<f32>,
pooled_shape: Vec<usize>,
}
impl EmbeddingModel for RawPooledModel {
fn forward(&self, input_ids: &Array, attention_mask: &Array) -> Result<EmbeddingModelOutput> {
let out = self.inner.forward(input_ids, attention_mask)?;
let pooled = Array::from_slice::<f32>(&self.pooled_data, &self.pooled_shape.as_slice())?;
let (last_hidden_state, _) = out.into_parts();
Ok(EmbeddingModelOutput::new(last_hidden_state, Some(pooled)))
}
}
fn close(a: f32, b: f32) -> bool {
(a - b).abs() <= TOL
}
fn vclose(a: &[f32], b: &[f32]) -> bool {
a.len() == b.len() && a.iter().zip(b).all(|(x, y)| close(*x, *y))
}
fn word_tokenizer() -> Tokenizer {
use tokenizers::{
Tokenizer as HfTokenizer, models::wordlevel::WordLevel, pre_tokenizers::whitespace::Whitespace,
};
let vocab = ["a", "b", "c", "d", "e"]
.iter()
.enumerate()
.map(|(i, w)| ((*w).to_string(), i as u32))
.collect();
let wl = WordLevel::builder()
.vocab(vocab)
.unk_token("a".to_string())
.build()
.unwrap();
let mut hf = HfTokenizer::new(wl);
hf.with_pre_tokenizer(Some(Whitespace {}));
static FIXTURE: std::sync::OnceLock<std::path::PathBuf> = std::sync::OnceLock::new();
let dir = FIXTURE.get_or_init(|| {
let dir = std::env::temp_dir().join(format!("mlxrs-emb-encode-tok-{}", std::process::id()));
std::fs::create_dir_all(&dir).unwrap();
hf.save(dir.join("tokenizer.json"), false).unwrap();
dir
});
Tokenizer::from_path(dir, None).unwrap()
}
fn padded_word_tokenizer() -> Tokenizer {
use tokenizers::{
PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer as HfTokenizer,
models::wordlevel::WordLevel, pre_tokenizers::whitespace::Whitespace,
};
let vocab = ["a", "b", "c", "d", "e"]
.iter()
.enumerate()
.map(|(i, w)| ((*w).to_string(), i as u32))
.collect();
let wl = WordLevel::builder()
.vocab(vocab)
.unk_token("a".to_string())
.build()
.unwrap();
let mut hf = HfTokenizer::new(wl);
hf.with_pre_tokenizer(Some(Whitespace {}));
hf.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::Fixed(4),
direction: PaddingDirection::Right,
pad_to_multiple_of: None,
pad_id: 4,
pad_type_id: 0,
pad_token: "e".to_string(),
}));
static FIXTURE: std::sync::OnceLock<std::path::PathBuf> = std::sync::OnceLock::new();
let dir = FIXTURE.get_or_init(|| {
let dir = std::env::temp_dir().join(format!("mlxrs-emb-encode-pad-tok-{}", std::process::id()));
std::fs::create_dir_all(&dir).unwrap();
hf.save(dir.join("tokenizer.json"), false).unwrap();
dir
});
Tokenizer::from_path(dir, None).unwrap()
}
#[test]
fn tokenize_and_pad_builds_right_padded_ids_and_mask() {
let tok = word_tokenizer();
let (mut ids, mut mask, seq_len) =
tokenize_and_pad(&tok, &["a b c", "d e"], false, None, 7).unwrap();
assert_eq!(seq_len, 3);
assert_eq!(ids.shape(), vec![2, 3]);
assert_eq!(mask.shape(), vec![2, 3]);
assert_eq!(ids.to_vec::<i32>().unwrap(), vec![0, 1, 2, 3, 4, 7]);
assert_eq!(
mask.to_vec::<f32>().unwrap(),
vec![1.0, 1.0, 1.0, 1.0, 1.0, 0.0]
);
}
#[test]
fn tokenize_and_pad_truncates_to_max_length() {
let tok = word_tokenizer();
let (mut ids, mut mask, seq_len) =
tokenize_and_pad(&tok, &["a b c", "d e"], false, Some(2), 0).unwrap();
assert_eq!(seq_len, 2);
assert_eq!(ids.to_vec::<i32>().unwrap(), vec![0, 1, 3, 4]);
assert_eq!(mask.to_vec::<f32>().unwrap(), vec![1.0, 1.0, 1.0, 1.0]);
}
#[test]
fn encode_mean_pool_normalized_two_text_batch() {
let tok = word_tokenizer();
let model = MockEmbeddingModel::new(vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]]);
let cfg = EncodeConfig::new()
.with_add_special_tokens(false)
.with_strategy(PoolingStrategy::Mean)
.with_normalize(true);
let mut emb = encode(&model, &tok, &["a b c", "d e"], &cfg).unwrap();
assert_eq!(emb.shape(), vec![2, 2]);
let v = emb.to_vec::<f32>().unwrap();
let inv_sqrt2 = 1.0 / 2.0_f32.sqrt();
assert!(vclose(&v[0..2], &[inv_sqrt2, inv_sqrt2]));
assert!(vclose(&v[2..4], &[inv_sqrt2, inv_sqrt2]));
}
#[test]
fn encode_mean_pool_unnormalized_excludes_padding() {
let tok = word_tokenizer();
let model = MockEmbeddingModel::new(vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]]);
let cfg = EncodeConfig::new()
.with_add_special_tokens(false)
.with_strategy(PoolingStrategy::Mean)
.with_normalize(false);
let mut emb = encode(&model, &tok, &["a b c", "d e"], &cfg).unwrap();
let v = emb.to_vec::<f32>().unwrap();
assert!(vclose(&v[0..2], &[2.0 / 3.0, 2.0 / 3.0]));
assert!(vclose(&v[2..4], &[0.5, 0.5]));
}
#[test]
fn encode_cls_pool_selects_first_real_token() {
let tok = word_tokenizer();
let model = MockEmbeddingModel::new(vec![vec![9.0, 3.0], vec![0.0, 1.0], vec![1.0, 1.0]]);
let cfg = EncodeConfig::new()
.with_add_special_tokens(false)
.with_strategy(PoolingStrategy::Cls)
.with_normalize(false);
let mut emb = encode(&model, &tok, &["a b c", "d e"], &cfg).unwrap();
assert_eq!(emb.shape(), vec![2, 2]);
let v = emb.to_vec::<f32>().unwrap();
assert!(vclose(&v[0..2], &[9.0, 3.0]));
assert!(vclose(&v[2..4], &[9.0, 3.0]));
}
#[test]
fn tokenize_and_pad_strips_tokenizer_applied_padding() {
let tok = padded_word_tokenizer();
let (mut ids, mut mask, seq_len) = tokenize_and_pad(&tok, &["a b c"], false, None, 0).unwrap();
assert_eq!(seq_len, 3, "pad cells must be stripped, not counted");
assert_eq!(ids.shape(), vec![1, 3]);
assert_eq!(ids.to_vec::<i32>().unwrap(), vec![0, 1, 2]);
assert_eq!(mask.to_vec::<f32>().unwrap(), vec![1.0, 1.0, 1.0]);
}
#[test]
fn tokenize_and_pad_padded_tokenizer_matches_unpadded() {
let unpadded = word_tokenizer();
let padded = padded_word_tokenizer();
let (mut u_ids, mut u_mask, u_seq) =
tokenize_and_pad(&unpadded, &["a b c", "d e"], false, None, 7).unwrap();
let (mut p_ids, mut p_mask, p_seq) =
tokenize_and_pad(&padded, &["a b c", "d e"], false, None, 7).unwrap();
assert_eq!(u_seq, p_seq);
assert_eq!(
u_ids.to_vec::<i32>().unwrap(),
p_ids.to_vec::<i32>().unwrap()
);
assert_eq!(
u_mask.to_vec::<f32>().unwrap(),
p_mask.to_vec::<f32>().unwrap()
);
assert_eq!(
p_mask.to_vec::<f32>().unwrap(),
vec![1.0, 1.0, 1.0, 1.0, 1.0, 0.0]
);
}
#[test]
fn encode_mean_pool_invariant_to_tokenizer_padding() {
let canned = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
let model_a = MockEmbeddingModel::new(canned.clone());
let model_b = MockEmbeddingModel::new(canned);
let cfg = EncodeConfig::new()
.with_add_special_tokens(false)
.with_strategy(PoolingStrategy::Mean)
.with_normalize(false);
let mut emb_unpadded = encode(&model_a, &word_tokenizer(), &["a b c", "d e"], &cfg).unwrap();
let mut emb_padded = encode(&model_b, &padded_word_tokenizer(), &["a b c", "d e"], &cfg).unwrap();
assert_eq!(emb_unpadded.shape(), emb_padded.shape());
let vu = emb_unpadded.to_vec::<f32>().unwrap();
let vp = emb_padded.to_vec::<f32>().unwrap();
assert!(vclose(&vu, &vp), "padded={vp:?} unpadded={vu:?}");
assert!(vclose(&vp[0..2], &[2.0 / 3.0, 2.0 / 3.0]));
assert!(vclose(&vp[2..4], &[0.5, 0.5]));
}
#[test]
fn encode_cls_uses_model_pooled_output_when_present() {
let tok = word_tokenizer();
let model = MockEmbeddingModel::new(vec![vec![9.0, 3.0], vec![0.0, 1.0], vec![1.0, 1.0]])
.with_pooled(vec![vec![7.0, 5.0], vec![6.0, 4.0]]);
let cfg = EncodeConfig::new()
.with_add_special_tokens(false)
.with_strategy(PoolingStrategy::Cls)
.with_normalize(false);
let mut emb = encode(&model, &tok, &["a b c", "d e"], &cfg).unwrap();
assert_eq!(emb.shape(), vec![2, 2]);
let v = emb.to_vec::<f32>().unwrap();
assert!(
vclose(&v[0..2], &[7.0, 5.0]),
"expected pooled row 0, got {:?}",
&v[0..2]
);
assert!(
vclose(&v[2..4], &[6.0, 4.0]),
"expected pooled row 1, got {:?}",
&v[2..4]
);
}
#[test]
fn encode_cls_pooled_output_applies_normalize_and_dimension() {
let tok = word_tokenizer();
let model =
MockEmbeddingModel::new(vec![vec![9.0, 3.0], vec![0.0, 1.0]]).with_pooled(vec![vec![3.0, 4.0]]);
let cfg = EncodeConfig::new()
.with_add_special_tokens(false)
.with_strategy(PoolingStrategy::Cls)
.with_normalize(true);
let mut emb = encode(&model, &tok, &["a b", "a b"], &cfg).unwrap();
let v = emb.to_vec::<f32>().unwrap();
assert!(vclose(&v[0..2], &[0.6, 0.8]));
assert!(vclose(&v[2..4], &[0.6, 0.8]));
}
#[test]
fn encode_cls_falls_back_to_hidden_states_without_pooled_output() {
let tok = word_tokenizer();
let model = MockEmbeddingModel::new(vec![vec![9.0, 3.0], vec![0.0, 1.0], vec![1.0, 1.0]]);
assert!(model.pooled.is_none());
let cfg = EncodeConfig::new()
.with_add_special_tokens(false)
.with_strategy(PoolingStrategy::Cls)
.with_normalize(false);
let mut emb = encode(&model, &tok, &["a b c", "d e"], &cfg).unwrap();
let v = emb.to_vec::<f32>().unwrap();
assert!(vclose(&v[0..2], &[9.0, 3.0]));
assert!(vclose(&v[2..4], &[9.0, 3.0]));
}
fn raw_pooled_model(pooled_data: Vec<f32>, pooled_shape: Vec<usize>) -> RawPooledModel {
RawPooledModel {
inner: MockEmbeddingModel::new(vec![vec![9.0, 3.0], vec![0.0, 1.0], vec![1.0, 1.0]]),
pooled_data,
pooled_shape,
}
}
#[test]
fn encode_cls_rejects_wrong_rank_pooled_output() {
let tok = word_tokenizer();
let model = raw_pooled_model(vec![7.0, 5.0], vec![2]);
let cfg = EncodeConfig::new()
.with_add_special_tokens(false)
.with_strategy(PoolingStrategy::Cls)
.with_normalize(false);
let err = encode(&model, &tok, &["a b c", "d e"], &cfg).unwrap_err();
assert!(
matches!(err, Error::RankMismatch(ref p) if p.actual() == 1),
"expected RankMismatch(actual=1), got {err:?}"
);
}
#[test]
fn encode_none_rejects_wrong_rank_pooled_output() {
let tok = word_tokenizer();
let model = raw_pooled_model(vec![7.0, 5.0], vec![2]);
let cfg = EncodeConfig::new()
.with_add_special_tokens(false)
.with_strategy(PoolingStrategy::None)
.with_normalize(false);
let err = encode(&model, &tok, &["a b c", "d e"], &cfg).unwrap_err();
assert!(
matches!(err, Error::RankMismatch(ref p) if p.actual() == 1),
"expected RankMismatch(actual=1), got {err:?}"
);
}
#[test]
fn encode_cls_rejects_wrong_batch_pooled_output() {
let tok = word_tokenizer();
let model = raw_pooled_model(vec![7.0, 5.0], vec![1, 2]);
let cfg = EncodeConfig::new()
.with_add_special_tokens(false)
.with_strategy(PoolingStrategy::Cls)
.with_normalize(false);
let err = encode(&model, &tok, &["a b c", "d e"], &cfg).unwrap_err();
assert!(
matches!(err, Error::LengthMismatch(_)),
"expected LengthMismatch, got {err:?}"
);
}
#[test]
fn encode_none_rejects_wrong_batch_pooled_output() {
let tok = word_tokenizer();
let model = raw_pooled_model(vec![7.0, 5.0], vec![1, 2]);
let cfg = EncodeConfig::new()
.with_add_special_tokens(false)
.with_strategy(PoolingStrategy::None)
.with_normalize(false);
let err = encode(&model, &tok, &["a b c", "d e"], &cfg).unwrap_err();
assert!(
matches!(err, Error::LengthMismatch(_)),
"expected LengthMismatch, got {err:?}"
);
}
#[test]
fn encode_cls_rejects_wrong_hidden_width_pooled_output() {
let tok = word_tokenizer();
let model = raw_pooled_model(vec![7.0, 5.0, 1.0, 6.0, 4.0, 2.0], vec![2, 3]);
let cfg = EncodeConfig::new()
.with_add_special_tokens(false)
.with_strategy(PoolingStrategy::Cls)
.with_normalize(false);
let err = encode(&model, &tok, &["a b c", "d e"], &cfg).unwrap_err();
assert!(
matches!(err, Error::ShapePairMismatch(_)),
"expected ShapePairMismatch, got {err:?}"
);
}
#[test]
fn encode_none_rejects_wrong_hidden_width_pooled_output() {
let tok = word_tokenizer();
let model = raw_pooled_model(vec![7.0, 5.0, 1.0, 6.0, 4.0, 2.0], vec![2, 3]);
let cfg = EncodeConfig::new()
.with_add_special_tokens(false)
.with_strategy(PoolingStrategy::None)
.with_normalize(false);
let err = encode(&model, &tok, &["a b c", "d e"], &cfg).unwrap_err();
assert!(
matches!(err, Error::ShapePairMismatch(_)),
"expected ShapePairMismatch, got {err:?}"
);
}
#[test]
fn encode_cls_accepts_correct_shape_raw_pooled_output() {
let tok = word_tokenizer();
let model = raw_pooled_model(vec![7.0, 5.0, 6.0, 4.0], vec![2, 2]);
let cfg = EncodeConfig::new()
.with_add_special_tokens(false)
.with_strategy(PoolingStrategy::Cls)
.with_normalize(false);
let mut emb = encode(&model, &tok, &["a b c", "d e"], &cfg).unwrap();
assert_eq!(emb.shape(), vec![2, 2]);
let v = emb.to_vec::<f32>().unwrap();
assert!(vclose(&v[0..2], &[7.0, 5.0]));
assert!(vclose(&v[2..4], &[6.0, 4.0]));
}
#[test]
fn encode_config_builders_and_accessors_round_trip() {
let d = EncodeConfig::new();
assert_eq!(d.strategy(), PoolingStrategy::Mean);
assert!(d.normalize());
assert!(d.add_special_tokens());
assert_eq!(d.max_length(), Some(512));
assert_eq!(d.pad_token_id(), 0);
assert_eq!(d.dimension(), None);
assert!(!d.apply_layer_norm());
assert!(!d.apply_rms_norm());
let cfg = EncodeConfig::new()
.with_strategy(PoolingStrategy::Max)
.with_normalize(false)
.with_add_special_tokens(false)
.with_max_length(None)
.with_pad_token_id(7)
.with_dimension(Some(3))
.with_apply_layer_norm(true)
.with_apply_rms_norm(true);
assert_eq!(cfg.strategy(), PoolingStrategy::Max);
assert!(!cfg.normalize());
assert!(!cfg.add_special_tokens());
assert_eq!(cfg.max_length(), None);
assert_eq!(cfg.pad_token_id(), 7);
assert_eq!(cfg.dimension(), Some(3));
assert!(cfg.apply_layer_norm());
assert!(cfg.apply_rms_norm());
let some = EncodeConfig::new()
.with_max_length(Some(8))
.with_dimension(Some(16));
assert_eq!(some.max_length(), Some(8));
assert_eq!(some.dimension(), Some(16));
}
fn huge_id_tokenizer() -> Tokenizer {
use tokenizers::{
Tokenizer as HfTokenizer, models::wordlevel::WordLevel, pre_tokenizers::whitespace::Whitespace,
};
let big_id: u32 = 0x8000_0000; let vocab = [("big".to_string(), big_id)].into_iter().collect();
let wl = WordLevel::builder()
.vocab(vocab)
.unk_token("big".to_string())
.build()
.unwrap();
let mut hf = HfTokenizer::new(wl);
hf.with_pre_tokenizer(Some(Whitespace {}));
static FIXTURE: std::sync::OnceLock<std::path::PathBuf> = std::sync::OnceLock::new();
let dir = FIXTURE.get_or_init(|| {
let dir =
std::env::temp_dir().join(format!("mlxrs-emb-encode-huge-tok-{}", std::process::id()));
std::fs::create_dir_all(&dir).unwrap();
hf.save(dir.join("tokenizer.json"), false).unwrap();
dir
});
Tokenizer::from_path(dir, None).unwrap()
}
#[test]
fn tokenize_and_pad_rejects_token_id_above_i32_max() {
let tok = huge_id_tokenizer();
let err = tokenize_and_pad(&tok, &["big"], false, None, 0).unwrap_err();
match err {
Error::OutOfRange(p) => {
assert_eq!(p.context(), "encode: token id");
assert_eq!(p.value(), "2147483648");
}
other => panic!("expected OutOfRange(token id), got {other:?}"),
}
}
#[test]
fn encode_rejects_token_id_above_i32_max() {
let tok = huge_id_tokenizer();
let model = MockEmbeddingModel::new(vec![vec![1.0, 0.0]]);
let cfg = EncodeConfig::new().with_add_special_tokens(false);
let err = encode(&model, &tok, &["big"], &cfg).unwrap_err();
assert!(
matches!(err, Error::OutOfRange(ref p) if p.context() == "encode: token id"),
"expected OutOfRange(token id), got {err:?}"
);
}
#[test]
fn tokenize_and_pad_rejects_pad_token_id_above_i32_max() {
let tok = word_tokenizer();
let bad_pad: u32 = 0x8000_0000; let err = tokenize_and_pad(&tok, &["a b"], false, None, bad_pad).unwrap_err();
match err {
Error::OutOfRange(p) => {
assert_eq!(p.context(), "encode: pad_token_id");
assert_eq!(p.value(), "2147483648");
}
other => panic!("expected OutOfRange(pad_token_id), got {other:?}"),
}
}
#[test]
fn encode_rejects_pad_token_id_above_i32_max() {
let tok = word_tokenizer();
let model = MockEmbeddingModel::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]]);
let cfg = EncodeConfig::new()
.with_add_special_tokens(false)
.with_pad_token_id(0x8000_0000);
let err = encode(&model, &tok, &["a b"], &cfg).unwrap_err();
assert!(
matches!(err, Error::OutOfRange(ref p) if p.context() == "encode: pad_token_id"),
"expected OutOfRange(pad_token_id), got {err:?}"
);
}
struct RawShapeModel {
hidden_data: Vec<f32>,
hidden_shape: Vec<usize>,
pooled_data: Vec<f32>,
pooled_shape: Vec<usize>,
}
impl EmbeddingModel for RawShapeModel {
fn forward(&self, _input_ids: &Array, _attention_mask: &Array) -> Result<EmbeddingModelOutput> {
let last_hidden_state =
Array::from_slice::<f32>(&self.hidden_data, &self.hidden_shape.as_slice())?;
let pooled = Array::from_slice::<f32>(&self.pooled_data, &self.pooled_shape.as_slice())?;
Ok(EmbeddingModelOutput::new(last_hidden_state, Some(pooled)))
}
}
#[test]
fn encode_cls_rejects_non_rank3_hidden_state_in_pooled_path() {
let tok = word_tokenizer();
let model = RawShapeModel {
hidden_data: vec![1.0, 2.0, 3.0, 4.0],
hidden_shape: vec![2, 2],
pooled_data: vec![7.0, 5.0, 6.0, 4.0],
pooled_shape: vec![2, 2],
};
let cfg = EncodeConfig::new()
.with_add_special_tokens(false)
.with_strategy(PoolingStrategy::Cls)
.with_normalize(false);
let err = encode(&model, &tok, &["a b c", "d e"], &cfg).unwrap_err();
match err {
Error::RankMismatch(p) => {
assert_eq!(
p.context(),
"encode: model last_hidden_state must be rank-3 (batch, seq_len, hidden)"
);
assert_eq!(p.actual(), 2);
}
other => panic!("expected RankMismatch(last_hidden_state), got {other:?}"),
}
}
#[test]
fn encode_none_rejects_non_rank3_hidden_state_in_pooled_path() {
let tok = word_tokenizer();
let model = RawShapeModel {
hidden_data: vec![1.0, 2.0, 3.0, 4.0],
hidden_shape: vec![4],
pooled_data: vec![7.0, 5.0, 6.0, 4.0],
pooled_shape: vec![2, 2],
};
let cfg = EncodeConfig::new()
.with_add_special_tokens(false)
.with_strategy(PoolingStrategy::None)
.with_normalize(false);
let err = encode(&model, &tok, &["a b c", "d e"], &cfg).unwrap_err();
assert!(
matches!(
err,
Error::RankMismatch(ref p)
if p.context() == "encode: model last_hidden_state must be rank-3 (batch, seq_len, hidden)"
&& p.actual() == 1
),
"expected RankMismatch(last_hidden_state, actual=1), got {err:?}"
);
}