milli_core/vector/
openai.rs

1use std::fmt;
2use std::time::Instant;
3
4use ordered_float::OrderedFloat;
5use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
6use rayon::slice::ParallelSlice as _;
7
8use super::error::{EmbedError, NewEmbedderError};
9use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
10use super::{DistributionShift, EmbeddingCache, REQUEST_PARALLELISM};
11use crate::error::FaultSource;
12use crate::vector::error::EmbedErrorKind;
13use crate::vector::Embedding;
14use crate::ThreadPoolNoAbort;
15
16#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
17pub struct EmbedderOptions {
18    pub url: Option<String>,
19    pub api_key: Option<String>,
20    pub embedding_model: EmbeddingModel,
21    pub dimensions: Option<usize>,
22    pub distribution: Option<DistributionShift>,
23}
24
25impl EmbedderOptions {
26    pub fn dimensions(&self) -> usize {
27        if self.embedding_model.supports_overriding_dimensions() {
28            self.dimensions.unwrap_or(self.embedding_model.default_dimensions())
29        } else {
30            self.embedding_model.default_dimensions()
31        }
32    }
33
34    pub fn request(&self) -> serde_json::Value {
35        let model = self.embedding_model.name();
36
37        let mut request = serde_json::json!({
38            "model": model,
39            "input": [super::rest::REQUEST_PLACEHOLDER, super::rest::REPEAT_PLACEHOLDER]
40        });
41
42        if self.embedding_model.supports_overriding_dimensions() {
43            if let Some(dimensions) = self.dimensions {
44                request["dimensions"] = dimensions.into();
45            }
46        }
47
48        request
49    }
50
51    pub fn distribution(&self) -> Option<DistributionShift> {
52        self.distribution.or(self.embedding_model.distribution())
53    }
54}
55
56#[derive(
57    Debug,
58    Clone,
59    Copy,
60    Default,
61    Hash,
62    PartialEq,
63    Eq,
64    serde::Serialize,
65    serde::Deserialize,
66    deserr::Deserr,
67)]
68#[serde(deny_unknown_fields, rename_all = "camelCase")]
69#[deserr(rename_all = camelCase, deny_unknown_fields)]
70pub enum EmbeddingModel {
71    // # WARNING
72    //
73    // If ever adding a model, make sure to add it to the list of supported models below.
74    #[serde(rename = "text-embedding-ada-002")]
75    #[deserr(rename = "text-embedding-ada-002")]
76    TextEmbeddingAda002,
77
78    #[default]
79    #[serde(rename = "text-embedding-3-small")]
80    #[deserr(rename = "text-embedding-3-small")]
81    TextEmbedding3Small,
82
83    #[serde(rename = "text-embedding-3-large")]
84    #[deserr(rename = "text-embedding-3-large")]
85    TextEmbedding3Large,
86}
87
88impl EmbeddingModel {
89    pub fn supported_models() -> &'static [&'static str] {
90        &["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"]
91    }
92
93    pub fn max_token(&self) -> usize {
94        match self {
95            EmbeddingModel::TextEmbeddingAda002 => 8191,
96            EmbeddingModel::TextEmbedding3Large => 8191,
97            EmbeddingModel::TextEmbedding3Small => 8191,
98        }
99    }
100
101    pub fn default_dimensions(&self) -> usize {
102        match self {
103            EmbeddingModel::TextEmbeddingAda002 => 1536,
104            EmbeddingModel::TextEmbedding3Large => 3072,
105            EmbeddingModel::TextEmbedding3Small => 1536,
106        }
107    }
108
109    pub fn name(&self) -> &'static str {
110        match self {
111            EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002",
112            EmbeddingModel::TextEmbedding3Large => "text-embedding-3-large",
113            EmbeddingModel::TextEmbedding3Small => "text-embedding-3-small",
114        }
115    }
116
117    pub fn from_name(name: &str) -> Option<Self> {
118        match name {
119            "text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002),
120            "text-embedding-3-large" => Some(EmbeddingModel::TextEmbedding3Large),
121            "text-embedding-3-small" => Some(EmbeddingModel::TextEmbedding3Small),
122            _ => None,
123        }
124    }
125
126    fn distribution(&self) -> Option<DistributionShift> {
127        match self {
128            EmbeddingModel::TextEmbeddingAda002 => Some(DistributionShift {
129                current_mean: OrderedFloat(0.90),
130                current_sigma: OrderedFloat(0.08),
131            }),
132            EmbeddingModel::TextEmbedding3Large => Some(DistributionShift {
133                current_mean: OrderedFloat(0.70),
134                current_sigma: OrderedFloat(0.1),
135            }),
136            EmbeddingModel::TextEmbedding3Small => Some(DistributionShift {
137                current_mean: OrderedFloat(0.75),
138                current_sigma: OrderedFloat(0.1),
139            }),
140        }
141    }
142
143    pub fn supports_overriding_dimensions(&self) -> bool {
144        match self {
145            EmbeddingModel::TextEmbeddingAda002 => false,
146            EmbeddingModel::TextEmbedding3Large => true,
147            EmbeddingModel::TextEmbedding3Small => true,
148        }
149    }
150}
151
152pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
153
154impl EmbedderOptions {
155    pub fn with_default_model(api_key: Option<String>) -> Self {
156        Self {
157            api_key,
158            embedding_model: Default::default(),
159            dimensions: None,
160            distribution: None,
161            url: None,
162        }
163    }
164}
165
166fn infer_api_key() -> String {
167    std::env::var("MEILI_OPENAI_API_KEY")
168        .or_else(|_| std::env::var("OPENAI_API_KEY"))
169        .unwrap_or_default()
170}
171
172pub struct Embedder {
173    tokenizer: tiktoken_rs::CoreBPE,
174    rest_embedder: RestEmbedder,
175    options: EmbedderOptions,
176}
177
178impl Embedder {
179    pub fn new(options: EmbedderOptions, cache_cap: usize) -> Result<Self, NewEmbedderError> {
180        let mut inferred_api_key = Default::default();
181        let api_key = options.api_key.as_ref().unwrap_or_else(|| {
182            inferred_api_key = infer_api_key();
183            &inferred_api_key
184        });
185
186        let url = options.url.as_deref().unwrap_or(OPENAI_EMBEDDINGS_URL).to_owned();
187
188        let rest_embedder = RestEmbedder::new(
189            RestEmbedderOptions {
190                api_key: (!api_key.is_empty()).then(|| api_key.clone()),
191                distribution: None,
192                dimensions: Some(options.dimensions()),
193                url,
194                request: options.request(),
195                response: serde_json::json!({
196                    "data": [{
197                        "embedding": super::rest::RESPONSE_PLACEHOLDER
198                    },
199                    super::rest::REPEAT_PLACEHOLDER
200                    ]
201                }),
202                headers: Default::default(),
203            },
204            cache_cap,
205            super::rest::ConfigurationSource::OpenAi,
206        )?;
207
208        // looking at the code it is very unclear that this can actually fail.
209        let tokenizer = tiktoken_rs::cl100k_base().unwrap();
210
211        Ok(Self { options, rest_embedder, tokenizer })
212    }
213
214    pub fn embed<S: AsRef<str> + serde::Serialize>(
215        &self,
216        texts: &[S],
217        deadline: Option<Instant>,
218    ) -> Result<Vec<Embedding>, EmbedError> {
219        match self.rest_embedder.embed_ref(texts, deadline) {
220            Ok(embeddings) => Ok(embeddings),
221            Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => {
222                tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
223                self.try_embed_tokenized(texts, deadline)
224            }
225            Err(error) => Err(error),
226        }
227    }
228
229    fn try_embed_tokenized<S: AsRef<str>>(
230        &self,
231        text: &[S],
232        deadline: Option<Instant>,
233    ) -> Result<Vec<Embedding>, EmbedError> {
234        let mut all_embeddings = Vec::with_capacity(text.len());
235        for text in text {
236            let text = text.as_ref();
237            let max_token_count = self.options.embedding_model.max_token();
238            let encoded = self.tokenizer.encode_ordinary(text);
239            let len = encoded.len();
240            if len < max_token_count {
241                all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text], deadline)?);
242                continue;
243            }
244
245            let tokens = &encoded.as_slice()[0..max_token_count];
246
247            let embedding = self.rest_embedder.embed_tokens(tokens, deadline)?;
248
249            all_embeddings.push(embedding);
250        }
251        Ok(all_embeddings)
252    }
253
254    pub fn embed_index(
255        &self,
256        text_chunks: Vec<Vec<String>>,
257        threads: &ThreadPoolNoAbort,
258    ) -> Result<Vec<Vec<Embedding>>, EmbedError> {
259        // This condition helps reduce the number of active rayon jobs
260        // so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
261        if threads.active_operations() >= REQUEST_PARALLELISM {
262            text_chunks.into_iter().map(move |chunk| self.embed(&chunk, None)).collect()
263        } else {
264            threads
265                .install(move || {
266                    text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None)).collect()
267                })
268                .map_err(|error| EmbedError {
269                    kind: EmbedErrorKind::PanicInThreadPool(error),
270                    fault: FaultSource::Bug,
271                })?
272        }
273    }
274
275    pub(crate) fn embed_index_ref(
276        &self,
277        texts: &[&str],
278        threads: &ThreadPoolNoAbort,
279    ) -> Result<Vec<Vec<f32>>, EmbedError> {
280        // This condition helps reduce the number of active rayon jobs
281        // so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
282        if threads.active_operations() >= REQUEST_PARALLELISM {
283            let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
284                .chunks(self.prompt_count_in_chunk_hint())
285                .map(move |chunk| self.embed(chunk, None))
286                .collect();
287            let embeddings = embeddings?;
288            Ok(embeddings.into_iter().flatten().collect())
289        } else {
290            threads
291                .install(move || {
292                    let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
293                        .par_chunks(self.prompt_count_in_chunk_hint())
294                        .map(move |chunk| self.embed(chunk, None))
295                        .collect();
296
297                    let embeddings = embeddings?;
298                    Ok(embeddings.into_iter().flatten().collect())
299                })
300                .map_err(|error| EmbedError {
301                    kind: EmbedErrorKind::PanicInThreadPool(error),
302                    fault: FaultSource::Bug,
303                })?
304        }
305    }
306
307    pub fn chunk_count_hint(&self) -> usize {
308        self.rest_embedder.chunk_count_hint()
309    }
310
311    pub fn prompt_count_in_chunk_hint(&self) -> usize {
312        self.rest_embedder.prompt_count_in_chunk_hint()
313    }
314
315    pub fn dimensions(&self) -> usize {
316        self.options.dimensions()
317    }
318
319    pub fn distribution(&self) -> Option<DistributionShift> {
320        self.options.distribution()
321    }
322
323    pub(super) fn cache(&self) -> &EmbeddingCache {
324        self.rest_embedder.cache()
325    }
326}
327
328impl fmt::Debug for Embedder {
329    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
330        f.debug_struct("Embedder")
331            .field("tokenizer", &"CoreBPE")
332            .field("rest_embedder", &self.rest_embedder)
333            .field("options", &self.options)
334            .finish()
335    }
336}