Skip to main content

rig_core/embeddings/
builder.rs

1//! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded
2//! and batch generates the embeddings for each object when built.
3//! Only types that implement the [Embed] trait can be added to the [EmbeddingsBuilder].
4
5use std::{cmp::max, collections::HashMap};
6
7use futures::{StreamExt, stream};
8
9use crate::{
10    OneOrMany,
11    completion::Usage,
12    embeddings::{
13        Embed, EmbedError, Embedding, EmbeddingError, EmbeddingModel, EmbeddingResponse,
14        embed::TextEmbedder,
15    },
16};
17
18/// Builder for creating embeddings from one or more documents of type `T`.
19/// Note: `T` can be any type that implements the [Embed] trait.
20///
21/// Using the builder is preferred over using [EmbeddingModel::embed_text] directly as
22/// it will batch the documents in a single request to the model provider.
23///
24/// # Example
25/// ```no_run
26/// use rig_core::{
27///     client::{EmbeddingsClient, ProviderClient},
28///     embeddings::EmbeddingsBuilder,
29///     providers::openai,
30/// };
31///
32/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
33/// // Create OpenAI client
34/// let openai_client = openai::Client::from_env()?;
35///
36/// let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_3_SMALL);
37///
38/// let embeddings = EmbeddingsBuilder::new(model.clone())
39///     .documents(vec![
40///         "1. *flurbo* (noun): A green alien that lives on cold planets.".to_string(),
41///         "2. *flurbo* (noun): A fictional digital currency.".to_string(),
42///         "1. *glarb-glarb* (noun): An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
43///         "2. *glarb-glarb* (noun): A fictional creature from marshlands.".to_string(),
44///         "1. *linlingdong* (noun): A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(),
45///         "2. *linlingdong* (noun): A rare instrument.".to_string(),
46///     ])?
47///     .build()
48///     .await?;
49/// # Ok(())
50/// # }
51/// ```
52#[non_exhaustive]
53pub struct EmbeddingsBuilder<M, T>
54where
55    M: EmbeddingModel,
56    T: Embed,
57{
58    model: M,
59    documents: Vec<(T, Vec<String>)>,
60}
61
62impl<M, T> EmbeddingsBuilder<M, T>
63where
64    M: EmbeddingModel,
65    T: Embed,
66{
67    /// Create a new embedding builder with the given embedding model
68    pub fn new(model: M) -> Self {
69        Self {
70            model,
71            documents: vec![],
72        }
73    }
74
75    /// Add a document to be embedded to the builder. `document` must implement the [Embed] trait.
76    pub fn document(mut self, document: T) -> Result<Self, EmbedError> {
77        let mut embedder = TextEmbedder::default();
78        document.embed(&mut embedder)?;
79
80        self.documents.push((document, embedder.texts));
81
82        Ok(self)
83    }
84
85    /// Add multiple documents to be embedded to the builder. `documents` must be iterable
86    /// with items that implement the [Embed] trait.
87    pub fn documents(self, documents: impl IntoIterator<Item = T>) -> Result<Self, EmbedError> {
88        let builder = documents
89            .into_iter()
90            .try_fold(self, |builder, doc| builder.document(doc))?;
91
92        Ok(builder)
93    }
94}
95
96impl<M, T> EmbeddingsBuilder<M, T>
97where
98    M: EmbeddingModel,
99    T: Embed + Send,
100{
101    /// Generate embeddings for all documents in the builder.
102    ///
103    /// Returns `(document, embeddings)` pairs. A document may produce one or many
104    /// embeddings depending on how its [`Embed`] implementation uses [`TextEmbedder`].
105    pub async fn build(self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> {
106        let (result, _usage) = self.build_with_usage().await?;
107        Ok(result)
108    }
109
110    /// Generate embeddings for all documents in the builder and return accumulated token usage.
111    ///
112    /// Returns `(document, embeddings)` pairs and the total token usage across all
113    /// batches. A document may produce one or many embeddings depending on how its
114    /// [`Embed`] implementation uses [`TextEmbedder`].
115    pub async fn build_with_usage(
116        self,
117    ) -> Result<(Vec<(T, OneOrMany<Embedding>)>, Usage), EmbeddingError> {
118        use stream::TryStreamExt;
119
120        // Store the documents and their texts in a HashMap for easy access.
121        let mut docs = HashMap::new();
122        let mut texts = Vec::new();
123
124        // Iterate over all documents in the builder and insert their docs and texts into the lookup stores.
125        for (i, (doc, doc_texts)) in self.documents.into_iter().enumerate() {
126            docs.insert(i, doc);
127            texts.push((i, doc_texts));
128        }
129
130        // Compute the embeddings.
131        let (mut embeddings, usage) = stream::iter(texts.into_iter())
132            // Merge the texts of each document into a single list of texts.
133            .flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text))))
134            // Chunk them into batches. Each batch size is at most the embedding API limit per request.
135            .chunks(M::MAX_DOCUMENTS)
136            // Generate the embeddings for each batch with usage tracking.
137            .map(|text| async {
138                let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip();
139
140                let response: EmbeddingResponse = self.model.embed_texts_with_usage(docs).await?;
141                Ok::<_, EmbeddingError>((
142                    ids.into_iter().zip(response.embeddings).collect::<Vec<_>>(),
143                    response.usage,
144                ))
145            })
146            // Parallelize the embeddings generation over 10 concurrent requests
147            .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS))
148            // Collect the embeddings into a HashMap and accumulate usage.
149            .try_fold(
150                (
151                    HashMap::<usize, OneOrMany<Embedding>>::new(),
152                    Usage::default(),
153                ),
154                |(mut acc, mut usage_acc), (chunk_embeddings, chunk_usage)| async move {
155                    chunk_embeddings.into_iter().for_each(|(i, embedding)| {
156                        acc.entry(i)
157                            .and_modify(|embeddings| embeddings.push(embedding.clone()))
158                            .or_insert(OneOrMany::one(embedding.clone()));
159                    });
160                    usage_acc += chunk_usage;
161                    Ok((acc, usage_acc))
162                },
163            )
164            .await?;
165
166        // Merge the embeddings with their respective documents
167        let result = docs
168            .into_iter()
169            .map(|(i, doc)| {
170                let embedding = embeddings.remove(&i).ok_or_else(|| {
171                    crate::embeddings::EmbeddingError::ResponseError(
172                        "missing embedding for document after batch merge".to_string(),
173                    )
174                })?;
175                Ok::<_, crate::embeddings::EmbeddingError>((doc, embedding))
176            })
177            .collect::<Result<Vec<_>, crate::embeddings::EmbeddingError>>()?;
178
179        Ok((result, usage))
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use crate::test_utils::{MockEmbeddingModel, MockMultiTextDocument, MockTextDocument};
186
187    use super::EmbeddingsBuilder;
188
189    fn definitions_multiple_text() -> Vec<MockMultiTextDocument> {
190        vec![
191            MockMultiTextDocument::new(
192                "doc0",
193                [
194                    "A green alien that lives on cold planets.",
195                    "A fictional digital currency that originated in the animated series Rick and Morty.",
196                ],
197            ),
198            MockMultiTextDocument::new(
199                "doc1",
200                [
201                    "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.",
202                    "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.",
203                ],
204            ),
205        ]
206    }
207
208    fn definitions_multiple_text_2() -> Vec<MockMultiTextDocument> {
209        vec![
210            MockMultiTextDocument::new("doc2", ["Another fake definitions"]),
211            MockMultiTextDocument::new("doc3", ["Some fake definition"]),
212        ]
213    }
214
215    fn definitions_single_text() -> Vec<MockTextDocument> {
216        vec![
217            MockTextDocument::new("doc0", "A green alien that lives on cold planets."),
218            MockTextDocument::new(
219                "doc1",
220                "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.",
221            ),
222        ]
223    }
224
225    #[tokio::test]
226    async fn test_build_multiple_text() {
227        let fake_definitions = definitions_multiple_text();
228
229        let fake_model = MockEmbeddingModel;
230        let mut result = EmbeddingsBuilder::new(fake_model)
231            .documents(fake_definitions)
232            .unwrap()
233            .build()
234            .await
235            .unwrap();
236
237        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
238            fake_definition_1.id.cmp(&fake_definition_2.id)
239        });
240
241        assert_eq!(result.len(), 2);
242
243        let first_definition = &result[0];
244        assert_eq!(first_definition.0.id, "doc0");
245        assert_eq!(first_definition.1.len(), 2);
246        assert_eq!(
247            first_definition.1.first().document,
248            "A green alien that lives on cold planets.".to_string()
249        );
250
251        let second_definition = &result[1];
252        assert_eq!(second_definition.0.id, "doc1");
253        assert_eq!(second_definition.1.len(), 2);
254        assert_eq!(
255            second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
256        )
257    }
258
259    #[tokio::test]
260    async fn test_build_single_text() {
261        let fake_definitions = definitions_single_text();
262
263        let fake_model = MockEmbeddingModel;
264        let mut result = EmbeddingsBuilder::new(fake_model)
265            .documents(fake_definitions)
266            .unwrap()
267            .build()
268            .await
269            .unwrap();
270
271        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
272            fake_definition_1.id.cmp(&fake_definition_2.id)
273        });
274
275        assert_eq!(result.len(), 2);
276
277        let first_definition = &result[0];
278        assert_eq!(first_definition.0.id, "doc0");
279        assert_eq!(first_definition.1.len(), 1);
280        assert_eq!(
281            first_definition.1.first().document,
282            "A green alien that lives on cold planets.".to_string()
283        );
284
285        let second_definition = &result[1];
286        assert_eq!(second_definition.0.id, "doc1");
287        assert_eq!(second_definition.1.len(), 1);
288        assert_eq!(
289            second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
290        )
291    }
292
293    #[tokio::test]
294    async fn test_build_multiple_and_single_text() {
295        let fake_definitions = definitions_multiple_text();
296        let fake_definitions_single = definitions_multiple_text_2();
297
298        let fake_model = MockEmbeddingModel;
299        let mut result = EmbeddingsBuilder::new(fake_model)
300            .documents(fake_definitions)
301            .unwrap()
302            .documents(fake_definitions_single)
303            .unwrap()
304            .build()
305            .await
306            .unwrap();
307
308        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
309            fake_definition_1.id.cmp(&fake_definition_2.id)
310        });
311
312        assert_eq!(result.len(), 4);
313
314        let second_definition = &result[1];
315        assert_eq!(second_definition.0.id, "doc1");
316        assert_eq!(second_definition.1.len(), 2);
317        assert_eq!(
318            second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
319        );
320
321        let third_definition = &result[2];
322        assert_eq!(third_definition.0.id, "doc2");
323        assert_eq!(third_definition.1.len(), 1);
324        assert_eq!(
325            third_definition.1.first().document,
326            "Another fake definitions".to_string()
327        )
328    }
329
330    #[tokio::test]
331    async fn test_build_string() {
332        let bindings = definitions_multiple_text();
333        let fake_definitions = bindings.iter().map(|def| def.texts.clone());
334
335        let fake_model = MockEmbeddingModel;
336        let mut result = EmbeddingsBuilder::new(fake_model)
337            .documents(fake_definitions)
338            .unwrap()
339            .build()
340            .await
341            .unwrap();
342
343        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
344            fake_definition_1.cmp(fake_definition_2)
345        });
346
347        assert_eq!(result.len(), 2);
348
349        let first_definition = &result[0];
350        assert_eq!(first_definition.1.len(), 2);
351        assert_eq!(
352            first_definition.1.first().document,
353            "A green alien that lives on cold planets.".to_string()
354        );
355
356        let second_definition = &result[1];
357        assert_eq!(second_definition.1.len(), 2);
358        assert_eq!(
359            second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
360        )
361    }
362}