1use std::{cmp::max, collections::HashMap};
6
7use futures::{StreamExt, stream};
8
9use crate::{
10 OneOrMany,
11 embeddings::{
12 Embed, EmbedError, Embedding, EmbeddingError, EmbeddingModel, embed::TextEmbedder,
13 },
14};
15
16pub struct EmbeddingsBuilder<M: EmbeddingModel, T: Embed> {
51 model: M,
52 documents: Vec<(T, Vec<String>)>,
53}
54
55impl<M: EmbeddingModel, T: Embed> EmbeddingsBuilder<M, T> {
56 pub fn new(model: M) -> Self {
58 Self {
59 model,
60 documents: vec![],
61 }
62 }
63
64 pub fn document(mut self, document: T) -> Result<Self, EmbedError> {
66 let mut embedder = TextEmbedder::default();
67 document.embed(&mut embedder)?;
68
69 self.documents.push((document, embedder.texts));
70
71 Ok(self)
72 }
73
74 pub fn documents(self, documents: impl IntoIterator<Item = T>) -> Result<Self, EmbedError> {
77 let builder = documents
78 .into_iter()
79 .try_fold(self, |builder, doc| builder.document(doc))?;
80
81 Ok(builder)
82 }
83}
84
85impl<M: EmbeddingModel, T: Embed + Send> EmbeddingsBuilder<M, T> {
86 pub async fn build(self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> {
89 use stream::TryStreamExt;
90
91 let mut docs = HashMap::new();
93 let mut texts = Vec::new();
94
95 for (i, (doc, doc_texts)) in self.documents.into_iter().enumerate() {
97 docs.insert(i, doc);
98 texts.push((i, doc_texts));
99 }
100
101 let mut embeddings = stream::iter(texts.into_iter())
103 .flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text))))
105 .chunks(M::MAX_DOCUMENTS)
107 .map(|text| async {
109 let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip();
110
111 let embeddings = self.model.embed_texts(docs).await?;
112 Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>())
113 })
114 .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS))
116 .try_fold(
118 HashMap::new(),
119 |mut acc: HashMap<_, OneOrMany<Embedding>>, embeddings| async move {
120 embeddings.into_iter().for_each(|(i, embedding)| {
121 acc.entry(i)
122 .and_modify(|embeddings| embeddings.push(embedding.clone()))
123 .or_insert(OneOrMany::one(embedding.clone()));
124 });
125
126 Ok(acc)
127 },
128 )
129 .await?;
130
131 Ok(docs
133 .into_iter()
134 .map(|(i, doc)| {
135 (
136 doc,
137 embeddings.remove(&i).expect("Document should be present"),
138 )
139 })
140 .collect())
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use crate::{
147 Embed,
148 embeddings::{Embedding, EmbeddingModel, embed::EmbedError, embed::TextEmbedder},
149 };
150
151 use super::EmbeddingsBuilder;
152
153 #[derive(Clone)]
154 struct Model;
155
156 impl EmbeddingModel for Model {
157 const MAX_DOCUMENTS: usize = 5;
158
159 fn ndims(&self) -> usize {
160 10
161 }
162
163 async fn embed_texts(
164 &self,
165 documents: impl IntoIterator<Item = String> + Send,
166 ) -> Result<Vec<crate::embeddings::Embedding>, crate::embeddings::EmbeddingError> {
167 Ok(documents
168 .into_iter()
169 .map(|doc| Embedding {
170 document: doc.to_string(),
171 vec: vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
172 })
173 .collect())
174 }
175 }
176
177 #[derive(Clone, Debug)]
178 struct WordDefinition {
179 id: String,
180 definitions: Vec<String>,
181 }
182
183 impl Embed for WordDefinition {
184 fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
185 for definition in &self.definitions {
186 embedder.embed(definition.clone());
187 }
188 Ok(())
189 }
190 }
191
192 fn definitions_multiple_text() -> Vec<WordDefinition> {
193 vec![
194 WordDefinition {
195 id: "doc0".to_string(),
196 definitions: vec![
197 "A green alien that lives on cold planets.".to_string(),
198 "A fictional digital currency that originated in the animated series Rick and Morty.".to_string()
199 ]
200 },
201 WordDefinition {
202 id: "doc1".to_string(),
203 definitions: vec![
204 "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
205 "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
206 ]
207 }
208 ]
209 }
210
211 fn definitions_multiple_text_2() -> Vec<WordDefinition> {
212 vec![
213 WordDefinition {
214 id: "doc2".to_string(),
215 definitions: vec!["Another fake definitions".to_string()],
216 },
217 WordDefinition {
218 id: "doc3".to_string(),
219 definitions: vec!["Some fake definition".to_string()],
220 },
221 ]
222 }
223
224 #[derive(Clone, Debug)]
225 struct WordDefinitionSingle {
226 id: String,
227 definition: String,
228 }
229
230 impl Embed for WordDefinitionSingle {
231 fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
232 embedder.embed(self.definition.clone());
233 Ok(())
234 }
235 }
236
237 fn definitions_single_text() -> Vec<WordDefinitionSingle> {
238 vec![
239 WordDefinitionSingle {
240 id: "doc0".to_string(),
241 definition: "A green alien that lives on cold planets.".to_string(),
242 },
243 WordDefinitionSingle {
244 id: "doc1".to_string(),
245 definition: "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
246 }
247 ]
248 }
249
250 #[tokio::test]
251 async fn test_build_multiple_text() {
252 let fake_definitions = definitions_multiple_text();
253
254 let fake_model = Model;
255 let mut result = EmbeddingsBuilder::new(fake_model)
256 .documents(fake_definitions)
257 .unwrap()
258 .build()
259 .await
260 .unwrap();
261
262 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
263 fake_definition_1.id.cmp(&fake_definition_2.id)
264 });
265
266 assert_eq!(result.len(), 2);
267
268 let first_definition = &result[0];
269 assert_eq!(first_definition.0.id, "doc0");
270 assert_eq!(first_definition.1.len(), 2);
271 assert_eq!(
272 first_definition.1.first().document,
273 "A green alien that lives on cold planets.".to_string()
274 );
275
276 let second_definition = &result[1];
277 assert_eq!(second_definition.0.id, "doc1");
278 assert_eq!(second_definition.1.len(), 2);
279 assert_eq!(
280 second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
281 )
282 }
283
284 #[tokio::test]
285 async fn test_build_single_text() {
286 let fake_definitions = definitions_single_text();
287
288 let fake_model = Model;
289 let mut result = EmbeddingsBuilder::new(fake_model)
290 .documents(fake_definitions)
291 .unwrap()
292 .build()
293 .await
294 .unwrap();
295
296 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
297 fake_definition_1.id.cmp(&fake_definition_2.id)
298 });
299
300 assert_eq!(result.len(), 2);
301
302 let first_definition = &result[0];
303 assert_eq!(first_definition.0.id, "doc0");
304 assert_eq!(first_definition.1.len(), 1);
305 assert_eq!(
306 first_definition.1.first().document,
307 "A green alien that lives on cold planets.".to_string()
308 );
309
310 let second_definition = &result[1];
311 assert_eq!(second_definition.0.id, "doc1");
312 assert_eq!(second_definition.1.len(), 1);
313 assert_eq!(
314 second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
315 )
316 }
317
318 #[tokio::test]
319 async fn test_build_multiple_and_single_text() {
320 let fake_definitions = definitions_multiple_text();
321 let fake_definitions_single = definitions_multiple_text_2();
322
323 let fake_model = Model;
324 let mut result = EmbeddingsBuilder::new(fake_model)
325 .documents(fake_definitions)
326 .unwrap()
327 .documents(fake_definitions_single)
328 .unwrap()
329 .build()
330 .await
331 .unwrap();
332
333 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
334 fake_definition_1.id.cmp(&fake_definition_2.id)
335 });
336
337 assert_eq!(result.len(), 4);
338
339 let second_definition = &result[1];
340 assert_eq!(second_definition.0.id, "doc1");
341 assert_eq!(second_definition.1.len(), 2);
342 assert_eq!(
343 second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
344 );
345
346 let third_definition = &result[2];
347 assert_eq!(third_definition.0.id, "doc2");
348 assert_eq!(third_definition.1.len(), 1);
349 assert_eq!(
350 third_definition.1.first().document,
351 "Another fake definitions".to_string()
352 )
353 }
354
355 #[tokio::test]
356 async fn test_build_string() {
357 let bindings = definitions_multiple_text();
358 let fake_definitions = bindings.iter().map(|def| def.definitions.clone());
359
360 let fake_model = Model;
361 let mut result = EmbeddingsBuilder::new(fake_model)
362 .documents(fake_definitions)
363 .unwrap()
364 .build()
365 .await
366 .unwrap();
367
368 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
369 fake_definition_1.cmp(fake_definition_2)
370 });
371
372 assert_eq!(result.len(), 2);
373
374 let first_definition = &result[0];
375 assert_eq!(first_definition.1.len(), 2);
376 assert_eq!(
377 first_definition.1.first().document,
378 "A green alien that lives on cold planets.".to_string()
379 );
380
381 let second_definition = &result[1];
382 assert_eq!(second_definition.1.len(), 2);
383 assert_eq!(
384 second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
385 )
386 }
387}