use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::error::{EmbeddingError, VectorStoreError};
use crate::layer1_echo::similarity::cosine_similarity;
use crate::layer1_echo::traits::SimilarityMetric;
use crate::types::{Document, DocumentId};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenEmbedding {
pub token: Option<String>,
pub position: usize,
pub embedding: Vec<f32>,
}
impl TokenEmbedding {
#[must_use]
pub fn new(position: usize, embedding: Vec<f32>) -> Self {
Self {
token: None,
position,
embedding,
}
}
#[must_use]
pub fn with_token(position: usize, token: impl Into<String>, embedding: Vec<f32>) -> Self {
Self {
token: Some(token.into()),
position,
embedding,
}
}
#[must_use]
pub fn dimension(&self) -> usize {
self.embedding.len()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiVectorDocument {
pub document: Document,
pub token_embeddings: Vec<TokenEmbedding>,
}
impl MultiVectorDocument {
#[must_use]
pub fn new(document: Document, token_embeddings: Vec<TokenEmbedding>) -> Self {
Self {
document,
token_embeddings,
}
}
#[must_use]
pub fn num_tokens(&self) -> usize {
self.token_embeddings.len()
}
#[must_use]
pub fn dimension(&self) -> usize {
self.token_embeddings
.first()
.map_or(0, |t| t.embedding.len())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenMatch {
pub query_token_idx: usize,
pub doc_token_idx: usize,
pub similarity: f32,
}
impl TokenMatch {
#[must_use]
pub fn new(query_token_idx: usize, doc_token_idx: usize, similarity: f32) -> Self {
Self {
query_token_idx,
doc_token_idx,
similarity,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiVectorSearchResult {
pub document: Document,
pub score: f32,
pub rank: usize,
pub token_matches: Vec<TokenMatch>,
}
impl MultiVectorSearchResult {
#[must_use]
pub fn new(
document: Document,
score: f32,
rank: usize,
token_matches: Vec<TokenMatch>,
) -> Self {
Self {
document,
score,
rank,
token_matches,
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct MaxSimScore {
pub metric: SimilarityMetric,
}
impl MaxSimScore {
#[must_use]
pub fn new(metric: SimilarityMetric) -> Self {
Self { metric }
}
#[must_use]
pub fn compute(
&self,
query_tokens: &[TokenEmbedding],
doc_tokens: &[TokenEmbedding],
) -> (f32, Vec<TokenMatch>) {
if query_tokens.is_empty() || doc_tokens.is_empty() {
return (0.0, Vec::new());
}
let mut total_score = 0.0;
let mut token_matches = Vec::with_capacity(query_tokens.len());
for (q_idx, q_token) in query_tokens.iter().enumerate() {
let mut max_sim = f32::NEG_INFINITY;
let mut best_doc_idx = 0;
for (d_idx, d_token) in doc_tokens.iter().enumerate() {
let sim = self.compute_token_similarity(&q_token.embedding, &d_token.embedding);
if sim > max_sim {
max_sim = sim;
best_doc_idx = d_idx;
}
}
total_score += max_sim;
token_matches.push(TokenMatch::new(q_idx, best_doc_idx, max_sim));
}
(total_score, token_matches)
}
fn compute_token_similarity(self, a: &[f32], b: &[f32]) -> f32 {
match self.metric {
SimilarityMetric::Cosine => cosine_similarity(a, b),
SimilarityMetric::DotProduct => a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(),
SimilarityMetric::Euclidean => {
let distance: f32 = a
.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt();
1.0 / (1.0 + distance)
}
}
}
#[must_use]
pub fn score(&self, query_tokens: &[TokenEmbedding], doc_tokens: &[TokenEmbedding]) -> f32 {
self.compute(query_tokens, doc_tokens).0
}
}
#[async_trait]
pub trait TokenEmbeddingProvider: Send + Sync {
async fn embed_tokens(&self, text: &str) -> Result<Vec<TokenEmbedding>, EmbeddingError>;
async fn embed_tokens_batch(
&self,
texts: &[&str],
) -> Result<Vec<Vec<TokenEmbedding>>, EmbeddingError>;
fn dimension(&self) -> usize;
fn model_id(&self) -> &str;
}
#[async_trait]
pub trait MultiVectorStore: Send + Sync {
async fn insert(&mut self, doc: MultiVectorDocument) -> Result<(), VectorStoreError>;
async fn insert_batch(
&mut self,
docs: Vec<MultiVectorDocument>,
) -> Result<(), VectorStoreError>;
async fn get(&self, id: &DocumentId) -> Result<Option<MultiVectorDocument>, VectorStoreError>;
async fn delete(&mut self, id: &DocumentId) -> Result<bool, VectorStoreError>;
async fn search(
&self,
query_tokens: &[TokenEmbedding],
top_k: usize,
min_score: Option<f32>,
) -> Result<Vec<MultiVectorSearchResult>, VectorStoreError>;
async fn count(&self) -> usize;
async fn clear(&mut self) -> Result<(), VectorStoreError>;
fn dimension(&self) -> usize;
fn similarity_metric(&self) -> SimilarityMetric;
}
pub struct MockTokenEmbeddingProvider {
dimension: usize,
model_id: String,
}
impl MockTokenEmbeddingProvider {
#[must_use]
pub fn new(dimension: usize) -> Self {
Self {
dimension,
model_id: "mock-token-embedder".to_string(),
}
}
#[must_use]
pub fn with_model_id(dimension: usize, model_id: impl Into<String>) -> Self {
Self {
dimension,
model_id: model_id.into(),
}
}
fn generate_embedding(&self, text: &str, position: usize) -> Vec<f32> {
let mut embedding = vec![0.0; self.dimension];
let bytes = text.as_bytes();
for (i, &byte) in bytes.iter().take(self.dimension).enumerate() {
embedding[i] = (f32::from(byte) / 255.0) * 2.0 - 1.0;
}
if self.dimension > 0 {
#[allow(clippy::cast_precision_loss)]
let pos_factor = (position as f32).sin() * 0.1;
for val in &mut embedding {
*val += pos_factor;
}
}
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for val in &mut embedding {
*val /= norm;
}
}
embedding
}
}
#[async_trait]
impl TokenEmbeddingProvider for MockTokenEmbeddingProvider {
async fn embed_tokens(&self, text: &str) -> Result<Vec<TokenEmbedding>, EmbeddingError> {
if text.is_empty() {
return Err(EmbeddingError::EmptyInput);
}
let tokens: Vec<&str> = text.split_whitespace().collect();
if tokens.is_empty() {
return Err(EmbeddingError::EmptyInput);
}
let token_embeddings: Vec<TokenEmbedding> = tokens
.iter()
.enumerate()
.map(|(pos, token)| {
let embedding = self.generate_embedding(token, pos);
TokenEmbedding::with_token(pos, *token, embedding)
})
.collect();
Ok(token_embeddings)
}
async fn embed_tokens_batch(
&self,
texts: &[&str],
) -> Result<Vec<Vec<TokenEmbedding>>, EmbeddingError> {
let mut results = Vec::with_capacity(texts.len());
for text in texts {
results.push(self.embed_tokens(text).await?);
}
Ok(results)
}
fn dimension(&self) -> usize {
self.dimension
}
fn model_id(&self) -> &str {
&self.model_id
}
}
pub struct InMemoryMultiVectorStore {
documents: HashMap<DocumentId, MultiVectorDocument>,
dimension: usize,
metric: SimilarityMetric,
max_sim: MaxSimScore,
}
impl InMemoryMultiVectorStore {
#[must_use]
pub fn new(dimension: usize) -> Self {
Self {
documents: HashMap::new(),
dimension,
metric: SimilarityMetric::Cosine,
max_sim: MaxSimScore::new(SimilarityMetric::Cosine),
}
}
#[must_use]
pub fn with_metric(dimension: usize, metric: SimilarityMetric) -> Self {
Self {
documents: HashMap::new(),
dimension,
metric,
max_sim: MaxSimScore::new(metric),
}
}
}
#[async_trait]
impl MultiVectorStore for InMemoryMultiVectorStore {
async fn insert(&mut self, doc: MultiVectorDocument) -> Result<(), VectorStoreError> {
if let Some(token) = doc.token_embeddings.first()
&& token.embedding.len() != self.dimension
{
return Err(VectorStoreError::DimensionMismatch {
expected: self.dimension,
actual: token.embedding.len(),
});
}
self.documents.insert(doc.document.id.clone(), doc);
Ok(())
}
async fn insert_batch(
&mut self,
docs: Vec<MultiVectorDocument>,
) -> Result<(), VectorStoreError> {
for doc in docs {
self.insert(doc).await?;
}
Ok(())
}
async fn get(&self, id: &DocumentId) -> Result<Option<MultiVectorDocument>, VectorStoreError> {
Ok(self.documents.get(id).cloned())
}
async fn delete(&mut self, id: &DocumentId) -> Result<bool, VectorStoreError> {
Ok(self.documents.remove(id).is_some())
}
async fn search(
&self,
query_tokens: &[TokenEmbedding],
top_k: usize,
min_score: Option<f32>,
) -> Result<Vec<MultiVectorSearchResult>, VectorStoreError> {
if query_tokens.is_empty() {
return Ok(Vec::new());
}
let mut scored: Vec<(DocumentId, f32, Vec<TokenMatch>)> = self
.documents
.iter()
.map(|(id, doc)| {
let (score, matches) = self.max_sim.compute(query_tokens, &doc.token_embeddings);
(id.clone(), score, matches)
})
.filter(|(_, score, _)| min_score.is_none_or(|min| *score >= min))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
let results = scored
.into_iter()
.enumerate()
.filter_map(|(rank, (id, score, token_matches))| {
self.documents.get(&id).map(|doc| {
MultiVectorSearchResult::new(doc.document.clone(), score, rank, token_matches)
})
})
.collect();
Ok(results)
}
async fn count(&self) -> usize {
self.documents.len()
}
async fn clear(&mut self) -> Result<(), VectorStoreError> {
self.documents.clear();
Ok(())
}
fn dimension(&self) -> usize {
self.dimension
}
fn similarity_metric(&self) -> SimilarityMetric {
self.metric
}
}
#[cfg(test)]
#[allow(clippy::float_cmp)]
mod tests {
use super::*;
#[test]
fn test_token_embedding_creation() {
let embedding = vec![0.1, 0.2, 0.3];
let token = TokenEmbedding::new(0, embedding.clone());
assert_eq!(token.position, 0);
assert_eq!(token.embedding, embedding);
assert!(token.token.is_none());
assert_eq!(token.dimension(), 3);
}
#[test]
fn test_token_embedding_with_token() {
let embedding = vec![0.1, 0.2, 0.3];
let token = TokenEmbedding::with_token(1, "hello", embedding.clone());
assert_eq!(token.position, 1);
assert_eq!(token.token, Some("hello".to_string()));
assert_eq!(token.embedding, embedding);
}
#[test]
fn test_multi_vector_document() {
let doc = Document::new("test content");
let tokens = vec![
TokenEmbedding::new(0, vec![0.1, 0.2]),
TokenEmbedding::new(1, vec![0.3, 0.4]),
];
let mv_doc = MultiVectorDocument::new(doc.clone(), tokens);
assert_eq!(mv_doc.document.content, "test content");
assert_eq!(mv_doc.num_tokens(), 2);
assert_eq!(mv_doc.dimension(), 2);
}
#[test]
fn test_multi_vector_document_empty() {
let doc = Document::new("test");
let mv_doc = MultiVectorDocument::new(doc, Vec::new());
assert_eq!(mv_doc.num_tokens(), 0);
assert_eq!(mv_doc.dimension(), 0);
}
#[test]
fn test_token_match() {
let m = TokenMatch::new(0, 2, 0.95);
assert_eq!(m.query_token_idx, 0);
assert_eq!(m.doc_token_idx, 2);
assert!((m.similarity - 0.95).abs() < 1e-6);
}
#[test]
fn test_max_sim_score_empty() {
let scorer = MaxSimScore::new(SimilarityMetric::Cosine);
let query: Vec<TokenEmbedding> = Vec::new();
let doc: Vec<TokenEmbedding> = Vec::new();
let (score, matches) = scorer.compute(&query, &doc);
assert_eq!(score, 0.0);
assert!(matches.is_empty());
}
#[test]
fn test_max_sim_score_identical() {
let scorer = MaxSimScore::new(SimilarityMetric::Cosine);
let embedding = vec![1.0, 0.0, 0.0];
let query = vec![TokenEmbedding::new(0, embedding.clone())];
let doc = vec![TokenEmbedding::new(0, embedding)];
let (score, matches) = scorer.compute(&query, &doc);
assert!((score - 1.0).abs() < 1e-6);
assert_eq!(matches.len(), 1);
assert!((matches[0].similarity - 1.0).abs() < 1e-6);
}
#[test]
fn test_max_sim_score_multiple_tokens() {
let scorer = MaxSimScore::new(SimilarityMetric::Cosine);
let query = vec![
TokenEmbedding::new(0, vec![1.0, 0.0]),
TokenEmbedding::new(1, vec![0.0, 1.0]),
];
let doc = vec![
TokenEmbedding::new(0, vec![0.5, 0.5]), TokenEmbedding::new(1, vec![1.0, 0.0]), TokenEmbedding::new(2, vec![0.0, 1.0]), ];
let (score, matches) = scorer.compute(&query, &doc);
assert!(score > 1.9);
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].doc_token_idx, 1);
assert!((matches[0].similarity - 1.0).abs() < 1e-6);
assert_eq!(matches[1].doc_token_idx, 2);
assert!((matches[1].similarity - 1.0).abs() < 1e-6);
}
#[test]
fn test_max_sim_score_with_dot_product() {
let scorer = MaxSimScore::new(SimilarityMetric::DotProduct);
let query = vec![TokenEmbedding::new(0, vec![2.0, 3.0])];
let doc = vec![TokenEmbedding::new(0, vec![4.0, 5.0])];
let score = scorer.score(&query, &doc);
assert!((score - 23.0).abs() < 1e-6);
}
#[tokio::test]
async fn test_mock_token_embedding_provider() {
let provider = MockTokenEmbeddingProvider::new(32);
assert_eq!(provider.dimension(), 32);
assert_eq!(provider.model_id(), "mock-token-embedder");
let tokens = provider.embed_tokens("hello world test").await.unwrap();
assert_eq!(tokens.len(), 3);
assert_eq!(tokens[0].position, 0);
assert_eq!(tokens[0].token, Some("hello".to_string()));
assert_eq!(tokens[0].embedding.len(), 32);
}
#[tokio::test]
async fn test_mock_token_embedding_provider_empty() {
let provider = MockTokenEmbeddingProvider::new(32);
let result = provider.embed_tokens("").await;
assert!(result.is_err());
let result = provider.embed_tokens(" ").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_mock_token_embedding_provider_batch() {
let provider = MockTokenEmbeddingProvider::new(16);
let texts = ["hello world", "foo bar baz"];
let results = provider.embed_tokens_batch(&texts).await.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].len(), 2);
assert_eq!(results[1].len(), 3);
}
#[tokio::test]
async fn test_in_memory_multi_vector_store_insert() {
let mut store = InMemoryMultiVectorStore::new(32);
let provider = MockTokenEmbeddingProvider::new(32);
let doc = Document::new("test document");
let tokens = provider.embed_tokens("test document").await.unwrap();
let mv_doc = MultiVectorDocument::new(doc.clone(), tokens);
store.insert(mv_doc).await.unwrap();
assert_eq!(store.count().await, 1);
let retrieved = store.get(&doc.id).await.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().document.content, "test document");
}
#[tokio::test]
async fn test_in_memory_multi_vector_store_dimension_mismatch() {
let mut store = InMemoryMultiVectorStore::new(32);
let doc = Document::new("test");
let tokens = vec![TokenEmbedding::new(0, vec![0.1, 0.2])]; let mv_doc = MultiVectorDocument::new(doc, tokens);
let result = store.insert(mv_doc).await;
assert!(matches!(
result,
Err(VectorStoreError::DimensionMismatch { .. })
));
}
#[tokio::test]
async fn test_in_memory_multi_vector_store_search() {
let mut store = InMemoryMultiVectorStore::new(32);
let provider = MockTokenEmbeddingProvider::new(32);
let doc1 = Document::new("quick brown fox");
let tokens1 = provider.embed_tokens("quick brown fox").await.unwrap();
store
.insert(MultiVectorDocument::new(doc1, tokens1))
.await
.unwrap();
let doc2 = Document::new("lazy dog sleeps");
let tokens2 = provider.embed_tokens("lazy dog sleeps").await.unwrap();
store
.insert(MultiVectorDocument::new(doc2, tokens2))
.await
.unwrap();
let doc3 = Document::new("quick fox jumps");
let tokens3 = provider.embed_tokens("quick fox jumps").await.unwrap();
store
.insert(MultiVectorDocument::new(doc3, tokens3))
.await
.unwrap();
let query_tokens = provider.embed_tokens("quick fox").await.unwrap();
let results = store.search(&query_tokens, 2, None).await.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].rank, 0);
assert_eq!(results[1].rank, 1);
assert!(
results[0].document.content.contains("quick")
|| results[0].document.content.contains("fox")
);
}
#[tokio::test]
async fn test_in_memory_multi_vector_store_search_with_min_score() {
let mut store = InMemoryMultiVectorStore::new(32);
let provider = MockTokenEmbeddingProvider::new(32);
let doc = Document::new("test doc");
let tokens = provider.embed_tokens("test doc").await.unwrap();
store
.insert(MultiVectorDocument::new(doc, tokens))
.await
.unwrap();
let query_tokens = provider.embed_tokens("completely different").await.unwrap();
let results = store.search(&query_tokens, 10, Some(100.0)).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn test_in_memory_multi_vector_store_delete() {
let mut store = InMemoryMultiVectorStore::new(32);
let provider = MockTokenEmbeddingProvider::new(32);
let doc = Document::new("test");
let id = doc.id.clone();
let tokens = provider.embed_tokens("test").await.unwrap();
store
.insert(MultiVectorDocument::new(doc, tokens))
.await
.unwrap();
assert!(store.delete(&id).await.unwrap());
assert_eq!(store.count().await, 0);
assert!(!store.delete(&id).await.unwrap());
}
#[tokio::test]
async fn test_in_memory_multi_vector_store_clear() {
let mut store = InMemoryMultiVectorStore::new(32);
let provider = MockTokenEmbeddingProvider::new(32);
for i in 0..5 {
let doc = Document::new(format!("doc {i}"));
let tokens = provider.embed_tokens(&format!("doc {i}")).await.unwrap();
store
.insert(MultiVectorDocument::new(doc, tokens))
.await
.unwrap();
}
assert_eq!(store.count().await, 5);
store.clear().await.unwrap();
assert_eq!(store.count().await, 0);
}
#[tokio::test]
async fn test_in_memory_multi_vector_store_batch_insert() {
let mut store = InMemoryMultiVectorStore::new(32);
let provider = MockTokenEmbeddingProvider::new(32);
let mut docs = Vec::new();
for i in 0..3 {
let doc = Document::new(format!("document {i}"));
let tokens = provider
.embed_tokens(&format!("document {i}"))
.await
.unwrap();
docs.push(MultiVectorDocument::new(doc, tokens));
}
store.insert_batch(docs).await.unwrap();
assert_eq!(store.count().await, 3);
}
#[test]
fn test_multi_vector_search_result() {
let doc = Document::new("test");
let matches = vec![TokenMatch::new(0, 1, 0.9), TokenMatch::new(1, 2, 0.8)];
let result = MultiVectorSearchResult::new(doc.clone(), 1.7, 0, matches);
assert_eq!(result.document.content, "test");
assert!((result.score - 1.7).abs() < 1e-6);
assert_eq!(result.rank, 0);
assert_eq!(result.token_matches.len(), 2);
}
#[test]
fn test_store_metric_config() {
let store = InMemoryMultiVectorStore::with_metric(64, SimilarityMetric::DotProduct);
assert_eq!(store.dimension(), 64);
assert_eq!(store.similarity_metric(), SimilarityMetric::DotProduct);
}
#[tokio::test]
async fn test_mock_provider_custom_model_id() {
let provider = MockTokenEmbeddingProvider::with_model_id(32, "custom-model");
assert_eq!(provider.model_id(), "custom-model");
}
#[tokio::test]
async fn test_search_empty_query() {
let store = InMemoryMultiVectorStore::new(32);
let results = store.search(&[], 10, None).await.unwrap();
assert!(results.is_empty());
}
}