use crate::embeddings::TextEmbedder;
use anyhow::{anyhow, Context, Result};
use candle_core::{Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Debug, Clone)]
pub enum CandleModel {
AllMiniLML6V2,
BgeSmallEn,
Custom { repo_id: String, revision: String },
}
impl CandleModel {
fn repo_id(&self) -> &str {
match self {
CandleModel::AllMiniLML6V2 => "sentence-transformers/all-MiniLM-L6-v2",
CandleModel::BgeSmallEn => "BAAI/bge-small-en",
CandleModel::Custom { repo_id, .. } => repo_id,
}
}
}
pub struct CandleEmbedder {
model: BertModel,
tokenizer: Tokenizer,
device: Device,
normalize: bool,
}
impl CandleEmbedder {
pub fn new(model: CandleModel) -> Result<Self> {
let device = Device::Cpu;
let api = Api::new().context("Failed to create HuggingFace API")?;
let repo = Repo::new(model.repo_id().to_string(), RepoType::Model);
let repo = api.repo(repo);
let tokenizer_path = repo
.get("tokenizer.json")
.context("Failed to download tokenizer")?;
let tokenizer = Tokenizer::from_file(tokenizer_path)
.map_err(|e| anyhow!("Failed to load tokenizer: {}", e))?;
let config_path = repo
.get("config.json")
.context("Failed to download config")?;
let config: Config = serde_json::from_reader(std::fs::File::open(config_path)?)
.context("Failed to parse config")?;
let weights_path = repo
.get("pytorch_model.bin")
.or_else(|_| repo.get("model.safetensors"))
.context("Failed to download model weights")?;
let vb = if weights_path.extension().and_then(|s| s.to_str()) == Some("safetensors") {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, &device)? }
} else {
VarBuilder::from_pth(&weights_path, DTYPE, &device)?
};
let model = BertModel::load(vb, &config)?;
Ok(Self {
model,
tokenizer,
device,
normalize: true, })
}
pub fn set_normalize(&mut self, normalize: bool) {
self.normalize = normalize;
}
fn mean_pool(&self, embeddings: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let (_batch_size, seq_len, hidden_size) = embeddings.dims3()?;
let mask = attention_mask.unsqueeze(2)?;
let mask = mask.expand(&[1, seq_len, hidden_size])?;
let masked = embeddings.mul(&mask)?;
let sum = masked.sum(1)?;
let mask_sum = mask.sum(1)?;
let mean = sum.broadcast_div(&mask_sum)?;
Ok(mean)
}
fn normalize_embeddings(&self, embeddings: &Tensor) -> Result<Tensor> {
let norm = embeddings.sqr()?.sum_keepdim(1)?.sqrt()?;
let normalized = embeddings.broadcast_div(&norm)?;
Ok(normalized)
}
}
impl TextEmbedder for CandleEmbedder {
fn dimension(&self) -> Result<usize> {
Ok(384)
}
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| anyhow!("Tokenization failed: {}", e))?;
let input_ids = encoding.get_ids();
let attention_mask = encoding.get_attention_mask();
let input_ids = Tensor::new(input_ids, &self.device)?.unsqueeze(0)?;
let attention_mask = Tensor::new(attention_mask, &self.device)?.unsqueeze(0)?;
let embeddings = self.model.forward(&input_ids, &attention_mask, None)?;
let pooled = self.mean_pool(&embeddings, &attention_mask)?;
let final_embedding = if self.normalize {
self.normalize_embeddings(&pooled)?
} else {
pooled
};
let embedding_vec = final_embedding.squeeze(0)?.to_vec1::<f32>()?;
Ok(embedding_vec)
}
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
texts.iter().map(|text| self.embed(text)).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore] fn test_candle_embedder() {
let embedder = CandleEmbedder::new(CandleModel::AllMiniLML6V2).unwrap();
let text = "This is a test sentence.";
let embedding = embedder.embed(text).unwrap();
assert_eq!(embedding.len(), 384);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.01, "Embedding should be normalized");
}
#[test]
#[ignore] fn test_batch_embed() {
let embedder = CandleEmbedder::new(CandleModel::AllMiniLML6V2).unwrap();
let texts = vec!["First sentence", "Second sentence", "Third sentence"];
let embeddings = embedder.embed_batch(&texts).unwrap();
assert_eq!(embeddings.len(), 3);
assert_eq!(embeddings[0].len(), 384);
}
}