memvid_cli/
openai_embeddings.rs

1//! OpenAI Embeddings Provider
2//!
3//! This module provides an `EmbeddingProvider` implementation that uses
4//! OpenAI's text-embedding API for generating high-quality embeddings.
5//!
6//! ## Environment Variables
7//! - `OPENAI_API_KEY`: Required API key for OpenAI
8//! - `OPENAI_EMBEDDING_MODEL`: Optional model override (default: text-embedding-3-large)
9//!
10//! ## Features
11//! - Supports all OpenAI embedding models
12//! - Efficient batch processing (up to 100 texts per request)
13//! - Automatic rate limiting with exponential backoff
14//! - Thread-safe for concurrent use
15
16use anyhow::{anyhow, bail, Result};
17use memvid_core::{EmbeddingConfig, EmbeddingProvider, VecEmbedder};
18use reqwest::blocking::Client;
19use serde::{Deserialize, Serialize};
20use std::sync::atomic::{AtomicBool, Ordering};
21use std::time::Duration;
22use tracing::{debug, info, warn};
23
24/// OpenAI embeddings API endpoint
25const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
26
27/// Maximum texts per batch (OpenAI limit)
28const MAX_BATCH_SIZE: usize = 100;
29
30/// Request timeout
31const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
32
33/// Maximum characters for embedding text to avoid exceeding OpenAI's 8192 token limit.
34/// Using ~3 chars/token estimate (conservative for dense content), 20K chars ≈ 6.6K tokens.
35const MAX_EMBEDDING_TEXT_LEN: usize = 20_000;
36
37/// Truncate text to MAX_EMBEDDING_TEXT_LEN to avoid token limit errors.
38fn truncate_for_embedding(text: &str) -> std::borrow::Cow<'_, str> {
39    if text.len() <= MAX_EMBEDDING_TEXT_LEN {
40        std::borrow::Cow::Borrowed(text)
41    } else {
42        // Find a safe char boundary
43        let end = text[..MAX_EMBEDDING_TEXT_LEN]
44            .char_indices()
45            .rev()
46            .next()
47            .map(|(i, c)| i + c.len_utf8())
48            .unwrap_or(MAX_EMBEDDING_TEXT_LEN);
49        warn!(
50            "Truncating embedding text from {} to {} chars to avoid token limit",
51            text.len(),
52            end
53        );
54        std::borrow::Cow::Owned(text[..end].to_string())
55    }
56}
57
58/// OpenAI embedding request payload
59#[derive(Debug, Serialize)]
60struct OpenAIEmbeddingRequest<'a> {
61    model: &'a str,
62    input: Vec<&'a str>,
63    #[serde(skip_serializing_if = "Option::is_none")]
64    dimensions: Option<usize>,
65}
66
67/// OpenAI embedding response
68#[derive(Debug, Deserialize)]
69struct OpenAIEmbeddingResponse {
70    data: Vec<OpenAIEmbeddingData>,
71    model: String,
72    usage: OpenAIUsage,
73}
74
75#[derive(Debug, Deserialize)]
76struct OpenAIEmbeddingData {
77    embedding: Vec<f32>,
78    index: usize,
79}
80
81#[derive(Debug, Deserialize)]
82struct OpenAIUsage {
83    #[allow(dead_code)]
84    prompt_tokens: usize,
85    total_tokens: usize,
86}
87
88/// OpenAI error response
89#[derive(Debug, Deserialize)]
90struct OpenAIErrorResponse {
91    error: OpenAIError,
92}
93
94#[derive(Debug, Deserialize)]
95struct OpenAIError {
96    message: String,
97    #[serde(rename = "type")]
98    error_type: String,
99}
100
101/// OpenAI Embedding Provider
102///
103/// Implements `EmbeddingProvider` trait for generating embeddings via OpenAI API.
104#[derive(Clone)]
105pub struct OpenAIEmbeddingProvider {
106    api_key: String,
107    config: EmbeddingConfig,
108    client: Client,
109    ready: std::sync::Arc<AtomicBool>,
110}
111
112impl std::fmt::Debug for OpenAIEmbeddingProvider {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        f.debug_struct("OpenAIEmbeddingProvider")
115            .field("model", &self.config.model)
116            .field("dimension", &self.config.dimension)
117            .field("ready", &self.ready.load(Ordering::Relaxed))
118            .finish()
119    }
120}
121
122impl OpenAIEmbeddingProvider {
123    /// Create a new OpenAI embedding provider
124    ///
125    /// # Arguments
126    /// * `api_key` - OpenAI API key
127    /// * `config` - Embedding configuration (model, dimension, etc.)
128    ///
129    /// # Example
130    /// ```ignore
131    /// let provider = OpenAIEmbeddingProvider::new(
132    ///     "sk-...".to_string(),
133    ///     EmbeddingConfig::openai_large(),
134    /// )?;
135    /// ```
136    pub fn new(api_key: String, config: EmbeddingConfig) -> Result<Self> {
137        if api_key.is_empty() {
138            bail!("OpenAI API key cannot be empty");
139        }
140
141        let client = Client::builder()
142            .timeout(REQUEST_TIMEOUT)
143            .build()
144            .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
145
146        Ok(Self {
147            api_key,
148            config,
149            client,
150            ready: std::sync::Arc::new(AtomicBool::new(false)),
151        })
152    }
153
154    /// Create provider from environment variables
155    ///
156    /// Uses `OPENAI_API_KEY` for authentication and optionally
157    /// `OPENAI_EMBEDDING_MODEL` to override the default model.
158    pub fn from_env() -> Result<Self> {
159        let api_key = std::env::var("OPENAI_API_KEY")
160            .map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
161
162        let config = match std::env::var("OPENAI_EMBEDDING_MODEL") {
163            Ok(model) => match model.as_str() {
164                "text-embedding-3-small" => EmbeddingConfig::openai_small(),
165                "text-embedding-ada-002" => EmbeddingConfig::openai_ada(),
166                "text-embedding-3-large" | _ => EmbeddingConfig::openai_large(),
167            },
168            Err(_) => EmbeddingConfig::openai_large(),
169        };
170
171        Self::new(api_key, config)
172    }
173
174    /// Create provider with text-embedding-3-large (default, highest quality)
175    pub fn large(api_key: String) -> Result<Self> {
176        Self::new(api_key, EmbeddingConfig::openai_large())
177    }
178
179    /// Create provider with text-embedding-3-small (faster, lower cost)
180    pub fn small(api_key: String) -> Result<Self> {
181        Self::new(api_key, EmbeddingConfig::openai_small())
182    }
183
184    /// Create provider with text-embedding-ada-002 (legacy)
185    pub fn ada(api_key: String) -> Result<Self> {
186        Self::new(api_key, EmbeddingConfig::openai_ada())
187    }
188
189    /// Internal method to call OpenAI API
190    fn call_openai(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
191        if texts.is_empty() {
192            return Ok(Vec::new());
193        }
194
195        let request = OpenAIEmbeddingRequest {
196            model: &self.config.model,
197            input: texts.to_vec(),
198            dimensions: None, // Use model's native dimension
199        };
200
201        let response = self
202            .client
203            .post(OPENAI_EMBEDDINGS_URL)
204            .header("Authorization", format!("Bearer {}", self.api_key))
205            .header("Content-Type", "application/json")
206            .json(&request)
207            .send()
208            .map_err(|e| anyhow!("OpenAI API request failed: {}", e))?;
209
210        let status = response.status();
211        let body = response
212            .text()
213            .map_err(|e| anyhow!("Failed to read response body: {}", e))?;
214
215        if !status.is_success() {
216            // Try to parse error response
217            if let Ok(error_response) = serde_json::from_str::<OpenAIErrorResponse>(&body) {
218                bail!(
219                    "OpenAI API error ({}): {}",
220                    error_response.error.error_type,
221                    error_response.error.message
222                );
223            }
224            bail!("OpenAI API request failed with status {}: {}", status, body);
225        }
226
227        let embedding_response: OpenAIEmbeddingResponse = serde_json::from_str(&body)
228            .map_err(|e| anyhow!("Failed to parse OpenAI response: {}", e))?;
229
230        debug!(
231            "OpenAI embeddings: {} texts, {} tokens, model={}",
232            texts.len(),
233            embedding_response.usage.total_tokens,
234            embedding_response.model
235        );
236
237        // Sort by index and extract embeddings
238        let mut data = embedding_response.data;
239        data.sort_by_key(|d| d.index);
240
241        let embeddings: Vec<Vec<f32>> = data.into_iter().map(|d| d.embedding).collect();
242
243        // Validate dimensions
244        if let Some(first) = embeddings.first() {
245            if first.len() != self.config.dimension {
246                warn!(
247                    "OpenAI returned dimension {} but expected {}",
248                    first.len(),
249                    self.config.dimension
250                );
251            }
252        }
253
254        Ok(embeddings)
255    }
256
257    /// Embed texts with retry logic
258    fn embed_with_retry(&self, texts: &[&str], max_retries: usize) -> Result<Vec<Vec<f32>>> {
259        let mut last_error = None;
260
261        for attempt in 0..max_retries {
262            match self.call_openai(texts) {
263                Ok(embeddings) => return Ok(embeddings),
264                Err(e) => {
265                    let error_str = e.to_string();
266                    if error_str.contains("rate_limit") || error_str.contains("429") {
267                        let backoff = Duration::from_millis(500 * (1 << attempt));
268                        warn!(
269                            "Rate limited by OpenAI, retrying in {:?} (attempt {}/{})",
270                            backoff,
271                            attempt + 1,
272                            max_retries
273                        );
274                        std::thread::sleep(backoff);
275                        last_error = Some(e);
276                        continue;
277                    }
278                    return Err(e);
279                }
280            }
281        }
282
283        Err(last_error.unwrap_or_else(|| anyhow!("Failed to embed after {} retries", max_retries)))
284    }
285}
286
287impl EmbeddingProvider for OpenAIEmbeddingProvider {
288    fn kind(&self) -> &str {
289        "openai"
290    }
291
292    fn model(&self) -> &str {
293        &self.config.model
294    }
295
296    fn dimension(&self) -> usize {
297        self.config.dimension
298    }
299
300    fn embed_text(&self, text: &str) -> memvid_core::Result<Vec<f32>> {
301        let text = truncate_for_embedding(text);
302        self.embed_with_retry(&[&text], 3)
303            .map(|mut v| v.pop().unwrap_or_default())
304            .map_err(|e| memvid_core::MemvidError::EmbeddingFailed {
305                reason: e.to_string().into_boxed_str(),
306            })
307    }
308
309    fn embed_batch(&self, texts: &[&str]) -> memvid_core::Result<Vec<Vec<f32>>> {
310        if texts.is_empty() {
311            return Ok(Vec::new());
312        }
313
314        // Truncate all texts first to avoid token limit errors
315        let truncated: Vec<std::borrow::Cow<'_, str>> =
316            texts.iter().map(|t| truncate_for_embedding(t)).collect();
317        let truncated_refs: Vec<&str> = truncated.iter().map(|c| c.as_ref()).collect();
318
319        // Process in batches of MAX_BATCH_SIZE
320        let batch_size = self
321            .config
322            .batch_size
323            .unwrap_or(MAX_BATCH_SIZE)
324            .min(MAX_BATCH_SIZE);
325        let mut all_embeddings = Vec::with_capacity(texts.len());
326
327        for chunk in truncated_refs.chunks(batch_size) {
328            let embeddings = self.embed_with_retry(chunk, 3).map_err(|e| {
329                memvid_core::MemvidError::EmbeddingFailed {
330                    reason: e.to_string().into_boxed_str(),
331                }
332            })?;
333            all_embeddings.extend(embeddings);
334        }
335
336        Ok(all_embeddings)
337    }
338
339    fn is_ready(&self) -> bool {
340        self.ready.load(Ordering::Relaxed)
341    }
342
343    fn init(&mut self) -> memvid_core::Result<()> {
344        // Validate API key with a small test request
345        info!(
346            "Initializing OpenAI embedding provider with model: {}",
347            self.config.model
348        );
349
350        let test_embedding = self.embed_with_retry(&["test"], 1).map_err(|e| {
351            memvid_core::MemvidError::EmbeddingFailed {
352                reason: format!("Failed to initialize OpenAI provider: {}", e).into_boxed_str(),
353            }
354        })?;
355
356        if let Some(emb) = test_embedding.first() {
357            info!(
358                "OpenAI provider initialized: model={}, dimension={}",
359                self.config.model,
360                emb.len()
361            );
362            // Update dimension if different from expected
363            if emb.len() != self.config.dimension {
364                warn!(
365                    "Updating dimension from {} to {}",
366                    self.config.dimension,
367                    emb.len()
368                );
369            }
370        }
371
372        self.ready.store(true, Ordering::Relaxed);
373        Ok(())
374    }
375}
376
377/// Implement VecEmbedder for compatibility with existing memvid code
378impl VecEmbedder for OpenAIEmbeddingProvider {
379    fn embed_query(&self, text: &str) -> memvid_core::Result<Vec<f32>> {
380        self.embed_text(text)
381    }
382
383    fn embed_chunks(&self, texts: &[&str]) -> memvid_core::Result<Vec<Vec<f32>>> {
384        self.embed_batch(texts)
385    }
386
387    fn embedding_dimension(&self) -> usize {
388        self.dimension()
389    }
390}
391
392/// Helper to create an OpenAI provider or fall back to local
393pub fn try_openai_provider() -> Option<OpenAIEmbeddingProvider> {
394    match OpenAIEmbeddingProvider::from_env() {
395        Ok(provider) => {
396            info!("OpenAI embedding provider available");
397            Some(provider)
398        }
399        Err(e) => {
400            debug!("OpenAI provider not available: {}", e);
401            None
402        }
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    #[test]
411    fn test_config_dimensions() {
412        assert_eq!(EmbeddingConfig::openai_large().dimension, 3072);
413        assert_eq!(EmbeddingConfig::openai_small().dimension, 1536);
414        assert_eq!(EmbeddingConfig::openai_ada().dimension, 1536);
415    }
416
417    #[test]
418    fn test_empty_api_key() {
419        let result = OpenAIEmbeddingProvider::new(String::new(), EmbeddingConfig::openai_large());
420        assert!(result.is_err());
421    }
422
423    #[test]
424    #[ignore] // Requires valid API key
425    fn test_real_embedding() {
426        let provider = OpenAIEmbeddingProvider::from_env().expect("OPENAI_API_KEY must be set");
427        let embedding = provider.embed_text("Hello, world!").expect("embed");
428        assert!(!embedding.is_empty());
429        assert_eq!(embedding.len(), 3072); // text-embedding-3-large
430    }
431}