1use std::{cmp::max, collections::HashMap};
6
7use futures::{StreamExt, stream};
8
9use crate::{
10 OneOrMany,
11 embeddings::{
12 Embed, EmbedError, Embedding, EmbeddingError, EmbeddingModel, embed::TextEmbedder,
13 },
14};
15
16#[non_exhaustive]
51pub struct EmbeddingsBuilder<M, T>
52where
53 M: EmbeddingModel,
54 T: Embed,
55{
56 model: M,
57 documents: Vec<(T, Vec<String>)>,
58}
59
60impl<M, T> EmbeddingsBuilder<M, T>
61where
62 M: EmbeddingModel,
63 T: Embed,
64{
65 pub fn new(model: M) -> Self {
67 Self {
68 model,
69 documents: vec![],
70 }
71 }
72
73 pub fn document(mut self, document: T) -> Result<Self, EmbedError> {
75 let mut embedder = TextEmbedder::default();
76 document.embed(&mut embedder)?;
77
78 self.documents.push((document, embedder.texts));
79
80 Ok(self)
81 }
82
83 pub fn documents(self, documents: impl IntoIterator<Item = T>) -> Result<Self, EmbedError> {
86 let builder = documents
87 .into_iter()
88 .try_fold(self, |builder, doc| builder.document(doc))?;
89
90 Ok(builder)
91 }
92}
93
94impl<M, T> EmbeddingsBuilder<M, T>
95where
96 M: EmbeddingModel,
97 T: Embed + Send,
98{
99 pub async fn build(self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> {
102 use stream::TryStreamExt;
103
104 let mut docs = HashMap::new();
106 let mut texts = Vec::new();
107
108 for (i, (doc, doc_texts)) in self.documents.into_iter().enumerate() {
110 docs.insert(i, doc);
111 texts.push((i, doc_texts));
112 }
113
114 let mut embeddings = stream::iter(texts.into_iter())
116 .flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text))))
118 .chunks(M::MAX_DOCUMENTS)
120 .map(|text| async {
122 let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip();
123
124 let embeddings = self.model.embed_texts(docs).await?;
125 Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::<Vec<_>>())
126 })
127 .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS))
129 .try_fold(
131 HashMap::new(),
132 |mut acc: HashMap<_, OneOrMany<Embedding>>, embeddings| async move {
133 embeddings.into_iter().for_each(|(i, embedding)| {
134 acc.entry(i)
135 .and_modify(|embeddings| embeddings.push(embedding.clone()))
136 .or_insert(OneOrMany::one(embedding.clone()));
137 });
138
139 Ok(acc)
140 },
141 )
142 .await?;
143
144 docs.into_iter()
146 .map(|(i, doc)| {
147 let embedding = embeddings.remove(&i).ok_or_else(|| {
148 crate::embeddings::EmbeddingError::ResponseError(
149 "missing embedding for document after batch merge".to_string(),
150 )
151 })?;
152 Ok::<_, crate::embeddings::EmbeddingError>((doc, embedding))
153 })
154 .collect::<Result<Vec<_>, crate::embeddings::EmbeddingError>>()
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use crate::{
161 Embed,
162 client::Nothing,
163 embeddings::{
164 Embedding, EmbeddingModel,
165 embed::{EmbedError, TextEmbedder},
166 },
167 };
168
169 use super::EmbeddingsBuilder;
170
171 #[derive(Clone)]
172 struct MockEmbeddingModel;
173
174 impl EmbeddingModel for MockEmbeddingModel {
175 const MAX_DOCUMENTS: usize = 5;
176
177 type Client = Nothing;
178
179 fn make(_: &Self::Client, _: impl Into<String>, _: Option<usize>) -> Self {
180 Self {}
181 }
182
183 fn ndims(&self) -> usize {
184 10
185 }
186
187 async fn embed_texts(
188 &self,
189 documents: impl IntoIterator<Item = String> + Send,
190 ) -> Result<Vec<crate::embeddings::Embedding>, crate::embeddings::EmbeddingError> {
191 Ok(documents
192 .into_iter()
193 .map(|doc| Embedding {
194 document: doc.to_string(),
195 vec: vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
196 })
197 .collect())
198 }
199 }
200
201 #[derive(Clone, Debug)]
202 struct WordDefinition {
203 id: String,
204 definitions: Vec<String>,
205 }
206
207 impl Embed for WordDefinition {
208 fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
209 for definition in &self.definitions {
210 embedder.embed(definition.clone());
211 }
212 Ok(())
213 }
214 }
215
216 fn definitions_multiple_text() -> Vec<WordDefinition> {
217 vec![
218 WordDefinition {
219 id: "doc0".to_string(),
220 definitions: vec![
221 "A green alien that lives on cold planets.".to_string(),
222 "A fictional digital currency that originated in the animated series Rick and Morty.".to_string()
223 ]
224 },
225 WordDefinition {
226 id: "doc1".to_string(),
227 definitions: vec![
228 "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
229 "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
230 ]
231 }
232 ]
233 }
234
235 fn definitions_multiple_text_2() -> Vec<WordDefinition> {
236 vec![
237 WordDefinition {
238 id: "doc2".to_string(),
239 definitions: vec!["Another fake definitions".to_string()],
240 },
241 WordDefinition {
242 id: "doc3".to_string(),
243 definitions: vec!["Some fake definition".to_string()],
244 },
245 ]
246 }
247
248 #[derive(Clone, Debug)]
249 struct WordDefinitionSingle {
250 id: String,
251 definition: String,
252 }
253
254 impl Embed for WordDefinitionSingle {
255 fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
256 embedder.embed(self.definition.clone());
257 Ok(())
258 }
259 }
260
261 fn definitions_single_text() -> Vec<WordDefinitionSingle> {
262 vec![
263 WordDefinitionSingle {
264 id: "doc0".to_string(),
265 definition: "A green alien that lives on cold planets.".to_string(),
266 },
267 WordDefinitionSingle {
268 id: "doc1".to_string(),
269 definition: "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
270 }
271 ]
272 }
273
274 #[tokio::test]
275 async fn test_build_multiple_text() {
276 let fake_definitions = definitions_multiple_text();
277
278 let fake_model = MockEmbeddingModel;
279 let mut result = EmbeddingsBuilder::new(fake_model)
280 .documents(fake_definitions)
281 .unwrap()
282 .build()
283 .await
284 .unwrap();
285
286 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
287 fake_definition_1.id.cmp(&fake_definition_2.id)
288 });
289
290 assert_eq!(result.len(), 2);
291
292 let first_definition = &result[0];
293 assert_eq!(first_definition.0.id, "doc0");
294 assert_eq!(first_definition.1.len(), 2);
295 assert_eq!(
296 first_definition.1.first().document,
297 "A green alien that lives on cold planets.".to_string()
298 );
299
300 let second_definition = &result[1];
301 assert_eq!(second_definition.0.id, "doc1");
302 assert_eq!(second_definition.1.len(), 2);
303 assert_eq!(
304 second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
305 )
306 }
307
308 #[tokio::test]
309 async fn test_build_single_text() {
310 let fake_definitions = definitions_single_text();
311
312 let fake_model = MockEmbeddingModel;
313 let mut result = EmbeddingsBuilder::new(fake_model)
314 .documents(fake_definitions)
315 .unwrap()
316 .build()
317 .await
318 .unwrap();
319
320 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
321 fake_definition_1.id.cmp(&fake_definition_2.id)
322 });
323
324 assert_eq!(result.len(), 2);
325
326 let first_definition = &result[0];
327 assert_eq!(first_definition.0.id, "doc0");
328 assert_eq!(first_definition.1.len(), 1);
329 assert_eq!(
330 first_definition.1.first().document,
331 "A green alien that lives on cold planets.".to_string()
332 );
333
334 let second_definition = &result[1];
335 assert_eq!(second_definition.0.id, "doc1");
336 assert_eq!(second_definition.1.len(), 1);
337 assert_eq!(
338 second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
339 )
340 }
341
342 #[tokio::test]
343 async fn test_build_multiple_and_single_text() {
344 let fake_definitions = definitions_multiple_text();
345 let fake_definitions_single = definitions_multiple_text_2();
346
347 let fake_model = MockEmbeddingModel;
348 let mut result = EmbeddingsBuilder::new(fake_model)
349 .documents(fake_definitions)
350 .unwrap()
351 .documents(fake_definitions_single)
352 .unwrap()
353 .build()
354 .await
355 .unwrap();
356
357 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
358 fake_definition_1.id.cmp(&fake_definition_2.id)
359 });
360
361 assert_eq!(result.len(), 4);
362
363 let second_definition = &result[1];
364 assert_eq!(second_definition.0.id, "doc1");
365 assert_eq!(second_definition.1.len(), 2);
366 assert_eq!(
367 second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
368 );
369
370 let third_definition = &result[2];
371 assert_eq!(third_definition.0.id, "doc2");
372 assert_eq!(third_definition.1.len(), 1);
373 assert_eq!(
374 third_definition.1.first().document,
375 "Another fake definitions".to_string()
376 )
377 }
378
379 #[tokio::test]
380 async fn test_build_string() {
381 let bindings = definitions_multiple_text();
382 let fake_definitions = bindings.iter().map(|def| def.definitions.clone());
383
384 let fake_model = MockEmbeddingModel;
385 let mut result = EmbeddingsBuilder::new(fake_model)
386 .documents(fake_definitions)
387 .unwrap()
388 .build()
389 .await
390 .unwrap();
391
392 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
393 fake_definition_1.cmp(fake_definition_2)
394 });
395
396 assert_eq!(result.len(), 2);
397
398 let first_definition = &result[0];
399 assert_eq!(first_definition.1.len(), 2);
400 assert_eq!(
401 first_definition.1.first().document,
402 "A green alien that lives on cold planets.".to_string()
403 );
404
405 let second_definition = &result[1];
406 assert_eq!(second_definition.1.len(), 2);
407 assert_eq!(
408 second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
409 )
410 }
411}