use std::sync::Arc;
use std::{error::Error as StdError, fmt};
pub use fastembed::EmbeddingModel as FastembedModel;
use fastembed::{InitOptionsUserDefined, ModelInfo, TextEmbedding, UserDefinedEmbeddingModel};
use rig::embeddings::{self, EmbeddingError};
#[cfg(feature = "hf-hub")]
use fastembed::InitOptions;
#[cfg(feature = "hf-hub")]
use rig::{Embed, embeddings::EmbeddingsBuilder};
#[derive(Clone)]
pub struct Client;
#[derive(Debug, Clone)]
pub enum FastembedError {
UnknownModel(FastembedModel),
Initialization(String),
UnsupportedMake,
}
impl fmt::Display for FastembedError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
FastembedError::UnknownModel(model) => {
write!(
f,
"Failed to resolve FastEmbed model metadata for {model:?}"
)
}
FastembedError::Initialization(message) => {
write!(f, "Failed to initialize FastEmbed model: {message}")
}
FastembedError::UnsupportedMake => write!(
f,
"`EmbeddingModel::make` is not supported for rig-fastembed; construct models via `Client::embedding_model` or `EmbeddingModel::new_from_user_defined`"
),
}
}
}
impl StdError for FastembedError {}
impl Default for Client {
fn default() -> Self {
Self::new()
}
}
impl Client {
pub fn new() -> Self {
Self
}
#[cfg(feature = "hf-hub")]
pub fn embedding_model(
&self,
model: &FastembedModel,
) -> Result<EmbeddingModel, FastembedError> {
let ndims = TextEmbedding::get_model_info(model)
.map(|info| info.dim)
.map_err(|_| FastembedError::UnknownModel(model.clone()))?;
EmbeddingModel::new(model, ndims)
}
#[cfg(feature = "hf-hub")]
pub fn embeddings<D: Embed>(
&self,
model: &fastembed::EmbeddingModel,
) -> Result<EmbeddingsBuilder<EmbeddingModel, D>, FastembedError> {
Ok(EmbeddingsBuilder::new(self.embedding_model(model)?))
}
}
#[derive(Clone)]
pub struct EmbeddingModel {
embedder: Option<Arc<TextEmbedding>>,
init_error: Option<FastembedError>,
pub model: FastembedModel,
ndims: usize,
}
impl EmbeddingModel {
#[cfg(feature = "hf-hub")]
pub fn new(model: &fastembed::EmbeddingModel, ndims: usize) -> Result<Self, FastembedError> {
let embedder = Arc::new(
TextEmbedding::try_new(
InitOptions::new(model.to_owned()).with_show_download_progress(true),
)
.map_err(|err| FastembedError::Initialization(err.to_string()))?,
);
Ok(Self {
embedder: Some(embedder),
init_error: None,
model: model.to_owned(),
ndims,
})
}
pub fn new_from_user_defined(
user_defined_model: UserDefinedEmbeddingModel,
ndims: usize,
model_info: &ModelInfo<FastembedModel>,
) -> Result<Self, FastembedError> {
let fastembed_embedding_model = TextEmbedding::try_new_from_user_defined(
user_defined_model,
InitOptionsUserDefined::default(),
)
.map_err(|err| FastembedError::Initialization(err.to_string()))?;
let embedder = Arc::new(fastembed_embedding_model);
Ok(Self {
embedder: Some(embedder),
init_error: None,
model: model_info.model.to_owned(),
ndims,
})
}
}
impl embeddings::EmbeddingModel for EmbeddingModel {
const MAX_DOCUMENTS: usize = 1024;
type Client = Client;
fn make(_: &Self::Client, _: impl Into<String>, _: Option<usize>) -> Self {
Self {
embedder: None,
init_error: Some(FastembedError::UnsupportedMake),
model: FastembedModel::AllMiniLML6V2Q,
ndims: 0,
}
}
fn ndims(&self) -> usize {
self.ndims
}
async fn embed_texts(
&self,
documents: impl IntoIterator<Item = String>,
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
let Some(embedder) = &self.embedder else {
let message = self
.init_error
.as_ref()
.map(ToString::to_string)
.unwrap_or_else(|| "FastEmbed model initialization failed".to_string());
return Err(EmbeddingError::ProviderError(message));
};
let documents_as_strings: Vec<String> = documents.into_iter().collect();
let documents_as_vec = embedder
.embed(documents_as_strings.clone(), None)
.map_err(|err| EmbeddingError::ProviderError(err.to_string()))?;
let docs = documents_as_strings
.into_iter()
.zip(documents_as_vec)
.map(|(document, embedding)| embeddings::Embedding {
document,
vec: embedding.into_iter().map(|f| f as f64).collect(),
})
.collect::<Vec<embeddings::Embedding>>();
Ok(docs)
}
}