#![cfg(feature = "candle")]
use anyhow::{Context, Result};
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
const MODEL_REPO: &str = "sentence-transformers/all-MiniLM-L6-v2";
const MAX_SEQ_LEN: usize = 256;
pub const EMBED_DIM: usize = 384;
pub struct CandleEmbedder {
model: BertModel,
tokenizer: Tokenizer,
device: Device,
}
impl CandleEmbedder {
pub fn new() -> Result<Self> {
let device = pick_device();
let metal = matches!(device, Device::Metal(_));
tracing::info!("embedder: candle/{}", if metal { "metal" } else { "cpu" });
let api = Api::new().context("hf_hub api init")?;
let repo = api.repo(Repo::new(MODEL_REPO.to_string(), RepoType::Model));
let config_filename = repo.get("config.json").context("download config.json")?;
let tokenizer_filename = repo
.get("tokenizer.json")
.context("download tokenizer.json")?;
let weights_filename = repo
.get("model.safetensors")
.context("download model.safetensors")?;
let config_bytes = std::fs::read(&config_filename).context("read config.json")?;
let mut config: Config =
serde_json::from_slice(&config_bytes).context("parse bert config")?;
config.hidden_act = HiddenAct::GeluApproximate;
let mut tokenizer =
Tokenizer::from_file(&tokenizer_filename).map_err(anyhow::Error::msg)?;
tokenizer
.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
..Default::default()
}))
.with_truncation(Some(TruncationParams {
max_length: MAX_SEQ_LEN,
..Default::default()
}))
.map_err(anyhow::Error::msg)?;
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(
std::slice::from_ref(&weights_filename),
DTYPE,
&device,
)
.context("load safetensors weights")?
};
let model = BertModel::load(vb, &config).context("instantiate bert model")?;
Ok(Self {
model,
tokenizer,
device,
})
}
pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let owned: Vec<String> = texts.iter().map(|s| (*s).to_string()).collect();
let encodings = self
.tokenizer
.encode_batch(owned, true)
.map_err(anyhow::Error::msg)
.context("tokenize batch")?;
let batch = encodings.len();
let seq_len = encodings
.iter()
.map(|e| e.get_ids().len())
.max()
.unwrap_or(0);
if seq_len == 0 {
return Ok(vec![vec![0.0; EMBED_DIM]; batch]);
}
let mut ids: Vec<u32> = Vec::with_capacity(batch * seq_len);
let mut mask: Vec<u32> = Vec::with_capacity(batch * seq_len);
for enc in &encodings {
ids.extend_from_slice(enc.get_ids());
mask.extend_from_slice(enc.get_attention_mask());
}
let input_ids =
Tensor::from_vec(ids, (batch, seq_len), &self.device).context("stack input_ids")?;
let attn_mask_u32 = Tensor::from_vec(mask, (batch, seq_len), &self.device)
.context("stack attention_mask")?;
let token_type_ids = input_ids.zeros_like().context("zeros for token_type_ids")?;
let hidden = self
.model
.forward(&input_ids, &token_type_ids, Some(&attn_mask_u32))
.context("bert forward")?;
let mask_f = attn_mask_u32.to_dtype(DType::F32).context("mask to f32")?;
let mask_b_t_1 = mask_f.unsqueeze(2).context("unsqueeze mask")?;
let masked = hidden
.broadcast_mul(&mask_b_t_1)
.context("apply mask to hidden")?;
let summed = masked.sum(1).context("sum along seq")?;
let counts = mask_f
.sum(1)
.context("sum mask along seq")?
.clamp(1e-9_f64, f64::INFINITY)
.context("clamp counts")?;
let counts_b_1 = counts.unsqueeze(1).context("unsqueeze counts")?;
let pooled = summed.broadcast_div(&counts_b_1).context("mean pool")?;
let norms = pooled
.sqr()
.context("square pooled")?
.sum_keepdim(1)
.context("sum squares")?
.sqrt()
.context("sqrt norms")?
.clamp(1e-12_f64, f64::INFINITY)
.context("clamp norms")?;
let normed = pooled.broadcast_div(&norms).context("l2 normalise")?;
let out: Vec<Vec<f32>> = normed
.to_vec2::<f32>()
.context("materialise embeddings to host")?;
Ok(out)
}
pub fn dimension(&self) -> usize {
EMBED_DIM
}
}
#[async_trait::async_trait]
impl trusty_common::embedder::Embedder for CandleEmbedder {
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let refs: Vec<&str> = texts.iter().map(String::as_str).collect();
self.embed_batch(&refs)
}
fn dimension(&self) -> usize {
EMBED_DIM
}
}
fn pick_device() -> Device {
#[cfg(all(target_arch = "aarch64", target_os = "macos"))]
{
match Device::new_metal(0) {
Ok(d) => return d,
Err(e) => {
tracing::warn!("candle: metal device init failed ({e}); falling back to CPU");
}
}
}
Device::Cpu
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore = "downloads ~22 MB from HuggingFace; run with --include-ignored"]
fn embed_dimension() {
let emb = CandleEmbedder::new().expect("load candle embedder");
let out = emb.embed_batch(&["hello"]).expect("embed succeeds");
assert_eq!(out.len(), 1);
assert_eq!(out[0].len(), EMBED_DIM);
}
#[test]
#[ignore = "downloads ~22 MB from HuggingFace; run with --include-ignored"]
fn cosine_self_similarity() {
let emb = CandleEmbedder::new().expect("load candle embedder");
let out = emb.embed_batch(&["hello world"]).expect("embed succeeds");
let v = &out[0];
let dot: f32 = v.iter().map(|x| x * x).sum();
assert!(
(dot - 1.0).abs() < 1e-3,
"self-cosine should be ~1.0 (got {dot})"
);
}
}