rig-core 0.37.0

An opinionated library for building LLM powered applications.
Documentation
//! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded
//! and batch generates the embeddings for each object when built.
//! Only types that implement the [Embed] trait can be added to the [EmbeddingsBuilder].

use std::{cmp::max, collections::HashMap};

use futures::{StreamExt, stream};

use crate::{
    OneOrMany,
    embeddings::{
        Embed, EmbedError, Embedding, EmbeddingError, EmbeddingModel, embed::TextEmbedder,
    },
};

/// Builder for creating embeddings from one or more documents of type `T`.
/// Note: `T` can be any type that implements the [Embed] trait.
///
/// Using the builder is preferred over using [EmbeddingModel::embed_text] directly as
/// it will batch the documents in a single request to the model provider.
///
/// # Example
/// ```no_run
/// use rig_core::{
///     client::{EmbeddingsClient, ProviderClient},
///     embeddings::EmbeddingsBuilder,
///     providers::openai,
/// };
///
/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
/// // Create OpenAI client
/// let openai_client = openai::Client::from_env()?;
///
/// let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_3_SMALL);
///
/// let embeddings = EmbeddingsBuilder::new(model.clone())
///     .documents(vec![
///         "1. *flurbo* (noun): A green alien that lives on cold planets.".to_string(),
///         "2. *flurbo* (noun): A fictional digital currency.".to_string(),
///         "1. *glarb-glarb* (noun): An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
///         "2. *glarb-glarb* (noun): A fictional creature from marshlands.".to_string(),
///         "1. *linlingdong* (noun): A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(),
///         "2. *linlingdong* (noun): A rare instrument.".to_string(),
///     ])?
///     .build()
///     .await?;
/// # Ok(())
/// # }
/// ```
#[non_exhaustive]
pub struct EmbeddingsBuilder<M, T>
where
    M: EmbeddingModel,
    T: Embed,
{
    model: M,
    documents: Vec<(T, Vec<String>)>,
}

impl<M, T> EmbeddingsBuilder<M, T>
where
    M: EmbeddingModel,
    T: Embed,
{
    /// Create a new embedding builder with the given embedding model
    pub fn new(model: M) -> Self {
        Self {
            model,
            documents: vec![],
        }
    }

    /// Add a document to be embedded to the builder. `document` must implement the [Embed] trait.
    pub fn document(mut self, document: T) -> Result<Self, EmbedError> {
        let mut embedder = TextEmbedder::default();
        document.embed(&mut embedder)?;

        self.documents.push((document, embedder.texts));

        Ok(self)
    }

    /// Add multiple documents to be embedded to the builder. `documents` must be iterable
    /// with items that implement the [Embed] trait.
    pub fn documents(self, documents: impl IntoIterator<Item = T>) -> Result<Self, EmbedError> {
        let builder = documents
            .into_iter()
            .try_fold(self, |builder, doc| builder.document(doc))?;

        Ok(builder)
    }
}

impl<M, T> EmbeddingsBuilder<M, T>
where
    M: EmbeddingModel,
    T: Embed + Send,
{
    /// Generate embeddings for all documents in the builder.
    ///
    /// Returns `(document, embeddings)` pairs. A document may produce one or many
    /// embeddings depending on how its [`Embed`] implementation uses [`TextEmbedder`].
    pub async fn build(self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> {
        use stream::TryStreamExt;

        // Store the documents and their texts in a HashMap for easy access.
        let mut docs = HashMap::new();
        let mut texts = Vec::new();

        // Iterate over all documents in the builder and insert their docs and texts into the lookup stores.
        for (i, (doc, doc_texts)) in self.documents.into_iter().enumerate() {
            docs.insert(i, doc);
            texts.push((i, doc_texts));
        }

        // Compute the embeddings.
        let mut embeddings = stream::iter(texts.into_iter())
            // Merge the texts of each document into a single list of texts.
            .flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text))))
            // Chunk them into batches. Each batch size is at most the embedding API limit per request.
            .chunks(M::MAX_DOCUMENTS)
            // Generate the embeddings for each batch.
            .map(|text| async {
                let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip();

                let embeddings = self.model.embed_texts(docs).await?;
                Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>())
            })
            // Parallelize the embeddings generation over 10 concurrent requests
            .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS))
            // Collect the embeddings into a HashMap.
            .try_fold(
                HashMap::new(),
                |mut acc: HashMap<_, OneOrMany<Embedding>>, embeddings| async move {
                    embeddings.into_iter().for_each(|(i, embedding)| {
                        acc.entry(i)
                            .and_modify(|embeddings| embeddings.push(embedding.clone()))
                            .or_insert(OneOrMany::one(embedding.clone()));
                    });

                    Ok(acc)
                },
            )
            .await?;

        // Merge the embeddings with their respective documents
        docs.into_iter()
            .map(|(i, doc)| {
                let embedding = embeddings.remove(&i).ok_or_else(|| {
                    crate::embeddings::EmbeddingError::ResponseError(
                        "missing embedding for document after batch merge".to_string(),
                    )
                })?;
                Ok::<_, crate::embeddings::EmbeddingError>((doc, embedding))
            })
            .collect::<Result<Vec<_>, crate::embeddings::EmbeddingError>>()
    }
}

