milli_core/vector/
ollama.rs

1use std::time::Instant;
2
3use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
4use rayon::slice::ParallelSlice as _;
5
6use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
7use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
8use super::{DistributionShift, EmbeddingCache, REQUEST_PARALLELISM};
9use crate::error::FaultSource;
10use crate::vector::Embedding;
11use crate::ThreadPoolNoAbort;
12
13#[derive(Debug)]
14pub struct Embedder {
15    rest_embedder: RestEmbedder,
16}
17
18#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
19pub struct EmbedderOptions {
20    pub embedding_model: String,
21    pub url: Option<String>,
22    pub api_key: Option<String>,
23    pub distribution: Option<DistributionShift>,
24    pub dimensions: Option<usize>,
25}
26
27impl EmbedderOptions {
28    pub fn with_default_model(
29        api_key: Option<String>,
30        url: Option<String>,
31        dimensions: Option<usize>,
32    ) -> Self {
33        Self {
34            embedding_model: "nomic-embed-text".into(),
35            api_key,
36            url,
37            distribution: None,
38            dimensions,
39        }
40    }
41
42    fn into_rest_embedder_config(self) -> Result<RestEmbedderOptions, NewEmbedderError> {
43        let url = self.url.unwrap_or_else(get_ollama_path);
44        let model = self.embedding_model.as_str();
45
46        // **warning**: do not swap these two `if`s, as the second one is always true when the first one is.
47        let (request, response) = if url.ends_with("/api/embeddings") {
48            (
49                serde_json::json!({
50                    "model": model,
51                    "prompt": super::rest::REQUEST_PLACEHOLDER,
52                }),
53                serde_json::json!({
54                    "embedding": super::rest::RESPONSE_PLACEHOLDER,
55                }),
56            )
57        } else if url.ends_with("/api/embed") {
58            (
59                serde_json::json!({"model": model, "input": [super::rest::REQUEST_PLACEHOLDER, super::rest::REPEAT_PLACEHOLDER]}),
60                serde_json::json!({"embeddings": [super::rest::RESPONSE_PLACEHOLDER, super::rest::REPEAT_PLACEHOLDER]}),
61            )
62        } else {
63            return Err(NewEmbedderError::ollama_unsupported_url(url));
64        };
65        Ok(RestEmbedderOptions {
66            api_key: self.api_key,
67            dimensions: self.dimensions,
68            distribution: self.distribution,
69            url,
70            request,
71            response,
72            headers: Default::default(),
73        })
74    }
75}
76
77impl Embedder {
78    pub fn new(options: EmbedderOptions, cache_cap: usize) -> Result<Self, NewEmbedderError> {
79        let rest_embedder = match RestEmbedder::new(
80            options.into_rest_embedder_config()?,
81            cache_cap,
82            super::rest::ConfigurationSource::Ollama,
83        ) {
84            Ok(embedder) => embedder,
85            Err(NewEmbedderError {
86                kind:
87                    NewEmbedderErrorKind::CouldNotDetermineDimension(EmbedError {
88                        kind: super::error::EmbedErrorKind::RestOtherStatusCode(404, error),
89                        fault: _,
90                    }),
91                fault: _,
92            }) => {
93                return Err(NewEmbedderError::could_not_determine_dimension(
94                    EmbedError::ollama_model_not_found(error),
95                ))
96            }
97            Err(error) => return Err(error),
98        };
99
100        Ok(Self { rest_embedder })
101    }
102
103    pub fn embed<S: AsRef<str> + serde::Serialize>(
104        &self,
105        texts: &[S],
106        deadline: Option<Instant>,
107    ) -> Result<Vec<Embedding>, EmbedError> {
108        match self.rest_embedder.embed_ref(texts, deadline) {
109            Ok(embeddings) => Ok(embeddings),
110            Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => {
111                Err(EmbedError::ollama_model_not_found(error))
112            }
113            Err(error) => Err(error),
114        }
115    }
116
117    pub fn embed_index(
118        &self,
119        text_chunks: Vec<Vec<String>>,
120        threads: &ThreadPoolNoAbort,
121    ) -> Result<Vec<Vec<Embedding>>, EmbedError> {
122        // This condition helps reduce the number of active rayon jobs
123        // so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
124        if threads.active_operations() >= REQUEST_PARALLELISM {
125            text_chunks.into_iter().map(move |chunk| self.embed(&chunk, None)).collect()
126        } else {
127            threads
128                .install(move || {
129                    text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None)).collect()
130                })
131                .map_err(|error| EmbedError {
132                    kind: EmbedErrorKind::PanicInThreadPool(error),
133                    fault: FaultSource::Bug,
134                })?
135        }
136    }
137
138    pub(crate) fn embed_index_ref(
139        &self,
140        texts: &[&str],
141        threads: &ThreadPoolNoAbort,
142    ) -> Result<Vec<Vec<f32>>, EmbedError> {
143        // This condition helps reduce the number of active rayon jobs
144        // so that we avoid consuming all the LMDB rtxns and avoid stack overflows.
145        if threads.active_operations() >= REQUEST_PARALLELISM {
146            let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
147                .chunks(self.prompt_count_in_chunk_hint())
148                .map(move |chunk| self.embed(chunk, None))
149                .collect();
150
151            let embeddings = embeddings?;
152            Ok(embeddings.into_iter().flatten().collect())
153        } else {
154            threads
155                .install(move || {
156                    let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
157                        .par_chunks(self.prompt_count_in_chunk_hint())
158                        .map(move |chunk| self.embed(chunk, None))
159                        .collect();
160
161                    let embeddings = embeddings?;
162                    Ok(embeddings.into_iter().flatten().collect())
163                })
164                .map_err(|error| EmbedError {
165                    kind: EmbedErrorKind::PanicInThreadPool(error),
166                    fault: FaultSource::Bug,
167                })?
168        }
169    }
170
171    pub fn chunk_count_hint(&self) -> usize {
172        self.rest_embedder.chunk_count_hint()
173    }
174
175    pub fn prompt_count_in_chunk_hint(&self) -> usize {
176        self.rest_embedder.prompt_count_in_chunk_hint()
177    }
178
179    pub fn dimensions(&self) -> usize {
180        self.rest_embedder.dimensions()
181    }
182
183    pub fn distribution(&self) -> Option<DistributionShift> {
184        self.rest_embedder.distribution()
185    }
186
187    pub(super) fn cache(&self) -> &EmbeddingCache {
188        self.rest_embedder.cache()
189    }
190}
191
192fn get_ollama_path() -> String {
193    // Important: Hostname not enough, has to be entire path to embeddings endpoint
194    std::env::var("MEILI_OLLAMA_URL").unwrap_or("http://localhost:11434/api/embeddings".to_string())
195}