#[cfg(target_endian = "big")]
compile_error!("fathomdb-embedder default path requires a little-endian target");
use std::path::Path;
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config as BertConfig};
use fathomdb_embedder_api::{Embedder, EmbedderError, EmbedderIdentity, Vector};
use tokenizers::{Tokenizer, TruncationParams};
use crate::loader::{load_pinned_default_embedder, EmbedderLoadError, LoadedWeights, HF_REVISION};
pub const DEFAULT_EMBEDDER_NAME: &str = "fathomdb-bge-small-en-v1.5";
pub const DEFAULT_EMBEDDER_DIM: u32 = 384;
const MAX_SEQUENCE_TOKENS: usize = 512;
pub struct CandleBgeEmbedder {
identity: EmbedderIdentity,
tokenizer: Tokenizer,
model: BertModel,
device: Device,
}
impl CandleBgeEmbedder {
pub fn new() -> Result<Self, EmbedderLoadError> {
let weights = load_pinned_default_embedder()?;
Self::new_from_weights(weights)
}
pub fn new_from_weights(weights: LoadedWeights) -> Result<Self, EmbedderLoadError> {
let config_bytes = std::fs::read(&weights.config_json_path).map_err(|source| {
EmbedderLoadError::CacheIoError { path: weights.config_json_path.clone(), source }
})?;
let config: BertConfig =
serde_json::from_slice(&config_bytes).map_err(|e| EmbedderLoadError::CacheIoError {
path: weights.config_json_path.clone(),
source: std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()),
})?;
if config.hidden_size != DEFAULT_EMBEDDER_DIM as usize {
return Err(EmbedderLoadError::DimensionMismatch {
expected: DEFAULT_EMBEDDER_DIM,
actual: config.hidden_size as u32,
});
}
let mut tokenizer = Tokenizer::from_file(&weights.tokenizer_json_path)
.map_err(|e| EmbedderLoadError::TokenizerLoad { source: e })?;
tokenizer
.with_truncation(Some(TruncationParams {
max_length: MAX_SEQUENCE_TOKENS,
..Default::default()
}))
.map_err(|e| EmbedderLoadError::TokenizerLoad { source: e })?;
let device = Device::Cpu;
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(
&[weights.model_safetensors_path.as_path() as &Path],
DType::F32,
&device,
)
}
.map_err(|source| EmbedderLoadError::ModelDeserialize { source })?;
let model = BertModel::load(vb, &config)
.map_err(|source| EmbedderLoadError::ModelDeserialize { source })?;
let identity =
EmbedderIdentity::new(DEFAULT_EMBEDDER_NAME, HF_REVISION, DEFAULT_EMBEDDER_DIM);
Ok(Self { identity, tokenizer, model, device })
}
}
impl Embedder for CandleBgeEmbedder {
fn identity(&self) -> EmbedderIdentity {
self.identity.clone()
}
fn embed(&self, input: &str) -> Result<Vector, EmbedderError> {
let encoding = self
.tokenizer
.encode(input, true)
.map_err(|e| EmbedderError::Failed { message: format!("tokenize: {e}") })?;
let ids: Vec<u32> = encoding.get_ids().to_vec();
let attn: Vec<u32> = encoding.get_attention_mask().to_vec();
let len = ids.len();
let embed_impl = || -> candle_core::Result<Vec<f32>> {
let input_ids = Tensor::from_vec(ids, (1, len), &self.device)?;
let attn_mask_u32 = Tensor::from_vec(attn, (1, len), &self.device)?;
let token_type_ids = input_ids.zeros_like()?;
let hidden = self.model.forward(&input_ids, &token_type_ids, Some(&attn_mask_u32))?;
let mask_f = attn_mask_u32.to_dtype(DType::F32)?.unsqueeze(2)?; let mask_f = mask_f.broadcast_as(hidden.shape())?; let summed = (hidden * &mask_f)?.sum(1)?; let counts = mask_f.sum(1)?.clamp(1e-9_f32, f32::INFINITY)?; let pooled = (summed / counts)?;
let norm = pooled.sqr()?.sum_keepdim(1)?.sqrt()?; let norm = norm.clamp(1e-12_f32, f32::INFINITY)?;
let normed = pooled.broadcast_div(&norm)?;
let v: Vec<f32> = normed.squeeze(0)?.to_vec1::<f32>()?;
Ok(v)
};
embed_impl().map_err(|e| EmbedderError::Failed { message: format!("forward: {e}") })
}
}