Skip to main content

rig_fastembed/
lib.rs

1use std::sync::Arc;
2use std::{error::Error as StdError, fmt};
3
4pub use fastembed::EmbeddingModel as FastembedModel;
5use fastembed::{InitOptionsUserDefined, ModelInfo, TextEmbedding, UserDefinedEmbeddingModel};
6use rig::embeddings::{self, EmbeddingError};
7
8#[cfg(feature = "hf-hub")]
9use fastembed::InitOptions;
10#[cfg(feature = "hf-hub")]
11use rig::{Embed, embeddings::EmbeddingsBuilder};
12
13/// The `rig-fastembed` client.
14///
15/// Use this as your main entrypoint for any `rig-fastembed` functionality.
16#[derive(Clone)]
17pub struct Client;
18
19#[derive(Debug, Clone)]
20pub enum FastembedError {
21    UnknownModel(FastembedModel),
22    Initialization(String),
23    UnsupportedMake,
24}
25
26impl fmt::Display for FastembedError {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        match self {
29            FastembedError::UnknownModel(model) => {
30                write!(
31                    f,
32                    "Failed to resolve FastEmbed model metadata for {model:?}"
33                )
34            }
35            FastembedError::Initialization(message) => {
36                write!(f, "Failed to initialize FastEmbed model: {message}")
37            }
38            FastembedError::UnsupportedMake => write!(
39                f,
40                "`EmbeddingModel::make` is not supported for rig-fastembed; construct models via `Client::embedding_model` or `EmbeddingModel::new_from_user_defined`"
41            ),
42        }
43    }
44}
45
46impl StdError for FastembedError {}
47
48impl Default for Client {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl Client {
55    /// Create a new `rig-fastembed` client.
56    pub fn new() -> Self {
57        Self
58    }
59
60    /// Create an embedding model with the given name.
61    /// Note: default embedding dimension of 0 will be used if model is not known.
62    /// If this is the case, it's better to use function `embedding_model_with_ndims`
63    ///
64    /// # Example
65    /// ```
66    /// use rig_fastembed::{Client, FastembedModel};
67    ///
68    /// // Initialize the `rig-fastembed` client
69    /// let fastembed_client = rig_fastembed::Client::new();
70    ///
71    /// let embedding_model = fastembed_client.embedding_model(&FastembedModel::AllMiniLML6V2Q);
72    /// ```
73    #[cfg(feature = "hf-hub")]
74    pub fn embedding_model(
75        &self,
76        model: &FastembedModel,
77    ) -> Result<EmbeddingModel, FastembedError> {
78        let ndims = TextEmbedding::get_model_info(model)
79            .map(|info| info.dim)
80            .map_err(|_| FastembedError::UnknownModel(model.clone()))?;
81
82        EmbeddingModel::new(model, ndims)
83    }
84
85    /// Create an embedding builder with the given embedding model.
86    ///
87    /// # Example
88    /// ```
89    /// use rig_fastembed::{Client, FastembedModel};
90    ///
91    /// // Initialize the Fastembed client
92    /// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
93    /// let fastembed_client = Client::new();
94    ///
95    /// let embeddings = fastembed_client
96    ///     .embeddings(&FastembedModel::AllMiniLML6V2Q)?
97    ///     .documents(vec![
98    ///         "Hello, world!".to_string(),
99    ///         "Goodbye, world!".to_string(),
100    ///     ])?
101    ///     .build()
102    ///     .await?;
103    /// # let _ = embeddings;
104    /// # Ok(())
105    /// # }
106    /// # let _ = run();
107    /// ```
108    #[cfg(feature = "hf-hub")]
109    pub fn embeddings<D: Embed>(
110        &self,
111        model: &fastembed::EmbeddingModel,
112    ) -> Result<EmbeddingsBuilder<EmbeddingModel, D>, FastembedError> {
113        Ok(EmbeddingsBuilder::new(self.embedding_model(model)?))
114    }
115}
116
117#[derive(Clone)]
118pub struct EmbeddingModel {
119    embedder: Option<Arc<TextEmbedding>>,
120    init_error: Option<FastembedError>,
121    pub model: FastembedModel,
122    ndims: usize,
123}
124
125impl EmbeddingModel {
126    #[cfg(feature = "hf-hub")]
127    pub fn new(model: &fastembed::EmbeddingModel, ndims: usize) -> Result<Self, FastembedError> {
128        let embedder = Arc::new(
129            TextEmbedding::try_new(
130                InitOptions::new(model.to_owned()).with_show_download_progress(true),
131            )
132            .map_err(|err| FastembedError::Initialization(err.to_string()))?,
133        );
134
135        Ok(Self {
136            embedder: Some(embedder),
137            init_error: None,
138            model: model.to_owned(),
139            ndims,
140        })
141    }
142
143    pub fn new_from_user_defined(
144        user_defined_model: UserDefinedEmbeddingModel,
145        ndims: usize,
146        model_info: &ModelInfo<FastembedModel>,
147    ) -> Result<Self, FastembedError> {
148        let fastembed_embedding_model = TextEmbedding::try_new_from_user_defined(
149            user_defined_model,
150            InitOptionsUserDefined::default(),
151        )
152        .map_err(|err| FastembedError::Initialization(err.to_string()))?;
153
154        let embedder = Arc::new(fastembed_embedding_model);
155
156        Ok(Self {
157            embedder: Some(embedder),
158            init_error: None,
159            model: model_info.model.to_owned(),
160            ndims,
161        })
162    }
163}
164
165impl embeddings::EmbeddingModel for EmbeddingModel {
166    const MAX_DOCUMENTS: usize = 1024;
167
168    type Client = Client;
169
170    fn make(_: &Self::Client, _: impl Into<String>, _: Option<usize>) -> Self {
171        Self {
172            embedder: None,
173            init_error: Some(FastembedError::UnsupportedMake),
174            model: FastembedModel::AllMiniLML6V2Q,
175            ndims: 0,
176        }
177    }
178
179    fn ndims(&self) -> usize {
180        self.ndims
181    }
182
183    async fn embed_texts(
184        &self,
185        documents: impl IntoIterator<Item = String>,
186    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
187        let Some(embedder) = &self.embedder else {
188            let message = self
189                .init_error
190                .as_ref()
191                .map(ToString::to_string)
192                .unwrap_or_else(|| "FastEmbed model initialization failed".to_string());
193            return Err(EmbeddingError::ProviderError(message));
194        };
195
196        let documents_as_strings: Vec<String> = documents.into_iter().collect();
197
198        let documents_as_vec = embedder
199            .embed(documents_as_strings.clone(), None)
200            .map_err(|err| EmbeddingError::ProviderError(err.to_string()))?;
201
202        let docs = documents_as_strings
203            .into_iter()
204            .zip(documents_as_vec)
205            .map(|(document, embedding)| embeddings::Embedding {
206                document,
207                vec: embedding.into_iter().map(|f| f as f64).collect(),
208            })
209            .collect::<Vec<embeddings::Embedding>>();
210
211        Ok(docs)
212    }
213}