1use std::collections::HashMap;
2use std::num::NonZeroUsize;
3use std::sync::{Arc, Mutex};
4use std::time::Instant;
5
6use arroy::distances::{BinaryQuantizedCosine, Cosine};
7use arroy::ItemId;
8use deserr::{DeserializeError, Deserr};
9use heed::{RoTxn, RwTxn, Unspecified};
10use ordered_float::OrderedFloat;
11use roaring::RoaringBitmap;
12use serde::{Deserialize, Serialize};
13use utoipa::ToSchema;
14
15use self::error::{EmbedError, NewEmbedderError};
16use crate::progress::Progress;
17use crate::prompt::{Prompt, PromptData};
18use crate::ThreadPoolNoAbort;
19
20pub mod composite;
21pub mod error;
22pub mod hf;
23pub mod json_template;
24pub mod manual;
25pub mod openai;
26pub mod parsed_vectors;
27pub mod settings;
28
29pub mod ollama;
30pub mod rest;
31
32pub use self::error::Error;
33
34pub type Embedding = Vec<f32>;
35
36pub const REQUEST_PARALLELISM: usize = 40;
37pub const MAX_COMPOSITE_DISTANCE: f32 = 0.01;
38
39pub struct ArroyWrapper {
40 quantized: bool,
41 embedder_index: u8,
42 database: arroy::Database<Unspecified>,
43}
44
45impl ArroyWrapper {
46 pub fn new(
47 database: arroy::Database<Unspecified>,
48 embedder_index: u8,
49 quantized: bool,
50 ) -> Self {
51 Self { database, embedder_index, quantized }
52 }
53
54 pub fn embedder_index(&self) -> u8 {
55 self.embedder_index
56 }
57
58 fn readers<'a, D: arroy::Distance>(
59 &'a self,
60 rtxn: &'a RoTxn<'a>,
61 db: arroy::Database<D>,
62 ) -> impl Iterator<Item = Result<arroy::Reader<'a, D>, arroy::Error>> + 'a {
63 arroy_db_range_for_embedder(self.embedder_index).map_while(move |index| {
64 match arroy::Reader::open(rtxn, index, db) {
65 Ok(reader) => match reader.is_empty(rtxn) {
66 Ok(false) => Some(Ok(reader)),
67 Ok(true) => None,
68 Err(e) => Some(Err(e)),
69 },
70 Err(arroy::Error::MissingMetadata(_)) => None,
71 Err(e) => Some(Err(e)),
72 }
73 })
74 }
75
76 pub fn dimensions(&self, rtxn: &RoTxn) -> Result<usize, arroy::Error> {
77 let first_id = arroy_db_range_for_embedder(self.embedder_index).next().unwrap();
78 if self.quantized {
79 Ok(arroy::Reader::open(rtxn, first_id, self.quantized_db())?.dimensions())
80 } else {
81 Ok(arroy::Reader::open(rtxn, first_id, self.angular_db())?.dimensions())
82 }
83 }
84
85 #[allow(clippy::too_many_arguments)]
86 pub fn build_and_quantize<R: rand::Rng + rand::SeedableRng>(
87 &mut self,
88 wtxn: &mut RwTxn,
89 progress: &Progress,
90 rng: &mut R,
91 dimension: usize,
92 quantizing: bool,
93 arroy_memory: Option<usize>,
94 cancel: &(impl Fn() -> bool + Sync + Send),
95 ) -> Result<(), arroy::Error> {
96 for index in arroy_db_range_for_embedder(self.embedder_index) {
97 if self.quantized {
98 let writer = arroy::Writer::new(self.quantized_db(), index, dimension);
99 if writer.need_build(wtxn)? {
100 writer.builder(rng).build(wtxn)?
101 } else if writer.is_empty(wtxn)? {
102 break;
103 }
104 } else {
105 let writer = arroy::Writer::new(self.angular_db(), index, dimension);
106 if quantizing && !self.quantized {
112 let writer = writer.prepare_changing_distance::<BinaryQuantizedCosine>(wtxn)?;
113 writer
114 .builder(rng)
115 .available_memory(arroy_memory.unwrap_or(usize::MAX))
116 .progress(|step| progress.update_progress_from_arroy(step))
117 .cancel(cancel)
118 .build(wtxn)?;
119 } else if writer.need_build(wtxn)? {
120 writer
121 .builder(rng)
122 .available_memory(arroy_memory.unwrap_or(usize::MAX))
123 .progress(|step| progress.update_progress_from_arroy(step))
124 .cancel(cancel)
125 .build(wtxn)?;
126 } else if writer.is_empty(wtxn)? {
127 break;
128 }
129 }
130 }
131 Ok(())
132 }
133
134 pub fn add_items(
139 &self,
140 wtxn: &mut RwTxn,
141 item_id: arroy::ItemId,
142 embeddings: &Embeddings<f32>,
143 ) -> Result<(), arroy::Error> {
144 let dimension = embeddings.dimension();
145 for (index, vector) in
146 arroy_db_range_for_embedder(self.embedder_index).zip(embeddings.iter())
147 {
148 if self.quantized {
149 arroy::Writer::new(self.quantized_db(), index, dimension)
150 .add_item(wtxn, item_id, vector)?
151 } else {
152 arroy::Writer::new(self.angular_db(), index, dimension)
153 .add_item(wtxn, item_id, vector)?
154 }
155 }
156 Ok(())
157 }
158
159 pub fn add_item(
161 &self,
162 wtxn: &mut RwTxn,
163 item_id: arroy::ItemId,
164 vector: &[f32],
165 ) -> Result<(), arroy::Error> {
166 if self.quantized {
167 self._add_item(wtxn, self.quantized_db(), item_id, vector)
168 } else {
169 self._add_item(wtxn, self.angular_db(), item_id, vector)
170 }
171 }
172
173 fn _add_item<D: arroy::Distance>(
174 &self,
175 wtxn: &mut RwTxn,
176 db: arroy::Database<D>,
177 item_id: arroy::ItemId,
178 vector: &[f32],
179 ) -> Result<(), arroy::Error> {
180 let dimension = vector.len();
181
182 for index in arroy_db_range_for_embedder(self.embedder_index) {
183 let writer = arroy::Writer::new(db, index, dimension);
184 if !writer.contains_item(wtxn, item_id)? {
185 writer.add_item(wtxn, item_id, vector)?;
186 break;
187 }
188 }
189 Ok(())
190 }
191
192 pub fn del_items(
194 &self,
195 wtxn: &mut RwTxn,
196 dimension: usize,
197 item_id: arroy::ItemId,
198 ) -> Result<(), arroy::Error> {
199 for index in arroy_db_range_for_embedder(self.embedder_index) {
200 if self.quantized {
201 let writer = arroy::Writer::new(self.quantized_db(), index, dimension);
202 if !writer.del_item(wtxn, item_id)? {
203 break;
204 }
205 } else {
206 let writer = arroy::Writer::new(self.angular_db(), index, dimension);
207 if !writer.del_item(wtxn, item_id)? {
208 break;
209 }
210 }
211 }
212
213 Ok(())
214 }
215
216 pub fn del_item(
218 &self,
219 wtxn: &mut RwTxn,
220 item_id: arroy::ItemId,
221 vector: &[f32],
222 ) -> Result<bool, arroy::Error> {
223 if self.quantized {
224 self._del_item(wtxn, self.quantized_db(), item_id, vector)
225 } else {
226 self._del_item(wtxn, self.angular_db(), item_id, vector)
227 }
228 }
229
230 fn _del_item<D: arroy::Distance>(
231 &self,
232 wtxn: &mut RwTxn,
233 db: arroy::Database<D>,
234 item_id: arroy::ItemId,
235 vector: &[f32],
236 ) -> Result<bool, arroy::Error> {
237 let dimension = vector.len();
238 let mut deleted_index = None;
239
240 for index in arroy_db_range_for_embedder(self.embedder_index) {
241 let writer = arroy::Writer::new(db, index, dimension);
242 let Some(candidate) = writer.item_vector(wtxn, item_id)? else {
243 break;
245 };
246 if candidate == vector {
247 writer.del_item(wtxn, item_id)?;
248 deleted_index = Some(index);
249 }
250 }
251
252 if let Some(deleted_index) = deleted_index {
254 let mut last_index_with_a_vector = None;
255 for index in
256 arroy_db_range_for_embedder(self.embedder_index).skip(deleted_index as usize)
257 {
258 let writer = arroy::Writer::new(db, index, dimension);
259 let Some(candidate) = writer.item_vector(wtxn, item_id)? else {
260 break;
261 };
262 last_index_with_a_vector = Some((index, candidate));
263 }
264 if let Some((last_index, vector)) = last_index_with_a_vector {
265 let writer = arroy::Writer::new(db, last_index, dimension);
266 writer.del_item(wtxn, item_id)?;
267 let writer = arroy::Writer::new(db, deleted_index, dimension);
268 writer.add_item(wtxn, item_id, &vector)?;
269 }
270 }
271 Ok(deleted_index.is_some())
272 }
273
274 pub fn clear(&self, wtxn: &mut RwTxn, dimension: usize) -> Result<(), arroy::Error> {
275 for index in arroy_db_range_for_embedder(self.embedder_index) {
276 if self.quantized {
277 let writer = arroy::Writer::new(self.quantized_db(), index, dimension);
278 if writer.is_empty(wtxn)? {
279 break;
280 }
281 writer.clear(wtxn)?;
282 } else {
283 let writer = arroy::Writer::new(self.angular_db(), index, dimension);
284 if writer.is_empty(wtxn)? {
285 break;
286 }
287 writer.clear(wtxn)?;
288 }
289 }
290 Ok(())
291 }
292
293 pub fn contains_item(
294 &self,
295 rtxn: &RoTxn,
296 dimension: usize,
297 item: arroy::ItemId,
298 ) -> Result<bool, arroy::Error> {
299 for index in arroy_db_range_for_embedder(self.embedder_index) {
300 let contains = if self.quantized {
301 let writer = arroy::Writer::new(self.quantized_db(), index, dimension);
302 if writer.is_empty(rtxn)? {
303 break;
304 }
305 writer.contains_item(rtxn, item)?
306 } else {
307 let writer = arroy::Writer::new(self.angular_db(), index, dimension);
308 if writer.is_empty(rtxn)? {
309 break;
310 }
311 writer.contains_item(rtxn, item)?
312 };
313 if contains {
314 return Ok(contains);
315 }
316 }
317 Ok(false)
318 }
319
320 pub fn nns_by_item(
321 &self,
322 rtxn: &RoTxn,
323 item: ItemId,
324 limit: usize,
325 filter: Option<&RoaringBitmap>,
326 ) -> Result<Vec<(ItemId, f32)>, arroy::Error> {
327 if self.quantized {
328 self._nns_by_item(rtxn, self.quantized_db(), item, limit, filter)
329 } else {
330 self._nns_by_item(rtxn, self.angular_db(), item, limit, filter)
331 }
332 }
333
334 fn _nns_by_item<D: arroy::Distance>(
335 &self,
336 rtxn: &RoTxn,
337 db: arroy::Database<D>,
338 item: ItemId,
339 limit: usize,
340 filter: Option<&RoaringBitmap>,
341 ) -> Result<Vec<(ItemId, f32)>, arroy::Error> {
342 let mut results = Vec::new();
343
344 for reader in self.readers(rtxn, db) {
345 let reader = reader?;
346 let mut searcher = reader.nns(limit);
347 if let Some(filter) = filter {
348 searcher.candidates(filter);
349 }
350
351 if let Some(mut ret) = searcher.by_item(rtxn, item)? {
352 results.append(&mut ret);
353 } else {
354 break;
355 }
356 }
357 results.sort_unstable_by_key(|(_, distance)| OrderedFloat(*distance));
358 Ok(results)
359 }
360
361 pub fn nns_by_vector(
362 &self,
363 rtxn: &RoTxn,
364 vector: &[f32],
365 limit: usize,
366 filter: Option<&RoaringBitmap>,
367 ) -> Result<Vec<(ItemId, f32)>, arroy::Error> {
368 if self.quantized {
369 self._nns_by_vector(rtxn, self.quantized_db(), vector, limit, filter)
370 } else {
371 self._nns_by_vector(rtxn, self.angular_db(), vector, limit, filter)
372 }
373 }
374
375 fn _nns_by_vector<D: arroy::Distance>(
376 &self,
377 rtxn: &RoTxn,
378 db: arroy::Database<D>,
379 vector: &[f32],
380 limit: usize,
381 filter: Option<&RoaringBitmap>,
382 ) -> Result<Vec<(ItemId, f32)>, arroy::Error> {
383 let mut results = Vec::new();
384
385 for reader in self.readers(rtxn, db) {
386 let reader = reader?;
387 let mut searcher = reader.nns(limit);
388 if let Some(filter) = filter {
389 searcher.candidates(filter);
390 }
391
392 results.append(&mut searcher.by_vector(rtxn, vector)?);
393 }
394
395 results.sort_unstable_by_key(|(_, distance)| OrderedFloat(*distance));
396
397 Ok(results)
398 }
399
400 pub fn item_vectors(&self, rtxn: &RoTxn, item_id: u32) -> Result<Vec<Vec<f32>>, arroy::Error> {
401 let mut vectors = Vec::new();
402
403 if self.quantized {
404 for reader in self.readers(rtxn, self.quantized_db()) {
405 if let Some(vec) = reader?.item_vector(rtxn, item_id)? {
406 vectors.push(vec);
407 } else {
408 break;
409 }
410 }
411 } else {
412 for reader in self.readers(rtxn, self.angular_db()) {
413 if let Some(vec) = reader?.item_vector(rtxn, item_id)? {
414 vectors.push(vec);
415 } else {
416 break;
417 }
418 }
419 }
420 Ok(vectors)
421 }
422
423 fn angular_db(&self) -> arroy::Database<Cosine> {
424 self.database.remap_data_type()
425 }
426
427 fn quantized_db(&self) -> arroy::Database<BinaryQuantizedCosine> {
428 self.database.remap_data_type()
429 }
430
431 pub fn aggregate_stats(
432 &self,
433 rtxn: &RoTxn,
434 stats: &mut ArroyStats,
435 ) -> Result<(), arroy::Error> {
436 if self.quantized {
437 for reader in self.readers(rtxn, self.quantized_db()) {
438 let reader = reader?;
439 let documents = reader.item_ids();
440 if documents.is_empty() {
441 break;
442 }
443 stats.documents |= documents;
444 stats.number_of_embeddings += documents.len();
445 }
446 } else {
447 for reader in self.readers(rtxn, self.angular_db()) {
448 let reader = reader?;
449 let documents = reader.item_ids();
450 if documents.is_empty() {
451 break;
452 }
453 stats.documents |= documents;
454 stats.number_of_embeddings += documents.len();
455 }
456 }
457
458 Ok(())
459 }
460}
461
462#[derive(Debug, Default, Clone)]
463pub struct ArroyStats {
464 pub number_of_embeddings: u64,
465 pub documents: RoaringBitmap,
466}
467pub struct Embeddings<F> {
469 data: Vec<F>,
470 dimension: usize,
471}
472
473impl<F> Embeddings<F> {
474 pub fn new(dimension: usize) -> Self {
476 Self { data: Default::default(), dimension }
477 }
478
479 pub fn from_single_embedding(embedding: Vec<F>) -> Self {
483 Self { dimension: embedding.len(), data: embedding }
484 }
485
486 pub fn from_inner(data: Vec<F>, dimension: usize) -> Result<Self, Vec<F>> {
490 let mut this = Self::new(dimension);
491 this.append(data)?;
492 Ok(this)
493 }
494
495 pub fn embedding_count(&self) -> usize {
497 self.data.len() / self.dimension
498 }
499
500 pub fn dimension(&self) -> usize {
502 self.dimension
503 }
504
505 pub fn into_inner(self) -> Vec<F> {
507 self.data
508 }
509
510 pub fn as_inner(&self) -> &[F] {
512 &self.data
513 }
514
515 pub fn iter(&self) -> impl Iterator<Item = &'_ [F]> + '_ {
517 self.data.as_slice().chunks_exact(self.dimension)
518 }
519
520 pub fn push(&mut self, mut embedding: Vec<F>) -> Result<(), Vec<F>> {
524 if embedding.len() != self.dimension {
525 return Err(embedding);
526 }
527 self.data.append(&mut embedding);
528 Ok(())
529 }
530
531 pub fn append(&mut self, mut embeddings: Vec<F>) -> Result<(), Vec<F>> {
535 if embeddings.len() % self.dimension != 0 {
536 return Err(embeddings);
537 }
538 self.data.append(&mut embeddings);
539 Ok(())
540 }
541}
542
543#[derive(Debug)]
545pub enum Embedder {
546 HuggingFace(hf::Embedder),
548 OpenAi(openai::Embedder),
550 UserProvided(manual::Embedder),
552 Ollama(ollama::Embedder),
554 Rest(rest::Embedder),
556 Composite(composite::Embedder),
558}
559
560#[derive(Debug)]
561struct EmbeddingCache {
562 data: Option<Mutex<lru::LruCache<String, Embedding>>>,
563}
564
565impl EmbeddingCache {
566 const MAX_TEXT_LEN: usize = 2000;
567
568 pub fn new(cap: usize) -> Self {
569 let data = NonZeroUsize::new(cap).map(lru::LruCache::new).map(Mutex::new);
570 Self { data }
571 }
572
573 pub fn get(&self, text: &str) -> Option<Embedding> {
575 let data = self.data.as_ref()?;
576 if text.len() > Self::MAX_TEXT_LEN {
577 return None;
578 }
579 let mut cache = data.lock().unwrap();
580
581 cache.get(text).cloned()
582 }
583
584 pub fn put(&self, text: String, embedding: Embedding) {
586 let Some(data) = self.data.as_ref() else {
587 return;
588 };
589 if text.len() > Self::MAX_TEXT_LEN {
590 return;
591 }
592 tracing::trace!(text, "embedding added to cache");
593
594 let mut cache = data.lock().unwrap();
595
596 cache.put(text, embedding);
597 }
598}
599
600#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
602pub struct EmbeddingConfig {
603 pub embedder_options: EmbedderOptions,
605 pub prompt: PromptData,
607 pub quantized: Option<bool>,
609 }
611
612impl EmbeddingConfig {
613 pub fn quantized(&self) -> bool {
614 self.quantized.unwrap_or_default()
615 }
616}
617
618#[derive(Clone, Default)]
622pub struct EmbeddingConfigs(HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)>);
623
624impl EmbeddingConfigs {
625 pub fn new(data: HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)>) -> Self {
627 Self(data)
628 }
629
630 pub fn contains(&self, name: &str) -> bool {
631 self.0.contains_key(name)
632 }
633
634 pub fn get(&self, name: &str) -> Option<(Arc<Embedder>, Arc<Prompt>, bool)> {
636 self.0.get(name).cloned()
637 }
638
639 pub fn inner_as_ref(&self) -> &HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)> {
640 &self.0
641 }
642
643 pub fn into_inner(self) -> HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)> {
644 self.0
645 }
646}
647
648impl IntoIterator for EmbeddingConfigs {
649 type Item = (String, (Arc<Embedder>, Arc<Prompt>, bool));
650
651 type IntoIter =
652 std::collections::hash_map::IntoIter<String, (Arc<Embedder>, Arc<Prompt>, bool)>;
653
654 fn into_iter(self) -> Self::IntoIter {
655 self.0.into_iter()
656 }
657}
658
659#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
661pub enum EmbedderOptions {
662 HuggingFace(hf::EmbedderOptions),
663 OpenAi(openai::EmbedderOptions),
664 Ollama(ollama::EmbedderOptions),
665 UserProvided(manual::EmbedderOptions),
666 Rest(rest::EmbedderOptions),
667 Composite(composite::EmbedderOptions),
668}
669
670impl Default for EmbedderOptions {
671 fn default() -> Self {
672 Self::HuggingFace(Default::default())
673 }
674}
675
676impl Embedder {
677 pub fn new(
679 options: EmbedderOptions,
680 cache_cap: usize,
681 ) -> std::result::Result<Self, NewEmbedderError> {
682 Ok(match options {
683 EmbedderOptions::HuggingFace(options) => {
684 Self::HuggingFace(hf::Embedder::new(options, cache_cap)?)
685 }
686 EmbedderOptions::OpenAi(options) => {
687 Self::OpenAi(openai::Embedder::new(options, cache_cap)?)
688 }
689 EmbedderOptions::Ollama(options) => {
690 Self::Ollama(ollama::Embedder::new(options, cache_cap)?)
691 }
692 EmbedderOptions::UserProvided(options) => {
693 Self::UserProvided(manual::Embedder::new(options))
694 }
695 EmbedderOptions::Rest(options) => Self::Rest(rest::Embedder::new(
696 options,
697 cache_cap,
698 rest::ConfigurationSource::User,
699 )?),
700 EmbedderOptions::Composite(options) => {
701 Self::Composite(composite::Embedder::new(options, cache_cap)?)
702 }
703 })
704 }
705
706 #[tracing::instrument(level = "debug", skip_all, target = "search")]
709 pub fn embed_search(
710 &self,
711 text: &str,
712 deadline: Option<Instant>,
713 ) -> std::result::Result<Embedding, EmbedError> {
714 if let Some(cache) = self.cache() {
715 if let Some(embedding) = cache.get(text) {
716 tracing::trace!(text, "embedding found in cache");
717 return Ok(embedding);
718 }
719 }
720 let embedding = match self {
721 Embedder::HuggingFace(embedder) => embedder.embed_one(text),
722 Embedder::OpenAi(embedder) => {
723 embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
724 }
725 Embedder::Ollama(embedder) => {
726 embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
727 }
728 Embedder::UserProvided(embedder) => embedder.embed_one(text),
729 Embedder::Rest(embedder) => embedder
730 .embed_ref(&[text], deadline)?
731 .pop()
732 .ok_or_else(EmbedError::missing_embedding),
733 Embedder::Composite(embedder) => embedder.search.embed_one(text, deadline),
734 }?;
735
736 if let Some(cache) = self.cache() {
737 cache.put(text.to_owned(), embedding.clone());
738 }
739
740 Ok(embedding)
741 }
742
743 pub fn embed_index(
747 &self,
748 text_chunks: Vec<Vec<String>>,
749 threads: &ThreadPoolNoAbort,
750 ) -> std::result::Result<Vec<Vec<Embedding>>, EmbedError> {
751 match self {
752 Embedder::HuggingFace(embedder) => embedder.embed_index(text_chunks),
753 Embedder::OpenAi(embedder) => embedder.embed_index(text_chunks, threads),
754 Embedder::Ollama(embedder) => embedder.embed_index(text_chunks, threads),
755 Embedder::UserProvided(embedder) => embedder.embed_index(text_chunks),
756 Embedder::Rest(embedder) => embedder.embed_index(text_chunks, threads),
757 Embedder::Composite(embedder) => embedder.index.embed_index(text_chunks, threads),
758 }
759 }
760
761 pub fn embed_index_ref(
763 &self,
764 texts: &[&str],
765 threads: &ThreadPoolNoAbort,
766 ) -> std::result::Result<Vec<Embedding>, EmbedError> {
767 match self {
768 Embedder::HuggingFace(embedder) => embedder.embed_index_ref(texts),
769 Embedder::OpenAi(embedder) => embedder.embed_index_ref(texts, threads),
770 Embedder::Ollama(embedder) => embedder.embed_index_ref(texts, threads),
771 Embedder::UserProvided(embedder) => embedder.embed_index_ref(texts),
772 Embedder::Rest(embedder) => embedder.embed_index_ref(texts, threads),
773 Embedder::Composite(embedder) => embedder.index.embed_index_ref(texts, threads),
774 }
775 }
776
777 pub fn chunk_count_hint(&self) -> usize {
779 match self {
780 Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
781 Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
782 Embedder::Ollama(embedder) => embedder.chunk_count_hint(),
783 Embedder::UserProvided(_) => 100,
784 Embedder::Rest(embedder) => embedder.chunk_count_hint(),
785 Embedder::Composite(embedder) => embedder.index.chunk_count_hint(),
786 }
787 }
788
789 pub fn prompt_count_in_chunk_hint(&self) -> usize {
791 match self {
792 Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
793 Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
794 Embedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(),
795 Embedder::UserProvided(_) => 1,
796 Embedder::Rest(embedder) => embedder.prompt_count_in_chunk_hint(),
797 Embedder::Composite(embedder) => embedder.index.prompt_count_in_chunk_hint(),
798 }
799 }
800
801 pub fn dimensions(&self) -> usize {
803 match self {
804 Embedder::HuggingFace(embedder) => embedder.dimensions(),
805 Embedder::OpenAi(embedder) => embedder.dimensions(),
806 Embedder::Ollama(embedder) => embedder.dimensions(),
807 Embedder::UserProvided(embedder) => embedder.dimensions(),
808 Embedder::Rest(embedder) => embedder.dimensions(),
809 Embedder::Composite(embedder) => embedder.dimensions(),
810 }
811 }
812
813 pub fn distribution(&self) -> Option<DistributionShift> {
815 match self {
816 Embedder::HuggingFace(embedder) => embedder.distribution(),
817 Embedder::OpenAi(embedder) => embedder.distribution(),
818 Embedder::Ollama(embedder) => embedder.distribution(),
819 Embedder::UserProvided(embedder) => embedder.distribution(),
820 Embedder::Rest(embedder) => embedder.distribution(),
821 Embedder::Composite(embedder) => embedder.distribution(),
822 }
823 }
824
825 pub fn uses_document_template(&self) -> bool {
826 match self {
827 Embedder::HuggingFace(_)
828 | Embedder::OpenAi(_)
829 | Embedder::Ollama(_)
830 | Embedder::Rest(_) => true,
831 Embedder::UserProvided(_) => false,
832 Embedder::Composite(embedder) => embedder.index.uses_document_template(),
833 }
834 }
835
836 fn cache(&self) -> Option<&EmbeddingCache> {
837 match self {
838 Embedder::HuggingFace(embedder) => Some(embedder.cache()),
839 Embedder::OpenAi(embedder) => Some(embedder.cache()),
840 Embedder::UserProvided(_) => None,
841 Embedder::Ollama(embedder) => Some(embedder.cache()),
842 Embedder::Rest(embedder) => Some(embedder.cache()),
843 Embedder::Composite(embedder) => embedder.search.cache(),
844 }
845 }
846}
847
848#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize, ToSchema)]
853#[serde(from = "DistributionShiftSerializable")]
854#[serde(into = "DistributionShiftSerializable")]
855pub struct DistributionShift {
856 #[schema(value_type = f32)]
860 pub current_mean: OrderedFloat<f32>,
861
862 #[schema(value_type = f32)]
866 pub current_sigma: OrderedFloat<f32>,
867}
868
869impl<E> Deserr<E> for DistributionShift
870where
871 E: DeserializeError,
872{
873 fn deserialize_from_value<V: deserr::IntoValue>(
874 value: deserr::Value<V>,
875 location: deserr::ValuePointerRef<'_>,
876 ) -> Result<Self, E> {
877 let value = DistributionShiftSerializable::deserialize_from_value(value, location)?;
878 if value.mean < 0. || value.mean > 1. {
879 return Err(deserr::take_cf_content(E::error::<std::convert::Infallible>(
880 None,
881 deserr::ErrorKind::Unexpected {
882 msg: format!(
883 "the distribution mean must be in the range [0, 1], got {}",
884 value.mean
885 ),
886 },
887 location,
888 )));
889 }
890 if value.sigma <= 0. || value.sigma > 1. {
891 return Err(deserr::take_cf_content(E::error::<std::convert::Infallible>(
892 None,
893 deserr::ErrorKind::Unexpected {
894 msg: format!(
895 "the distribution sigma must be in the range ]0, 1], got {}",
896 value.sigma
897 ),
898 },
899 location,
900 )));
901 }
902
903 Ok(value.into())
904 }
905}
906
907#[derive(Serialize, Deserialize, Deserr)]
908#[serde(deny_unknown_fields)]
909#[deserr(deny_unknown_fields)]
910struct DistributionShiftSerializable {
911 mean: f32,
912 sigma: f32,
913}
914
915impl From<DistributionShift> for DistributionShiftSerializable {
916 fn from(
917 DistributionShift {
918 current_mean: OrderedFloat(current_mean),
919 current_sigma: OrderedFloat(current_sigma),
920 }: DistributionShift,
921 ) -> Self {
922 Self { mean: current_mean, sigma: current_sigma }
923 }
924}
925
926impl From<DistributionShiftSerializable> for DistributionShift {
927 fn from(DistributionShiftSerializable { mean, sigma }: DistributionShiftSerializable) -> Self {
928 Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) }
929 }
930}
931
932impl DistributionShift {
933 pub fn new(mean: f32, sigma: f32) -> Option<Self> {
935 if sigma <= 0.0 {
936 None
937 } else {
938 Some(Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) })
939 }
940 }
941
942 pub fn shift(&self, score: f32) -> f32 {
943 let current_mean = self.current_mean.0;
944 let current_sigma = self.current_sigma.0;
945 let target_mean = 0.5;
951 let target_sigma = 0.4;
952
953 let factor = target_sigma / current_sigma;
955 let offset = target_mean - (factor * current_mean);
957
958 let mut score = factor * score + offset;
959
960 if score <= 0.0 {
962 score = f32::EPSILON;
963 }
964 if score > 1.0 {
965 score = 1.0;
966 }
967
968 score
969 }
970}
971
972pub const fn is_cuda_enabled() -> bool {
974 cfg!(feature = "cuda")
975}
976
977pub fn arroy_db_range_for_embedder(embedder_id: u8) -> impl Iterator<Item = u16> {
978 let embedder_id = (embedder_id as u16) << 8;
979
980 (0..=u8::MAX).map(move |k| embedder_id | (k as u16))
981}