Skip to main content

cognee_embedding/
openai_compatible.rs

1//! OpenAI-compatible embedding engine.
2//!
3//! Supports OpenAI, Azure OpenAI, and any server implementing the OpenAI
4//! `/v1/embeddings` endpoint (vLLM, llama.cpp, TEI, LocalAI, etc.).
5
6use async_trait::async_trait;
7use futures::stream::{self, StreamExt, TryStreamExt};
8use serde::{Deserialize, Serialize};
9
10use crate::config::EmbeddingConfig;
11use crate::engine::EmbeddingEngine;
12use crate::error::{EmbeddingError, EmbeddingResult};
13use crate::utils::{handle_embedding_response, sanitize_embedding_inputs};
14
15/// Maximum number of sub-batch HTTP requests issued concurrently from a single
16/// `embed` call. Bounds in-flight work against provider rate limits while still
17/// overlapping network latency across sub-batches.
18const MAX_CONCURRENT_BATCHES: usize = 8;
19
20// ─── Response types ───────────────────────────────────────────────────────────
21
22#[derive(Deserialize)]
23struct EmbeddingResponse {
24    data: Vec<EmbeddingData>,
25}
26
27#[derive(Deserialize)]
28struct EmbeddingData {
29    embedding: Vec<f32>,
30}
31
32// ─── Request type ─────────────────────────────────────────────────────────────
33
34#[derive(Serialize)]
35struct EmbeddingRequest<'a> {
36    model: &'a str,
37    input: Vec<&'a str>,
38    encoding_format: &'a str,
39}
40
41// ─── Engine ───────────────────────────────────────────────────────────────────
42
43/// Embedding engine that calls an OpenAI-compatible `/v1/embeddings` HTTP endpoint.
44///
45/// Works with:
46/// - OpenAI (`https://api.openai.com`)
47/// - Azure OpenAI (set `api_version` in config)
48/// - vLLM, llama.cpp, TEI, LocalAI (any OpenAI-compatible server)
49///
50/// # URL normalisation
51///
52/// The `base_url` is derived from `config.endpoint` and is always normalised to
53/// end with `/v1` so that the final request URL is `{base_url}/embeddings`.
54///
55/// The following transformations are applied in order:
56/// 1. Strip a trailing `/`
57/// 2. If the URL ends with `/v1/embeddings`, strip the `/embeddings` suffix.
58/// 3. If the URL does not end with `/v1`, append `/v1`.
59///
60/// # Retry behaviour
61///
62/// Transient errors (HTTP 429, 5xx, network errors) are retried with
63/// exponential back-off (starting at 2 s, doubling up to 128 s, plus
64/// a uniform random jitter in `[0, wait_secs)`) for up to 128 s total.
65pub struct OpenAICompatibleEmbeddingEngine {
66    client: reqwest::Client,
67    /// Normalised base URL ending with `/v1`.
68    base_url: String,
69    model: String,
70    dimensions: usize,
71    batch_size: usize,
72    max_sequence_length: usize,
73}
74
75impl OpenAICompatibleEmbeddingEngine {
76    /// Construct a new engine from the given [`EmbeddingConfig`].
77    ///
78    /// Returns [`EmbeddingError::ConfigError`] if the `reqwest` client cannot
79    /// be built (e.g. invalid TLS configuration).
80    pub fn new(config: &EmbeddingConfig) -> EmbeddingResult<Self> {
81        let raw_endpoint = config
82            .endpoint
83            .clone()
84            .unwrap_or_else(|| "https://api.openai.com".to_string());
85
86        let base_url = normalize_base_url(&raw_endpoint);
87
88        let api_key = config.api_key.clone().unwrap_or_default();
89
90        let mut default_headers = reqwest::header::HeaderMap::new();
91        let bearer = format!("Bearer {api_key}");
92        let auth_value = reqwest::header::HeaderValue::from_str(&bearer)
93            .map_err(|e| EmbeddingError::ConfigError(format!("Invalid API key value: {e}")))?;
94        default_headers.insert(reqwest::header::AUTHORIZATION, auth_value);
95
96        // For Azure OpenAI the api-version is sent as a query parameter, not a
97        // header.  We store the version on the struct and append it per-request.
98        // Nothing to add to default headers here.
99
100        let client = reqwest::Client::builder()
101            .default_headers(default_headers)
102            .timeout(std::time::Duration::from_secs(30))
103            .build()
104            .map_err(|e| {
105                EmbeddingError::ConfigError(format!("Failed to build HTTP client: {e}"))
106            })?;
107
108        Ok(Self {
109            client,
110            base_url,
111            model: config.model.clone(),
112            dimensions: config.dimensions,
113            batch_size: config.batch_size,
114            max_sequence_length: config.max_completion_tokens,
115        })
116    }
117
118    /// Build the full embeddings URL.
119    fn embeddings_url(&self) -> String {
120        format!("{}/embeddings", self.base_url)
121    }
122
123    /// Call the `/v1/embeddings` endpoint once (no retry).
124    async fn embed_batch_once(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
125        let sanitized = sanitize_embedding_inputs(texts);
126        let sanitized_strs: Vec<&str> = sanitized.iter().map(|c| c.as_ref()).collect();
127
128        let request_body = EmbeddingRequest {
129            model: &self.model,
130            input: sanitized_strs,
131            encoding_format: "float",
132        };
133
134        let response = self
135            .client
136            .post(self.embeddings_url())
137            .json(&request_body)
138            .send()
139            .await
140            .map_err(|e| EmbeddingError::HttpError(format!("Request failed: {e}")))?;
141
142        let status = response.status();
143        if !status.is_success() {
144            let body = response
145                .text()
146                .await
147                .unwrap_or_else(|_| "<failed to read body>".to_string());
148            return Err(if status.as_u16() == 429 || status.is_server_error() {
149                // Retryable — use HttpError so `is_retryable` can detect it
150                EmbeddingError::HttpError(format!("HTTP {status}: {body}"))
151            } else {
152                EmbeddingError::ApiError(format!("HTTP {status}: {body}"))
153            });
154        }
155
156        let parsed: EmbeddingResponse = response
157            .json()
158            .await
159            .map_err(|e| EmbeddingError::ApiError(format!("Failed to parse response: {e}")))?;
160
161        let vectors: Vec<Vec<f32>> = parsed.data.into_iter().map(|d| d.embedding).collect();
162
163        // Zero out slots that were originally empty/whitespace
164        let result = handle_embedding_response(texts, vectors, self.dimensions);
165        Ok(result)
166    }
167
168    /// Call the endpoint with exponential-jitter retry on transient errors.
169    ///
170    /// Retries for up to 128 s total. Wait starts at 2 s and doubles on each
171    /// attempt, capped at 128 s.  A uniform random jitter of `[0, wait_secs)`
172    /// is added to prevent thundering-herd.
173    async fn embed_batch_with_retry(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
174        let max_duration = std::time::Duration::from_secs(128);
175        let start = std::time::Instant::now();
176        let mut wait_secs = 2u64;
177        loop {
178            match self.embed_batch_once(texts).await {
179                Ok(result) => return Ok(result),
180                Err(e) if is_retryable(&e) && start.elapsed() < max_duration => {
181                    let jitter = rand::random::<u64>() % wait_secs;
182                    tokio::time::sleep(std::time::Duration::from_secs(wait_secs + jitter)).await;
183                    wait_secs = (wait_secs * 2).min(128);
184                }
185                Err(e) => return Err(e),
186            }
187        }
188    }
189}
190
191#[async_trait]
192impl EmbeddingEngine for OpenAICompatibleEmbeddingEngine {
193    async fn embed(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
194        if texts.is_empty() {
195            return Ok(Vec::new());
196        }
197
198        // Dispatch sub-batches concurrently (bounded by MAX_CONCURRENT_BATCHES).
199        // `try_collect` over `buffer_unordered` aborts on the first failure —
200        // cancelling in-flight retries instead of waiting them out — and the
201        // batch index restores input order afterwards.
202        let batch_futures: Vec<_> = texts
203            .chunks(self.batch_size.max(1))
204            .enumerate()
205            .map(|(index, batch)| async move {
206                self.embed_batch_with_retry(batch).await.map(|v| (index, v))
207            })
208            .collect();
209
210        let mut indexed: Vec<(usize, Vec<Vec<f32>>)> = stream::iter(batch_futures)
211            .buffer_unordered(MAX_CONCURRENT_BATCHES)
212            .try_collect()
213            .await?;
214
215        indexed.sort_by_key(|(index, _)| *index);
216        Ok(indexed.into_iter().flat_map(|(_, batch)| batch).collect())
217    }
218
219    fn dimension(&self) -> usize {
220        self.dimensions
221    }
222
223    fn batch_size(&self) -> usize {
224        self.batch_size
225    }
226
227    fn max_sequence_length(&self) -> usize {
228        self.max_sequence_length
229    }
230}
231
232// ─── Error classification ─────────────────────────────────────────────────────
233
234/// Returns `true` for errors that are worth retrying (rate-limit, server error, network).
235fn is_retryable(e: &EmbeddingError) -> bool {
236    matches!(e, EmbeddingError::HttpError(_))
237}
238
239// ─── URL normalisation ────────────────────────────────────────────────────────
240
241/// Normalise an endpoint URL to always end with `/v1`.
242///
243/// Rules (applied in order):
244/// 1. Strip trailing `/`
245/// 2. Strip `/embeddings` suffix if present (so `/v1/embeddings` → `/v1`)
246/// 3. Append `/v1` if the URL does not already end with `/v1`
247pub(crate) fn normalize_base_url(url: &str) -> String {
248    let mut s = url.trim_end_matches('/').to_string();
249
250    if s.ends_with("/v1/embeddings") {
251        s.truncate(s.len() - "/embeddings".len());
252    }
253
254    if !s.ends_with("/v1") {
255        s.push_str("/v1");
256    }
257
258    s
259}
260
261// ─── Tests ────────────────────────────────────────────────────────────────────
262
263#[cfg(test)]
264#[allow(
265    clippy::expect_used,
266    clippy::unwrap_used,
267    reason = "test code — panics are acceptable failures"
268)]
269mod tests {
270    use super::*;
271
272    // ── URL normalisation ────────────────────────────────────────────────────
273
274    #[test]
275    fn test_normalize_plain_domain() {
276        assert_eq!(
277            normalize_base_url("https://api.openai.com"),
278            "https://api.openai.com/v1"
279        );
280    }
281
282    #[test]
283    fn test_normalize_trailing_slash() {
284        assert_eq!(
285            normalize_base_url("https://api.openai.com/"),
286            "https://api.openai.com/v1"
287        );
288    }
289
290    #[test]
291    fn test_normalize_already_v1() {
292        assert_eq!(
293            normalize_base_url("https://api.openai.com/v1"),
294            "https://api.openai.com/v1"
295        );
296    }
297
298    #[test]
299    fn test_normalize_v1_trailing_slash() {
300        assert_eq!(
301            normalize_base_url("https://api.openai.com/v1/"),
302            "https://api.openai.com/v1"
303        );
304    }
305
306    #[test]
307    fn test_normalize_v1_embeddings_suffix() {
308        assert_eq!(
309            normalize_base_url("https://api.openai.com/v1/embeddings"),
310            "https://api.openai.com/v1"
311        );
312    }
313
314    #[test]
315    fn test_normalize_localhost_with_port() {
316        assert_eq!(
317            normalize_base_url("http://localhost:11434"),
318            "http://localhost:11434/v1"
319        );
320    }
321
322    #[test]
323    fn test_normalize_localhost_with_port_v1() {
324        assert_eq!(
325            normalize_base_url("http://localhost:8080/v1"),
326            "http://localhost:8080/v1"
327        );
328    }
329
330    #[test]
331    fn test_normalize_azure_endpoint() {
332        // Azure endpoints typically end with the API path, not /v1
333        let url = "https://myresource.openai.azure.com/openai";
334        assert_eq!(
335            normalize_base_url(url),
336            "https://myresource.openai.azure.com/openai/v1"
337        );
338    }
339
340    // ── Constructor ──────────────────────────────────────────────────────────
341
342    #[test]
343    fn test_new_with_defaults() {
344        let config = EmbeddingConfig {
345            model: "text-embedding-3-small".to_string(),
346            dimensions: 1536,
347            batch_size: 10,
348            ..EmbeddingConfig::default()
349        };
350        let engine = OpenAICompatibleEmbeddingEngine::new(&config)
351            .expect("should build engine with default config");
352        assert_eq!(engine.dimension(), 1536);
353        assert_eq!(engine.batch_size(), 10);
354        assert_eq!(engine.base_url, "https://api.openai.com/v1");
355    }
356
357    #[test]
358    fn test_new_with_custom_endpoint() {
359        let config = EmbeddingConfig {
360            endpoint: Some("http://localhost:8080/v1/embeddings".to_string()),
361            model: "my-model".to_string(),
362            dimensions: 384,
363            batch_size: 5,
364            ..EmbeddingConfig::default()
365        };
366        let engine = OpenAICompatibleEmbeddingEngine::new(&config)
367            .expect("should build engine with custom endpoint");
368        assert_eq!(engine.base_url, "http://localhost:8080/v1");
369    }
370
371    #[test]
372    fn test_embeddings_url() {
373        let config = EmbeddingConfig {
374            endpoint: Some("https://api.openai.com".to_string()),
375            ..EmbeddingConfig::default()
376        };
377        let engine = OpenAICompatibleEmbeddingEngine::new(&config).expect("should build engine");
378        assert_eq!(
379            engine.embeddings_url(),
380            "https://api.openai.com/v1/embeddings"
381        );
382    }
383
384    // ── is_retryable ─────────────────────────────────────────────────────────
385
386    #[test]
387    fn test_is_retryable_http_error() {
388        assert!(is_retryable(&EmbeddingError::HttpError(
389            "HTTP 429: rate limited".to_string()
390        )));
391        assert!(is_retryable(&EmbeddingError::HttpError(
392            "HTTP 503: unavailable".to_string()
393        )));
394    }
395
396    #[test]
397    fn test_is_retryable_api_error_not_retryable() {
398        assert!(!is_retryable(&EmbeddingError::ApiError(
399            "HTTP 400: bad request".to_string()
400        )));
401        assert!(!is_retryable(&EmbeddingError::ConfigError(
402            "bad config".to_string()
403        )));
404    }
405}