Skip to main content

rig/embeddings/
builder.rs

1//! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded
2//! and batch generates the embeddings for each object when built.
3//! Only types that implement the [Embed] trait can be added to the [EmbeddingsBuilder].
4
5use std::{cmp::max, collections::HashMap};
6
7use futures::{StreamExt, stream};
8
9use crate::{
10    OneOrMany,
11    embeddings::{
12        Embed, EmbedError, Embedding, EmbeddingError, EmbeddingModel, embed::TextEmbedder,
13    },
14};
15
16/// Builder for creating embeddings from one or more documents of type `T`.
17/// Note: `T` can be any type that implements the [Embed] trait.
18///
19/// Using the builder is preferred over using [EmbeddingModel::embed_text] directly as
20/// it will batch the documents in a single request to the model provider.
21///
22/// # Example
23/// ```rust
24/// use std::env;
25///
26/// use rig::{
27///     embeddings::EmbeddingsBuilder,
28///     providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
29/// };
30/// use serde::{Deserialize, Serialize};
31///
32/// // Create OpenAI client
33/// let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
34/// let openai_client = Client::new(&openai_api_key);
35///
36/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
37///
38/// let embeddings = EmbeddingsBuilder::new(model.clone())
39///     .documents(vec![
40///         "1. *flurbo* (noun): A green alien that lives on cold planets.".to_string(),
41///         "2. *flurbo* (noun): A fictional digital currency that originated in the animated series Rick and Morty.".to_string()
42///         "1. *glarb-glarb* (noun): An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
43///         "2. *glarb-glarb* (noun): A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
44///         "1. *linlingdong* (noun): A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(),
45///         "2. *linlingdong* (noun): A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string()
46///     ])?
47///     .build()
48///     .await?;
49/// ```
50#[non_exhaustive]
51pub struct EmbeddingsBuilder<M, T>
52where
53    M: EmbeddingModel,
54    T: Embed,
55{
56    model: M,
57    documents: Vec<(T, Vec<String>)>,
58}
59
60impl<M, T> EmbeddingsBuilder<M, T>
61where
62    M: EmbeddingModel,
63    T: Embed,
64{
65    /// Create a new embedding builder with the given embedding model
66    pub fn new(model: M) -> Self {
67        Self {
68            model,
69            documents: vec![],
70        }
71    }
72
73    /// Add a document to be embedded to the builder. `document` must implement the [Embed] trait.
74    pub fn document(mut self, document: T) -> Result<Self, EmbedError> {
75        let mut embedder = TextEmbedder::default();
76        document.embed(&mut embedder)?;
77
78        self.documents.push((document, embedder.texts));
79
80        Ok(self)
81    }
82
83    /// Add multiple documents to be embedded to the builder. `documents` must be iterable
84    /// with items that implement the [Embed] trait.
85    pub fn documents(self, documents: impl IntoIterator<Item = T>) -> Result<Self, EmbedError> {
86        let builder = documents
87            .into_iter()
88            .try_fold(self, |builder, doc| builder.document(doc))?;
89
90        Ok(builder)
91    }
92}
93
94impl<M, T> EmbeddingsBuilder<M, T>
95where
96    M: EmbeddingModel,
97    T: Embed + Send,
98{
99    /// Generate embeddings for all documents in the builder.
100    /// Returns a vector of tuples, where the first element is the document and the second element is the embeddings (either one embedding or many).
101    pub async fn build(self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> {
102        use stream::TryStreamExt;
103
104        // Store the documents and their texts in a HashMap for easy access.
105        let mut docs = HashMap::new();
106        let mut texts = Vec::new();
107
108        // Iterate over all documents in the builder and insert their docs and texts into the lookup stores.
109        for (i, (doc, doc_texts)) in self.documents.into_iter().enumerate() {
110            docs.insert(i, doc);
111            texts.push((i, doc_texts));
112        }
113
114        // Compute the embeddings.
115        let mut embeddings = stream::iter(texts.into_iter())
116            // Merge the texts of each document into a single list of texts.
117            .flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text))))
118            // Chunk them into batches. Each batch size is at most the embedding API limit per request.
119            .chunks(M::MAX_DOCUMENTS)
120            // Generate the embeddings for each batch.
121            .map(|text| async {
122                let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip();
123
124                let embeddings = self.model.embed_texts(docs).await?;
125                Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>())
126            })
127            // Parallelize the embeddings generation over 10 concurrent requests
128            .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS))
129            // Collect the embeddings into a HashMap.
130            .try_fold(
131                HashMap::new(),
132                |mut acc: HashMap<_, OneOrMany<Embedding>>, embeddings| async move {
133                    embeddings.into_iter().for_each(|(i, embedding)| {
134                        acc.entry(i)
135                            .and_modify(|embeddings| embeddings.push(embedding.clone()))
136                            .or_insert(OneOrMany::one(embedding.clone()));
137                    });
138
139                    Ok(acc)
140                },
141            )
142            .await?;
143
144        // Merge the embeddings with their respective documents
145        docs.into_iter()
146            .map(|(i, doc)| {
147                let embedding = embeddings.remove(&i).ok_or_else(|| {
148                    crate::embeddings::EmbeddingError::ResponseError(
149                        "missing embedding for document after batch merge".to_string(),
150                    )
151                })?;
152                Ok::<_, crate::embeddings::EmbeddingError>((doc, embedding))
153            })
154            .collect::<Result<Vec<_>, crate::embeddings::EmbeddingError>>()
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use crate::{
161        Embed,
162        client::Nothing,
163        embeddings::{
164            Embedding, EmbeddingModel,
165            embed::{EmbedError, TextEmbedder},
166        },
167    };
168
169    use super::EmbeddingsBuilder;
170
171    #[derive(Clone)]
172    struct MockEmbeddingModel;
173
174    impl EmbeddingModel for MockEmbeddingModel {
175        const MAX_DOCUMENTS: usize = 5;
176
177        type Client = Nothing;
178
179        fn make(_: &Self::Client, _: impl Into<String>, _: Option<usize>) -> Self {
180            Self {}
181        }
182
183        fn ndims(&self) -> usize {
184            10
185        }
186
187        async fn embed_texts(
188            &self,
189            documents: impl IntoIterator<Item = String> + Send,
190        ) -> Result<Vec<crate::embeddings::Embedding>, crate::embeddings::EmbeddingError> {
191            Ok(documents
192                .into_iter()
193                .map(|doc| Embedding {
194                    document: doc.to_string(),
195                    vec: vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
196                })
197                .collect())
198        }
199    }
200
201    #[derive(Clone, Debug)]
202    struct WordDefinition {
203        id: String,
204        definitions: Vec<String>,
205    }
206
207    impl Embed for WordDefinition {
208        fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
209            for definition in &self.definitions {
210                embedder.embed(definition.clone());
211            }
212            Ok(())
213        }
214    }
215
216    fn definitions_multiple_text() -> Vec<WordDefinition> {
217        vec![
218            WordDefinition {
219                id: "doc0".to_string(),
220                definitions: vec![
221                    "A green alien that lives on cold planets.".to_string(),
222                    "A fictional digital currency that originated in the animated series Rick and Morty.".to_string()
223                ]
224            },
225            WordDefinition {
226                id: "doc1".to_string(),
227                definitions: vec![
228                    "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
229                    "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
230                ]
231            }
232        ]
233    }
234
235    fn definitions_multiple_text_2() -> Vec<WordDefinition> {
236        vec![
237            WordDefinition {
238                id: "doc2".to_string(),
239                definitions: vec!["Another fake definitions".to_string()],
240            },
241            WordDefinition {
242                id: "doc3".to_string(),
243                definitions: vec!["Some fake definition".to_string()],
244            },
245        ]
246    }
247
248    #[derive(Clone, Debug)]
249    struct WordDefinitionSingle {
250        id: String,
251        definition: String,
252    }
253
254    impl Embed for WordDefinitionSingle {
255        fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
256            embedder.embed(self.definition.clone());
257            Ok(())
258        }
259    }
260
261    fn definitions_single_text() -> Vec<WordDefinitionSingle> {
262        vec![
263            WordDefinitionSingle {
264                id: "doc0".to_string(),
265                definition: "A green alien that lives on cold planets.".to_string(),
266            },
267            WordDefinitionSingle {
268                id: "doc1".to_string(),
269                definition: "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
270            }
271        ]
272    }
273
274    #[tokio::test]
275    async fn test_build_multiple_text() {
276        let fake_definitions = definitions_multiple_text();
277
278        let fake_model = MockEmbeddingModel;
279        let mut result = EmbeddingsBuilder::new(fake_model)
280            .documents(fake_definitions)
281            .unwrap()
282            .build()
283            .await
284            .unwrap();
285
286        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
287            fake_definition_1.id.cmp(&fake_definition_2.id)
288        });
289
290        assert_eq!(result.len(), 2);
291
292        let first_definition = &result[0];
293        assert_eq!(first_definition.0.id, "doc0");
294        assert_eq!(first_definition.1.len(), 2);
295        assert_eq!(
296            first_definition.1.first().document,
297            "A green alien that lives on cold planets.".to_string()
298        );
299
300        let second_definition = &result[1];
301        assert_eq!(second_definition.0.id, "doc1");
302        assert_eq!(second_definition.1.len(), 2);
303        assert_eq!(
304            second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
305        )
306    }
307
308    #[tokio::test]
309    async fn test_build_single_text() {
310        let fake_definitions = definitions_single_text();
311
312        let fake_model = MockEmbeddingModel;
313        let mut result = EmbeddingsBuilder::new(fake_model)
314            .documents(fake_definitions)
315            .unwrap()
316            .build()
317            .await
318            .unwrap();
319
320        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
321            fake_definition_1.id.cmp(&fake_definition_2.id)
322        });
323
324        assert_eq!(result.len(), 2);
325
326        let first_definition = &result[0];
327        assert_eq!(first_definition.0.id, "doc0");
328        assert_eq!(first_definition.1.len(), 1);
329        assert_eq!(
330            first_definition.1.first().document,
331            "A green alien that lives on cold planets.".to_string()
332        );
333
334        let second_definition = &result[1];
335        assert_eq!(second_definition.0.id, "doc1");
336        assert_eq!(second_definition.1.len(), 1);
337        assert_eq!(
338            second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
339        )
340    }
341
342    #[tokio::test]
343    async fn test_build_multiple_and_single_text() {
344        let fake_definitions = definitions_multiple_text();
345        let fake_definitions_single = definitions_multiple_text_2();
346
347        let fake_model = MockEmbeddingModel;
348        let mut result = EmbeddingsBuilder::new(fake_model)
349            .documents(fake_definitions)
350            .unwrap()
351            .documents(fake_definitions_single)
352            .unwrap()
353            .build()
354            .await
355            .unwrap();
356
357        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
358            fake_definition_1.id.cmp(&fake_definition_2.id)
359        });
360
361        assert_eq!(result.len(), 4);
362
363        let second_definition = &result[1];
364        assert_eq!(second_definition.0.id, "doc1");
365        assert_eq!(second_definition.1.len(), 2);
366        assert_eq!(
367            second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
368        );
369
370        let third_definition = &result[2];
371        assert_eq!(third_definition.0.id, "doc2");
372        assert_eq!(third_definition.1.len(), 1);
373        assert_eq!(
374            third_definition.1.first().document,
375            "Another fake definitions".to_string()
376        )
377    }
378
379    #[tokio::test]
380    async fn test_build_string() {
381        let bindings = definitions_multiple_text();
382        let fake_definitions = bindings.iter().map(|def| def.definitions.clone());
383
384        let fake_model = MockEmbeddingModel;
385        let mut result = EmbeddingsBuilder::new(fake_model)
386            .documents(fake_definitions)
387            .unwrap()
388            .build()
389            .await
390            .unwrap();
391
392        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
393            fake_definition_1.cmp(fake_definition_2)
394        });
395
396        assert_eq!(result.len(), 2);
397
398        let first_definition = &result[0];
399        assert_eq!(first_definition.1.len(), 2);
400        assert_eq!(
401            first_definition.1.first().document,
402            "A green alien that lives on cold planets.".to_string()
403        );
404
405        let second_definition = &result[1];
406        assert_eq!(second_definition.1.len(), 2);
407        assert_eq!(
408            second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
409        )
410    }
411}