#[cfg(feature = "embeddings-candle")]
use std::any::Any;
#[cfg(feature = "embeddings-candle")]
use async_trait::async_trait;
#[cfg(feature = "embeddings-candle")]
use candle_core::{DType, Device, Tensor};
#[cfg(feature = "embeddings-candle")]
use candle_nn::VarBuilder;
#[cfg(feature = "embeddings-candle")]
use candle_transformers::models::bert::{BertModel, Config};
#[cfg(feature = "embeddings-candle")]
use hf_hub::api::sync::ApiBuilder;
#[cfg(feature = "embeddings-candle")]
use tokenizers::Tokenizer;
#[cfg(feature = "embeddings-candle")]
use crate::embedding::embedder::{EmbedInput, EmbedInputType, Embedder};
#[cfg(feature = "embeddings-candle")]
use crate::error::{LaurusError, Result};
#[cfg(feature = "embeddings-candle")]
use crate::vector::core::vector::Vector;
#[cfg(feature = "embeddings-candle")]
pub struct CandleBertEmbedder {
model: BertModel,
tokenizer: Tokenizer,
device: Device,
dim: usize,
model_name: String,
}
#[cfg(feature = "embeddings-candle")]
impl std::fmt::Debug for CandleBertEmbedder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CandleBertEmbedder")
.field("model_name", &self.model_name)
.field("dimension", &self.dim)
.finish()
}
}
#[cfg(feature = "embeddings-candle")]
impl CandleBertEmbedder {
pub fn new(model_name: &str) -> Result<Self> {
let device = Device::cuda_if_available(0)
.map_err(|e| LaurusError::InvalidOperation(format!("Device setup failed: {}", e)))?;
let cache_dir = std::env::var("HF_HOME")
.or_else(|_| std::env::var("HOME").map(|home| format!("{}/.cache/huggingface", home)))
.unwrap_or_else(|_| "/tmp/huggingface".to_string());
let api = ApiBuilder::new()
.with_cache_dir(cache_dir.into())
.build()
.map_err(|e| {
LaurusError::InvalidOperation(format!("HF API initialization failed: {}", e))
})?;
let repo = api.model(model_name.to_string());
let config_filename = repo
.get("config.json")
.map_err(|e| LaurusError::InvalidOperation(format!("Config download failed: {}", e)))?;
let config_str = std::fs::read_to_string(config_filename)
.map_err(|e| LaurusError::InvalidOperation(format!("Config read failed: {}", e)))?;
let config: Config = serde_json::from_str(&config_str)
.map_err(|e| LaurusError::InvalidOperation(format!("Config parse failed: {}", e)))?;
let weights_filename = repo.get("model.safetensors").map_err(|e| {
LaurusError::InvalidOperation(format!("Weights download failed: {}", e))
})?;
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device).map_err(
|e| LaurusError::InvalidOperation(format!("VarBuilder creation failed: {}", e)),
)?
};
let model = BertModel::load(vb, &config)
.map_err(|e| LaurusError::InvalidOperation(format!("Model load failed: {}", e)))?;
let tokenizer_filename = repo.get("tokenizer.json").map_err(|e| {
LaurusError::InvalidOperation(format!("Tokenizer download failed: {}", e))
})?;
let tokenizer = Tokenizer::from_file(tokenizer_filename)
.map_err(|e| LaurusError::InvalidOperation(format!("Tokenizer load failed: {}", e)))?;
let dim = config.hidden_size;
Ok(Self {
model,
tokenizer,
device,
dim,
model_name: model_name.to_string(),
})
}
async fn embed_text(&self, text: &str) -> Result<Vector> {
let text = text.to_string();
tokio::task::block_in_place(|| self.embed_text_sync(&text))
}
fn embed_text_sync(&self, text: &str) -> Result<Vector> {
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| LaurusError::InvalidOperation(format!("Tokenization failed: {}", e)))?;
let token_ids = encoding.get_ids();
let attention_mask = encoding.get_attention_mask();
let token_ids_tensor = Tensor::new(token_ids, &self.device)
.map_err(|e| LaurusError::InvalidOperation(format!("Tensor creation failed: {}", e)))?
.unsqueeze(0)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?;
let attention_mask_tensor = Tensor::new(attention_mask, &self.device)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?
.unsqueeze(0)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?;
let embeddings = self
.model
.forward(&token_ids_tensor, &attention_mask_tensor, None)
.map_err(|e| LaurusError::InvalidOperation(format!("Model forward failed: {}", e)))?;
let pooled = self.mean_pool(&embeddings, &attention_mask_tensor)?;
let norm = pooled
.sqr()
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?
.sum_all()
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?
.sqrt()
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?
.to_scalar::<f32>()
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?;
let normalized = pooled
.affine((1.0 / norm) as f64, 0.0)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?;
let vector_data: Vec<f32> = normalized
.squeeze(0)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?
.to_vec1()
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?;
Ok(Vector::new(vector_data))
}
fn mean_pool(&self, embeddings: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let mask_expanded = attention_mask
.unsqueeze(2)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?
.expand(embeddings.shape())
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?
.to_dtype(embeddings.dtype())
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?;
let masked_embeddings = embeddings
.mul(&mask_expanded)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?;
let sum_embeddings = masked_embeddings
.sum(1)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?;
let sum_mask = mask_expanded
.sum(1)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?;
let mean = sum_embeddings
.div(&sum_mask)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?;
Ok(mean)
}
}
#[cfg(feature = "embeddings-candle")]
#[async_trait]
impl Embedder for CandleBertEmbedder {
async fn embed(&self, input: &EmbedInput<'_>) -> Result<Vector> {
match input {
EmbedInput::Text(text) => self.embed_text(text).await,
_ => Err(LaurusError::invalid_argument(
"CandleBertEmbedder only supports text input",
)),
}
}
fn supported_input_types(&self) -> Vec<EmbedInputType> {
vec![EmbedInputType::Text]
}
fn name(&self) -> &str {
&self.model_name
}
fn as_any(&self) -> &dyn Any {
self
}
}