use std::sync::Arc;
use super::cache::EmbeddingCache;
use super::modernbert::{ModernBertConfig, ModernBertModel};
use super::Result;
#[derive(Clone, Debug)]
pub struct EmbeddingConfig {
pub model_config: ModernBertConfig,
pub pooling: PoolingStrategy,
pub normalize: bool,
pub cache_size: usize,
pub batch_size: usize,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
model_config: ModernBertConfig::default(),
pooling: PoolingStrategy::Cls,
normalize: true,
cache_size: 10000,
batch_size: 32,
}
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum PoolingStrategy {
#[default]
Cls,
MeanPooling,
MaxPooling,
}
pub struct ModernBertEmbedder {
model: Arc<ModernBertModel>,
config: EmbeddingConfig,
cache: Option<EmbeddingCache>,
}
impl ModernBertEmbedder {
pub fn new(config: EmbeddingConfig) -> Result<Self> {
let model = ModernBertModel::load(config.model_config.clone())?;
let cache = if config.cache_size > 0 {
Some(EmbeddingCache::new(config.cache_size))
} else {
None
};
Ok(Self {
model: Arc::new(model),
config,
cache,
})
}
pub fn from_model(model: Arc<ModernBertModel>, config: EmbeddingConfig) -> Self {
let cache = if config.cache_size > 0 {
Some(EmbeddingCache::new(config.cache_size))
} else {
None
};
Self {
model,
config,
cache,
}
}
pub fn embedding_dim(&self) -> usize {
self.model.hidden_size()
}
pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
if let Some(cache) = &self.cache {
if let Some(embedding) = cache.get(text) {
return Ok(embedding.to_vec());
}
}
let embedding = match self.config.pooling {
PoolingStrategy::Cls => self.model.embed(text)?,
PoolingStrategy::MeanPooling => self.model.embed_mean_pooled(text)?,
PoolingStrategy::MaxPooling => {
self.model.embed_mean_pooled(text)?
}
};
let embedding = if self.config.normalize {
Self::normalize(&embedding)
} else {
embedding
};
if let Some(cache) = &self.cache {
cache.insert(text, embedding.clone());
}
Ok(embedding)
}
pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
let mut uncached_indices: Vec<usize> = Vec::new();
let mut uncached_texts: Vec<&str> = Vec::new();
if let Some(cache) = &self.cache {
for (i, text) in texts.iter().enumerate() {
if let Some(embedding) = cache.get(text) {
results[i] = Some(embedding.to_vec());
} else {
uncached_indices.push(i);
uncached_texts.push(text);
}
}
} else {
uncached_indices = (0..texts.len()).collect();
uncached_texts = texts.to_vec();
}
if !uncached_texts.is_empty() {
for chunk_start in (0..uncached_texts.len()).step_by(self.config.batch_size) {
let chunk_end = (chunk_start + self.config.batch_size).min(uncached_texts.len());
let chunk = &uncached_texts[chunk_start..chunk_end];
let embeddings = match self.config.pooling {
PoolingStrategy::Cls => self.model.embed_batch(chunk)?,
PoolingStrategy::MeanPooling | PoolingStrategy::MaxPooling => {
chunk
.iter()
.map(|t| self.model.embed_mean_pooled(t))
.collect::<Result<Vec<_>>>()?
}
};
for (j, embedding) in embeddings.into_iter().enumerate() {
let idx = uncached_indices[chunk_start + j];
let embedding = if self.config.normalize {
Self::normalize(&embedding)
} else {
embedding
};
if let Some(cache) = &self.cache {
cache.insert(texts[idx], embedding.clone());
}
results[idx] = Some(embedding);
}
}
}
Ok(results.into_iter().map(|r| r.unwrap()).collect())
}
pub fn embed_document(&self, title: Option<&str>, content: &str) -> Result<Vec<f32>> {
let text = match title {
Some(t) => format!("{} {}", t, content),
None => content.to_string(),
};
let truncated = self.truncate_text(&text);
self.embed(&truncated)
}
pub fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
self.embed(query)
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
pub fn normalize(embedding: &[f32]) -> Vec<f32> {
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm == 0.0 {
embedding.to_vec()
} else {
embedding.iter().map(|x| x / norm).collect()
}
}
fn truncate_text(&self, text: &str) -> String {
let max_chars = self.config.model_config.max_seq_len * 4;
if text.len() <= max_chars {
text.to_string()
} else {
let truncated = &text[..max_chars];
match truncated.rfind(char::is_whitespace) {
Some(pos) => truncated[..pos].to_string(),
None => truncated.to_string(),
}
}
}
pub fn model(&self) -> &ModernBertModel {
&self.model
}
pub fn model_arc(&self) -> Arc<ModernBertModel> {
Arc::clone(&self.model)
}
pub fn config(&self) -> &EmbeddingConfig {
&self.config
}
pub fn clear_cache(&self) {
if let Some(cache) = &self.cache {
cache.clear();
}
}
pub fn cache_stats(&self) -> Option<usize> {
self.cache.as_ref().map(|c| c.len())
}
}
impl std::fmt::Debug for ModernBertEmbedder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ModernBertEmbedder")
.field("embedding_dim", &self.embedding_dim())
.field("pooling", &self.config.pooling)
.field("normalize", &self.config.normalize)
.field("cache_size", &self.cache.as_ref().map(|c| c.len()))
.finish()
}
}
#[derive(Clone, Debug)]
pub struct DocumentEmbedding {
pub embedding: Vec<f32>,
pub document_id: String,
pub title: Option<String>,
}
pub struct BatchDocumentEmbedder {
embedder: ModernBertEmbedder,
}
impl BatchDocumentEmbedder {
pub fn new(embedder: ModernBertEmbedder) -> Self {
Self { embedder }
}
pub fn embed_documents(
&self,
documents: &[(String, Option<String>, String)], ) -> Result<Vec<DocumentEmbedding>> {
let texts: Vec<String> = documents
.iter()
.map(|(_, title, content)| match title {
Some(t) => format!("{} {}", t, content),
None => content.clone(),
})
.collect();
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
let embeddings = self.embedder.embed_batch(&text_refs)?;
Ok(documents
.iter()
.zip(embeddings)
.map(|((id, title, _), embedding)| DocumentEmbedding {
embedding,
document_id: id.clone(),
title: title.clone(),
})
.collect())
}
pub fn embedding_dim(&self) -> usize {
self.embedder.embedding_dim()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize() {
let embedding = vec![3.0, 4.0];
let normalized = ModernBertEmbedder::normalize(&embedding);
let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6);
assert!((normalized[0] - 0.6).abs() < 1e-6);
assert!((normalized[1] - 0.8).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((ModernBertEmbedder::cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
let c = vec![0.0, 1.0, 0.0];
assert!((ModernBertEmbedder::cosine_similarity(&a, &c) - 0.0).abs() < 1e-6);
let d = vec![-1.0, 0.0, 0.0];
assert!((ModernBertEmbedder::cosine_similarity(&a, &d) - (-1.0)).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_normalized() {
let a = ModernBertEmbedder::normalize(&vec![3.0, 4.0]);
let b = ModernBertEmbedder::normalize(&vec![4.0, 3.0]);
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let cosine = ModernBertEmbedder::cosine_similarity(&a, &b);
assert!((dot - cosine).abs() < 1e-6);
}
}