use std::sync::Mutex;
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
use hf_hub::{Repo, RepoType, api::sync::Api};
use tokenizers::Tokenizer;
use super::{BatchEmbedder, EmbedderError, QueryEmbedder, QueryEmbedderIdentity};
const MODEL_ID: &str = "BAAI/bge-small-en-v1.5";
const MODEL_REVISION: &str = "main";
const MODEL_DIMENSION: usize = 384;
pub struct BuiltinBgeSmallEmbedder {
state: Mutex<Option<ModelState>>,
}
impl std::fmt::Debug for BuiltinBgeSmallEmbedder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let loaded = self.state.lock().map(|g| g.is_some()).unwrap_or(false);
f.debug_struct("BuiltinBgeSmallEmbedder")
.field("model_id", &MODEL_ID)
.field("loaded", &loaded)
.finish()
}
}
impl Default for BuiltinBgeSmallEmbedder {
fn default() -> Self {
Self::new()
}
}
impl BuiltinBgeSmallEmbedder {
#[must_use]
pub fn new() -> Self {
Self {
state: Mutex::new(None),
}
}
fn load_model_state() -> Result<ModelState, EmbedderError> {
let device = Device::Cpu;
let repo = Repo::with_revision(
MODEL_ID.to_owned(),
RepoType::Model,
MODEL_REVISION.to_owned(),
);
let api = Api::new()
.map_err(|e| EmbedderError::Unavailable(format!("hf-hub api init failed: {e}")))?
.repo(repo);
let config_path = api
.get("config.json")
.map_err(|e| EmbedderError::Unavailable(format!("fetch config.json: {e}")))?;
let tokenizer_path = api
.get("tokenizer.json")
.map_err(|e| EmbedderError::Unavailable(format!("fetch tokenizer.json: {e}")))?;
let weights_path = api
.get("model.safetensors")
.map_err(|e| EmbedderError::Unavailable(format!("fetch model.safetensors: {e}")))?;
let config_bytes = std::fs::read_to_string(&config_path)
.map_err(|e| EmbedderError::Unavailable(format!("read config.json: {e}")))?;
let config: Config = serde_json::from_str(&config_bytes)
.map_err(|e| EmbedderError::Unavailable(format!("parse config.json: {e}")))?;
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| EmbedderError::Unavailable(format!("load tokenizer: {e}")))?;
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, &device)
.map_err(|e| EmbedderError::Unavailable(format!("mmap safetensors: {e}")))?
};
let model = BertModel::load(vb, &config)
.map_err(|e| EmbedderError::Unavailable(format!("load BertModel: {e}")))?;
Ok(ModelState {
tokenizer,
model,
device,
})
}
fn embed_with_state(state: &ModelState, text: &str) -> Result<Vec<f32>, EmbedderError> {
let encoding = state
.tokenizer
.encode(text, true)
.map_err(|e| EmbedderError::Failed(format!("tokenize: {e}")))?;
let ids = encoding.get_ids();
if ids.is_empty() {
return Err(EmbedderError::Failed(
"tokenizer produced empty id sequence".to_owned(),
));
}
let input_ids = Tensor::new(ids, &state.device)
.and_then(|t| t.unsqueeze(0))
.map_err(|e| EmbedderError::Failed(format!("build input_ids tensor: {e}")))?;
let token_type_ids = input_ids
.zeros_like()
.map_err(|e| EmbedderError::Failed(format!("build token_type_ids: {e}")))?;
let hidden = state
.model
.forward(&input_ids, &token_type_ids, None)
.map_err(|e| EmbedderError::Failed(format!("bert forward: {e}")))?;
let cls = hidden
.get(0) .and_then(|batch0| batch0.get(0)) .map_err(|e| EmbedderError::Failed(format!("index CLS token: {e}")))?;
let normalized =
l2_normalize(&cls).map_err(|e| EmbedderError::Failed(format!("l2 normalize: {e}")))?;
let as_f32 = normalized
.to_dtype(DType::F32)
.and_then(|t| t.to_vec1::<f32>())
.map_err(|e| EmbedderError::Failed(format!("tensor to Vec<f32>: {e}")))?;
if as_f32.len() != MODEL_DIMENSION {
return Err(EmbedderError::Failed(format!(
"expected {MODEL_DIMENSION}-dim vector, got {}",
as_f32.len()
)));
}
Ok(as_f32)
}
}
fn l2_normalize(v: &Tensor) -> candle_core::Result<Tensor> {
let sq = v.sqr()?;
let norm_sq = sq.sum_all()?.to_scalar::<f32>()?;
if norm_sq <= f32::EPSILON {
return Ok(v.clone());
}
let norm = norm_sq.sqrt();
v.affine(f64::from(1.0_f32 / norm), 0.0)
}
struct ModelState {
tokenizer: Tokenizer,
model: BertModel,
device: Device,
}
impl QueryEmbedder for BuiltinBgeSmallEmbedder {
fn embed_query(&self, text: &str) -> Result<Vec<f32>, EmbedderError> {
let mut guard = self
.state
.lock()
.map_err(|_| EmbedderError::Failed("embedder state mutex poisoned".to_owned()))?;
if guard.is_none() {
*guard = Some(Self::load_model_state()?);
}
let state = guard
.as_ref()
.ok_or_else(|| EmbedderError::Failed("model state unexpectedly None".to_owned()))?;
Self::embed_with_state(state, text)
}
fn identity(&self) -> QueryEmbedderIdentity {
QueryEmbedderIdentity {
model_identity: MODEL_ID.to_owned(),
model_version: MODEL_REVISION.to_owned(),
dimension: MODEL_DIMENSION,
normalization_policy: "l2".to_owned(),
}
}
fn max_tokens(&self) -> usize {
512
}
}
impl BatchEmbedder for BuiltinBgeSmallEmbedder {
fn batch_embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbedderError> {
let mut guard = self
.state
.lock()
.map_err(|_| EmbedderError::Failed("embedder state mutex poisoned".to_owned()))?;
if guard.is_none() {
*guard = Some(Self::load_model_state()?);
}
let state = guard
.as_ref()
.ok_or_else(|| EmbedderError::Failed("model state unexpectedly None".to_owned()))?;
texts
.iter()
.map(|text| Self::embed_with_state(state, text))
.collect()
}
fn identity(&self) -> QueryEmbedderIdentity {
QueryEmbedderIdentity {
model_identity: MODEL_ID.to_owned(),
model_version: MODEL_REVISION.to_owned(),
dimension: MODEL_DIMENSION,
normalization_policy: "l2".to_owned(),
}
}
fn max_tokens(&self) -> usize {
512
}
}
#[cfg(test)]
#[allow(clippy::expect_used)]
mod tests {
use super::*;
use crate::embedder::{BatchEmbedder, QueryEmbedder};
#[test]
fn builtin_bge_small_max_tokens_returns_512() {
let embedder = BuiltinBgeSmallEmbedder::new();
assert_eq!(QueryEmbedder::max_tokens(&embedder), 512);
assert_eq!(BatchEmbedder::max_tokens(&embedder), 512);
}
#[test]
fn builtin_bge_small_batch_embed_returns_one_vector_per_input() {
let embedder = BuiltinBgeSmallEmbedder::new();
let texts = vec![
"hello world".to_owned(),
"machine learning".to_owned(),
"rust programming".to_owned(),
];
let result = embedder
.batch_embed(&texts)
.expect("batch_embed must succeed");
assert_eq!(result.len(), 3, "one vector per input text");
for (i, vec) in result.iter().enumerate() {
assert_eq!(
vec.len(),
384,
"vector {i} must have BGE-small dimension 384"
);
}
}
}