Skip to main content

cognee_embedding/
openai_compatible.rs

1//! OpenAI-compatible embedding engine.
2//!
3//! Supports OpenAI, Azure OpenAI, and any server implementing the OpenAI
4//! `/v1/embeddings` endpoint (vLLM, llama.cpp, TEI, LocalAI, etc.).
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9use crate::config::EmbeddingConfig;
10use crate::engine::EmbeddingEngine;
11use crate::error::{EmbeddingError, EmbeddingResult};
12use crate::utils::{handle_embedding_response, sanitize_embedding_inputs};
13
14// ─── Response types ───────────────────────────────────────────────────────────
15
16#[derive(Deserialize)]
17struct EmbeddingResponse {
18    data: Vec<EmbeddingData>,
19}
20
21#[derive(Deserialize)]
22struct EmbeddingData {
23    embedding: Vec<f32>,
24}
25
26// ─── Request type ─────────────────────────────────────────────────────────────
27
28#[derive(Serialize)]
29struct EmbeddingRequest<'a> {
30    model: &'a str,
31    input: Vec<&'a str>,
32    encoding_format: &'a str,
33}
34
35// ─── Engine ───────────────────────────────────────────────────────────────────
36
37/// Embedding engine that calls an OpenAI-compatible `/v1/embeddings` HTTP endpoint.
38///
39/// Works with:
40/// - OpenAI (`https://api.openai.com`)
41/// - Azure OpenAI (set `api_version` in config)
42/// - vLLM, llama.cpp, TEI, LocalAI (any OpenAI-compatible server)
43///
44/// # URL normalisation
45///
46/// The `base_url` is derived from `config.endpoint` and is always normalised to
47/// end with `/v1` so that the final request URL is `{base_url}/embeddings`.
48///
49/// The following transformations are applied in order:
50/// 1. Strip a trailing `/`
51/// 2. If the URL ends with `/v1/embeddings`, strip the `/embeddings` suffix.
52/// 3. If the URL does not end with `/v1`, append `/v1`.
53///
54/// # Retry behaviour
55///
56/// Transient errors (HTTP 429, 5xx, network errors) are retried with
57/// exponential back-off (starting at 2 s, doubling up to 128 s, plus
58/// a uniform random jitter in `[0, wait_secs)`) for up to 128 s total.
59pub struct OpenAICompatibleEmbeddingEngine {
60    client: reqwest::Client,
61    /// Normalised base URL ending with `/v1`.
62    base_url: String,
63    model: String,
64    dimensions: usize,
65    batch_size: usize,
66    max_sequence_length: usize,
67}
68
69impl OpenAICompatibleEmbeddingEngine {
70    /// Construct a new engine from the given [`EmbeddingConfig`].
71    ///
72    /// Returns [`EmbeddingError::ConfigError`] if the `reqwest` client cannot
73    /// be built (e.g. invalid TLS configuration).
74    pub fn new(config: &EmbeddingConfig) -> EmbeddingResult<Self> {
75        let raw_endpoint = config
76            .endpoint
77            .clone()
78            .unwrap_or_else(|| "https://api.openai.com".to_string());
79
80        let base_url = normalize_base_url(&raw_endpoint);
81
82        let api_key = config.api_key.clone().unwrap_or_default();
83
84        let mut default_headers = reqwest::header::HeaderMap::new();
85        let bearer = format!("Bearer {api_key}");
86        let auth_value = reqwest::header::HeaderValue::from_str(&bearer)
87            .map_err(|e| EmbeddingError::ConfigError(format!("Invalid API key value: {e}")))?;
88        default_headers.insert(reqwest::header::AUTHORIZATION, auth_value);
89
90        // For Azure OpenAI the api-version is sent as a query parameter, not a
91        // header.  We store the version on the struct and append it per-request.
92        // Nothing to add to default headers here.
93
94        let client = reqwest::Client::builder()
95            .default_headers(default_headers)
96            .timeout(std::time::Duration::from_secs(30))
97            .build()
98            .map_err(|e| {
99                EmbeddingError::ConfigError(format!("Failed to build HTTP client: {e}"))
100            })?;
101
102        Ok(Self {
103            client,
104            base_url,
105            model: config.model.clone(),
106            dimensions: config.dimensions,
107            batch_size: config.batch_size,
108            max_sequence_length: config.max_completion_tokens,
109        })
110    }
111
112    /// Build the full embeddings URL.
113    fn embeddings_url(&self) -> String {
114        format!("{}/embeddings", self.base_url)
115    }
116
117    /// Call the `/v1/embeddings` endpoint once (no retry).
118    async fn embed_batch_once(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
119        let sanitized = sanitize_embedding_inputs(texts);
120        let sanitized_strs: Vec<&str> = sanitized.iter().map(|c| c.as_ref()).collect();
121
122        let request_body = EmbeddingRequest {
123            model: &self.model,
124            input: sanitized_strs,
125            encoding_format: "float",
126        };
127
128        let response = self
129            .client
130            .post(self.embeddings_url())
131            .json(&request_body)
132            .send()
133            .await
134            .map_err(|e| EmbeddingError::HttpError(format!("Request failed: {e}")))?;
135
136        let status = response.status();
137        if !status.is_success() {
138            let body = response
139                .text()
140                .await
141                .unwrap_or_else(|_| "<failed to read body>".to_string());
142            return Err(if status.as_u16() == 429 || status.is_server_error() {
143                // Retryable — use HttpError so `is_retryable` can detect it
144                EmbeddingError::HttpError(format!("HTTP {status}: {body}"))
145            } else {
146                EmbeddingError::ApiError(format!("HTTP {status}: {body}"))
147            });
148        }
149
150        let parsed: EmbeddingResponse = response
151            .json()
152            .await
153            .map_err(|e| EmbeddingError::ApiError(format!("Failed to parse response: {e}")))?;
154
155        let vectors: Vec<Vec<f32>> = parsed.data.into_iter().map(|d| d.embedding).collect();
156
157        // Zero out slots that were originally empty/whitespace
158        let result = handle_embedding_response(texts, vectors, self.dimensions);
159        Ok(result)
160    }
161
162    /// Call the endpoint with exponential-jitter retry on transient errors.
163    ///
164    /// Retries for up to 128 s total. Wait starts at 2 s and doubles on each
165    /// attempt, capped at 128 s.  A uniform random jitter of `[0, wait_secs)`
166    /// is added to prevent thundering-herd.
167    async fn embed_batch_with_retry(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
168        let max_duration = std::time::Duration::from_secs(128);
169        let start = std::time::Instant::now();
170        let mut wait_secs = 2u64;
171        loop {
172            match self.embed_batch_once(texts).await {
173                Ok(result) => return Ok(result),
174                Err(e) if is_retryable(&e) && start.elapsed() < max_duration => {
175                    let jitter = rand::random::<u64>() % wait_secs;
176                    tokio::time::sleep(std::time::Duration::from_secs(wait_secs + jitter)).await;
177                    wait_secs = (wait_secs * 2).min(128);
178                }
179                Err(e) => return Err(e),
180            }
181        }
182    }
183}
184
185#[async_trait]
186impl EmbeddingEngine for OpenAICompatibleEmbeddingEngine {
187    async fn embed(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
188        if texts.is_empty() {
189            return Ok(Vec::new());
190        }
191
192        let mut results: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
193
194        for batch in texts.chunks(self.batch_size) {
195            let batch_results = self.embed_batch_with_retry(batch).await?;
196            results.extend(batch_results);
197        }
198
199        Ok(results)
200    }
201
202    fn dimension(&self) -> usize {
203        self.dimensions
204    }
205
206    fn batch_size(&self) -> usize {
207        self.batch_size
208    }
209
210    fn max_sequence_length(&self) -> usize {
211        self.max_sequence_length
212    }
213}
214
215// ─── Error classification ─────────────────────────────────────────────────────
216
217/// Returns `true` for errors that are worth retrying (rate-limit, server error, network).
218fn is_retryable(e: &EmbeddingError) -> bool {
219    matches!(e, EmbeddingError::HttpError(_))
220}
221
222// ─── URL normalisation ────────────────────────────────────────────────────────
223
224/// Normalise an endpoint URL to always end with `/v1`.
225///
226/// Rules (applied in order):
227/// 1. Strip trailing `/`
228/// 2. Strip `/embeddings` suffix if present (so `/v1/embeddings` → `/v1`)
229/// 3. Append `/v1` if the URL does not already end with `/v1`
230pub(crate) fn normalize_base_url(url: &str) -> String {
231    let mut s = url.trim_end_matches('/').to_string();
232
233    if s.ends_with("/v1/embeddings") {
234        s.truncate(s.len() - "/embeddings".len());
235    }
236
237    if !s.ends_with("/v1") {
238        s.push_str("/v1");
239    }
240
241    s
242}
243
244// ─── Tests ────────────────────────────────────────────────────────────────────
245
246#[cfg(test)]
247#[allow(
248    clippy::expect_used,
249    clippy::unwrap_used,
250    reason = "test code — panics are acceptable failures"
251)]
252mod tests {
253    use super::*;
254
255    // ── URL normalisation ────────────────────────────────────────────────────
256
257    #[test]
258    fn test_normalize_plain_domain() {
259        assert_eq!(
260            normalize_base_url("https://api.openai.com"),
261            "https://api.openai.com/v1"
262        );
263    }
264
265    #[test]
266    fn test_normalize_trailing_slash() {
267        assert_eq!(
268            normalize_base_url("https://api.openai.com/"),
269            "https://api.openai.com/v1"
270        );
271    }
272
273    #[test]
274    fn test_normalize_already_v1() {
275        assert_eq!(
276            normalize_base_url("https://api.openai.com/v1"),
277            "https://api.openai.com/v1"
278        );
279    }
280
281    #[test]
282    fn test_normalize_v1_trailing_slash() {
283        assert_eq!(
284            normalize_base_url("https://api.openai.com/v1/"),
285            "https://api.openai.com/v1"
286        );
287    }
288
289    #[test]
290    fn test_normalize_v1_embeddings_suffix() {
291        assert_eq!(
292            normalize_base_url("https://api.openai.com/v1/embeddings"),
293            "https://api.openai.com/v1"
294        );
295    }
296
297    #[test]
298    fn test_normalize_localhost_with_port() {
299        assert_eq!(
300            normalize_base_url("http://localhost:11434"),
301            "http://localhost:11434/v1"
302        );
303    }
304
305    #[test]
306    fn test_normalize_localhost_with_port_v1() {
307        assert_eq!(
308            normalize_base_url("http://localhost:8080/v1"),
309            "http://localhost:8080/v1"
310        );
311    }
312
313    #[test]
314    fn test_normalize_azure_endpoint() {
315        // Azure endpoints typically end with the API path, not /v1
316        let url = "https://myresource.openai.azure.com/openai";
317        assert_eq!(
318            normalize_base_url(url),
319            "https://myresource.openai.azure.com/openai/v1"
320        );
321    }
322
323    // ── Constructor ──────────────────────────────────────────────────────────
324
325    #[test]
326    fn test_new_with_defaults() {
327        let config = EmbeddingConfig {
328            model: "text-embedding-3-small".to_string(),
329            dimensions: 1536,
330            batch_size: 10,
331            ..EmbeddingConfig::default()
332        };
333        let engine = OpenAICompatibleEmbeddingEngine::new(&config)
334            .expect("should build engine with default config");
335        assert_eq!(engine.dimension(), 1536);
336        assert_eq!(engine.batch_size(), 10);
337        assert_eq!(engine.base_url, "https://api.openai.com/v1");
338    }
339
340    #[test]
341    fn test_new_with_custom_endpoint() {
342        let config = EmbeddingConfig {
343            endpoint: Some("http://localhost:8080/v1/embeddings".to_string()),
344            model: "my-model".to_string(),
345            dimensions: 384,
346            batch_size: 5,
347            ..EmbeddingConfig::default()
348        };
349        let engine = OpenAICompatibleEmbeddingEngine::new(&config)
350            .expect("should build engine with custom endpoint");
351        assert_eq!(engine.base_url, "http://localhost:8080/v1");
352    }
353
354    #[test]
355    fn test_embeddings_url() {
356        let config = EmbeddingConfig {
357            endpoint: Some("https://api.openai.com".to_string()),
358            ..EmbeddingConfig::default()
359        };
360        let engine = OpenAICompatibleEmbeddingEngine::new(&config).expect("should build engine");
361        assert_eq!(
362            engine.embeddings_url(),
363            "https://api.openai.com/v1/embeddings"
364        );
365    }
366
367    // ── is_retryable ─────────────────────────────────────────────────────────
368
369    #[test]
370    fn test_is_retryable_http_error() {
371        assert!(is_retryable(&EmbeddingError::HttpError(
372            "HTTP 429: rate limited".to_string()
373        )));
374        assert!(is_retryable(&EmbeddingError::HttpError(
375            "HTTP 503: unavailable".to_string()
376        )));
377    }
378
379    #[test]
380    fn test_is_retryable_api_error_not_retryable() {
381        assert!(!is_retryable(&EmbeddingError::ApiError(
382            "HTTP 400: bad request".to_string()
383        )));
384        assert!(!is_retryable(&EmbeddingError::ConfigError(
385            "bad config".to_string()
386        )));
387    }
388}