rig/embeddings/
builder.rs

1//! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded
2//! and batch generates the embeddings for each object when built.
3//! Only types that implement the [Embed] trait can be added to the [EmbeddingsBuilder].
4
5use std::{cmp::max, collections::HashMap};
6
7use futures::{StreamExt, stream};
8
9use crate::{
10    OneOrMany,
11    embeddings::{
12        Embed, EmbedError, Embedding, EmbeddingError, EmbeddingModel, embed::TextEmbedder,
13    },
14};
15
16/// Builder for creating embeddings from one or more documents of type `T`.
17/// Note: `T` can be any type that implements the [Embed] trait.
18///
19/// Using the builder is preferred over using [EmbeddingModel::embed_text] directly as
20/// it will batch the documents in a single request to the model provider.
21///
22/// # Example
23/// ```rust
24/// use std::env;
25///
26/// use rig::{
27///     embeddings::EmbeddingsBuilder,
28///     providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
29/// };
30/// use serde::{Deserialize, Serialize};
31///
32/// // Create OpenAI client
33/// let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
34/// let openai_client = Client::new(&openai_api_key);
35///
36/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
37///
38/// let embeddings = EmbeddingsBuilder::new(model.clone())
39///     .documents(vec![
40///         "1. *flurbo* (noun): A green alien that lives on cold planets.".to_string(),
41///         "2. *flurbo* (noun): A fictional digital currency that originated in the animated series Rick and Morty.".to_string()
42///         "1. *glarb-glarb* (noun): An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
43///         "2. *glarb-glarb* (noun): A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
44///         "1. *linlingdong* (noun): A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(),
45///         "2. *linlingdong* (noun): A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string()
46///     ])?
47///     .build()
48///     .await?;
49/// ```
50#[non_exhaustive]
51pub struct EmbeddingsBuilder<M, T>
52where
53    M: EmbeddingModel,
54    T: Embed,
55{
56    model: M,
57    documents: Vec<(T, Vec<String>)>,
58}
59
60impl<M, T> EmbeddingsBuilder<M, T>
61where
62    M: EmbeddingModel,
63    T: Embed,
64{
65    /// Create a new embedding builder with the given embedding model
66    pub fn new(model: M) -> Self {
67        Self {
68            model,
69            documents: vec![],
70        }
71    }
72
73    /// Add a document to be embedded to the builder. `document` must implement the [Embed] trait.
74    pub fn document(mut self, document: T) -> Result<Self, EmbedError> {
75        let mut embedder = TextEmbedder::default();
76        document.embed(&mut embedder)?;
77
78        self.documents.push((document, embedder.texts));
79
80        Ok(self)
81    }
82
83    /// Add multiple documents to be embedded to the builder. `documents` must be iterable
84    /// with items that implement the [Embed] trait.
85    pub fn documents(self, documents: impl IntoIterator<Item = T>) -> Result<Self, EmbedError> {
86        let builder = documents
87            .into_iter()
88            .try_fold(self, |builder, doc| builder.document(doc))?;
89
90        Ok(builder)
91    }
92}
93
94impl<M, T> EmbeddingsBuilder<M, T>
95where
96    M: EmbeddingModel,
97    T: Embed + Send,
98{
99    /// Generate embeddings for all documents in the builder.
100    /// Returns a vector of tuples, where the first element is the document and the second element is the embeddings (either one embedding or many).
101    pub async fn build(self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> {
102        use stream::TryStreamExt;
103
104        // Store the documents and their texts in a HashMap for easy access.
105        let mut docs = HashMap::new();
106        let mut texts = Vec::new();
107
108        // Iterate over all documents in the builder and insert their docs and texts into the lookup stores.
109        for (i, (doc, doc_texts)) in self.documents.into_iter().enumerate() {
110            docs.insert(i, doc);
111            texts.push((i, doc_texts));
112        }
113
114        // Compute the embeddings.
115        let mut embeddings = stream::iter(texts.into_iter())
116            // Merge the texts of each document into a single list of texts.
117            .flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text))))
118            // Chunk them into batches. Each batch size is at most the embedding API limit per request.
119            .chunks(M::MAX_DOCUMENTS)
120            // Generate the embeddings for each batch.
121            .map(|text| async {
122                let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip();
123
124                let embeddings = self.model.embed_texts(docs).await?;
125                Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>())
126            })
127            // Parallelize the embeddings generation over 10 concurrent requests
128            .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS))
129            // Collect the embeddings into a HashMap.
130            .try_fold(
131                HashMap::new(),
132                |mut acc: HashMap<_, OneOrMany<Embedding>>, embeddings| async move {
133                    embeddings.into_iter().for_each(|(i, embedding)| {
134                        acc.entry(i)
135                            .and_modify(|embeddings| embeddings.push(embedding.clone()))
136                            .or_insert(OneOrMany::one(embedding.clone()));
137                    });
138
139                    Ok(acc)
140                },
141            )
142            .await?;
143
144        // Merge the embeddings with their respective documents
145        Ok(docs
146            .into_iter()
147            .map(|(i, doc)| {
148                (
149                    doc,
150                    embeddings.remove(&i).expect("Document should be present"),
151                )
152            })
153            .collect())
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use crate::{
160        Embed,
161        client::Nothing,
162        embeddings::{
163            Embedding, EmbeddingModel,
164            embed::{EmbedError, TextEmbedder},
165        },
166    };
167
168    use super::EmbeddingsBuilder;
169
170    #[derive(Clone)]
171    struct Model;
172
173    impl EmbeddingModel for Model {
174        const MAX_DOCUMENTS: usize = 5;
175
176        type Client = Nothing;
177
178        fn make(_: &Self::Client, _: impl Into<String>, _: Option<usize>) -> Self {
179            Self
180        }
181
182        fn ndims(&self) -> usize {
183            10
184        }
185
186        async fn embed_texts(
187            &self,
188            documents: impl IntoIterator<Item = String> + Send,
189        ) -> Result<Vec<crate::embeddings::Embedding>, crate::embeddings::EmbeddingError> {
190            Ok(documents
191                .into_iter()
192                .map(|doc| Embedding {
193                    document: doc.to_string(),
194                    vec: vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
195                })
196                .collect())
197        }
198    }
199
200    #[derive(Clone, Debug)]
201    struct WordDefinition {
202        id: String,
203        definitions: Vec<String>,
204    }
205
206    impl Embed for WordDefinition {
207        fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
208            for definition in &self.definitions {
209                embedder.embed(definition.clone());
210            }
211            Ok(())
212        }
213    }
214
215    fn definitions_multiple_text() -> Vec<WordDefinition> {
216        vec![
217            WordDefinition {
218                id: "doc0".to_string(),
219                definitions: vec![
220                    "A green alien that lives on cold planets.".to_string(),
221                    "A fictional digital currency that originated in the animated series Rick and Morty.".to_string()
222                ]
223            },
224            WordDefinition {
225                id: "doc1".to_string(),
226                definitions: vec![
227                    "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
228                    "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
229                ]
230            }
231        ]
232    }
233
234    fn definitions_multiple_text_2() -> Vec<WordDefinition> {
235        vec![
236            WordDefinition {
237                id: "doc2".to_string(),
238                definitions: vec!["Another fake definitions".to_string()],
239            },
240            WordDefinition {
241                id: "doc3".to_string(),
242                definitions: vec!["Some fake definition".to_string()],
243            },
244        ]
245    }
246
247    #[derive(Clone, Debug)]
248    struct WordDefinitionSingle {
249        id: String,
250        definition: String,
251    }
252
253    impl Embed for WordDefinitionSingle {
254        fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
255            embedder.embed(self.definition.clone());
256            Ok(())
257        }
258    }
259
260    fn definitions_single_text() -> Vec<WordDefinitionSingle> {
261        vec![
262            WordDefinitionSingle {
263                id: "doc0".to_string(),
264                definition: "A green alien that lives on cold planets.".to_string(),
265            },
266            WordDefinitionSingle {
267                id: "doc1".to_string(),
268                definition: "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
269            }
270        ]
271    }
272
273    #[tokio::test]
274    async fn test_build_multiple_text() {
275        let fake_definitions = definitions_multiple_text();
276
277        let fake_model = Model;
278        let mut result = EmbeddingsBuilder::new(fake_model)
279            .documents(fake_definitions)
280            .unwrap()
281            .build()
282            .await
283            .unwrap();
284
285        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
286            fake_definition_1.id.cmp(&fake_definition_2.id)
287        });
288
289        assert_eq!(result.len(), 2);
290
291        let first_definition = &result[0];
292        assert_eq!(first_definition.0.id, "doc0");
293        assert_eq!(first_definition.1.len(), 2);
294        assert_eq!(
295            first_definition.1.first().document,
296            "A green alien that lives on cold planets.".to_string()
297        );
298
299        let second_definition = &result[1];
300        assert_eq!(second_definition.0.id, "doc1");
301        assert_eq!(second_definition.1.len(), 2);
302        assert_eq!(
303            second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
304        )
305    }
306
307    #[tokio::test]
308    async fn test_build_single_text() {
309        let fake_definitions = definitions_single_text();
310
311        let fake_model = Model;
312        let mut result = EmbeddingsBuilder::new(fake_model)
313            .documents(fake_definitions)
314            .unwrap()
315            .build()
316            .await
317            .unwrap();
318
319        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
320            fake_definition_1.id.cmp(&fake_definition_2.id)
321        });
322
323        assert_eq!(result.len(), 2);
324
325        let first_definition = &result[0];
326        assert_eq!(first_definition.0.id, "doc0");
327        assert_eq!(first_definition.1.len(), 1);
328        assert_eq!(
329            first_definition.1.first().document,
330            "A green alien that lives on cold planets.".to_string()
331        );
332
333        let second_definition = &result[1];
334        assert_eq!(second_definition.0.id, "doc1");
335        assert_eq!(second_definition.1.len(), 1);
336        assert_eq!(
337            second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
338        )
339    }
340
341    #[tokio::test]
342    async fn test_build_multiple_and_single_text() {
343        let fake_definitions = definitions_multiple_text();
344        let fake_definitions_single = definitions_multiple_text_2();
345
346        let fake_model = Model;
347        let mut result = EmbeddingsBuilder::new(fake_model)
348            .documents(fake_definitions)
349            .unwrap()
350            .documents(fake_definitions_single)
351            .unwrap()
352            .build()
353            .await
354            .unwrap();
355
356        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
357            fake_definition_1.id.cmp(&fake_definition_2.id)
358        });
359
360        assert_eq!(result.len(), 4);
361
362        let second_definition = &result[1];
363        assert_eq!(second_definition.0.id, "doc1");
364        assert_eq!(second_definition.1.len(), 2);
365        assert_eq!(
366            second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
367        );
368
369        let third_definition = &result[2];
370        assert_eq!(third_definition.0.id, "doc2");
371        assert_eq!(third_definition.1.len(), 1);
372        assert_eq!(
373            third_definition.1.first().document,
374            "Another fake definitions".to_string()
375        )
376    }
377
378    #[tokio::test]
379    async fn test_build_string() {
380        let bindings = definitions_multiple_text();
381        let fake_definitions = bindings.iter().map(|def| def.definitions.clone());
382
383        let fake_model = Model;
384        let mut result = EmbeddingsBuilder::new(fake_model)
385            .documents(fake_definitions)
386            .unwrap()
387            .build()
388            .await
389            .unwrap();
390
391        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
392            fake_definition_1.cmp(fake_definition_2)
393        });
394
395        assert_eq!(result.len(), 2);
396
397        let first_definition = &result[0];
398        assert_eq!(first_definition.1.len(), 2);
399        assert_eq!(
400            first_definition.1.first().document,
401            "A green alien that lives on cold planets.".to_string()
402        );
403
404        let second_definition = &result[1];
405        assert_eq!(second_definition.1.len(), 2);
406        assert_eq!(
407            second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
408        )
409    }
410}