use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config as BertConfig};
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
use crate::error::{Error, Result};
use super::{download_model_files, l2_normalize, EmbedderConfig};
pub struct BertEmbedder {
model: BertModel,
tokenizer: Tokenizer,
device: Device,
normalize: bool,
#[allow(dead_code)]
max_length: usize,
}
impl BertEmbedder {
pub fn load(config: &EmbedderConfig) -> Result<Self> {
let model_id = config.model.model_id();
let device = Device::Cpu;
let max_length = config.max_length.unwrap_or(512);
let files = download_model_files(
model_id,
&["config.json", "tokenizer.json", "model.safetensors"],
config.cache_dir.as_ref(),
)?;
let config_path = &files[0];
let tokenizer_path = &files[1];
let weights_path = &files[2];
let bert_config: BertConfig = {
let config_str = std::fs::read_to_string(config_path)
.map_err(|e| Error::InvalidConfig(format!("Failed to read config.json: {}", e)))?;
serde_json::from_str(&config_str).map_err(|e| {
Error::InvalidConfig(format!("Failed to parse BERT config: {}", e))
})?
};
let mut tokenizer = Tokenizer::from_file(tokenizer_path).map_err(|e| {
Error::InvalidConfig(format!("Failed to load tokenizer: {}", e))
})?;
let _ = tokenizer.with_truncation(Some(TruncationParams {
max_length,
..Default::default()
}));
tokenizer.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
..Default::default()
}));
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_path.clone()], DType::F32, &device)
.map_err(|e| {
Error::InvalidConfig(format!("Failed to load model weights: {}", e))
})?
};
let model = BertModel::load(vb, &bert_config)
.map_err(|e| Error::InvalidConfig(format!("Failed to build BERT model: {}", e)))?;
Ok(Self {
model,
tokenizer,
device,
normalize: config.normalize,
max_length,
})
}
pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
let results = self.embed_batch(&[text])?;
Ok(results.into_iter().next().unwrap())
}
pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let encodings = self.tokenizer.encode_batch(texts.to_vec(), true).map_err(|e| {
Error::InvalidConfig(format!("Tokenization failed: {}", e))
})?;
let batch_size = encodings.len();
let max_len = encodings.iter().map(|e| e.get_ids().len()).max().unwrap_or(0);
let mut all_ids = Vec::with_capacity(batch_size * max_len);
let mut all_type_ids = Vec::with_capacity(batch_size * max_len);
let mut all_attention_mask = Vec::with_capacity(batch_size * max_len);
for encoding in &encodings {
let ids = encoding.get_ids();
let type_ids = encoding.get_type_ids();
let attention = encoding.get_attention_mask();
let len = ids.len();
all_ids.extend_from_slice(ids);
all_type_ids.extend_from_slice(type_ids);
all_attention_mask.extend_from_slice(attention);
for _ in len..max_len {
all_ids.push(0);
all_type_ids.push(0);
all_attention_mask.push(0);
}
}
let input_ids = Tensor::new(all_ids.as_slice(), &self.device)
.and_then(|t| t.reshape((batch_size, max_len)))
.map_err(|e| Error::InvalidConfig(format!("Failed to create input tensor: {}", e)))?;
let token_type_ids = Tensor::new(all_type_ids.as_slice(), &self.device)
.and_then(|t| t.reshape((batch_size, max_len)))
.map_err(|e| Error::InvalidConfig(format!("Failed to create type_ids tensor: {}", e)))?;
let attention_mask_tensor = Tensor::new(all_attention_mask.as_slice(), &self.device)
.and_then(|t| t.reshape((batch_size, max_len)))
.map_err(|e| {
Error::InvalidConfig(format!("Failed to create attention_mask tensor: {}", e))
})?;
let output = self
.model
.forward(&input_ids, &token_type_ids, Some(&attention_mask_tensor))
.map_err(|e| Error::InvalidConfig(format!("BERT forward pass failed: {}", e)))?;
let embeddings = mean_pooling(&output, &attention_mask_tensor)
.map_err(|e| Error::InvalidConfig(format!("Mean pooling failed: {}", e)))?;
let mut results = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let emb = embeddings
.get(i)
.map_err(|e| Error::InvalidConfig(format!("Failed to get embedding {}: {}", i, e)))?;
let mut vec: Vec<f32> = emb
.to_vec1()
.map_err(|e| Error::InvalidConfig(format!("Failed to convert to vec: {}", e)))?;
if self.normalize {
l2_normalize(&mut vec);
}
results.push(vec);
}
Ok(results)
}
}
fn mean_pooling(output: &Tensor, attention_mask: &Tensor) -> candle_core::Result<Tensor> {
let (_batch, _seq_len, _hidden) = output.dims3()?;
let mask = attention_mask
.to_dtype(DType::F32)?
.unsqueeze(2)?
.broadcast_as(output.shape())?;
let masked = output.mul(&mask)?;
let sum = masked.sum(1)?;
let count = mask.sum(1)?; let count = count.clamp(1e-9, f64::MAX)?;
sum.div(&count)
}