cognee_embedding/
openai_compatible.rs1use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9use crate::config::EmbeddingConfig;
10use crate::engine::EmbeddingEngine;
11use crate::error::{EmbeddingError, EmbeddingResult};
12use crate::utils::{handle_embedding_response, sanitize_embedding_inputs};
13
14#[derive(Deserialize)]
17struct EmbeddingResponse {
18 data: Vec<EmbeddingData>,
19}
20
21#[derive(Deserialize)]
22struct EmbeddingData {
23 embedding: Vec<f32>,
24}
25
26#[derive(Serialize)]
29struct EmbeddingRequest<'a> {
30 model: &'a str,
31 input: Vec<&'a str>,
32 encoding_format: &'a str,
33}
34
35pub struct OpenAICompatibleEmbeddingEngine {
60 client: reqwest::Client,
61 base_url: String,
63 model: String,
64 dimensions: usize,
65 batch_size: usize,
66 max_sequence_length: usize,
67}
68
69impl OpenAICompatibleEmbeddingEngine {
70 pub fn new(config: &EmbeddingConfig) -> EmbeddingResult<Self> {
75 let raw_endpoint = config
76 .endpoint
77 .clone()
78 .unwrap_or_else(|| "https://api.openai.com".to_string());
79
80 let base_url = normalize_base_url(&raw_endpoint);
81
82 let api_key = config.api_key.clone().unwrap_or_default();
83
84 let mut default_headers = reqwest::header::HeaderMap::new();
85 let bearer = format!("Bearer {api_key}");
86 let auth_value = reqwest::header::HeaderValue::from_str(&bearer)
87 .map_err(|e| EmbeddingError::ConfigError(format!("Invalid API key value: {e}")))?;
88 default_headers.insert(reqwest::header::AUTHORIZATION, auth_value);
89
90 let client = reqwest::Client::builder()
95 .default_headers(default_headers)
96 .timeout(std::time::Duration::from_secs(30))
97 .build()
98 .map_err(|e| {
99 EmbeddingError::ConfigError(format!("Failed to build HTTP client: {e}"))
100 })?;
101
102 Ok(Self {
103 client,
104 base_url,
105 model: config.model.clone(),
106 dimensions: config.dimensions,
107 batch_size: config.batch_size,
108 max_sequence_length: config.max_completion_tokens,
109 })
110 }
111
112 fn embeddings_url(&self) -> String {
114 format!("{}/embeddings", self.base_url)
115 }
116
117 async fn embed_batch_once(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
119 let sanitized = sanitize_embedding_inputs(texts);
120 let sanitized_strs: Vec<&str> = sanitized.iter().map(|c| c.as_ref()).collect();
121
122 let request_body = EmbeddingRequest {
123 model: &self.model,
124 input: sanitized_strs,
125 encoding_format: "float",
126 };
127
128 let response = self
129 .client
130 .post(self.embeddings_url())
131 .json(&request_body)
132 .send()
133 .await
134 .map_err(|e| EmbeddingError::HttpError(format!("Request failed: {e}")))?;
135
136 let status = response.status();
137 if !status.is_success() {
138 let body = response
139 .text()
140 .await
141 .unwrap_or_else(|_| "<failed to read body>".to_string());
142 return Err(if status.as_u16() == 429 || status.is_server_error() {
143 EmbeddingError::HttpError(format!("HTTP {status}: {body}"))
145 } else {
146 EmbeddingError::ApiError(format!("HTTP {status}: {body}"))
147 });
148 }
149
150 let parsed: EmbeddingResponse = response
151 .json()
152 .await
153 .map_err(|e| EmbeddingError::ApiError(format!("Failed to parse response: {e}")))?;
154
155 let vectors: Vec<Vec<f32>> = parsed.data.into_iter().map(|d| d.embedding).collect();
156
157 let result = handle_embedding_response(texts, vectors, self.dimensions);
159 Ok(result)
160 }
161
162 async fn embed_batch_with_retry(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
168 let max_duration = std::time::Duration::from_secs(128);
169 let start = std::time::Instant::now();
170 let mut wait_secs = 2u64;
171 loop {
172 match self.embed_batch_once(texts).await {
173 Ok(result) => return Ok(result),
174 Err(e) if is_retryable(&e) && start.elapsed() < max_duration => {
175 let jitter = rand::random::<u64>() % wait_secs;
176 tokio::time::sleep(std::time::Duration::from_secs(wait_secs + jitter)).await;
177 wait_secs = (wait_secs * 2).min(128);
178 }
179 Err(e) => return Err(e),
180 }
181 }
182 }
183}
184
185#[async_trait]
186impl EmbeddingEngine for OpenAICompatibleEmbeddingEngine {
187 async fn embed(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
188 if texts.is_empty() {
189 return Ok(Vec::new());
190 }
191
192 let mut results: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
193
194 for batch in texts.chunks(self.batch_size) {
195 let batch_results = self.embed_batch_with_retry(batch).await?;
196 results.extend(batch_results);
197 }
198
199 Ok(results)
200 }
201
202 fn dimension(&self) -> usize {
203 self.dimensions
204 }
205
206 fn batch_size(&self) -> usize {
207 self.batch_size
208 }
209
210 fn max_sequence_length(&self) -> usize {
211 self.max_sequence_length
212 }
213}
214
215fn is_retryable(e: &EmbeddingError) -> bool {
219 matches!(e, EmbeddingError::HttpError(_))
220}
221
222pub(crate) fn normalize_base_url(url: &str) -> String {
231 let mut s = url.trim_end_matches('/').to_string();
232
233 if s.ends_with("/v1/embeddings") {
234 s.truncate(s.len() - "/embeddings".len());
235 }
236
237 if !s.ends_with("/v1") {
238 s.push_str("/v1");
239 }
240
241 s
242}
243
244#[cfg(test)]
247#[allow(
248 clippy::expect_used,
249 clippy::unwrap_used,
250 reason = "test code — panics are acceptable failures"
251)]
252mod tests {
253 use super::*;
254
255 #[test]
258 fn test_normalize_plain_domain() {
259 assert_eq!(
260 normalize_base_url("https://api.openai.com"),
261 "https://api.openai.com/v1"
262 );
263 }
264
265 #[test]
266 fn test_normalize_trailing_slash() {
267 assert_eq!(
268 normalize_base_url("https://api.openai.com/"),
269 "https://api.openai.com/v1"
270 );
271 }
272
273 #[test]
274 fn test_normalize_already_v1() {
275 assert_eq!(
276 normalize_base_url("https://api.openai.com/v1"),
277 "https://api.openai.com/v1"
278 );
279 }
280
281 #[test]
282 fn test_normalize_v1_trailing_slash() {
283 assert_eq!(
284 normalize_base_url("https://api.openai.com/v1/"),
285 "https://api.openai.com/v1"
286 );
287 }
288
289 #[test]
290 fn test_normalize_v1_embeddings_suffix() {
291 assert_eq!(
292 normalize_base_url("https://api.openai.com/v1/embeddings"),
293 "https://api.openai.com/v1"
294 );
295 }
296
297 #[test]
298 fn test_normalize_localhost_with_port() {
299 assert_eq!(
300 normalize_base_url("http://localhost:11434"),
301 "http://localhost:11434/v1"
302 );
303 }
304
305 #[test]
306 fn test_normalize_localhost_with_port_v1() {
307 assert_eq!(
308 normalize_base_url("http://localhost:8080/v1"),
309 "http://localhost:8080/v1"
310 );
311 }
312
313 #[test]
314 fn test_normalize_azure_endpoint() {
315 let url = "https://myresource.openai.azure.com/openai";
317 assert_eq!(
318 normalize_base_url(url),
319 "https://myresource.openai.azure.com/openai/v1"
320 );
321 }
322
323 #[test]
326 fn test_new_with_defaults() {
327 let config = EmbeddingConfig {
328 model: "text-embedding-3-small".to_string(),
329 dimensions: 1536,
330 batch_size: 10,
331 ..EmbeddingConfig::default()
332 };
333 let engine = OpenAICompatibleEmbeddingEngine::new(&config)
334 .expect("should build engine with default config");
335 assert_eq!(engine.dimension(), 1536);
336 assert_eq!(engine.batch_size(), 10);
337 assert_eq!(engine.base_url, "https://api.openai.com/v1");
338 }
339
340 #[test]
341 fn test_new_with_custom_endpoint() {
342 let config = EmbeddingConfig {
343 endpoint: Some("http://localhost:8080/v1/embeddings".to_string()),
344 model: "my-model".to_string(),
345 dimensions: 384,
346 batch_size: 5,
347 ..EmbeddingConfig::default()
348 };
349 let engine = OpenAICompatibleEmbeddingEngine::new(&config)
350 .expect("should build engine with custom endpoint");
351 assert_eq!(engine.base_url, "http://localhost:8080/v1");
352 }
353
354 #[test]
355 fn test_embeddings_url() {
356 let config = EmbeddingConfig {
357 endpoint: Some("https://api.openai.com".to_string()),
358 ..EmbeddingConfig::default()
359 };
360 let engine = OpenAICompatibleEmbeddingEngine::new(&config).expect("should build engine");
361 assert_eq!(
362 engine.embeddings_url(),
363 "https://api.openai.com/v1/embeddings"
364 );
365 }
366
367 #[test]
370 fn test_is_retryable_http_error() {
371 assert!(is_retryable(&EmbeddingError::HttpError(
372 "HTTP 429: rate limited".to_string()
373 )));
374 assert!(is_retryable(&EmbeddingError::HttpError(
375 "HTTP 503: unavailable".to_string()
376 )));
377 }
378
379 #[test]
380 fn test_is_retryable_api_error_not_retryable() {
381 assert!(!is_retryable(&EmbeddingError::ApiError(
382 "HTTP 400: bad request".to_string()
383 )));
384 assert!(!is_retryable(&EmbeddingError::ConfigError(
385 "bad config".to_string()
386 )));
387 }
388}