use super::Embedder;
use crate::{Error, Result};
use std::path::PathBuf;
#[cfg(feature = "hf-embeddings")]
use {
candle_core::{Device, Tensor},
candle_nn::VarBuilder,
candle_transformers::models::bert::{BertModel, Config, DTYPE},
hf_hub::{api::tokio::Api, Repo, RepoType},
std::sync::Arc,
tokenizers::Tokenizer,
};
pub struct HFEmbedder {
#[cfg(feature = "hf-embeddings")]
model: Arc<BertModel>,
#[cfg(feature = "hf-embeddings")]
tokenizer: Arc<Tokenizer>,
#[cfg(feature = "hf-embeddings")]
device: Device,
model_name: String,
dimension: usize,
}
impl HFEmbedder {
#[cfg(feature = "hf-embeddings")]
pub async fn new(model_name: &str) -> Result<Self> {
Self::new_with_options(model_name, None).await
}
#[cfg(feature = "hf-embeddings")]
pub async fn new_with_options(model_name: &str, cache_dir: Option<PathBuf>) -> Result<Self> {
use tokio::runtime::Handle;
let runtime = Handle::try_current()
.map_err(|_| Error::Embedding("No tokio runtime found".to_string()))?;
if let Some(dir) = cache_dir {
std::env::set_var("HF_HOME", dir.to_str().unwrap_or_default());
}
let token = std::env::var("HF_TOKEN").ok();
if token.is_some() {
eprintln!("✓ Using HF_TOKEN for authentication");
}
let api = Api::new()
.map_err(|e| Error::Embedding(format!("Failed to initialize HF Hub API: {}", e)))?;
let repo = api.repo(Repo::new(model_name.to_string(), RepoType::Model));
eprintln!("Downloading model files for {}...", model_name);
let config_path = repo
.get("config.json")
.await
.map_err(|e| Error::Embedding(format!("Failed to download config.json: {}", e)))?;
let tokenizer_path = repo
.get("tokenizer.json")
.await
.map_err(|e| Error::Embedding(format!("Failed to download tokenizer.json: {}", e)))?;
let weights_path = repo
.get("model.safetensors")
.await
.or_else(|_| {
runtime.block_on(repo.get("pytorch_model.bin"))
})
.map_err(|e| Error::Embedding(format!("Failed to download model weights: {}", e)))?;
eprintln!("✓ Model files cached locally");
let config_str = std::fs::read_to_string(&config_path)
.map_err(|e| Error::Embedding(format!("Failed to read config: {}", e)))?;
let config: Config = serde_json::from_str(&config_str)
.map_err(|e| Error::Embedding(format!("Failed to parse config: {}", e)))?;
let dimension = config.hidden_size;
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| Error::Embedding(format!("Failed to load tokenizer: {}", e)))?;
let device = Device::Cpu;
let vb = if weights_path.extension().and_then(|s| s.to_str()) == Some("safetensors") {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, &device) }
.map_err(|e| Error::Embedding(format!("Failed to load weights: {}", e)))?
} else {
return Err(Error::Embedding(
"Only safetensors format is supported".to_string(),
));
};
let model = BertModel::load(vb, &config)
.map_err(|e| Error::Embedding(format!("Failed to create model: {}", e)))?;
eprintln!("✓ Model loaded successfully");
Ok(HFEmbedder {
model: Arc::new(model),
tokenizer: Arc::new(tokenizer),
device,
model_name: model_name.to_string(),
dimension,
})
}
pub fn cache_dir() -> PathBuf {
std::env::var("HF_HOME")
.ok()
.map(PathBuf::from)
.unwrap_or_else(|| {
let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
PathBuf::from(home).join(".cache").join("huggingface")
})
}
#[cfg(feature = "hf-embeddings")]
fn mean_pooling(last_hidden_state: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let expanded_mask = attention_mask
.unsqueeze(2)
.map_err(|e| Error::Embedding(format!("Failed to expand mask: {}", e)))?
.expand(last_hidden_state.shape())
.map_err(|e| Error::Embedding(format!("Failed to expand mask shape: {}", e)))?
.to_dtype(last_hidden_state.dtype())
.map_err(|e| Error::Embedding(format!("Failed to convert dtype: {}", e)))?;
let masked = (last_hidden_state * &expanded_mask)
.map_err(|e| Error::Embedding(format!("Failed to apply mask: {}", e)))?;
let sum_embeddings = masked
.sum(1)
.map_err(|e| Error::Embedding(format!("Failed to sum embeddings: {}", e)))?;
let sum_mask = expanded_mask
.sum(1)
.map_err(|e| Error::Embedding(format!("Failed to sum mask: {}", e)))?;
let sum_mask = sum_mask
.clamp(1e-9, f32::MAX)
.map_err(|e| Error::Embedding(format!("Failed to clamp: {}", e)))?;
let pooled = sum_embeddings
.broadcast_div(&sum_mask)
.map_err(|e| Error::Embedding(format!("Failed to divide: {}", e)))?;
Ok(pooled)
}
#[cfg(feature = "hf-embeddings")]
fn normalize(tensor: &Tensor) -> Result<Tensor> {
let norm = tensor
.sqr()
.map_err(|e| Error::Embedding(format!("Failed to square: {}", e)))?
.sum_keepdim(1)
.map_err(|e| Error::Embedding(format!("Failed to sum: {}", e)))?
.sqrt()
.map_err(|e| Error::Embedding(format!("Failed to sqrt: {}", e)))?
.clamp(1e-12, f32::MAX)
.map_err(|e| Error::Embedding(format!("Failed to clamp: {}", e)))?;
tensor
.broadcast_div(&norm)
.map_err(|e| Error::Embedding(format!("Failed to normalize: {}", e)))
}
}
#[cfg(not(feature = "hf-embeddings"))]
impl HFEmbedder {
pub async fn new(_model_name: &str) -> Result<Self> {
Err(Error::Embedding(
"HF embeddings feature not enabled. Compile with --features hf-embeddings".to_string(),
))
}
pub async fn new_with_options(_model_name: &str, _cache_dir: Option<PathBuf>) -> Result<Self> {
Err(Error::Embedding(
"HF embeddings feature not enabled. Compile with --features hf-embeddings".to_string(),
))
}
pub fn cache_dir() -> PathBuf {
PathBuf::from(".")
}
}
impl Embedder for HFEmbedder {
fn dimension(&self) -> usize {
self.dimension
}
#[cfg(feature = "hf-embeddings")]
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| Error::Embedding(format!("Tokenization failed: {}", e)))?;
let token_ids = encoding.get_ids();
let attention_mask = encoding.get_attention_mask();
let token_ids = Tensor::new(token_ids, &self.device)
.map_err(|e| Error::Embedding(format!("Failed to create token tensor: {}", e)))?
.unsqueeze(0)
.map_err(|e| Error::Embedding(format!("Failed to unsqueeze tokens: {}", e)))?;
let attention_mask = Tensor::new(attention_mask, &self.device)
.map_err(|e| Error::Embedding(format!("Failed to create mask tensor: {}", e)))?
.unsqueeze(0)
.map_err(|e| Error::Embedding(format!("Failed to unsqueeze mask: {}", e)))?;
let outputs = self
.model
.forward(&token_ids, &attention_mask, None)
.map_err(|e| Error::Embedding(format!("Model forward failed: {}", e)))?;
let pooled = Self::mean_pooling(&outputs, &attention_mask)?;
let normalized = Self::normalize(&pooled)?;
let embedding = normalized
.squeeze(0)
.map_err(|e| Error::Embedding(format!("Failed to squeeze: {}", e)))?
.to_vec1::<f32>()
.map_err(|e| Error::Embedding(format!("Failed to convert to vec: {}", e)))?;
Ok(embedding)
}
#[cfg(not(feature = "hf-embeddings"))]
fn embed(&self, _text: &str) -> Result<Vec<f32>> {
Err(Error::Embedding(
"HF embeddings feature not enabled".to_string(),
))
}
fn model_name(&self) -> &str {
&self.model_name
}
}
#[cfg(all(test, feature = "hf-embeddings"))]
mod tests {
use super::*;
use crate::embedding::cosine_similarity;
fn is_ci() -> bool {
std::env::var("CI").is_ok() || std::env::var("GITHUB_ACTIONS").is_ok()
}
fn model_exists_in_cache(model_name: &str) -> bool {
let cache_dir = HFEmbedder::cache_dir();
let model_dir_name = format!("models--{}", model_name.replace('/', "--"));
let model_path = cache_dir.join("hub").join(&model_dir_name);
if !model_path.exists() {
return false;
}
let snapshots = model_path.join("snapshots");
if !snapshots.exists() {
return false;
}
if let Ok(entries) = std::fs::read_dir(snapshots) {
for entry in entries.flatten() {
let snapshot_path = entry.path();
if snapshot_path.join("model.safetensors").exists()
|| snapshot_path.join("pytorch_model.bin").exists()
{
return true;
}
}
}
false
}
#[tokio::test]
async fn test_hf_cache_dir() {
std::env::set_var("HF_HOME", "/custom/cache");
assert_eq!(HFEmbedder::cache_dir(), PathBuf::from("/custom/cache"));
std::env::remove_var("HF_HOME");
}
#[tokio::test]
#[ignore] async fn test_hf_embedder_creation() {
if is_ci() {
eprintln!("Skipping HF embedder test in CI (requires model download)");
return;
}
let embedder = HFEmbedder::new("sentence-transformers/all-MiniLM-L6-v2")
.await
.unwrap();
assert_eq!(embedder.dimension(), 384);
assert_eq!(
embedder.model_name(),
"sentence-transformers/all-MiniLM-L6-v2"
);
}
#[tokio::test]
#[ignore] async fn test_hf_embedder_embed() {
if is_ci() {
eprintln!("Skipping HF embedder test in CI");
return;
}
let embedder = HFEmbedder::new("sentence-transformers/all-MiniLM-L6-v2")
.await
.unwrap();
let embedding = embedder.embed("Hello, world!").unwrap();
assert_eq!(embedding.len(), 384);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-4);
}
#[tokio::test]
#[ignore] async fn test_hf_embedder_deterministic() {
if is_ci() {
eprintln!("Skipping HF embedder test in CI");
return;
}
let embedder = HFEmbedder::new("sentence-transformers/all-MiniLM-L6-v2")
.await
.unwrap();
let e1 = embedder.embed("test text").unwrap();
let e2 = embedder.embed("test text").unwrap();
for (a, b) in e1.iter().zip(e2.iter()) {
assert!((a - b).abs() < 1e-6);
}
}
#[tokio::test]
#[ignore] async fn test_hf_embedder_semantic_similarity() {
if is_ci() || !model_exists_in_cache("sentence-transformers/all-MiniLM-L6-v2") {
eprintln!("Skipping semantic similarity test (model not in cache)");
return;
}
let embedder = HFEmbedder::new("sentence-transformers/all-MiniLM-L6-v2")
.await
.unwrap();
let e1 = embedder.embed("The cat sits on the mat").unwrap();
let e2 = embedder.embed("A cat is sitting on a mat").unwrap();
let e3 = embedder.embed("Dogs are loyal animals").unwrap();
let sim_cat = cosine_similarity(&e1, &e2);
let sim_dog = cosine_similarity(&e1, &e3);
assert!(sim_cat > sim_dog);
assert!(sim_cat > 0.7); println!(
"Cat similarity: {:.3}, Dog similarity: {:.3}",
sim_cat, sim_dog
);
}
#[tokio::test]
#[ignore] async fn test_hf_embedder_batch() {
if is_ci() || !model_exists_in_cache("sentence-transformers/all-MiniLM-L6-v2") {
eprintln!("Skipping batch test (model not in cache)");
return;
}
let embedder = HFEmbedder::new("sentence-transformers/all-MiniLM-L6-v2")
.await
.unwrap();
let texts = vec!["first text", "second text", "third text"];
let embeddings = embedder.embed_batch(&texts).unwrap();
assert_eq!(embeddings.len(), 3);
for embedding in embeddings {
assert_eq!(embedding.len(), 384);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-4);
}
}
}