rig_fastembed/
lib.rs

1use std::sync::Arc;
2
3pub use fastembed::EmbeddingModel as FastembedModel;
4use fastembed::{
5    InitOptions, InitOptionsUserDefined, ModelInfo, TextEmbedding, UserDefinedEmbeddingModel,
6};
7use rig::{
8    Embed,
9    embeddings::{self, EmbeddingError, EmbeddingsBuilder},
10};
11
12/// The `rig-fastembed` client.
13///
14/// Use this as your main entrypoint for any `rig-fastembed` functionality.
15#[derive(Clone)]
16pub struct Client;
17
18impl Default for Client {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl Client {
25    /// Create a new `rig-fastembed` client.
26    pub fn new() -> Self {
27        Self
28    }
29
30    /// Create an embedding model with the given name.
31    /// Note: default embedding dimension of 0 will be used if model is not known.
32    /// If this is the case, it's better to use function `embedding_model_with_ndims`
33    ///
34    /// # Example
35    /// ```
36    /// use rig_fastembed::{Client, FastembedModel};
37    ///
38    /// // Initialize the `rig-fastembed` client
39    /// let fastembed_client = rig_fastembed::Client::new();
40    ///
41    /// let embedding_model = fastembed_client.embedding_model(&FastembedModel::AllMiniLML6V2Q);
42    /// ```
43    pub fn embedding_model(&self, model: &FastembedModel) -> EmbeddingModel {
44        let ndims = TextEmbedding::get_model_info(model).unwrap().dim;
45
46        EmbeddingModel::new(model, ndims)
47    }
48
49    /// Create an embedding builder with the given embedding model.
50    ///
51    /// # Example
52    /// ```
53    /// use rig_fastembed::{Client, FastembedModel};
54    ///
55    /// // Initialize the Fastembed client
56    /// let fastembed_client = Client::new();
57    ///
58    /// let embeddings = fastembed_client.embeddings(FastembedModel::AllMiniLML6V2Q)
59    ///     .simple_document("doc0", "Hello, world!")
60    ///     .simple_document("doc1", "Goodbye, world!")
61    ///     .build()
62    ///     .await
63    ///     .expect("Failed to embed documents");
64    /// ```
65    pub fn embeddings<D: Embed>(
66        &self,
67        model: &fastembed::EmbeddingModel,
68    ) -> EmbeddingsBuilder<EmbeddingModel, D> {
69        EmbeddingsBuilder::new(self.embedding_model(model))
70    }
71}
72
73#[derive(Clone)]
74pub struct EmbeddingModel {
75    embedder: Arc<TextEmbedding>,
76    pub model: FastembedModel,
77    ndims: usize,
78}
79
80impl EmbeddingModel {
81    pub fn new(model: &fastembed::EmbeddingModel, ndims: usize) -> Self {
82        let embedder = Arc::new(
83            TextEmbedding::try_new(
84                InitOptions::new(model.to_owned()).with_show_download_progress(true),
85            )
86            .unwrap(),
87        );
88
89        Self {
90            embedder,
91            model: model.to_owned(),
92            ndims,
93        }
94    }
95
96    pub fn new_from_user_defined(
97        user_defined_model: UserDefinedEmbeddingModel,
98        ndims: usize,
99        model_info: &ModelInfo<FastembedModel>,
100    ) -> Self {
101        let fastembed_embedding_model = TextEmbedding::try_new_from_user_defined(
102            user_defined_model,
103            InitOptionsUserDefined::default(),
104        )
105        .unwrap();
106
107        let embedder = Arc::new(fastembed_embedding_model);
108
109        Self {
110            embedder,
111            model: model_info.model.to_owned(),
112            ndims,
113        }
114    }
115}
116
117impl embeddings::EmbeddingModel for EmbeddingModel {
118    const MAX_DOCUMENTS: usize = 1024;
119
120    fn ndims(&self) -> usize {
121        self.ndims
122    }
123
124    async fn embed_texts(
125        &self,
126        documents: impl IntoIterator<Item = String>,
127    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
128        let documents_as_strings: Vec<String> = documents.into_iter().collect();
129
130        let documents_as_vec = self
131            .embedder
132            .embed(documents_as_strings.clone(), None)
133            .map_err(|err| EmbeddingError::ProviderError(err.to_string()))?;
134
135        let docs = documents_as_strings
136            .into_iter()
137            .zip(documents_as_vec)
138            .map(|(document, embedding)| embeddings::Embedding {
139                document,
140                vec: embedding.into_iter().map(|f| f as f64).collect(),
141            })
142            .collect::<Vec<embeddings::Embedding>>();
143
144        Ok(docs)
145    }
146}