Skip to main content

alith_core/
embeddings.rs

1use async_trait::async_trait;
2use futures::stream;
3use futures::stream::StreamExt;
4use futures::stream::TryStreamExt;
5use serde::{Deserialize, Serialize};
6use std::cmp::max;
7use std::collections::HashMap;
8
9/// Struct representing an embedding
10#[derive(Clone, Default, Deserialize, Serialize, Debug)]
11pub struct EmbeddingsData {
12    pub document: String,
13    pub vec: Vec<f64>,
14}
15
16/// Trait for embeddings
17#[async_trait]
18pub trait Embeddings: Clone + Send + Sync {
19    const MAX_DOCUMENTS: usize = 1024;
20
21    /// Generate embeddings for a list of texts
22    async fn embed_texts(&self, input: Vec<String>)
23    -> Result<Vec<EmbeddingsData>, EmbeddingsError>;
24}
25
26// Trait that defines the embedding process for a document
27pub trait Embed {
28    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError>;
29}
30
31impl<T: AsRef<str>> Embed for T {
32    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
33        let text = self.as_ref().to_string();
34        embedder.embed(text);
35        Ok(())
36    }
37}
38
39// A simple struct to hold text data for embedding
40#[derive(Default)]
41pub struct TextEmbedder {
42    pub texts: Vec<String>,
43}
44
45impl TextEmbedder {
46    /// Adds input `text` string to the list of texts in the [TextEmbedder] that need to be embedded.
47    pub fn embed<S>(&mut self, text: S)
48    where
49        S: AsRef<str> + Sync + Send,
50    {
51        self.texts.push(text.as_ref().to_string());
52    }
53}
54
55// Errors related to embedding
56#[derive(Debug)]
57pub enum EmbedError {
58    Custom(String),
59}
60
61#[derive(Debug, thiserror::Error)]
62pub enum EmbeddingsError {
63    /// Json error (e.g.: serialization, deserialization)
64    #[error("JsonError: {0}")]
65    JsonError(#[from] serde_json::Error),
66    /// Error processing the document for embedding
67    #[error("DocumentError: {0}")]
68    DocumentError(Box<dyn std::error::Error + Send + Sync + 'static>),
69    /// Error parsing the completion response
70    #[error("ResponseError: {0}")]
71    ResponseError(String),
72    /// Error returned by the embedding model provider
73    #[error("ProviderError: {0}")]
74    ProviderError(String),
75}
76
77/// The main builder struct for generating embeddings
78pub struct EmbeddingsBuilder<M: Embeddings, T: Embed> {
79    model: M,
80    documents: Vec<(T, Vec<String>)>,
81}
82
83impl<M: Embeddings, T: Embed> EmbeddingsBuilder<M, T> {
84    /// Create a new embedding builder with the given model
85    pub fn new(model: M) -> Self {
86        Self {
87            model,
88            documents: vec![],
89        }
90    }
91
92    /// Add a single document to the builder
93    pub fn document(mut self, document: T) -> Result<Self, EmbedError> {
94        let mut embedder = TextEmbedder::default();
95        document.embed(&mut embedder)?;
96
97        self.documents.push((document, embedder.texts));
98        Ok(self)
99    }
100
101    /// Add multiple documents to the builder
102    pub fn documents(self, documents: impl IntoIterator<Item = T>) -> Result<Self, EmbedError> {
103        documents
104            .into_iter()
105            .try_fold(self, |builder, doc| builder.document(doc))
106    }
107}
108
109impl<M: Embeddings, T: Embed + Send> EmbeddingsBuilder<M, T> {
110    /// Generate embeddings for all documents
111    pub async fn build(self) -> Result<Vec<(T, Vec<EmbeddingsData>)>, EmbeddingsError> {
112        // Create lookup stores for documents and their corresponding texts
113        let mut docs = HashMap::new();
114        let mut texts = HashMap::new();
115
116        for (i, (doc, doc_texts)) in self.documents.into_iter().enumerate() {
117            docs.insert(i, doc);
118            texts.insert(i, doc_texts);
119        }
120
121        // Compute embeddings for the texts
122        let mut embeddings = stream::iter(texts.into_iter())
123            .flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text))))
124            .chunks(M::MAX_DOCUMENTS)
125            .map(|chunk| async {
126                let (ids, docs): (Vec<_>, Vec<_>) = chunk.into_iter().unzip();
127
128                let embeddings = self.model.embed_texts(docs).await?;
129                Ok::<_, EmbeddingsError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>())
130            })
131            .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS))
132            .try_fold(
133                HashMap::new(),
134                |mut acc: HashMap<_, Vec<EmbeddingsData>>, embeddings| async move {
135                    embeddings.into_iter().for_each(|(i, embedding)| {
136                        acc.entry(i)
137                            .and_modify(|embeds| embeds.push(embedding.clone()))
138                            .or_insert(vec![embedding]);
139                    });
140
141                    Ok(acc)
142                },
143            )
144            .await?;
145
146        // Merge the embeddings back with their respective documents
147        Ok(docs
148            .into_iter()
149            .map(|(i, doc)| {
150                (
151                    doc,
152                    embeddings
153                        .remove(&i)
154                        .expect("Document embeddings should be present"),
155                )
156            })
157            .collect())
158    }
159}