use crate::{
embeddings::{normalize_l2, utils::tokenize_batch},
models::modernbert::{Config, ModernBert},
Dtype,
};
use anyhow::Error as E;
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use hf_hub::{api::sync::ApiBuilder, Repo};
use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
use crate::embeddings::{embed::EmbeddingResult, select_device};
use super::{
bert::BertEmbed,
pooling::{ModelOutput, PooledOutputType, Pooling},
};
pub struct ModernBertEmbedder {
pub model: ModernBert,
pub tokenizer: Tokenizer,
pub device: Device,
pub pooling: Pooling,
}
impl Default for ModernBertEmbedder {
fn default() -> Self {
Self::new(
"nomic-ai/modernbert-embed-base".to_string(),
None,
None,
None,
)
.unwrap()
}
}
impl ModernBertEmbedder {
pub fn new(
model_id: String,
revision: Option<String>,
token: Option<&str>,
dtype: Option<Dtype>,
) -> Result<Self, E> {
let (config_filename, tokenizer_filename, weights_filename) = {
let api = ApiBuilder::from_env()
.with_token(token.map(|s| s.to_string()))
.build()
.unwrap();
let api = match revision {
Some(rev) => api.repo(Repo::with_revision(model_id, hf_hub::RepoType::Model, rev)),
None => api.repo(hf_hub::Repo::new(
model_id.to_string(),
hf_hub::RepoType::Model,
)),
};
let config = api.get("config.json")?;
let tokenizer = api.get("tokenizer.json")?;
let weights = match api.get("model.safetensors") {
Ok(safetensors) => safetensors,
Err(_) => match api.get("pytorch_model.bin") {
Ok(pytorch_model) => pytorch_model,
Err(e) => {
return Err(anyhow::Error::msg(format!(
"Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {}",
e
)));
}
},
};
(config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let pp = PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
let trunc = TruncationParams {
strategy: tokenizers::TruncationStrategy::LongestFirst,
max_length: config.max_position_embeddings as usize,
..Default::default()
};
tokenizer
.with_padding(Some(pp))
.with_truncation(Some(trunc))
.unwrap();
let device = select_device();
let dtype = match dtype {
Some(Dtype::F16) => DType::F16,
Some(Dtype::F32) => DType::F32,
_ => DType::F32,
};
let vb = if weights_filename.ends_with("model.safetensors") {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], dtype, &device)? }
} else {
println!("Can't find model.safetensors, loading from pytorch_model.bin");
VarBuilder::from_pth(&weights_filename, dtype, &device)?
};
let model = ModernBert::load(vb, &config)?;
let tokenizer = tokenizer;
Ok(ModernBertEmbedder {
model,
tokenizer,
device,
pooling: Pooling::Mean,
})
}
fn embed(
&self,
text_batch: &[&str],
batch_size: Option<usize>,
) -> Result<Vec<EmbeddingResult>, anyhow::Error> {
let batch_size = batch_size.unwrap_or(32);
let mut encodings: Vec<EmbeddingResult> = Vec::new();
for mini_text_batch in text_batch.chunks(batch_size) {
let (token_ids, attention_mask) =
tokenize_batch(&self.tokenizer, mini_text_batch, &self.device)?;
let embeddings: Tensor = self
.model
.forward(&token_ids, &attention_mask)?
.to_dtype(DType::F32)
.unwrap();
let attention_mask = PooledOutputType::from(attention_mask);
let attention_mask = Some(&attention_mask);
let model_output = ModelOutput::Tensor(embeddings.clone());
let pooled_output = match self.pooling {
Pooling::Cls => self.pooling.pool(&model_output, None)?,
Pooling::Mean => self.pooling.pool(&model_output, attention_mask)?,
Pooling::LastToken => self.pooling.pool(&model_output, attention_mask)?,
};
let pooled_output = pooled_output.to_tensor()?;
let embeddings = normalize_l2(pooled_output)?;
let batch_encodings = embeddings.to_vec2::<f32>()?;
encodings.extend(
batch_encodings
.iter()
.map(|x| EmbeddingResult::DenseVector(x.to_vec())),
);
}
Ok(encodings)
}
fn embed_late_chunking(
&self,
text_batch: &[&str],
batch_size: Option<usize>,
) -> Result<Vec<EmbeddingResult>, anyhow::Error> {
let batch_size = batch_size.unwrap_or(32);
let mut results = Vec::new();
for mini_text_batch in text_batch.chunks(batch_size) {
let tokens = self
.tokenizer
.encode_batch(mini_text_batch.to_vec(), true)
.map_err(E::msg)?;
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
tokens
})
.collect::<Vec<_>>();
let attention_mask = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_attention_mask().to_vec();
tokens
})
.collect::<Vec<_>>();
let sequence_lengths: Vec<usize> = token_ids.iter().map(|seq| seq.len()).collect();
let cumulative_seq_lengths: Vec<usize> = sequence_lengths
.iter()
.scan(0, |acc, &x| {
*acc += x;
Some(*acc)
})
.collect();
let token_ids_merged = token_ids.concat();
let attention_mask_merged = attention_mask.concat();
let device = &self.device;
let token_ids_tensor =
Tensor::new(token_ids_merged.as_slice(), device)?.unsqueeze(0)?;
let attention_mask_tensor =
Tensor::new(attention_mask_merged.as_slice(), device)?.unsqueeze(0)?;
let token_type_ids = token_ids_tensor.zeros_like()?;
let embeddings = self.model.forward(&token_ids_tensor, &token_type_ids)?;
let attention_mask_tensor = PooledOutputType::from(attention_mask_tensor);
for (i, &end_idx) in cumulative_seq_lengths.iter().enumerate() {
let start_idx = if i == 0 {
0
} else {
cumulative_seq_lengths[i - 1]
};
let seq_embeddings = embeddings.narrow(1, start_idx, end_idx - start_idx)?;
let seq_attention_mask =
attention_mask_tensor
.to_tensor()?
.narrow(1, start_idx, end_idx - start_idx)?;
let model_output = ModelOutput::Tensor(seq_embeddings);
let pooled_output = Pooling::Mean.pool(
&model_output,
Some(&PooledOutputType::from(seq_attention_mask)),
)?;
let pooled_tensor = pooled_output.to_tensor()?;
let normalized = normalize_l2(pooled_tensor)?.squeeze(0)?;
let embedding_vec = normalized.to_vec1::<f32>().unwrap();
results.push(EmbeddingResult::DenseVector(embedding_vec));
}
}
Ok(results)
}
}
impl BertEmbed for ModernBertEmbedder {
fn embed(
&self,
text_batch: &[&str],
batch_size: Option<usize>,
late_chunking: Option<bool>,
) -> Result<Vec<EmbeddingResult>, anyhow::Error> {
if late_chunking.unwrap_or(false) {
self.embed_late_chunking(text_batch, batch_size)
} else {
self.embed(text_batch, batch_size)
}
}
}