rig/embeddings/
builder.rs

1//! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded
2//! and batch generates the embeddings for each object when built.
3//! Only types that implement the [Embed] trait can be added to the [EmbeddingsBuilder].
4
5use std::{cmp::max, collections::HashMap};
6
7use futures::{StreamExt, stream};
8
9use crate::{
10    OneOrMany,
11    embeddings::{
12        Embed, EmbedError, Embedding, EmbeddingError, EmbeddingModel, embed::TextEmbedder,
13    },
14};
15
16/// Builder for creating embeddings from one or more documents of type `T`.
17/// Note: `T` can be any type that implements the [Embed] trait.
18///
19/// Using the builder is preferred over using [EmbeddingModel::embed_text] directly as
20/// it will batch the documents in a single request to the model provider.
21///
22/// # Example
23/// ```rust
24/// use std::env;
25///
26/// use rig::{
27///     embeddings::EmbeddingsBuilder,
28///     providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
29/// };
30/// use serde::{Deserialize, Serialize};
31///
32/// // Create OpenAI client
33/// let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
34/// let openai_client = Client::new(&openai_api_key);
35///
36/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
37///
38/// let embeddings = EmbeddingsBuilder::new(model.clone())
39///     .documents(vec![
40///         "1. *flurbo* (noun): A green alien that lives on cold planets.".to_string(),
41///         "2. *flurbo* (noun): A fictional digital currency that originated in the animated series Rick and Morty.".to_string()
42///         "1. *glarb-glarb* (noun): An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
43///         "2. *glarb-glarb* (noun): A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
44///         "1. *linlingdong* (noun): A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(),
45///         "2. *linlingdong* (noun): A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string()
46///     ])?
47///     .build()
48///     .await?;
49/// ```
50#[non_exhaustive]
51pub struct EmbeddingsBuilder<M, T>
52where
53    M: EmbeddingModel,
54    T: Embed,
55{
56    model: M,
57    documents: Vec<(T, Vec<String>)>,
58}
59
60impl<M, T> EmbeddingsBuilder<M, T>
61where
62    M: EmbeddingModel,
63    T: Embed,
64{
65    /// Create a new embedding builder with the given embedding model
66    pub fn new(model: M) -> Self {
67        Self {
68            model,
69            documents: vec![],
70        }
71    }
72
73    /// Add a document to be embedded to the builder. `document` must implement the [Embed] trait.
74    pub fn document(mut self, document: T) -> Result<Self, EmbedError> {
75        let mut embedder = TextEmbedder::default();
76        document.embed(&mut embedder)?;
77
78        self.documents.push((document, embedder.texts));
79
80        Ok(self)
81    }
82
83    /// Add multiple documents to be embedded to the builder. `documents` must be iterable
84    /// with items that implement the [Embed] trait.
85    pub fn documents(self, documents: impl IntoIterator<Item = T>) -> Result<Self, EmbedError> {
86        let builder = documents
87            .into_iter()
88            .try_fold(self, |builder, doc| builder.document(doc))?;
89
90        Ok(builder)
91    }
92}
93
94impl<M, T> EmbeddingsBuilder<M, T>
95where
96    M: EmbeddingModel,
97    T: Embed + Send,
98{
99    /// Generate embeddings for all documents in the builder.
100    /// Returns a vector of tuples, where the first element is the document and the second element is the embeddings (either one embedding or many).
101    pub async fn build(self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> {
102        use stream::TryStreamExt;
103
104        // Store the documents and their texts in a HashMap for easy access.
105        let mut docs = HashMap::new();
106        let mut texts = Vec::new();
107
108        // Iterate over all documents in the builder and insert their docs and texts into the lookup stores.
109        for (i, (doc, doc_texts)) in self.documents.into_iter().enumerate() {
110            docs.insert(i, doc);
111            texts.push((i, doc_texts));
112        }
113
114        // Compute the embeddings.
115        let mut embeddings = stream::iter(texts.into_iter())
116            // Merge the texts of each document into a single list of texts.
117            .flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text))))
118            // Chunk them into batches. Each batch size is at most the embedding API limit per request.
119            .chunks(M::MAX_DOCUMENTS)
120            // Generate the embeddings for each batch.
121            .map(|text| async {
122                let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip();
123
124                let embeddings = self.model.embed_texts(docs).await?;
125                Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>())
126            })
127            // Parallelize the embeddings generation over 10 concurrent requests
128            .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS))
129            // Collect the embeddings into a HashMap.
130            .try_fold(
131                HashMap::new(),
132                |mut acc: HashMap<_, OneOrMany<Embedding>>, embeddings| async move {
133                    embeddings.into_iter().for_each(|(i, embedding)| {
134                        acc.entry(i)
135                            .and_modify(|embeddings| embeddings.push(embedding.clone()))
136                            .or_insert(OneOrMany::one(embedding.clone()));
137                    });
138
139                    Ok(acc)
140                },
141            )
142            .await?;
143
144        // Merge the embeddings with their respective documents
145        Ok(docs
146            .into_iter()
147            .map(|(i, doc)| {
148                (
149                    doc,
150                    embeddings.remove(&i).expect("Document should be present"),
151                )
152            })
153            .collect())
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use crate::{
160        Embed,
161        embeddings::{Embedding, EmbeddingModel, embed::EmbedError, embed::TextEmbedder},
162    };
163
164    use super::EmbeddingsBuilder;
165
166    #[derive(Clone)]
167    struct Model;
168
169    impl EmbeddingModel for Model {
170        const MAX_DOCUMENTS: usize = 5;
171
172        fn ndims(&self) -> usize {
173            10
174        }
175
176        async fn embed_texts(
177            &self,
178            documents: impl IntoIterator<Item = String> + Send,
179        ) -> Result<Vec<crate::embeddings::Embedding>, crate::embeddings::EmbeddingError> {
180            Ok(documents
181                .into_iter()
182                .map(|doc| Embedding {
183                    document: doc.to_string(),
184                    vec: vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
185                })
186                .collect())
187        }
188    }
189
190    #[derive(Clone, Debug)]
191    struct WordDefinition {
192        id: String,
193        definitions: Vec<String>,
194    }
195
196    impl Embed for WordDefinition {
197        fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
198            for definition in &self.definitions {
199                embedder.embed(definition.clone());
200            }
201            Ok(())
202        }
203    }
204
205    fn definitions_multiple_text() -> Vec<WordDefinition> {
206        vec![
207            WordDefinition {
208                id: "doc0".to_string(),
209                definitions: vec![
210                    "A green alien that lives on cold planets.".to_string(),
211                    "A fictional digital currency that originated in the animated series Rick and Morty.".to_string()
212                ]
213            },
214            WordDefinition {
215                id: "doc1".to_string(),
216                definitions: vec![
217                    "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
218                    "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
219                ]
220            }
221        ]
222    }
223
224    fn definitions_multiple_text_2() -> Vec<WordDefinition> {
225        vec![
226            WordDefinition {
227                id: "doc2".to_string(),
228                definitions: vec!["Another fake definitions".to_string()],
229            },
230            WordDefinition {
231                id: "doc3".to_string(),
232                definitions: vec!["Some fake definition".to_string()],
233            },
234        ]
235    }
236
237    #[derive(Clone, Debug)]
238    struct WordDefinitionSingle {
239        id: String,
240        definition: String,
241    }
242
243    impl Embed for WordDefinitionSingle {
244        fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
245            embedder.embed(self.definition.clone());
246            Ok(())
247        }
248    }
249
250    fn definitions_single_text() -> Vec<WordDefinitionSingle> {
251        vec![
252            WordDefinitionSingle {
253                id: "doc0".to_string(),
254                definition: "A green alien that lives on cold planets.".to_string(),
255            },
256            WordDefinitionSingle {
257                id: "doc1".to_string(),
258                definition: "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
259            }
260        ]
261    }
262
263    #[tokio::test]
264    async fn test_build_multiple_text() {
265        let fake_definitions = definitions_multiple_text();
266
267        let fake_model = Model;
268        let mut result = EmbeddingsBuilder::new(fake_model)
269            .documents(fake_definitions)
270            .unwrap()
271            .build()
272            .await
273            .unwrap();
274
275        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
276            fake_definition_1.id.cmp(&fake_definition_2.id)
277        });
278
279        assert_eq!(result.len(), 2);
280
281        let first_definition = &result[0];
282        assert_eq!(first_definition.0.id, "doc0");
283        assert_eq!(first_definition.1.len(), 2);
284        assert_eq!(
285            first_definition.1.first().document,
286            "A green alien that lives on cold planets.".to_string()
287        );
288
289        let second_definition = &result[1];
290        assert_eq!(second_definition.0.id, "doc1");
291        assert_eq!(second_definition.1.len(), 2);
292        assert_eq!(
293            second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
294        )
295    }
296
297    #[tokio::test]
298    async fn test_build_single_text() {
299        let fake_definitions = definitions_single_text();
300
301        let fake_model = Model;
302        let mut result = EmbeddingsBuilder::new(fake_model)
303            .documents(fake_definitions)
304            .unwrap()
305            .build()
306            .await
307            .unwrap();
308
309        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
310            fake_definition_1.id.cmp(&fake_definition_2.id)
311        });
312
313        assert_eq!(result.len(), 2);
314
315        let first_definition = &result[0];
316        assert_eq!(first_definition.0.id, "doc0");
317        assert_eq!(first_definition.1.len(), 1);
318        assert_eq!(
319            first_definition.1.first().document,
320            "A green alien that lives on cold planets.".to_string()
321        );
322
323        let second_definition = &result[1];
324        assert_eq!(second_definition.0.id, "doc1");
325        assert_eq!(second_definition.1.len(), 1);
326        assert_eq!(
327            second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
328        )
329    }
330
331    #[tokio::test]
332    async fn test_build_multiple_and_single_text() {
333        let fake_definitions = definitions_multiple_text();
334        let fake_definitions_single = definitions_multiple_text_2();
335
336        let fake_model = Model;
337        let mut result = EmbeddingsBuilder::new(fake_model)
338            .documents(fake_definitions)
339            .unwrap()
340            .documents(fake_definitions_single)
341            .unwrap()
342            .build()
343            .await
344            .unwrap();
345
346        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
347            fake_definition_1.id.cmp(&fake_definition_2.id)
348        });
349
350        assert_eq!(result.len(), 4);
351
352        let second_definition = &result[1];
353        assert_eq!(second_definition.0.id, "doc1");
354        assert_eq!(second_definition.1.len(), 2);
355        assert_eq!(
356            second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
357        );
358
359        let third_definition = &result[2];
360        assert_eq!(third_definition.0.id, "doc2");
361        assert_eq!(third_definition.1.len(), 1);
362        assert_eq!(
363            third_definition.1.first().document,
364            "Another fake definitions".to_string()
365        )
366    }
367
368    #[tokio::test]
369    async fn test_build_string() {
370        let bindings = definitions_multiple_text();
371        let fake_definitions = bindings.iter().map(|def| def.definitions.clone());
372
373        let fake_model = Model;
374        let mut result = EmbeddingsBuilder::new(fake_model)
375            .documents(fake_definitions)
376            .unwrap()
377            .build()
378            .await
379            .unwrap();
380
381        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
382            fake_definition_1.cmp(fake_definition_2)
383        });
384
385        assert_eq!(result.len(), 2);
386
387        let first_definition = &result[0];
388        assert_eq!(first_definition.1.len(), 2);
389        assert_eq!(
390            first_definition.1.first().document,
391            "A green alien that lives on cold planets.".to_string()
392        );
393
394        let second_definition = &result[1];
395        assert_eq!(second_definition.1.len(), 2);
396        assert_eq!(
397            second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
398        )
399    }
400}