use async_trait::async_trait;
use candle_core::{Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
use hf_hub::{api::tokio::Api, Repo, RepoType};
use tokenizers::Tokenizer;
use std::sync::Mutex;
use crate::traits::{CerebroError, Embedder, Result};
pub struct LocalEmbedder {
model: Mutex<BertModel>,
tokenizer: Mutex<Tokenizer>,
device: Device,
}
impl LocalEmbedder {
pub async fn new() -> Result<Self> {
let api = Api::new().map_err(|e| CerebroError::EmbeddingError(format!("HF API Error: {}", e)))?;
let repo = api.repo(Repo::with_revision(
"sentence-transformers/all-MiniLM-L6-v2".to_string(),
RepoType::Model,
"refs/pr/21".to_string(), ));
let tokenizer_filename = repo.get("tokenizer.json").await
.map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;
let config_filename = repo.get("config.json").await
.map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;
let weights_filename = repo.get("model.safetensors").await
.map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;
let tokenizer = Tokenizer::from_file(&tokenizer_filename)
.map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;
let config_str = std::fs::read_to_string(config_filename).unwrap();
let config: Config = serde_json::from_str(&config_str).unwrap();
let device = Device::Cpu;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device) }
.map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;
let model = BertModel::load(vb, &config)
.map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;
Ok(Self {
model: Mutex::new(model),
tokenizer: Mutex::new(tokenizer),
device,
})
}
}
#[async_trait]
impl Embedder for LocalEmbedder {
async fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let tokenizer = self.tokenizer.lock().unwrap();
let tokens = tokenizer.encode_batch(texts.to_vec(), true)
.map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;
let mut token_ids: Vec<Vec<u32>> = tokens.iter().map(|t| t.get_ids().to_vec()).collect();
let max_len = token_ids.iter().map(|v| v.len()).max().unwrap_or(0);
for ids in token_ids.iter_mut() {
ids.resize(max_len, 0);
}
let n_sentences = texts.len();
let token_ids_flat: Vec<u32> = token_ids.into_iter().flatten().collect();
let token_ids_tensor = Tensor::from_vec(token_ids_flat, (n_sentences, max_len), &self.device)
.map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;
let token_type_ids = Tensor::zeros((n_sentences, max_len), candle_core::DType::U32, &self.device)
.map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;
let model = self.model.lock().unwrap();
let embeddings = model.forward(&token_ids_tensor, &token_type_ids, None)
.map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;
let pooled_embeddings = embeddings
.sum(1)
.map_err(|e| CerebroError::EmbeddingError(e.to_string()))?
.broadcast_div(&Tensor::new(max_len as f32, &self.device).unwrap())
.map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;
let mut results = Vec::with_capacity(n_sentences);
for i in 0..n_sentences {
let row = pooled_embeddings
.get(i)
.map_err(|e| CerebroError::EmbeddingError(e.to_string()))?
.to_vec1::<f32>()
.map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;
results.push(row);
}
Ok(results)
}
}