use super::*;
#[test]
fn score_single_vector_rejects_empty_queries() {
let p = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(3,)).unwrap();
let err = score_single_vector(&[], std::slice::from_ref(&p)).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("No queries provided"),
"expected python parity msg, got {msg}"
);
}
#[test]
fn score_single_vector_rejects_empty_passages() {
let q = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(3,)).unwrap();
let err = score_single_vector(std::slice::from_ref(&q), &[]).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("No passages provided"),
"expected python parity msg, got {msg}"
);
}
#[test]
fn score_single_vector_dot_product_shape_and_values() {
let q0 = Array::from_slice::<f32>(&[1.0, 0.0, 0.0], &(3,)).unwrap();
let q1 = Array::from_slice::<f32>(&[0.0, 1.0, 0.0], &(3,)).unwrap();
let p0 = Array::from_slice::<f32>(&[1.0, 1.0, 1.0], &(3,)).unwrap();
let p1 = Array::from_slice::<f32>(&[2.0, 0.0, 2.0], &(3,)).unwrap();
let mut scores =
score_single_vector(&[q0, q1], &[p0, p1]).expect("score_single_vector should succeed");
assert_eq!(scores.shape(), vec![2, 2], "shape (B,C) = (2,2)");
assert_eq!(scores.dtype().unwrap(), Dtype::F32, "f32 cast (python L63)");
let v = scores.to_vec::<f32>().unwrap();
assert_eq!(v, vec![1.0, 2.0, 1.0, 0.0]);
}
#[test]
fn score_single_vector_rejects_zero_token_query() {
let q_empty = Array::from_slice::<f32>(&[], &(0,)).unwrap();
let p = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(3,)).unwrap();
let err =
score_single_vector(std::slice::from_ref(&q_empty), std::slice::from_ref(&p)).unwrap_err();
assert!(
matches!(err, Error::OutOfRange(_)),
"expected OutOfRange, got {err:?}"
);
let msg = format!("{err}");
assert!(
msg.contains("queries"),
"expected 'queries' in message, got {msg}"
);
}
#[test]
fn score_single_vector_rejects_zero_token_passage() {
let q = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(3,)).unwrap();
let p_empty = Array::from_slice::<f32>(&[], &(0,)).unwrap();
let err =
score_single_vector(std::slice::from_ref(&q), std::slice::from_ref(&p_empty)).unwrap_err();
assert!(
matches!(err, Error::OutOfRange(_)),
"expected OutOfRange, got {err:?}"
);
let msg = format!("{err}");
assert!(
msg.contains("passages"),
"expected 'passages' in message, got {msg}"
);
}
#[test]
fn score_single_vector_casts_result_to_f32_from_f16() {
let q0 = Array::from_slice::<f32>(&[1.0, 0.0], &(2,))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let p0 = Array::from_slice::<f32>(&[1.0, 1.0], &(2,))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let scores = score_single_vector(&[q0], &[p0]).unwrap();
assert_eq!(scores.shape(), vec![1, 1]);
assert_eq!(
scores.dtype().unwrap(),
Dtype::F32,
"result must be f32 even with f16 inputs"
);
}
#[test]
fn score_multi_vector_rejects_empty_queries() {
let p = Array::from_slice::<f32>(&[1.0; 4], &(2, 2)).unwrap();
let err = score_multi_vector(&[], std::slice::from_ref(&p), 128).unwrap_err();
assert!(format!("{err}").contains("No queries provided"));
}
#[test]
fn score_multi_vector_rejects_empty_passages() {
let q = Array::from_slice::<f32>(&[1.0; 4], &(2, 2)).unwrap();
let err = score_multi_vector(std::slice::from_ref(&q), &[], 128).unwrap_err();
assert!(format!("{err}").contains("No passages provided"));
}
#[test]
fn score_multi_vector_rejects_zero_batch_size() {
let q = Array::from_slice::<f32>(&[1.0; 4], &(2, 2)).unwrap();
let p = Array::from_slice::<f32>(&[1.0; 4], &(2, 2)).unwrap();
let err = score_multi_vector(std::slice::from_ref(&q), std::slice::from_ref(&p), 0).unwrap_err();
assert!(format!("{err}").contains("batch_size"));
}
#[test]
fn score_multi_vector_rejects_zero_token_query() {
let q_empty = Array::from_slice::<f32>(&[], &(0, 2)).unwrap();
let p = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let err = score_multi_vector(
std::slice::from_ref(&q_empty),
std::slice::from_ref(&p),
128,
)
.unwrap_err();
assert!(
matches!(err, Error::OutOfRange(_)),
"expected OutOfRange from score_multi_vector pre-validation, got {err:?}"
);
let msg = format!("{err}");
assert!(
msg.contains("queries") && msg.contains("index 0"),
"expected 'queries' + 'index 0' in message, got {msg}"
);
}
#[test]
fn score_multi_vector_rejects_zero_token_passage_in_mixed_tile() {
let q = Array::from_slice::<f32>(&[1.0, 0.0], &(1, 2)).unwrap();
let p0 = Array::from_slice::<f32>(&[], &(0, 2)).unwrap();
let p1 = Array::from_slice::<f32>(&[2.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let err = score_multi_vector(std::slice::from_ref(&q), &[p0, p1], 2).unwrap_err();
assert!(
matches!(err, Error::OutOfRange(_)),
"expected OutOfRange from pre-validation (NOT -inf propagation), got {err:?}"
);
let msg = format!("{err}");
assert!(
msg.contains("passages") && msg.contains("index 0"),
"expected 'passages' + 'index 0' in message, got {msg}"
);
}
#[test]
fn score_multi_vector_rejects_zero_token_query_at_non_zero_global_index() {
let q_valid_0 = Array::from_slice::<f32>(&[1.0, 0.0], &(1, 2)).unwrap();
let q_valid_1 = Array::from_slice::<f32>(&[0.0, 1.0], &(1, 2)).unwrap();
let q_valid_2 = Array::from_slice::<f32>(&[1.0, 1.0], &(1, 2)).unwrap();
let q_empty = Array::from_slice::<f32>(&[], &(0, 2)).unwrap();
let p = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let qs = vec![q_valid_0, q_valid_1, q_valid_2, q_empty];
let result = score_multi_vector(&qs, std::slice::from_ref(&p), 2);
let err = match result {
Err(e) => e,
Ok(_) => panic!("expected OutOfRange, got Ok"),
};
assert!(
matches!(err, Error::OutOfRange(_)),
"expected OutOfRange, got {err:?}"
);
let msg = format!("{err}");
assert!(
msg.contains("queries") && msg.contains("index 3"),
"expected 'queries' + global index 3, got: {msg}"
);
assert!(
!msg.contains("index 1") && !msg.contains("array 1"),
"tile-local index leaked: {msg}"
);
}
#[test]
fn score_multi_vector_rejects_zero_token_passage_at_non_zero_global_index() {
let q = Array::from_slice::<f32>(&[1.0, 0.0], &(1, 2)).unwrap();
let p0 = Array::from_slice::<f32>(&[1.0, 0.0], &(1, 2)).unwrap();
let p1 = Array::from_slice::<f32>(&[0.0, 1.0], &(1, 2)).unwrap();
let p2 = Array::from_slice::<f32>(&[1.0, 1.0], &(1, 2)).unwrap();
let p_empty = Array::from_slice::<f32>(&[], &(0, 2)).unwrap();
let err = score_multi_vector(std::slice::from_ref(&q), &[p0, p1, p2, p_empty], 2).unwrap_err();
assert!(
matches!(err, Error::OutOfRange(_)),
"expected OutOfRange from pre-validation, got {err:?}"
);
let msg = format!("{err}");
assert!(
msg.contains("passages") && msg.contains("index 3"),
"expected 'passages' + global 'index 3' (not tile-local 'array 1') in message, got {msg}"
);
}
#[test]
fn score_multi_vector_identity_pair() {
let q = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let p = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let mut scores = score_multi_vector(&[q], &[p], 128).unwrap();
assert_eq!(scores.shape(), vec![1, 1]);
assert_eq!(scores.dtype().unwrap(), Dtype::F32);
assert_eq!(scores.to_vec::<f32>().unwrap(), vec![2.0]);
}
#[test]
fn score_multi_vector_ragged_n_and_s_with_batching() {
let q0 = Array::from_slice::<f32>(&[1.0, 0.0], &(1, 2)).unwrap();
let q1 = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let p0 = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let p1 = Array::from_slice::<f32>(&[1.0, 0.0], &(1, 2)).unwrap();
let mut scores = score_multi_vector(&[q0, q1], &[p0, p1], 1).unwrap();
assert_eq!(scores.shape(), vec![2, 2]);
assert_eq!(scores.dtype().unwrap(), Dtype::F32);
assert_eq!(scores.to_vec::<f32>().unwrap(), vec![1.0, 1.0, 2.0, 1.0]);
}
#[test]
fn score_multi_vector_default_batch_size_matches_tiled() {
let q0 = Array::from_slice::<f32>(&[1.0, 0.0], &(1, 2)).unwrap();
let q1 = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let p0 = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let p1 = Array::from_slice::<f32>(&[1.0, 0.0], &(1, 2)).unwrap();
let mut scores = score_multi_vector(&[q0, q1], &[p0, p1], 128).unwrap();
assert_eq!(scores.shape(), vec![2, 2]);
assert_eq!(scores.to_vec::<f32>().unwrap(), vec![1.0, 1.0, 2.0, 1.0]);
}
#[test]
fn score_multi_vector_casts_result_to_f32_from_f16() {
let q = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let p = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let scores = score_multi_vector(&[q], &[p], 128).unwrap();
assert_eq!(scores.shape(), vec![1, 1]);
assert_eq!(scores.dtype().unwrap(), Dtype::F32);
}
#[test]
fn pad_to_max_pads_ragged_then_stacks() {
let a = Array::from_slice::<f32>(&[1.0, 2.0], &(1, 2)).unwrap(); let b = Array::from_slice::<f32>(&[3.0, 4.0, 5.0, 6.0], &(2, 2)).unwrap(); let (mut padded, _lens) = pad_to_max(&[a, b]).unwrap();
assert_eq!(padded.shape(), vec![2, 2, 2]);
let v = padded.to_vec::<f32>().unwrap();
assert_eq!(v, vec![1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn pad_to_max_rejects_empty_slice() {
let err = pad_to_max(&[]).unwrap_err();
assert!(format!("{err}").contains("empty"));
}
#[test]
fn pad_to_max_rejects_non_rank_2() {
let bad = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &(3,)).unwrap();
let err = pad_to_max(std::slice::from_ref(&bad)).unwrap_err();
assert!(format!("{err}").contains("rank-2"));
}
#[test]
fn pad_to_max_rejects_mismatched_emb_dim() {
let a = Array::from_slice::<f32>(&[1.0, 2.0], &(1, 2)).unwrap();
let b = Array::from_slice::<f32>(&[3.0, 4.0, 5.0], &(1, 3)).unwrap();
let err = pad_to_max(&[a, b]).unwrap_err();
assert!(format!("{err}").contains("emb_dim"));
}
#[test]
fn pad_to_max_rejects_zero_token_array() {
let zero = Array::from_slice::<f32>(&[], &(0, 2)).unwrap();
let err = pad_to_max(std::slice::from_ref(&zero)).unwrap_err();
assert!(
matches!(err, Error::OutOfRange(_)),
"expected OutOfRange, got {err:?}"
);
let msg = format!("{err}");
assert!(
msg.contains("index 0"),
"expected 'index 0' in message, got {msg}"
);
let good = Array::from_slice::<f32>(&[1.0, 2.0], &(1, 2)).unwrap();
let bad = Array::from_slice::<f32>(&[], &(0, 2)).unwrap();
let err2 = pad_to_max(&[good, bad]).unwrap_err();
let msg2 = format!("{err2}");
assert!(
msg2.contains("index 1"),
"expected 'index 1' in message, got {msg2}"
);
}
#[test]
fn pad_to_max_preserves_f16_dtype() {
let a = Array::from_slice::<f32>(&[1.0, 2.0], &(1, 2))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let b = Array::from_slice::<f32>(&[3.0, 4.0, 5.0, 6.0], &(2, 2))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let (padded, _lens) = pad_to_max(&[a, b]).unwrap();
assert_eq!(padded.shape(), vec![2, 2, 2]);
assert_eq!(
padded.dtype().unwrap(),
Dtype::F16,
"padding must preserve input dtype (python L87 `dtype=a.dtype`)"
);
}
#[test]
fn pad_to_max_returns_original_lengths() {
let a = Array::from_slice::<f32>(&[1.0, 2.0], &(1, 2)).unwrap(); let b = Array::from_slice::<f32>(&[3.0, 4.0, 5.0, 6.0], &(2, 2)).unwrap(); let c = Array::from_slice::<f32>(&[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &(3, 2)).unwrap(); let (padded, lens) = pad_to_max(&[a, b, c]).unwrap();
assert_eq!(padded.shape(), vec![3, 3, 2], "stacked to (3, max_n=3, 2)");
assert_eq!(
lens,
vec![1, 2, 3],
"original_lengths must mirror input order"
);
}
#[test]
fn score_multi_vector_ragged_negative_similarity_batch_size_agnostic() {
let q = Array::from_slice::<f32>(&[1.0, 0.0], &(1, 2)).unwrap();
let p0 = Array::from_slice::<f32>(&[-1.0, 0.0], &(1, 2)).unwrap();
let p1 = Array::from_slice::<f32>(&[2.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let mut scores_b1 = score_multi_vector(
std::slice::from_ref(&q),
&[p0.try_clone().unwrap(), p1.try_clone().unwrap()],
1,
)
.unwrap();
assert_eq!(scores_b1.shape(), vec![1, 2]);
let v_b1 = scores_b1.to_vec::<f32>().unwrap();
let mut scores_tiled = score_multi_vector(std::slice::from_ref(&q), &[p0, p1], 2).unwrap();
assert_eq!(scores_tiled.shape(), vec![1, 2]);
let v_tiled = scores_tiled.to_vec::<f32>().unwrap();
assert_eq!(
v_b1[0], -1.0,
"p0 alone: <q,p0_0> = -1.0; sum over the single query token = -1.0"
);
assert_eq!(
v_tiled[0], -1.0,
"p0 tiled with p1: padded zero column must be masked → -1.0, not 0.0"
);
assert_eq!(
v_b1[1], v_tiled[1],
"p1 score must be tile-invariant in both branches"
);
assert_eq!(
v_b1, v_tiled,
"score_multi_vector ranking must not depend on batch_size"
);
}
struct TestProcessor;
impl BaseColVisionProcessor for TestProcessor {
fn process_images(&self, images: &[Vec<u8>]) -> Result<ProcessorBatch> {
let mut batch = ProcessorBatch::new();
let count = i32::try_from(images.len()).unwrap_or(0);
batch.insert(
"pixel_values_count".into(),
Array::from_slice::<i32>(&[count], &(1,))?,
);
Ok(batch)
}
fn process_queries(
&self,
queries: &[&str],
_max_length: usize,
_suffix: Option<&str>,
) -> Result<ProcessorBatch> {
let mut batch = ProcessorBatch::new();
let count = i32::try_from(queries.len()).unwrap_or(0);
batch.insert(
"input_ids_count".into(),
Array::from_slice::<i32>(&[count], &(1,))?,
);
Ok(batch)
}
fn score(&self, qs: &[Array], ps: &[Array], batch_size: usize) -> Result<Array> {
score_multi_vector(qs, ps, batch_size)
}
}
#[test]
fn base_processor_trait_impl_round_trips() {
let p = TestProcessor;
let imgs = vec![vec![0u8, 1, 2], vec![3u8, 4, 5]];
let img_batch = p.process_images(&imgs).unwrap();
assert!(img_batch.contains_key("pixel_values_count"));
let queries = vec!["query one", "query two", "query three"];
let q_batch = p.process_queries(&queries, 50, None).unwrap();
assert!(q_batch.contains_key("input_ids_count"));
let q = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let pp = Array::from_slice::<f32>(&[1.0, 0.0, 0.0, 1.0], &(2, 2)).unwrap();
let mut scores = p.score(&[q], &[pp], 128).unwrap();
assert_eq!(scores.shape(), vec![1, 1]);
assert_eq!(scores.to_vec::<f32>().unwrap(), vec![2.0]);
}