use std::fmt;
use std::sync::{Arc, Mutex};
use crate::FastEmbedOptions;
#[derive(Debug)]
pub enum FastEmbedError {
UnknownModel(String),
Init(String),
Embed(String),
MutexPoisoned(String),
TaskPanicked(String),
}
impl fmt::Display for FastEmbedError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::UnknownModel(msg) => write!(f, "unknown fastembed model: {msg}"),
Self::Init(msg) => write!(f, "fastembed init failed: {msg}"),
Self::Embed(msg) => write!(f, "fastembed embed failed: {msg}"),
Self::MutexPoisoned(msg) => write!(f, "fastembed mutex poisoned: {msg}"),
Self::TaskPanicked(msg) => write!(f, "fastembed blocking task panicked: {msg}"),
}
}
}
impl std::error::Error for FastEmbedError {}
#[derive(Debug, Clone)]
pub struct FastEmbedResponse {
pub embeddings: Vec<Vec<f32>>,
pub model: String,
}
pub struct FastEmbedModel {
model: Arc<Mutex<fastembed::TextEmbedding>>,
model_id: String,
dims: usize,
batch_size: Option<usize>,
}
impl FastEmbedModel {
pub fn from_options(opts: FastEmbedOptions) -> Result<Self, FastEmbedError> {
let fe_model = if let Some(ref name) = opts.model_name {
name.parse::<fastembed::EmbeddingModel>()
.map_err(|e| FastEmbedError::UnknownModel(format!("\"{name}\": {e}")))?
} else {
fastembed::EmbeddingModel::default()
};
let model_info =
<fastembed::EmbeddingModel as fastembed::ModelTrait>::get_model_info(&fe_model)
.ok_or_else(|| {
FastEmbedError::Init(format!("no model info found for {fe_model:?}"))
})?;
let dims = model_info.dim;
let model_code = model_info.model_code.clone();
let mut init_opts = fastembed::TextInitOptions::new(fe_model);
if let Some(cache_dir) = opts.cache_dir {
init_opts = init_opts.with_cache_dir(cache_dir);
}
if let Some(show) = opts.show_download_progress {
init_opts = init_opts.with_show_download_progress(show);
}
let te = fastembed::TextEmbedding::try_new(init_opts)
.map_err(|e| FastEmbedError::Init(e.to_string()))?;
Ok(Self {
model: Arc::new(Mutex::new(te)),
model_id: model_code,
dims,
batch_size: opts.max_batch_size,
})
}
#[must_use]
pub fn model_id(&self) -> &str {
&self.model_id
}
#[must_use]
pub fn dimensions(&self) -> usize {
self.dims
}
pub async fn embed(&self, texts: &[String]) -> Result<FastEmbedResponse, FastEmbedError> {
if texts.is_empty() {
return Ok(FastEmbedResponse {
embeddings: vec![],
model: self.model_id.clone(),
});
}
let texts_owned: Vec<String> = texts.to_vec();
let batch_size = self.batch_size;
let model_id = self.model_id.clone();
let model_handle = Arc::clone(&self.model);
let embeddings = tokio::task::spawn_blocking(move || {
let mut model = model_handle
.lock()
.map_err(|e| FastEmbedError::MutexPoisoned(e.to_string()))?;
let result: Vec<Vec<f32>> = model
.embed(&texts_owned, batch_size)
.map_err(|e| FastEmbedError::Embed(e.to_string()))?;
Ok::<Vec<Vec<f32>>, FastEmbedError>(result)
})
.await
.map_err(|e| FastEmbedError::TaskPanicked(e.to_string()))??;
Ok(FastEmbedResponse {
embeddings,
model: model_id,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore = "requires model download from HuggingFace"]
fn from_options_default_loads_model() {
let model = FastEmbedModel::from_options(FastEmbedOptions::default())
.expect("should create model with default options");
assert!(model.dimensions() > 0);
assert!(!model.model_id().is_empty());
}
#[tokio::test]
#[ignore = "requires model download from HuggingFace"]
async fn embed_returns_correct_count() {
let model = FastEmbedModel::from_options(FastEmbedOptions::default())
.expect("should create model with default options");
let response = model
.embed(&["hello".into(), "world".into()])
.await
.expect("embedding should succeed");
assert_eq!(response.embeddings.len(), 2);
assert!(!response.embeddings[0].is_empty());
assert_eq!(response.embeddings[0].len(), model.dimensions());
}
#[tokio::test]
async fn embed_empty_input_returns_empty() {
let Ok(model) = FastEmbedModel::from_options(FastEmbedOptions::default()) else {
eprintln!("skipping embed_empty_input_returns_empty: model not available");
return;
};
let response = model.embed(&[]).await.expect("empty embed should succeed");
assert!(response.embeddings.is_empty());
}
}