Skip to main content

cognee_embedding/
ollama.rs

1//! Ollama embedding engine.
2//!
3//! Calls the Ollama `/api/embed` endpoint for each input text concurrently.
4//! Supports all three response shapes that Ollama can return:
5//! - `{"embeddings": [[...]]}` — standard Ollama `/api/embed`
6//! - `{"embedding": [...]}` — legacy Ollama `/api/embeddings`
7//! - `{"data": [{"embedding": [...]}]}` — OpenAI-compatible fallback shape
8
9use async_trait::async_trait;
10use futures::future;
11use serde::Serialize;
12use serde_json::Value;
13
14use crate::config::EmbeddingConfig;
15use crate::engine::EmbeddingEngine;
16use crate::error::{EmbeddingError, EmbeddingResult};
17use crate::utils::{handle_embedding_response, sanitize_embedding_inputs};
18
19// ─── Request type ─────────────────────────────────────────────────────────────
20
21#[derive(Serialize)]
22struct OllamaEmbedRequest<'a> {
23    model: &'a str,
24    input: &'a str,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    dimensions: Option<usize>,
27}
28
29// ─── Engine ───────────────────────────────────────────────────────────────────
30
31/// Embedding engine that calls the Ollama `/api/embed` HTTP endpoint.
32///
33/// Sends one request per input text concurrently using `futures::future::join_all`.
34/// Transient HTTP errors (network failures, 429, 5xx) are retried with
35/// exponential back-off starting at 8 s (doubling to 128 s) for up to 128 s total.
36///
37/// # Response shapes
38///
39/// Ollama can return embeddings in three shapes depending on the version and endpoint:
40/// - `{"embeddings": [[...]]}` — standard `/api/embed` response
41/// - `{"embedding": [...]}` — legacy single-embedding response
42/// - `{"data": [{"embedding": [...]}]}` — OpenAI-compatible shape
43///
44/// All three shapes are handled transparently.
45pub struct OllamaEmbeddingEngine {
46    client: reqwest::Client,
47    /// Full URL to the Ollama embed endpoint, e.g. `http://localhost:11434/api/embed`.
48    endpoint: String,
49    model: String,
50    dimensions: usize,
51    batch_size: usize,
52    max_completion_tokens: usize,
53}
54
55impl OllamaEmbeddingEngine {
56    /// Construct a new engine from the given [`EmbeddingConfig`].
57    ///
58    /// Returns [`EmbeddingError::ConfigError`] if the `reqwest` client cannot
59    /// be built (e.g. invalid TLS or API key header value).
60    pub fn new(config: &EmbeddingConfig) -> EmbeddingResult<Self> {
61        let endpoint = config
62            .endpoint
63            .clone()
64            .unwrap_or_else(|| "http://localhost:11434/api/embed".to_string());
65
66        let mut default_headers = reqwest::header::HeaderMap::new();
67
68        if let Some(api_key) = &config.api_key
69            && !api_key.is_empty()
70        {
71            let bearer = format!("Bearer {api_key}");
72            let auth_value = reqwest::header::HeaderValue::from_str(&bearer)
73                .map_err(|e| EmbeddingError::ConfigError(format!("Invalid API key value: {e}")))?;
74            default_headers.insert(reqwest::header::AUTHORIZATION, auth_value);
75        }
76
77        let client = reqwest::Client::builder()
78            .default_headers(default_headers)
79            .timeout(std::time::Duration::from_secs(30))
80            .build()
81            .map_err(|e| {
82                EmbeddingError::ConfigError(format!("Failed to build HTTP client: {e}"))
83            })?;
84
85        Ok(Self {
86            client,
87            endpoint,
88            model: config.model.clone(),
89            dimensions: config.dimensions,
90            batch_size: config.batch_size,
91            max_completion_tokens: config.max_completion_tokens,
92        })
93    }
94
95    /// Truncate `text` to at most `max_completion_tokens * 4` characters.
96    ///
97    /// Truncation is on a Unicode character boundary, not a byte boundary.
98    /// The factor of 4 is the same heuristic used by the Python SDK.
99    fn truncate_text<'a>(&self, text: &'a str) -> &'a str {
100        let char_limit = self.max_completion_tokens * 4;
101        let byte_pos = text
102            .char_indices()
103            .nth(char_limit)
104            .map(|(i, _)| i)
105            .unwrap_or(text.len());
106        &text[..byte_pos]
107    }
108
109    /// Call the Ollama endpoint once for a single text (no retry).
110    async fn embed_single_once(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
111        let truncated = self.truncate_text(text);
112
113        let request_body = OllamaEmbedRequest {
114            model: &self.model,
115            input: truncated,
116            // Only send `dimensions` if it's non-zero; some older Ollama versions
117            // reject unknown fields.
118            dimensions: if self.dimensions > 0 {
119                Some(self.dimensions)
120            } else {
121                None
122            },
123        };
124
125        let response = self
126            .client
127            .post(&self.endpoint)
128            .json(&request_body)
129            .send()
130            .await
131            .map_err(|e| EmbeddingError::HttpError(format!("Request failed: {e}")))?;
132
133        let status = response.status();
134        if !status.is_success() {
135            let body = response
136                .text()
137                .await
138                .unwrap_or_else(|_| "<failed to read body>".to_string());
139            return Err(if status.as_u16() == 429 || status.is_server_error() {
140                EmbeddingError::HttpError(format!("HTTP {status}: {body}"))
141            } else {
142                EmbeddingError::ApiError(format!("HTTP {status}: {body}"))
143            });
144        }
145
146        let value: Value = response
147            .json()
148            .await
149            .map_err(|e| EmbeddingError::ApiError(format!("Failed to parse response: {e}")))?;
150
151        extract_embedding_from_value(&value)
152    }
153
154    /// Call the endpoint with exponential-jitter retry on transient errors.
155    ///
156    /// Retries for up to 128 s total. Wait starts at 8 s (matching the Python
157    /// Ollama engine) and doubles on each attempt, capped at 128 s.  A uniform
158    /// random jitter of `[0, wait_secs)` is added to prevent thundering-herd.
159    async fn embed_single_with_retry(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
160        let max_duration = std::time::Duration::from_secs(128);
161        let start = std::time::Instant::now();
162        let mut wait_secs = 8u64;
163        loop {
164            match self.embed_single_once(text).await {
165                Ok(v) => return Ok(v),
166                Err(e)
167                    if matches!(e, EmbeddingError::HttpError(_))
168                        && start.elapsed() < max_duration =>
169                {
170                    let jitter = rand::random::<u64>() % wait_secs;
171                    tokio::time::sleep(std::time::Duration::from_secs(wait_secs + jitter)).await;
172                    wait_secs = (wait_secs * 2).min(128);
173                }
174                Err(e) => return Err(e),
175            }
176        }
177    }
178
179    /// Embed all texts concurrently, one request per text.
180    async fn embed_all(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
181        let sanitized = sanitize_embedding_inputs(texts);
182        let sanitized_refs: Vec<&str> = sanitized.iter().map(|s| s.as_ref()).collect();
183
184        let futures: Vec<_> = sanitized_refs
185            .iter()
186            .map(|&text| self.embed_single_with_retry(text))
187            .collect();
188
189        let results = future::join_all(futures).await;
190
191        let embeddings: EmbeddingResult<Vec<Vec<f32>>> = results.into_iter().collect();
192
193        Ok(handle_embedding_response(
194            texts,
195            embeddings?,
196            self.dimensions,
197        ))
198    }
199}
200
201#[async_trait]
202impl EmbeddingEngine for OllamaEmbeddingEngine {
203    async fn embed(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
204        if texts.is_empty() {
205            return Ok(Vec::new());
206        }
207        self.embed_all(texts).await
208    }
209
210    fn dimension(&self) -> usize {
211        self.dimensions
212    }
213
214    fn batch_size(&self) -> usize {
215        self.batch_size
216    }
217
218    fn max_sequence_length(&self) -> usize {
219        self.max_completion_tokens
220    }
221}
222
223// ─── Response parsing ─────────────────────────────────────────────────────────
224
225/// Extract a `Vec<f32>` from any of the three response shapes Ollama can return.
226///
227/// Shape 1 — standard `/api/embed`:
228/// ```json
229/// {"embeddings": [[0.1, 0.2, ...]]}
230/// ```
231///
232/// Shape 2 — legacy `/api/embeddings` (single embedding):
233/// ```json
234/// {"embedding": [0.1, 0.2, ...]}
235/// ```
236///
237/// Shape 3 — OpenAI-compatible:
238/// ```json
239/// {"data": [{"embedding": [0.1, 0.2, ...]}]}
240/// ```
241fn extract_embedding_from_value(value: &Value) -> EmbeddingResult<Vec<f32>> {
242    // Shape 1: {"embeddings": [[...]]}
243    if let Some(embeddings) = value.get("embeddings") {
244        if let Some(first) = embeddings.get(0) {
245            return parse_f32_array(first);
246        }
247        return Err(EmbeddingError::ApiError(
248            "Response 'embeddings' array is empty".to_string(),
249        ));
250    }
251
252    // Shape 2: {"embedding": [...]}
253    if let Some(embedding) = value.get("embedding") {
254        return parse_f32_array(embedding);
255    }
256
257    // Shape 3: {"data": [{"embedding": [...]}]}
258    if let Some(data) = value.get("data") {
259        if let Some(first) = data.get(0)
260            && let Some(embedding) = first.get("embedding")
261        {
262            return parse_f32_array(embedding);
263        }
264        return Err(EmbeddingError::ApiError(
265            "Response 'data' array is empty or missing 'embedding' field".to_string(),
266        ));
267    }
268
269    Err(EmbeddingError::ApiError(format!(
270        "Unrecognised response shape; expected 'embeddings', 'embedding', or 'data' key. Got: {value}"
271    )))
272}
273
274/// Parse a JSON array of numbers into a `Vec<f32>`.
275fn parse_f32_array(value: &Value) -> EmbeddingResult<Vec<f32>> {
276    let arr = value.as_array().ok_or_else(|| {
277        EmbeddingError::ApiError(format!("Expected a JSON array for embedding, got: {value}"))
278    })?;
279
280    arr.iter()
281        .map(|v| {
282            v.as_f64().map(|f| f as f32).ok_or_else(|| {
283                EmbeddingError::ApiError(format!("Non-numeric value in embedding array: {v}"))
284            })
285        })
286        .collect()
287}
288
289// ─── Tests ────────────────────────────────────────────────────────────────────
290
291#[cfg(test)]
292#[allow(
293    clippy::expect_used,
294    clippy::unwrap_used,
295    reason = "test code — panics are acceptable failures"
296)]
297mod tests {
298    use super::*;
299    use crate::config::EmbeddingConfig;
300    use crate::provider::EmbeddingProvider;
301
302    fn make_config() -> EmbeddingConfig {
303        EmbeddingConfig {
304            provider: EmbeddingProvider::Ollama,
305            model: "avr/sfr-embedding-mistral:latest".to_string(),
306            dimensions: 1024,
307            endpoint: None,
308            api_key: None,
309            api_version: None,
310            max_completion_tokens: 8191,
311            batch_size: 10,
312            mock: false,
313            mock_mode: Default::default(),
314            #[cfg(feature = "onnx")]
315            onnx: Default::default(),
316            huggingface_tokenizer: None,
317        }
318    }
319
320    #[test]
321    fn test_constructor_defaults() {
322        let config = make_config();
323        let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
324        assert_eq!(engine.endpoint, "http://localhost:11434/api/embed");
325        assert_eq!(engine.model, "avr/sfr-embedding-mistral:latest");
326        assert_eq!(engine.dimension(), 1024);
327        assert_eq!(engine.batch_size(), 10);
328        assert_eq!(engine.max_sequence_length(), 8191);
329    }
330
331    #[test]
332    fn test_constructor_custom_endpoint() {
333        let config = EmbeddingConfig {
334            endpoint: Some("http://my-ollama:11434/api/embed".to_string()),
335            ..make_config()
336        };
337        let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
338        assert_eq!(engine.endpoint, "http://my-ollama:11434/api/embed");
339    }
340
341    #[test]
342    fn test_truncate_text_short() {
343        let config = EmbeddingConfig {
344            max_completion_tokens: 10,
345            ..make_config()
346        };
347        let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
348        // "hello" is 5 chars, limit is 10 * 4 = 40 — no truncation
349        let result = engine.truncate_text("hello");
350        assert_eq!(result, "hello");
351    }
352
353    #[test]
354    fn test_truncate_text_exact_limit() {
355        let config = EmbeddingConfig {
356            max_completion_tokens: 2,
357            ..make_config()
358        };
359        let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
360        // limit = 2 * 4 = 8 chars; "abcdefgh" is exactly 8 chars → no truncation
361        let result = engine.truncate_text("abcdefgh");
362        assert_eq!(result, "abcdefgh");
363    }
364
365    #[test]
366    fn test_truncate_text_over_limit() {
367        let config = EmbeddingConfig {
368            max_completion_tokens: 2,
369            ..make_config()
370        };
371        let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
372        // limit = 2 * 4 = 8 chars; "abcdefghij" has 10 chars → truncated to 8
373        let result = engine.truncate_text("abcdefghij");
374        assert_eq!(result, "abcdefgh");
375    }
376
377    #[test]
378    fn test_truncate_text_unicode_boundary() {
379        let config = EmbeddingConfig {
380            max_completion_tokens: 1,
381            ..make_config()
382        };
383        let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
384        // limit = 1 * 4 = 4 chars
385        // "héllo" has 5 chars; 'é' is 2 bytes — must truncate at char boundary
386        let result = engine.truncate_text("héllo");
387        // First 4 chars: 'h', 'é', 'l', 'l'
388        assert_eq!(result, "héll");
389        // Must be valid UTF-8
390        assert!(std::str::from_utf8(result.as_bytes()).is_ok());
391    }
392
393    #[test]
394    fn test_truncate_text_empty() {
395        let config = make_config();
396        let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
397        assert_eq!(engine.truncate_text(""), "");
398    }
399
400    // ── Response shape parsing ───────────────────────────────────────────────
401
402    #[test]
403    fn test_parse_shape1_embeddings() {
404        let json = serde_json::json!({
405            "embeddings": [[0.1_f64, 0.2_f64, 0.3_f64]]
406        });
407        let result = extract_embedding_from_value(&json).expect("should parse shape 1");
408        assert_eq!(result.len(), 3);
409        assert!((result[0] - 0.1_f32).abs() < 1e-6);
410        assert!((result[1] - 0.2_f32).abs() < 1e-6);
411        assert!((result[2] - 0.3_f32).abs() < 1e-6);
412    }
413
414    #[test]
415    fn test_parse_shape2_embedding() {
416        let json = serde_json::json!({
417            "embedding": [0.4_f64, 0.5_f64]
418        });
419        let result = extract_embedding_from_value(&json).expect("should parse shape 2");
420        assert_eq!(result.len(), 2);
421        assert!((result[0] - 0.4_f32).abs() < 1e-6);
422        assert!((result[1] - 0.5_f32).abs() < 1e-6);
423    }
424
425    #[test]
426    fn test_parse_shape3_data() {
427        let json = serde_json::json!({
428            "data": [{"embedding": [0.6_f64, 0.7_f64, 0.8_f64]}]
429        });
430        let result = extract_embedding_from_value(&json).expect("should parse shape 3");
431        assert_eq!(result.len(), 3);
432        assert!((result[0] - 0.6_f32).abs() < 1e-6);
433        assert!((result[1] - 0.7_f32).abs() < 1e-6);
434        assert!((result[2] - 0.8_f32).abs() < 1e-6);
435    }
436
437    #[test]
438    fn test_parse_unrecognised_shape() {
439        let json = serde_json::json!({ "unknown": "value" });
440        let result = extract_embedding_from_value(&json);
441        assert!(result.is_err());
442        assert!(matches!(result, Err(EmbeddingError::ApiError(_))));
443    }
444
445    #[test]
446    fn test_parse_empty_embeddings_array() {
447        let json = serde_json::json!({ "embeddings": [] });
448        let result = extract_embedding_from_value(&json);
449        assert!(result.is_err());
450        assert!(matches!(result, Err(EmbeddingError::ApiError(_))));
451    }
452
453    #[test]
454    fn test_parse_empty_data_array() {
455        let json = serde_json::json!({ "data": [] });
456        let result = extract_embedding_from_value(&json);
457        assert!(result.is_err());
458        assert!(matches!(result, Err(EmbeddingError::ApiError(_))));
459    }
460
461    #[test]
462    fn test_parse_non_numeric_values() {
463        let json = serde_json::json!({ "embedding": ["not", "numbers"] });
464        let result = extract_embedding_from_value(&json);
465        assert!(result.is_err());
466        assert!(matches!(result, Err(EmbeddingError::ApiError(_))));
467    }
468}