1use std::collections::BTreeMap;
2use std::time::Instant;
3
4use deserr::Deserr;
5use rand::Rng;
6use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
7use rayon::slice::ParallelSlice as _;
8use serde::{Deserialize, Serialize};
9
10use super::error::EmbedErrorKind;
11use super::json_template::ValueTemplate;
12use super::{
13 DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, REQUEST_PARALLELISM,
14};
15use crate::error::FaultSource;
16use crate::ThreadPoolNoAbort;
17
18pub struct Retry {
20 pub error: EmbedError,
21 strategy: RetryStrategy,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum ConfigurationSource {
26 OpenAi,
27 Ollama,
28 User,
29}
30
31pub enum RetryStrategy {
32 GiveUp,
33 Retry,
34 RetryTokenized,
35 RetryAfterRateLimit,
36}
37
38impl Retry {
39 pub fn give_up(error: EmbedError) -> Self {
40 Self { error, strategy: RetryStrategy::GiveUp }
41 }
42
43 pub fn retry_later(error: EmbedError) -> Self {
44 Self { error, strategy: RetryStrategy::Retry }
45 }
46
47 pub fn retry_tokenized(error: EmbedError) -> Self {
48 Self { error, strategy: RetryStrategy::RetryTokenized }
49 }
50
51 pub fn rate_limited(error: EmbedError) -> Self {
52 Self { error, strategy: RetryStrategy::RetryAfterRateLimit }
53 }
54
55 pub fn into_duration(self, attempt: u32) -> Result<std::time::Duration, EmbedError> {
56 match self.strategy {
57 RetryStrategy::GiveUp => Err(self.error),
58 RetryStrategy::Retry => Ok(std::time::Duration::from_millis((10u64).pow(attempt))),
59 RetryStrategy::RetryTokenized => Ok(std::time::Duration::from_millis(1)),
60 RetryStrategy::RetryAfterRateLimit => {
61 Ok(std::time::Duration::from_millis(100 + 10u64.pow(attempt)))
62 }
63 }
64 }
65
66 pub fn must_tokenize(&self) -> bool {
67 matches!(self.strategy, RetryStrategy::RetryTokenized)
68 }
69
70 pub fn into_error(self) -> EmbedError {
71 self.error
72 }
73}
74
75#[derive(Debug)]
76pub struct Embedder {
77 data: EmbedderData,
78 dimensions: usize,
79 distribution: Option<DistributionShift>,
80 cache: EmbeddingCache,
81}
82
83#[derive(Debug)]
85struct EmbedderData {
86 client: ureq::Agent,
87 bearer: Option<String>,
88 headers: BTreeMap<String, String>,
89 url: String,
90 request: Request,
91 response: Response,
92 configuration_source: ConfigurationSource,
93}
94
95#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
96pub struct EmbedderOptions {
97 pub api_key: Option<String>,
98 pub distribution: Option<DistributionShift>,
99 pub dimensions: Option<usize>,
100 pub url: String,
101 pub request: serde_json::Value,
102 pub response: serde_json::Value,
103 pub headers: BTreeMap<String, String>,
104}
105
106impl std::hash::Hash for EmbedderOptions {
107 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
108 self.api_key.hash(state);
109 self.distribution.hash(state);
110 self.dimensions.hash(state);
111 self.url.hash(state);
112 }
116}
117
118#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Hash, Deserr)]
119#[serde(rename_all = "camelCase")]
120#[deserr(rename_all = camelCase, deny_unknown_fields)]
121enum InputType {
122 Text,
123 TextArray,
124}
125
126impl Embedder {
127 pub fn new(
128 options: EmbedderOptions,
129 cache_cap: usize,
130 configuration_source: ConfigurationSource,
131 ) -> Result<Self, NewEmbedderError> {
132 let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer {api_key}"));
133
134 let client = ureq::AgentBuilder::new()
135 .max_idle_connections(REQUEST_PARALLELISM * 2)
136 .max_idle_connections_per_host(REQUEST_PARALLELISM * 2)
137 .timeout(std::time::Duration::from_secs(30))
138 .build();
139
140 let request = Request::new(options.request)?;
141 let response = Response::new(options.response, &request)?;
142
143 let data = EmbedderData {
144 client,
145 bearer,
146 url: options.url,
147 request,
148 response,
149 configuration_source,
150 headers: options.headers,
151 };
152
153 let dimensions = if let Some(dimensions) = options.dimensions {
154 dimensions
155 } else {
156 infer_dimensions(&data)?
157 };
158
159 Ok(Self {
160 data,
161 dimensions,
162 distribution: options.distribution,
163 cache: EmbeddingCache::new(cache_cap),
164 })
165 }
166
167 pub fn embed(
168 &self,
169 texts: Vec<String>,
170 deadline: Option<Instant>,
171 ) -> Result<Vec<Embedding>, EmbedError> {
172 embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions), deadline)
173 }
174
175 pub fn embed_ref<S>(
176 &self,
177 texts: &[S],
178 deadline: Option<Instant>,
179 ) -> Result<Vec<Embedding>, EmbedError>
180 where
181 S: AsRef<str> + Serialize,
182 {
183 embed(&self.data, texts, texts.len(), Some(self.dimensions), deadline)
184 }
185
186 pub fn embed_tokens(
187 &self,
188 tokens: &[u32],
189 deadline: Option<Instant>,
190 ) -> Result<Embedding, EmbedError> {
191 let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions), deadline)?;
192 Ok(embeddings.pop().unwrap())
194 }
195
196 pub fn embed_index(
197 &self,
198 text_chunks: Vec<Vec<String>>,
199 threads: &ThreadPoolNoAbort,
200 ) -> Result<Vec<Vec<Embedding>>, EmbedError> {
201 if threads.active_operations() >= REQUEST_PARALLELISM {
204 text_chunks.into_iter().map(move |chunk| self.embed(chunk, None)).collect()
205 } else {
206 threads
207 .install(move || {
208 text_chunks.into_par_iter().map(move |chunk| self.embed(chunk, None)).collect()
209 })
210 .map_err(|error| EmbedError {
211 kind: EmbedErrorKind::PanicInThreadPool(error),
212 fault: FaultSource::Bug,
213 })?
214 }
215 }
216
217 pub(crate) fn embed_index_ref(
218 &self,
219 texts: &[&str],
220 threads: &ThreadPoolNoAbort,
221 ) -> Result<Vec<Embedding>, EmbedError> {
222 if threads.active_operations() >= REQUEST_PARALLELISM {
225 let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
226 .chunks(self.prompt_count_in_chunk_hint())
227 .map(move |chunk| self.embed_ref(chunk, None))
228 .collect();
229
230 let embeddings = embeddings?;
231 Ok(embeddings.into_iter().flatten().collect())
232 } else {
233 threads
234 .install(move || {
235 let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
236 .par_chunks(self.prompt_count_in_chunk_hint())
237 .map(move |chunk| self.embed_ref(chunk, None))
238 .collect();
239
240 let embeddings = embeddings?;
241 Ok(embeddings.into_iter().flatten().collect())
242 })
243 .map_err(|error| EmbedError {
244 kind: EmbedErrorKind::PanicInThreadPool(error),
245 fault: FaultSource::Bug,
246 })?
247 }
248 }
249
250 pub fn chunk_count_hint(&self) -> usize {
251 super::REQUEST_PARALLELISM
252 }
253
254 pub fn prompt_count_in_chunk_hint(&self) -> usize {
255 match self.data.request.input_type() {
256 InputType::Text => 1,
257 InputType::TextArray => 10,
258 }
259 }
260
261 pub fn dimensions(&self) -> usize {
262 self.dimensions
263 }
264
265 pub fn distribution(&self) -> Option<DistributionShift> {
266 self.distribution
267 }
268
269 pub(super) fn cache(&self) -> &EmbeddingCache {
270 &self.cache
271 }
272}
273
274fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
275 let v = embed(data, ["test"].as_slice(), 1, None, None)
276 .map_err(NewEmbedderError::could_not_determine_dimension)?;
277 Ok(v.first().unwrap().len())
279}
280
281fn embed<S>(
282 data: &EmbedderData,
283 inputs: &[S],
284 expected_count: usize,
285 expected_dimension: Option<usize>,
286 deadline: Option<Instant>,
287) -> Result<Vec<Embedding>, EmbedError>
288where
289 S: Serialize,
290{
291 let request = data.client.post(&data.url);
292 let request = if let Some(bearer) = &data.bearer {
293 request.set("Authorization", bearer)
294 } else {
295 request
296 };
297 let mut request = request.set("Content-Type", "application/json");
298 for (header, value) in &data.headers {
299 request = request.set(header.as_str(), value.as_str());
300 }
301
302 let body = data.request.inject_texts(inputs);
303
304 for attempt in 0..10 {
305 let response = request.clone().send_json(&body);
306 let result = check_response(response, data.configuration_source).and_then(|response| {
307 response_to_embedding(response, data, expected_count, expected_dimension)
308 });
309
310 let retry_duration = match result {
311 Ok(response) => return Ok(response),
312 Err(retry) => {
313 tracing::warn!("Failed: {}", retry.error);
314 if let Some(deadline) = deadline {
315 let now = std::time::Instant::now();
316 if now > deadline {
317 tracing::warn!("Could not embed due to deadline");
318 return Err(retry.into_error());
319 }
320
321 let duration_to_deadline = deadline - now;
322 retry.into_duration(attempt).map(|duration| duration.min(duration_to_deadline))
323 } else {
324 retry.into_duration(attempt)
325 }
326 }
327 }?;
328
329 let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); let retry_duration = retry_duration
333 + rand::thread_rng().gen_range(std::time::Duration::ZERO..retry_duration);
334
335 tracing::warn!("Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis());
336 std::thread::sleep(retry_duration);
337 }
338
339 let response = request.send_json(&body);
340 let result = check_response(response, data.configuration_source);
341 result.map_err(Retry::into_error).and_then(|response| {
342 response_to_embedding(response, data, expected_count, expected_dimension)
343 .map_err(Retry::into_error)
344 })
345}
346
347fn check_response(
348 response: Result<ureq::Response, ureq::Error>,
349 configuration_source: ConfigurationSource,
350) -> Result<ureq::Response, Retry> {
351 match response {
352 Ok(response) => Ok(response),
353 Err(ureq::Error::Status(code, response)) => {
354 let error_response: Option<String> = response.into_string().ok();
355 Err(match code {
356 401 => Retry::give_up(EmbedError::rest_unauthorized(
357 error_response,
358 configuration_source,
359 )),
360 429 => Retry::rate_limited(EmbedError::rest_too_many_requests(error_response)),
361 400 => Retry::give_up(EmbedError::rest_bad_request(
362 error_response,
363 configuration_source,
364 )),
365 500..=599 => {
366 Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response))
367 }
368 402..=499 => {
369 Retry::give_up(EmbedError::rest_other_status_code(code, error_response))
370 }
371 _ => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)),
372 })
373 }
374 Err(ureq::Error::Transport(transport)) => {
375 Err(Retry::retry_later(EmbedError::rest_network(transport)))
376 }
377 }
378}
379
380fn response_to_embedding(
381 response: ureq::Response,
382 data: &EmbedderData,
383 expected_count: usize,
384 expected_dimensions: Option<usize>,
385) -> Result<Vec<Embedding>, Retry> {
386 let response: serde_json::Value = response
387 .into_json()
388 .map_err(EmbedError::rest_response_deserialization)
389 .map_err(Retry::retry_later)?;
390
391 let embeddings = data.response.extract_embeddings(response).map_err(Retry::give_up)?;
392
393 if embeddings.len() != expected_count {
394 return Err(Retry::give_up(EmbedError::rest_response_embedding_count(
395 expected_count,
396 embeddings.len(),
397 )));
398 }
399
400 if let Some(dimensions) = expected_dimensions {
401 for embedding in &embeddings {
402 if embedding.len() != dimensions {
403 return Err(Retry::give_up(EmbedError::rest_unexpected_dimension(
404 dimensions,
405 embedding.len(),
406 )));
407 }
408 }
409 }
410
411 Ok(embeddings)
412}
413
414pub(super) const REQUEST_PLACEHOLDER: &str = "{{text}}";
415pub(super) const RESPONSE_PLACEHOLDER: &str = "{{embedding}}";
416pub(super) const REPEAT_PLACEHOLDER: &str = "{{..}}";
417
418#[derive(Debug)]
419pub struct Request {
420 template: ValueTemplate,
421}
422
423impl Request {
424 pub fn new(template: serde_json::Value) -> Result<Self, NewEmbedderError> {
425 let template = match ValueTemplate::new(template, REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER) {
426 Ok(template) => template,
427 Err(error) => {
428 let message =
429 error.error_message("request", REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER);
430 return Err(NewEmbedderError::rest_could_not_parse_template(message));
431 }
432 };
433
434 Ok(Self { template })
435 }
436
437 fn input_type(&self) -> InputType {
438 if self.template.has_array_value() {
439 InputType::TextArray
440 } else {
441 InputType::Text
442 }
443 }
444
445 pub fn inject_texts<S: Serialize>(
446 &self,
447 texts: impl IntoIterator<Item = S>,
448 ) -> serde_json::Value {
449 self.template.inject(texts.into_iter().map(|s| serde_json::json!(s))).unwrap()
450 }
451}
452
453#[derive(Debug)]
454pub struct Response {
455 template: ValueTemplate,
456}
457
458impl Response {
459 pub fn new(template: serde_json::Value, request: &Request) -> Result<Self, NewEmbedderError> {
460 let template = match ValueTemplate::new(template, RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER)
461 {
462 Ok(template) => template,
463 Err(error) => {
464 let message =
465 error.error_message("response", RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER);
466 return Err(NewEmbedderError::rest_could_not_parse_template(message));
467 }
468 };
469
470 match (template.has_array_value(), request.template.has_array_value()) {
471 (true, true) | (false, false) => Ok(Self {template}),
472 (true, false) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has multiple embeddings, but `request` has only one text to embed".to_string())),
473 (false, true) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has a single embedding, but `request` has multiple texts to embed".to_string())),
474 }
475 }
476
477 pub fn extract_embeddings(
478 &self,
479 response: serde_json::Value,
480 ) -> Result<Vec<Embedding>, EmbedError> {
481 let extracted_values: Vec<Embedding> = match self.template.extract(response) {
482 Ok(extracted_values) => extracted_values,
483 Err(error) => {
484 let error_message =
485 error.error_message("response", "{{embedding}}", "an array of numbers");
486 return Err(EmbedError::rest_extraction_error(error_message));
487 }
488 };
489 let embeddings: Vec<Embedding> = extracted_values.into_iter().collect();
490
491 Ok(embeddings)
492 }
493}