1use std::{cmp::max, collections::HashMap};
6
7use futures::{StreamExt, stream};
8
9use crate::{
10 OneOrMany,
11 completion::Usage,
12 embeddings::{
13 Embed, EmbedError, Embedding, EmbeddingError, EmbeddingModel, EmbeddingResponse,
14 embed::TextEmbedder,
15 },
16};
17
18#[non_exhaustive]
53pub struct EmbeddingsBuilder<M, T>
54where
55 M: EmbeddingModel,
56 T: Embed,
57{
58 model: M,
59 documents: Vec<(T, Vec<String>)>,
60}
61
62impl<M, T> EmbeddingsBuilder<M, T>
63where
64 M: EmbeddingModel,
65 T: Embed,
66{
67 pub fn new(model: M) -> Self {
69 Self {
70 model,
71 documents: vec![],
72 }
73 }
74
75 pub fn document(mut self, document: T) -> Result<Self, EmbedError> {
77 let mut embedder = TextEmbedder::default();
78 document.embed(&mut embedder)?;
79
80 self.documents.push((document, embedder.texts));
81
82 Ok(self)
83 }
84
85 pub fn documents(self, documents: impl IntoIterator<Item = T>) -> Result<Self, EmbedError> {
88 let builder = documents
89 .into_iter()
90 .try_fold(self, |builder, doc| builder.document(doc))?;
91
92 Ok(builder)
93 }
94}
95
96impl<M, T> EmbeddingsBuilder<M, T>
97where
98 M: EmbeddingModel,
99 T: Embed + Send,
100{
101 pub async fn build(self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> {
106 let (result, _usage) = self.build_with_usage().await?;
107 Ok(result)
108 }
109
110 pub async fn build_with_usage(
116 self,
117 ) -> Result<(Vec<(T, OneOrMany<Embedding>)>, Usage), EmbeddingError> {
118 use stream::TryStreamExt;
119
120 let mut docs = HashMap::new();
122 let mut texts = Vec::new();
123
124 for (i, (doc, doc_texts)) in self.documents.into_iter().enumerate() {
126 docs.insert(i, doc);
127 texts.push((i, doc_texts));
128 }
129
130 let (mut embeddings, usage) = stream::iter(texts.into_iter())
132 .flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text))))
134 .chunks(M::MAX_DOCUMENTS)
136 .map(|text| async {
138 let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip();
139
140 let response: EmbeddingResponse = self.model.embed_texts_with_usage(docs).await?;
141 Ok::<_, EmbeddingError>((
142 ids.into_iter().zip(response.embeddings).collect::<Vec<_>>(),
143 response.usage,
144 ))
145 })
146 .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS))
148 .try_fold(
150 (
151 HashMap::<usize, OneOrMany<Embedding>>::new(),
152 Usage::default(),
153 ),
154 |(mut acc, mut usage_acc), (chunk_embeddings, chunk_usage)| async move {
155 chunk_embeddings.into_iter().for_each(|(i, embedding)| {
156 acc.entry(i)
157 .and_modify(|embeddings| embeddings.push(embedding.clone()))
158 .or_insert(OneOrMany::one(embedding.clone()));
159 });
160 usage_acc += chunk_usage;
161 Ok((acc, usage_acc))
162 },
163 )
164 .await?;
165
166 let result = docs
168 .into_iter()
169 .map(|(i, doc)| {
170 let embedding = embeddings.remove(&i).ok_or_else(|| {
171 crate::embeddings::EmbeddingError::ResponseError(
172 "missing embedding for document after batch merge".to_string(),
173 )
174 })?;
175 Ok::<_, crate::embeddings::EmbeddingError>((doc, embedding))
176 })
177 .collect::<Result<Vec<_>, crate::embeddings::EmbeddingError>>()?;
178
179 Ok((result, usage))
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use crate::test_utils::{MockEmbeddingModel, MockMultiTextDocument, MockTextDocument};
186
187 use super::EmbeddingsBuilder;
188
189 fn definitions_multiple_text() -> Vec<MockMultiTextDocument> {
190 vec![
191 MockMultiTextDocument::new(
192 "doc0",
193 [
194 "A green alien that lives on cold planets.",
195 "A fictional digital currency that originated in the animated series Rick and Morty.",
196 ],
197 ),
198 MockMultiTextDocument::new(
199 "doc1",
200 [
201 "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.",
202 "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.",
203 ],
204 ),
205 ]
206 }
207
208 fn definitions_multiple_text_2() -> Vec<MockMultiTextDocument> {
209 vec![
210 MockMultiTextDocument::new("doc2", ["Another fake definitions"]),
211 MockMultiTextDocument::new("doc3", ["Some fake definition"]),
212 ]
213 }
214
215 fn definitions_single_text() -> Vec<MockTextDocument> {
216 vec![
217 MockTextDocument::new("doc0", "A green alien that lives on cold planets."),
218 MockTextDocument::new(
219 "doc1",
220 "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.",
221 ),
222 ]
223 }
224
225 #[tokio::test]
226 async fn test_build_multiple_text() {
227 let fake_definitions = definitions_multiple_text();
228
229 let fake_model = MockEmbeddingModel;
230 let mut result = EmbeddingsBuilder::new(fake_model)
231 .documents(fake_definitions)
232 .unwrap()
233 .build()
234 .await
235 .unwrap();
236
237 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
238 fake_definition_1.id.cmp(&fake_definition_2.id)
239 });
240
241 assert_eq!(result.len(), 2);
242
243 let first_definition = &result[0];
244 assert_eq!(first_definition.0.id, "doc0");
245 assert_eq!(first_definition.1.len(), 2);
246 assert_eq!(
247 first_definition.1.first().document,
248 "A green alien that lives on cold planets.".to_string()
249 );
250
251 let second_definition = &result[1];
252 assert_eq!(second_definition.0.id, "doc1");
253 assert_eq!(second_definition.1.len(), 2);
254 assert_eq!(
255 second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
256 )
257 }
258
259 #[tokio::test]
260 async fn test_build_single_text() {
261 let fake_definitions = definitions_single_text();
262
263 let fake_model = MockEmbeddingModel;
264 let mut result = EmbeddingsBuilder::new(fake_model)
265 .documents(fake_definitions)
266 .unwrap()
267 .build()
268 .await
269 .unwrap();
270
271 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
272 fake_definition_1.id.cmp(&fake_definition_2.id)
273 });
274
275 assert_eq!(result.len(), 2);
276
277 let first_definition = &result[0];
278 assert_eq!(first_definition.0.id, "doc0");
279 assert_eq!(first_definition.1.len(), 1);
280 assert_eq!(
281 first_definition.1.first().document,
282 "A green alien that lives on cold planets.".to_string()
283 );
284
285 let second_definition = &result[1];
286 assert_eq!(second_definition.0.id, "doc1");
287 assert_eq!(second_definition.1.len(), 1);
288 assert_eq!(
289 second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
290 )
291 }
292
293 #[tokio::test]
294 async fn test_build_multiple_and_single_text() {
295 let fake_definitions = definitions_multiple_text();
296 let fake_definitions_single = definitions_multiple_text_2();
297
298 let fake_model = MockEmbeddingModel;
299 let mut result = EmbeddingsBuilder::new(fake_model)
300 .documents(fake_definitions)
301 .unwrap()
302 .documents(fake_definitions_single)
303 .unwrap()
304 .build()
305 .await
306 .unwrap();
307
308 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
309 fake_definition_1.id.cmp(&fake_definition_2.id)
310 });
311
312 assert_eq!(result.len(), 4);
313
314 let second_definition = &result[1];
315 assert_eq!(second_definition.0.id, "doc1");
316 assert_eq!(second_definition.1.len(), 2);
317 assert_eq!(
318 second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string()
319 );
320
321 let third_definition = &result[2];
322 assert_eq!(third_definition.0.id, "doc2");
323 assert_eq!(third_definition.1.len(), 1);
324 assert_eq!(
325 third_definition.1.first().document,
326 "Another fake definitions".to_string()
327 )
328 }
329
330 #[tokio::test]
331 async fn test_build_string() {
332 let bindings = definitions_multiple_text();
333 let fake_definitions = bindings.iter().map(|def| def.texts.clone());
334
335 let fake_model = MockEmbeddingModel;
336 let mut result = EmbeddingsBuilder::new(fake_model)
337 .documents(fake_definitions)
338 .unwrap()
339 .build()
340 .await
341 .unwrap();
342
343 result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| {
344 fake_definition_1.cmp(fake_definition_2)
345 });
346
347 assert_eq!(result.len(), 2);
348
349 let first_definition = &result[0];
350 assert_eq!(first_definition.1.len(), 2);
351 assert_eq!(
352 first_definition.1.first().document,
353 "A green alien that lives on cold planets.".to_string()
354 );
355
356 let second_definition = &result[1];
357 assert_eq!(second_definition.1.len(), 2);
358 assert_eq!(
359 second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
360 )
361 }
362}