1use std::{cmp::max, collections::HashMap};
6
7use futures::{StreamExt, stream};
8
9use crate::{
10 OneOrMany,
11 embeddings::{
12 Embed, EmbedError, Embedding, EmbeddingError, EmbeddingModel, embed::TextEmbedder,
13 },
14};
15
16#[non_exhaustive]
51pub struct EmbeddingsBuilder<M, T>
52where
53 M: EmbeddingModel,
54 T: Embed,
55{
56 model: M,
57 documents: Vec<(T, Vec<String>)>,
58}
59
60impl<M, T> EmbeddingsBuilder<M, T>
61where
62 M: EmbeddingModel,
63 T: Embed,
64{
65 pub fn new(model: M) -> Self {
67 Self {
68 model,
69 documents: vec![],
70 }
71 }
72
73 pub fn document(mut self, document: T) -> Result<Self, EmbedError> {
75 let mut embedder = TextEmbedder::default();
76 document.embed(&mut embedder)?;
77
78 self.documents.push((document, embedder.texts));
79
80 Ok(self)
81 }
82
83 pub fn documents(self, documents: impl IntoIterator<Item = T>) -> Result<Self, EmbedError> {
86 let builder = documents
87 .into_iter()
88 .try_fold(self, |builder, doc| builder.document(doc))?;
89
90 Ok(builder)
91 }
92}
93
94impl<M, T> EmbeddingsBuilder<M, T>
95where
96 M: EmbeddingModel,
97 T: Embed + Send,
98{
99 pub async fn build(self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> {
104 use stream::TryStreamExt;
105
106 let mut docs = HashMap::new();
108 let mut texts = Vec::new();
109
110 for (i, (doc, doc_texts)) in self.documents.into_iter().enumerate() {
112 docs.insert(i, doc);
113 texts.push((i, doc_texts));
114 }
115
116 let mut embeddings = stream::iter(texts.into_iter())
118 .flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text))))
120 .chunks(M::MAX_DOCUMENTS)
122 .map(|text| async {
124 let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip();
125
126 let embeddings = self.model.embed_texts(docs).await?;
127 Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>())
128 })
129 .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS))
131 .try_fold(
133 HashMap::new(),
134 |mut acc: HashMap<_, OneOrMany<Embedding>>, embeddings| async move {
135 embeddings.into_iter().for_each(|(i, embedding)| {
136 acc.entry(i)
137 .and_modify(|embeddings| embeddings.push(embedding.clone()))
138 .or_insert(OneOrMany::one(embedding.clone()));
139 });
140
141 Ok(acc)
142 },
143 )
144 .await?;
145
146 docs.into_iter()
148 .map(|(i, doc)| {
149 let embedding = embeddings.remove(&i).ok_or_else(|| {
150 crate::embeddings::EmbeddingError::ResponseError(
151 "missing embedding for document after batch merge".to_string(),
152 )
153 })?;
154 Ok::<_, crate::embeddings::EmbeddingError>((doc, embedding))
155 })
156 .collect::<Result<Vec<_>, crate::embeddings::EmbeddingError>>()
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use crate::test_utils::{MockEmbeddingModel, MockMultiTextDocument, MockTextDocument};
163
164 use super::EmbeddingsBuilder;
165
166 fn definitions_multiple_text() -> Vec<MockMultiTextDocument> {
167 vec![
168 MockMultiTextDocument::new(
169 "doc0",
170 [
171 "A green alien that lives on cold planets.",
172 "A fictional digital currency that originated in the animated series Rick and Morty.",
173 ],
174 ),
175 MockMultiTextDocument::new(
176 "doc1",
177 [
178 "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.",
179 "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.",
180 ],
181 ),
182 ]
183 }
184
185 fn definitions_multiple_text_2() -> Vec<MockMultiTextDocument> {
186 vec![
187 MockMultiTextDocument::new("doc2", ["Another fake definitions"]),
188 MockMultiTextDocument::new("doc3", ["Some fake definition"]),
189 ]
190 }
191
192 fn definitions_single_text() -> Vec<MockTextDocument> {
193 vec![
194 MockTextDocument::new("doc0", "A green alien that lives on cold planets."),
195 MockTextDocument::new(
196 "doc1",
197 "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.",
198 ),
199 ]
200 }
201
202 #[tokio::test]
203 async fn test_build_multiple_text() {
204 let fake_definitions = definitions_multiple_text();
205
206 let fake_model = MockEmbeddingModel;
207 let mut result = EmbeddingsBuilder::new(fake_model)
208 .documents(fake_definitions)
209 .unwrap()
210 .build()
211 .await
212 .unwrap();
213
214 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
215 fake_definition_1.id.cmp(&fake_definition_2.id)
216 });
217
218 assert_eq!(result.len(), 2);
219
220 let first_definition = &result[0];
221 assert_eq!(first_definition.0.id, "doc0");
222 assert_eq!(first_definition.1.len(), 2);
223 assert_eq!(
224 first_definition.1.first().document,
225 "A green alien that lives on cold planets.".to_string()
226 );
227
228 let second_definition = &result[1];
229 assert_eq!(second_definition.0.id, "doc1");
230 assert_eq!(second_definition.1.len(), 2);
231 assert_eq!(
232 second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
233 )
234 }
235
236 #[tokio::test]
237 async fn test_build_single_text() {
238 let fake_definitions = definitions_single_text();
239
240 let fake_model = MockEmbeddingModel;
241 let mut result = EmbeddingsBuilder::new(fake_model)
242 .documents(fake_definitions)
243 .unwrap()
244 .build()
245 .await
246 .unwrap();
247
248 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
249 fake_definition_1.id.cmp(&fake_definition_2.id)
250 });
251
252 assert_eq!(result.len(), 2);
253
254 let first_definition = &result[0];
255 assert_eq!(first_definition.0.id, "doc0");
256 assert_eq!(first_definition.1.len(), 1);
257 assert_eq!(
258 first_definition.1.first().document,
259 "A green alien that lives on cold planets.".to_string()
260 );
261
262 let second_definition = &result[1];
263 assert_eq!(second_definition.0.id, "doc1");
264 assert_eq!(second_definition.1.len(), 1);
265 assert_eq!(
266 second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
267 )
268 }
269
270 #[tokio::test]
271 async fn test_build_multiple_and_single_text() {
272 let fake_definitions = definitions_multiple_text();
273 let fake_definitions_single = definitions_multiple_text_2();
274
275 let fake_model = MockEmbeddingModel;
276 let mut result = EmbeddingsBuilder::new(fake_model)
277 .documents(fake_definitions)
278 .unwrap()
279 .documents(fake_definitions_single)
280 .unwrap()
281 .build()
282 .await
283 .unwrap();
284
285 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
286 fake_definition_1.id.cmp(&fake_definition_2.id)
287 });
288
289 assert_eq!(result.len(), 4);
290
291 let second_definition = &result[1];
292 assert_eq!(second_definition.0.id, "doc1");
293 assert_eq!(second_definition.1.len(), 2);
294 assert_eq!(
295 second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
296 );
297
298 let third_definition = &result[2];
299 assert_eq!(third_definition.0.id, "doc2");
300 assert_eq!(third_definition.1.len(), 1);
301 assert_eq!(
302 third_definition.1.first().document,
303 "Another fake definitions".to_string()
304 )
305 }
306
307 #[tokio::test]
308 async fn test_build_string() {
309 let bindings = definitions_multiple_text();
310 let fake_definitions = bindings.iter().map(|def| def.texts.clone());
311
312 let fake_model = MockEmbeddingModel;
313 let mut result = EmbeddingsBuilder::new(fake_model)
314 .documents(fake_definitions)
315 .unwrap()
316 .build()
317 .await
318 .unwrap();
319
320 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
321 fake_definition_1.cmp(fake_definition_2)
322 });
323
324 assert_eq!(result.len(), 2);
325
326 let first_definition = &result[0];
327 assert_eq!(first_definition.1.len(), 2);
328 assert_eq!(
329 first_definition.1.first().document,
330 "A green alien that lives on cold planets.".to_string()
331 );
332
333 let second_definition = &result[1];
334 assert_eq!(second_definition.1.len(), 2);
335 assert_eq!(
336 second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
337 )
338 }
339}