1use std::time::Instant;
2
3use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
4use rayon::slice::ParallelSlice as _;
5
6use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
7use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
8use super::{DistributionShift, EmbeddingCache, REQUEST_PARALLELISM};
9use crate::error::FaultSource;
10use crate::vector::Embedding;
11use crate::ThreadPoolNoAbort;
12
13#[derive(Debug)]
14pub struct Embedder {
15 rest_embedder: RestEmbedder,
16}
17
18#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
19pub struct EmbedderOptions {
20 pub embedding_model: String,
21 pub url: Option<String>,
22 pub api_key: Option<String>,
23 pub distribution: Option<DistributionShift>,
24 pub dimensions: Option<usize>,
25}
26
27impl EmbedderOptions {
28 pub fn with_default_model(
29 api_key: Option<String>,
30 url: Option<String>,
31 dimensions: Option<usize>,
32 ) -> Self {
33 Self {
34 embedding_model: "nomic-embed-text".into(),
35 api_key,
36 url,
37 distribution: None,
38 dimensions,
39 }
40 }
41
42 fn into_rest_embedder_config(self) -> Result<RestEmbedderOptions, NewEmbedderError> {
43 let url = self.url.unwrap_or_else(get_ollama_path);
44 let model = self.embedding_model.as_str();
45
46 let (request, response) = if url.ends_with("/api/embeddings") {
48 (
49 serde_json::json!({
50 "model": model,
51 "prompt": super::rest::REQUEST_PLACEHOLDER,
52 }),
53 serde_json::json!({
54 "embedding": super::rest::RESPONSE_PLACEHOLDER,
55 }),
56 )
57 } else if url.ends_with("/api/embed") {
58 (
59 serde_json::json!({"model": model, "input": [super::rest::REQUEST_PLACEHOLDER, super::rest::REPEAT_PLACEHOLDER]}),
60 serde_json::json!({"embeddings": [super::rest::RESPONSE_PLACEHOLDER, super::rest::REPEAT_PLACEHOLDER]}),
61 )
62 } else {
63 return Err(NewEmbedderError::ollama_unsupported_url(url));
64 };
65 Ok(RestEmbedderOptions {
66 api_key: self.api_key,
67 dimensions: self.dimensions,
68 distribution: self.distribution,
69 url,
70 request,
71 response,
72 headers: Default::default(),
73 })
74 }
75}
76
77impl Embedder {
78 pub fn new(options: EmbedderOptions, cache_cap: usize) -> Result<Self, NewEmbedderError> {
79 let rest_embedder = match RestEmbedder::new(
80 options.into_rest_embedder_config()?,
81 cache_cap,
82 super::rest::ConfigurationSource::Ollama,
83 ) {
84 Ok(embedder) => embedder,
85 Err(NewEmbedderError {
86 kind:
87 NewEmbedderErrorKind::CouldNotDetermineDimension(EmbedError {
88 kind: super::error::EmbedErrorKind::RestOtherStatusCode(404, error),
89 fault: _,
90 }),
91 fault: _,
92 }) => {
93 return Err(NewEmbedderError::could_not_determine_dimension(
94 EmbedError::ollama_model_not_found(error),
95 ))
96 }
97 Err(error) => return Err(error),
98 };
99
100 Ok(Self { rest_embedder })
101 }
102
103 pub fn embed<S: AsRef<str> + serde::Serialize>(
104 &self,
105 texts: &[S],
106 deadline: Option<Instant>,
107 ) -> Result<Vec<Embedding>, EmbedError> {
108 match self.rest_embedder.embed_ref(texts, deadline) {
109 Ok(embeddings) => Ok(embeddings),
110 Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => {
111 Err(EmbedError::ollama_model_not_found(error))
112 }
113 Err(error) => Err(error),
114 }
115 }
116
117 pub fn embed_index(
118 &self,
119 text_chunks: Vec<Vec<String>>,
120 threads: &ThreadPoolNoAbort,
121 ) -> Result<Vec<Vec<Embedding>>, EmbedError> {
122 if threads.active_operations() >= REQUEST_PARALLELISM {
125 text_chunks.into_iter().map(move |chunk| self.embed(&chunk, None)).collect()
126 } else {
127 threads
128 .install(move || {
129 text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None)).collect()
130 })
131 .map_err(|error| EmbedError {
132 kind: EmbedErrorKind::PanicInThreadPool(error),
133 fault: FaultSource::Bug,
134 })?
135 }
136 }
137
138 pub(crate) fn embed_index_ref(
139 &self,
140 texts: &[&str],
141 threads: &ThreadPoolNoAbort,
142 ) -> Result<Vec<Vec<f32>>, EmbedError> {
143 if threads.active_operations() >= REQUEST_PARALLELISM {
146 let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
147 .chunks(self.prompt_count_in_chunk_hint())
148 .map(move |chunk| self.embed(chunk, None))
149 .collect();
150
151 let embeddings = embeddings?;
152 Ok(embeddings.into_iter().flatten().collect())
153 } else {
154 threads
155 .install(move || {
156 let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
157 .par_chunks(self.prompt_count_in_chunk_hint())
158 .map(move |chunk| self.embed(chunk, None))
159 .collect();
160
161 let embeddings = embeddings?;
162 Ok(embeddings.into_iter().flatten().collect())
163 })
164 .map_err(|error| EmbedError {
165 kind: EmbedErrorKind::PanicInThreadPool(error),
166 fault: FaultSource::Bug,
167 })?
168 }
169 }
170
171 pub fn chunk_count_hint(&self) -> usize {
172 self.rest_embedder.chunk_count_hint()
173 }
174
175 pub fn prompt_count_in_chunk_hint(&self) -> usize {
176 self.rest_embedder.prompt_count_in_chunk_hint()
177 }
178
179 pub fn dimensions(&self) -> usize {
180 self.rest_embedder.dimensions()
181 }
182
183 pub fn distribution(&self) -> Option<DistributionShift> {
184 self.rest_embedder.distribution()
185 }
186
187 pub(super) fn cache(&self) -> &EmbeddingCache {
188 self.rest_embedder.cache()
189 }
190}
191
192fn get_ollama_path() -> String {
193 std::env::var("MEILI_OLLAMA_URL").unwrap_or("http://localhost:11434/api/embeddings".to_string())
195}