Skip to main content

cortex_retrieval/embedding/
ollama.rs

1//! Ollama-backed embedder (Phase 4.C enrichment layer).
2//!
3//! [`OllamaEmbedder`] calls the Ollama `/api/embeddings` REST endpoint to
4//! produce real semantic vectors. It coexists alongside
5//! [`super::LocalStubEmbedder`]: both can be stored simultaneously under
6//! different `backend_id` values in the `memory_embeddings` side table.
7//!
8//! # Backend id format
9//!
10//! `"ollama:<model_name>:<dim>"` — e.g. `"ollama:nomic-embed-text:768"`.
11//!
12//! The dimension is part of the id so that a model upgrade (which typically
13//! changes dimensionality) automatically creates a new backend bucket rather
14//! than overwriting incomparable old vectors.
15//!
16//! # Synchronous HTTP
17//!
18//! `ureq` (already a workspace dependency) is used for the HTTP call. Ollama
19//! runs locally, so latency is dominated by the model inference time rather
20//! than network round-trip. The call is blocking; callers that require async
21//! should wrap it in `tokio::task::spawn_blocking`.
22//!
23//! # Error handling
24//!
25//! Network errors, HTTP non-200 responses, and parse failures all surface as
26//! [`EmbedError::Backend`]. Ollama returning a zero-length vector surfaces as
27//! [`EmbedError::DimensionMismatch`].
28
29use std::time::Duration;
30
31use serde::{Deserialize, Serialize};
32
33use super::{EmbedError, EmbedResult, Embedder};
34
35/// Prefix for all Ollama backend ids.
36pub const OLLAMA_BACKEND_ID_PREFIX: &str = "ollama";
37
38/// Default Ollama endpoint used when none is configured.
39pub const DEFAULT_OLLAMA_ENDPOINT: &str = "http://localhost:11434";
40
41/// Default Ollama embedding model.
42pub const DEFAULT_OLLAMA_EMBED_MODEL: &str = "nomic-embed-text";
43
44/// Default dimension for `nomic-embed-text`. Used as a fallback when the
45/// caller does not supply an explicit dimension at construction time (the
46/// embedder learns the true dim on the first call and validates from there).
47pub const NOMIC_EMBED_DIM: usize = 768;
48
49/// Default HTTP timeout for embedding calls (milliseconds).
50const DEFAULT_TIMEOUT_MS: u64 = 30_000;
51
52/// Ollama `/api/embeddings` request body.
53#[derive(Debug, Serialize)]
54struct EmbedRequest<'a> {
55    model: &'a str,
56    prompt: &'a str,
57}
58
59/// Ollama `/api/embeddings` response body.
60#[derive(Debug, Deserialize)]
61struct EmbedResponse {
62    embedding: Vec<f64>,
63}
64
65/// Return `true` if `endpoint` targets a loopback address.
66///
67/// Accepted hosts: `localhost`, `127.x.x.x` (any IPv4 in 127/8), and `::1`.
68/// Scheme and path are ignored; only the host part is inspected.
69fn is_loopback_endpoint(endpoint: &str) -> bool {
70    let without_scheme = endpoint
71        .strip_prefix("https://")
72        .or_else(|| endpoint.strip_prefix("http://"))
73        .unwrap_or(endpoint);
74    let host_port = without_scheme.split('/').next().unwrap_or(without_scheme);
75    let host = if host_port.starts_with('[') {
76        // IPv6 literal bracket form: `[::1]:11434`.
77        host_port
78            .trim_start_matches('[')
79            .split(']')
80            .next()
81            .unwrap_or(host_port)
82    } else {
83        // IPv4 or hostname: drop `:port`.
84        host_port.split(':').next().unwrap_or(host_port)
85    };
86    host.eq_ignore_ascii_case("localhost")
87        || host == "127.0.0.1"
88        || host.starts_with("127.")
89        || host == "::1"
90}
91
92/// Embedder that calls Ollama's `/api/embeddings` endpoint.
93///
94/// Construct via [`OllamaEmbedder::new`] (endpoint + model + expected
95/// dimension) or [`OllamaEmbedder::default_nomic`] (loopback, nomic-embed-
96/// text, 768-dim).
97///
98/// The `backend_id` is fixed at construction time as
99/// `"ollama:<model>:<dim>"`.
100#[derive(Debug, Clone)]
101pub struct OllamaEmbedder {
102    endpoint: String,
103    model: String,
104    dim: usize,
105    backend_id: String,
106    timeout_ms: u64,
107}
108
109impl OllamaEmbedder {
110    /// Construct an embedder targeting `endpoint` (e.g.
111    /// `"http://localhost:11434"`) with `model` (e.g. `"nomic-embed-text"`)
112    /// and expected output dimensionality `dim`.
113    ///
114    /// Returns an error if `endpoint` is empty, `model` is empty, or `dim`
115    /// is zero.
116    pub fn new(
117        endpoint: impl Into<String>,
118        model: impl Into<String>,
119        dim: usize,
120    ) -> EmbedResult<Self> {
121        let endpoint = endpoint.into();
122        let model = model.into();
123
124        if endpoint.trim().is_empty() {
125            return Err(EmbedError::InvalidInput(
126                "OllamaEmbedder: endpoint must not be empty".to_string(),
127            ));
128        }
129        if model.trim().is_empty() {
130            return Err(EmbedError::InvalidInput(
131                "OllamaEmbedder: model must not be empty".to_string(),
132            ));
133        }
134        if dim == 0 {
135            return Err(EmbedError::InvalidInput(
136                "OllamaEmbedder: dim must be > 0".to_string(),
137            ));
138        }
139
140        // Enforce loopback-only: the endpoint hostname must be localhost,
141        // 127.0.0.1, or ::1. This mirrors the guardrail in CLAUDE.md §Ollama.
142        if !is_loopback_endpoint(&endpoint) {
143            return Err(EmbedError::InvalidInput(format!(
144                "OllamaEmbedder: endpoint must be loopback-only (localhost/127.0.0.1/::1), got `{endpoint}`"
145            )));
146        }
147
148        let backend_id = format!("{OLLAMA_BACKEND_ID_PREFIX}:{model}:{dim}");
149        Ok(Self {
150            endpoint,
151            model,
152            dim,
153            backend_id,
154            timeout_ms: DEFAULT_TIMEOUT_MS,
155        })
156    }
157
158    /// Convenience constructor: loopback Ollama, `nomic-embed-text`, 768 dim.
159    pub fn default_nomic() -> EmbedResult<Self> {
160        Self::new(
161            DEFAULT_OLLAMA_ENDPOINT,
162            DEFAULT_OLLAMA_EMBED_MODEL,
163            NOMIC_EMBED_DIM,
164        )
165    }
166
167    /// Override the HTTP timeout (milliseconds). Default is 30 000.
168    #[must_use]
169    pub fn with_timeout_ms(mut self, ms: u64) -> Self {
170        self.timeout_ms = ms;
171        self
172    }
173
174    /// Return the backend id for this embedder without constructing a full
175    /// instance. Useful for querying the store before creating the embedder.
176    pub fn backend_id_for(model: &str, dim: usize) -> String {
177        format!("{OLLAMA_BACKEND_ID_PREFIX}:{model}:{dim}")
178    }
179}
180
181impl Embedder for OllamaEmbedder {
182    fn backend_id(&self) -> &str {
183        &self.backend_id
184    }
185
186    fn dim(&self) -> usize {
187        self.dim
188    }
189
190    fn embed(&self, text: &str, tags: &[String]) -> EmbedResult<Vec<f32>> {
191        // Build the prompt: concatenate claim text + tags with a separator so
192        // tags influence the embedding without polluting the main claim signal.
193        let prompt = if tags.is_empty() {
194            text.to_string()
195        } else {
196            format!("{text} | {}", tags.join(" "))
197        };
198
199        let url = format!("{}/api/embeddings", self.endpoint);
200
201        let body = EmbedRequest {
202            model: &self.model,
203            prompt: &prompt,
204        };
205
206        let timeout = Duration::from_millis(self.timeout_ms);
207        let agent = ureq::AgentBuilder::new().timeout(timeout).build();
208
209        let body_json = serde_json::to_value(&body)
210            .map_err(|e| EmbedError::Backend(format!("request serialization failed: {e}")))?;
211
212        let response = agent
213            .post(&url)
214            .send_json(body_json)
215            .map_err(|err| EmbedError::Backend(format!("Ollama HTTP error: {err}")))?;
216
217        if response.status() != 200 {
218            let status = response.status();
219            return Err(EmbedError::Backend(format!(
220                "Ollama returned HTTP {status}"
221            )));
222        }
223
224        let response_text = response
225            .into_string()
226            .map_err(|e| EmbedError::Backend(format!("reading Ollama response body: {e}")))?;
227
228        let parsed: EmbedResponse = serde_json::from_str(&response_text)
229            .map_err(|e| EmbedError::Backend(format!("Ollama response parse: {e}")))?;
230
231        let vector: Vec<f32> = parsed.embedding.iter().map(|&v| v as f32).collect();
232
233        if vector.len() != self.dim {
234            return Err(EmbedError::DimensionMismatch {
235                backend_id: self.backend_id.clone(),
236                expected: self.dim,
237                actual: vector.len(),
238            });
239        }
240
241        Ok(vector)
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn constructor_rejects_empty_endpoint() {
251        let err = OllamaEmbedder::new("", "nomic-embed-text", 768).unwrap_err();
252        assert!(
253            matches!(err, EmbedError::InvalidInput(_)),
254            "expected InvalidInput, got {err:?}"
255        );
256    }
257
258    #[test]
259    fn constructor_rejects_empty_model() {
260        let err = OllamaEmbedder::new("http://localhost:11434", "", 768).unwrap_err();
261        assert!(
262            matches!(err, EmbedError::InvalidInput(_)),
263            "expected InvalidInput, got {err:?}"
264        );
265    }
266
267    #[test]
268    fn constructor_rejects_zero_dim() {
269        let err = OllamaEmbedder::new("http://localhost:11434", "nomic-embed-text", 0).unwrap_err();
270        assert!(
271            matches!(err, EmbedError::InvalidInput(_)),
272            "expected InvalidInput, got {err:?}"
273        );
274    }
275
276    #[test]
277    fn backend_id_encodes_model_and_dim() {
278        let e = OllamaEmbedder::new("http://localhost:11434", "nomic-embed-text", 768).unwrap();
279        assert_eq!(e.backend_id(), "ollama:nomic-embed-text:768");
280        assert_eq!(e.dim(), 768);
281    }
282
283    #[test]
284    fn backend_id_for_matches_instance() {
285        let id = OllamaEmbedder::backend_id_for("nomic-embed-text", 768);
286        let e = OllamaEmbedder::default_nomic().unwrap();
287        assert_eq!(id, e.backend_id());
288    }
289
290    #[test]
291    fn default_nomic_has_expected_backend_id() {
292        let e = OllamaEmbedder::default_nomic().unwrap();
293        assert_eq!(e.backend_id(), "ollama:nomic-embed-text:768");
294        assert_eq!(e.dim(), NOMIC_EMBED_DIM);
295    }
296
297    #[test]
298    fn with_timeout_ms_overrides_default() {
299        let e = OllamaEmbedder::default_nomic()
300            .unwrap()
301            .with_timeout_ms(5_000);
302        assert_eq!(e.timeout_ms, 5_000);
303    }
304}