1use std::{cmp::max, collections::HashMap};
6
7use futures::{StreamExt, stream};
8
9use crate::{
10 OneOrMany,
11 embeddings::{
12 Embed, EmbedError, Embedding, EmbeddingError, EmbeddingModel, embed::TextEmbedder,
13 },
14};
15
16#[non_exhaustive]
51pub struct EmbeddingsBuilder<M, T>
52where
53 M: EmbeddingModel,
54 T: Embed,
55{
56 model: M,
57 documents: Vec<(T, Vec<String>)>,
58}
59
60impl<M, T> EmbeddingsBuilder<M, T>
61where
62 M: EmbeddingModel,
63 T: Embed,
64{
65 pub fn new(model: M) -> Self {
67 Self {
68 model,
69 documents: vec![],
70 }
71 }
72
73 pub fn document(mut self, document: T) -> Result<Self, EmbedError> {
75 let mut embedder = TextEmbedder::default();
76 document.embed(&mut embedder)?;
77
78 self.documents.push((document, embedder.texts));
79
80 Ok(self)
81 }
82
83 pub fn documents(self, documents: impl IntoIterator<Item = T>) -> Result<Self, EmbedError> {
86 let builder = documents
87 .into_iter()
88 .try_fold(self, |builder, doc| builder.document(doc))?;
89
90 Ok(builder)
91 }
92}
93
94impl<M, T> EmbeddingsBuilder<M, T>
95where
96 M: EmbeddingModel,
97 T: Embed + Send,
98{
99 pub async fn build(self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> {
102 use stream::TryStreamExt;
103
104 let mut docs = HashMap::new();
106 let mut texts = Vec::new();
107
108 for (i, (doc, doc_texts)) in self.documents.into_iter().enumerate() {
110 docs.insert(i, doc);
111 texts.push((i, doc_texts));
112 }
113
114 let mut embeddings = stream::iter(texts.into_iter())
116 .flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text))))
118 .chunks(M::MAX_DOCUMENTS)
120 .map(|text| async {
122 let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip();
123
124 let embeddings = self.model.embed_texts(docs).await?;
125 Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>())
126 })
127 .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS))
129 .try_fold(
131 HashMap::new(),
132 |mut acc: HashMap<_, OneOrMany<Embedding>>, embeddings| async move {
133 embeddings.into_iter().for_each(|(i, embedding)| {
134 acc.entry(i)
135 .and_modify(|embeddings| embeddings.push(embedding.clone()))
136 .or_insert(OneOrMany::one(embedding.clone()));
137 });
138
139 Ok(acc)
140 },
141 )
142 .await?;
143
144 Ok(docs
146 .into_iter()
147 .map(|(i, doc)| {
148 (
149 doc,
150 embeddings.remove(&i).expect("Document should be present"),
151 )
152 })
153 .collect())
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use crate::{
160 Embed,
161 embeddings::{Embedding, EmbeddingModel, embed::EmbedError, embed::TextEmbedder},
162 };
163
164 use super::EmbeddingsBuilder;
165
166 #[derive(Clone)]
167 struct Model;
168
169 impl EmbeddingModel for Model {
170 const MAX_DOCUMENTS: usize = 5;
171
172 fn ndims(&self) -> usize {
173 10
174 }
175
176 async fn embed_texts(
177 &self,
178 documents: impl IntoIterator<Item = String> + Send,
179 ) -> Result<Vec<crate::embeddings::Embedding>, crate::embeddings::EmbeddingError> {
180 Ok(documents
181 .into_iter()
182 .map(|doc| Embedding {
183 document: doc.to_string(),
184 vec: vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
185 })
186 .collect())
187 }
188 }
189
190 #[derive(Clone, Debug)]
191 struct WordDefinition {
192 id: String,
193 definitions: Vec<String>,
194 }
195
196 impl Embed for WordDefinition {
197 fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
198 for definition in &self.definitions {
199 embedder.embed(definition.clone());
200 }
201 Ok(())
202 }
203 }
204
205 fn definitions_multiple_text() -> Vec<WordDefinition> {
206 vec![
207 WordDefinition {
208 id: "doc0".to_string(),
209 definitions: vec![
210 "A green alien that lives on cold planets.".to_string(),
211 "A fictional digital currency that originated in the animated series Rick and Morty.".to_string()
212 ]
213 },
214 WordDefinition {
215 id: "doc1".to_string(),
216 definitions: vec![
217 "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
218 "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
219 ]
220 }
221 ]
222 }
223
224 fn definitions_multiple_text_2() -> Vec<WordDefinition> {
225 vec![
226 WordDefinition {
227 id: "doc2".to_string(),
228 definitions: vec!["Another fake definitions".to_string()],
229 },
230 WordDefinition {
231 id: "doc3".to_string(),
232 definitions: vec!["Some fake definition".to_string()],
233 },
234 ]
235 }
236
237 #[derive(Clone, Debug)]
238 struct WordDefinitionSingle {
239 id: String,
240 definition: String,
241 }
242
243 impl Embed for WordDefinitionSingle {
244 fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
245 embedder.embed(self.definition.clone());
246 Ok(())
247 }
248 }
249
250 fn definitions_single_text() -> Vec<WordDefinitionSingle> {
251 vec![
252 WordDefinitionSingle {
253 id: "doc0".to_string(),
254 definition: "A green alien that lives on cold planets.".to_string(),
255 },
256 WordDefinitionSingle {
257 id: "doc1".to_string(),
258 definition: "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
259 }
260 ]
261 }
262
263 #[tokio::test]
264 async fn test_build_multiple_text() {
265 let fake_definitions = definitions_multiple_text();
266
267 let fake_model = Model;
268 let mut result = EmbeddingsBuilder::new(fake_model)
269 .documents(fake_definitions)
270 .unwrap()
271 .build()
272 .await
273 .unwrap();
274
275 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
276 fake_definition_1.id.cmp(&fake_definition_2.id)
277 });
278
279 assert_eq!(result.len(), 2);
280
281 let first_definition = &result[0];
282 assert_eq!(first_definition.0.id, "doc0");
283 assert_eq!(first_definition.1.len(), 2);
284 assert_eq!(
285 first_definition.1.first().document,
286 "A green alien that lives on cold planets.".to_string()
287 );
288
289 let second_definition = &result[1];
290 assert_eq!(second_definition.0.id, "doc1");
291 assert_eq!(second_definition.1.len(), 2);
292 assert_eq!(
293 second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
294 )
295 }
296
297 #[tokio::test]
298 async fn test_build_single_text() {
299 let fake_definitions = definitions_single_text();
300
301 let fake_model = Model;
302 let mut result = EmbeddingsBuilder::new(fake_model)
303 .documents(fake_definitions)
304 .unwrap()
305 .build()
306 .await
307 .unwrap();
308
309 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
310 fake_definition_1.id.cmp(&fake_definition_2.id)
311 });
312
313 assert_eq!(result.len(), 2);
314
315 let first_definition = &result[0];
316 assert_eq!(first_definition.0.id, "doc0");
317 assert_eq!(first_definition.1.len(), 1);
318 assert_eq!(
319 first_definition.1.first().document,
320 "A green alien that lives on cold planets.".to_string()
321 );
322
323 let second_definition = &result[1];
324 assert_eq!(second_definition.0.id, "doc1");
325 assert_eq!(second_definition.1.len(), 1);
326 assert_eq!(
327 second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
328 )
329 }
330
331 #[tokio::test]
332 async fn test_build_multiple_and_single_text() {
333 let fake_definitions = definitions_multiple_text();
334 let fake_definitions_single = definitions_multiple_text_2();
335
336 let fake_model = Model;
337 let mut result = EmbeddingsBuilder::new(fake_model)
338 .documents(fake_definitions)
339 .unwrap()
340 .documents(fake_definitions_single)
341 .unwrap()
342 .build()
343 .await
344 .unwrap();
345
346 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
347 fake_definition_1.id.cmp(&fake_definition_2.id)
348 });
349
350 assert_eq!(result.len(), 4);
351
352 let second_definition = &result[1];
353 assert_eq!(second_definition.0.id, "doc1");
354 assert_eq!(second_definition.1.len(), 2);
355 assert_eq!(
356 second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
357 );
358
359 let third_definition = &result[2];
360 assert_eq!(third_definition.0.id, "doc2");
361 assert_eq!(third_definition.1.len(), 1);
362 assert_eq!(
363 third_definition.1.first().document,
364 "Another fake definitions".to_string()
365 )
366 }
367
368 #[tokio::test]
369 async fn test_build_string() {
370 let bindings = definitions_multiple_text();
371 let fake_definitions = bindings.iter().map(|def| def.definitions.clone());
372
373 let fake_model = Model;
374 let mut result = EmbeddingsBuilder::new(fake_model)
375 .documents(fake_definitions)
376 .unwrap()
377 .build()
378 .await
379 .unwrap();
380
381 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
382 fake_definition_1.cmp(fake_definition_2)
383 });
384
385 assert_eq!(result.len(), 2);
386
387 let first_definition = &result[0];
388 assert_eq!(first_definition.1.len(), 2);
389 assert_eq!(
390 first_definition.1.first().document,
391 "A green alien that lives on cold planets.".to_string()
392 );
393
394 let second_definition = &result[1];
395 assert_eq!(second_definition.1.len(), 2);
396 assert_eq!(
397 second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
398 )
399 }
400}