use std::{cmp::max, collections::HashMap};
use futures::{StreamExt, stream};
use crate::{
OneOrMany,
embeddings::{
Embed, EmbedError, Embedding, EmbeddingError, EmbeddingModel, embed::TextEmbedder,
},
};
#[non_exhaustive]
pub struct EmbeddingsBuilder<M, T>
where
M: EmbeddingModel,
T: Embed,
{
model: M,
documents: Vec<(T, Vec<String>)>,
}
impl<M, T> EmbeddingsBuilder<M, T>
where
M: EmbeddingModel,
T: Embed,
{
pub fn new(model: M) -> Self {
Self {
model,
documents: vec![],
}
}
pub fn document(mut self, document: T) -> Result<Self, EmbedError> {
let mut embedder = TextEmbedder::default();
document.embed(&mut embedder)?;
self.documents.push((document, embedder.texts));
Ok(self)
}
pub fn documents(self, documents: impl IntoIterator<Item = T>) -> Result<Self, EmbedError> {
let builder = documents
.into_iter()
.try_fold(self, |builder, doc| builder.document(doc))?;
Ok(builder)
}
}
impl<M, T> EmbeddingsBuilder<M, T>
where
M: EmbeddingModel,
T: Embed + Send,
{
pub async fn build(self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> {
use stream::TryStreamExt;
let mut docs = HashMap::new();
let mut texts = Vec::new();
for (i, (doc, doc_texts)) in self.documents.into_iter().enumerate() {
docs.insert(i, doc);
texts.push((i, doc_texts));
}
let mut embeddings = stream::iter(texts.into_iter())
.flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text))))
.chunks(M::MAX_DOCUMENTS)
.map(|text| async {
let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip();
let embeddings = self.model.embed_texts(docs).await?;
Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>())
})
.buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS))
.try_fold(
HashMap::new(),
|mut acc: HashMap<_, OneOrMany<Embedding>>, embeddings| async move {
embeddings.into_iter().for_each(|(i, embedding)| {
acc.entry(i)
.and_modify(|embeddings| embeddings.push(embedding.clone()))
.or_insert(OneOrMany::one(embedding.clone()));
});
Ok(acc)
},
)
.await?;
docs.into_iter()
.map(|(i, doc)| {
let embedding = embeddings.remove(&i).ok_or_else(|| {
crate::embeddings::EmbeddingError::ResponseError(
"missing embedding for document after batch merge".to_string(),
)
})?;
Ok::<_, crate::embeddings::EmbeddingError>((doc, embedding))
})
.collect::<Result<Vec<_>, crate::embeddings::EmbeddingError>>()
}
}
#[cfg(test)]
mod tests {
use crate::test_utils::{MockEmbeddingModel, MockMultiTextDocument, MockTextDocument};
use super::EmbeddingsBuilder;
fn definitions_multiple_text() -> Vec<MockMultiTextDocument> {
vec![
MockMultiTextDocument::new(
"doc0",
[
"A green alien that lives on cold planets.",
"A fictional digital currency that originated in the animated series Rick and Morty.",
],
),
MockMultiTextDocument::new(
"doc1",
[
"An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.",
"A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.",
],
),
]
}
fn definitions_multiple_text_2() -> Vec<MockMultiTextDocument> {
vec![
MockMultiTextDocument::new("doc2", ["Another fake definitions"]),
MockMultiTextDocument::new("doc3", ["Some fake definition"]),
]
}
fn definitions_single_text() -> Vec<MockTextDocument> {
vec![
MockTextDocument::new("doc0", "A green alien that lives on cold planets."),
MockTextDocument::new(
"doc1",
"An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.",
),
]
}
#[tokio::test]
async fn test_build_multiple_text() {
let fake_definitions = definitions_multiple_text();
let fake_model = MockEmbeddingModel;
let mut result = EmbeddingsBuilder::new(fake_model)
.documents(fake_definitions)
.unwrap()
.build()
.await
.unwrap();
result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
fake_definition_1.id.cmp(&fake_definition_2.id)
});
assert_eq!(result.len(), 2);
let first_definition = &result[0];
assert_eq!(first_definition.0.id, "doc0");
assert_eq!(first_definition.1.len(), 2);
assert_eq!(
first_definition.1.first().document,
"A green alien that lives on cold planets.".to_string()
);
let second_definition = &result[1];
assert_eq!(second_definition.0.id, "doc1");
assert_eq!(second_definition.1.len(), 2);
assert_eq!(
second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
)
}
#[tokio::test]
async fn test_build_single_text() {
let fake_definitions = definitions_single_text();
let fake_model = MockEmbeddingModel;
let mut result = EmbeddingsBuilder::new(fake_model)
.documents(fake_definitions)
.unwrap()
.build()
.await
.unwrap();
result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
fake_definition_1.id.cmp(&fake_definition_2.id)
});
assert_eq!(result.len(), 2);
let first_definition = &result[0];
assert_eq!(first_definition.0.id, "doc0");
assert_eq!(first_definition.1.len(), 1);
assert_eq!(
first_definition.1.first().document,
"A green alien that lives on cold planets.".to_string()
);
let second_definition = &result[1];
assert_eq!(second_definition.0.id, "doc1");
assert_eq!(second_definition.1.len(), 1);
assert_eq!(
second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
)
}
#[tokio::test]
async fn test_build_multiple_and_single_text() {
let fake_definitions = definitions_multiple_text();
let fake_definitions_single = definitions_multiple_text_2();
let fake_model = MockEmbeddingModel;
let mut result = EmbeddingsBuilder::new(fake_model)
.documents(fake_definitions)
.unwrap()
.documents(fake_definitions_single)
.unwrap()
.build()
.await
.unwrap();
result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
fake_definition_1.id.cmp(&fake_definition_2.id)
});
assert_eq!(result.len(), 4);
let second_definition = &result[1];
assert_eq!(second_definition.0.id, "doc1");
assert_eq!(second_definition.1.len(), 2);
assert_eq!(
second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
);
let third_definition = &result[2];
assert_eq!(third_definition.0.id, "doc2");
assert_eq!(third_definition.1.len(), 1);
assert_eq!(
third_definition.1.first().document,
"Another fake definitions".to_string()
)
}
#[tokio::test]
async fn test_build_string() {
let bindings = definitions_multiple_text();
let fake_definitions = bindings.iter().map(|def| def.texts.clone());
let fake_model = MockEmbeddingModel;
let mut result = EmbeddingsBuilder::new(fake_model)
.documents(fake_definitions)
.unwrap()
.build()
.await
.unwrap();
result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
fake_definition_1.cmp(fake_definition_2)
});
assert_eq!(result.len(), 2);
let first_definition = &result[0];
assert_eq!(first_definition.1.len(), 2);
assert_eq!(
first_definition.1.first().document,
"A green alien that lives on cold planets.".to_string()
);
let second_definition = &result[1];
assert_eq!(second_definition.1.len(), 2);
assert_eq!(
second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
)
}
}