Skip to main content

rig_fastembed/
lib.rs

1//! Local embedding model integration backed by `fastembed`.
2//!
3//! This crate adapts `fastembed` text embedding models to Rig's
4//! [`rig_core::embeddings::EmbeddingModel`] trait. The default feature set
5//! enables Hugging Face model downloads and ONNX Runtime binary downloads.
6//!
7//! `rig-fastembed` is native-only and does not target `wasm32-unknown-unknown`.
8//! The root `rig` facade re-exports this crate as `rig::fastembed` when one of
9//! its Fastembed features is enabled.
10
11use std::sync::Arc;
12use std::{error::Error as StdError, fmt};
13
14pub use fastembed::EmbeddingModel as FastembedModel;
15use fastembed::{InitOptionsUserDefined, ModelInfo, TextEmbedding, UserDefinedEmbeddingModel};
16use rig_core::embeddings::{self, EmbeddingError};
17
18#[cfg(feature = "hf-hub")]
19use fastembed::InitOptions;
20#[cfg(feature = "hf-hub")]
21use rig_core::{Embed, embeddings::EmbeddingsBuilder};
22
23/// The `rig-fastembed` client.
24///
25/// Use this as your main entrypoint for any `rig-fastembed` functionality.
26#[derive(Clone)]
27pub struct Client;
28
29#[derive(Debug, Clone)]
30pub enum FastembedError {
31    UnknownModel(FastembedModel),
32    Initialization(String),
33    UnsupportedMake,
34}
35
36impl fmt::Display for FastembedError {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        match self {
39            FastembedError::UnknownModel(model) => {
40                write!(
41                    f,
42                    "Failed to resolve FastEmbed model metadata for {model:?}"
43                )
44            }
45            FastembedError::Initialization(message) => {
46                write!(f, "Failed to initialize FastEmbed model: {message}")
47            }
48            FastembedError::UnsupportedMake => write!(
49                f,
50                "`EmbeddingModel::make` is not supported for rig-fastembed; construct models via `Client::embedding_model` or `EmbeddingModel::new_from_user_defined`"
51            ),
52        }
53    }
54}
55
56impl StdError for FastembedError {}
57
58impl Default for Client {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl Client {
65    /// Create a new `rig-fastembed` client.
66    pub fn new() -> Self {
67        Self
68    }
69
70    /// Create an embedding model with the given name.
71    /// Note: default embedding dimension of 0 will be used if model is not known.
72    /// If this is the case, it's better to use function `embedding_model_with_ndims`
73    ///
74    /// # Example
75    /// ```
76    /// use rig_fastembed::{Client, FastembedModel};
77    ///
78    /// // Initialize the `rig-fastembed` client
79    /// let fastembed_client = rig_fastembed::Client::new();
80    ///
81    /// let embedding_model = fastembed_client.embedding_model(&FastembedModel::AllMiniLML6V2Q);
82    /// ```
83    #[cfg(feature = "hf-hub")]
84    pub fn embedding_model(
85        &self,
86        model: &FastembedModel,
87    ) -> Result<EmbeddingModel, FastembedError> {
88        let ndims = TextEmbedding::get_model_info(model)
89            .map(|info| info.dim)
90            .map_err(|_| FastembedError::UnknownModel(model.clone()))?;
91
92        EmbeddingModel::new(model, ndims)
93    }
94
95    /// Create an embedding builder with the given embedding model.
96    ///
97    /// # Example
98    /// ```
99    /// use rig_fastembed::{Client, FastembedModel};
100    ///
101    /// // Initialize the Fastembed client
102    /// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
103    /// let fastembed_client = Client::new();
104    ///
105    /// let embeddings = fastembed_client
106    ///     .embeddings(&FastembedModel::AllMiniLML6V2Q)?
107    ///     .documents(vec![
108    ///         "Hello, world!".to_string(),
109    ///         "Goodbye, world!".to_string(),
110    ///     ])?
111    ///     .build()
112    ///     .await?;
113    /// # let _ = embeddings;
114    /// # Ok(())
115    /// # }
116    /// # let _ = run();
117    /// ```
118    #[cfg(feature = "hf-hub")]
119    pub fn embeddings<D: Embed>(
120        &self,
121        model: &fastembed::EmbeddingModel,
122    ) -> Result<EmbeddingsBuilder<EmbeddingModel, D>, FastembedError> {
123        Ok(EmbeddingsBuilder::new(self.embedding_model(model)?))
124    }
125}
126
127#[derive(Clone)]
128pub struct EmbeddingModel {
129    embedder: Option<Arc<TextEmbedding>>,
130    init_error: Option<FastembedError>,
131    pub model: FastembedModel,
132    ndims: usize,
133}
134
135impl EmbeddingModel {
136    #[cfg(feature = "hf-hub")]
137    pub fn new(model: &fastembed::EmbeddingModel, ndims: usize) -> Result<Self, FastembedError> {
138        let embedder = Arc::new(
139            TextEmbedding::try_new(
140                InitOptions::new(model.to_owned()).with_show_download_progress(true),
141            )
142            .map_err(|err| FastembedError::Initialization(err.to_string()))?,
143        );
144
145        Ok(Self {
146            embedder: Some(embedder),
147            init_error: None,
148            model: model.to_owned(),
149            ndims,
150        })
151    }
152
153    pub fn new_from_user_defined(
154        user_defined_model: UserDefinedEmbeddingModel,
155        ndims: usize,
156        model_info: &ModelInfo<FastembedModel>,
157    ) -> Result<Self, FastembedError> {
158        let fastembed_embedding_model = TextEmbedding::try_new_from_user_defined(
159            user_defined_model,
160            InitOptionsUserDefined::default(),
161        )
162        .map_err(|err| FastembedError::Initialization(err.to_string()))?;
163
164        let embedder = Arc::new(fastembed_embedding_model);
165
166        Ok(Self {
167            embedder: Some(embedder),
168            init_error: None,
169            model: model_info.model.to_owned(),
170            ndims,
171        })
172    }
173}
174
175impl embeddings::EmbeddingModel for EmbeddingModel {
176    const MAX_DOCUMENTS: usize = 1024;
177
178    type Client = Client;
179
180    fn make(_: &Self::Client, _: impl Into<String>, _: Option<usize>) -> Self {
181        Self {
182            embedder: None,
183            init_error: Some(FastembedError::UnsupportedMake),
184            model: FastembedModel::AllMiniLML6V2Q,
185            ndims: 0,
186        }
187    }
188
189    fn ndims(&self) -> usize {
190        self.ndims
191    }
192
193    async fn embed_texts(
194        &self,
195        documents: impl IntoIterator<Item = String>,
196    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
197        let Some(embedder) = &self.embedder else {
198            let message = self
199                .init_error
200                .as_ref()
201                .map(ToString::to_string)
202                .unwrap_or_else(|| "FastEmbed model initialization failed".to_string());
203            return Err(EmbeddingError::ProviderError(message));
204        };
205
206        let documents_as_strings: Vec<String> = documents.into_iter().collect();
207
208        let documents_as_vec = embedder
209            .embed(documents_as_strings.clone(), None)
210            .map_err(|err| EmbeddingError::ProviderError(err.to_string()))?;
211
212        let docs = documents_as_strings
213            .into_iter()
214            .zip(documents_as_vec)
215            .map(|(document, embedding)| embeddings::Embedding {
216                document,
217                vec: embedding.into_iter().map(|f| f as f64).collect(),
218            })
219            .collect::<Vec<embeddings::Embedding>>();
220
221        Ok(docs)
222    }
223}