Skip to main content

converge_knowledge/embedding/
openai.rs

1//! OpenAI embedding provider.
2//!
3//! Uses the OpenAI API to generate high-quality text embeddings.
4//! Supports text-embedding-3-small, text-embedding-3-large, and text-embedding-ada-002.
5//!
6//! # Features
7//! - Automatic retries with exponential backoff
8//! - In-memory caching to avoid re-embedding
9//! - Rate limiting to respect API limits
10//! - Token usage tracking
11//! - Environment variable support for API key
12//!
13//! # Example
14//! ```ignore
15//! use converge_knowledge::embedding::OpenAIEmbedding;
16//!
17//! // From environment variable OPENAI_API_KEY
18//! let provider = OpenAIEmbedding::from_env()?;
19//!
20//! // Or with explicit key
21//! let provider = OpenAIEmbedding::new("sk-...", None);
22//!
23//! let embedding = provider.embed("Hello, world!").await?;
24//! ```
25
26use super::EmbeddingProvider;
27use crate::error::{Error, Result};
28use serde::{Deserialize, Serialize};
29use std::any::Any;
30use std::collections::HashMap;
31use std::sync::Arc;
32use std::sync::atomic::{AtomicU64, Ordering};
33use std::time::{Duration, Instant};
34use tokio::sync::{RwLock, Semaphore};
35use tracing::{debug, warn};
36
37/// OpenAI embedding models.
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
39pub enum OpenAIModel {
40    /// text-embedding-3-small: 1536 dimensions, lowest cost
41    #[default]
42    TextEmbedding3Small,
43    /// text-embedding-3-large: 3072 dimensions, highest quality
44    TextEmbedding3Large,
45    /// text-embedding-ada-002: 1536 dimensions, legacy model
46    TextEmbeddingAda002,
47}
48
49impl OpenAIModel {
50    /// Get the model string for API requests.
51    pub fn as_str(&self) -> &'static str {
52        match self {
53            OpenAIModel::TextEmbedding3Small => "text-embedding-3-small",
54            OpenAIModel::TextEmbedding3Large => "text-embedding-3-large",
55            OpenAIModel::TextEmbeddingAda002 => "text-embedding-ada-002",
56        }
57    }
58
59    /// Get default dimensions for this model.
60    pub fn default_dimensions(&self) -> usize {
61        match self {
62            OpenAIModel::TextEmbedding3Small => 1536,
63            OpenAIModel::TextEmbedding3Large => 3072,
64            OpenAIModel::TextEmbeddingAda002 => 1536,
65        }
66    }
67
68    /// Whether this model supports custom dimensions.
69    pub fn supports_custom_dimensions(&self) -> bool {
70        matches!(
71            self,
72            OpenAIModel::TextEmbedding3Small | OpenAIModel::TextEmbedding3Large
73        )
74    }
75
76    /// Parse from string.
77    pub fn from_str(s: &str) -> Option<Self> {
78        match s {
79            "text-embedding-3-small" => Some(OpenAIModel::TextEmbedding3Small),
80            "text-embedding-3-large" => Some(OpenAIModel::TextEmbedding3Large),
81            "text-embedding-ada-002" => Some(OpenAIModel::TextEmbeddingAda002),
82            _ => None,
83        }
84    }
85}
86
87/// Configuration for the OpenAI embedding provider.
88#[derive(Debug, Clone)]
89pub struct OpenAIConfig {
90    /// The model to use.
91    pub model: OpenAIModel,
92    /// Custom dimensions (only for v3 models).
93    pub dimensions: Option<usize>,
94    /// Maximum retries for failed requests.
95    pub max_retries: u32,
96    /// Base delay for exponential backoff (milliseconds).
97    pub retry_base_delay_ms: u64,
98    /// Maximum concurrent requests.
99    pub max_concurrent_requests: usize,
100    /// Cache capacity (number of embeddings to cache).
101    pub cache_capacity: usize,
102    /// Request timeout in seconds.
103    pub timeout_secs: u64,
104    /// Custom API base URL (for proxies or Azure).
105    pub api_base: Option<String>,
106}
107
108impl Default for OpenAIConfig {
109    fn default() -> Self {
110        Self {
111            model: OpenAIModel::default(),
112            dimensions: None,
113            max_retries: 3,
114            retry_base_delay_ms: 1000,
115            max_concurrent_requests: 10,
116            cache_capacity: 10_000,
117            timeout_secs: 30,
118            api_base: None,
119        }
120    }
121}
122
123/// Token usage statistics.
124#[derive(Debug, Default)]
125pub struct UsageStats {
126    /// Total prompt tokens used.
127    pub prompt_tokens: AtomicU64,
128    /// Total requests made.
129    pub requests: AtomicU64,
130    /// Cache hits.
131    pub cache_hits: AtomicU64,
132    /// Cache misses.
133    pub cache_misses: AtomicU64,
134    /// Failed requests.
135    pub failures: AtomicU64,
136}
137
138impl UsageStats {
139    /// Get a snapshot of current stats.
140    pub fn snapshot(&self) -> UsageSnapshot {
141        UsageSnapshot {
142            prompt_tokens: self.prompt_tokens.load(Ordering::Relaxed),
143            requests: self.requests.load(Ordering::Relaxed),
144            cache_hits: self.cache_hits.load(Ordering::Relaxed),
145            cache_misses: self.cache_misses.load(Ordering::Relaxed),
146            failures: self.failures.load(Ordering::Relaxed),
147        }
148    }
149}
150
151/// Snapshot of usage statistics.
152#[derive(Debug, Clone)]
153pub struct UsageSnapshot {
154    /// Total prompt tokens used.
155    pub prompt_tokens: u64,
156    /// Total API requests made.
157    pub requests: u64,
158    /// Number of cache hits.
159    pub cache_hits: u64,
160    /// Number of cache misses.
161    pub cache_misses: u64,
162    /// Number of failed requests.
163    pub failures: u64,
164}
165
166impl UsageSnapshot {
167    /// Calculate cache hit rate.
168    pub fn cache_hit_rate(&self) -> f64 {
169        let total = self.cache_hits + self.cache_misses;
170        if total == 0 {
171            0.0
172        } else {
173            self.cache_hits as f64 / total as f64
174        }
175    }
176
177    /// Estimate cost in USD (approximate, based on public pricing).
178    pub fn estimated_cost_usd(&self, model: OpenAIModel) -> f64 {
179        let cost_per_million = match model {
180            OpenAIModel::TextEmbedding3Small => 0.02,
181            OpenAIModel::TextEmbedding3Large => 0.13,
182            OpenAIModel::TextEmbeddingAda002 => 0.10,
183        };
184        (self.prompt_tokens as f64 / 1_000_000.0) * cost_per_million
185    }
186}
187
188/// Cached embedding entry.
189struct CacheEntry {
190    embedding: Vec<f32>,
191    created_at: Instant,
192}
193
194/// LRU-ish cache for embeddings.
195struct EmbeddingCache {
196    entries: HashMap<String, CacheEntry>,
197    capacity: usize,
198    ttl: Duration,
199}
200
201impl EmbeddingCache {
202    fn new(capacity: usize) -> Self {
203        Self {
204            entries: HashMap::with_capacity(capacity),
205            capacity,
206            ttl: Duration::from_secs(3600), // 1 hour TTL
207        }
208    }
209
210    fn get(&self, key: &str) -> Option<Vec<f32>> {
211        self.entries.get(key).and_then(|entry| {
212            if entry.created_at.elapsed() < self.ttl {
213                Some(entry.embedding.clone())
214            } else {
215                None
216            }
217        })
218    }
219
220    fn insert(&mut self, key: String, embedding: Vec<f32>) {
221        // Simple eviction: remove expired entries if at capacity
222        if self.entries.len() >= self.capacity {
223            self.evict_expired();
224        }
225
226        // If still at capacity, remove oldest
227        if self.entries.len() >= self.capacity {
228            if let Some(oldest_key) = self
229                .entries
230                .iter()
231                .min_by_key(|(_, v)| v.created_at)
232                .map(|(k, _)| k.clone())
233            {
234                self.entries.remove(&oldest_key);
235            }
236        }
237
238        self.entries.insert(
239            key,
240            CacheEntry {
241                embedding,
242                created_at: Instant::now(),
243            },
244        );
245    }
246
247    fn evict_expired(&mut self) {
248        self.entries
249            .retain(|_, entry| entry.created_at.elapsed() < self.ttl);
250    }
251}
252
253/// OpenAI embedding provider with production features.
254pub struct OpenAIEmbedding {
255    api_key: String,
256    config: OpenAIConfig,
257    client: reqwest::Client,
258    cache: Arc<RwLock<EmbeddingCache>>,
259    semaphore: Arc<Semaphore>,
260    stats: Arc<UsageStats>,
261    effective_dimensions: usize,
262}
263
264impl OpenAIEmbedding {
265    /// Create a new OpenAI embedding provider.
266    pub fn new(api_key: impl Into<String>, model: Option<String>) -> Self {
267        let mut config = OpenAIConfig::default();
268        if let Some(model_str) = model {
269            if let Some(model) = OpenAIModel::from_str(&model_str) {
270                config.model = model;
271            }
272        }
273        Self::with_config(api_key, config)
274    }
275
276    /// Create from OPENAI_API_KEY environment variable.
277    pub fn from_env() -> Result<Self> {
278        let api_key = std::env::var("OPENAI_API_KEY")
279            .map_err(|_| Error::embedding("OPENAI_API_KEY environment variable not set"))?;
280        Ok(Self::new(api_key, None))
281    }
282
283    /// Create with custom configuration.
284    pub fn with_config(api_key: impl Into<String>, config: OpenAIConfig) -> Self {
285        let effective_dimensions = config
286            .dimensions
287            .unwrap_or_else(|| config.model.default_dimensions());
288
289        let client = reqwest::Client::builder()
290            .timeout(Duration::from_secs(config.timeout_secs))
291            .pool_max_idle_per_host(config.max_concurrent_requests)
292            .build()
293            .expect("Failed to create HTTP client");
294
295        Self {
296            api_key: api_key.into(),
297            effective_dimensions,
298            cache: Arc::new(RwLock::new(EmbeddingCache::new(config.cache_capacity))),
299            semaphore: Arc::new(Semaphore::new(config.max_concurrent_requests)),
300            stats: Arc::new(UsageStats::default()),
301            client,
302            config,
303        }
304    }
305
306    /// Set custom dimensions (for text-embedding-3-* models).
307    pub fn with_dimensions(mut self, dimensions: usize) -> Self {
308        if self.config.model.supports_custom_dimensions() {
309            self.config.dimensions = Some(dimensions);
310            self.effective_dimensions = dimensions;
311        }
312        self
313    }
314
315    /// Get usage statistics.
316    pub fn stats(&self) -> UsageSnapshot {
317        self.stats.snapshot()
318    }
319
320    /// Get the API base URL.
321    fn api_url(&self) -> String {
322        self.config
323            .api_base
324            .clone()
325            .unwrap_or_else(|| "https://api.openai.com/v1".to_string())
326            + "/embeddings"
327    }
328
329    /// Compute cache key for text.
330    fn cache_key(&self, text: &str) -> String {
331        use std::collections::hash_map::DefaultHasher;
332        use std::hash::{Hash, Hasher};
333
334        let mut hasher = DefaultHasher::new();
335        self.config.model.as_str().hash(&mut hasher);
336        self.effective_dimensions.hash(&mut hasher);
337        text.hash(&mut hasher);
338        format!("{:x}", hasher.finish())
339    }
340
341    /// Execute request with retries.
342    async fn execute_with_retry(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
343        let mut last_error = None;
344
345        for attempt in 0..=self.config.max_retries {
346            if attempt > 0 {
347                let delay = self.config.retry_base_delay_ms * 2u64.pow(attempt - 1);
348                debug!(attempt, delay_ms = delay, "Retrying after delay");
349                tokio::time::sleep(Duration::from_millis(delay)).await;
350            }
351
352            // Acquire semaphore permit for rate limiting
353            let _permit = self
354                .semaphore
355                .acquire()
356                .await
357                .map_err(|_| Error::embedding("Semaphore closed"))?;
358
359            self.stats.requests.fetch_add(1, Ordering::Relaxed);
360
361            match self.execute_request(request).await {
362                Ok(response) => return Ok(response),
363                Err(e) => {
364                    warn!(attempt, error = %e, "Request failed");
365                    self.stats.failures.fetch_add(1, Ordering::Relaxed);
366
367                    // Don't retry on certain errors
368                    if e.to_string().contains("invalid_api_key")
369                        || e.to_string().contains("insufficient_quota")
370                    {
371                        return Err(e);
372                    }
373
374                    last_error = Some(e);
375                }
376            }
377        }
378
379        Err(last_error.unwrap_or_else(|| Error::embedding("Unknown error")))
380    }
381
382    /// Execute a single request.
383    async fn execute_request(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
384        let response = self
385            .client
386            .post(self.api_url())
387            .header("Authorization", format!("Bearer {}", self.api_key))
388            .header("Content-Type", "application/json")
389            .json(request)
390            .send()
391            .await
392            .map_err(|e| Error::embedding(format!("Request failed: {}", e)))?;
393
394        let status = response.status();
395        let body = response
396            .text()
397            .await
398            .map_err(|e| Error::embedding(format!("Failed to read response: {}", e)))?;
399
400        if !status.is_success() {
401            let error: std::result::Result<ErrorResponse, _> = serde_json::from_str(&body);
402            return Err(match error {
403                Ok(e) => Error::embedding(format!("OpenAI API error: {}", e.error.message)),
404                Err(_) => Error::embedding(format!("API error ({}): {}", status, body)),
405            });
406        }
407
408        serde_json::from_str(&body)
409            .map_err(|e| Error::embedding(format!("Failed to parse response: {}", e)))
410    }
411
412    /// Embed texts, using cache where possible.
413    async fn embed_with_cache(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
414        if texts.is_empty() {
415            return Ok(Vec::new());
416        }
417
418        let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
419        let mut uncached_indices = Vec::new();
420        let mut uncached_texts = Vec::new();
421
422        // Check cache first
423        {
424            let cache = self.cache.read().await;
425            for (i, text) in texts.iter().enumerate() {
426                let key = self.cache_key(text);
427                if let Some(embedding) = cache.get(&key) {
428                    results[i] = Some(embedding);
429                    self.stats.cache_hits.fetch_add(1, Ordering::Relaxed);
430                } else {
431                    uncached_indices.push(i);
432                    uncached_texts.push(*text);
433                    self.stats.cache_misses.fetch_add(1, Ordering::Relaxed);
434                }
435            }
436        }
437
438        // Fetch uncached embeddings
439        if !uncached_texts.is_empty() {
440            debug!(
441                count = uncached_texts.len(),
442                cached = texts.len() - uncached_texts.len(),
443                "Fetching embeddings from API"
444            );
445
446            let request = EmbeddingRequest {
447                model: self.config.model.as_str().to_string(),
448                input: uncached_texts.iter().map(|s| s.to_string()).collect(),
449                dimensions: if self.config.model.supports_custom_dimensions() {
450                    Some(self.effective_dimensions)
451                } else {
452                    None
453                },
454            };
455
456            let response = self.execute_with_retry(&request).await?;
457
458            // Track token usage
459            self.stats
460                .prompt_tokens
461                .fetch_add(response.usage.prompt_tokens as u64, Ordering::Relaxed);
462
463            // Sort by index to maintain order
464            let mut data = response.data;
465            data.sort_by_key(|d| d.index);
466
467            // Update cache and results
468            {
469                let mut cache = self.cache.write().await;
470                for (data_idx, embedding_data) in data.into_iter().enumerate() {
471                    let original_idx = uncached_indices[data_idx];
472                    let text = uncached_texts[data_idx];
473                    let key = self.cache_key(text);
474
475                    cache.insert(key, embedding_data.embedding.clone());
476                    results[original_idx] = Some(embedding_data.embedding);
477                }
478            }
479        }
480
481        // Unwrap all results (should all be Some now)
482        results
483            .into_iter()
484            .enumerate()
485            .map(|(i, opt)| {
486                opt.ok_or_else(|| Error::embedding(format!("Missing embedding for index {}", i)))
487            })
488            .collect()
489    }
490}
491
492#[derive(Serialize)]
493struct EmbeddingRequest {
494    model: String,
495    input: Vec<String>,
496    #[serde(skip_serializing_if = "Option::is_none")]
497    dimensions: Option<usize>,
498}
499
500#[derive(Deserialize)]
501struct EmbeddingResponse {
502    data: Vec<EmbeddingData>,
503    #[allow(dead_code)]
504    model: String,
505    usage: Usage,
506}
507
508#[derive(Deserialize)]
509struct EmbeddingData {
510    embedding: Vec<f32>,
511    index: usize,
512}
513
514#[derive(Deserialize)]
515struct Usage {
516    prompt_tokens: usize,
517    #[allow(dead_code)]
518    total_tokens: usize,
519}
520
521#[derive(Deserialize)]
522struct ErrorResponse {
523    error: ApiError,
524}
525
526#[derive(Deserialize)]
527struct ApiError {
528    message: String,
529    #[allow(dead_code)]
530    #[serde(rename = "type")]
531    error_type: String,
532}
533
534#[async_trait::async_trait]
535impl EmbeddingProvider for OpenAIEmbedding {
536    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
537        let embeddings = self.embed_with_cache(&[text]).await?;
538        embeddings
539            .into_iter()
540            .next()
541            .ok_or_else(|| Error::embedding("No embedding returned"))
542    }
543
544    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
545        // OpenAI has a limit of ~8000 tokens per batch
546        // Split into chunks of 100 texts to be safe
547        const BATCH_SIZE: usize = 100;
548
549        if texts.len() <= BATCH_SIZE {
550            return self.embed_with_cache(texts).await;
551        }
552
553        let mut all_embeddings = Vec::with_capacity(texts.len());
554
555        for chunk in texts.chunks(BATCH_SIZE) {
556            let embeddings = self.embed_with_cache(chunk).await?;
557            all_embeddings.extend(embeddings);
558        }
559
560        Ok(all_embeddings)
561    }
562
563    fn dimensions(&self) -> usize {
564        self.effective_dimensions
565    }
566
567    fn as_any(&self) -> &dyn Any {
568        self
569    }
570}
571
572#[cfg(test)]
573mod tests {
574    use super::*;
575
576    #[test]
577    fn test_model_properties() {
578        assert_eq!(OpenAIModel::TextEmbedding3Small.default_dimensions(), 1536);
579        assert_eq!(OpenAIModel::TextEmbedding3Large.default_dimensions(), 3072);
580        assert!(OpenAIModel::TextEmbedding3Small.supports_custom_dimensions());
581        assert!(!OpenAIModel::TextEmbeddingAda002.supports_custom_dimensions());
582    }
583
584    #[test]
585    fn test_model_parsing() {
586        assert_eq!(
587            OpenAIModel::from_str("text-embedding-3-small"),
588            Some(OpenAIModel::TextEmbedding3Small)
589        );
590        assert_eq!(OpenAIModel::from_str("unknown-model"), None);
591    }
592
593    #[test]
594    fn test_openai_dimensions() {
595        let provider = OpenAIEmbedding::new("test-key", None);
596        assert_eq!(provider.dimensions(), 1536);
597
598        let provider = OpenAIEmbedding::new("test-key", Some("text-embedding-3-large".to_string()));
599        assert_eq!(provider.dimensions(), 3072);
600    }
601
602    #[test]
603    fn test_custom_dimensions() {
604        let provider = OpenAIEmbedding::new("test-key", Some("text-embedding-3-small".to_string()))
605            .with_dimensions(512);
606        assert_eq!(provider.dimensions(), 512);
607
608        // Ada doesn't support custom dimensions
609        let provider = OpenAIEmbedding::new("test-key", Some("text-embedding-ada-002".to_string()))
610            .with_dimensions(512);
611        assert_eq!(provider.dimensions(), 1536); // Unchanged
612    }
613
614    #[test]
615    fn test_config_defaults() {
616        let config = OpenAIConfig::default();
617        assert_eq!(config.max_retries, 3);
618        assert_eq!(config.max_concurrent_requests, 10);
619        assert_eq!(config.cache_capacity, 10_000);
620    }
621
622    #[test]
623    fn test_usage_stats() {
624        let stats = UsageStats::default();
625        stats.prompt_tokens.fetch_add(1000, Ordering::Relaxed);
626        stats.cache_hits.fetch_add(80, Ordering::Relaxed);
627        stats.cache_misses.fetch_add(20, Ordering::Relaxed);
628
629        let snapshot = stats.snapshot();
630        assert_eq!(snapshot.prompt_tokens, 1000);
631        assert!((snapshot.cache_hit_rate() - 0.8).abs() < 0.001);
632    }
633
634    #[test]
635    fn test_cost_estimation() {
636        let snapshot = UsageSnapshot {
637            prompt_tokens: 1_000_000,
638            requests: 100,
639            cache_hits: 50,
640            cache_misses: 50,
641            failures: 0,
642        };
643
644        let cost_small = snapshot.estimated_cost_usd(OpenAIModel::TextEmbedding3Small);
645        let cost_large = snapshot.estimated_cost_usd(OpenAIModel::TextEmbedding3Large);
646
647        assert!((cost_small - 0.02).abs() < 0.001);
648        assert!((cost_large - 0.13).abs() < 0.001);
649    }
650
651    #[tokio::test]
652    async fn test_cache_operations() {
653        let mut cache = EmbeddingCache::new(3);
654
655        cache.insert("key1".to_string(), vec![1.0, 2.0, 3.0]);
656        cache.insert("key2".to_string(), vec![4.0, 5.0, 6.0]);
657
658        assert_eq!(cache.get("key1"), Some(vec![1.0, 2.0, 3.0]));
659        assert_eq!(cache.get("key2"), Some(vec![4.0, 5.0, 6.0]));
660        assert_eq!(cache.get("key3"), None);
661
662        // Test capacity eviction
663        cache.insert("key3".to_string(), vec![7.0, 8.0, 9.0]);
664        cache.insert("key4".to_string(), vec![10.0, 11.0, 12.0]);
665
666        // One of the older keys should be evicted
667        assert_eq!(cache.entries.len(), 3);
668    }
669
670    #[test]
671    fn test_cache_key_consistency() {
672        let provider = OpenAIEmbedding::new("test-key", None);
673
674        let key1 = provider.cache_key("hello world");
675        let key2 = provider.cache_key("hello world");
676        let key3 = provider.cache_key("different text");
677
678        assert_eq!(key1, key2);
679        assert_ne!(key1, key3);
680    }
681}