#[cfg(test)]
mod tests {
    use crate::test_utils::{MockEmbeddingModel, MockMultiTextDocument, MockTextDocument};

    use super::EmbeddingsBuilder;

    fn definitions_multiple_text() -> Vec<MockMultiTextDocument> {
        vec![
            MockMultiTextDocument::new(
                "doc0",
                [
                    "A green alien that lives on cold planets.",
                    "A fictional digital currency that originated in the animated series Rick and Morty.",
                ],
            ),
            MockMultiTextDocument::new(
                "doc1",
                [
                    "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.",
                    "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.",
                ],
            ),
        ]
    }

    fn definitions_multiple_text_2() -> Vec<MockMultiTextDocument> {
        vec![
            MockMultiTextDocument::new("doc2", ["Another fake definitions"]),
            MockMultiTextDocument::new("doc3", ["Some fake definition"]),
        ]
    }

    fn definitions_single_text() -> Vec<MockTextDocument> {
        vec![
            MockTextDocument::new("doc0", "A green alien that lives on cold planets."),
            MockTextDocument::new(
                "doc1",
                "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.",
            ),
        ]
    }

    #[tokio::test]
    async fn test_build_multiple_text() {
        let fake_definitions = definitions_multiple_text();

        let fake_model = MockEmbeddingModel;
        let mut result = EmbeddingsBuilder::new(fake_model)
            .documents(fake_definitions)
            .unwrap()
            .build()
            .await
            .unwrap();

        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
            fake_definition_1.id.cmp(&fake_definition_2.id)
        });

        assert_eq!(result.len(), 2);

        let first_definition = &result[0];
        assert_eq!(first_definition.0.id, "doc0");
        assert_eq!(first_definition.1.len(), 2);
        assert_eq!(
            first_definition.1.first().document,
            "A green alien that lives on cold planets.".to_string()
        );

        let second_definition = &result[1];
        assert_eq!(second_definition.0.id, "doc1");
        assert_eq!(second_definition.1.len(), 2);
        assert_eq!(
            second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
        )
    }

    #[tokio::test]
    async fn test_build_single_text() {
        let fake_definitions = definitions_single_text();

        let fake_model = MockEmbeddingModel;
        let mut result = EmbeddingsBuilder::new(fake_model)
            .documents(fake_definitions)
            .unwrap()
            .build()
            .await
            .unwrap();

        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
            fake_definition_1.id.cmp(&fake_definition_2.id)
        });

        assert_eq!(result.len(), 2);

        let first_definition = &result[0];
        assert_eq!(first_definition.0.id, "doc0");
        assert_eq!(first_definition.1.len(), 1);
        assert_eq!(
            first_definition.1.first().document,
            "A green alien that lives on cold planets.".to_string()
        );

        let second_definition = &result[1];
        assert_eq!(second_definition.0.id, "doc1");
        assert_eq!(second_definition.1.len(), 1);
        assert_eq!(
            second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
        )
    }

    #[tokio::test]
    async fn test_build_multiple_and_single_text() {
        let fake_definitions = definitions_multiple_text();
        let fake_definitions_single = definitions_multiple_text_2();

        let fake_model = MockEmbeddingModel;
        let mut result = EmbeddingsBuilder::new(fake_model)
            .documents(fake_definitions)
            .unwrap()
            .documents(fake_definitions_single)
            .unwrap()
            .build()
            .await
            .unwrap();

        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
            fake_definition_1.id.cmp(&fake_definition_2.id)
        });

        assert_eq!(result.len(), 4);

        let second_definition = &result[1];
        assert_eq!(second_definition.0.id, "doc1");
        assert_eq!(second_definition.1.len(), 2);
        assert_eq!(
            second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
        );

        let third_definition = &result[2];
        assert_eq!(third_definition.0.id, "doc2");
        assert_eq!(third_definition.1.len(), 1);
        assert_eq!(
            third_definition.1.first().document,
            "Another fake definitions".to_string()
        )
    }

    #[tokio::test]
    async fn test_build_string() {
        let bindings = definitions_multiple_text();
        let fake_definitions = bindings.iter().map(|def| def.texts.clone());

        let fake_model = MockEmbeddingModel;
        let mut result = EmbeddingsBuilder::new(fake_model)
            .documents(fake_definitions)
            .unwrap()
            .build()
            .await
            .unwrap();

        result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
            fake_definition_1.cmp(fake_definition_2)
        });

        assert_eq!(result.len(), 2);

        let first_definition = &result[0];
        assert_eq!(first_definition.1.len(), 2);
        assert_eq!(
            first_definition.1.first().document,
            "A green alien that lives on cold planets.".to_string()
        );

        let second_definition = &result[1];
        assert_eq!(second_definition.1.len(), 2);
        assert_eq!(
            second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
        )
    }
}