use crate::multivector::MultiVectorEmbedding;
use crate::Result;
pub trait MultiVectorEmbedder: Send + Sync {
fn embed_tokens(&self, text: &str) -> Result<MultiVectorEmbedding>;
fn embed_tokens_batch(&self, texts: &[&str]) -> Result<Vec<MultiVectorEmbedding>> {
texts.iter().map(|t| self.embed_tokens(t)).collect()
}
fn token_dimension(&self) -> usize;
fn max_tokens(&self) -> usize;
fn model_id(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct MockMultiVectorEmbedder {
dim: usize,
max_tokens: usize,
seed: u64,
}
impl MockMultiVectorEmbedder {
#[must_use]
pub fn new(dim: usize, max_tokens: usize) -> Self {
Self { dim, max_tokens, seed: 42 }
}
#[must_use]
pub fn with_seed(dim: usize, max_tokens: usize, seed: u64) -> Self {
Self { dim, max_tokens, seed }
}
fn generate_unit_vector(&self, seed: u64) -> Vec<f32> {
let mut vec = Vec::with_capacity(self.dim);
let mut rng = seed;
for _ in 0..self.dim {
rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
let val = ((rng >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
vec.push(val);
}
let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for v in &mut vec {
*v /= norm;
}
}
vec
}
fn hash_token(&self, token: &str, index: usize) -> u64 {
let mut hash = self.seed;
for byte in token.bytes() {
hash = hash.wrapping_mul(31).wrapping_add(u64::from(byte));
}
hash = hash.wrapping_mul(31).wrapping_add(index as u64);
hash
}
}
impl MultiVectorEmbedder for MockMultiVectorEmbedder {
fn embed_tokens(&self, text: &str) -> Result<MultiVectorEmbedding> {
let tokens: Vec<&str> = text.split_whitespace().collect();
let num_tokens = tokens.len().min(self.max_tokens);
if num_tokens == 0 {
return Ok(MultiVectorEmbedding::new(Vec::new(), 0, self.dim));
}
let mut embeddings = Vec::with_capacity(num_tokens * self.dim);
for (i, token) in tokens.iter().take(num_tokens).enumerate() {
let token_seed = self.hash_token(token, i);
embeddings.extend(self.generate_unit_vector(token_seed));
}
Ok(MultiVectorEmbedding::new(embeddings, num_tokens, self.dim))
}
fn embed_tokens_batch(&self, texts: &[&str]) -> Result<Vec<MultiVectorEmbedding>> {
texts.iter().map(|t| self.embed_tokens(t)).collect()
}
fn token_dimension(&self) -> usize {
self.dim
}
fn max_tokens(&self) -> usize {
self.max_tokens
}
fn model_id(&self) -> &str {
"mock-multivector"
}
}
impl<E: MultiVectorEmbedder + ?Sized> MultiVectorEmbedder for Box<E> {
fn embed_tokens(&self, text: &str) -> Result<MultiVectorEmbedding> {
(**self).embed_tokens(text)
}
fn embed_tokens_batch(&self, texts: &[&str]) -> Result<Vec<MultiVectorEmbedding>> {
(**self).embed_tokens_batch(texts)
}
fn token_dimension(&self) -> usize {
(**self).token_dimension()
}
fn max_tokens(&self) -> usize {
(**self).max_tokens()
}
fn model_id(&self) -> &str {
(**self).model_id()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mock_embedder_new() {
let embedder = MockMultiVectorEmbedder::new(128, 512);
assert_eq!(embedder.token_dimension(), 128);
assert_eq!(embedder.max_tokens(), 512);
assert_eq!(embedder.model_id(), "mock-multivector");
}
#[test]
fn test_mock_embedder_with_seed() {
let embedder1 = MockMultiVectorEmbedder::with_seed(128, 512, 123);
let embedder2 = MockMultiVectorEmbedder::with_seed(128, 512, 456);
let emb1 = embedder1.embed_tokens("test").unwrap();
let emb2 = embedder2.embed_tokens("test").unwrap();
assert_ne!(emb1.as_slice(), emb2.as_slice());
}
#[test]
fn test_mock_embedder_deterministic() {
let embedder = MockMultiVectorEmbedder::new(64, 256);
let emb1 = embedder.embed_tokens("hello world").unwrap();
let emb2 = embedder.embed_tokens("hello world").unwrap();
assert_eq!(emb1.num_tokens(), emb2.num_tokens());
assert_eq!(emb1.as_slice(), emb2.as_slice());
}
#[test]
fn test_mock_embedder_token_count() {
let embedder = MockMultiVectorEmbedder::new(64, 256);
let emb = embedder.embed_tokens("one two three four five").unwrap();
assert_eq!(emb.num_tokens(), 5);
assert_eq!(emb.dim(), 64);
}
#[test]
fn test_mock_embedder_max_tokens() {
let embedder = MockMultiVectorEmbedder::new(64, 3);
let emb = embedder.embed_tokens("one two three four five six").unwrap();
assert_eq!(emb.num_tokens(), 3); }
#[test]
fn test_mock_embedder_empty_text() {
let embedder = MockMultiVectorEmbedder::new(64, 256);
let emb = embedder.embed_tokens("").unwrap();
assert_eq!(emb.num_tokens(), 0);
assert!(emb.is_empty());
}
#[test]
fn test_mock_embedder_whitespace_only() {
let embedder = MockMultiVectorEmbedder::new(64, 256);
let emb = embedder.embed_tokens(" \t\n ").unwrap();
assert_eq!(emb.num_tokens(), 0);
}
#[test]
fn test_mock_embedder_unit_vectors() {
let embedder = MockMultiVectorEmbedder::new(64, 256);
let emb = embedder.embed_tokens("test token").unwrap();
for token_emb in emb.tokens() {
let norm: f32 = token_emb.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.001, "Token not unit length: norm = {}", norm);
}
}
#[test]
fn test_mock_embedder_different_tokens() {
let embedder = MockMultiVectorEmbedder::new(64, 256);
let emb = embedder.embed_tokens("hello world").unwrap();
let token0 = emb.token(0);
let token1 = emb.token(1);
assert_ne!(token0, token1);
}
#[test]
fn test_mock_embedder_batch() {
let embedder = MockMultiVectorEmbedder::new(64, 256);
let texts = ["hello", "world", "test"];
let embeddings = embedder.embed_tokens_batch(&texts).unwrap();
assert_eq!(embeddings.len(), 3);
assert_eq!(embeddings[0].num_tokens(), 1);
assert_eq!(embeddings[1].num_tokens(), 1);
assert_eq!(embeddings[2].num_tokens(), 1);
}
#[test]
fn test_mock_embedder_batch_consistency() {
let embedder = MockMultiVectorEmbedder::new(64, 256);
let texts = ["hello", "world"];
let batch_result = embedder.embed_tokens_batch(&texts).unwrap();
let single1 = embedder.embed_tokens("hello").unwrap();
let single2 = embedder.embed_tokens("world").unwrap();
assert_eq!(batch_result[0].as_slice(), single1.as_slice());
assert_eq!(batch_result[1].as_slice(), single2.as_slice());
}
#[test]
fn test_boxed_embedder() {
let embedder: Box<dyn MultiVectorEmbedder> =
Box::new(MockMultiVectorEmbedder::new(64, 256));
let emb = embedder.embed_tokens("test").unwrap();
assert_eq!(emb.num_tokens(), 1);
assert_eq!(embedder.token_dimension(), 64);
}
use proptest::prelude::*;
proptest! {
#[test]
fn prop_embed_produces_correct_dimensions(
dim in 16usize..256,
text in "[a-z ]{1,100}"
) {
let embedder = MockMultiVectorEmbedder::new(dim, 512);
let emb = embedder.embed_tokens(&text).unwrap();
prop_assert_eq!(emb.dim(), dim);
if emb.num_tokens() > 0 {
prop_assert_eq!(emb.token(0).len(), dim);
}
}
#[test]
fn prop_embed_respects_max_tokens(
max_tokens in 1usize..10,
words in 1usize..20
) {
let text: String = (0..words).map(|i| format!("word{}", i)).collect::<Vec<_>>().join(" ");
let embedder = MockMultiVectorEmbedder::new(64, max_tokens);
let emb = embedder.embed_tokens(&text).unwrap();
prop_assert!(emb.num_tokens() <= max_tokens);
}
#[test]
fn prop_embed_is_deterministic(
seed in 0u64..10000,
text in "[a-z ]{1,50}"
) {
let embedder = MockMultiVectorEmbedder::with_seed(64, 256, seed);
let emb1 = embedder.embed_tokens(&text).unwrap();
let emb2 = embedder.embed_tokens(&text).unwrap();
prop_assert_eq!(emb1.as_slice(), emb2.as_slice());
}
#[test]
fn prop_tokens_are_approximately_unit_length(
dim in 32usize..128,
text in "[a-z]{3,10}"
) {
let embedder = MockMultiVectorEmbedder::new(dim, 256);
let emb = embedder.embed_tokens(&text).unwrap();
for token_emb in emb.tokens() {
let norm: f32 = token_emb.iter().map(|x| x * x).sum::<f32>().sqrt();
prop_assert!((norm - 1.0).abs() < 0.01);
}
}
}
}