rig_fastembed/
lib.rs

1use std::sync::Arc;
2
3pub use fastembed::EmbeddingModel as FastembedModel;
4use fastembed::{InitOptionsUserDefined, ModelInfo, TextEmbedding, UserDefinedEmbeddingModel};
5use rig::embeddings::{self, EmbeddingError};
6
7#[cfg(feature = "hf-hub")]
8use fastembed::InitOptions;
9#[cfg(feature = "hf-hub")]
10use rig::{Embed, embeddings::EmbeddingsBuilder};
11
12/// The `rig-fastembed` client.
13///
14/// Use this as your main entrypoint for any `rig-fastembed` functionality.
15#[derive(Clone)]
16pub struct Client;
17
18impl Default for Client {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl Client {
25    /// Create a new `rig-fastembed` client.
26    pub fn new() -> Self {
27        Self
28    }
29
30    /// Create an embedding model with the given name.
31    /// Note: default embedding dimension of 0 will be used if model is not known.
32    /// If this is the case, it's better to use function `embedding_model_with_ndims`
33    ///
34    /// # Example
35    /// ```
36    /// use rig_fastembed::{Client, FastembedModel};
37    ///
38    /// // Initialize the `rig-fastembed` client
39    /// let fastembed_client = rig_fastembed::Client::new();
40    ///
41    /// let embedding_model = fastembed_client.embedding_model(&FastembedModel::AllMiniLML6V2Q);
42    /// ```
43    #[cfg(feature = "hf-hub")]
44    pub fn embedding_model(&self, model: &FastembedModel) -> EmbeddingModel {
45        let ndims = TextEmbedding::get_model_info(model).unwrap().dim;
46
47        EmbeddingModel::new(model, ndims)
48    }
49
50    /// Create an embedding builder with the given embedding model.
51    ///
52    /// # Example
53    /// ```
54    /// use rig_fastembed::{Client, FastembedModel};
55    ///
56    /// // Initialize the Fastembed client
57    /// let fastembed_client = Client::new();
58    ///
59    /// let embeddings = fastembed_client.embeddings(FastembedModel::AllMiniLML6V2Q)
60    ///     .simple_document("doc0", "Hello, world!")
61    ///     .simple_document("doc1", "Goodbye, world!")
62    ///     .build()
63    ///     .await
64    ///     .expect("Failed to embed documents");
65    /// ```
66    #[cfg(feature = "hf-hub")]
67    pub fn embeddings<D: Embed>(
68        &self,
69        model: &fastembed::EmbeddingModel,
70    ) -> EmbeddingsBuilder<EmbeddingModel, D> {
71        EmbeddingsBuilder::new(self.embedding_model(model))
72    }
73}
74
75#[derive(Clone)]
76pub struct EmbeddingModel {
77    embedder: Arc<TextEmbedding>,
78    pub model: FastembedModel,
79    ndims: usize,
80}
81
82impl EmbeddingModel {
83    #[cfg(feature = "hf-hub")]
84    pub fn new(model: &fastembed::EmbeddingModel, ndims: usize) -> Self {
85        let embedder = Arc::new(
86            TextEmbedding::try_new(
87                InitOptions::new(model.to_owned()).with_show_download_progress(true),
88            )
89            .unwrap(),
90        );
91
92        Self {
93            embedder,
94            model: model.to_owned(),
95            ndims,
96        }
97    }
98
99    pub fn new_from_user_defined(
100        user_defined_model: UserDefinedEmbeddingModel,
101        ndims: usize,
102        model_info: &ModelInfo<FastembedModel>,
103    ) -> Self {
104        let fastembed_embedding_model = TextEmbedding::try_new_from_user_defined(
105            user_defined_model,
106            InitOptionsUserDefined::default(),
107        )
108        .unwrap();
109
110        let embedder = Arc::new(fastembed_embedding_model);
111
112        Self {
113            embedder,
114            model: model_info.model.to_owned(),
115            ndims,
116        }
117    }
118}
119
120impl embeddings::EmbeddingModel for EmbeddingModel {
121    const MAX_DOCUMENTS: usize = 1024;
122
123    type Client = Client;
124
125    /// **PANICS**: FastEmbed models cannot be created via this method, which will panic
126    fn make(_: &Self::Client, _: impl Into<String>, _: Option<usize>) -> Self {
127        panic!("Cannot create a fastembed model via `EmbeddingModel::make`")
128    }
129
130    fn ndims(&self) -> usize {
131        self.ndims
132    }
133
134    async fn embed_texts(
135        &self,
136        documents: impl IntoIterator<Item = String>,
137    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
138        let documents_as_strings: Vec<String> = documents.into_iter().collect();
139
140        let documents_as_vec = self
141            .embedder
142            .embed(documents_as_strings.clone(), None)
143            .map_err(|err| EmbeddingError::ProviderError(err.to_string()))?;
144
145        let docs = documents_as_strings
146            .into_iter()
147            .zip(documents_as_vec)
148            .map(|(document, embedding)| embeddings::Embedding {
149                document,
150                vec: embedding.into_iter().map(|f| f as f64).collect(),
151            })
152            .collect::<Vec<embeddings::Embedding>>();
153
154        Ok(docs)
155    }
156}