1use async_trait::async_trait;
12use futures::future;
13use serde::Serialize;
14use serde_json::Value;
15
16use crate::config::EmbeddingConfig;
17use crate::engine::EmbeddingEngine;
18use crate::error::{EmbeddingError, EmbeddingResult};
19use crate::utils::{handle_embedding_response, sanitize_embedding_inputs};
20
21#[derive(Serialize)]
24struct OllamaEmbedRequest<'a> {
25 model: &'a str,
26 input: &'a str,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 dimensions: Option<usize>,
29}
30
31#[derive(Serialize)]
34struct OllamaBatchEmbedRequest<'a> {
35 model: &'a str,
36 input: Vec<&'a str>,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 dimensions: Option<usize>,
39}
40
41enum BatchError {
48 ArrayUnsupported,
51 Fatal(EmbeddingError),
53}
54
55pub struct OllamaEmbeddingEngine {
74 client: reqwest::Client,
75 endpoint: String,
77 model: String,
78 dimensions: usize,
79 batch_size: usize,
80 max_completion_tokens: usize,
81}
82
83impl OllamaEmbeddingEngine {
84 pub fn new(config: &EmbeddingConfig) -> EmbeddingResult<Self> {
89 let endpoint = config
90 .endpoint
91 .clone()
92 .unwrap_or_else(|| "http://localhost:11434/api/embed".to_string());
93
94 let mut default_headers = reqwest::header::HeaderMap::new();
95
96 if let Some(api_key) = &config.api_key
97 && !api_key.is_empty()
98 {
99 let bearer = format!("Bearer {api_key}");
100 let auth_value = reqwest::header::HeaderValue::from_str(&bearer)
101 .map_err(|e| EmbeddingError::ConfigError(format!("Invalid API key value: {e}")))?;
102 default_headers.insert(reqwest::header::AUTHORIZATION, auth_value);
103 }
104
105 let client = reqwest::Client::builder()
106 .default_headers(default_headers)
107 .timeout(std::time::Duration::from_secs(30))
108 .build()
109 .map_err(|e| {
110 EmbeddingError::ConfigError(format!("Failed to build HTTP client: {e}"))
111 })?;
112
113 Ok(Self {
114 client,
115 endpoint,
116 model: config.model.clone(),
117 dimensions: config.dimensions,
118 batch_size: config.batch_size,
119 max_completion_tokens: config.max_completion_tokens,
120 })
121 }
122
123 fn truncate_text<'a>(&self, text: &'a str) -> &'a str {
128 let char_limit = self.max_completion_tokens * 4;
129 let byte_pos = text
130 .char_indices()
131 .nth(char_limit)
132 .map(|(i, _)| i)
133 .unwrap_or(text.len());
134 &text[..byte_pos]
135 }
136
137 async fn embed_single_once(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
139 let truncated = self.truncate_text(text);
140
141 let request_body = OllamaEmbedRequest {
142 model: &self.model,
143 input: truncated,
144 dimensions: if self.dimensions > 0 {
147 Some(self.dimensions)
148 } else {
149 None
150 },
151 };
152
153 let response = self
154 .client
155 .post(&self.endpoint)
156 .json(&request_body)
157 .send()
158 .await
159 .map_err(|e| EmbeddingError::HttpError(format!("Request failed: {e}")))?;
160
161 let status = response.status();
162 if !status.is_success() {
163 let body = response
164 .text()
165 .await
166 .unwrap_or_else(|_| "<failed to read body>".to_string());
167 return Err(if status.as_u16() == 429 || status.is_server_error() {
168 EmbeddingError::HttpError(format!("HTTP {status}: {body}"))
169 } else {
170 EmbeddingError::ApiError(format!("HTTP {status}: {body}"))
171 });
172 }
173
174 let value: Value = response
175 .json()
176 .await
177 .map_err(|e| EmbeddingError::ApiError(format!("Failed to parse response: {e}")))?;
178
179 extract_embedding_from_value(&value)
180 }
181
182 async fn embed_single_with_retry(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
188 let max_duration = std::time::Duration::from_secs(128);
189 let start = std::time::Instant::now();
190 let mut wait_secs = 8u64;
191 loop {
192 match self.embed_single_once(text).await {
193 Ok(v) => return Ok(v),
194 Err(e)
195 if matches!(e, EmbeddingError::HttpError(_))
196 && start.elapsed() < max_duration =>
197 {
198 let jitter = rand::random::<u64>() % wait_secs;
199 tokio::time::sleep(std::time::Duration::from_secs(wait_secs + jitter)).await;
200 wait_secs = (wait_secs * 2).min(128);
201 }
202 Err(e) => return Err(e),
203 }
204 }
205 }
206
207 async fn embed_batch_once(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, BatchError> {
209 let truncated: Vec<&str> = texts.iter().map(|t| self.truncate_text(t)).collect();
210
211 let request_body = OllamaBatchEmbedRequest {
212 model: &self.model,
213 input: truncated,
214 dimensions: if self.dimensions > 0 {
215 Some(self.dimensions)
216 } else {
217 None
218 },
219 };
220
221 let response = self
222 .client
223 .post(&self.endpoint)
224 .json(&request_body)
225 .send()
226 .await
227 .map_err(|e| {
228 BatchError::Fatal(EmbeddingError::HttpError(format!("Request failed: {e}")))
229 })?;
230
231 let status = response.status();
232 if !status.is_success() {
233 let body = response
234 .text()
235 .await
236 .unwrap_or_else(|_| "<failed to read body>".to_string());
237 return Err(BatchError::Fatal(
238 if status.as_u16() == 429 || status.is_server_error() {
239 EmbeddingError::HttpError(format!("HTTP {status}: {body}"))
240 } else {
241 EmbeddingError::ApiError(format!("HTTP {status}: {body}"))
242 },
243 ));
244 }
245
246 let value: Value = response.json().await.map_err(|e| {
247 BatchError::Fatal(EmbeddingError::ApiError(format!(
248 "Failed to parse response: {e}"
249 )))
250 })?;
251
252 let embeddings =
256 extract_all_embeddings_from_value(&value).map_err(|_| BatchError::ArrayUnsupported)?;
257 if embeddings.len() != texts.len() {
258 return Err(BatchError::ArrayUnsupported);
259 }
260 Ok(embeddings)
261 }
262
263 async fn embed_batch_with_retry(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, BatchError> {
265 let max_duration = std::time::Duration::from_secs(128);
266 let start = std::time::Instant::now();
267 let mut wait_secs = 8u64;
268 loop {
269 match self.embed_batch_once(texts).await {
270 Ok(v) => return Ok(v),
271 Err(err) => {
272 let transient = matches!(&err, BatchError::Fatal(EmbeddingError::HttpError(_)));
273 if transient && start.elapsed() < max_duration {
274 let jitter = rand::random::<u64>() % wait_secs;
275 tokio::time::sleep(std::time::Duration::from_secs(wait_secs + jitter))
276 .await;
277 wait_secs = (wait_secs * 2).min(128);
278 } else {
279 return Err(err);
280 }
281 }
282 }
283 }
284 }
285
286 async fn embed_all(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
292 let sanitized = sanitize_embedding_inputs(texts);
293 let sanitized_refs: Vec<&str> = sanitized.iter().map(|s| s.as_ref()).collect();
294
295 let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
296 for batch in sanitized_refs.chunks(self.batch_size.max(1)) {
297 match self.embed_batch_with_retry(batch).await {
298 Ok(batch_embeddings) => embeddings.extend(batch_embeddings),
299 Err(BatchError::ArrayUnsupported) => {
300 let futures: Vec<_> = batch
301 .iter()
302 .map(|&text| self.embed_single_with_retry(text))
303 .collect();
304 for result in future::join_all(futures).await {
305 embeddings.push(result?);
306 }
307 }
308 Err(BatchError::Fatal(e)) => return Err(e),
309 }
310 }
311
312 Ok(handle_embedding_response(
313 texts,
314 embeddings,
315 self.dimensions,
316 ))
317 }
318}
319
320#[async_trait]
321impl EmbeddingEngine for OllamaEmbeddingEngine {
322 async fn embed(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
323 if texts.is_empty() {
324 return Ok(Vec::new());
325 }
326 self.embed_all(texts).await
327 }
328
329 fn dimension(&self) -> usize {
330 self.dimensions
331 }
332
333 fn batch_size(&self) -> usize {
334 self.batch_size
335 }
336
337 fn max_sequence_length(&self) -> usize {
338 self.max_completion_tokens
339 }
340}
341
342fn extract_embedding_from_value(value: &Value) -> EmbeddingResult<Vec<f32>> {
361 if let Some(embeddings) = value.get("embeddings") {
363 if let Some(first) = embeddings.get(0) {
364 return parse_f32_array(first);
365 }
366 return Err(EmbeddingError::ApiError(
367 "Response 'embeddings' array is empty".to_string(),
368 ));
369 }
370
371 if let Some(embedding) = value.get("embedding") {
373 return parse_f32_array(embedding);
374 }
375
376 if let Some(data) = value.get("data") {
378 if let Some(first) = data.get(0)
379 && let Some(embedding) = first.get("embedding")
380 {
381 return parse_f32_array(embedding);
382 }
383 return Err(EmbeddingError::ApiError(
384 "Response 'data' array is empty or missing 'embedding' field".to_string(),
385 ));
386 }
387
388 Err(EmbeddingError::ApiError(format!(
389 "Unrecognised response shape; expected 'embeddings', 'embedding', or 'data' key. Got: {value}"
390 )))
391}
392
393fn extract_all_embeddings_from_value(value: &Value) -> EmbeddingResult<Vec<Vec<f32>>> {
401 if let Some(embeddings) = value.get("embeddings").and_then(|v| v.as_array()) {
402 return embeddings.iter().map(parse_f32_array).collect();
403 }
404
405 if let Some(data) = value.get("data").and_then(|v| v.as_array()) {
406 return data
407 .iter()
408 .map(|item| {
409 item.get("embedding").ok_or_else(|| {
410 EmbeddingError::ApiError("Response 'data' item missing 'embedding'".to_string())
411 })
412 })
413 .map(|embedding| embedding.and_then(parse_f32_array))
414 .collect();
415 }
416
417 if let Some(embedding) = value.get("embedding") {
418 return Ok(vec![parse_f32_array(embedding)?]);
419 }
420
421 Err(EmbeddingError::ApiError(format!(
422 "Unrecognised response shape; expected 'embeddings', 'embedding', or 'data' key. Got: {value}"
423 )))
424}
425
426fn parse_f32_array(value: &Value) -> EmbeddingResult<Vec<f32>> {
428 let arr = value.as_array().ok_or_else(|| {
429 EmbeddingError::ApiError(format!("Expected a JSON array for embedding, got: {value}"))
430 })?;
431
432 arr.iter()
433 .map(|v| {
434 v.as_f64().map(|f| f as f32).ok_or_else(|| {
435 EmbeddingError::ApiError(format!("Non-numeric value in embedding array: {v}"))
436 })
437 })
438 .collect()
439}
440
441#[cfg(test)]
444#[allow(
445 clippy::expect_used,
446 clippy::unwrap_used,
447 reason = "test code — panics are acceptable failures"
448)]
449mod tests {
450 use super::*;
451 use crate::config::EmbeddingConfig;
452 use crate::provider::EmbeddingProvider;
453
454 fn make_config() -> EmbeddingConfig {
455 EmbeddingConfig {
456 provider: EmbeddingProvider::Ollama,
457 model: "avr/sfr-embedding-mistral:latest".to_string(),
458 dimensions: 1024,
459 endpoint: None,
460 api_key: None,
461 api_version: None,
462 max_completion_tokens: 8191,
463 batch_size: 10,
464 mock: false,
465 mock_mode: Default::default(),
466 #[cfg(feature = "onnx")]
467 onnx: Default::default(),
468 huggingface_tokenizer: None,
469 }
470 }
471
472 #[test]
473 fn test_constructor_defaults() {
474 let config = make_config();
475 let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
476 assert_eq!(engine.endpoint, "http://localhost:11434/api/embed");
477 assert_eq!(engine.model, "avr/sfr-embedding-mistral:latest");
478 assert_eq!(engine.dimension(), 1024);
479 assert_eq!(engine.batch_size(), 10);
480 assert_eq!(engine.max_sequence_length(), 8191);
481 }
482
483 #[test]
484 fn test_constructor_custom_endpoint() {
485 let config = EmbeddingConfig {
486 endpoint: Some("http://my-ollama:11434/api/embed".to_string()),
487 ..make_config()
488 };
489 let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
490 assert_eq!(engine.endpoint, "http://my-ollama:11434/api/embed");
491 }
492
493 #[test]
494 fn test_truncate_text_short() {
495 let config = EmbeddingConfig {
496 max_completion_tokens: 10,
497 ..make_config()
498 };
499 let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
500 let result = engine.truncate_text("hello");
502 assert_eq!(result, "hello");
503 }
504
505 #[test]
506 fn test_truncate_text_exact_limit() {
507 let config = EmbeddingConfig {
508 max_completion_tokens: 2,
509 ..make_config()
510 };
511 let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
512 let result = engine.truncate_text("abcdefgh");
514 assert_eq!(result, "abcdefgh");
515 }
516
517 #[test]
518 fn test_truncate_text_over_limit() {
519 let config = EmbeddingConfig {
520 max_completion_tokens: 2,
521 ..make_config()
522 };
523 let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
524 let result = engine.truncate_text("abcdefghij");
526 assert_eq!(result, "abcdefgh");
527 }
528
529 #[test]
530 fn test_truncate_text_unicode_boundary() {
531 let config = EmbeddingConfig {
532 max_completion_tokens: 1,
533 ..make_config()
534 };
535 let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
536 let result = engine.truncate_text("héllo");
539 assert_eq!(result, "héll");
541 assert!(std::str::from_utf8(result.as_bytes()).is_ok());
543 }
544
545 #[test]
546 fn test_truncate_text_empty() {
547 let config = make_config();
548 let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
549 assert_eq!(engine.truncate_text(""), "");
550 }
551
552 #[test]
555 fn test_parse_shape1_embeddings() {
556 let json = serde_json::json!({
557 "embeddings": [[0.1_f64, 0.2_f64, 0.3_f64]]
558 });
559 let result = extract_embedding_from_value(&json).expect("should parse shape 1");
560 assert_eq!(result.len(), 3);
561 assert!((result[0] - 0.1_f32).abs() < 1e-6);
562 assert!((result[1] - 0.2_f32).abs() < 1e-6);
563 assert!((result[2] - 0.3_f32).abs() < 1e-6);
564 }
565
566 #[test]
567 fn test_parse_shape2_embedding() {
568 let json = serde_json::json!({
569 "embedding": [0.4_f64, 0.5_f64]
570 });
571 let result = extract_embedding_from_value(&json).expect("should parse shape 2");
572 assert_eq!(result.len(), 2);
573 assert!((result[0] - 0.4_f32).abs() < 1e-6);
574 assert!((result[1] - 0.5_f32).abs() < 1e-6);
575 }
576
577 #[test]
578 fn test_parse_shape3_data() {
579 let json = serde_json::json!({
580 "data": [{"embedding": [0.6_f64, 0.7_f64, 0.8_f64]}]
581 });
582 let result = extract_embedding_from_value(&json).expect("should parse shape 3");
583 assert_eq!(result.len(), 3);
584 assert!((result[0] - 0.6_f32).abs() < 1e-6);
585 assert!((result[1] - 0.7_f32).abs() < 1e-6);
586 assert!((result[2] - 0.8_f32).abs() < 1e-6);
587 }
588
589 #[test]
590 fn test_parse_unrecognised_shape() {
591 let json = serde_json::json!({ "unknown": "value" });
592 let result = extract_embedding_from_value(&json);
593 assert!(result.is_err());
594 assert!(matches!(result, Err(EmbeddingError::ApiError(_))));
595 }
596
597 #[test]
598 fn test_parse_empty_embeddings_array() {
599 let json = serde_json::json!({ "embeddings": [] });
600 let result = extract_embedding_from_value(&json);
601 assert!(result.is_err());
602 assert!(matches!(result, Err(EmbeddingError::ApiError(_))));
603 }
604
605 #[test]
606 fn test_parse_empty_data_array() {
607 let json = serde_json::json!({ "data": [] });
608 let result = extract_embedding_from_value(&json);
609 assert!(result.is_err());
610 assert!(matches!(result, Err(EmbeddingError::ApiError(_))));
611 }
612
613 #[test]
614 fn test_parse_non_numeric_values() {
615 let json = serde_json::json!({ "embedding": ["not", "numbers"] });
616 let result = extract_embedding_from_value(&json);
617 assert!(result.is_err());
618 assert!(matches!(result, Err(EmbeddingError::ApiError(_))));
619 }
620
621 #[test]
624 fn test_parse_all_embeddings_shape1() {
625 let json = serde_json::json!({
626 "embeddings": [[0.1_f64, 0.2_f64], [0.3_f64, 0.4_f64]]
627 });
628 let result = extract_all_embeddings_from_value(&json).expect("should parse batch");
629 assert_eq!(result.len(), 2);
630 assert!((result[0][0] - 0.1_f32).abs() < 1e-6);
631 assert!((result[1][1] - 0.4_f32).abs() < 1e-6);
632 }
633
634 #[test]
635 fn test_parse_all_embeddings_data_shape() {
636 let json = serde_json::json!({
637 "data": [{"embedding": [0.1_f64]}, {"embedding": [0.2_f64]}]
638 });
639 let result = extract_all_embeddings_from_value(&json).expect("should parse batch");
640 assert_eq!(result.len(), 2);
641 assert!((result[0][0] - 0.1_f32).abs() < 1e-6);
642 assert!((result[1][0] - 0.2_f32).abs() < 1e-6);
643 }
644
645 #[test]
646 fn test_parse_all_embeddings_single_shape() {
647 let json = serde_json::json!({ "embedding": [0.5_f64, 0.6_f64] });
648 let result = extract_all_embeddings_from_value(&json).expect("should parse single");
649 assert_eq!(result.len(), 1);
650 assert_eq!(result[0].len(), 2);
651 }
652
653 #[test]
654 fn test_parse_all_embeddings_unrecognised() {
655 let json = serde_json::json!({ "nope": 1 });
656 assert!(matches!(
657 extract_all_embeddings_from_value(&json),
658 Err(EmbeddingError::ApiError(_))
659 ));
660 }
661
662 fn config_for(server_url: &str) -> EmbeddingConfig {
665 EmbeddingConfig {
666 dimensions: 2,
667 endpoint: Some(format!("{server_url}/api/embed")),
668 ..make_config()
669 }
670 }
671
672 #[tokio::test]
673 async fn embed_batches_array_input() {
674 let mut server = mockito::Server::new_async().await;
675 let batch = server
676 .mock("POST", "/api/embed")
677 .match_body(mockito::Matcher::Regex(r#""input":\["#.to_string()))
678 .with_status(200)
679 .with_header("content-type", "application/json")
680 .with_body(r#"{"embeddings":[[1.0,0.0],[0.0,1.0]]}"#)
681 .create_async()
682 .await;
683
684 let engine = OllamaEmbeddingEngine::new(&config_for(&server.url())).unwrap();
685 let out = engine.embed(&["alpha", "beta"]).await.unwrap();
686
687 assert_eq!(out, vec![vec![1.0, 0.0], vec![0.0, 1.0]]);
688 batch.assert_async().await;
689 }
690
691 #[tokio::test]
692 async fn embed_falls_back_to_per_text_when_array_rejected() {
693 let mut server = mockito::Server::new_async().await;
694 let batch = server
697 .mock("POST", "/api/embed")
698 .match_body(mockito::Matcher::Regex(r#""input":\["#.to_string()))
699 .with_status(200)
700 .with_header("content-type", "application/json")
701 .with_body(r#"{"embedding":[9.9,9.9]}"#)
702 .create_async()
703 .await;
704 let single_a = server
706 .mock("POST", "/api/embed")
707 .match_body(mockito::Matcher::Regex(r#""input":"alpha""#.to_string()))
708 .with_status(200)
709 .with_body(r#"{"embedding":[1.0,0.0]}"#)
710 .create_async()
711 .await;
712 let single_b = server
713 .mock("POST", "/api/embed")
714 .match_body(mockito::Matcher::Regex(r#""input":"beta""#.to_string()))
715 .with_status(200)
716 .with_body(r#"{"embedding":[0.0,1.0]}"#)
717 .create_async()
718 .await;
719
720 let engine = OllamaEmbeddingEngine::new(&config_for(&server.url())).unwrap();
721 let out = engine.embed(&["alpha", "beta"]).await.unwrap();
722
723 assert_eq!(out, vec![vec![1.0, 0.0], vec![0.0, 1.0]]);
724 batch.assert_async().await;
725 single_a.assert_async().await;
726 single_b.assert_async().await;
727 }
728
729 #[tokio::test]
730 async fn embed_does_not_panic_on_zero_batch_size() {
731 let mut server = mockito::Server::new_async().await;
732 let batch = server
734 .mock("POST", "/api/embed")
735 .match_body(mockito::Matcher::Regex(r#""input":\["#.to_string()))
736 .with_status(200)
737 .with_header("content-type", "application/json")
738 .with_body(r#"{"embeddings":[[1.0,0.0]]}"#)
739 .expect(2)
740 .create_async()
741 .await;
742
743 let config = EmbeddingConfig {
744 batch_size: 0,
745 ..config_for(&server.url())
746 };
747 let engine = OllamaEmbeddingEngine::new(&config).unwrap();
748 let out = engine.embed(&["alpha", "beta"]).await.unwrap();
749
750 assert_eq!(out.len(), 2);
751 batch.assert_async().await;
752 }
753
754 #[tokio::test]
755 async fn embed_propagates_http_error_without_falling_back() {
756 let mut server = mockito::Server::new_async().await;
757 let batch = server
759 .mock("POST", "/api/embed")
760 .match_body(mockito::Matcher::Regex(r#""input":\["#.to_string()))
761 .with_status(404)
762 .with_body("model not found")
763 .expect(1)
764 .create_async()
765 .await;
766 let per_text = server
768 .mock("POST", "/api/embed")
769 .match_body(mockito::Matcher::Regex(r#""input":"[a-z]"#.to_string()))
770 .with_status(200)
771 .with_body(r#"{"embedding":[0.0,0.0]}"#)
772 .expect(0)
773 .create_async()
774 .await;
775
776 let engine = OllamaEmbeddingEngine::new(&config_for(&server.url())).unwrap();
777 let result = engine.embed(&["alpha", "beta"]).await;
778
779 assert!(result.is_err());
780 batch.assert_async().await;
781 per_text.assert_async().await;
782 }
783}