use super::provider::{EmbeddingProvider, EmbeddingResult};
use anyhow::{Context, Result};
use async_trait::async_trait;
use candle_core::{Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config};
use std::io::Write;
use std::path::PathBuf;
use std::str::FromStr;
use tokenizers::Tokenizer;
const HF_BASE_URL: &str = "https://huggingface.co";
const MINILM_CONFIG: &str = include_str!("../../models/all-MiniLM-L6-v2/config.json");
const MINILM_TOKENIZER: &str = include_str!("../../models/all-MiniLM-L6-v2/tokenizer.json");
const MINILM_MODEL_ID: &str = "sentence-transformers/all-MiniLM-L6-v2";
#[cfg(feature = "bundled-weights")]
const MINILM_WEIGHTS: &[u8] = include_bytes!("../../models/all-MiniLM-L6-v2/model.safetensors");
pub struct LocalEmbedder {
model: BertModel,
tokenizer: Tokenizer,
device: Device,
dimensions: usize,
model_name: String,
}
impl std::fmt::Debug for LocalEmbedder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocalEmbedder")
.field("model", &"<BertModel>")
.field("dimensions", &self.dimensions)
.field("model_name", &self.model_name)
.finish()
}
}
fn download_hf_file(
model_id: &str,
filename: &str,
cache_dir: &std::path::Path,
) -> Result<PathBuf> {
let safe_model_id = model_id.replace('/', "--");
let model_cache = cache_dir.join(&safe_model_id);
std::fs::create_dir_all(&model_cache)?;
let file_path = model_cache.join(filename);
if file_path.exists() {
println!(" Using cached: {}", filename);
return Ok(file_path);
}
let url = format!("{}/{}/resolve/main/{}", HF_BASE_URL, model_id, filename);
println!(" Downloading: {}", filename);
let response =
reqwest::blocking::get(&url).with_context(|| format!("Failed to download {}", url))?;
if !response.status().is_success() {
anyhow::bail!(
"Failed to download {}: HTTP {}",
filename,
response.status()
);
}
let bytes = response.bytes()?;
let mut file = std::fs::File::create(&file_path)?;
file.write_all(&bytes)?;
Ok(file_path)
}
impl LocalEmbedder {
pub fn new(model_id: &str, cache_dir: &str, verbose: bool) -> Result<Self> {
use std::time::Instant;
let total_start = Instant::now();
let device = Device::Cpu;
let cache_path = PathBuf::from(cache_dir);
std::fs::create_dir_all(&cache_path)?;
if verbose {
println!("Loading model: {}", model_id);
}
let is_bundled = model_id == MINILM_MODEL_ID || model_id == "all-MiniLM-L6-v2";
let (config_str, tokenizer_str) = if is_bundled {
if verbose {
println!(" Using bundled config and tokenizer");
}
(MINILM_CONFIG.to_string(), MINILM_TOKENIZER.to_string())
} else {
let config_path = download_hf_file(model_id, "config.json", &cache_path)?;
let tokenizer_path = download_hf_file(model_id, "tokenizer.json", &cache_path)?;
let config_str = std::fs::read_to_string(&config_path)?;
let tokenizer_str = std::fs::read_to_string(&tokenizer_path)?;
(config_str, tokenizer_str)
};
let config: Config = serde_json::from_str(&config_str)?;
let dimensions = config.hidden_size;
let tokenizer_start = Instant::now();
let tokenizer = Tokenizer::from_str(&tokenizer_str)
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
if verbose {
println!(" Tokenizer loaded in {:?}", tokenizer_start.elapsed());
}
let weights_start = Instant::now();
#[cfg(feature = "bundled-weights")]
let vb = if is_bundled {
if verbose {
println!(" Loading bundled weights into memory...");
}
VarBuilder::from_buffered_safetensors(
MINILM_WEIGHTS.to_vec(),
candle_core::DType::F32,
&device,
)?
} else {
let weights_path = download_hf_file(model_id, "model.safetensors", &cache_path)
.or_else(|_| download_hf_file(model_id, "pytorch_model.bin", &cache_path))
.context("Failed to download model weights")?;
if weights_path
.extension()
.map(|e| e == "safetensors")
.unwrap_or(false)
{
unsafe {
VarBuilder::from_mmaped_safetensors(
&[weights_path],
candle_core::DType::F32,
&device,
)?
}
} else {
VarBuilder::from_pth(&weights_path, candle_core::DType::F32, &device)?
}
};
#[cfg(not(feature = "bundled-weights"))]
let vb = {
let weights_path = if is_bundled {
download_hf_file(MINILM_MODEL_ID, "model.safetensors", &cache_path)
.context("Failed to download model weights")?
} else {
download_hf_file(model_id, "model.safetensors", &cache_path)
.or_else(|_| download_hf_file(model_id, "pytorch_model.bin", &cache_path))
.context("Failed to download model weights")?
};
if verbose {
println!(" Loading weights into memory...");
}
if weights_path
.extension()
.map(|e| e == "safetensors")
.unwrap_or(false)
{
unsafe {
VarBuilder::from_mmaped_safetensors(
&[weights_path],
candle_core::DType::F32,
&device,
)?
}
} else {
VarBuilder::from_pth(&weights_path, candle_core::DType::F32, &device)?
}
};
if verbose {
println!(" Weights loaded in {:?}", weights_start.elapsed());
}
let model_start = Instant::now();
let model = BertModel::load(vb, &config)?;
if verbose {
println!(" Model initialized in {:?}", model_start.elapsed());
println!(" Total model load time: {:?}", total_start.elapsed());
}
Ok(Self {
model,
tokenizer,
device,
dimensions,
model_name: model_id.to_string(),
})
}
fn mean_pooling(&self, embeddings: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let mask = attention_mask
.unsqueeze(2)?
.to_dtype(candle_core::DType::F32)?;
let masked = embeddings.broadcast_mul(&mask)?;
let sum = masked.sum(1)?;
let count = mask.sum(1)?.clamp(1e-9, f64::MAX)?;
Ok(sum.broadcast_div(&count)?)
}
fn normalize(&self, embeddings: &Tensor) -> Result<Tensor> {
let norm = embeddings.sqr()?.sum_keepdim(1)?.sqrt()?;
Ok(embeddings.broadcast_div(&norm)?)
}
}
#[async_trait]
impl EmbeddingProvider for LocalEmbedder {
fn dimensions(&self) -> usize {
self.dimensions
}
fn model_name(&self) -> &str {
&self.model_name
}
async fn embed(&self, text: &str) -> Result<EmbeddingResult> {
let results = self.embed_batch(&[text.to_string()]).await?;
results
.into_iter()
.next()
.context("Empty result from local model")
}
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<EmbeddingResult>> {
if texts.is_empty() {
return Ok(vec![]);
}
let encodings = self
.tokenizer
.encode_batch(texts.to_vec(), true)
.map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
let max_len = encodings.iter().map(|e| e.len()).max().unwrap_or(0);
let mut input_ids_vec = Vec::new();
let mut attention_mask_vec = Vec::new();
for encoding in &encodings {
let ids = encoding.get_ids();
let mask = encoding.get_attention_mask();
let mut padded_ids = ids.to_vec();
let mut padded_mask = mask.to_vec();
padded_ids.resize(max_len, 0);
padded_mask.resize(max_len, 0);
input_ids_vec.push(padded_ids);
attention_mask_vec.push(padded_mask);
}
let batch_size = texts.len();
let input_ids = Tensor::new(
input_ids_vec
.into_iter()
.flatten()
.map(|x| x as i64)
.collect::<Vec<_>>(),
&self.device,
)?
.reshape((batch_size, max_len))?;
let attention_mask = Tensor::new(
attention_mask_vec
.into_iter()
.flatten()
.map(|x| x as i64)
.collect::<Vec<_>>(),
&self.device,
)?
.reshape((batch_size, max_len))?;
let token_type_ids = Tensor::zeros_like(&input_ids)?;
let embeddings = self
.model
.forward(&input_ids, &token_type_ids, Some(&attention_mask))?;
let pooled = self.mean_pooling(&embeddings, &attention_mask)?;
let normalized = self.normalize(&pooled)?;
let mut results = Vec::with_capacity(batch_size);
for (i, encoding) in encodings.iter().enumerate().take(batch_size) {
let embedding = normalized.get(i)?.to_vec1::<f32>()?;
results.push(EmbeddingResult {
embedding,
token_count: Some(encoding.len()),
});
}
Ok(results)
}
}
#[cfg(test)]
mod tests {
}