milli_core/vector/
rest.rs

1use std::collections::BTreeMap;
2use std::time::Instant;
3
4use deserr::Deserr;
5use rand::Rng;
6use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
7use rayon::slice::ParallelSlice as _;
8use serde::{Deserialize, Serialize};
9
10use super::error::EmbedErrorKind;
11use super::json_template::ValueTemplate;
12use super::{
13    DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, REQUEST_PARALLELISM,
14};
15use crate::error::FaultSource;
16use crate::ThreadPoolNoAbort;
17
18// retrying in case of failure
19pub struct Retry {
20    pub error: EmbedError,
21    strategy: RetryStrategy,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum ConfigurationSource {
26    OpenAi,
27    Ollama,
28    User,
29}
30
31pub enum RetryStrategy {
32    GiveUp,
33    Retry,
34    RetryTokenized,
35    RetryAfterRateLimit,
36}
37
38impl Retry {
39    pub fn give_up(error: EmbedError) -> Self {
40        Self { error, strategy: RetryStrategy::GiveUp }
41    }
42
43    pub fn retry_later(error: EmbedError) -> Self {
44        Self { error, strategy: RetryStrategy::Retry }
45    }
46
47    pub fn retry_tokenized(error: EmbedError) -> Self {
48        Self { error, strategy: RetryStrategy::RetryTokenized }
49    }
50
51    pub fn rate_limited(error: EmbedError) -> Self {
52        Self { error, strategy: RetryStrategy::RetryAfterRateLimit }
53    }
54
55    pub fn into_duration(self, attempt: u32) -> Result<std::time::Duration, EmbedError> {
56        match self.strategy {
57            RetryStrategy::GiveUp => Err(self.error),
58            RetryStrategy::Retry => Ok(std::time::Duration::from_millis((10u64).pow(attempt))),
59            RetryStrategy::RetryTokenized => Ok(std::time::Duration::from_millis(1)),
60            RetryStrategy::RetryAfterRateLimit => {
61                Ok(std::time::Duration::from_millis(100 + 10u64.pow(attempt)))
62            }
63        }
64    }
65
66    pub fn must_tokenize(&self) -> bool {
67        matches!(self.strategy, RetryStrategy::RetryTokenized)
68    }
69
70    pub fn into_error(self) -> EmbedError {
71        self.error
72    }
73}
74
75#[derive(Debug)]
76pub struct Embedder {
77    data: EmbedderData,
78    dimensions: usize,
79    distribution: Option<DistributionShift>,
80    cache: EmbeddingCache,
81}
82
83/// All data needed to perform requests and parse responses
84#[derive(Debug)]
85struct EmbedderData {
86    client: ureq::Agent,
87    bearer: Option<String>,
88    headers: BTreeMap<String, String>,
89    url: String,
90    request: Request,
91    response: Response,
92    configuration_source: ConfigurationSource,
93}
94
95#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
96pub struct EmbedderOptions {
97    pub api_key: Option<String>,
98    pub distribution: Option<DistributionShift>,
99    pub dimensions: Option<usize>,
100    pub url: String,
101    pub request: serde_json::Value,
102    pub response: serde_json::Value,
103    pub headers: BTreeMap<String, String>,
104}
105
106impl std::hash::Hash for EmbedderOptions {
107    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
108        self.api_key.hash(state);
109        self.distribution.hash(state);
110        self.dimensions.hash(state);
111        self.url.hash(state);
112        // skip hashing the request and response
113        // collisions in regular usage should be minimal,
114        // and the list is limited to 256 values anyway
115    }
116}
117
118#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Hash, Deserr)]
119#[serde(rename_all = "camelCase")]
120#[deserr(rename_all = camelCase, deny_unknown_fields)]
121enum InputType {
122    Text,
123    TextArray,
124}
125
126impl Embedder {
127    pub fn new(
128        options: EmbedderOptions,
129        cache_cap: usize,
130        configuration_source: ConfigurationSource,
131    ) -> Result<Self, NewEmbedderError> {
132        let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer {api_key}"));
133
134        let client = ureq::AgentBuilder::new()
135            .max_idle_connections(REQUEST_PARALLELISM * 2)
136            .max_idle_connections_per_host(REQUEST_PARALLELISM * 2)
137            .timeout(std::time::Duration::from_secs(30))
138            .build();
139
140        let request = Request::new(options.request)?;
141        let response = Response::new(options.response, &request)?;
142
143        let data = EmbedderData {
144            client,
145            bearer,
146            url: options.url,
147            request,
148            response,
149            configuration_source,
150            headers: options.headers,
151        };
152
153        let dimensions = if let Some(dimensions) = options.dimensions {
154            dimensions
155        } else {
156            infer_dimensions(&data)?
157        };
158
159        Ok(Self {
160            data,
161            dimensions,
162            distribution: options.distribution,
163            cache: EmbeddingCache::new(cache_cap),
164        })
165    }
166
167    pub fn embed(
168        &self,
169        texts: Vec<String>,
170        deadline: Option<Instant>,
171    ) -> Result<Vec<Embedding>, EmbedError> {
172        embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions), deadline)
173    }
174
175    pub fn embed_ref<S>(
176        &self,
177        texts: &[S],
178        deadline: Option<Instant>,
179    ) -> Result<Vec<Embedding>, EmbedError>
180    where
181        S: AsRef<str> + Serialize,
182    {
183        embed(&self.data, texts, texts.len(), Some(self.dimensions), deadline)
184    }
185
186    pub fn embed_tokens(
187        &self,
188        tokens: &[u32],
189        deadline: Option<Instant>,
190    ) -> Result<Embedding, EmbedError> {
191        let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions), deadline)?;
192        // unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error
193        Ok(embeddings.pop().unwrap())
194    }
195
196    pub fn embed_index(
197        &self,
198        text_chunks: Vec<Vec<String>>,
199        threads: &ThreadPoolNoAbort,
200    ) -> Result<Vec<Vec<Embedding>>, EmbedError> {
201        // This condition helps reduce the number of active rayon jobs
202        // so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
203        if threads.active_operations() >= REQUEST_PARALLELISM {
204            text_chunks.into_iter().map(move |chunk| self.embed(chunk, None)).collect()
205        } else {
206            threads
207                .install(move || {
208                    text_chunks.into_par_iter().map(move |chunk| self.embed(chunk, None)).collect()
209                })
210                .map_err(|error| EmbedError {
211                    kind: EmbedErrorKind::PanicInThreadPool(error),
212                    fault: FaultSource::Bug,
213                })?
214        }
215    }
216
217    pub(crate) fn embed_index_ref(
218        &self,
219        texts: &[&str],
220        threads: &ThreadPoolNoAbort,
221    ) -> Result<Vec<Embedding>, EmbedError> {
222        // This condition helps reduce the number of active rayon jobs
223        // so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
224        if threads.active_operations() >= REQUEST_PARALLELISM {
225            let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
226                .chunks(self.prompt_count_in_chunk_hint())
227                .map(move |chunk| self.embed_ref(chunk, None))
228                .collect();
229
230            let embeddings = embeddings?;
231            Ok(embeddings.into_iter().flatten().collect())
232        } else {
233            threads
234                .install(move || {
235                    let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
236                        .par_chunks(self.prompt_count_in_chunk_hint())
237                        .map(move |chunk| self.embed_ref(chunk, None))
238                        .collect();
239
240                    let embeddings = embeddings?;
241                    Ok(embeddings.into_iter().flatten().collect())
242                })
243                .map_err(|error| EmbedError {
244                    kind: EmbedErrorKind::PanicInThreadPool(error),
245                    fault: FaultSource::Bug,
246                })?
247        }
248    }
249
250    pub fn chunk_count_hint(&self) -> usize {
251        super::REQUEST_PARALLELISM
252    }
253
254    pub fn prompt_count_in_chunk_hint(&self) -> usize {
255        match self.data.request.input_type() {
256            InputType::Text => 1,
257            InputType::TextArray => 10,
258        }
259    }
260
261    pub fn dimensions(&self) -> usize {
262        self.dimensions
263    }
264
265    pub fn distribution(&self) -> Option<DistributionShift> {
266        self.distribution
267    }
268
269    pub(super) fn cache(&self) -> &EmbeddingCache {
270        &self.cache
271    }
272}
273
274fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
275    let v = embed(data, ["test"].as_slice(), 1, None, None)
276        .map_err(NewEmbedderError::could_not_determine_dimension)?;
277    // unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
278    Ok(v.first().unwrap().len())
279}
280
281fn embed<S>(
282    data: &EmbedderData,
283    inputs: &[S],
284    expected_count: usize,
285    expected_dimension: Option<usize>,
286    deadline: Option<Instant>,
287) -> Result<Vec<Embedding>, EmbedError>
288where
289    S: Serialize,
290{
291    let request = data.client.post(&data.url);
292    let request = if let Some(bearer) = &data.bearer {
293        request.set("Authorization", bearer)
294    } else {
295        request
296    };
297    let mut request = request.set("Content-Type", "application/json");
298    for (header, value) in &data.headers {
299        request = request.set(header.as_str(), value.as_str());
300    }
301
302    let body = data.request.inject_texts(inputs);
303
304    for attempt in 0..10 {
305        let response = request.clone().send_json(&body);
306        let result = check_response(response, data.configuration_source).and_then(|response| {
307            response_to_embedding(response, data, expected_count, expected_dimension)
308        });
309
310        let retry_duration = match result {
311            Ok(response) => return Ok(response),
312            Err(retry) => {
313                tracing::warn!("Failed: {}", retry.error);
314                if let Some(deadline) = deadline {
315                    let now = std::time::Instant::now();
316                    if now > deadline {
317                        tracing::warn!("Could not embed due to deadline");
318                        return Err(retry.into_error());
319                    }
320
321                    let duration_to_deadline = deadline - now;
322                    retry.into_duration(attempt).map(|duration| duration.min(duration_to_deadline))
323                } else {
324                    retry.into_duration(attempt)
325                }
326            }
327        }?;
328
329        let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute
330
331        // randomly up to double the retry duration
332        let retry_duration = retry_duration
333            + rand::thread_rng().gen_range(std::time::Duration::ZERO..retry_duration);
334
335        tracing::warn!("Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis());
336        std::thread::sleep(retry_duration);
337    }
338
339    let response = request.send_json(&body);
340    let result = check_response(response, data.configuration_source);
341    result.map_err(Retry::into_error).and_then(|response| {
342        response_to_embedding(response, data, expected_count, expected_dimension)
343            .map_err(Retry::into_error)
344    })
345}
346
347fn check_response(
348    response: Result<ureq::Response, ureq::Error>,
349    configuration_source: ConfigurationSource,
350) -> Result<ureq::Response, Retry> {
351    match response {
352        Ok(response) => Ok(response),
353        Err(ureq::Error::Status(code, response)) => {
354            let error_response: Option<String> = response.into_string().ok();
355            Err(match code {
356                401 => Retry::give_up(EmbedError::rest_unauthorized(
357                    error_response,
358                    configuration_source,
359                )),
360                429 => Retry::rate_limited(EmbedError::rest_too_many_requests(error_response)),
361                400 => Retry::give_up(EmbedError::rest_bad_request(
362                    error_response,
363                    configuration_source,
364                )),
365                500..=599 => {
366                    Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response))
367                }
368                402..=499 => {
369                    Retry::give_up(EmbedError::rest_other_status_code(code, error_response))
370                }
371                _ => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)),
372            })
373        }
374        Err(ureq::Error::Transport(transport)) => {
375            Err(Retry::retry_later(EmbedError::rest_network(transport)))
376        }
377    }
378}
379
380fn response_to_embedding(
381    response: ureq::Response,
382    data: &EmbedderData,
383    expected_count: usize,
384    expected_dimensions: Option<usize>,
385) -> Result<Vec<Embedding>, Retry> {
386    let response: serde_json::Value = response
387        .into_json()
388        .map_err(EmbedError::rest_response_deserialization)
389        .map_err(Retry::retry_later)?;
390
391    let embeddings = data.response.extract_embeddings(response).map_err(Retry::give_up)?;
392
393    if embeddings.len() != expected_count {
394        return Err(Retry::give_up(EmbedError::rest_response_embedding_count(
395            expected_count,
396            embeddings.len(),
397        )));
398    }
399
400    if let Some(dimensions) = expected_dimensions {
401        for embedding in &embeddings {
402            if embedding.len() != dimensions {
403                return Err(Retry::give_up(EmbedError::rest_unexpected_dimension(
404                    dimensions,
405                    embedding.len(),
406                )));
407            }
408        }
409    }
410
411    Ok(embeddings)
412}
413
414pub(super) const REQUEST_PLACEHOLDER: &str = "{{text}}";
415pub(super) const RESPONSE_PLACEHOLDER: &str = "{{embedding}}";
416pub(super) const REPEAT_PLACEHOLDER: &str = "{{..}}";
417
418#[derive(Debug)]
419pub struct Request {
420    template: ValueTemplate,
421}
422
423impl Request {
424    pub fn new(template: serde_json::Value) -> Result<Self, NewEmbedderError> {
425        let template = match ValueTemplate::new(template, REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER) {
426            Ok(template) => template,
427            Err(error) => {
428                let message =
429                    error.error_message("request", REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER);
430                return Err(NewEmbedderError::rest_could_not_parse_template(message));
431            }
432        };
433
434        Ok(Self { template })
435    }
436
437    fn input_type(&self) -> InputType {
438        if self.template.has_array_value() {
439            InputType::TextArray
440        } else {
441            InputType::Text
442        }
443    }
444
445    pub fn inject_texts<S: Serialize>(
446        &self,
447        texts: impl IntoIterator<Item = S>,
448    ) -> serde_json::Value {
449        self.template.inject(texts.into_iter().map(|s| serde_json::json!(s))).unwrap()
450    }
451}
452
453#[derive(Debug)]
454pub struct Response {
455    template: ValueTemplate,
456}
457
458impl Response {
459    pub fn new(template: serde_json::Value, request: &Request) -> Result<Self, NewEmbedderError> {
460        let template = match ValueTemplate::new(template, RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER)
461        {
462            Ok(template) => template,
463            Err(error) => {
464                let message =
465                    error.error_message("response", RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER);
466                return Err(NewEmbedderError::rest_could_not_parse_template(message));
467            }
468        };
469
470        match (template.has_array_value(), request.template.has_array_value()) {
471            (true, true) | (false, false) => Ok(Self {template}),
472            (true, false) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has multiple embeddings, but `request` has only one text to embed".to_string())),
473            (false, true) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has a single embedding, but `request` has multiple texts to embed".to_string())),
474        }
475    }
476
477    pub fn extract_embeddings(
478        &self,
479        response: serde_json::Value,
480    ) -> Result<Vec<Embedding>, EmbedError> {
481        let extracted_values: Vec<Embedding> = match self.template.extract(response) {
482            Ok(extracted_values) => extracted_values,
483            Err(error) => {
484                let error_message =
485                    error.error_message("response", "{{embedding}}", "an array of numbers");
486                return Err(EmbedError::rest_extraction_error(error_message));
487            }
488        };
489        let embeddings: Vec<Embedding> = extracted_values.into_iter().collect();
490
491        Ok(embeddings)
492    }
493}