Skip to main content

cognee_embedding/
ollama.rs

1//! Ollama embedding engine.
2//!
3//! Calls the Ollama `/api/embed` endpoint with a batched array `input`,
4//! sub-batched by `batch_size`, falling back to one concurrent request per text
5//! on servers that do not accept array input. Supports all three response
6//! shapes that Ollama can return:
7//! - `{"embeddings": [[...]]}` — standard Ollama `/api/embed`
8//! - `{"embedding": [...]}` — legacy Ollama `/api/embeddings`
9//! - `{"data": [{"embedding": [...]}]}` — OpenAI-compatible fallback shape
10
11use async_trait::async_trait;
12use futures::future;
13use serde::Serialize;
14use serde_json::Value;
15
16use crate::config::EmbeddingConfig;
17use crate::engine::EmbeddingEngine;
18use crate::error::{EmbeddingError, EmbeddingResult};
19use crate::utils::{handle_embedding_response, sanitize_embedding_inputs};
20
21// ─── Request type ─────────────────────────────────────────────────────────────
22
23#[derive(Serialize)]
24struct OllamaEmbedRequest<'a> {
25    model: &'a str,
26    input: &'a str,
27    #[serde(skip_serializing_if = "Option::is_none")]
28    dimensions: Option<usize>,
29}
30
31/// Batched request body: recent Ollama `/api/embed` accepts an array `input`
32/// and returns one embedding per element under the `embeddings` key.
33#[derive(Serialize)]
34struct OllamaBatchEmbedRequest<'a> {
35    model: &'a str,
36    input: Vec<&'a str>,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    dimensions: Option<usize>,
39}
40
41/// Outcome of a failed batched (`array input`) request.
42///
43/// Only [`BatchError::ArrayUnsupported`] triggers the per-text fallback in
44/// [`OllamaEmbeddingEngine::embed_all`]; a [`BatchError::Fatal`] (real HTTP or
45/// parse error such as 404 model-not-found) propagates instead of fanning out
46/// `1 + N` doomed requests.
47enum BatchError {
48    /// The server likely ignores/does not support array `input`: it returned a
49    /// count that does not match the inputs or an unrecognised response shape.
50    ArrayUnsupported,
51    /// A genuine error that per-text requests would hit too.
52    Fatal(EmbeddingError),
53}
54
55// ─── Engine ───────────────────────────────────────────────────────────────────
56
57/// Embedding engine that calls the Ollama `/api/embed` HTTP endpoint.
58///
59/// Sends a batched array `input` per request, sub-batched by `batch_size`, and
60/// falls back to one concurrent request per text (via
61/// `futures::future::join_all`) for servers that do not accept array input.
62/// Transient HTTP errors (network failures, 429, 5xx) are retried with
63/// exponential back-off starting at 8 s (doubling to 128 s) for up to 128 s total.
64///
65/// # Response shapes
66///
67/// Ollama can return embeddings in three shapes depending on the version and endpoint:
68/// - `{"embeddings": [[...]]}` — standard `/api/embed` response
69/// - `{"embedding": [...]}` — legacy single-embedding response
70/// - `{"data": [{"embedding": [...]}]}` — OpenAI-compatible shape
71///
72/// All three shapes are handled transparently.
73pub struct OllamaEmbeddingEngine {
74    client: reqwest::Client,
75    /// Full URL to the Ollama embed endpoint, e.g. `http://localhost:11434/api/embed`.
76    endpoint: String,
77    model: String,
78    dimensions: usize,
79    batch_size: usize,
80    max_completion_tokens: usize,
81}
82
83impl OllamaEmbeddingEngine {
84    /// Construct a new engine from the given [`EmbeddingConfig`].
85    ///
86    /// Returns [`EmbeddingError::ConfigError`] if the `reqwest` client cannot
87    /// be built (e.g. invalid TLS or API key header value).
88    pub fn new(config: &EmbeddingConfig) -> EmbeddingResult<Self> {
89        let endpoint = config
90            .endpoint
91            .clone()
92            .unwrap_or_else(|| "http://localhost:11434/api/embed".to_string());
93
94        let mut default_headers = reqwest::header::HeaderMap::new();
95
96        if let Some(api_key) = &config.api_key
97            && !api_key.is_empty()
98        {
99            let bearer = format!("Bearer {api_key}");
100            let auth_value = reqwest::header::HeaderValue::from_str(&bearer)
101                .map_err(|e| EmbeddingError::ConfigError(format!("Invalid API key value: {e}")))?;
102            default_headers.insert(reqwest::header::AUTHORIZATION, auth_value);
103        }
104
105        let client = reqwest::Client::builder()
106            .default_headers(default_headers)
107            .timeout(std::time::Duration::from_secs(30))
108            .build()
109            .map_err(|e| {
110                EmbeddingError::ConfigError(format!("Failed to build HTTP client: {e}"))
111            })?;
112
113        Ok(Self {
114            client,
115            endpoint,
116            model: config.model.clone(),
117            dimensions: config.dimensions,
118            batch_size: config.batch_size,
119            max_completion_tokens: config.max_completion_tokens,
120        })
121    }
122
123    /// Truncate `text` to at most `max_completion_tokens * 4` characters.
124    ///
125    /// Truncation is on a Unicode character boundary, not a byte boundary.
126    /// The factor of 4 is the same heuristic used by the Python SDK.
127    fn truncate_text<'a>(&self, text: &'a str) -> &'a str {
128        let char_limit = self.max_completion_tokens * 4;
129        let byte_pos = text
130            .char_indices()
131            .nth(char_limit)
132            .map(|(i, _)| i)
133            .unwrap_or(text.len());
134        &text[..byte_pos]
135    }
136
137    /// Call the Ollama endpoint once for a single text (no retry).
138    async fn embed_single_once(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
139        let truncated = self.truncate_text(text);
140
141        let request_body = OllamaEmbedRequest {
142            model: &self.model,
143            input: truncated,
144            // Only send `dimensions` if it's non-zero; some older Ollama versions
145            // reject unknown fields.
146            dimensions: if self.dimensions > 0 {
147                Some(self.dimensions)
148            } else {
149                None
150            },
151        };
152
153        let response = self
154            .client
155            .post(&self.endpoint)
156            .json(&request_body)
157            .send()
158            .await
159            .map_err(|e| EmbeddingError::HttpError(format!("Request failed: {e}")))?;
160
161        let status = response.status();
162        if !status.is_success() {
163            let body = response
164                .text()
165                .await
166                .unwrap_or_else(|_| "<failed to read body>".to_string());
167            return Err(if status.as_u16() == 429 || status.is_server_error() {
168                EmbeddingError::HttpError(format!("HTTP {status}: {body}"))
169            } else {
170                EmbeddingError::ApiError(format!("HTTP {status}: {body}"))
171            });
172        }
173
174        let value: Value = response
175            .json()
176            .await
177            .map_err(|e| EmbeddingError::ApiError(format!("Failed to parse response: {e}")))?;
178
179        extract_embedding_from_value(&value)
180    }
181
182    /// Call the endpoint with exponential-jitter retry on transient errors.
183    ///
184    /// Retries for up to 128 s total. Wait starts at 8 s (matching the Python
185    /// Ollama engine) and doubles on each attempt, capped at 128 s.  A uniform
186    /// random jitter of `[0, wait_secs)` is added to prevent thundering-herd.
187    async fn embed_single_with_retry(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
188        let max_duration = std::time::Duration::from_secs(128);
189        let start = std::time::Instant::now();
190        let mut wait_secs = 8u64;
191        loop {
192            match self.embed_single_once(text).await {
193                Ok(v) => return Ok(v),
194                Err(e)
195                    if matches!(e, EmbeddingError::HttpError(_))
196                        && start.elapsed() < max_duration =>
197                {
198                    let jitter = rand::random::<u64>() % wait_secs;
199                    tokio::time::sleep(std::time::Duration::from_secs(wait_secs + jitter)).await;
200                    wait_secs = (wait_secs * 2).min(128);
201                }
202                Err(e) => return Err(e),
203            }
204        }
205    }
206
207    /// Call the endpoint once with an array `input` (no retry).
208    async fn embed_batch_once(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, BatchError> {
209        let truncated: Vec<&str> = texts.iter().map(|t| self.truncate_text(t)).collect();
210
211        let request_body = OllamaBatchEmbedRequest {
212            model: &self.model,
213            input: truncated,
214            dimensions: if self.dimensions > 0 {
215                Some(self.dimensions)
216            } else {
217                None
218            },
219        };
220
221        let response = self
222            .client
223            .post(&self.endpoint)
224            .json(&request_body)
225            .send()
226            .await
227            .map_err(|e| {
228                BatchError::Fatal(EmbeddingError::HttpError(format!("Request failed: {e}")))
229            })?;
230
231        let status = response.status();
232        if !status.is_success() {
233            let body = response
234                .text()
235                .await
236                .unwrap_or_else(|_| "<failed to read body>".to_string());
237            return Err(BatchError::Fatal(
238                if status.as_u16() == 429 || status.is_server_error() {
239                    EmbeddingError::HttpError(format!("HTTP {status}: {body}"))
240                } else {
241                    EmbeddingError::ApiError(format!("HTTP {status}: {body}"))
242                },
243            ));
244        }
245
246        let value: Value = response.json().await.map_err(|e| {
247            BatchError::Fatal(EmbeddingError::ApiError(format!(
248                "Failed to parse response: {e}"
249            )))
250        })?;
251
252        // An unrecognised shape or a count that doesn't match the inputs means the
253        // server ignored/rejected array `input`; treat it as "array unsupported"
254        // so the caller can fall back to per-text requests.
255        let embeddings =
256            extract_all_embeddings_from_value(&value).map_err(|_| BatchError::ArrayUnsupported)?;
257        if embeddings.len() != texts.len() {
258            return Err(BatchError::ArrayUnsupported);
259        }
260        Ok(embeddings)
261    }
262
263    /// Batch variant of [`embed_single_with_retry`], retrying transient errors.
264    async fn embed_batch_with_retry(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, BatchError> {
265        let max_duration = std::time::Duration::from_secs(128);
266        let start = std::time::Instant::now();
267        let mut wait_secs = 8u64;
268        loop {
269            match self.embed_batch_once(texts).await {
270                Ok(v) => return Ok(v),
271                Err(err) => {
272                    let transient = matches!(&err, BatchError::Fatal(EmbeddingError::HttpError(_)));
273                    if transient && start.elapsed() < max_duration {
274                        let jitter = rand::random::<u64>() % wait_secs;
275                        tokio::time::sleep(std::time::Duration::from_secs(wait_secs + jitter))
276                            .await;
277                        wait_secs = (wait_secs * 2).min(128);
278                    } else {
279                        return Err(err);
280                    }
281                }
282            }
283        }
284    }
285
286    /// Embed all texts, sub-batched by `batch_size` using array `input`.
287    ///
288    /// Only falls back to one request per text when the server signals it does
289    /// not support array `input` ([`BatchError::ArrayUnsupported`]); genuine
290    /// errors propagate rather than fanning out `1 + N` doomed requests.
291    async fn embed_all(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
292        let sanitized = sanitize_embedding_inputs(texts);
293        let sanitized_refs: Vec<&str> = sanitized.iter().map(|s| s.as_ref()).collect();
294
295        let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
296        for batch in sanitized_refs.chunks(self.batch_size.max(1)) {
297            match self.embed_batch_with_retry(batch).await {
298                Ok(batch_embeddings) => embeddings.extend(batch_embeddings),
299                Err(BatchError::ArrayUnsupported) => {
300                    let futures: Vec<_> = batch
301                        .iter()
302                        .map(|&text| self.embed_single_with_retry(text))
303                        .collect();
304                    for result in future::join_all(futures).await {
305                        embeddings.push(result?);
306                    }
307                }
308                Err(BatchError::Fatal(e)) => return Err(e),
309            }
310        }
311
312        Ok(handle_embedding_response(
313            texts,
314            embeddings,
315            self.dimensions,
316        ))
317    }
318}
319
320#[async_trait]
321impl EmbeddingEngine for OllamaEmbeddingEngine {
322    async fn embed(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
323        if texts.is_empty() {
324            return Ok(Vec::new());
325        }
326        self.embed_all(texts).await
327    }
328
329    fn dimension(&self) -> usize {
330        self.dimensions
331    }
332
333    fn batch_size(&self) -> usize {
334        self.batch_size
335    }
336
337    fn max_sequence_length(&self) -> usize {
338        self.max_completion_tokens
339    }
340}
341
342// ─── Response parsing ─────────────────────────────────────────────────────────
343
344/// Extract a `Vec<f32>` from any of the three response shapes Ollama can return.
345///
346/// Shape 1 — standard `/api/embed`:
347/// ```json
348/// {"embeddings": [[0.1, 0.2, ...]]}
349/// ```
350///
351/// Shape 2 — legacy `/api/embeddings` (single embedding):
352/// ```json
353/// {"embedding": [0.1, 0.2, ...]}
354/// ```
355///
356/// Shape 3 — OpenAI-compatible:
357/// ```json
358/// {"data": [{"embedding": [0.1, 0.2, ...]}]}
359/// ```
360fn extract_embedding_from_value(value: &Value) -> EmbeddingResult<Vec<f32>> {
361    // Shape 1: {"embeddings": [[...]]}
362    if let Some(embeddings) = value.get("embeddings") {
363        if let Some(first) = embeddings.get(0) {
364            return parse_f32_array(first);
365        }
366        return Err(EmbeddingError::ApiError(
367            "Response 'embeddings' array is empty".to_string(),
368        ));
369    }
370
371    // Shape 2: {"embedding": [...]}
372    if let Some(embedding) = value.get("embedding") {
373        return parse_f32_array(embedding);
374    }
375
376    // Shape 3: {"data": [{"embedding": [...]}]}
377    if let Some(data) = value.get("data") {
378        if let Some(first) = data.get(0)
379            && let Some(embedding) = first.get("embedding")
380        {
381            return parse_f32_array(embedding);
382        }
383        return Err(EmbeddingError::ApiError(
384            "Response 'data' array is empty or missing 'embedding' field".to_string(),
385        ));
386    }
387
388    Err(EmbeddingError::ApiError(format!(
389        "Unrecognised response shape; expected 'embeddings', 'embedding', or 'data' key. Got: {value}"
390    )))
391}
392
393/// Extract every embedding from a batched response (array `input`).
394///
395/// Handles the same shapes as [`extract_embedding_from_value`] but returns all
396/// embeddings rather than just the first:
397/// - `{"embeddings": [[...], [...]]}` — standard `/api/embed`
398/// - `{"data": [{"embedding": [...]}, ...]}` — OpenAI-compatible
399/// - `{"embedding": [...]}` — single embedding, returned as a one-element vec
400fn extract_all_embeddings_from_value(value: &Value) -> EmbeddingResult<Vec<Vec<f32>>> {
401    if let Some(embeddings) = value.get("embeddings").and_then(|v| v.as_array()) {
402        return embeddings.iter().map(parse_f32_array).collect();
403    }
404
405    if let Some(data) = value.get("data").and_then(|v| v.as_array()) {
406        return data
407            .iter()
408            .map(|item| {
409                item.get("embedding").ok_or_else(|| {
410                    EmbeddingError::ApiError("Response 'data' item missing 'embedding'".to_string())
411                })
412            })
413            .map(|embedding| embedding.and_then(parse_f32_array))
414            .collect();
415    }
416
417    if let Some(embedding) = value.get("embedding") {
418        return Ok(vec![parse_f32_array(embedding)?]);
419    }
420
421    Err(EmbeddingError::ApiError(format!(
422        "Unrecognised response shape; expected 'embeddings', 'embedding', or 'data' key. Got: {value}"
423    )))
424}
425
426/// Parse a JSON array of numbers into a `Vec<f32>`.
427fn parse_f32_array(value: &Value) -> EmbeddingResult<Vec<f32>> {
428    let arr = value.as_array().ok_or_else(|| {
429        EmbeddingError::ApiError(format!("Expected a JSON array for embedding, got: {value}"))
430    })?;
431
432    arr.iter()
433        .map(|v| {
434            v.as_f64().map(|f| f as f32).ok_or_else(|| {
435                EmbeddingError::ApiError(format!("Non-numeric value in embedding array: {v}"))
436            })
437        })
438        .collect()
439}
440
441// ─── Tests ────────────────────────────────────────────────────────────────────
442
443#[cfg(test)]
444#[allow(
445    clippy::expect_used,
446    clippy::unwrap_used,
447    reason = "test code — panics are acceptable failures"
448)]
449mod tests {
450    use super::*;
451    use crate::config::EmbeddingConfig;
452    use crate::provider::EmbeddingProvider;
453
454    fn make_config() -> EmbeddingConfig {
455        EmbeddingConfig {
456            provider: EmbeddingProvider::Ollama,
457            model: "avr/sfr-embedding-mistral:latest".to_string(),
458            dimensions: 1024,
459            endpoint: None,
460            api_key: None,
461            api_version: None,
462            max_completion_tokens: 8191,
463            batch_size: 10,
464            mock: false,
465            mock_mode: Default::default(),
466            #[cfg(feature = "onnx")]
467            onnx: Default::default(),
468            huggingface_tokenizer: None,
469        }
470    }
471
472    #[test]
473    fn test_constructor_defaults() {
474        let config = make_config();
475        let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
476        assert_eq!(engine.endpoint, "http://localhost:11434/api/embed");
477        assert_eq!(engine.model, "avr/sfr-embedding-mistral:latest");
478        assert_eq!(engine.dimension(), 1024);
479        assert_eq!(engine.batch_size(), 10);
480        assert_eq!(engine.max_sequence_length(), 8191);
481    }
482
483    #[test]
484    fn test_constructor_custom_endpoint() {
485        let config = EmbeddingConfig {
486            endpoint: Some("http://my-ollama:11434/api/embed".to_string()),
487            ..make_config()
488        };
489        let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
490        assert_eq!(engine.endpoint, "http://my-ollama:11434/api/embed");
491    }
492
493    #[test]
494    fn test_truncate_text_short() {
495        let config = EmbeddingConfig {
496            max_completion_tokens: 10,
497            ..make_config()
498        };
499        let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
500        // "hello" is 5 chars, limit is 10 * 4 = 40 — no truncation
501        let result = engine.truncate_text("hello");
502        assert_eq!(result, "hello");
503    }
504
505    #[test]
506    fn test_truncate_text_exact_limit() {
507        let config = EmbeddingConfig {
508            max_completion_tokens: 2,
509            ..make_config()
510        };
511        let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
512        // limit = 2 * 4 = 8 chars; "abcdefgh" is exactly 8 chars → no truncation
513        let result = engine.truncate_text("abcdefgh");
514        assert_eq!(result, "abcdefgh");
515    }
516
517    #[test]
518    fn test_truncate_text_over_limit() {
519        let config = EmbeddingConfig {
520            max_completion_tokens: 2,
521            ..make_config()
522        };
523        let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
524        // limit = 2 * 4 = 8 chars; "abcdefghij" has 10 chars → truncated to 8
525        let result = engine.truncate_text("abcdefghij");
526        assert_eq!(result, "abcdefgh");
527    }
528
529    #[test]
530    fn test_truncate_text_unicode_boundary() {
531        let config = EmbeddingConfig {
532            max_completion_tokens: 1,
533            ..make_config()
534        };
535        let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
536        // limit = 1 * 4 = 4 chars
537        // "héllo" has 5 chars; 'é' is 2 bytes — must truncate at char boundary
538        let result = engine.truncate_text("héllo");
539        // First 4 chars: 'h', 'é', 'l', 'l'
540        assert_eq!(result, "héll");
541        // Must be valid UTF-8
542        assert!(std::str::from_utf8(result.as_bytes()).is_ok());
543    }
544
545    #[test]
546    fn test_truncate_text_empty() {
547        let config = make_config();
548        let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
549        assert_eq!(engine.truncate_text(""), "");
550    }
551
552    // ── Response shape parsing ───────────────────────────────────────────────
553
554    #[test]
555    fn test_parse_shape1_embeddings() {
556        let json = serde_json::json!({
557            "embeddings": [[0.1_f64, 0.2_f64, 0.3_f64]]
558        });
559        let result = extract_embedding_from_value(&json).expect("should parse shape 1");
560        assert_eq!(result.len(), 3);
561        assert!((result[0] - 0.1_f32).abs() < 1e-6);
562        assert!((result[1] - 0.2_f32).abs() < 1e-6);
563        assert!((result[2] - 0.3_f32).abs() < 1e-6);
564    }
565
566    #[test]
567    fn test_parse_shape2_embedding() {
568        let json = serde_json::json!({
569            "embedding": [0.4_f64, 0.5_f64]
570        });
571        let result = extract_embedding_from_value(&json).expect("should parse shape 2");
572        assert_eq!(result.len(), 2);
573        assert!((result[0] - 0.4_f32).abs() < 1e-6);
574        assert!((result[1] - 0.5_f32).abs() < 1e-6);
575    }
576
577    #[test]
578    fn test_parse_shape3_data() {
579        let json = serde_json::json!({
580            "data": [{"embedding": [0.6_f64, 0.7_f64, 0.8_f64]}]
581        });
582        let result = extract_embedding_from_value(&json).expect("should parse shape 3");
583        assert_eq!(result.len(), 3);
584        assert!((result[0] - 0.6_f32).abs() < 1e-6);
585        assert!((result[1] - 0.7_f32).abs() < 1e-6);
586        assert!((result[2] - 0.8_f32).abs() < 1e-6);
587    }
588
589    #[test]
590    fn test_parse_unrecognised_shape() {
591        let json = serde_json::json!({ "unknown": "value" });
592        let result = extract_embedding_from_value(&json);
593        assert!(result.is_err());
594        assert!(matches!(result, Err(EmbeddingError::ApiError(_))));
595    }
596
597    #[test]
598    fn test_parse_empty_embeddings_array() {
599        let json = serde_json::json!({ "embeddings": [] });
600        let result = extract_embedding_from_value(&json);
601        assert!(result.is_err());
602        assert!(matches!(result, Err(EmbeddingError::ApiError(_))));
603    }
604
605    #[test]
606    fn test_parse_empty_data_array() {
607        let json = serde_json::json!({ "data": [] });
608        let result = extract_embedding_from_value(&json);
609        assert!(result.is_err());
610        assert!(matches!(result, Err(EmbeddingError::ApiError(_))));
611    }
612
613    #[test]
614    fn test_parse_non_numeric_values() {
615        let json = serde_json::json!({ "embedding": ["not", "numbers"] });
616        let result = extract_embedding_from_value(&json);
617        assert!(result.is_err());
618        assert!(matches!(result, Err(EmbeddingError::ApiError(_))));
619    }
620
621    // ── Batched response parsing (array input) ───────────────────────────────
622
623    #[test]
624    fn test_parse_all_embeddings_shape1() {
625        let json = serde_json::json!({
626            "embeddings": [[0.1_f64, 0.2_f64], [0.3_f64, 0.4_f64]]
627        });
628        let result = extract_all_embeddings_from_value(&json).expect("should parse batch");
629        assert_eq!(result.len(), 2);
630        assert!((result[0][0] - 0.1_f32).abs() < 1e-6);
631        assert!((result[1][1] - 0.4_f32).abs() < 1e-6);
632    }
633
634    #[test]
635    fn test_parse_all_embeddings_data_shape() {
636        let json = serde_json::json!({
637            "data": [{"embedding": [0.1_f64]}, {"embedding": [0.2_f64]}]
638        });
639        let result = extract_all_embeddings_from_value(&json).expect("should parse batch");
640        assert_eq!(result.len(), 2);
641        assert!((result[0][0] - 0.1_f32).abs() < 1e-6);
642        assert!((result[1][0] - 0.2_f32).abs() < 1e-6);
643    }
644
645    #[test]
646    fn test_parse_all_embeddings_single_shape() {
647        let json = serde_json::json!({ "embedding": [0.5_f64, 0.6_f64] });
648        let result = extract_all_embeddings_from_value(&json).expect("should parse single");
649        assert_eq!(result.len(), 1);
650        assert_eq!(result[0].len(), 2);
651    }
652
653    #[test]
654    fn test_parse_all_embeddings_unrecognised() {
655        let json = serde_json::json!({ "nope": 1 });
656        assert!(matches!(
657            extract_all_embeddings_from_value(&json),
658            Err(EmbeddingError::ApiError(_))
659        ));
660    }
661
662    // ── End-to-end batching / fallback (mock HTTP server) ────────────────────
663
664    fn config_for(server_url: &str) -> EmbeddingConfig {
665        EmbeddingConfig {
666            dimensions: 2,
667            endpoint: Some(format!("{server_url}/api/embed")),
668            ..make_config()
669        }
670    }
671
672    #[tokio::test]
673    async fn embed_batches_array_input() {
674        let mut server = mockito::Server::new_async().await;
675        let batch = server
676            .mock("POST", "/api/embed")
677            .match_body(mockito::Matcher::Regex(r#""input":\["#.to_string()))
678            .with_status(200)
679            .with_header("content-type", "application/json")
680            .with_body(r#"{"embeddings":[[1.0,0.0],[0.0,1.0]]}"#)
681            .create_async()
682            .await;
683
684        let engine = OllamaEmbeddingEngine::new(&config_for(&server.url())).unwrap();
685        let out = engine.embed(&["alpha", "beta"]).await.unwrap();
686
687        assert_eq!(out, vec![vec![1.0, 0.0], vec![0.0, 1.0]]);
688        batch.assert_async().await;
689    }
690
691    #[tokio::test]
692    async fn embed_falls_back_to_per_text_when_array_rejected() {
693        let mut server = mockito::Server::new_async().await;
694        // Legacy server ignores the array and returns a single embedding →
695        // count mismatch → treated as "array unsupported" → per-text fallback.
696        let batch = server
697            .mock("POST", "/api/embed")
698            .match_body(mockito::Matcher::Regex(r#""input":\["#.to_string()))
699            .with_status(200)
700            .with_header("content-type", "application/json")
701            .with_body(r#"{"embedding":[9.9,9.9]}"#)
702            .create_async()
703            .await;
704        // Per-text requests succeed; distinct vectors verify ordering is kept.
705        let single_a = server
706            .mock("POST", "/api/embed")
707            .match_body(mockito::Matcher::Regex(r#""input":"alpha""#.to_string()))
708            .with_status(200)
709            .with_body(r#"{"embedding":[1.0,0.0]}"#)
710            .create_async()
711            .await;
712        let single_b = server
713            .mock("POST", "/api/embed")
714            .match_body(mockito::Matcher::Regex(r#""input":"beta""#.to_string()))
715            .with_status(200)
716            .with_body(r#"{"embedding":[0.0,1.0]}"#)
717            .create_async()
718            .await;
719
720        let engine = OllamaEmbeddingEngine::new(&config_for(&server.url())).unwrap();
721        let out = engine.embed(&["alpha", "beta"]).await.unwrap();
722
723        assert_eq!(out, vec![vec![1.0, 0.0], vec![0.0, 1.0]]);
724        batch.assert_async().await;
725        single_a.assert_async().await;
726        single_b.assert_async().await;
727    }
728
729    #[tokio::test]
730    async fn embed_does_not_panic_on_zero_batch_size() {
731        let mut server = mockito::Server::new_async().await;
732        // Each element becomes its own single-item batch (chunks(1)).
733        let batch = server
734            .mock("POST", "/api/embed")
735            .match_body(mockito::Matcher::Regex(r#""input":\["#.to_string()))
736            .with_status(200)
737            .with_header("content-type", "application/json")
738            .with_body(r#"{"embeddings":[[1.0,0.0]]}"#)
739            .expect(2)
740            .create_async()
741            .await;
742
743        let config = EmbeddingConfig {
744            batch_size: 0,
745            ..config_for(&server.url())
746        };
747        let engine = OllamaEmbeddingEngine::new(&config).unwrap();
748        let out = engine.embed(&["alpha", "beta"]).await.unwrap();
749
750        assert_eq!(out.len(), 2);
751        batch.assert_async().await;
752    }
753
754    #[tokio::test]
755    async fn embed_propagates_http_error_without_falling_back() {
756        let mut server = mockito::Server::new_async().await;
757        // A genuine 404 (e.g. model not found) must propagate, not fan out.
758        let batch = server
759            .mock("POST", "/api/embed")
760            .match_body(mockito::Matcher::Regex(r#""input":\["#.to_string()))
761            .with_status(404)
762            .with_body("model not found")
763            .expect(1)
764            .create_async()
765            .await;
766        // Per-text (string input) requests must never be issued.
767        let per_text = server
768            .mock("POST", "/api/embed")
769            .match_body(mockito::Matcher::Regex(r#""input":"[a-z]"#.to_string()))
770            .with_status(200)
771            .with_body(r#"{"embedding":[0.0,0.0]}"#)
772            .expect(0)
773            .create_async()
774            .await;
775
776        let engine = OllamaEmbeddingEngine::new(&config_for(&server.url())).unwrap();
777        let result = engine.embed(&["alpha", "beta"]).await;
778
779        assert!(result.is_err());
780        batch.assert_async().await;
781        per_text.assert_async().await;
782    }
783}