rig/embeddings/
builder.rs

1//! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded
2//! and batch generates the embeddings for each object when built.
3//! Only types that implement the [Embed] trait can be added to the [EmbeddingsBuilder].
4
5use std::{cmp::max, collections::HashMap};
6
7use futures::{StreamExt, stream};
8
9use crate::{
10    OneOrMany,
11    embeddings::{
12        Embed, EmbedError, Embedding, EmbeddingError, EmbeddingModel, embed::TextEmbedder,
13    },
14};
15
16/// Builder for creating embeddings from one or more documents of type `T`.
17/// Note: `T` can be any type that implements the [Embed] trait.
18///
19/// Using the builder is preferred over using [EmbeddingModel::embed_text] directly as
20/// it will batch the documents in a single request to the model provider.
21///
22/// # Example
23/// ```rust
24/// use std::env;
25///
26/// use rig::{
27///     embeddings::EmbeddingsBuilder,
28///     providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
29/// };
30/// use serde::{Deserialize, Serialize};
31///
32/// // Create OpenAI client
33/// let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
34/// let openai_client = Client::new(&openai_api_key);
35///
36/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
37///
38/// let embeddings = EmbeddingsBuilder::new(model.clone())
39///     .documents(vec![
40///         "1. *flurbo* (noun): A green alien that lives on cold planets.".to_string(),
41///         "2. *flurbo* (noun): A fictional digital currency that originated in the animated series Rick and Morty.".to_string()
42///         "1. *glarb-glarb* (noun): An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
43///         "2. *glarb-glarb* (noun): A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
44///         "1. *linlingdong* (noun): A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(),
45///         "2. *linlingdong* (noun): A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string()
46///     ])?
47///     .build()
48///     .await?;
49/// ```
50pub struct EmbeddingsBuilder<M: EmbeddingModel, T: Embed> {
51    model: M,
52    documents: Vec<(T, Vec<String>)>,
53}
54
55impl<M: EmbeddingModel, T: Embed> EmbeddingsBuilder<M, T> {
56    /// Create a new embedding builder with the given embedding model
57    pub fn new(model: M) -> Self {
58        Self {
59            model,
60            documents: vec![],
61        }
62    }
63
64    /// Add a document to be embedded to the builder. `document` must implement the [Embed] trait.
65    pub fn document(mut self, document: T) -> Result<Self, EmbedError> {
66        let mut embedder = TextEmbedder::default();
67        document.embed(&mut embedder)?;
68
69        self.documents.push((document, embedder.texts));
70
71        Ok(self)
72    }
73
74    /// Add multiple documents to be embedded to the builder. `documents` must be iterable
75    /// with items that implement the [Embed] trait.
76    pub fn documents(self, documents: impl IntoIterator<Item = T>) -> Result<Self, EmbedError> {
77        let builder = documents
78            .into_iter()
79            .try_fold(self, |builder, doc| builder.document(doc))?;
80
81        Ok(builder)
82    }
83}
84
85impl<M: EmbeddingModel, T: Embed + Send> EmbeddingsBuilder<M, T> {
86    /// Generate embeddings for all documents in the builder.
87    /// Returns a vector of tuples, where the first element is the document and the second element is the embeddings (either one embedding or many).
88    pub async fn build(self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> {
89        use stream::TryStreamExt;
90
91        // Store the documents and their texts in a HashMap for easy access.
92        let mut docs = HashMap::new();
93        let mut texts = Vec::new();
94
95        // Iterate over all documents in the builder and insert their docs and texts into the lookup stores.
96        for (i, (doc, doc_texts)) in self.documents.into_iter().enumerate() {
97            docs.insert(i, doc);
98            texts.push((i, doc_texts));
99        }
100
101        // Compute the embeddings.
102        let mut embeddings = stream::iter(texts.into_iter())
103            // Merge the texts of each document into a single list of texts.
104            .flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text))))
105            // Chunk them into batches. Each batch size is at most the embedding API limit per request.
106            .chunks(M::MAX_DOCUMENTS)
107            // Generate the embeddings for each batch.
108            .map(|text| async {
109                let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip();
110
111                let embeddings = self.model.embed_texts(docs).await?;
112                Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>())
113            })
114            // Parallelize the embeddings generation over 10 concurrent requests
115            .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS))
116            // Collect the embeddings into a HashMap.
117            .try_fold(
118                HashMap::new(),
119                |mut acc: HashMap<_, OneOrMany<Embedding>>, embeddings| async move {
120                    embeddings.into_iter().for_each(|(i, embedding)| {
121                        acc.entry(i)
122                            .and_modify(|embeddings| embeddings.push(embedding.clone()))
123                            .or_insert(OneOrMany::one(embedding.clone()));
124                    });
125
126                    Ok(acc)
127                },
128            )
129            .await?;
130
131        // Merge the embeddings with their respective documents
132        Ok(docs
133            .into_iter()
134            .map(|(i, doc)| {
135                (
136                    doc,
137                    embeddings.remove(&i).expect("Document should be present"),
138                )
139            })
140            .collect())
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use crate::{
147        Embed,
148        embeddings::{Embedding, EmbeddingModel, embed::EmbedError, embed::TextEmbedder},
149    };
150
151    use super::EmbeddingsBuilder;
152
153    #[derive(Clone)]
154    struct Model;
155
156    impl EmbeddingModel for Model {
157        const MAX_DOCUMENTS: usize = 5;
158
159        fn ndims(&self) -> usize {
160            10
161        }
162
163        async fn embed_texts(
164            &self,
165            documents: impl IntoIterator<Item = String> + Send,
166        ) -> Result<Vec<crate::embeddings::Embedding>, crate::embeddings::EmbeddingError> {
167            Ok(documents
168                .into_iter()
169                .map(|doc| Embedding {
170                    document: doc.to_string(),
171                    vec: vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
172                })
173                .collect())
174        }
175    }
176
177    #[derive(Clone, Debug)]
178    struct WordDefinition {
179        id: String,
180        definitions: Vec<String>,
181    }
182
183    impl Embed for WordDefinition {
184        fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
185            for definition in &self.definitions {
186                embedder.embed(definition.clone());
187            }
188            Ok(())
189        }
190    }
191
192    fn definitions_multiple_text() -> Vec<WordDefinition> {
193        vec![
194            WordDefinition {
195                id: "doc0".to_string(),
196                definitions: vec![
197                    "A green alien that lives on cold planets.".to_string(),
198                    "A fictional digital currency that originated in the animated series Rick and Morty.".to_string()
199                ]
200            },
201            WordDefinition {
202                id: "doc1".to_string(),
203                definitions: vec![
204                    "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
205                    "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
206                ]
207            }
208        ]
209    }
210
211    fn definitions_multiple_text_2() -> Vec<WordDefinition> {
212        vec![
213            WordDefinition {
214                id: "doc2".to_string(),
215                definitions: vec!["Another fake definitions".to_string()],
216            },
217            WordDefinition {
218                id: "doc3".to_string(),
219                definitions: vec!["Some fake definition".to_string()],
220            },
221        ]
222    }
223
224    #[derive(Clone, Debug)]
225    struct WordDefinitionSingle {
226        id: String,
227        definition: String,
228    }
229
230    impl Embed for WordDefinitionSingle {
231        fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
232            embedder.embed(self.definition.clone());
233            Ok(())
234        }
235    }
236
237    fn definitions_single_text() -> Vec<WordDefinitionSingle> {
238        vec![
239            WordDefinitionSingle {
240                id: "doc0".to_string(),
241                definition: "A green alien that lives on cold planets.".to_string(),
242            },
243            WordDefinitionSingle {
244                id: "doc1".to_string(),
245                definition: "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
246            }
247        ]
248    }
249
250    #[tokio::test]
251    async fn test_build_multiple_text() {
252        let fake_definitions = definitions_multiple_text();
253
254        let fake_model = Model;
255        let mut result = EmbeddingsBuilder::new(fake_model)
256            .documents(fake_definitions)
257            .unwrap()
258            .build()
259            .await
260            .unwrap();
261
262        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
263            fake_definition_1.id.cmp(&fake_definition_2.id)
264        });
265
266        assert_eq!(result.len(), 2);
267
268        let first_definition = &result[0];
269        assert_eq!(first_definition.0.id, "doc0");
270        assert_eq!(first_definition.1.len(), 2);
271        assert_eq!(
272            first_definition.1.first().document,
273            "A green alien that lives on cold planets.".to_string()
274        );
275
276        let second_definition = &result[1];
277        assert_eq!(second_definition.0.id, "doc1");
278        assert_eq!(second_definition.1.len(), 2);
279        assert_eq!(
280            second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
281        )
282    }
283
284    #[tokio::test]
285    async fn test_build_single_text() {
286        let fake_definitions = definitions_single_text();
287
288        let fake_model = Model;
289        let mut result = EmbeddingsBuilder::new(fake_model)
290            .documents(fake_definitions)
291            .unwrap()
292            .build()
293            .await
294            .unwrap();
295
296        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
297            fake_definition_1.id.cmp(&fake_definition_2.id)
298        });
299
300        assert_eq!(result.len(), 2);
301
302        let first_definition = &result[0];
303        assert_eq!(first_definition.0.id, "doc0");
304        assert_eq!(first_definition.1.len(), 1);
305        assert_eq!(
306            first_definition.1.first().document,
307            "A green alien that lives on cold planets.".to_string()
308        );
309
310        let second_definition = &result[1];
311        assert_eq!(second_definition.0.id, "doc1");
312        assert_eq!(second_definition.1.len(), 1);
313        assert_eq!(
314            second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
315        )
316    }
317
318    #[tokio::test]
319    async fn test_build_multiple_and_single_text() {
320        let fake_definitions = definitions_multiple_text();
321        let fake_definitions_single = definitions_multiple_text_2();
322
323        let fake_model = Model;
324        let mut result = EmbeddingsBuilder::new(fake_model)
325            .documents(fake_definitions)
326            .unwrap()
327            .documents(fake_definitions_single)
328            .unwrap()
329            .build()
330            .await
331            .unwrap();
332
333        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
334            fake_definition_1.id.cmp(&fake_definition_2.id)
335        });
336
337        assert_eq!(result.len(), 4);
338
339        let second_definition = &result[1];
340        assert_eq!(second_definition.0.id, "doc1");
341        assert_eq!(second_definition.1.len(), 2);
342        assert_eq!(
343            second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
344        );
345
346        let third_definition = &result[2];
347        assert_eq!(third_definition.0.id, "doc2");
348        assert_eq!(third_definition.1.len(), 1);
349        assert_eq!(
350            third_definition.1.first().document,
351            "Another fake definitions".to_string()
352        )
353    }
354
355    #[tokio::test]
356    async fn test_build_string() {
357        let bindings = definitions_multiple_text();
358        let fake_definitions = bindings.iter().map(|def| def.definitions.clone());
359
360        let fake_model = Model;
361        let mut result = EmbeddingsBuilder::new(fake_model)
362            .documents(fake_definitions)
363            .unwrap()
364            .build()
365            .await
366            .unwrap();
367
368        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
369            fake_definition_1.cmp(fake_definition_2)
370        });
371
372        assert_eq!(result.len(), 2);
373
374        let first_definition = &result[0];
375        assert_eq!(first_definition.1.len(), 2);
376        assert_eq!(
377            first_definition.1.first().document,
378            "A green alien that lives on cold planets.".to_string()
379        );
380
381        let second_definition = &result[1];
382        assert_eq!(second_definition.1.len(), 2);
383        assert_eq!(
384            second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
385        )
386    }
387}