cognee_embedding/
ollama.rs1use async_trait::async_trait;
10use futures::future;
11use serde::Serialize;
12use serde_json::Value;
13
14use crate::config::EmbeddingConfig;
15use crate::engine::EmbeddingEngine;
16use crate::error::{EmbeddingError, EmbeddingResult};
17use crate::utils::{handle_embedding_response, sanitize_embedding_inputs};
18
19#[derive(Serialize)]
22struct OllamaEmbedRequest<'a> {
23 model: &'a str,
24 input: &'a str,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 dimensions: Option<usize>,
27}
28
29pub struct OllamaEmbeddingEngine {
46 client: reqwest::Client,
47 endpoint: String,
49 model: String,
50 dimensions: usize,
51 batch_size: usize,
52 max_completion_tokens: usize,
53}
54
55impl OllamaEmbeddingEngine {
56 pub fn new(config: &EmbeddingConfig) -> EmbeddingResult<Self> {
61 let endpoint = config
62 .endpoint
63 .clone()
64 .unwrap_or_else(|| "http://localhost:11434/api/embed".to_string());
65
66 let mut default_headers = reqwest::header::HeaderMap::new();
67
68 if let Some(api_key) = &config.api_key
69 && !api_key.is_empty()
70 {
71 let bearer = format!("Bearer {api_key}");
72 let auth_value = reqwest::header::HeaderValue::from_str(&bearer)
73 .map_err(|e| EmbeddingError::ConfigError(format!("Invalid API key value: {e}")))?;
74 default_headers.insert(reqwest::header::AUTHORIZATION, auth_value);
75 }
76
77 let client = reqwest::Client::builder()
78 .default_headers(default_headers)
79 .timeout(std::time::Duration::from_secs(30))
80 .build()
81 .map_err(|e| {
82 EmbeddingError::ConfigError(format!("Failed to build HTTP client: {e}"))
83 })?;
84
85 Ok(Self {
86 client,
87 endpoint,
88 model: config.model.clone(),
89 dimensions: config.dimensions,
90 batch_size: config.batch_size,
91 max_completion_tokens: config.max_completion_tokens,
92 })
93 }
94
95 fn truncate_text<'a>(&self, text: &'a str) -> &'a str {
100 let char_limit = self.max_completion_tokens * 4;
101 let byte_pos = text
102 .char_indices()
103 .nth(char_limit)
104 .map(|(i, _)| i)
105 .unwrap_or(text.len());
106 &text[..byte_pos]
107 }
108
109 async fn embed_single_once(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
111 let truncated = self.truncate_text(text);
112
113 let request_body = OllamaEmbedRequest {
114 model: &self.model,
115 input: truncated,
116 dimensions: if self.dimensions > 0 {
119 Some(self.dimensions)
120 } else {
121 None
122 },
123 };
124
125 let response = self
126 .client
127 .post(&self.endpoint)
128 .json(&request_body)
129 .send()
130 .await
131 .map_err(|e| EmbeddingError::HttpError(format!("Request failed: {e}")))?;
132
133 let status = response.status();
134 if !status.is_success() {
135 let body = response
136 .text()
137 .await
138 .unwrap_or_else(|_| "<failed to read body>".to_string());
139 return Err(if status.as_u16() == 429 || status.is_server_error() {
140 EmbeddingError::HttpError(format!("HTTP {status}: {body}"))
141 } else {
142 EmbeddingError::ApiError(format!("HTTP {status}: {body}"))
143 });
144 }
145
146 let value: Value = response
147 .json()
148 .await
149 .map_err(|e| EmbeddingError::ApiError(format!("Failed to parse response: {e}")))?;
150
151 extract_embedding_from_value(&value)
152 }
153
154 async fn embed_single_with_retry(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
160 let max_duration = std::time::Duration::from_secs(128);
161 let start = std::time::Instant::now();
162 let mut wait_secs = 8u64;
163 loop {
164 match self.embed_single_once(text).await {
165 Ok(v) => return Ok(v),
166 Err(e)
167 if matches!(e, EmbeddingError::HttpError(_))
168 && start.elapsed() < max_duration =>
169 {
170 let jitter = rand::random::<u64>() % wait_secs;
171 tokio::time::sleep(std::time::Duration::from_secs(wait_secs + jitter)).await;
172 wait_secs = (wait_secs * 2).min(128);
173 }
174 Err(e) => return Err(e),
175 }
176 }
177 }
178
179 async fn embed_all(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
181 let sanitized = sanitize_embedding_inputs(texts);
182 let sanitized_refs: Vec<&str> = sanitized.iter().map(|s| s.as_ref()).collect();
183
184 let futures: Vec<_> = sanitized_refs
185 .iter()
186 .map(|&text| self.embed_single_with_retry(text))
187 .collect();
188
189 let results = future::join_all(futures).await;
190
191 let embeddings: EmbeddingResult<Vec<Vec<f32>>> = results.into_iter().collect();
192
193 Ok(handle_embedding_response(
194 texts,
195 embeddings?,
196 self.dimensions,
197 ))
198 }
199}
200
201#[async_trait]
202impl EmbeddingEngine for OllamaEmbeddingEngine {
203 async fn embed(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
204 if texts.is_empty() {
205 return Ok(Vec::new());
206 }
207 self.embed_all(texts).await
208 }
209
210 fn dimension(&self) -> usize {
211 self.dimensions
212 }
213
214 fn batch_size(&self) -> usize {
215 self.batch_size
216 }
217
218 fn max_sequence_length(&self) -> usize {
219 self.max_completion_tokens
220 }
221}
222
223fn extract_embedding_from_value(value: &Value) -> EmbeddingResult<Vec<f32>> {
242 if let Some(embeddings) = value.get("embeddings") {
244 if let Some(first) = embeddings.get(0) {
245 return parse_f32_array(first);
246 }
247 return Err(EmbeddingError::ApiError(
248 "Response 'embeddings' array is empty".to_string(),
249 ));
250 }
251
252 if let Some(embedding) = value.get("embedding") {
254 return parse_f32_array(embedding);
255 }
256
257 if let Some(data) = value.get("data") {
259 if let Some(first) = data.get(0)
260 && let Some(embedding) = first.get("embedding")
261 {
262 return parse_f32_array(embedding);
263 }
264 return Err(EmbeddingError::ApiError(
265 "Response 'data' array is empty or missing 'embedding' field".to_string(),
266 ));
267 }
268
269 Err(EmbeddingError::ApiError(format!(
270 "Unrecognised response shape; expected 'embeddings', 'embedding', or 'data' key. Got: {value}"
271 )))
272}
273
274fn parse_f32_array(value: &Value) -> EmbeddingResult<Vec<f32>> {
276 let arr = value.as_array().ok_or_else(|| {
277 EmbeddingError::ApiError(format!("Expected a JSON array for embedding, got: {value}"))
278 })?;
279
280 arr.iter()
281 .map(|v| {
282 v.as_f64().map(|f| f as f32).ok_or_else(|| {
283 EmbeddingError::ApiError(format!("Non-numeric value in embedding array: {v}"))
284 })
285 })
286 .collect()
287}
288
289#[cfg(test)]
292#[allow(
293 clippy::expect_used,
294 clippy::unwrap_used,
295 reason = "test code — panics are acceptable failures"
296)]
297mod tests {
298 use super::*;
299 use crate::config::EmbeddingConfig;
300 use crate::provider::EmbeddingProvider;
301
302 fn make_config() -> EmbeddingConfig {
303 EmbeddingConfig {
304 provider: EmbeddingProvider::Ollama,
305 model: "avr/sfr-embedding-mistral:latest".to_string(),
306 dimensions: 1024,
307 endpoint: None,
308 api_key: None,
309 api_version: None,
310 max_completion_tokens: 8191,
311 batch_size: 10,
312 mock: false,
313 mock_mode: Default::default(),
314 #[cfg(feature = "onnx")]
315 onnx: Default::default(),
316 huggingface_tokenizer: None,
317 }
318 }
319
320 #[test]
321 fn test_constructor_defaults() {
322 let config = make_config();
323 let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
324 assert_eq!(engine.endpoint, "http://localhost:11434/api/embed");
325 assert_eq!(engine.model, "avr/sfr-embedding-mistral:latest");
326 assert_eq!(engine.dimension(), 1024);
327 assert_eq!(engine.batch_size(), 10);
328 assert_eq!(engine.max_sequence_length(), 8191);
329 }
330
331 #[test]
332 fn test_constructor_custom_endpoint() {
333 let config = EmbeddingConfig {
334 endpoint: Some("http://my-ollama:11434/api/embed".to_string()),
335 ..make_config()
336 };
337 let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
338 assert_eq!(engine.endpoint, "http://my-ollama:11434/api/embed");
339 }
340
341 #[test]
342 fn test_truncate_text_short() {
343 let config = EmbeddingConfig {
344 max_completion_tokens: 10,
345 ..make_config()
346 };
347 let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
348 let result = engine.truncate_text("hello");
350 assert_eq!(result, "hello");
351 }
352
353 #[test]
354 fn test_truncate_text_exact_limit() {
355 let config = EmbeddingConfig {
356 max_completion_tokens: 2,
357 ..make_config()
358 };
359 let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
360 let result = engine.truncate_text("abcdefgh");
362 assert_eq!(result, "abcdefgh");
363 }
364
365 #[test]
366 fn test_truncate_text_over_limit() {
367 let config = EmbeddingConfig {
368 max_completion_tokens: 2,
369 ..make_config()
370 };
371 let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
372 let result = engine.truncate_text("abcdefghij");
374 assert_eq!(result, "abcdefgh");
375 }
376
377 #[test]
378 fn test_truncate_text_unicode_boundary() {
379 let config = EmbeddingConfig {
380 max_completion_tokens: 1,
381 ..make_config()
382 };
383 let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
384 let result = engine.truncate_text("héllo");
387 assert_eq!(result, "héll");
389 assert!(std::str::from_utf8(result.as_bytes()).is_ok());
391 }
392
393 #[test]
394 fn test_truncate_text_empty() {
395 let config = make_config();
396 let engine = OllamaEmbeddingEngine::new(&config).expect("should construct engine");
397 assert_eq!(engine.truncate_text(""), "");
398 }
399
400 #[test]
403 fn test_parse_shape1_embeddings() {
404 let json = serde_json::json!({
405 "embeddings": [[0.1_f64, 0.2_f64, 0.3_f64]]
406 });
407 let result = extract_embedding_from_value(&json).expect("should parse shape 1");
408 assert_eq!(result.len(), 3);
409 assert!((result[0] - 0.1_f32).abs() < 1e-6);
410 assert!((result[1] - 0.2_f32).abs() < 1e-6);
411 assert!((result[2] - 0.3_f32).abs() < 1e-6);
412 }
413
414 #[test]
415 fn test_parse_shape2_embedding() {
416 let json = serde_json::json!({
417 "embedding": [0.4_f64, 0.5_f64]
418 });
419 let result = extract_embedding_from_value(&json).expect("should parse shape 2");
420 assert_eq!(result.len(), 2);
421 assert!((result[0] - 0.4_f32).abs() < 1e-6);
422 assert!((result[1] - 0.5_f32).abs() < 1e-6);
423 }
424
425 #[test]
426 fn test_parse_shape3_data() {
427 let json = serde_json::json!({
428 "data": [{"embedding": [0.6_f64, 0.7_f64, 0.8_f64]}]
429 });
430 let result = extract_embedding_from_value(&json).expect("should parse shape 3");
431 assert_eq!(result.len(), 3);
432 assert!((result[0] - 0.6_f32).abs() < 1e-6);
433 assert!((result[1] - 0.7_f32).abs() < 1e-6);
434 assert!((result[2] - 0.8_f32).abs() < 1e-6);
435 }
436
437 #[test]
438 fn test_parse_unrecognised_shape() {
439 let json = serde_json::json!({ "unknown": "value" });
440 let result = extract_embedding_from_value(&json);
441 assert!(result.is_err());
442 assert!(matches!(result, Err(EmbeddingError::ApiError(_))));
443 }
444
445 #[test]
446 fn test_parse_empty_embeddings_array() {
447 let json = serde_json::json!({ "embeddings": [] });
448 let result = extract_embedding_from_value(&json);
449 assert!(result.is_err());
450 assert!(matches!(result, Err(EmbeddingError::ApiError(_))));
451 }
452
453 #[test]
454 fn test_parse_empty_data_array() {
455 let json = serde_json::json!({ "data": [] });
456 let result = extract_embedding_from_value(&json);
457 assert!(result.is_err());
458 assert!(matches!(result, Err(EmbeddingError::ApiError(_))));
459 }
460
461 #[test]
462 fn test_parse_non_numeric_values() {
463 let json = serde_json::json!({ "embedding": ["not", "numbers"] });
464 let result = extract_embedding_from_value(&json);
465 assert!(result.is_err());
466 assert!(matches!(result, Err(EmbeddingError::ApiError(_))));
467 }
468}