1use std::time::Instant;
2
3use arroy::Distance;
4
5use super::error::CompositeEmbedderContainsHuggingFace;
6use super::{
7 hf, manual, ollama, openai, rest, DistributionShift, EmbedError, Embedding, EmbeddingCache,
8 NewEmbedderError,
9};
10use crate::ThreadPoolNoAbort;
11
12#[derive(Debug)]
13pub enum SubEmbedder {
14 HuggingFace(hf::Embedder),
16 OpenAi(openai::Embedder),
18 UserProvided(manual::Embedder),
20 Ollama(ollama::Embedder),
22 Rest(rest::Embedder),
24}
25
26#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
27pub enum SubEmbedderOptions {
28 HuggingFace(hf::EmbedderOptions),
29 OpenAi(openai::EmbedderOptions),
30 Ollama(ollama::EmbedderOptions),
31 UserProvided(manual::EmbedderOptions),
32 Rest(rest::EmbedderOptions),
33}
34
35impl SubEmbedderOptions {
36 pub fn distribution(&self) -> Option<DistributionShift> {
37 match self {
38 SubEmbedderOptions::HuggingFace(embedder_options) => embedder_options.distribution,
39 SubEmbedderOptions::OpenAi(embedder_options) => embedder_options.distribution,
40 SubEmbedderOptions::Ollama(embedder_options) => embedder_options.distribution,
41 SubEmbedderOptions::UserProvided(embedder_options) => embedder_options.distribution,
42 SubEmbedderOptions::Rest(embedder_options) => embedder_options.distribution,
43 }
44 }
45}
46
47#[derive(Debug)]
48pub struct Embedder {
49 pub(super) search: SubEmbedder,
50 pub(super) index: SubEmbedder,
51}
52
53#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
54pub struct EmbedderOptions {
55 pub search: SubEmbedderOptions,
56 pub index: SubEmbedderOptions,
57}
58
59impl Embedder {
60 pub fn new(
61 EmbedderOptions { search, index }: EmbedderOptions,
62 cache_cap: usize,
63 ) -> Result<Self, NewEmbedderError> {
64 let search = SubEmbedder::new(search, cache_cap)?;
65 let index = SubEmbedder::new(index, 0)?;
67
68 if search.dimensions() != index.dimensions() {
70 return Err(NewEmbedderError::composite_dimensions_mismatch(
71 search.dimensions(),
72 index.dimensions(),
73 ));
74 }
75 let search_embeddings = search
77 .embed(
78 vec![
79 "test".into(),
80 "a brave dog".into(),
81 "This is a sample text. It is meant to compare similarity.".into(),
82 ],
83 None,
84 )
85 .map_err(|error| NewEmbedderError::composite_test_embedding_failed(error, "search"))?;
86
87 let index_embeddings = index
88 .embed(
89 vec![
90 "test".into(),
91 "a brave dog".into(),
92 "This is a sample text. It is meant to compare similarity.".into(),
93 ],
94 None,
95 )
96 .map_err(|error| {
97 NewEmbedderError::composite_test_embedding_failed(error, "indexing")
98 })?;
99
100 let hint = configuration_hint(&search, &index);
101
102 check_similarity(search_embeddings, index_embeddings, hint)?;
103
104 Ok(Self { search, index })
105 }
106
107 pub fn dimensions(&self) -> usize {
109 self.index.dimensions()
111 }
112
113 pub fn distribution(&self) -> Option<DistributionShift> {
115 self.search.distribution().or_else(|| self.index.distribution())
120 }
121}
122
123impl SubEmbedder {
124 pub fn new(
125 options: SubEmbedderOptions,
126 cache_cap: usize,
127 ) -> std::result::Result<Self, NewEmbedderError> {
128 Ok(match options {
129 SubEmbedderOptions::HuggingFace(options) => {
130 Self::HuggingFace(hf::Embedder::new(options, cache_cap)?)
131 }
132 SubEmbedderOptions::OpenAi(options) => {
133 Self::OpenAi(openai::Embedder::new(options, cache_cap)?)
134 }
135 SubEmbedderOptions::Ollama(options) => {
136 Self::Ollama(ollama::Embedder::new(options, cache_cap)?)
137 }
138 SubEmbedderOptions::UserProvided(options) => {
139 Self::UserProvided(manual::Embedder::new(options))
140 }
141 SubEmbedderOptions::Rest(options) => Self::Rest(rest::Embedder::new(
142 options,
143 cache_cap,
144 rest::ConfigurationSource::User,
145 )?),
146 })
147 }
148
149 pub fn embed(
150 &self,
151 texts: Vec<String>,
152 deadline: Option<Instant>,
153 ) -> std::result::Result<Vec<Embedding>, EmbedError> {
154 match self {
155 SubEmbedder::HuggingFace(embedder) => embedder.embed(texts),
156 SubEmbedder::OpenAi(embedder) => embedder.embed(&texts, deadline),
157 SubEmbedder::Ollama(embedder) => embedder.embed(&texts, deadline),
158 SubEmbedder::UserProvided(embedder) => embedder.embed(&texts),
159 SubEmbedder::Rest(embedder) => embedder.embed(texts, deadline),
160 }
161 }
162
163 pub fn embed_one(
164 &self,
165 text: &str,
166 deadline: Option<Instant>,
167 ) -> std::result::Result<Embedding, EmbedError> {
168 match self {
169 SubEmbedder::HuggingFace(embedder) => embedder.embed_one(text),
170 SubEmbedder::OpenAi(embedder) => {
171 embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
172 }
173 SubEmbedder::Ollama(embedder) => {
174 embedder.embed(&[text], deadline)?.pop().ok_or_else(EmbedError::missing_embedding)
175 }
176 SubEmbedder::UserProvided(embedder) => embedder.embed_one(text),
177 SubEmbedder::Rest(embedder) => embedder
178 .embed_ref(&[text], deadline)?
179 .pop()
180 .ok_or_else(EmbedError::missing_embedding),
181 }
182 }
183
184 pub fn embed_index(
188 &self,
189 text_chunks: Vec<Vec<String>>,
190 threads: &ThreadPoolNoAbort,
191 ) -> std::result::Result<Vec<Vec<Embedding>>, EmbedError> {
192 match self {
193 SubEmbedder::HuggingFace(embedder) => embedder.embed_index(text_chunks),
194 SubEmbedder::OpenAi(embedder) => embedder.embed_index(text_chunks, threads),
195 SubEmbedder::Ollama(embedder) => embedder.embed_index(text_chunks, threads),
196 SubEmbedder::UserProvided(embedder) => embedder.embed_index(text_chunks),
197 SubEmbedder::Rest(embedder) => embedder.embed_index(text_chunks, threads),
198 }
199 }
200
201 pub fn embed_index_ref(
203 &self,
204 texts: &[&str],
205 threads: &ThreadPoolNoAbort,
206 ) -> std::result::Result<Vec<Embedding>, EmbedError> {
207 match self {
208 SubEmbedder::HuggingFace(embedder) => embedder.embed_index_ref(texts),
209 SubEmbedder::OpenAi(embedder) => embedder.embed_index_ref(texts, threads),
210 SubEmbedder::Ollama(embedder) => embedder.embed_index_ref(texts, threads),
211 SubEmbedder::UserProvided(embedder) => embedder.embed_index_ref(texts),
212 SubEmbedder::Rest(embedder) => embedder.embed_index_ref(texts, threads),
213 }
214 }
215
216 pub fn chunk_count_hint(&self) -> usize {
218 match self {
219 SubEmbedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
220 SubEmbedder::OpenAi(embedder) => embedder.chunk_count_hint(),
221 SubEmbedder::Ollama(embedder) => embedder.chunk_count_hint(),
222 SubEmbedder::UserProvided(_) => 100,
223 SubEmbedder::Rest(embedder) => embedder.chunk_count_hint(),
224 }
225 }
226
227 pub fn prompt_count_in_chunk_hint(&self) -> usize {
229 match self {
230 SubEmbedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
231 SubEmbedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
232 SubEmbedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(),
233 SubEmbedder::UserProvided(_) => 1,
234 SubEmbedder::Rest(embedder) => embedder.prompt_count_in_chunk_hint(),
235 }
236 }
237
238 pub fn uses_document_template(&self) -> bool {
239 match self {
240 SubEmbedder::HuggingFace(_)
241 | SubEmbedder::OpenAi(_)
242 | SubEmbedder::Ollama(_)
243 | SubEmbedder::Rest(_) => true,
244 SubEmbedder::UserProvided(_) => false,
245 }
246 }
247
248 pub fn dimensions(&self) -> usize {
250 match self {
251 SubEmbedder::HuggingFace(embedder) => embedder.dimensions(),
252 SubEmbedder::OpenAi(embedder) => embedder.dimensions(),
253 SubEmbedder::Ollama(embedder) => embedder.dimensions(),
254 SubEmbedder::UserProvided(embedder) => embedder.dimensions(),
255 SubEmbedder::Rest(embedder) => embedder.dimensions(),
256 }
257 }
258
259 pub fn distribution(&self) -> Option<DistributionShift> {
261 match self {
262 SubEmbedder::HuggingFace(embedder) => embedder.distribution(),
263 SubEmbedder::OpenAi(embedder) => embedder.distribution(),
264 SubEmbedder::Ollama(embedder) => embedder.distribution(),
265 SubEmbedder::UserProvided(embedder) => embedder.distribution(),
266 SubEmbedder::Rest(embedder) => embedder.distribution(),
267 }
268 }
269
270 pub(super) fn cache(&self) -> Option<&EmbeddingCache> {
271 match self {
272 SubEmbedder::HuggingFace(embedder) => Some(embedder.cache()),
273 SubEmbedder::OpenAi(embedder) => Some(embedder.cache()),
274 SubEmbedder::UserProvided(_) => None,
275 SubEmbedder::Ollama(embedder) => Some(embedder.cache()),
276 SubEmbedder::Rest(embedder) => Some(embedder.cache()),
277 }
278 }
279}
280
281fn check_similarity(
282 left: Vec<Embedding>,
283 right: Vec<Embedding>,
284 hint: CompositeEmbedderContainsHuggingFace,
285) -> Result<(), NewEmbedderError> {
286 if left.len() != right.len() {
287 return Err(NewEmbedderError::composite_embedding_count_mismatch(left.len(), right.len()));
288 }
289
290 for (left, right) in left.into_iter().zip(right) {
291 let left = arroy::internals::UnalignedVector::from_slice(&left);
292 let right = arroy::internals::UnalignedVector::from_slice(&right);
293 let left = arroy::internals::Leaf {
294 header: arroy::distances::Cosine::new_header(&left),
295 vector: left,
296 };
297 let right = arroy::internals::Leaf {
298 header: arroy::distances::Cosine::new_header(&right),
299 vector: right,
300 };
301
302 let distance = arroy::distances::Cosine::built_distance(&left, &right);
303
304 if distance > super::MAX_COMPOSITE_DISTANCE {
305 return Err(NewEmbedderError::composite_embedding_value_mismatch(distance, hint));
306 }
307 }
308 Ok(())
309}
310
311fn configuration_hint(
312 search: &SubEmbedder,
313 index: &SubEmbedder,
314) -> CompositeEmbedderContainsHuggingFace {
315 match (search, index) {
316 (SubEmbedder::HuggingFace(_), SubEmbedder::HuggingFace(_)) => {
317 CompositeEmbedderContainsHuggingFace::Both
318 }
319 (SubEmbedder::HuggingFace(_), _) => CompositeEmbedderContainsHuggingFace::Search,
320 (_, SubEmbedder::HuggingFace(_)) => CompositeEmbedderContainsHuggingFace::Indexing,
321 _ => CompositeEmbedderContainsHuggingFace::None,
322 }
323}