use anyhow::Result;
use async_trait::async_trait;
use model2vec_rs::model::StaticModel;
use std::sync::Arc;
use tracing::{debug, info};
use crate::embeddings::backend::EmbeddingBackend;
use crate::embeddings::config::EmbeddingModelType;
const DEFAULT_MAX_TOKENS: usize = 512;
const DEFAULT_INNER_BATCH: usize = 1024;
pub struct Model2VecBackend {
model: Arc<StaticModel>,
dimension: usize,
model_type: EmbeddingModelType,
}
impl Model2VecBackend {
pub async fn load(model_type: EmbeddingModelType) -> Result<Self> {
let model_id = model_type.model_id().to_string();
info!("Loading Model2Vec backend for model: {}", model_id);
let model = tokio::task::spawn_blocking({
let model_id = model_id.clone();
move || -> Result<StaticModel> {
StaticModel::from_pretrained(model_id, None, None, None)
}
})
.await
.map_err(|e| anyhow::anyhow!("Model2Vec load join error: {e}"))??;
let probe = model.encode_with_args(
&["dimension probe".to_string()],
Some(DEFAULT_MAX_TOKENS),
DEFAULT_INNER_BATCH,
);
let runtime_dim = probe.first().map(|v| v.len()).unwrap_or(0);
let declared_dim = model_type.embedding_dimension();
if runtime_dim == 0 {
return Err(anyhow::anyhow!(
"Model2Vec ({model_id}) produced an empty vector on the dimension probe"
));
}
if runtime_dim != declared_dim {
return Err(anyhow::anyhow!(
"Model2Vec ({model_id}) reports dimension {runtime_dim} at runtime but \
EmbeddingModelType::{model_type:?}.embedding_dimension() = {declared_dim}; \
update the enum constant before shipping or HNSW indices will be sized wrong"
));
}
info!(
"Model2Vec model loaded — id: {}, dim: {}",
model_id, runtime_dim
);
Ok(Self {
model: Arc::new(model),
dimension: runtime_dim,
model_type,
})
}
}
#[async_trait]
impl EmbeddingBackend for Model2VecBackend {
fn embedding_dimension(&self) -> usize {
self.dimension
}
fn is_bert_based(&self) -> bool {
false
}
async fn process_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let model = self.model.clone();
let dim = self.dimension;
let model_type = self.model_type;
let embeddings = tokio::task::spawn_blocking(move || {
model.encode_with_args(&texts, Some(DEFAULT_MAX_TOKENS), DEFAULT_INNER_BATCH)
})
.await
.map_err(|e| anyhow::anyhow!("Model2Vec encode join error: {e}"))?;
for (i, v) in embeddings.iter().enumerate() {
if v.len() != dim {
return Err(anyhow::anyhow!(
"Model2Vec ({model_type:?}) produced vector at index {i} with dim {} \
(expected {})",
v.len(),
dim
));
}
}
debug!(
"Model2Vec encoded {} texts (dim={}) for {:?}",
embeddings.len(),
dim,
model_type
);
Ok(embeddings)
}
}