cognee_embedding/
openai_compatible.rs1use async_trait::async_trait;
7use futures::stream::{self, StreamExt, TryStreamExt};
8use serde::{Deserialize, Serialize};
9
10use crate::config::EmbeddingConfig;
11use crate::engine::EmbeddingEngine;
12use crate::error::{EmbeddingError, EmbeddingResult};
13use crate::utils::{handle_embedding_response, sanitize_embedding_inputs};
14
15const MAX_CONCURRENT_BATCHES: usize = 8;
19
20#[derive(Deserialize)]
23struct EmbeddingResponse {
24 data: Vec<EmbeddingData>,
25}
26
27#[derive(Deserialize)]
28struct EmbeddingData {
29 embedding: Vec<f32>,
30}
31
32#[derive(Serialize)]
35struct EmbeddingRequest<'a> {
36 model: &'a str,
37 input: Vec<&'a str>,
38 encoding_format: &'a str,
39}
40
41pub struct OpenAICompatibleEmbeddingEngine {
66 client: reqwest::Client,
67 base_url: String,
69 model: String,
70 dimensions: usize,
71 batch_size: usize,
72 max_sequence_length: usize,
73}
74
75impl OpenAICompatibleEmbeddingEngine {
76 pub fn new(config: &EmbeddingConfig) -> EmbeddingResult<Self> {
81 let raw_endpoint = config
82 .endpoint
83 .clone()
84 .unwrap_or_else(|| "https://api.openai.com".to_string());
85
86 let base_url = normalize_base_url(&raw_endpoint);
87
88 let api_key = config.api_key.clone().unwrap_or_default();
89
90 let mut default_headers = reqwest::header::HeaderMap::new();
91 let bearer = format!("Bearer {api_key}");
92 let auth_value = reqwest::header::HeaderValue::from_str(&bearer)
93 .map_err(|e| EmbeddingError::ConfigError(format!("Invalid API key value: {e}")))?;
94 default_headers.insert(reqwest::header::AUTHORIZATION, auth_value);
95
96 let client = reqwest::Client::builder()
101 .default_headers(default_headers)
102 .timeout(std::time::Duration::from_secs(30))
103 .build()
104 .map_err(|e| {
105 EmbeddingError::ConfigError(format!("Failed to build HTTP client: {e}"))
106 })?;
107
108 Ok(Self {
109 client,
110 base_url,
111 model: config.model.clone(),
112 dimensions: config.dimensions,
113 batch_size: config.batch_size,
114 max_sequence_length: config.max_completion_tokens,
115 })
116 }
117
118 fn embeddings_url(&self) -> String {
120 format!("{}/embeddings", self.base_url)
121 }
122
123 async fn embed_batch_once(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
125 let sanitized = sanitize_embedding_inputs(texts);
126 let sanitized_strs: Vec<&str> = sanitized.iter().map(|c| c.as_ref()).collect();
127
128 let request_body = EmbeddingRequest {
129 model: &self.model,
130 input: sanitized_strs,
131 encoding_format: "float",
132 };
133
134 let response = self
135 .client
136 .post(self.embeddings_url())
137 .json(&request_body)
138 .send()
139 .await
140 .map_err(|e| EmbeddingError::HttpError(format!("Request failed: {e}")))?;
141
142 let status = response.status();
143 if !status.is_success() {
144 let body = response
145 .text()
146 .await
147 .unwrap_or_else(|_| "<failed to read body>".to_string());
148 return Err(if status.as_u16() == 429 || status.is_server_error() {
149 EmbeddingError::HttpError(format!("HTTP {status}: {body}"))
151 } else {
152 EmbeddingError::ApiError(format!("HTTP {status}: {body}"))
153 });
154 }
155
156 let parsed: EmbeddingResponse = response
157 .json()
158 .await
159 .map_err(|e| EmbeddingError::ApiError(format!("Failed to parse response: {e}")))?;
160
161 let vectors: Vec<Vec<f32>> = parsed.data.into_iter().map(|d| d.embedding).collect();
162
163 let result = handle_embedding_response(texts, vectors, self.dimensions);
165 Ok(result)
166 }
167
168 async fn embed_batch_with_retry(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
174 let max_duration = std::time::Duration::from_secs(128);
175 let start = std::time::Instant::now();
176 let mut wait_secs = 2u64;
177 loop {
178 match self.embed_batch_once(texts).await {
179 Ok(result) => return Ok(result),
180 Err(e) if is_retryable(&e) && start.elapsed() < max_duration => {
181 let jitter = rand::random::<u64>() % wait_secs;
182 tokio::time::sleep(std::time::Duration::from_secs(wait_secs + jitter)).await;
183 wait_secs = (wait_secs * 2).min(128);
184 }
185 Err(e) => return Err(e),
186 }
187 }
188 }
189}
190
191#[async_trait]
192impl EmbeddingEngine for OpenAICompatibleEmbeddingEngine {
193 async fn embed(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
194 if texts.is_empty() {
195 return Ok(Vec::new());
196 }
197
198 let batch_futures: Vec<_> = texts
203 .chunks(self.batch_size.max(1))
204 .enumerate()
205 .map(|(index, batch)| async move {
206 self.embed_batch_with_retry(batch).await.map(|v| (index, v))
207 })
208 .collect();
209
210 let mut indexed: Vec<(usize, Vec<Vec<f32>>)> = stream::iter(batch_futures)
211 .buffer_unordered(MAX_CONCURRENT_BATCHES)
212 .try_collect()
213 .await?;
214
215 indexed.sort_by_key(|(index, _)| *index);
216 Ok(indexed.into_iter().flat_map(|(_, batch)| batch).collect())
217 }
218
219 fn dimension(&self) -> usize {
220 self.dimensions
221 }
222
223 fn batch_size(&self) -> usize {
224 self.batch_size
225 }
226
227 fn max_sequence_length(&self) -> usize {
228 self.max_sequence_length
229 }
230}
231
232fn is_retryable(e: &EmbeddingError) -> bool {
236 matches!(e, EmbeddingError::HttpError(_))
237}
238
239pub(crate) fn normalize_base_url(url: &str) -> String {
248 let mut s = url.trim_end_matches('/').to_string();
249
250 if s.ends_with("/v1/embeddings") {
251 s.truncate(s.len() - "/embeddings".len());
252 }
253
254 if !s.ends_with("/v1") {
255 s.push_str("/v1");
256 }
257
258 s
259}
260
261#[cfg(test)]
264#[allow(
265 clippy::expect_used,
266 clippy::unwrap_used,
267 reason = "test code — panics are acceptable failures"
268)]
269mod tests {
270 use super::*;
271
272 #[test]
275 fn test_normalize_plain_domain() {
276 assert_eq!(
277 normalize_base_url("https://api.openai.com"),
278 "https://api.openai.com/v1"
279 );
280 }
281
282 #[test]
283 fn test_normalize_trailing_slash() {
284 assert_eq!(
285 normalize_base_url("https://api.openai.com/"),
286 "https://api.openai.com/v1"
287 );
288 }
289
290 #[test]
291 fn test_normalize_already_v1() {
292 assert_eq!(
293 normalize_base_url("https://api.openai.com/v1"),
294 "https://api.openai.com/v1"
295 );
296 }
297
298 #[test]
299 fn test_normalize_v1_trailing_slash() {
300 assert_eq!(
301 normalize_base_url("https://api.openai.com/v1/"),
302 "https://api.openai.com/v1"
303 );
304 }
305
306 #[test]
307 fn test_normalize_v1_embeddings_suffix() {
308 assert_eq!(
309 normalize_base_url("https://api.openai.com/v1/embeddings"),
310 "https://api.openai.com/v1"
311 );
312 }
313
314 #[test]
315 fn test_normalize_localhost_with_port() {
316 assert_eq!(
317 normalize_base_url("http://localhost:11434"),
318 "http://localhost:11434/v1"
319 );
320 }
321
322 #[test]
323 fn test_normalize_localhost_with_port_v1() {
324 assert_eq!(
325 normalize_base_url("http://localhost:8080/v1"),
326 "http://localhost:8080/v1"
327 );
328 }
329
330 #[test]
331 fn test_normalize_azure_endpoint() {
332 let url = "https://myresource.openai.azure.com/openai";
334 assert_eq!(
335 normalize_base_url(url),
336 "https://myresource.openai.azure.com/openai/v1"
337 );
338 }
339
340 #[test]
343 fn test_new_with_defaults() {
344 let config = EmbeddingConfig {
345 model: "text-embedding-3-small".to_string(),
346 dimensions: 1536,
347 batch_size: 10,
348 ..EmbeddingConfig::default()
349 };
350 let engine = OpenAICompatibleEmbeddingEngine::new(&config)
351 .expect("should build engine with default config");
352 assert_eq!(engine.dimension(), 1536);
353 assert_eq!(engine.batch_size(), 10);
354 assert_eq!(engine.base_url, "https://api.openai.com/v1");
355 }
356
357 #[test]
358 fn test_new_with_custom_endpoint() {
359 let config = EmbeddingConfig {
360 endpoint: Some("http://localhost:8080/v1/embeddings".to_string()),
361 model: "my-model".to_string(),
362 dimensions: 384,
363 batch_size: 5,
364 ..EmbeddingConfig::default()
365 };
366 let engine = OpenAICompatibleEmbeddingEngine::new(&config)
367 .expect("should build engine with custom endpoint");
368 assert_eq!(engine.base_url, "http://localhost:8080/v1");
369 }
370
371 #[test]
372 fn test_embeddings_url() {
373 let config = EmbeddingConfig {
374 endpoint: Some("https://api.openai.com".to_string()),
375 ..EmbeddingConfig::default()
376 };
377 let engine = OpenAICompatibleEmbeddingEngine::new(&config).expect("should build engine");
378 assert_eq!(
379 engine.embeddings_url(),
380 "https://api.openai.com/v1/embeddings"
381 );
382 }
383
384 #[test]
387 fn test_is_retryable_http_error() {
388 assert!(is_retryable(&EmbeddingError::HttpError(
389 "HTTP 429: rate limited".to_string()
390 )));
391 assert!(is_retryable(&EmbeddingError::HttpError(
392 "HTTP 503: unavailable".to_string()
393 )));
394 }
395
396 #[test]
397 fn test_is_retryable_api_error_not_retryable() {
398 assert!(!is_retryable(&EmbeddingError::ApiError(
399 "HTTP 400: bad request".to_string()
400 )));
401 assert!(!is_retryable(&EmbeddingError::ConfigError(
402 "bad config".to_string()
403 )));
404 }
405}