use crate::context::{Context, ContextParams};
use crate::error::MullamaError;
use crate::model::Model;
use crate::sys;
use crate::token::TokenId;
use std::sync::Arc;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
#[derive(Debug, Clone)]
pub struct MultiVectorEmbedding {
data: Vec<f32>,
dimension: usize,
n_tokens: usize,
token_ids: Option<Vec<TokenId>>,
normalized: bool,
}
impl MultiVectorEmbedding {
pub fn new(data: Vec<f32>, dimension: usize, token_ids: Option<Vec<TokenId>>) -> Self {
let n_tokens = if dimension > 0 {
debug_assert!(
data.len() % dimension == 0,
"Data length {} is not divisible by dimension {}",
data.len(),
dimension
);
data.len() / dimension
} else {
0
};
Self {
data,
dimension,
n_tokens,
token_ids,
normalized: false,
}
}
pub fn empty(dimension: usize) -> Self {
Self {
data: Vec::new(),
dimension,
n_tokens: 0,
token_ids: None,
normalized: false,
}
}
#[inline]
pub fn get(&self, index: usize) -> Option<&[f32]> {
if index >= self.n_tokens {
return None;
}
let start = index * self.dimension;
let end = start + self.dimension;
Some(&self.data[start..end])
}
#[inline]
pub fn get_mut(&mut self, index: usize) -> Option<&mut [f32]> {
if index >= self.n_tokens {
return None;
}
let start = index * self.dimension;
let end = start + self.dimension;
Some(&mut self.data[start..end])
}
#[inline]
pub fn len(&self) -> usize {
self.n_tokens
}
#[inline]
pub fn is_empty(&self) -> bool {
self.n_tokens == 0
}
#[inline]
pub fn dimension(&self) -> usize {
self.dimension
}
#[inline]
pub fn as_slice(&self) -> &[f32] {
&self.data
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [f32] {
&mut self.data
}
pub fn to_vecs(&self) -> Vec<Vec<f32>> {
(0..self.n_tokens)
.filter_map(|i| self.get(i).map(|e| e.to_vec()))
.collect()
}
pub fn token_ids(&self) -> Option<&[TokenId]> {
self.token_ids.as_deref()
}
#[inline]
pub fn is_normalized(&self) -> bool {
self.normalized
}
pub fn normalize(&mut self) {
if self.normalized {
return;
}
for i in 0..self.n_tokens {
let start = i * self.dimension;
let end = start + self.dimension;
let slice = &mut self.data[start..end];
let magnitude: f32 = slice.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude > f32::EPSILON {
for x in slice.iter_mut() {
*x /= magnitude;
}
}
}
self.normalized = true;
}
pub fn normalized(&self) -> Self {
let mut copy = self.clone();
copy.normalize();
copy
}
pub fn iter(&self) -> impl Iterator<Item = &[f32]> {
(0..self.n_tokens).map(move |i| {
let start = i * self.dimension;
let end = start + self.dimension;
&self.data[start..end]
})
}
pub fn size_bytes(&self) -> usize {
self.data.len() * std::mem::size_of::<f32>()
+ self
.token_ids
.as_ref()
.map(|ids| ids.len() * std::mem::size_of::<TokenId>())
.unwrap_or(0)
}
}
#[derive(Debug, Clone)]
pub struct MultiVectorConfig {
pub normalize: bool,
pub skip_special_tokens: bool,
pub store_token_ids: bool,
pub batch_size: usize,
pub max_seq_len: u32,
}
impl Default for MultiVectorConfig {
fn default() -> Self {
Self {
normalize: true,
skip_special_tokens: true,
store_token_ids: false,
batch_size: 32,
max_seq_len: 0,
}
}
}
impl MultiVectorConfig {
pub fn new() -> Self {
Self::default()
}
pub fn normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
pub fn skip_special_tokens(mut self, skip: bool) -> Self {
self.skip_special_tokens = skip;
self
}
pub fn store_token_ids(mut self, store: bool) -> Self {
self.store_token_ids = store;
self
}
pub fn batch_size(mut self, size: usize) -> Self {
self.batch_size = size;
self
}
pub fn max_seq_len(mut self, len: u32) -> Self {
self.max_seq_len = len;
self
}
}
pub struct MultiVectorGenerator {
model: Arc<Model>,
context: Context,
config: MultiVectorConfig,
}
impl MultiVectorGenerator {
pub fn new(model: Arc<Model>, config: MultiVectorConfig) -> Result<Self, MullamaError> {
let mut ctx_params = ContextParams::default();
ctx_params.embeddings = true;
ctx_params.pooling_type = sys::llama_pooling_type::LLAMA_POOLING_TYPE_NONE;
if config.max_seq_len > 0 {
ctx_params.n_ctx = config.max_seq_len;
}
let context = Context::new(model.clone(), ctx_params)?;
Ok(Self {
model,
context,
config,
})
}
pub fn embed_text(&mut self, text: &str) -> Result<MultiVectorEmbedding, MullamaError> {
let tokens = self.model.tokenize(text, true, false)?;
self.embed_tokens(&tokens)
}
pub fn embed_tokens(
&mut self,
tokens: &[TokenId],
) -> Result<MultiVectorEmbedding, MullamaError> {
if tokens.is_empty() {
return Err(MullamaError::InvalidInput(
"Cannot embed empty token sequence".to_string(),
));
}
self.context.kv_cache_clear();
self.context.decode(tokens)?;
let n_embd = self.model.n_embd() as usize;
let mut embeddings_data = Vec::with_capacity(tokens.len() * n_embd);
let mut output_token_ids = if self.config.store_token_ids {
Some(Vec::with_capacity(tokens.len()))
} else {
None
};
for (i, &token) in tokens.iter().enumerate() {
if self.config.skip_special_tokens && self.is_special_token(token) {
continue;
}
if let Some(emb) = self.context.get_embeddings_ith(i as i32) {
embeddings_data.extend_from_slice(emb);
if let Some(ref mut ids) = output_token_ids {
ids.push(token);
}
} else {
return Err(MullamaError::EmbeddingError(format!(
"Failed to get embedding for token at index {}",
i
)));
}
}
let mut mv = MultiVectorEmbedding::new(embeddings_data, n_embd, output_token_ids);
if self.config.normalize {
mv.normalize();
}
Ok(mv)
}
pub fn embed_batch(
&mut self,
texts: &[&str],
) -> Result<Vec<MultiVectorEmbedding>, MullamaError> {
let mut results = Vec::with_capacity(texts.len());
for text in texts {
results.push(self.embed_text(text)?);
}
Ok(results)
}
#[inline]
pub fn embedding_dim(&self) -> usize {
self.model.n_embd() as usize
}
#[inline]
pub fn model(&self) -> &Arc<Model> {
&self.model
}
#[inline]
pub fn config(&self) -> &MultiVectorConfig {
&self.config
}
fn is_special_token(&self, token: TokenId) -> bool {
token == self.model.token_bos()
|| token == self.model.token_eos()
|| token == self.model.token_pad()
|| token == self.model.token_eot()
|| self.model.token_is_control(token)
}
}
pub struct LateInteractionScorer;
impl LateInteractionScorer {
pub fn max_sim(query: &MultiVectorEmbedding, document: &MultiVectorEmbedding) -> f32 {
if query.is_empty() || document.is_empty() {
return 0.0;
}
if query.dimension() != document.dimension() {
return 0.0; }
let mut total_score = 0.0;
for q_emb in query.iter() {
let mut max_sim = f32::NEG_INFINITY;
for d_emb in document.iter() {
let sim = Self::dot_product(q_emb, d_emb);
if sim > max_sim {
max_sim = sim;
}
}
if max_sim > f32::NEG_INFINITY {
total_score += max_sim;
}
}
total_score
}
pub fn max_sim_normalized(
query: &MultiVectorEmbedding,
document: &MultiVectorEmbedding,
) -> f32 {
if query.is_empty() {
return 0.0;
}
Self::max_sim(query, document) / query.len() as f32
}
pub fn max_sim_symmetric(a: &MultiVectorEmbedding, b: &MultiVectorEmbedding) -> f32 {
let ab = Self::max_sim(a, b);
let ba = Self::max_sim(b, a);
(ab + ba) / 2.0
}
pub fn find_top_k(
query: &MultiVectorEmbedding,
documents: &[MultiVectorEmbedding],
k: usize,
) -> Vec<(usize, f32)> {
let mut scores: Vec<(usize, f32)> = documents
.iter()
.enumerate()
.map(|(i, doc)| (i, Self::max_sim(query, doc)))
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(k);
scores
}
pub fn rank_documents(
query: &MultiVectorEmbedding,
documents: &[MultiVectorEmbedding],
) -> Vec<(usize, f32)> {
let mut scores: Vec<(usize, f32)> = documents
.iter()
.enumerate()
.map(|(i, doc)| (i, Self::max_sim(query, doc)))
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores
}
pub fn similarity_matrix(
query: &MultiVectorEmbedding,
document: &MultiVectorEmbedding,
) -> Vec<Vec<f32>> {
let mut matrix = Vec::with_capacity(query.len());
for q_emb in query.iter() {
let mut row = Vec::with_capacity(document.len());
for d_emb in document.iter() {
row.push(Self::dot_product(q_emb, d_emb));
}
matrix.push(row);
}
matrix
}
pub fn best_matches(
query: &MultiVectorEmbedding,
document: &MultiVectorEmbedding,
) -> Vec<(usize, f32)> {
let mut matches = Vec::with_capacity(query.len());
for q_emb in query.iter() {
let mut best_idx = 0;
let mut best_sim = f32::NEG_INFINITY;
for (d_idx, d_emb) in document.iter().enumerate() {
let sim = Self::dot_product(q_emb, d_emb);
if sim > best_sim {
best_sim = sim;
best_idx = d_idx;
}
}
matches.push((best_idx, best_sim));
}
matches
}
pub fn batch_score(
queries: &[MultiVectorEmbedding],
documents: &[MultiVectorEmbedding],
) -> Vec<Vec<f32>> {
queries
.iter()
.map(|q| documents.iter().map(|d| Self::max_sim(q, d)).collect())
.collect()
}
#[inline]
fn dot_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if mag_a < f32::EPSILON || mag_b < f32::EPSILON {
0.0
} else {
dot / (mag_a * mag_b)
}
}
}
#[cfg(feature = "parallel")]
impl LateInteractionScorer {
pub fn find_top_k_parallel(
query: &MultiVectorEmbedding,
documents: &[MultiVectorEmbedding],
k: usize,
) -> Vec<(usize, f32)> {
let mut scores: Vec<(usize, f32)> = documents
.par_iter()
.enumerate()
.map(|(i, doc)| (i, Self::max_sim(query, doc)))
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(k);
scores
}
pub fn batch_score_parallel(
queries: &[MultiVectorEmbedding],
documents: &[MultiVectorEmbedding],
) -> Vec<Vec<f32>> {
queries
.par_iter()
.map(|q| documents.iter().map(|d| Self::max_sim(q, d)).collect())
.collect()
}
pub fn rank_documents_parallel(
query: &MultiVectorEmbedding,
documents: &[MultiVectorEmbedding],
) -> Vec<(usize, f32)> {
let mut scores: Vec<(usize, f32)> = documents
.par_iter()
.enumerate()
.map(|(i, doc)| (i, Self::max_sim(query, doc)))
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multi_vector_embedding_creation() {
let data = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
let mv = MultiVectorEmbedding::new(data, 3, None);
assert_eq!(mv.len(), 2);
assert_eq!(mv.dimension(), 3);
assert!(!mv.is_empty());
assert_eq!(mv.get(0), Some(&[1.0, 0.0, 0.0][..]));
assert_eq!(mv.get(1), Some(&[0.0, 1.0, 0.0][..]));
assert_eq!(mv.get(2), None);
}
#[test]
fn test_multi_vector_empty() {
let mv = MultiVectorEmbedding::empty(128);
assert!(mv.is_empty());
assert_eq!(mv.len(), 0);
assert_eq!(mv.dimension(), 128);
}
#[test]
fn test_multi_vector_normalization() {
let data = vec![3.0, 4.0, 0.0]; let mut mv = MultiVectorEmbedding::new(data, 3, None);
assert!(!mv.is_normalized());
mv.normalize();
let emb = mv.get(0).unwrap();
assert!((emb[0] - 0.6).abs() < 0.001);
assert!((emb[1] - 0.8).abs() < 0.001);
assert!(mv.is_normalized());
mv.normalize();
let emb = mv.get(0).unwrap();
assert!((emb[0] - 0.6).abs() < 0.001);
}
#[test]
fn test_multi_vector_to_vecs() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let mv = MultiVectorEmbedding::new(data, 2, None);
let vecs = mv.to_vecs();
assert_eq!(vecs.len(), 2);
assert_eq!(vecs[0], vec![1.0, 2.0]);
assert_eq!(vecs[1], vec![3.0, 4.0]);
}
#[test]
fn test_multi_vector_with_token_ids() {
let data = vec![1.0, 0.0, 0.0, 1.0];
let token_ids = Some(vec![100, 200]);
let mv = MultiVectorEmbedding::new(data, 2, token_ids);
assert_eq!(mv.token_ids(), Some(&[100, 200][..]));
}
#[test]
fn test_max_sim_identical() {
let data = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
let mv1 = MultiVectorEmbedding::new(data.clone(), 3, None);
let mv2 = MultiVectorEmbedding::new(data, 3, None);
let score = LateInteractionScorer::max_sim(&mv1, &mv2);
assert!((score - 2.0).abs() < 0.001); }
#[test]
fn test_max_sim_orthogonal() {
let q_data = vec![1.0, 0.0, 0.0]; let d_data = vec![0.0, 1.0, 0.0];
let query = MultiVectorEmbedding::new(q_data, 3, None);
let doc = MultiVectorEmbedding::new(d_data, 3, None);
let score = LateInteractionScorer::max_sim(&query, &doc);
assert!((score - 0.0).abs() < 0.001);
}
#[test]
fn test_max_sim_empty() {
let empty = MultiVectorEmbedding::empty(3);
let non_empty = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.0], 3, None);
assert_eq!(LateInteractionScorer::max_sim(&empty, &non_empty), 0.0);
assert_eq!(LateInteractionScorer::max_sim(&non_empty, &empty), 0.0);
}
#[test]
fn test_max_sim_dimension_mismatch() {
let mv1 = MultiVectorEmbedding::new(vec![1.0, 0.0], 2, None);
let mv2 = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.0], 3, None);
assert_eq!(LateInteractionScorer::max_sim(&mv1, &mv2), 0.0);
}
#[test]
fn test_max_sim_normalized() {
let q = MultiVectorEmbedding::new(vec![1.0, 0.0, 1.0, 0.0], 2, None);
let d = MultiVectorEmbedding::new(vec![1.0, 0.0], 2, None);
let score = LateInteractionScorer::max_sim(&q, &d);
let norm_score = LateInteractionScorer::max_sim_normalized(&q, &d);
assert_eq!(norm_score, score / 2.0);
}
#[test]
fn test_max_sim_symmetric() {
let a = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.5, 0.5], 2, None);
let b = MultiVectorEmbedding::new(vec![0.0, 1.0, 0.5, 0.5], 2, None);
let ab = LateInteractionScorer::max_sim(&a, &b);
let ba = LateInteractionScorer::max_sim(&b, &a);
let sym = LateInteractionScorer::max_sim_symmetric(&a, &b);
assert!((sym - (ab + ba) / 2.0).abs() < 0.001);
}
#[test]
fn test_similarity_matrix() {
let q = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.0, 1.0], 2, None);
let d = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0], 2, None);
let matrix = LateInteractionScorer::similarity_matrix(&q, &d);
assert_eq!(matrix.len(), 2); assert_eq!(matrix[0].len(), 3);
assert!((matrix[0][0] - 1.0).abs() < 0.001);
assert!((matrix[0][1] - 0.0).abs() < 0.001);
}
#[test]
fn test_best_matches() {
let q = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.0, 1.0], 2, None);
let d = MultiVectorEmbedding::new(vec![0.5, 0.5, 1.0, 0.0, 0.0, 1.0], 2, None);
let matches = LateInteractionScorer::best_matches(&q, &d);
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].0, 1);
assert_eq!(matches[1].0, 2);
}
#[test]
fn test_find_top_k() {
let query = MultiVectorEmbedding::new(vec![1.0, 0.0], 2, None);
let docs = vec![
MultiVectorEmbedding::new(vec![0.0, 1.0], 2, None), MultiVectorEmbedding::new(vec![1.0, 0.0], 2, None), MultiVectorEmbedding::new(vec![0.5, 0.5], 2, None), ];
let top_k = LateInteractionScorer::find_top_k(&query, &docs, 2);
assert_eq!(top_k.len(), 2);
assert_eq!(top_k[0].0, 1); assert_eq!(top_k[1].0, 2); }
#[test]
fn test_batch_score() {
let queries = vec![
MultiVectorEmbedding::new(vec![1.0, 0.0], 2, None),
MultiVectorEmbedding::new(vec![0.0, 1.0], 2, None),
];
let docs = vec![
MultiVectorEmbedding::new(vec![1.0, 0.0], 2, None),
MultiVectorEmbedding::new(vec![0.0, 1.0], 2, None),
];
let scores = LateInteractionScorer::batch_score(&queries, &docs);
assert_eq!(scores.len(), 2);
assert_eq!(scores[0].len(), 2);
assert!(scores[0][0] > scores[0][1]);
assert!(scores[1][1] > scores[1][0]);
}
#[test]
fn test_cosine_similarity() {
let a = [1.0, 0.0, 0.0];
let b = [1.0, 0.0, 0.0];
assert!((LateInteractionScorer::cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
let c = [0.0, 1.0, 0.0];
assert!((LateInteractionScorer::cosine_similarity(&a, &c) - 0.0).abs() < 0.001);
let d = [-1.0, 0.0, 0.0];
assert!((LateInteractionScorer::cosine_similarity(&a, &d) - (-1.0)).abs() < 0.001);
}
#[test]
fn test_config_builder() {
let config = MultiVectorConfig::default()
.normalize(false)
.skip_special_tokens(false)
.store_token_ids(true)
.batch_size(64)
.max_seq_len(512);
assert!(!config.normalize);
assert!(!config.skip_special_tokens);
assert!(config.store_token_ids);
assert_eq!(config.batch_size, 64);
assert_eq!(config.max_seq_len, 512);
}
#[test]
fn test_size_bytes() {
let data = vec![1.0f32; 100]; let mv = MultiVectorEmbedding::new(data, 10, None);
assert_eq!(mv.size_bytes(), 100 * 4);
let data2 = vec![1.0f32; 100];
let ids = Some(vec![0i32; 10]);
let mv2 = MultiVectorEmbedding::new(data2, 10, ids);
assert_eq!(mv2.size_bytes(), 100 * 4 + 10 * 4); }
#[cfg(feature = "parallel")]
#[test]
fn test_parallel_find_top_k() {
let query = MultiVectorEmbedding::new(vec![1.0, 0.0], 2, None);
let docs: Vec<MultiVectorEmbedding> = (0..100)
.map(|i| {
let x = (i as f32) / 100.0;
MultiVectorEmbedding::new(vec![x, 1.0 - x], 2, None)
})
.collect();
let top_k = LateInteractionScorer::find_top_k_parallel(&query, &docs, 5);
assert_eq!(top_k.len(), 5);
assert!(top_k[0].0 > 90);
}
}