milli_core/vector/
mod.rs

1use std::collections::HashMap;
2use std::num::NonZeroUsize;
3use std::sync::{Arc, Mutex};
4use std::time::Instant;
5
6use arroy::distances::{BinaryQuantizedCosine, Cosine};
7use arroy::ItemId;
8use deserr::{DeserializeError, Deserr};
9use heed::{RoTxn, RwTxn, Unspecified};
10use ordered_float::OrderedFloat;
11use roaring::RoaringBitmap;
12use serde::{Deserialize, Serialize};
13use utoipa::ToSchema;
14
15use self::error::{EmbedError, NewEmbedderError};
16use crate::progress::Progress;
17use crate::prompt::{Prompt, PromptData};
18use crate::ThreadPoolNoAbort;
19
20pub mod composite;
21pub mod error;
22pub mod hf;
23pub mod json_template;
24pub mod manual;
25pub mod openai;
26pub mod parsed_vectors;
27pub mod settings;
28
29pub mod ollama;
30pub mod rest;
31
32pub use self::error::Error;
33
34pub type Embedding = Vec<f32>;
35
36pub const REQUEST_PARALLELISM: usize = 40;
37pub const MAX_COMPOSITE_DISTANCE: f32 = 0.01;
38
39pub struct ArroyWrapper {
40    quantized: bool,
41    embedder_index: u8,
42    database: arroy::Database<Unspecified>,
43}
44
45impl ArroyWrapper {
46    pub fn new(
47        database: arroy::Database<Unspecified>,
48        embedder_index: u8,
49        quantized: bool,
50    ) -> Self {
51        Self { database, embedder_index, quantized }
52    }
53
54    pub fn embedder_index(&self) -> u8 {
55        self.embedder_index
56    }
57
58    fn readers<'a, D: arroy::Distance>(
59        &'a self,
60        rtxn: &'a RoTxn<'a>,
61        db: arroy::Database<D>,
62    ) -> impl Iterator<Item = Result<arroy::Reader<'a, D>, arroy::Error>> + 'a {
63        arroy_db_range_for_embedder(self.embedder_index).map_while(move |index| {
64            match arroy::Reader::open(rtxn, index, db) {
65                Ok(reader) => match reader.is_empty(rtxn) {
66                    Ok(false) => Some(Ok(reader)),
67                    Ok(true) => None,
68                    Err(e) => Some(Err(e)),
69                },
70                Err(arroy::Error::MissingMetadata(_)) => None,
71                Err(e) => Some(Err(e)),
72            }
73        })
74    }
75
76    pub fn dimensions(&self, rtxn: &RoTxn) -> Result<usize, arroy::Error> {
77        let first_id = arroy_db_range_for_embedder(self.embedder_index).next().unwrap();
78        if self.quantized {
79            Ok(arroy::Reader::open(rtxn, first_id, self.quantized_db())?.dimensions())
80        } else {
81            Ok(arroy::Reader::open(rtxn, first_id, self.angular_db())?.dimensions())
82        }
83    }
84
85    #[allow(clippy::too_many_arguments)]
86    pub fn build_and_quantize<R: rand::Rng + rand::SeedableRng>(
87        &mut self,
88        wtxn: &mut RwTxn,
89        progress: &Progress,
90        rng: &mut R,
91        dimension: usize,
92        quantizing: bool,
93        arroy_memory: Option<usize>,
94        cancel: &(impl Fn() -> bool + Sync + Send),
95    ) -> Result<(), arroy::Error> {
96        for index in arroy_db_range_for_embedder(self.embedder_index) {
97            if self.quantized {
98                let writer = arroy::Writer::new(self.quantized_db(), index, dimension);
99                if writer.need_build(wtxn)? {
100                    writer.builder(rng).build(wtxn)?
101                } else if writer.is_empty(wtxn)? {
102                    break;
103                }
104            } else {
105                let writer = arroy::Writer::new(self.angular_db(), index, dimension);
106                // If we are quantizing the databases, we can't know from meilisearch
107                // if the db was empty but still contained the wrong metadata, thus we need
108                // to quantize everything and can't stop early. Since this operation can
109                // only happens once in the life of an embedder, it's not very performances
110                // sensitive.
111                if quantizing && !self.quantized {
112                    let writer = writer.prepare_changing_distance::<BinaryQuantizedCosine>(wtxn)?;
113                    writer
114                        .builder(rng)
115                        .available_memory(arroy_memory.unwrap_or(usize::MAX))
116                        .progress(|step| progress.update_progress_from_arroy(step))
117                        .cancel(cancel)
118                        .build(wtxn)?;
119                } else if writer.need_build(wtxn)? {
120                    writer
121                        .builder(rng)
122                        .available_memory(arroy_memory.unwrap_or(usize::MAX))
123                        .progress(|step| progress.update_progress_from_arroy(step))
124                        .cancel(cancel)
125                        .build(wtxn)?;
126                } else if writer.is_empty(wtxn)? {
127                    break;
128                }
129            }
130        }
131        Ok(())
132    }
133
134    /// Overwrite all the embeddings associated with the index and item ID.
135    /// /!\ It won't remove embeddings after the last passed embedding, which can leave stale embeddings.
136    ///     You should call `del_items` on the `item_id` before calling this method.
137    /// /!\ Cannot insert more than u8::MAX embeddings; after inserting u8::MAX embeddings, all the remaining ones will be silently ignored.
138    pub fn add_items(
139        &self,
140        wtxn: &mut RwTxn,
141        item_id: arroy::ItemId,
142        embeddings: &Embeddings<f32>,
143    ) -> Result<(), arroy::Error> {
144        let dimension = embeddings.dimension();
145        for (index, vector) in
146            arroy_db_range_for_embedder(self.embedder_index).zip(embeddings.iter())
147        {
148            if self.quantized {
149                arroy::Writer::new(self.quantized_db(), index, dimension)
150                    .add_item(wtxn, item_id, vector)?
151            } else {
152                arroy::Writer::new(self.angular_db(), index, dimension)
153                    .add_item(wtxn, item_id, vector)?
154            }
155        }
156        Ok(())
157    }
158
159    /// Add one document int for this index where we can find an empty spot.
160    pub fn add_item(
161        &self,
162        wtxn: &mut RwTxn,
163        item_id: arroy::ItemId,
164        vector: &[f32],
165    ) -> Result<(), arroy::Error> {
166        if self.quantized {
167            self._add_item(wtxn, self.quantized_db(), item_id, vector)
168        } else {
169            self._add_item(wtxn, self.angular_db(), item_id, vector)
170        }
171    }
172
173    fn _add_item<D: arroy::Distance>(
174        &self,
175        wtxn: &mut RwTxn,
176        db: arroy::Database<D>,
177        item_id: arroy::ItemId,
178        vector: &[f32],
179    ) -> Result<(), arroy::Error> {
180        let dimension = vector.len();
181
182        for index in arroy_db_range_for_embedder(self.embedder_index) {
183            let writer = arroy::Writer::new(db, index, dimension);
184            if !writer.contains_item(wtxn, item_id)? {
185                writer.add_item(wtxn, item_id, vector)?;
186                break;
187            }
188        }
189        Ok(())
190    }
191
192    /// Delete all embeddings from a specific `item_id`
193    pub fn del_items(
194        &self,
195        wtxn: &mut RwTxn,
196        dimension: usize,
197        item_id: arroy::ItemId,
198    ) -> Result<(), arroy::Error> {
199        for index in arroy_db_range_for_embedder(self.embedder_index) {
200            if self.quantized {
201                let writer = arroy::Writer::new(self.quantized_db(), index, dimension);
202                if !writer.del_item(wtxn, item_id)? {
203                    break;
204                }
205            } else {
206                let writer = arroy::Writer::new(self.angular_db(), index, dimension);
207                if !writer.del_item(wtxn, item_id)? {
208                    break;
209                }
210            }
211        }
212
213        Ok(())
214    }
215
216    /// Delete one item.
217    pub fn del_item(
218        &self,
219        wtxn: &mut RwTxn,
220        item_id: arroy::ItemId,
221        vector: &[f32],
222    ) -> Result<bool, arroy::Error> {
223        if self.quantized {
224            self._del_item(wtxn, self.quantized_db(), item_id, vector)
225        } else {
226            self._del_item(wtxn, self.angular_db(), item_id, vector)
227        }
228    }
229
230    fn _del_item<D: arroy::Distance>(
231        &self,
232        wtxn: &mut RwTxn,
233        db: arroy::Database<D>,
234        item_id: arroy::ItemId,
235        vector: &[f32],
236    ) -> Result<bool, arroy::Error> {
237        let dimension = vector.len();
238        let mut deleted_index = None;
239
240        for index in arroy_db_range_for_embedder(self.embedder_index) {
241            let writer = arroy::Writer::new(db, index, dimension);
242            let Some(candidate) = writer.item_vector(wtxn, item_id)? else {
243                // uses invariant: vectors are packed in the first writers.
244                break;
245            };
246            if candidate == vector {
247                writer.del_item(wtxn, item_id)?;
248                deleted_index = Some(index);
249            }
250        }
251
252        // 🥲 enforce invariant: vectors are packed in the first writers.
253        if let Some(deleted_index) = deleted_index {
254            let mut last_index_with_a_vector = None;
255            for index in
256                arroy_db_range_for_embedder(self.embedder_index).skip(deleted_index as usize)
257            {
258                let writer = arroy::Writer::new(db, index, dimension);
259                let Some(candidate) = writer.item_vector(wtxn, item_id)? else {
260                    break;
261                };
262                last_index_with_a_vector = Some((index, candidate));
263            }
264            if let Some((last_index, vector)) = last_index_with_a_vector {
265                let writer = arroy::Writer::new(db, last_index, dimension);
266                writer.del_item(wtxn, item_id)?;
267                let writer = arroy::Writer::new(db, deleted_index, dimension);
268                writer.add_item(wtxn, item_id, &vector)?;
269            }
270        }
271        Ok(deleted_index.is_some())
272    }
273
274    pub fn clear(&self, wtxn: &mut RwTxn, dimension: usize) -> Result<(), arroy::Error> {
275        for index in arroy_db_range_for_embedder(self.embedder_index) {
276            if self.quantized {
277                let writer = arroy::Writer::new(self.quantized_db(), index, dimension);
278                if writer.is_empty(wtxn)? {
279                    break;
280                }
281                writer.clear(wtxn)?;
282            } else {
283                let writer = arroy::Writer::new(self.angular_db(), index, dimension);
284                if writer.is_empty(wtxn)? {
285                    break;
286                }
287                writer.clear(wtxn)?;
288            }
289        }
290        Ok(())
291    }
292
293    pub fn contains_item(
294        &self,
295        rtxn: &RoTxn,
296        dimension: usize,
297        item: arroy::ItemId,
298    ) -> Result<bool, arroy::Error> {
299        for index in arroy_db_range_for_embedder(self.embedder_index) {
300            let contains = if self.quantized {
301                let writer = arroy::Writer::new(self.quantized_db(), index, dimension);
302                if writer.is_empty(rtxn)? {
303                    break;
304                }
305                writer.contains_item(rtxn, item)?
306            } else {
307                let writer = arroy::Writer::new(self.angular_db(), index, dimension);
308                if writer.is_empty(rtxn)? {
309                    break;
310                }
311                writer.contains_item(rtxn, item)?
312            };
313            if contains {
314                return Ok(contains);
315            }
316        }
317        Ok(false)
318    }
319
320    pub fn nns_by_item(
321        &self,
322        rtxn: &RoTxn,
323        item: ItemId,
324        limit: usize,
325        filter: Option<&RoaringBitmap>,
326    ) -> Result<Vec<(ItemId, f32)>, arroy::Error> {
327        if self.quantized {
328            self._nns_by_item(rtxn, self.quantized_db(), item, limit, filter)
329        } else {
330            self._nns_by_item(rtxn, self.angular_db(), item, limit, filter)
331        }
332    }
333
334    fn _nns_by_item<D: arroy::Distance>(
335        &self,
336        rtxn: &RoTxn,
337        db: arroy::Database<D>,
338        item: ItemId,
339        limit: usize,
340        filter: Option<&RoaringBitmap>,
341    ) -> Result<Vec<(ItemId, f32)>, arroy::Error> {
342        let mut results = Vec::new();
343
344        for reader in self.readers(rtxn, db) {
345            let reader = reader?;
346            let mut searcher = reader.nns(limit);
347            if let Some(filter) = filter {
348                searcher.candidates(filter);
349            }
350
351            if let Some(mut ret) = searcher.by_item(rtxn, item)? {
352                results.append(&mut ret);
353            } else {
354                break;
355            }
356        }
357        results.sort_unstable_by_key(|(_, distance)| OrderedFloat(*distance));
358        Ok(results)
359    }
360
361    pub fn nns_by_vector(
362        &self,
363        rtxn: &RoTxn,
364        vector: &[f32],
365        limit: usize,
366        filter: Option<&RoaringBitmap>,
367    ) -> Result<Vec<(ItemId, f32)>, arroy::Error> {
368        if self.quantized {
369            self._nns_by_vector(rtxn, self.quantized_db(), vector, limit, filter)
370        } else {
371            self._nns_by_vector(rtxn, self.angular_db(), vector, limit, filter)
372        }
373    }
374
375    fn _nns_by_vector<D: arroy::Distance>(
376        &self,
377        rtxn: &RoTxn,
378        db: arroy::Database<D>,
379        vector: &[f32],
380        limit: usize,
381        filter: Option<&RoaringBitmap>,
382    ) -> Result<Vec<(ItemId, f32)>, arroy::Error> {
383        let mut results = Vec::new();
384
385        for reader in self.readers(rtxn, db) {
386            let reader = reader?;
387            let mut searcher = reader.nns(limit);
388            if let Some(filter) = filter {
389                searcher.candidates(filter);
390            }
391
392            results.append(&mut searcher.by_vector(rtxn, vector)?);
393        }
394
395        results.sort_unstable_by_key(|(_, distance)| OrderedFloat(*distance));
396
397        Ok(results)
398    }
399
400    pub fn item_vectors(&self, rtxn: &RoTxn, item_id: u32) -> Result<Vec<Vec<f32>>, arroy::Error> {
401        let mut vectors = Vec::new();
402
403        if self.quantized {
404            for reader in self.readers(rtxn, self.quantized_db()) {
405                if let Some(vec) = reader?.item_vector(rtxn, item_id)? {
406                    vectors.push(vec);
407                } else {
408                    break;
409                }
410            }
411        } else {
412            for reader in self.readers(rtxn, self.angular_db()) {
413                if let Some(vec) = reader?.item_vector(rtxn, item_id)? {
414                    vectors.push(vec);
415                } else {
416                    break;
417                }
418            }
419        }
420        Ok(vectors)
421    }
422
423    fn angular_db(&self) -> arroy::Database<Cosine> {
424        self.database.remap_data_type()
425    }
426
427    fn quantized_db(&self) -> arroy::Database<BinaryQuantizedCosine> {
428        self.database.remap_data_type()
429    }
430
431    pub fn aggregate_stats(
432        &self,
433        rtxn: &RoTxn,
434        stats: &mut ArroyStats,
435    ) -> Result<(), arroy::Error> {
436        if self.quantized {
437            for reader in self.readers(rtxn, self.quantized_db()) {
438                let reader = reader?;
439                let documents = reader.item_ids();
440                if documents.is_empty() {
441                    break;
442                }
443                stats.documents |= documents;
444                stats.number_of_embeddings += documents.len();
445            }
446        } else {
447            for reader in self.readers(rtxn, self.angular_db()) {
448                let reader = reader?;
449                let documents = reader.item_ids();
450                if documents.is_empty() {
451                    break;
452                }
453                stats.documents |= documents;
454                stats.number_of_embeddings += documents.len();
455            }
456        }
457
458        Ok(())
459    }
460}
461
462#[derive(Debug, Default, Clone)]
463pub struct ArroyStats {
464    pub number_of_embeddings: u64,
465    pub documents: RoaringBitmap,
466}
467/// One or multiple embeddings stored consecutively in a flat vector.
468pub struct Embeddings<F> {
469    data: Vec<F>,
470    dimension: usize,
471}
472
473impl<F> Embeddings<F> {
474    /// Declares an empty  vector of embeddings of the specified dimensions.
475    pub fn new(dimension: usize) -> Self {
476        Self { data: Default::default(), dimension }
477    }
478
479    /// Declares a vector of embeddings containing a single element.
480    ///
481    /// The dimension is inferred from the length of the passed embedding.
482    pub fn from_single_embedding(embedding: Vec<F>) -> Self {
483        Self { dimension: embedding.len(), data: embedding }
484    }
485
486    /// Declares a vector of embeddings from its components.
487    ///
488    /// `data.len()` must be a multiple of `dimension`, otherwise an error is returned.
489    pub fn from_inner(data: Vec<F>, dimension: usize) -> Result<Self, Vec<F>> {
490        let mut this = Self::new(dimension);
491        this.append(data)?;
492        Ok(this)
493    }
494
495    /// Returns the number of embeddings in this vector of embeddings.
496    pub fn embedding_count(&self) -> usize {
497        self.data.len() / self.dimension
498    }
499
500    /// Dimension of a single embedding.
501    pub fn dimension(&self) -> usize {
502        self.dimension
503    }
504
505    /// Deconstructs self into the inner flat vector.
506    pub fn into_inner(self) -> Vec<F> {
507        self.data
508    }
509
510    /// A reference to the inner flat vector.
511    pub fn as_inner(&self) -> &[F] {
512        &self.data
513    }
514
515    /// Iterates over the embeddings contained in the flat vector.
516    pub fn iter(&self) -> impl Iterator<Item = &'_ [F]> + '_ {
517        self.data.as_slice().chunks_exact(self.dimension)
518    }
519
520    /// Push an embedding at the end of the embeddings.
521    ///
522    /// If `embedding.len() != self.dimension`, then the push operation fails.
523    pub fn push(&mut self, mut embedding: Vec<F>) -> Result<(), Vec<F>> {
524        if embedding.len() != self.dimension {
525            return Err(embedding);
526        }
527        self.data.append(&mut embedding);
528        Ok(())
529    }
530
531    /// Append a flat vector of embeddings at the end of the embeddings.
532    ///
533    /// If `embeddings.len() % self.dimension != 0`, then the append operation fails.
534    pub fn append(&mut self, mut embeddings: Vec<F>) -> Result<(), Vec<F>> {
535        if embeddings.len() % self.dimension != 0 {
536            return Err(embeddings);
537        }
538        self.data.append(&mut embeddings);
539        Ok(())
540    }
541}
542
543/// An embedder can be used to transform text into embeddings.
544#[derive(Debug)]
545pub enum Embedder {
546    /// An embedder based on running local models, fetched from the Hugging Face Hub.
547    HuggingFace(hf::Embedder),
548    /// An embedder based on making embedding queries against the OpenAI API.
549    OpenAi(openai::Embedder),
550    /// An embedder based on the user providing the embeddings in the documents and queries.
551    UserProvided(manual::Embedder),
552    /// An embedder based on making embedding queries against an <https://ollama.com> embedding server.
553    Ollama(ollama::Embedder),
554    /// An embedder based on making embedding queries against a generic JSON/REST embedding server.
555    Rest(rest::Embedder),
556    /// An embedder composed of an embedder at search time and an embedder at indexing time.
557    Composite(composite::Embedder),
558}
559
560#[derive(Debug)]
561struct EmbeddingCache {
562    data: Option<Mutex<lru::LruCache<String, Embedding>>>,
563}
564
565impl EmbeddingCache {
566    const MAX_TEXT_LEN: usize = 2000;
567
568    pub fn new(cap: usize) -> Self {
569        let data = NonZeroUsize::new(cap).map(lru::LruCache::new).map(Mutex::new);
570        Self { data }
571    }
572
573    /// Get the embedding corresponding to `text`, if any is present in the cache.
574    pub fn get(&self, text: &str) -> Option<Embedding> {
575        let data = self.data.as_ref()?;
576        if text.len() > Self::MAX_TEXT_LEN {
577            return None;
578        }
579        let mut cache = data.lock().unwrap();
580
581        cache.get(text).cloned()
582    }
583
584    /// Puts a new embedding for the specified `text`
585    pub fn put(&self, text: String, embedding: Embedding) {
586        let Some(data) = self.data.as_ref() else {
587            return;
588        };
589        if text.len() > Self::MAX_TEXT_LEN {
590            return;
591        }
592        tracing::trace!(text, "embedding added to cache");
593
594        let mut cache = data.lock().unwrap();
595
596        cache.put(text, embedding);
597    }
598}
599
600/// Configuration for an embedder.
601#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
602pub struct EmbeddingConfig {
603    /// Options of the embedder, specific to each kind of embedder
604    pub embedder_options: EmbedderOptions,
605    /// Document template
606    pub prompt: PromptData,
607    /// If this embedder is binary quantized
608    pub quantized: Option<bool>,
609    // TODO: add metrics and anything needed
610}
611
612impl EmbeddingConfig {
613    pub fn quantized(&self) -> bool {
614        self.quantized.unwrap_or_default()
615    }
616}
617
618/// Map of embedder configurations.
619///
620/// Each configuration is mapped to a name.
621#[derive(Clone, Default)]
622pub struct EmbeddingConfigs(HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)>);
623
624impl EmbeddingConfigs {
625    /// Create the map from its internal component.s
626    pub fn new(data: HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)>) -> Self {
627        Self(data)
628    }
629
630    pub fn contains(&self, name: &str) -> bool {
631        self.0.contains_key(name)
632    }
633
634    /// Get an embedder configuration and template from its name.
635    pub fn get(&self, name: &str) -> Option<(Arc<Embedder>, Arc<Prompt>, bool)> {
636        self.0.get(name).cloned()
637    }
638
639    pub fn inner_as_ref(&self) -> &HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)> {
640        &self.0
641    }
642
643    pub fn into_inner(self) -> HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)> {
644        self.0
645    }
646}
647
648impl IntoIterator for EmbeddingConfigs {
649    type Item = (String, (Arc<Embedder>, Arc<Prompt>, bool));
650
651    type IntoIter =
652        std::collections::hash_map::IntoIter<String, (Arc<Embedder>, Arc<Prompt>, bool)>;
653
654    fn into_iter(self) -> Self::IntoIter {
655        self.0.into_iter()
656    }
657}
658
659/// Options of an embedder, specific to each kind of embedder.
660#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
661pub enum EmbedderOptions {
662    HuggingFace(hf::EmbedderOptions),
663    OpenAi(openai::EmbedderOptions),
664    Ollama(ollama::EmbedderOptions),
665    UserProvided(manual::EmbedderOptions),
666    Rest(rest::EmbedderOptions),
667    Composite(composite::EmbedderOptions),
668}
669
670impl Default for EmbedderOptions {
671    fn default() -> Self {
672        Self::HuggingFace(Default::default())
673    }
674}
675
676impl Embedder {
677    /// Spawns a new embedder built from its options.
678    pub fn new(
679        options: EmbedderOptions,
680        cache_cap: usize,
681    ) -> std::result::Result<Self, NewEmbedderError> {
682        Ok(match options {
683            EmbedderOptions::HuggingFace(options) => {
684                Self::HuggingFace(hf::Embedder::new(options, cache_cap)?)
685            }
686            EmbedderOptions::OpenAi(options) => {
687                Self::OpenAi(openai::Embedder::new(options, cache_cap)?)
688            }
689            EmbedderOptions::Ollama(options) => {
690                Self::Ollama(ollama::Embedder::new(options, cache_cap)?)
691            }
692            EmbedderOptions::UserProvided(options) => {
693                Self::UserProvided(manual::Embedder::new(options))
694            }
695            EmbedderOptions::Rest(options) => Self::Rest(rest::Embedder::new(
696                options,
697                cache_cap,
698                rest::ConfigurationSource::User,
699            )?),
700            EmbedderOptions::Composite(options) => {
701                Self::Composite(composite::Embedder::new(options, cache_cap)?)
702            }
703        })
704    }
705
706    /// Embed in search context
707
708    #[tracing::instrument(level = "debug", skip_all, target = "search")]
709    pub fn embed_search(
710        &self,
711        text: &str,
712        deadline: Option<Instant>,
713    ) -> std::result::Result<Embedding, EmbedError> {
714        if let Some(cache) = self.cache() {
715            if let Some(embedding) = cache.get(text) {
716                tracing::trace!(text, "embedding found in cache");
717                return Ok(embedding);
718            }
719        }
720        let embedding = match self {
721            Embedder::HuggingFace(embedder) => embedder.embed_one(text),
722            Embedder::OpenAi(embedder) => {
723                embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
724            }
725            Embedder::Ollama(embedder) => {
726                embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
727            }
728            Embedder::UserProvided(embedder) => embedder.embed_one(text),
729            Embedder::Rest(embedder) => embedder
730                .embed_ref(&[text], deadline)?
731                .pop()
732                .ok_or_else(EmbedError::missing_embedding),
733            Embedder::Composite(embedder) => embedder.search.embed_one(text, deadline),
734        }?;
735
736        if let Some(cache) = self.cache() {
737            cache.put(text.to_owned(), embedding.clone());
738        }
739
740        Ok(embedding)
741    }
742
743    /// Embed multiple chunks of texts.
744    ///
745    /// Each chunk is composed of one or multiple texts.
746    pub fn embed_index(
747        &self,
748        text_chunks: Vec<Vec<String>>,
749        threads: &ThreadPoolNoAbort,
750    ) -> std::result::Result<Vec<Vec<Embedding>>, EmbedError> {
751        match self {
752            Embedder::HuggingFace(embedder) => embedder.embed_index(text_chunks),
753            Embedder::OpenAi(embedder) => embedder.embed_index(text_chunks, threads),
754            Embedder::Ollama(embedder) => embedder.embed_index(text_chunks, threads),
755            Embedder::UserProvided(embedder) => embedder.embed_index(text_chunks),
756            Embedder::Rest(embedder) => embedder.embed_index(text_chunks, threads),
757            Embedder::Composite(embedder) => embedder.index.embed_index(text_chunks, threads),
758        }
759    }
760
761    /// Non-owning variant of [`Self::embed_index`].
762    pub fn embed_index_ref(
763        &self,
764        texts: &[&str],
765        threads: &ThreadPoolNoAbort,
766    ) -> std::result::Result<Vec<Embedding>, EmbedError> {
767        match self {
768            Embedder::HuggingFace(embedder) => embedder.embed_index_ref(texts),
769            Embedder::OpenAi(embedder) => embedder.embed_index_ref(texts, threads),
770            Embedder::Ollama(embedder) => embedder.embed_index_ref(texts, threads),
771            Embedder::UserProvided(embedder) => embedder.embed_index_ref(texts),
772            Embedder::Rest(embedder) => embedder.embed_index_ref(texts, threads),
773            Embedder::Composite(embedder) => embedder.index.embed_index_ref(texts, threads),
774        }
775    }
776
777    /// Indicates the preferred number of chunks to pass to [`Self::embed_chunks`]
778    pub fn chunk_count_hint(&self) -> usize {
779        match self {
780            Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
781            Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
782            Embedder::Ollama(embedder) => embedder.chunk_count_hint(),
783            Embedder::UserProvided(_) => 100,
784            Embedder::Rest(embedder) => embedder.chunk_count_hint(),
785            Embedder::Composite(embedder) => embedder.index.chunk_count_hint(),
786        }
787    }
788
789    /// Indicates the preferred number of texts in a single chunk passed to [`Self::embed`]
790    pub fn prompt_count_in_chunk_hint(&self) -> usize {
791        match self {
792            Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
793            Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
794            Embedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(),
795            Embedder::UserProvided(_) => 1,
796            Embedder::Rest(embedder) => embedder.prompt_count_in_chunk_hint(),
797            Embedder::Composite(embedder) => embedder.index.prompt_count_in_chunk_hint(),
798        }
799    }
800
801    /// Indicates the dimensions of a single embedding produced by the embedder.
802    pub fn dimensions(&self) -> usize {
803        match self {
804            Embedder::HuggingFace(embedder) => embedder.dimensions(),
805            Embedder::OpenAi(embedder) => embedder.dimensions(),
806            Embedder::Ollama(embedder) => embedder.dimensions(),
807            Embedder::UserProvided(embedder) => embedder.dimensions(),
808            Embedder::Rest(embedder) => embedder.dimensions(),
809            Embedder::Composite(embedder) => embedder.dimensions(),
810        }
811    }
812
813    /// An optional distribution used to apply an affine transformation to the similarity score of a document.
814    pub fn distribution(&self) -> Option<DistributionShift> {
815        match self {
816            Embedder::HuggingFace(embedder) => embedder.distribution(),
817            Embedder::OpenAi(embedder) => embedder.distribution(),
818            Embedder::Ollama(embedder) => embedder.distribution(),
819            Embedder::UserProvided(embedder) => embedder.distribution(),
820            Embedder::Rest(embedder) => embedder.distribution(),
821            Embedder::Composite(embedder) => embedder.distribution(),
822        }
823    }
824
825    pub fn uses_document_template(&self) -> bool {
826        match self {
827            Embedder::HuggingFace(_)
828            | Embedder::OpenAi(_)
829            | Embedder::Ollama(_)
830            | Embedder::Rest(_) => true,
831            Embedder::UserProvided(_) => false,
832            Embedder::Composite(embedder) => embedder.index.uses_document_template(),
833        }
834    }
835
836    fn cache(&self) -> Option<&EmbeddingCache> {
837        match self {
838            Embedder::HuggingFace(embedder) => Some(embedder.cache()),
839            Embedder::OpenAi(embedder) => Some(embedder.cache()),
840            Embedder::UserProvided(_) => None,
841            Embedder::Ollama(embedder) => Some(embedder.cache()),
842            Embedder::Rest(embedder) => Some(embedder.cache()),
843            Embedder::Composite(embedder) => embedder.search.cache(),
844        }
845    }
846}
847
848/// Describes the mean and sigma of distribution of embedding similarity in the embedding space.
849///
850/// The intended use is to make the similarity score more comparable to the regular ranking score.
851/// This allows to correct effects where results are too "packed" around a certain value.
852#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize, ToSchema)]
853#[serde(from = "DistributionShiftSerializable")]
854#[serde(into = "DistributionShiftSerializable")]
855pub struct DistributionShift {
856    /// Value where the results are "packed".
857    ///
858    /// Similarity scores are translated so that they are packed around 0.5 instead
859    #[schema(value_type = f32)]
860    pub current_mean: OrderedFloat<f32>,
861
862    /// standard deviation of a similarity score.
863    ///
864    /// Set below 0.4 to make the results less packed around the mean, and above 0.4 to make them more packed.
865    #[schema(value_type = f32)]
866    pub current_sigma: OrderedFloat<f32>,
867}
868
869impl<E> Deserr<E> for DistributionShift
870where
871    E: DeserializeError,
872{
873    fn deserialize_from_value<V: deserr::IntoValue>(
874        value: deserr::Value<V>,
875        location: deserr::ValuePointerRef<'_>,
876    ) -> Result<Self, E> {
877        let value = DistributionShiftSerializable::deserialize_from_value(value, location)?;
878        if value.mean < 0. || value.mean > 1. {
879            return Err(deserr::take_cf_content(E::error::<std::convert::Infallible>(
880                None,
881                deserr::ErrorKind::Unexpected {
882                    msg: format!(
883                        "the distribution mean must be in the range [0, 1], got {}",
884                        value.mean
885                    ),
886                },
887                location,
888            )));
889        }
890        if value.sigma <= 0. || value.sigma > 1. {
891            return Err(deserr::take_cf_content(E::error::<std::convert::Infallible>(
892                None,
893                deserr::ErrorKind::Unexpected {
894                    msg: format!(
895                        "the distribution sigma must be in the range ]0, 1], got {}",
896                        value.sigma
897                    ),
898                },
899                location,
900            )));
901        }
902
903        Ok(value.into())
904    }
905}
906
907#[derive(Serialize, Deserialize, Deserr)]
908#[serde(deny_unknown_fields)]
909#[deserr(deny_unknown_fields)]
910struct DistributionShiftSerializable {
911    mean: f32,
912    sigma: f32,
913}
914
915impl From<DistributionShift> for DistributionShiftSerializable {
916    fn from(
917        DistributionShift {
918            current_mean: OrderedFloat(current_mean),
919            current_sigma: OrderedFloat(current_sigma),
920        }: DistributionShift,
921    ) -> Self {
922        Self { mean: current_mean, sigma: current_sigma }
923    }
924}
925
926impl From<DistributionShiftSerializable> for DistributionShift {
927    fn from(DistributionShiftSerializable { mean, sigma }: DistributionShiftSerializable) -> Self {
928        Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) }
929    }
930}
931
932impl DistributionShift {
933    /// `None` if sigma <= 0.
934    pub fn new(mean: f32, sigma: f32) -> Option<Self> {
935        if sigma <= 0.0 {
936            None
937        } else {
938            Some(Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) })
939        }
940    }
941
942    pub fn shift(&self, score: f32) -> f32 {
943        let current_mean = self.current_mean.0;
944        let current_sigma = self.current_sigma.0;
945        // <https://math.stackexchange.com/a/2894689>
946        // We're somewhat abusively mapping the distribution of distances to a gaussian.
947        // The parameters we're given is the mean and sigma of the native result distribution.
948        // We're using them to retarget the distribution to a gaussian centered on 0.5 with a sigma of 0.4.
949
950        let target_mean = 0.5;
951        let target_sigma = 0.4;
952
953        // a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive.
954        let factor = target_sigma / current_sigma;
955        // a*mu1 + b = mu2 => b = mu2 - a*mu1
956        let offset = target_mean - (factor * current_mean);
957
958        let mut score = factor * score + offset;
959
960        // clamp the final score in the ]0, 1] interval.
961        if score <= 0.0 {
962            score = f32::EPSILON;
963        }
964        if score > 1.0 {
965            score = 1.0;
966        }
967
968        score
969    }
970}
971
972/// Whether CUDA is supported in this version of Meilisearch.
973pub const fn is_cuda_enabled() -> bool {
974    cfg!(feature = "cuda")
975}
976
977pub fn arroy_db_range_for_embedder(embedder_id: u8) -> impl Iterator<Item = u16> {
978    let embedder_id = (embedder_id as u16) << 8;
979
980    (0..=u8::MAX).map(move |k| embedder_id | (k as u16))
981}