Skip to main content

rig_core/embeddings/
builder.rs

1//! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded
2//! and batch generates the embeddings for each object when built.
3//! Only types that implement the [Embed] trait can be added to the [EmbeddingsBuilder].
4
5use std::{cmp::max, collections::HashMap};
6
7use futures::{StreamExt, stream};
8
9use crate::{
10    OneOrMany,
11    embeddings::{
12        Embed, EmbedError, Embedding, EmbeddingError, EmbeddingModel, embed::TextEmbedder,
13    },
14};
15
16/// Builder for creating embeddings from one or more documents of type `T`.
17/// Note: `T` can be any type that implements the [Embed] trait.
18///
19/// Using the builder is preferred over using [EmbeddingModel::embed_text] directly as
20/// it will batch the documents in a single request to the model provider.
21///
22/// # Example
23/// ```no_run
24/// use rig_core::{
25///     client::{EmbeddingsClient, ProviderClient},
26///     embeddings::EmbeddingsBuilder,
27///     providers::openai,
28/// };
29///
30/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
31/// // Create OpenAI client
32/// let openai_client = openai::Client::from_env()?;
33///
34/// let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_3_SMALL);
35///
36/// let embeddings = EmbeddingsBuilder::new(model.clone())
37///     .documents(vec![
38///         "1. *flurbo* (noun): A green alien that lives on cold planets.".to_string(),
39///         "2. *flurbo* (noun): A fictional digital currency.".to_string(),
40///         "1. *glarb-glarb* (noun): An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
41///         "2. *glarb-glarb* (noun): A fictional creature from marshlands.".to_string(),
42///         "1. *linlingdong* (noun): A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(),
43///         "2. *linlingdong* (noun): A rare instrument.".to_string(),
44///     ])?
45///     .build()
46///     .await?;
47/// # Ok(())
48/// # }
49/// ```
50#[non_exhaustive]
51pub struct EmbeddingsBuilder<M, T>
52where
53    M: EmbeddingModel,
54    T: Embed,
55{
56    model: M,
57    documents: Vec<(T, Vec<String>)>,
58}
59
60impl<M, T> EmbeddingsBuilder<M, T>
61where
62    M: EmbeddingModel,
63    T: Embed,
64{
65    /// Create a new embedding builder with the given embedding model
66    pub fn new(model: M) -> Self {
67        Self {
68            model,
69            documents: vec![],
70        }
71    }
72
73    /// Add a document to be embedded to the builder. `document` must implement the [Embed] trait.
74    pub fn document(mut self, document: T) -> Result<Self, EmbedError> {
75        let mut embedder = TextEmbedder::default();
76        document.embed(&mut embedder)?;
77
78        self.documents.push((document, embedder.texts));
79
80        Ok(self)
81    }
82
83    /// Add multiple documents to be embedded to the builder. `documents` must be iterable
84    /// with items that implement the [Embed] trait.
85    pub fn documents(self, documents: impl IntoIterator<Item = T>) -> Result<Self, EmbedError> {
86        let builder = documents
87            .into_iter()
88            .try_fold(self, |builder, doc| builder.document(doc))?;
89
90        Ok(builder)
91    }
92}
93
94impl<M, T> EmbeddingsBuilder<M, T>
95where
96    M: EmbeddingModel,
97    T: Embed + Send,
98{
99    /// Generate embeddings for all documents in the builder.
100    ///
101    /// Returns `(document, embeddings)` pairs. A document may produce one or many
102    /// embeddings depending on how its [`Embed`] implementation uses [`TextEmbedder`].
103    pub async fn build(self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> {
104        use stream::TryStreamExt;
105
106        // Store the documents and their texts in a HashMap for easy access.
107        let mut docs = HashMap::new();
108        let mut texts = Vec::new();
109
110        // Iterate over all documents in the builder and insert their docs and texts into the lookup stores.
111        for (i, (doc, doc_texts)) in self.documents.into_iter().enumerate() {
112            docs.insert(i, doc);
113            texts.push((i, doc_texts));
114        }
115
116        // Compute the embeddings.
117        let mut embeddings = stream::iter(texts.into_iter())
118            // Merge the texts of each document into a single list of texts.
119            .flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text))))
120            // Chunk them into batches. Each batch size is at most the embedding API limit per request.
121            .chunks(M::MAX_DOCUMENTS)
122            // Generate the embeddings for each batch.
123            .map(|text| async {
124                let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip();
125
126                let embeddings = self.model.embed_texts(docs).await?;
127                Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>())
128            })
129            // Parallelize the embeddings generation over 10 concurrent requests
130            .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS))
131            // Collect the embeddings into a HashMap.
132            .try_fold(
133                HashMap::new(),
134                |mut acc: HashMap<_, OneOrMany<Embedding>>, embeddings| async move {
135                    embeddings.into_iter().for_each(|(i, embedding)| {
136                        acc.entry(i)
137                            .and_modify(|embeddings| embeddings.push(embedding.clone()))
138                            .or_insert(OneOrMany::one(embedding.clone()));
139                    });
140
141                    Ok(acc)
142                },
143            )
144            .await?;
145
146        // Merge the embeddings with their respective documents
147        docs.into_iter()
148            .map(|(i, doc)| {
149                let embedding = embeddings.remove(&i).ok_or_else(|| {
150                    crate::embeddings::EmbeddingError::ResponseError(
151                        "missing embedding for document after batch merge".to_string(),
152                    )
153                })?;
154                Ok::<_, crate::embeddings::EmbeddingError>((doc, embedding))
155            })
156            .collect::<Result<Vec<_>, crate::embeddings::EmbeddingError>>()
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use crate::test_utils::{MockEmbeddingModel, MockMultiTextDocument, MockTextDocument};
163
164    use super::EmbeddingsBuilder;
165
166    fn definitions_multiple_text() -> Vec<MockMultiTextDocument> {
167        vec![
168            MockMultiTextDocument::new(
169                "doc0",
170                [
171                    "A green alien that lives on cold planets.",
172                    "A fictional digital currency that originated in the animated series Rick and Morty.",
173                ],
174            ),
175            MockMultiTextDocument::new(
176                "doc1",
177                [
178                    "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.",
179                    "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.",
180                ],
181            ),
182        ]
183    }
184
185    fn definitions_multiple_text_2() -> Vec<MockMultiTextDocument> {
186        vec![
187            MockMultiTextDocument::new("doc2", ["Another fake definitions"]),
188            MockMultiTextDocument::new("doc3", ["Some fake definition"]),
189        ]
190    }
191
192    fn definitions_single_text() -> Vec<MockTextDocument> {
193        vec![
194            MockTextDocument::new("doc0", "A green alien that lives on cold planets."),
195            MockTextDocument::new(
196                "doc1",
197                "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.",
198            ),
199        ]
200    }
201
202    #[tokio::test]
203    async fn test_build_multiple_text() {
204        let fake_definitions = definitions_multiple_text();
205
206        let fake_model = MockEmbeddingModel;
207        let mut result = EmbeddingsBuilder::new(fake_model)
208            .documents(fake_definitions)
209            .unwrap()
210            .build()
211            .await
212            .unwrap();
213
214        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
215            fake_definition_1.id.cmp(&fake_definition_2.id)
216        });
217
218        assert_eq!(result.len(), 2);
219
220        let first_definition = &result[0];
221        assert_eq!(first_definition.0.id, "doc0");
222        assert_eq!(first_definition.1.len(), 2);
223        assert_eq!(
224            first_definition.1.first().document,
225            "A green alien that lives on cold planets.".to_string()
226        );
227
228        let second_definition = &result[1];
229        assert_eq!(second_definition.0.id, "doc1");
230        assert_eq!(second_definition.1.len(), 2);
231        assert_eq!(
232            second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
233        )
234    }
235
236    #[tokio::test]
237    async fn test_build_single_text() {
238        let fake_definitions = definitions_single_text();
239
240        let fake_model = MockEmbeddingModel;
241        let mut result = EmbeddingsBuilder::new(fake_model)
242            .documents(fake_definitions)
243            .unwrap()
244            .build()
245            .await
246            .unwrap();
247
248        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
249            fake_definition_1.id.cmp(&fake_definition_2.id)
250        });
251
252        assert_eq!(result.len(), 2);
253
254        let first_definition = &result[0];
255        assert_eq!(first_definition.0.id, "doc0");
256        assert_eq!(first_definition.1.len(), 1);
257        assert_eq!(
258            first_definition.1.first().document,
259            "A green alien that lives on cold planets.".to_string()
260        );
261
262        let second_definition = &result[1];
263        assert_eq!(second_definition.0.id, "doc1");
264        assert_eq!(second_definition.1.len(), 1);
265        assert_eq!(
266            second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
267        )
268    }
269
270    #[tokio::test]
271    async fn test_build_multiple_and_single_text() {
272        let fake_definitions = definitions_multiple_text();
273        let fake_definitions_single = definitions_multiple_text_2();
274
275        let fake_model = MockEmbeddingModel;
276        let mut result = EmbeddingsBuilder::new(fake_model)
277            .documents(fake_definitions)
278            .unwrap()
279            .documents(fake_definitions_single)
280            .unwrap()
281            .build()
282            .await
283            .unwrap();
284
285        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
286            fake_definition_1.id.cmp(&fake_definition_2.id)
287        });
288
289        assert_eq!(result.len(), 4);
290
291        let second_definition = &result[1];
292        assert_eq!(second_definition.0.id, "doc1");
293        assert_eq!(second_definition.1.len(), 2);
294        assert_eq!(
295            second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
296        );
297
298        let third_definition = &result[2];
299        assert_eq!(third_definition.0.id, "doc2");
300        assert_eq!(third_definition.1.len(), 1);
301        assert_eq!(
302            third_definition.1.first().document,
303            "Another fake definitions".to_string()
304        )
305    }
306
307    #[tokio::test]
308    async fn test_build_string() {
309        let bindings = definitions_multiple_text();
310        let fake_definitions = bindings.iter().map(|def| def.texts.clone());
311
312        let fake_model = MockEmbeddingModel;
313        let mut result = EmbeddingsBuilder::new(fake_model)
314            .documents(fake_definitions)
315            .unwrap()
316            .build()
317            .await
318            .unwrap();
319
320        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
321            fake_definition_1.cmp(fake_definition_2)
322        });
323
324        assert_eq!(result.len(), 2);
325
326        let first_definition = &result[0];
327        assert_eq!(first_definition.1.len(), 2);
328        assert_eq!(
329            first_definition.1.first().document,
330            "A green alien that lives on cold planets.".to_string()
331        );
332
333        let second_definition = &result[1];
334        assert_eq!(second_definition.1.len(), 2);
335        assert_eq!(
336            second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
337        )
338    }
339}