1use std::{cmp::max, collections::HashMap};
6
7use futures::{StreamExt, stream};
8
9use crate::{
10 OneOrMany,
11 embeddings::{
12 Embed, EmbedError, Embedding, EmbeddingError, EmbeddingModel, embed::TextEmbedder,
13 },
14};
15
16#[non_exhaustive]
51pub struct EmbeddingsBuilder<M, T>
52where
53 M: EmbeddingModel,
54 T: Embed,
55{
56 model: M,
57 documents: Vec<(T, Vec<String>)>,
58}
59
60impl<M, T> EmbeddingsBuilder<M, T>
61where
62 M: EmbeddingModel,
63 T: Embed,
64{
65 pub fn new(model: M) -> Self {
67 Self {
68 model,
69 documents: vec![],
70 }
71 }
72
73 pub fn document(mut self, document: T) -> Result<Self, EmbedError> {
75 let mut embedder = TextEmbedder::default();
76 document.embed(&mut embedder)?;
77
78 self.documents.push((document, embedder.texts));
79
80 Ok(self)
81 }
82
83 pub fn documents(self, documents: impl IntoIterator<Item = T>) -> Result<Self, EmbedError> {
86 let builder = documents
87 .into_iter()
88 .try_fold(self, |builder, doc| builder.document(doc))?;
89
90 Ok(builder)
91 }
92}
93
94impl<M, T> EmbeddingsBuilder<M, T>
95where
96 M: EmbeddingModel,
97 T: Embed + Send,
98{
99 pub async fn build(self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> {
102 use stream::TryStreamExt;
103
104 let mut docs = HashMap::new();
106 let mut texts = Vec::new();
107
108 for (i, (doc, doc_texts)) in self.documents.into_iter().enumerate() {
110 docs.insert(i, doc);
111 texts.push((i, doc_texts));
112 }
113
114 let mut embeddings = stream::iter(texts.into_iter())
116 .flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text))))
118 .chunks(M::MAX_DOCUMENTS)
120 .map(|text| async {
122 let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip();
123
124 let embeddings = self.model.embed_texts(docs).await?;
125 Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>())
126 })
127 .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS))
129 .try_fold(
131 HashMap::new(),
132 |mut acc: HashMap<_, OneOrMany<Embedding>>, embeddings| async move {
133 embeddings.into_iter().for_each(|(i, embedding)| {
134 acc.entry(i)
135 .and_modify(|embeddings| embeddings.push(embedding.clone()))
136 .or_insert(OneOrMany::one(embedding.clone()));
137 });
138
139 Ok(acc)
140 },
141 )
142 .await?;
143
144 Ok(docs
146 .into_iter()
147 .map(|(i, doc)| {
148 (
149 doc,
150 embeddings.remove(&i).expect("Document should be present"),
151 )
152 })
153 .collect())
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use crate::{
160 Embed,
161 client::Nothing,
162 embeddings::{
163 Embedding, EmbeddingModel,
164 embed::{EmbedError, TextEmbedder},
165 },
166 };
167
168 use super::EmbeddingsBuilder;
169
170 #[derive(Clone)]
171 struct Model;
172
173 impl EmbeddingModel for Model {
174 const MAX_DOCUMENTS: usize = 5;
175
176 type Client = Nothing;
177
178 fn make(_: &Self::Client, _: impl Into<String>, _: Option<usize>) -> Self {
179 Self
180 }
181
182 fn ndims(&self) -> usize {
183 10
184 }
185
186 async fn embed_texts(
187 &self,
188 documents: impl IntoIterator<Item = String> + Send,
189 ) -> Result<Vec<crate::embeddings::Embedding>, crate::embeddings::EmbeddingError> {
190 Ok(documents
191 .into_iter()
192 .map(|doc| Embedding {
193 document: doc.to_string(),
194 vec: vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
195 })
196 .collect())
197 }
198 }
199
200 #[derive(Clone, Debug)]
201 struct WordDefinition {
202 id: String,
203 definitions: Vec<String>,
204 }
205
206 impl Embed for WordDefinition {
207 fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
208 for definition in &self.definitions {
209 embedder.embed(definition.clone());
210 }
211 Ok(())
212 }
213 }
214
215 fn definitions_multiple_text() -> Vec<WordDefinition> {
216 vec![
217 WordDefinition {
218 id: "doc0".to_string(),
219 definitions: vec![
220 "A green alien that lives on cold planets.".to_string(),
221 "A fictional digital currency that originated in the animated series Rick and Morty.".to_string()
222 ]
223 },
224 WordDefinition {
225 id: "doc1".to_string(),
226 definitions: vec![
227 "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
228 "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
229 ]
230 }
231 ]
232 }
233
234 fn definitions_multiple_text_2() -> Vec<WordDefinition> {
235 vec![
236 WordDefinition {
237 id: "doc2".to_string(),
238 definitions: vec!["Another fake definitions".to_string()],
239 },
240 WordDefinition {
241 id: "doc3".to_string(),
242 definitions: vec!["Some fake definition".to_string()],
243 },
244 ]
245 }
246
247 #[derive(Clone, Debug)]
248 struct WordDefinitionSingle {
249 id: String,
250 definition: String,
251 }
252
253 impl Embed for WordDefinitionSingle {
254 fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
255 embedder.embed(self.definition.clone());
256 Ok(())
257 }
258 }
259
260 fn definitions_single_text() -> Vec<WordDefinitionSingle> {
261 vec![
262 WordDefinitionSingle {
263 id: "doc0".to_string(),
264 definition: "A green alien that lives on cold planets.".to_string(),
265 },
266 WordDefinitionSingle {
267 id: "doc1".to_string(),
268 definition: "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
269 }
270 ]
271 }
272
273 #[tokio::test]
274 async fn test_build_multiple_text() {
275 let fake_definitions = definitions_multiple_text();
276
277 let fake_model = Model;
278 let mut result = EmbeddingsBuilder::new(fake_model)
279 .documents(fake_definitions)
280 .unwrap()
281 .build()
282 .await
283 .unwrap();
284
285 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
286 fake_definition_1.id.cmp(&fake_definition_2.id)
287 });
288
289 assert_eq!(result.len(), 2);
290
291 let first_definition = &result[0];
292 assert_eq!(first_definition.0.id, "doc0");
293 assert_eq!(first_definition.1.len(), 2);
294 assert_eq!(
295 first_definition.1.first().document,
296 "A green alien that lives on cold planets.".to_string()
297 );
298
299 let second_definition = &result[1];
300 assert_eq!(second_definition.0.id, "doc1");
301 assert_eq!(second_definition.1.len(), 2);
302 assert_eq!(
303 second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
304 )
305 }
306
307 #[tokio::test]
308 async fn test_build_single_text() {
309 let fake_definitions = definitions_single_text();
310
311 let fake_model = Model;
312 let mut result = EmbeddingsBuilder::new(fake_model)
313 .documents(fake_definitions)
314 .unwrap()
315 .build()
316 .await
317 .unwrap();
318
319 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
320 fake_definition_1.id.cmp(&fake_definition_2.id)
321 });
322
323 assert_eq!(result.len(), 2);
324
325 let first_definition = &result[0];
326 assert_eq!(first_definition.0.id, "doc0");
327 assert_eq!(first_definition.1.len(), 1);
328 assert_eq!(
329 first_definition.1.first().document,
330 "A green alien that lives on cold planets.".to_string()
331 );
332
333 let second_definition = &result[1];
334 assert_eq!(second_definition.0.id, "doc1");
335 assert_eq!(second_definition.1.len(), 1);
336 assert_eq!(
337 second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
338 )
339 }
340
341 #[tokio::test]
342 async fn test_build_multiple_and_single_text() {
343 let fake_definitions = definitions_multiple_text();
344 let fake_definitions_single = definitions_multiple_text_2();
345
346 let fake_model = Model;
347 let mut result = EmbeddingsBuilder::new(fake_model)
348 .documents(fake_definitions)
349 .unwrap()
350 .documents(fake_definitions_single)
351 .unwrap()
352 .build()
353 .await
354 .unwrap();
355
356 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
357 fake_definition_1.id.cmp(&fake_definition_2.id)
358 });
359
360 assert_eq!(result.len(), 4);
361
362 let second_definition = &result[1];
363 assert_eq!(second_definition.0.id, "doc1");
364 assert_eq!(second_definition.1.len(), 2);
365 assert_eq!(
366 second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
367 );
368
369 let third_definition = &result[2];
370 assert_eq!(third_definition.0.id, "doc2");
371 assert_eq!(third_definition.1.len(), 1);
372 assert_eq!(
373 third_definition.1.first().document,
374 "Another fake definitions".to_string()
375 )
376 }
377
378 #[tokio::test]
379 async fn test_build_string() {
380 let bindings = definitions_multiple_text();
381 let fake_definitions = bindings.iter().map(|def| def.definitions.clone());
382
383 let fake_model = Model;
384 let mut result = EmbeddingsBuilder::new(fake_model)
385 .documents(fake_definitions)
386 .unwrap()
387 .build()
388 .await
389 .unwrap();
390
391 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
392 fake_definition_1.cmp(fake_definition_2)
393 });
394
395 assert_eq!(result.len(), 2);
396
397 let first_definition = &result[0];
398 assert_eq!(first_definition.1.len(), 2);
399 assert_eq!(
400 first_definition.1.first().document,
401 "A green alien that lives on cold planets.".to_string()
402 );
403
404 let second_definition = &result[1];
405 assert_eq!(second_definition.1.len(), 2);
406 assert_eq!(
407 second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
408 )
409 }
410}