1use std::fmt;
2use std::time::Instant;
3
4use ordered_float::OrderedFloat;
5use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
6use rayon::slice::ParallelSlice as _;
7
8use super::error::{EmbedError, NewEmbedderError};
9use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
10use super::{DistributionShift, EmbeddingCache, REQUEST_PARALLELISM};
11use crate::error::FaultSource;
12use crate::vector::error::EmbedErrorKind;
13use crate::vector::Embedding;
14use crate::ThreadPoolNoAbort;
15
16#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
17pub struct EmbedderOptions {
18 pub url: Option<String>,
19 pub api_key: Option<String>,
20 pub embedding_model: EmbeddingModel,
21 pub dimensions: Option<usize>,
22 pub distribution: Option<DistributionShift>,
23}
24
25impl EmbedderOptions {
26 pub fn dimensions(&self) -> usize {
27 if self.embedding_model.supports_overriding_dimensions() {
28 self.dimensions.unwrap_or(self.embedding_model.default_dimensions())
29 } else {
30 self.embedding_model.default_dimensions()
31 }
32 }
33
34 pub fn request(&self) -> serde_json::Value {
35 let model = self.embedding_model.name();
36
37 let mut request = serde_json::json!({
38 "model": model,
39 "input": [super::rest::REQUEST_PLACEHOLDER, super::rest::REPEAT_PLACEHOLDER]
40 });
41
42 if self.embedding_model.supports_overriding_dimensions() {
43 if let Some(dimensions) = self.dimensions {
44 request["dimensions"] = dimensions.into();
45 }
46 }
47
48 request
49 }
50
51 pub fn distribution(&self) -> Option<DistributionShift> {
52 self.distribution.or(self.embedding_model.distribution())
53 }
54}
55
56#[derive(
57 Debug,
58 Clone,
59 Copy,
60 Default,
61 Hash,
62 PartialEq,
63 Eq,
64 serde::Serialize,
65 serde::Deserialize,
66 deserr::Deserr,
67)]
68#[serde(deny_unknown_fields, rename_all = "camelCase")]
69#[deserr(rename_all = camelCase, deny_unknown_fields)]
70pub enum EmbeddingModel {
71 #[serde(rename = "text-embedding-ada-002")]
75 #[deserr(rename = "text-embedding-ada-002")]
76 TextEmbeddingAda002,
77
78 #[default]
79 #[serde(rename = "text-embedding-3-small")]
80 #[deserr(rename = "text-embedding-3-small")]
81 TextEmbedding3Small,
82
83 #[serde(rename = "text-embedding-3-large")]
84 #[deserr(rename = "text-embedding-3-large")]
85 TextEmbedding3Large,
86}
87
88impl EmbeddingModel {
89 pub fn supported_models() -> &'static [&'static str] {
90 &["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"]
91 }
92
93 pub fn max_token(&self) -> usize {
94 match self {
95 EmbeddingModel::TextEmbeddingAda002 => 8191,
96 EmbeddingModel::TextEmbedding3Large => 8191,
97 EmbeddingModel::TextEmbedding3Small => 8191,
98 }
99 }
100
101 pub fn default_dimensions(&self) -> usize {
102 match self {
103 EmbeddingModel::TextEmbeddingAda002 => 1536,
104 EmbeddingModel::TextEmbedding3Large => 3072,
105 EmbeddingModel::TextEmbedding3Small => 1536,
106 }
107 }
108
109 pub fn name(&self) -> &'static str {
110 match self {
111 EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002",
112 EmbeddingModel::TextEmbedding3Large => "text-embedding-3-large",
113 EmbeddingModel::TextEmbedding3Small => "text-embedding-3-small",
114 }
115 }
116
117 pub fn from_name(name: &str) -> Option<Self> {
118 match name {
119 "text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002),
120 "text-embedding-3-large" => Some(EmbeddingModel::TextEmbedding3Large),
121 "text-embedding-3-small" => Some(EmbeddingModel::TextEmbedding3Small),
122 _ => None,
123 }
124 }
125
126 fn distribution(&self) -> Option<DistributionShift> {
127 match self {
128 EmbeddingModel::TextEmbeddingAda002 => Some(DistributionShift {
129 current_mean: OrderedFloat(0.90),
130 current_sigma: OrderedFloat(0.08),
131 }),
132 EmbeddingModel::TextEmbedding3Large => Some(DistributionShift {
133 current_mean: OrderedFloat(0.70),
134 current_sigma: OrderedFloat(0.1),
135 }),
136 EmbeddingModel::TextEmbedding3Small => Some(DistributionShift {
137 current_mean: OrderedFloat(0.75),
138 current_sigma: OrderedFloat(0.1),
139 }),
140 }
141 }
142
143 pub fn supports_overriding_dimensions(&self) -> bool {
144 match self {
145 EmbeddingModel::TextEmbeddingAda002 => false,
146 EmbeddingModel::TextEmbedding3Large => true,
147 EmbeddingModel::TextEmbedding3Small => true,
148 }
149 }
150}
151
152pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
153
154impl EmbedderOptions {
155 pub fn with_default_model(api_key: Option<String>) -> Self {
156 Self {
157 api_key,
158 embedding_model: Default::default(),
159 dimensions: None,
160 distribution: None,
161 url: None,
162 }
163 }
164}
165
166fn infer_api_key() -> String {
167 std::env::var("MEILI_OPENAI_API_KEY")
168 .or_else(|_| std::env::var("OPENAI_API_KEY"))
169 .unwrap_or_default()
170}
171
172pub struct Embedder {
173 tokenizer: tiktoken_rs::CoreBPE,
174 rest_embedder: RestEmbedder,
175 options: EmbedderOptions,
176}
177
178impl Embedder {
179 pub fn new(options: EmbedderOptions, cache_cap: usize) -> Result<Self, NewEmbedderError> {
180 let mut inferred_api_key = Default::default();
181 let api_key = options.api_key.as_ref().unwrap_or_else(|| {
182 inferred_api_key = infer_api_key();
183 &inferred_api_key
184 });
185
186 let url = options.url.as_deref().unwrap_or(OPENAI_EMBEDDINGS_URL).to_owned();
187
188 let rest_embedder = RestEmbedder::new(
189 RestEmbedderOptions {
190 api_key: (!api_key.is_empty()).then(|| api_key.clone()),
191 distribution: None,
192 dimensions: Some(options.dimensions()),
193 url,
194 request: options.request(),
195 response: serde_json::json!({
196 "data": [{
197 "embedding": super::rest::RESPONSE_PLACEHOLDER
198 },
199 super::rest::REPEAT_PLACEHOLDER
200 ]
201 }),
202 headers: Default::default(),
203 },
204 cache_cap,
205 super::rest::ConfigurationSource::OpenAi,
206 )?;
207
208 let tokenizer = tiktoken_rs::cl100k_base().unwrap();
210
211 Ok(Self { options, rest_embedder, tokenizer })
212 }
213
214 pub fn embed<S: AsRef<str> + serde::Serialize>(
215 &self,
216 texts: &[S],
217 deadline: Option<Instant>,
218 ) -> Result<Vec<Embedding>, EmbedError> {
219 match self.rest_embedder.embed_ref(texts, deadline) {
220 Ok(embeddings) => Ok(embeddings),
221 Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => {
222 tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
223 self.try_embed_tokenized(texts, deadline)
224 }
225 Err(error) => Err(error),
226 }
227 }
228
229 fn try_embed_tokenized<S: AsRef<str>>(
230 &self,
231 text: &[S],
232 deadline: Option<Instant>,
233 ) -> Result<Vec<Embedding>, EmbedError> {
234 let mut all_embeddings = Vec::with_capacity(text.len());
235 for text in text {
236 let text = text.as_ref();
237 let max_token_count = self.options.embedding_model.max_token();
238 let encoded = self.tokenizer.encode_ordinary(text);
239 let len = encoded.len();
240 if len < max_token_count {
241 all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text], deadline)?);
242 continue;
243 }
244
245 let tokens = &encoded.as_slice()[0..max_token_count];
246
247 let embedding = self.rest_embedder.embed_tokens(tokens, deadline)?;
248
249 all_embeddings.push(embedding);
250 }
251 Ok(all_embeddings)
252 }
253
254 pub fn embed_index(
255 &self,
256 text_chunks: Vec<Vec<String>>,
257 threads: &ThreadPoolNoAbort,
258 ) -> Result<Vec<Vec<Embedding>>, EmbedError> {
259 if threads.active_operations() >= REQUEST_PARALLELISM {
262 text_chunks.into_iter().map(move |chunk| self.embed(&chunk, None)).collect()
263 } else {
264 threads
265 .install(move || {
266 text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None)).collect()
267 })
268 .map_err(|error| EmbedError {
269 kind: EmbedErrorKind::PanicInThreadPool(error),
270 fault: FaultSource::Bug,
271 })?
272 }
273 }
274
275 pub(crate) fn embed_index_ref(
276 &self,
277 texts: &[&str],
278 threads: &ThreadPoolNoAbort,
279 ) -> Result<Vec<Vec<f32>>, EmbedError> {
280 if threads.active_operations() >= REQUEST_PARALLELISM {
283 let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
284 .chunks(self.prompt_count_in_chunk_hint())
285 .map(move |chunk| self.embed(chunk, None))
286 .collect();
287 let embeddings = embeddings?;
288 Ok(embeddings.into_iter().flatten().collect())
289 } else {
290 threads
291 .install(move || {
292 let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
293 .par_chunks(self.prompt_count_in_chunk_hint())
294 .map(move |chunk| self.embed(chunk, None))
295 .collect();
296
297 let embeddings = embeddings?;
298 Ok(embeddings.into_iter().flatten().collect())
299 })
300 .map_err(|error| EmbedError {
301 kind: EmbedErrorKind::PanicInThreadPool(error),
302 fault: FaultSource::Bug,
303 })?
304 }
305 }
306
307 pub fn chunk_count_hint(&self) -> usize {
308 self.rest_embedder.chunk_count_hint()
309 }
310
311 pub fn prompt_count_in_chunk_hint(&self) -> usize {
312 self.rest_embedder.prompt_count_in_chunk_hint()
313 }
314
315 pub fn dimensions(&self) -> usize {
316 self.options.dimensions()
317 }
318
319 pub fn distribution(&self) -> Option<DistributionShift> {
320 self.options.distribution()
321 }
322
323 pub(super) fn cache(&self) -> &EmbeddingCache {
324 self.rest_embedder.cache()
325 }
326}
327
328impl fmt::Debug for Embedder {
329 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
330 f.debug_struct("Embedder")
331 .field("tokenizer", &"CoreBPE")
332 .field("rest_embedder", &self.rest_embedder)
333 .field("options", &self.options)
334 .finish()
335 }
336}