1use async_trait::async_trait;
2use futures::stream;
3use futures::stream::StreamExt;
4use futures::stream::TryStreamExt;
5use serde::{Deserialize, Serialize};
6use std::cmp::max;
7use std::collections::HashMap;
8
9#[derive(Clone, Default, Deserialize, Serialize, Debug)]
11pub struct EmbeddingsData {
12 pub document: String,
13 pub vec: Vec<f64>,
14}
15
16#[async_trait]
18pub trait Embeddings: Clone + Send + Sync {
19 const MAX_DOCUMENTS: usize = 1024;
20
21 async fn embed_texts(&self, input: Vec<String>)
23 -> Result<Vec<EmbeddingsData>, EmbeddingsError>;
24}
25
26pub trait Embed {
28 fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError>;
29}
30
31impl<T: AsRef<str>> Embed for T {
32 fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
33 let text = self.as_ref().to_string();
34 embedder.embed(text);
35 Ok(())
36 }
37}
38
39#[derive(Default)]
41pub struct TextEmbedder {
42 pub texts: Vec<String>,
43}
44
45impl TextEmbedder {
46 pub fn embed<S>(&mut self, text: S)
48 where
49 S: AsRef<str> + Sync + Send,
50 {
51 self.texts.push(text.as_ref().to_string());
52 }
53}
54
55#[derive(Debug)]
57pub enum EmbedError {
58 Custom(String),
59}
60
61#[derive(Debug, thiserror::Error)]
62pub enum EmbeddingsError {
63 #[error("JsonError: {0}")]
65 JsonError(#[from] serde_json::Error),
66 #[error("DocumentError: {0}")]
68 DocumentError(Box<dyn std::error::Error + Send + Sync + 'static>),
69 #[error("ResponseError: {0}")]
71 ResponseError(String),
72 #[error("ProviderError: {0}")]
74 ProviderError(String),
75}
76
77pub struct EmbeddingsBuilder<M: Embeddings, T: Embed> {
79 model: M,
80 documents: Vec<(T, Vec<String>)>,
81}
82
83impl<M: Embeddings, T: Embed> EmbeddingsBuilder<M, T> {
84 pub fn new(model: M) -> Self {
86 Self {
87 model,
88 documents: vec![],
89 }
90 }
91
92 pub fn document(mut self, document: T) -> Result<Self, EmbedError> {
94 let mut embedder = TextEmbedder::default();
95 document.embed(&mut embedder)?;
96
97 self.documents.push((document, embedder.texts));
98 Ok(self)
99 }
100
101 pub fn documents(self, documents: impl IntoIterator<Item = T>) -> Result<Self, EmbedError> {
103 documents
104 .into_iter()
105 .try_fold(self, |builder, doc| builder.document(doc))
106 }
107}
108
109impl<M: Embeddings, T: Embed + Send> EmbeddingsBuilder<M, T> {
110 pub async fn build(self) -> Result<Vec<(T, Vec<EmbeddingsData>)>, EmbeddingsError> {
112 let mut docs = HashMap::new();
114 let mut texts = HashMap::new();
115
116 for (i, (doc, doc_texts)) in self.documents.into_iter().enumerate() {
117 docs.insert(i, doc);
118 texts.insert(i, doc_texts);
119 }
120
121 let mut embeddings = stream::iter(texts.into_iter())
123 .flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text))))
124 .chunks(M::MAX_DOCUMENTS)
125 .map(|chunk| async {
126 let (ids, docs): (Vec<_>, Vec<_>) = chunk.into_iter().unzip();
127
128 let embeddings = self.model.embed_texts(docs).await?;
129 Ok::<_, EmbeddingsError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>())
130 })
131 .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS))
132 .try_fold(
133 HashMap::new(),
134 |mut acc: HashMap<_, Vec<EmbeddingsData>>, embeddings| async move {
135 embeddings.into_iter().for_each(|(i, embedding)| {
136 acc.entry(i)
137 .and_modify(|embeds| embeds.push(embedding.clone()))
138 .or_insert(vec![embedding]);
139 });
140
141 Ok(acc)
142 },
143 )
144 .await?;
145
146 Ok(docs
148 .into_iter()
149 .map(|(i, doc)| {
150 (
151 doc,
152 embeddings
153 .remove(&i)
154 .expect("Document embeddings should be present"),
155 )
156 })
157 .collect())
158 }
159}