use std::{cmp::max, collections::HashMap};
use futures::{stream, StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use crate::tool::{ToolEmbedding, ToolSet, ToolType};
#[derive(Debug, thiserror::Error)]
pub enum EmbeddingError {
#[error("HttpError: {0}")]
HttpError(#[from] reqwest::Error),
#[error("JsonError: {0}")]
JsonError(#[from] serde_json::Error),
#[error("DocumentError: {0}")]
DocumentError(String),
#[error("ProviderError: {0}")]
ProviderError(String),
}
pub trait EmbeddingModel: Clone + Sync + Send {
const MAX_DOCUMENTS: usize;
fn embed_document(
&self,
document: &str,
) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + Send
where
Self: Sync,
{
async {
Ok(self
.embed_documents(vec![document.to_string()])
.await?
.first()
.cloned()
.expect("One embedding should be present"))
}
}
fn embed_documents(
&self,
documents: Vec<String>,
) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
}
#[derive(Clone, Default, Deserialize, Serialize)]
pub struct Embedding {
pub document: String,
pub vec: Vec<f64>,
}
impl PartialEq for Embedding {
fn eq(&self, other: &Self) -> bool {
self.document == other.document
}
}
impl Eq for Embedding {}
impl Embedding {
pub fn distance(&self, other: &Self) -> f64 {
let dot_product: f64 = self
.vec
.iter()
.zip(other.vec.iter())
.map(|(x, y)| x * y)
.sum();
let product_of_lengths = (self.vec.len() * other.vec.len()) as f64;
dot_product / product_of_lengths
}
}
#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)]
pub struct DocumentEmbeddings {
#[serde(rename = "_id")]
pub id: String,
pub document: serde_json::Value,
pub embeddings: Vec<Embedding>,
}
type Embeddings = Vec<DocumentEmbeddings>;
pub struct EmbeddingsBuilder<M: EmbeddingModel> {
model: M,
documents: Vec<(String, serde_json::Value, Vec<String>)>,
}
impl<M: EmbeddingModel> EmbeddingsBuilder<M> {
pub fn new(model: M) -> Self {
Self {
model,
documents: vec![],
}
}
pub fn simple_document(mut self, id: &str, document: &str) -> Self {
self.documents.push((
id.to_string(),
serde_json::Value::String(document.to_string()),
vec![document.to_string()],
));
self
}
pub fn simple_documents(mut self, documents: Vec<(String, String)>) -> Self {
self.documents
.extend(documents.into_iter().map(|(id, document)| {
(
id,
serde_json::Value::String(document.clone()),
vec![document],
)
}));
self
}
pub fn tool(mut self, tool: impl ToolEmbedding + 'static) -> Result<Self, EmbeddingError> {
self.documents.push((
tool.name(),
serde_json::to_value(tool.context())?,
tool.embedding_docs(),
));
Ok(self)
}
pub fn tools(mut self, toolset: &ToolSet) -> Result<Self, EmbeddingError> {
for (name, tool) in toolset.tools.iter() {
if let ToolType::Embedding(tool) = tool {
self.documents.push((
name.clone(),
tool.context().map_err(|e| {
EmbeddingError::DocumentError(format!(
"Failed to generate context for tool {}: {}",
name, e
))
})?,
tool.embedding_docs(),
));
}
}
Ok(self)
}
pub fn document<T: Serialize>(
mut self,
id: &str,
document: T,
embed_documents: Vec<String>,
) -> Self {
self.documents.push((
id.to_string(),
serde_json::to_value(document).expect("Document should serialize"),
embed_documents,
));
self
}
pub fn documents<T: Serialize>(mut self, documents: Vec<(String, T, Vec<String>)>) -> Self {
self.documents.extend(
documents
.into_iter()
.map(|(id, document, embed_documents)| {
(
id,
serde_json::to_value(document).expect("Document should serialize"),
embed_documents,
)
}),
);
self
}
pub fn json_document(
mut self,
id: &str,
document: serde_json::Value,
embed_documents: Vec<String>,
) -> Self {
self.documents
.push((id.to_string(), document, embed_documents));
self
}
pub fn json_documents(
mut self,
documents: Vec<(String, serde_json::Value, Vec<String>)>,
) -> Self {
self.documents.extend(documents);
self
}
pub async fn build(self) -> Result<Embeddings, EmbeddingError> {
let documents_map = self
.documents
.into_iter()
.map(|(id, document, docs)| (id, (document, docs)))
.collect::<HashMap<_, _>>();
let embeddings = stream::iter(documents_map.iter())
.flat_map(|(id, (_, docs))| {
stream::iter(docs.iter().map(|doc| (id.clone(), doc.clone())))
})
.chunks(M::MAX_DOCUMENTS)
.map(|docs| async {
let (ids, docs): (Vec<_>, Vec<_>) = docs.into_iter().unzip();
Ok::<_, EmbeddingError>(
ids.into_iter()
.zip(self.model.embed_documents(docs).await?.into_iter())
.collect::<Vec<_>>(),
)
})
.boxed()
.buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS))
.try_fold(vec![], |mut acc, mut embeddings| async move {
Ok({
acc.append(&mut embeddings);
acc
})
})
.await?;
let mut document_embeddings: HashMap<String, DocumentEmbeddings> = HashMap::new();
embeddings.into_iter().for_each(|(id, embedding)| {
let (document, _) = documents_map.get(&id).expect("Document not found");
let document_embedding =
document_embeddings
.entry(id.clone())
.or_insert_with(|| DocumentEmbeddings {
id: id.clone(),
document: document.clone(),
embeddings: vec![],
});
document_embedding.embeddings.push(embedding);
});
Ok(document_embeddings.into_values().collect())
}
}