1use super::EmbeddingProvider;
27use crate::error::{Error, Result};
28use serde::{Deserialize, Serialize};
29use std::any::Any;
30use std::collections::HashMap;
31use std::sync::Arc;
32use std::sync::atomic::{AtomicU64, Ordering};
33use std::time::{Duration, Instant};
34use tokio::sync::{RwLock, Semaphore};
35use tracing::{debug, warn};
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
39pub enum OpenAIModel {
40 #[default]
42 TextEmbedding3Small,
43 TextEmbedding3Large,
45 TextEmbeddingAda002,
47}
48
49impl OpenAIModel {
50 pub fn as_str(&self) -> &'static str {
52 match self {
53 OpenAIModel::TextEmbedding3Small => "text-embedding-3-small",
54 OpenAIModel::TextEmbedding3Large => "text-embedding-3-large",
55 OpenAIModel::TextEmbeddingAda002 => "text-embedding-ada-002",
56 }
57 }
58
59 pub fn default_dimensions(&self) -> usize {
61 match self {
62 OpenAIModel::TextEmbedding3Small => 1536,
63 OpenAIModel::TextEmbedding3Large => 3072,
64 OpenAIModel::TextEmbeddingAda002 => 1536,
65 }
66 }
67
68 pub fn supports_custom_dimensions(&self) -> bool {
70 matches!(
71 self,
72 OpenAIModel::TextEmbedding3Small | OpenAIModel::TextEmbedding3Large
73 )
74 }
75
76 pub fn from_str(s: &str) -> Option<Self> {
78 match s {
79 "text-embedding-3-small" => Some(OpenAIModel::TextEmbedding3Small),
80 "text-embedding-3-large" => Some(OpenAIModel::TextEmbedding3Large),
81 "text-embedding-ada-002" => Some(OpenAIModel::TextEmbeddingAda002),
82 _ => None,
83 }
84 }
85}
86
87#[derive(Debug, Clone)]
89pub struct OpenAIConfig {
90 pub model: OpenAIModel,
92 pub dimensions: Option<usize>,
94 pub max_retries: u32,
96 pub retry_base_delay_ms: u64,
98 pub max_concurrent_requests: usize,
100 pub cache_capacity: usize,
102 pub timeout_secs: u64,
104 pub api_base: Option<String>,
106}
107
108impl Default for OpenAIConfig {
109 fn default() -> Self {
110 Self {
111 model: OpenAIModel::default(),
112 dimensions: None,
113 max_retries: 3,
114 retry_base_delay_ms: 1000,
115 max_concurrent_requests: 10,
116 cache_capacity: 10_000,
117 timeout_secs: 30,
118 api_base: None,
119 }
120 }
121}
122
123#[derive(Debug, Default)]
125pub struct UsageStats {
126 pub prompt_tokens: AtomicU64,
128 pub requests: AtomicU64,
130 pub cache_hits: AtomicU64,
132 pub cache_misses: AtomicU64,
134 pub failures: AtomicU64,
136}
137
138impl UsageStats {
139 pub fn snapshot(&self) -> UsageSnapshot {
141 UsageSnapshot {
142 prompt_tokens: self.prompt_tokens.load(Ordering::Relaxed),
143 requests: self.requests.load(Ordering::Relaxed),
144 cache_hits: self.cache_hits.load(Ordering::Relaxed),
145 cache_misses: self.cache_misses.load(Ordering::Relaxed),
146 failures: self.failures.load(Ordering::Relaxed),
147 }
148 }
149}
150
151#[derive(Debug, Clone)]
153pub struct UsageSnapshot {
154 pub prompt_tokens: u64,
156 pub requests: u64,
158 pub cache_hits: u64,
160 pub cache_misses: u64,
162 pub failures: u64,
164}
165
166impl UsageSnapshot {
167 pub fn cache_hit_rate(&self) -> f64 {
169 let total = self.cache_hits + self.cache_misses;
170 if total == 0 {
171 0.0
172 } else {
173 self.cache_hits as f64 / total as f64
174 }
175 }
176
177 pub fn estimated_cost_usd(&self, model: OpenAIModel) -> f64 {
179 let cost_per_million = match model {
180 OpenAIModel::TextEmbedding3Small => 0.02,
181 OpenAIModel::TextEmbedding3Large => 0.13,
182 OpenAIModel::TextEmbeddingAda002 => 0.10,
183 };
184 (self.prompt_tokens as f64 / 1_000_000.0) * cost_per_million
185 }
186}
187
188struct CacheEntry {
190 embedding: Vec<f32>,
191 created_at: Instant,
192}
193
194struct EmbeddingCache {
196 entries: HashMap<String, CacheEntry>,
197 capacity: usize,
198 ttl: Duration,
199}
200
201impl EmbeddingCache {
202 fn new(capacity: usize) -> Self {
203 Self {
204 entries: HashMap::with_capacity(capacity),
205 capacity,
206 ttl: Duration::from_secs(3600), }
208 }
209
210 fn get(&self, key: &str) -> Option<Vec<f32>> {
211 self.entries.get(key).and_then(|entry| {
212 if entry.created_at.elapsed() < self.ttl {
213 Some(entry.embedding.clone())
214 } else {
215 None
216 }
217 })
218 }
219
220 fn insert(&mut self, key: String, embedding: Vec<f32>) {
221 if self.entries.len() >= self.capacity {
223 self.evict_expired();
224 }
225
226 if self.entries.len() >= self.capacity {
228 if let Some(oldest_key) = self
229 .entries
230 .iter()
231 .min_by_key(|(_, v)| v.created_at)
232 .map(|(k, _)| k.clone())
233 {
234 self.entries.remove(&oldest_key);
235 }
236 }
237
238 self.entries.insert(
239 key,
240 CacheEntry {
241 embedding,
242 created_at: Instant::now(),
243 },
244 );
245 }
246
247 fn evict_expired(&mut self) {
248 self.entries
249 .retain(|_, entry| entry.created_at.elapsed() < self.ttl);
250 }
251}
252
253pub struct OpenAIEmbedding {
255 api_key: String,
256 config: OpenAIConfig,
257 client: reqwest::Client,
258 cache: Arc<RwLock<EmbeddingCache>>,
259 semaphore: Arc<Semaphore>,
260 stats: Arc<UsageStats>,
261 effective_dimensions: usize,
262}
263
264impl OpenAIEmbedding {
265 pub fn new(api_key: impl Into<String>, model: Option<String>) -> Self {
267 let mut config = OpenAIConfig::default();
268 if let Some(model_str) = model {
269 if let Some(model) = OpenAIModel::from_str(&model_str) {
270 config.model = model;
271 }
272 }
273 Self::with_config(api_key, config)
274 }
275
276 pub fn from_env() -> Result<Self> {
278 let api_key = std::env::var("OPENAI_API_KEY")
279 .map_err(|_| Error::embedding("OPENAI_API_KEY environment variable not set"))?;
280 Ok(Self::new(api_key, None))
281 }
282
283 pub fn with_config(api_key: impl Into<String>, config: OpenAIConfig) -> Self {
285 let effective_dimensions = config
286 .dimensions
287 .unwrap_or_else(|| config.model.default_dimensions());
288
289 let client = reqwest::Client::builder()
290 .timeout(Duration::from_secs(config.timeout_secs))
291 .pool_max_idle_per_host(config.max_concurrent_requests)
292 .build()
293 .expect("Failed to create HTTP client");
294
295 Self {
296 api_key: api_key.into(),
297 effective_dimensions,
298 cache: Arc::new(RwLock::new(EmbeddingCache::new(config.cache_capacity))),
299 semaphore: Arc::new(Semaphore::new(config.max_concurrent_requests)),
300 stats: Arc::new(UsageStats::default()),
301 client,
302 config,
303 }
304 }
305
306 pub fn with_dimensions(mut self, dimensions: usize) -> Self {
308 if self.config.model.supports_custom_dimensions() {
309 self.config.dimensions = Some(dimensions);
310 self.effective_dimensions = dimensions;
311 }
312 self
313 }
314
315 pub fn stats(&self) -> UsageSnapshot {
317 self.stats.snapshot()
318 }
319
320 fn api_url(&self) -> String {
322 self.config
323 .api_base
324 .clone()
325 .unwrap_or_else(|| "https://api.openai.com/v1".to_string())
326 + "/embeddings"
327 }
328
329 fn cache_key(&self, text: &str) -> String {
331 use std::collections::hash_map::DefaultHasher;
332 use std::hash::{Hash, Hasher};
333
334 let mut hasher = DefaultHasher::new();
335 self.config.model.as_str().hash(&mut hasher);
336 self.effective_dimensions.hash(&mut hasher);
337 text.hash(&mut hasher);
338 format!("{:x}", hasher.finish())
339 }
340
341 async fn execute_with_retry(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
343 let mut last_error = None;
344
345 for attempt in 0..=self.config.max_retries {
346 if attempt > 0 {
347 let delay = self.config.retry_base_delay_ms * 2u64.pow(attempt - 1);
348 debug!(attempt, delay_ms = delay, "Retrying after delay");
349 tokio::time::sleep(Duration::from_millis(delay)).await;
350 }
351
352 let _permit = self
354 .semaphore
355 .acquire()
356 .await
357 .map_err(|_| Error::embedding("Semaphore closed"))?;
358
359 self.stats.requests.fetch_add(1, Ordering::Relaxed);
360
361 match self.execute_request(request).await {
362 Ok(response) => return Ok(response),
363 Err(e) => {
364 warn!(attempt, error = %e, "Request failed");
365 self.stats.failures.fetch_add(1, Ordering::Relaxed);
366
367 if e.to_string().contains("invalid_api_key")
369 || e.to_string().contains("insufficient_quota")
370 {
371 return Err(e);
372 }
373
374 last_error = Some(e);
375 }
376 }
377 }
378
379 Err(last_error.unwrap_or_else(|| Error::embedding("Unknown error")))
380 }
381
382 async fn execute_request(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
384 let response = self
385 .client
386 .post(self.api_url())
387 .header("Authorization", format!("Bearer {}", self.api_key))
388 .header("Content-Type", "application/json")
389 .json(request)
390 .send()
391 .await
392 .map_err(|e| Error::embedding(format!("Request failed: {}", e)))?;
393
394 let status = response.status();
395 let body = response
396 .text()
397 .await
398 .map_err(|e| Error::embedding(format!("Failed to read response: {}", e)))?;
399
400 if !status.is_success() {
401 let error: std::result::Result<ErrorResponse, _> = serde_json::from_str(&body);
402 return Err(match error {
403 Ok(e) => Error::embedding(format!("OpenAI API error: {}", e.error.message)),
404 Err(_) => Error::embedding(format!("API error ({}): {}", status, body)),
405 });
406 }
407
408 serde_json::from_str(&body)
409 .map_err(|e| Error::embedding(format!("Failed to parse response: {}", e)))
410 }
411
412 async fn embed_with_cache(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
414 if texts.is_empty() {
415 return Ok(Vec::new());
416 }
417
418 let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
419 let mut uncached_indices = Vec::new();
420 let mut uncached_texts = Vec::new();
421
422 {
424 let cache = self.cache.read().await;
425 for (i, text) in texts.iter().enumerate() {
426 let key = self.cache_key(text);
427 if let Some(embedding) = cache.get(&key) {
428 results[i] = Some(embedding);
429 self.stats.cache_hits.fetch_add(1, Ordering::Relaxed);
430 } else {
431 uncached_indices.push(i);
432 uncached_texts.push(*text);
433 self.stats.cache_misses.fetch_add(1, Ordering::Relaxed);
434 }
435 }
436 }
437
438 if !uncached_texts.is_empty() {
440 debug!(
441 count = uncached_texts.len(),
442 cached = texts.len() - uncached_texts.len(),
443 "Fetching embeddings from API"
444 );
445
446 let request = EmbeddingRequest {
447 model: self.config.model.as_str().to_string(),
448 input: uncached_texts.iter().map(|s| s.to_string()).collect(),
449 dimensions: if self.config.model.supports_custom_dimensions() {
450 Some(self.effective_dimensions)
451 } else {
452 None
453 },
454 };
455
456 let response = self.execute_with_retry(&request).await?;
457
458 self.stats
460 .prompt_tokens
461 .fetch_add(response.usage.prompt_tokens as u64, Ordering::Relaxed);
462
463 let mut data = response.data;
465 data.sort_by_key(|d| d.index);
466
467 {
469 let mut cache = self.cache.write().await;
470 for (data_idx, embedding_data) in data.into_iter().enumerate() {
471 let original_idx = uncached_indices[data_idx];
472 let text = uncached_texts[data_idx];
473 let key = self.cache_key(text);
474
475 cache.insert(key, embedding_data.embedding.clone());
476 results[original_idx] = Some(embedding_data.embedding);
477 }
478 }
479 }
480
481 results
483 .into_iter()
484 .enumerate()
485 .map(|(i, opt)| {
486 opt.ok_or_else(|| Error::embedding(format!("Missing embedding for index {}", i)))
487 })
488 .collect()
489 }
490}
491
492#[derive(Serialize)]
493struct EmbeddingRequest {
494 model: String,
495 input: Vec<String>,
496 #[serde(skip_serializing_if = "Option::is_none")]
497 dimensions: Option<usize>,
498}
499
500#[derive(Deserialize)]
501struct EmbeddingResponse {
502 data: Vec<EmbeddingData>,
503 #[allow(dead_code)]
504 model: String,
505 usage: Usage,
506}
507
508#[derive(Deserialize)]
509struct EmbeddingData {
510 embedding: Vec<f32>,
511 index: usize,
512}
513
514#[derive(Deserialize)]
515struct Usage {
516 prompt_tokens: usize,
517 #[allow(dead_code)]
518 total_tokens: usize,
519}
520
521#[derive(Deserialize)]
522struct ErrorResponse {
523 error: ApiError,
524}
525
526#[derive(Deserialize)]
527struct ApiError {
528 message: String,
529 #[allow(dead_code)]
530 #[serde(rename = "type")]
531 error_type: String,
532}
533
534#[async_trait::async_trait]
535impl EmbeddingProvider for OpenAIEmbedding {
536 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
537 let embeddings = self.embed_with_cache(&[text]).await?;
538 embeddings
539 .into_iter()
540 .next()
541 .ok_or_else(|| Error::embedding("No embedding returned"))
542 }
543
544 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
545 const BATCH_SIZE: usize = 100;
548
549 if texts.len() <= BATCH_SIZE {
550 return self.embed_with_cache(texts).await;
551 }
552
553 let mut all_embeddings = Vec::with_capacity(texts.len());
554
555 for chunk in texts.chunks(BATCH_SIZE) {
556 let embeddings = self.embed_with_cache(chunk).await?;
557 all_embeddings.extend(embeddings);
558 }
559
560 Ok(all_embeddings)
561 }
562
563 fn dimensions(&self) -> usize {
564 self.effective_dimensions
565 }
566
567 fn as_any(&self) -> &dyn Any {
568 self
569 }
570}
571
572#[cfg(test)]
573mod tests {
574 use super::*;
575
576 #[test]
577 fn test_model_properties() {
578 assert_eq!(OpenAIModel::TextEmbedding3Small.default_dimensions(), 1536);
579 assert_eq!(OpenAIModel::TextEmbedding3Large.default_dimensions(), 3072);
580 assert!(OpenAIModel::TextEmbedding3Small.supports_custom_dimensions());
581 assert!(!OpenAIModel::TextEmbeddingAda002.supports_custom_dimensions());
582 }
583
584 #[test]
585 fn test_model_parsing() {
586 assert_eq!(
587 OpenAIModel::from_str("text-embedding-3-small"),
588 Some(OpenAIModel::TextEmbedding3Small)
589 );
590 assert_eq!(OpenAIModel::from_str("unknown-model"), None);
591 }
592
593 #[test]
594 fn test_openai_dimensions() {
595 let provider = OpenAIEmbedding::new("test-key", None);
596 assert_eq!(provider.dimensions(), 1536);
597
598 let provider = OpenAIEmbedding::new("test-key", Some("text-embedding-3-large".to_string()));
599 assert_eq!(provider.dimensions(), 3072);
600 }
601
602 #[test]
603 fn test_custom_dimensions() {
604 let provider = OpenAIEmbedding::new("test-key", Some("text-embedding-3-small".to_string()))
605 .with_dimensions(512);
606 assert_eq!(provider.dimensions(), 512);
607
608 let provider = OpenAIEmbedding::new("test-key", Some("text-embedding-ada-002".to_string()))
610 .with_dimensions(512);
611 assert_eq!(provider.dimensions(), 1536); }
613
614 #[test]
615 fn test_config_defaults() {
616 let config = OpenAIConfig::default();
617 assert_eq!(config.max_retries, 3);
618 assert_eq!(config.max_concurrent_requests, 10);
619 assert_eq!(config.cache_capacity, 10_000);
620 }
621
622 #[test]
623 fn test_usage_stats() {
624 let stats = UsageStats::default();
625 stats.prompt_tokens.fetch_add(1000, Ordering::Relaxed);
626 stats.cache_hits.fetch_add(80, Ordering::Relaxed);
627 stats.cache_misses.fetch_add(20, Ordering::Relaxed);
628
629 let snapshot = stats.snapshot();
630 assert_eq!(snapshot.prompt_tokens, 1000);
631 assert!((snapshot.cache_hit_rate() - 0.8).abs() < 0.001);
632 }
633
634 #[test]
635 fn test_cost_estimation() {
636 let snapshot = UsageSnapshot {
637 prompt_tokens: 1_000_000,
638 requests: 100,
639 cache_hits: 50,
640 cache_misses: 50,
641 failures: 0,
642 };
643
644 let cost_small = snapshot.estimated_cost_usd(OpenAIModel::TextEmbedding3Small);
645 let cost_large = snapshot.estimated_cost_usd(OpenAIModel::TextEmbedding3Large);
646
647 assert!((cost_small - 0.02).abs() < 0.001);
648 assert!((cost_large - 0.13).abs() < 0.001);
649 }
650
651 #[tokio::test]
652 async fn test_cache_operations() {
653 let mut cache = EmbeddingCache::new(3);
654
655 cache.insert("key1".to_string(), vec![1.0, 2.0, 3.0]);
656 cache.insert("key2".to_string(), vec![4.0, 5.0, 6.0]);
657
658 assert_eq!(cache.get("key1"), Some(vec![1.0, 2.0, 3.0]));
659 assert_eq!(cache.get("key2"), Some(vec![4.0, 5.0, 6.0]));
660 assert_eq!(cache.get("key3"), None);
661
662 cache.insert("key3".to_string(), vec![7.0, 8.0, 9.0]);
664 cache.insert("key4".to_string(), vec![10.0, 11.0, 12.0]);
665
666 assert_eq!(cache.entries.len(), 3);
668 }
669
670 #[test]
671 fn test_cache_key_consistency() {
672 let provider = OpenAIEmbedding::new("test-key", None);
673
674 let key1 = provider.cache_key("hello world");
675 let key2 = provider.cache_key("hello world");
676 let key3 = provider.cache_key("different text");
677
678 assert_eq!(key1, key2);
679 assert_ne!(key1, key3);
680 }
681}