cortex_retrieval/embedding/
ollama.rs1use std::time::Duration;
30
31use serde::{Deserialize, Serialize};
32
33use super::{EmbedError, EmbedResult, Embedder};
34
35pub const OLLAMA_BACKEND_ID_PREFIX: &str = "ollama";
37
38pub const DEFAULT_OLLAMA_ENDPOINT: &str = "http://localhost:11434";
40
41pub const DEFAULT_OLLAMA_EMBED_MODEL: &str = "nomic-embed-text";
43
44pub const NOMIC_EMBED_DIM: usize = 768;
48
49const DEFAULT_TIMEOUT_MS: u64 = 30_000;
51
52#[derive(Debug, Serialize)]
54struct EmbedRequest<'a> {
55 model: &'a str,
56 prompt: &'a str,
57}
58
59#[derive(Debug, Deserialize)]
61struct EmbedResponse {
62 embedding: Vec<f64>,
63}
64
65fn is_loopback_endpoint(endpoint: &str) -> bool {
70 let without_scheme = endpoint
71 .strip_prefix("https://")
72 .or_else(|| endpoint.strip_prefix("http://"))
73 .unwrap_or(endpoint);
74 let host_port = without_scheme.split('/').next().unwrap_or(without_scheme);
75 let host = if host_port.starts_with('[') {
76 host_port
78 .trim_start_matches('[')
79 .split(']')
80 .next()
81 .unwrap_or(host_port)
82 } else {
83 host_port.split(':').next().unwrap_or(host_port)
85 };
86 host.eq_ignore_ascii_case("localhost")
87 || host == "127.0.0.1"
88 || host.starts_with("127.")
89 || host == "::1"
90}
91
92#[derive(Debug, Clone)]
101pub struct OllamaEmbedder {
102 endpoint: String,
103 model: String,
104 dim: usize,
105 backend_id: String,
106 timeout_ms: u64,
107}
108
109impl OllamaEmbedder {
110 pub fn new(
117 endpoint: impl Into<String>,
118 model: impl Into<String>,
119 dim: usize,
120 ) -> EmbedResult<Self> {
121 let endpoint = endpoint.into();
122 let model = model.into();
123
124 if endpoint.trim().is_empty() {
125 return Err(EmbedError::InvalidInput(
126 "OllamaEmbedder: endpoint must not be empty".to_string(),
127 ));
128 }
129 if model.trim().is_empty() {
130 return Err(EmbedError::InvalidInput(
131 "OllamaEmbedder: model must not be empty".to_string(),
132 ));
133 }
134 if dim == 0 {
135 return Err(EmbedError::InvalidInput(
136 "OllamaEmbedder: dim must be > 0".to_string(),
137 ));
138 }
139
140 if !is_loopback_endpoint(&endpoint) {
143 return Err(EmbedError::InvalidInput(format!(
144 "OllamaEmbedder: endpoint must be loopback-only (localhost/127.0.0.1/::1), got `{endpoint}`"
145 )));
146 }
147
148 let backend_id = format!("{OLLAMA_BACKEND_ID_PREFIX}:{model}:{dim}");
149 Ok(Self {
150 endpoint,
151 model,
152 dim,
153 backend_id,
154 timeout_ms: DEFAULT_TIMEOUT_MS,
155 })
156 }
157
158 pub fn default_nomic() -> EmbedResult<Self> {
160 Self::new(
161 DEFAULT_OLLAMA_ENDPOINT,
162 DEFAULT_OLLAMA_EMBED_MODEL,
163 NOMIC_EMBED_DIM,
164 )
165 }
166
167 #[must_use]
169 pub fn with_timeout_ms(mut self, ms: u64) -> Self {
170 self.timeout_ms = ms;
171 self
172 }
173
174 pub fn backend_id_for(model: &str, dim: usize) -> String {
177 format!("{OLLAMA_BACKEND_ID_PREFIX}:{model}:{dim}")
178 }
179}
180
181impl Embedder for OllamaEmbedder {
182 fn backend_id(&self) -> &str {
183 &self.backend_id
184 }
185
186 fn dim(&self) -> usize {
187 self.dim
188 }
189
190 fn embed(&self, text: &str, tags: &[String]) -> EmbedResult<Vec<f32>> {
191 let prompt = if tags.is_empty() {
194 text.to_string()
195 } else {
196 format!("{text} | {}", tags.join(" "))
197 };
198
199 let url = format!("{}/api/embeddings", self.endpoint);
200
201 let body = EmbedRequest {
202 model: &self.model,
203 prompt: &prompt,
204 };
205
206 let timeout = Duration::from_millis(self.timeout_ms);
207 let agent = ureq::AgentBuilder::new().timeout(timeout).build();
208
209 let body_json = serde_json::to_value(&body)
210 .map_err(|e| EmbedError::Backend(format!("request serialization failed: {e}")))?;
211
212 let response = agent
213 .post(&url)
214 .send_json(body_json)
215 .map_err(|err| EmbedError::Backend(format!("Ollama HTTP error: {err}")))?;
216
217 if response.status() != 200 {
218 let status = response.status();
219 return Err(EmbedError::Backend(format!(
220 "Ollama returned HTTP {status}"
221 )));
222 }
223
224 let response_text = response
225 .into_string()
226 .map_err(|e| EmbedError::Backend(format!("reading Ollama response body: {e}")))?;
227
228 let parsed: EmbedResponse = serde_json::from_str(&response_text)
229 .map_err(|e| EmbedError::Backend(format!("Ollama response parse: {e}")))?;
230
231 let vector: Vec<f32> = parsed.embedding.iter().map(|&v| v as f32).collect();
232
233 if vector.len() != self.dim {
234 return Err(EmbedError::DimensionMismatch {
235 backend_id: self.backend_id.clone(),
236 expected: self.dim,
237 actual: vector.len(),
238 });
239 }
240
241 Ok(vector)
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn constructor_rejects_empty_endpoint() {
251 let err = OllamaEmbedder::new("", "nomic-embed-text", 768).unwrap_err();
252 assert!(
253 matches!(err, EmbedError::InvalidInput(_)),
254 "expected InvalidInput, got {err:?}"
255 );
256 }
257
258 #[test]
259 fn constructor_rejects_empty_model() {
260 let err = OllamaEmbedder::new("http://localhost:11434", "", 768).unwrap_err();
261 assert!(
262 matches!(err, EmbedError::InvalidInput(_)),
263 "expected InvalidInput, got {err:?}"
264 );
265 }
266
267 #[test]
268 fn constructor_rejects_zero_dim() {
269 let err = OllamaEmbedder::new("http://localhost:11434", "nomic-embed-text", 0).unwrap_err();
270 assert!(
271 matches!(err, EmbedError::InvalidInput(_)),
272 "expected InvalidInput, got {err:?}"
273 );
274 }
275
276 #[test]
277 fn backend_id_encodes_model_and_dim() {
278 let e = OllamaEmbedder::new("http://localhost:11434", "nomic-embed-text", 768).unwrap();
279 assert_eq!(e.backend_id(), "ollama:nomic-embed-text:768");
280 assert_eq!(e.dim(), 768);
281 }
282
283 #[test]
284 fn backend_id_for_matches_instance() {
285 let id = OllamaEmbedder::backend_id_for("nomic-embed-text", 768);
286 let e = OllamaEmbedder::default_nomic().unwrap();
287 assert_eq!(id, e.backend_id());
288 }
289
290 #[test]
291 fn default_nomic_has_expected_backend_id() {
292 let e = OllamaEmbedder::default_nomic().unwrap();
293 assert_eq!(e.backend_id(), "ollama:nomic-embed-text:768");
294 assert_eq!(e.dim(), NOMIC_EMBED_DIM);
295 }
296
297 #[test]
298 fn with_timeout_ms_overrides_default() {
299 let e = OllamaEmbedder::default_nomic()
300 .unwrap()
301 .with_timeout_ms(5_000);
302 assert_eq!(e.timeout_ms, 5_000);
303 }
304}