#![cfg(feature = "embedder-candle")]
use anyhow::{Context, Result};
use async_trait::async_trait;
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config, DTYPE, HiddenAct};
use hf_hub::{Repo, RepoType, api::sync::Api};
use thiserror::Error;
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
use super::{EMBED_DIM, Embedder};
const MODEL_REPO: &str = "sentence-transformers/all-MiniLM-L6-v2";
const MAX_SEQ_LEN: usize = 128;
#[derive(Debug, Error)]
pub enum CandleEmbedderError {
#[error("candle embedder init failed: {0}")]
Init(#[source] anyhow::Error),
#[error("candle embedder embed failed: {0}")]
Embed(#[source] anyhow::Error),
}
pub struct CandleEmbedder {
model: BertModel,
tokenizer: Tokenizer,
device: Device,
dim: usize,
}
impl CandleEmbedder {
pub fn new(use_metal: bool) -> Result<Self, CandleEmbedderError> {
let device = pick_device(use_metal);
Self::new_with_device(device)
}
pub fn new_with_device(device: Device) -> Result<Self, CandleEmbedderError> {
let metal = matches!(device, Device::Metal(_));
tracing::info!(
"trusty-common::embedder: building CandleEmbedder ({})",
if metal { "metal" } else { "cpu" }
);
let api = Api::new()
.context("hf_hub api init")
.map_err(CandleEmbedderError::Init)?;
let repo = api.repo(Repo::new(MODEL_REPO.to_string(), RepoType::Model));
let config_filename = repo
.get("config.json")
.context("download config.json")
.map_err(CandleEmbedderError::Init)?;
let tokenizer_filename = repo
.get("tokenizer.json")
.context("download tokenizer.json")
.map_err(CandleEmbedderError::Init)?;
let weights_filename = repo
.get("model.safetensors")
.context("download model.safetensors")
.map_err(CandleEmbedderError::Init)?;
let config_bytes = std::fs::read(&config_filename)
.context("read config.json")
.map_err(CandleEmbedderError::Init)?;
let mut config: Config = serde_json::from_slice(&config_bytes)
.context("parse bert config")
.map_err(CandleEmbedderError::Init)?;
config.hidden_act = HiddenAct::GeluApproximate;
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
.map_err(|e| CandleEmbedderError::Init(anyhow::anyhow!(e)))?;
tokenizer
.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
..Default::default()
}))
.with_truncation(Some(TruncationParams {
max_length: MAX_SEQ_LEN,
..Default::default()
}))
.map_err(|e| CandleEmbedderError::Init(anyhow::anyhow!(e)))?;
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(
std::slice::from_ref(&weights_filename),
DTYPE,
&device,
)
.context("load safetensors weights")
.map_err(CandleEmbedderError::Init)?
};
let model = BertModel::load(vb, &config)
.context("instantiate bert model")
.map_err(CandleEmbedderError::Init)?;
Ok(Self {
model,
tokenizer,
device,
dim: EMBED_DIM,
})
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CandleEmbedderError> {
if texts.is_empty() {
return Ok(Vec::new());
}
let owned: Vec<String> = texts.iter().map(|s| (*s).to_string()).collect();
let encodings = self
.tokenizer
.encode_batch(owned, true)
.map_err(|e| CandleEmbedderError::Embed(anyhow::anyhow!(e)))
.context("tokenize batch")
.map_err(CandleEmbedderError::Embed)?;
let batch = encodings.len();
let seq_len = encodings
.iter()
.map(|e| e.get_ids().len())
.max()
.unwrap_or(0);
if seq_len == 0 {
return Ok(vec![vec![0.0; self.dim]; batch]);
}
let mut ids: Vec<u32> = Vec::with_capacity(batch * seq_len);
let mut mask: Vec<u32> = Vec::with_capacity(batch * seq_len);
for enc in &encodings {
ids.extend_from_slice(enc.get_ids());
mask.extend_from_slice(enc.get_attention_mask());
}
let input_ids = Tensor::from_vec(ids, (batch, seq_len), &self.device)
.context("stack input_ids")
.map_err(CandleEmbedderError::Embed)?;
let attn_mask_u32 = Tensor::from_vec(mask, (batch, seq_len), &self.device)
.context("stack attention_mask")
.map_err(CandleEmbedderError::Embed)?;
let token_type_ids = input_ids
.zeros_like()
.context("zeros for token_type_ids")
.map_err(CandleEmbedderError::Embed)?;
let hidden = self
.model
.forward(&input_ids, &token_type_ids, Some(&attn_mask_u32))
.context("bert forward")
.map_err(CandleEmbedderError::Embed)?;
let mask_f = attn_mask_u32
.to_dtype(DType::F32)
.context("mask to f32")
.map_err(CandleEmbedderError::Embed)?;
let mask_b_t_1 = mask_f
.unsqueeze(2)
.context("unsqueeze mask")
.map_err(CandleEmbedderError::Embed)?;
let masked = hidden
.broadcast_mul(&mask_b_t_1)
.context("apply mask to hidden")
.map_err(CandleEmbedderError::Embed)?;
let summed = masked
.sum(1)
.context("sum along seq")
.map_err(CandleEmbedderError::Embed)?;
let counts = mask_f
.sum(1)
.context("sum mask along seq")
.map_err(CandleEmbedderError::Embed)?
.clamp(1e-9_f64, f64::INFINITY)
.context("clamp counts")
.map_err(CandleEmbedderError::Embed)?;
let counts_b_1 = counts
.unsqueeze(1)
.context("unsqueeze counts")
.map_err(CandleEmbedderError::Embed)?;
let pooled = summed
.broadcast_div(&counts_b_1)
.context("mean pool")
.map_err(CandleEmbedderError::Embed)?;
let norms = pooled
.sqr()
.context("square pooled")
.map_err(CandleEmbedderError::Embed)?
.sum_keepdim(1)
.context("sum squares")
.map_err(CandleEmbedderError::Embed)?
.sqrt()
.context("sqrt norms")
.map_err(CandleEmbedderError::Embed)?
.clamp(1e-12_f64, f64::INFINITY)
.context("clamp norms")
.map_err(CandleEmbedderError::Embed)?;
let normed = pooled
.broadcast_div(&norms)
.context("l2 normalise")
.map_err(CandleEmbedderError::Embed)?;
let out: Vec<Vec<f32>> = normed
.to_vec2::<f32>()
.context("materialise embeddings to host")
.map_err(CandleEmbedderError::Embed)?;
Ok(out)
}
}
#[async_trait]
impl Embedder for CandleEmbedder {
async fn embed_batch(&self, texts: &[String]) -> anyhow::Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let refs: Vec<&str> = texts.iter().map(String::as_str).collect();
self.embed(&refs).map_err(anyhow::Error::from)
}
fn dimension(&self) -> usize {
self.dim
}
}
fn pick_device(use_metal: bool) -> Device {
if use_metal {
#[cfg(all(target_arch = "aarch64", target_os = "macos"))]
{
match Device::new_metal(0) {
Ok(d) => return d,
Err(e) => {
tracing::warn!(
"trusty-common::embedder: metal device init failed ({e}); \
falling back to CPU"
);
}
}
}
}
Device::Cpu
}
#[cfg(test)]
#[cfg(feature = "embedder-candle")]
mod tests {
use super::*;
#[test]
#[ignore = "downloads ~90 MB from HuggingFace; run with --include-ignored"]
fn candle_cpu_embeds_single_text() {
let emb = CandleEmbedder::new(false).expect("load candle embedder");
assert_eq!(emb.dimension(), EMBED_DIM);
let out = emb.embed(&["hello world"]).expect("embed succeeds");
assert_eq!(out.len(), 1);
assert_eq!(out[0].len(), EMBED_DIM);
let norm: f32 = out[0].iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-3,
"L2 norm should be ~1.0 (got {norm})"
);
}
#[test]
#[ignore = "downloads ~90 MB from HuggingFace; run with --include-ignored"]
fn candle_batch_embeds_consistently() {
let emb = CandleEmbedder::new(false).expect("load candle embedder");
let out_a = emb.embed(&["the quick brown fox"]).expect("first embed");
let out_b = emb.embed(&["the quick brown fox"]).expect("second embed");
let cos: f32 = out_a[0]
.iter()
.zip(out_b[0].iter())
.map(|(a, b)| a * b)
.sum();
assert!(cos > 0.999, "self-cosine should be > 0.999 (got {cos})");
}
#[test]
#[ignore = "downloads ~90 MB from HuggingFace; run with --include-ignored"]
fn candle_similar_texts_closer_than_dissimilar() {
let emb = CandleEmbedder::new(false).expect("load candle embedder");
let out = emb
.embed(&["dog", "cat", "airplane"])
.expect("embed succeeds");
let cos = |a: &[f32], b: &[f32]| -> f32 { a.iter().zip(b).map(|(x, y)| x * y).sum() };
let dog_cat = cos(&out[0], &out[1]);
let dog_airplane = cos(&out[0], &out[2]);
assert!(
dog_cat > dog_airplane,
"cos(dog, cat) should exceed cos(dog, airplane) (got {dog_cat} vs {dog_airplane})"
);
}
}