mod bert;
mod gemma;
use crate::error::{Error, Result};
use std::path::PathBuf;
pub use bert::BertEmbedder;
pub use gemma::GemmaEmbedder;
#[derive(Debug, Clone)]
pub enum EmbeddingModel {
MiniLM,
BgeSmall,
Gemma {
dimensions: usize,
},
Custom {
model_id: String,
dimensions: usize,
},
}
impl EmbeddingModel {
pub fn dimensions(&self) -> usize {
match self {
EmbeddingModel::MiniLM => 384,
EmbeddingModel::BgeSmall => 384,
EmbeddingModel::Gemma { dimensions } => *dimensions,
EmbeddingModel::Custom { dimensions, .. } => *dimensions,
}
}
pub fn model_id(&self) -> &str {
match self {
EmbeddingModel::MiniLM => "sentence-transformers/all-MiniLM-L6-v2",
EmbeddingModel::BgeSmall => "BAAI/bge-small-en-v1.5",
EmbeddingModel::Gemma { .. } => "google/embeddinggemma-300m",
EmbeddingModel::Custom { model_id, .. } => model_id,
}
}
pub fn native_dimensions(&self) -> usize {
match self {
EmbeddingModel::MiniLM => 384,
EmbeddingModel::BgeSmall => 384,
EmbeddingModel::Gemma { .. } => 768,
EmbeddingModel::Custom { dimensions, .. } => *dimensions,
}
}
pub fn supports_matryoshka(&self) -> bool {
matches!(self, EmbeddingModel::Gemma { .. })
}
fn is_gemma(&self) -> bool {
matches!(self, EmbeddingModel::Gemma { .. })
}
}
pub struct Embedder {
inner: EmbedderInner,
model: EmbeddingModel,
}
enum EmbedderInner {
Bert(BertEmbedder),
Gemma(GemmaEmbedder),
}
#[derive(Debug, Clone)]
pub struct EmbedderConfig {
pub model: EmbeddingModel,
pub cache_dir: Option<PathBuf>,
pub normalize: bool,
pub max_length: Option<usize>,
}
impl EmbedderConfig {
pub fn new(model: EmbeddingModel) -> Self {
Self {
model,
cache_dir: None,
normalize: true,
max_length: None,
}
}
pub fn with_cache_dir(mut self, path: impl Into<PathBuf>) -> Self {
self.cache_dir = Some(path.into());
self
}
pub fn with_normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
pub fn with_max_length(mut self, max_length: usize) -> Self {
self.max_length = Some(max_length);
self
}
}
impl Embedder {
pub fn new(model: EmbeddingModel) -> Result<Self> {
Self::with_config(EmbedderConfig::new(model))
}
pub fn with_config(config: EmbedderConfig) -> Result<Self> {
let model = config.model.clone();
let inner = if model.is_gemma() {
let gemma = GemmaEmbedder::load(&config)?;
EmbedderInner::Gemma(gemma)
} else {
let bert = BertEmbedder::load(&config)?;
EmbedderInner::Bert(bert)
};
Ok(Self { inner, model })
}
pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
let mut embedding = match &self.inner {
EmbedderInner::Bert(bert) => bert.embed(text)?,
EmbedderInner::Gemma(gemma) => gemma.embed(text)?,
};
let target_dims = self.model.dimensions();
if embedding.len() > target_dims {
embedding.truncate(target_dims);
l2_normalize(&mut embedding);
}
Ok(embedding)
}
pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let mut results = match &self.inner {
EmbedderInner::Bert(bert) => bert.embed_batch(texts)?,
EmbedderInner::Gemma(gemma) => gemma.embed_batch(texts)?,
};
let target_dims = self.model.dimensions();
for embedding in &mut results {
if embedding.len() > target_dims {
embedding.truncate(target_dims);
l2_normalize(embedding);
}
}
Ok(results)
}
pub fn model(&self) -> &EmbeddingModel {
&self.model
}
pub fn dimensions(&self) -> usize {
self.model.dimensions()
}
pub fn into_embed_fn(self) -> impl Fn(&str) -> Vec<f32> + Send + Sync + 'static {
use std::sync::Arc;
let embedder = Arc::new(self);
move |text: &str| -> Vec<f32> {
embedder
.embed(text)
.unwrap_or_else(|_| vec![0.0; embedder.dimensions()])
}
}
}
pub(crate) fn download_model_files(
model_id: &str,
filenames: &[&str],
cache_dir: Option<&PathBuf>,
) -> Result<Vec<PathBuf>> {
use hf_hub::api::tokio::{Api, ApiBuilder};
use hf_hub::{Cache, Repo, RepoType};
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| Error::InvalidConfig(format!("Failed to create tokio runtime: {}", e)))?;
rt.block_on(async {
let api: Api = if let Some(cache) = cache_dir {
let cache = Cache::new(cache.clone());
ApiBuilder::from_cache(cache).build().map_err(|e| {
Error::InvalidConfig(format!("Failed to create HF Hub API: {}", e))
})?
} else {
Api::new()
.map_err(|e| Error::InvalidConfig(format!("Failed to create HF Hub API: {}", e)))?
};
let repo = api.repo(Repo::new(model_id.to_string(), RepoType::Model));
let mut paths = Vec::new();
for filename in filenames {
let path = repo.get(filename).await.map_err(|e| {
Error::Embedding(format!(
"Failed to download '{}' from '{}': {}",
filename, model_id, e
))
})?;
paths.push(path);
}
Ok(paths)
})
}
pub(crate) fn l2_normalize(v: &mut [f32]) {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-12 {
v.iter_mut().for_each(|x| *x /= norm);
}
}