milli_core/vector/
composite.rs

1use std::time::Instant;
2
3use arroy::Distance;
4
5use super::error::CompositeEmbedderContainsHuggingFace;
6use super::{
7    hf, manual, ollama, openai, rest, DistributionShift, EmbedError, Embedding, EmbeddingCache,
8    NewEmbedderError,
9};
10use crate::ThreadPoolNoAbort;
11
12#[derive(Debug)]
13pub enum SubEmbedder {
14    /// An embedder based on running local models, fetched from the Hugging Face Hub.
15    HuggingFace(hf::Embedder),
16    /// An embedder based on making embedding queries against the OpenAI API.
17    OpenAi(openai::Embedder),
18    /// An embedder based on the user providing the embeddings in the documents and queries.
19    UserProvided(manual::Embedder),
20    /// An embedder based on making embedding queries against an <https://ollama.com> embedding server.
21    Ollama(ollama::Embedder),
22    /// An embedder based on making embedding queries against a generic JSON/REST embedding server.
23    Rest(rest::Embedder),
24}
25
26#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
27pub enum SubEmbedderOptions {
28    HuggingFace(hf::EmbedderOptions),
29    OpenAi(openai::EmbedderOptions),
30    Ollama(ollama::EmbedderOptions),
31    UserProvided(manual::EmbedderOptions),
32    Rest(rest::EmbedderOptions),
33}
34
35impl SubEmbedderOptions {
36    pub fn distribution(&self) -> Option<DistributionShift> {
37        match self {
38            SubEmbedderOptions::HuggingFace(embedder_options) => embedder_options.distribution,
39            SubEmbedderOptions::OpenAi(embedder_options) => embedder_options.distribution,
40            SubEmbedderOptions::Ollama(embedder_options) => embedder_options.distribution,
41            SubEmbedderOptions::UserProvided(embedder_options) => embedder_options.distribution,
42            SubEmbedderOptions::Rest(embedder_options) => embedder_options.distribution,
43        }
44    }
45}
46
47#[derive(Debug)]
48pub struct Embedder {
49    pub(super) search: SubEmbedder,
50    pub(super) index: SubEmbedder,
51}
52
53#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
54pub struct EmbedderOptions {
55    pub search: SubEmbedderOptions,
56    pub index: SubEmbedderOptions,
57}
58
59impl Embedder {
60    pub fn new(
61        EmbedderOptions { search, index }: EmbedderOptions,
62        cache_cap: usize,
63    ) -> Result<Self, NewEmbedderError> {
64        let search = SubEmbedder::new(search, cache_cap)?;
65        // cache is only used at search
66        let index = SubEmbedder::new(index, 0)?;
67
68        // check dimensions
69        if search.dimensions() != index.dimensions() {
70            return Err(NewEmbedderError::composite_dimensions_mismatch(
71                search.dimensions(),
72                index.dimensions(),
73            ));
74        }
75        // check similarity
76        let search_embeddings = search
77            .embed(
78                vec![
79                    "test".into(),
80                    "a brave dog".into(),
81                    "This is a sample text. It is meant to compare similarity.".into(),
82                ],
83                None,
84            )
85            .map_err(|error| NewEmbedderError::composite_test_embedding_failed(error, "search"))?;
86
87        let index_embeddings = index
88            .embed(
89                vec![
90                    "test".into(),
91                    "a brave dog".into(),
92                    "This is a sample text. It is meant to compare similarity.".into(),
93                ],
94                None,
95            )
96            .map_err(|error| {
97                NewEmbedderError::composite_test_embedding_failed(error, "indexing")
98            })?;
99
100        let hint = configuration_hint(&search, &index);
101
102        check_similarity(search_embeddings, index_embeddings, hint)?;
103
104        Ok(Self { search, index })
105    }
106
107    /// Indicates the dimensions of a single embedding produced by the embedder.
108    pub fn dimensions(&self) -> usize {
109        // can use the dimensions of any embedder since they should match
110        self.index.dimensions()
111    }
112
113    /// An optional distribution used to apply an affine transformation to the similarity score of a document.
114    pub fn distribution(&self) -> Option<DistributionShift> {
115        // 3 cases here:
116        // 1. distribution provided by user => use that one, which was stored in search
117        // 2. no user-provided distribution, distribution in search embedder => use that one
118        // 2. no user-provided distribution, no distribution in search embedder => use the distribution in indexing embedder
119        self.search.distribution().or_else(|| self.index.distribution())
120    }
121}
122
123impl SubEmbedder {
124    pub fn new(
125        options: SubEmbedderOptions,
126        cache_cap: usize,
127    ) -> std::result::Result<Self, NewEmbedderError> {
128        Ok(match options {
129            SubEmbedderOptions::HuggingFace(options) => {
130                Self::HuggingFace(hf::Embedder::new(options, cache_cap)?)
131            }
132            SubEmbedderOptions::OpenAi(options) => {
133                Self::OpenAi(openai::Embedder::new(options, cache_cap)?)
134            }
135            SubEmbedderOptions::Ollama(options) => {
136                Self::Ollama(ollama::Embedder::new(options, cache_cap)?)
137            }
138            SubEmbedderOptions::UserProvided(options) => {
139                Self::UserProvided(manual::Embedder::new(options))
140            }
141            SubEmbedderOptions::Rest(options) => Self::Rest(rest::Embedder::new(
142                options,
143                cache_cap,
144                rest::ConfigurationSource::User,
145            )?),
146        })
147    }
148
149    pub fn embed(
150        &self,
151        texts: Vec<String>,
152        deadline: Option<Instant>,
153    ) -> std::result::Result<Vec<Embedding>, EmbedError> {
154        match self {
155            SubEmbedder::HuggingFace(embedder) => embedder.embed(texts),
156            SubEmbedder::OpenAi(embedder) => embedder.embed(&texts, deadline),
157            SubEmbedder::Ollama(embedder) => embedder.embed(&texts, deadline),
158            SubEmbedder::UserProvided(embedder) => embedder.embed(&texts),
159            SubEmbedder::Rest(embedder) => embedder.embed(texts, deadline),
160        }
161    }
162
163    pub fn embed_one(
164        &self,
165        text: &str,
166        deadline: Option<Instant>,
167    ) -> std::result::Result<Embedding, EmbedError> {
168        match self {
169            SubEmbedder::HuggingFace(embedder) => embedder.embed_one(text),
170            SubEmbedder::OpenAi(embedder) => {
171                embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
172            }
173            SubEmbedder::Ollama(embedder) => {
174                embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
175            }
176            SubEmbedder::UserProvided(embedder) => embedder.embed_one(text),
177            SubEmbedder::Rest(embedder) => embedder
178                .embed_ref(&[text], deadline)?
179                .pop()
180                .ok_or_else(EmbedError::missing_embedding),
181        }
182    }
183
184    /// Embed multiple chunks of texts.
185    ///
186    /// Each chunk is composed of one or multiple texts.
187    pub fn embed_index(
188        &self,
189        text_chunks: Vec<Vec<String>>,
190        threads: &ThreadPoolNoAbort,
191    ) -> std::result::Result<Vec<Vec<Embedding>>, EmbedError> {
192        match self {
193            SubEmbedder::HuggingFace(embedder) => embedder.embed_index(text_chunks),
194            SubEmbedder::OpenAi(embedder) => embedder.embed_index(text_chunks, threads),
195            SubEmbedder::Ollama(embedder) => embedder.embed_index(text_chunks, threads),
196            SubEmbedder::UserProvided(embedder) => embedder.embed_index(text_chunks),
197            SubEmbedder::Rest(embedder) => embedder.embed_index(text_chunks, threads),
198        }
199    }
200
201    /// Non-owning variant of [`Self::embed_index`].
202    pub fn embed_index_ref(
203        &self,
204        texts: &[&str],
205        threads: &ThreadPoolNoAbort,
206    ) -> std::result::Result<Vec<Embedding>, EmbedError> {
207        match self {
208            SubEmbedder::HuggingFace(embedder) => embedder.embed_index_ref(texts),
209            SubEmbedder::OpenAi(embedder) => embedder.embed_index_ref(texts, threads),
210            SubEmbedder::Ollama(embedder) => embedder.embed_index_ref(texts, threads),
211            SubEmbedder::UserProvided(embedder) => embedder.embed_index_ref(texts),
212            SubEmbedder::Rest(embedder) => embedder.embed_index_ref(texts, threads),
213        }
214    }
215
216    /// Indicates the preferred number of chunks to pass to [`Self::embed_chunks`]
217    pub fn chunk_count_hint(&self) -> usize {
218        match self {
219            SubEmbedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
220            SubEmbedder::OpenAi(embedder) => embedder.chunk_count_hint(),
221            SubEmbedder::Ollama(embedder) => embedder.chunk_count_hint(),
222            SubEmbedder::UserProvided(_) => 100,
223            SubEmbedder::Rest(embedder) => embedder.chunk_count_hint(),
224        }
225    }
226
227    /// Indicates the preferred number of texts in a single chunk passed to [`Self::embed`]
228    pub fn prompt_count_in_chunk_hint(&self) -> usize {
229        match self {
230            SubEmbedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
231            SubEmbedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
232            SubEmbedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(),
233            SubEmbedder::UserProvided(_) => 1,
234            SubEmbedder::Rest(embedder) => embedder.prompt_count_in_chunk_hint(),
235        }
236    }
237
238    pub fn uses_document_template(&self) -> bool {
239        match self {
240            SubEmbedder::HuggingFace(_)
241            | SubEmbedder::OpenAi(_)
242            | SubEmbedder::Ollama(_)
243            | SubEmbedder::Rest(_) => true,
244            SubEmbedder::UserProvided(_) => false,
245        }
246    }
247
248    /// Indicates the dimensions of a single embedding produced by the embedder.
249    pub fn dimensions(&self) -> usize {
250        match self {
251            SubEmbedder::HuggingFace(embedder) => embedder.dimensions(),
252            SubEmbedder::OpenAi(embedder) => embedder.dimensions(),
253            SubEmbedder::Ollama(embedder) => embedder.dimensions(),
254            SubEmbedder::UserProvided(embedder) => embedder.dimensions(),
255            SubEmbedder::Rest(embedder) => embedder.dimensions(),
256        }
257    }
258
259    /// An optional distribution used to apply an affine transformation to the similarity score of a document.
260    pub fn distribution(&self) -> Option<DistributionShift> {
261        match self {
262            SubEmbedder::HuggingFace(embedder) => embedder.distribution(),
263            SubEmbedder::OpenAi(embedder) => embedder.distribution(),
264            SubEmbedder::Ollama(embedder) => embedder.distribution(),
265            SubEmbedder::UserProvided(embedder) => embedder.distribution(),
266            SubEmbedder::Rest(embedder) => embedder.distribution(),
267        }
268    }
269
270    pub(super) fn cache(&self) -> Option<&EmbeddingCache> {
271        match self {
272            SubEmbedder::HuggingFace(embedder) => Some(embedder.cache()),
273            SubEmbedder::OpenAi(embedder) => Some(embedder.cache()),
274            SubEmbedder::UserProvided(_) => None,
275            SubEmbedder::Ollama(embedder) => Some(embedder.cache()),
276            SubEmbedder::Rest(embedder) => Some(embedder.cache()),
277        }
278    }
279}
280
281fn check_similarity(
282    left: Vec<Embedding>,
283    right: Vec<Embedding>,
284    hint: CompositeEmbedderContainsHuggingFace,
285) -> Result<(), NewEmbedderError> {
286    if left.len() != right.len() {
287        return Err(NewEmbedderError::composite_embedding_count_mismatch(left.len(), right.len()));
288    }
289
290    for (left, right) in left.into_iter().zip(right) {
291        let left = arroy::internals::UnalignedVector::from_slice(&left);
292        let right = arroy::internals::UnalignedVector::from_slice(&right);
293        let left = arroy::internals::Leaf {
294            header: arroy::distances::Cosine::new_header(&left),
295            vector: left,
296        };
297        let right = arroy::internals::Leaf {
298            header: arroy::distances::Cosine::new_header(&right),
299            vector: right,
300        };
301
302        let distance = arroy::distances::Cosine::built_distance(&left, &right);
303
304        if distance > super::MAX_COMPOSITE_DISTANCE {
305            return Err(NewEmbedderError::composite_embedding_value_mismatch(distance, hint));
306        }
307    }
308    Ok(())
309}
310
311fn configuration_hint(
312    search: &SubEmbedder,
313    index: &SubEmbedder,
314) -> CompositeEmbedderContainsHuggingFace {
315    match (search, index) {
316        (SubEmbedder::HuggingFace(_), SubEmbedder::HuggingFace(_)) => {
317            CompositeEmbedderContainsHuggingFace::Both
318        }
319        (SubEmbedder::HuggingFace(_), _) => CompositeEmbedderContainsHuggingFace::Search,
320        (_, SubEmbedder::HuggingFace(_)) => CompositeEmbedderContainsHuggingFace::Indexing,
321        _ => CompositeEmbedderContainsHuggingFace::None,
322    }
323}