agcodex_core/embeddings/providers/
openai.rs

1//! OpenAI embeddings provider - completely separate from chat models
2//!
3//! Supports:
4//! - text-embedding-3-small (256-1536 dimensions)
5//! - text-embedding-3-large (256-3072 dimensions)
6//! - Batch processing (up to 2048 inputs)
7//! - Uses OPENAI_EMBEDDING_KEY (not OPENAI_API_KEY)
8
9use super::super::EmbeddingError;
10use super::super::EmbeddingProvider;
11use super::super::EmbeddingVector;
12use reqwest::Client;
13use serde::Deserialize;
14use serde::Serialize;
15
16/// OpenAI embedding provider
17pub struct OpenAIProvider {
18    client: Client,
19    api_key: String,
20    model: String,
21    dimensions: Option<usize>,
22    api_endpoint: Option<String>,
23}
24
25impl OpenAIProvider {
26    /// Create a new OpenAI provider
27    pub fn new(
28        api_key: String,
29        model: String,
30        dimensions: Option<usize>,
31        api_endpoint: Option<String>,
32    ) -> Self {
33        Self {
34            client: Client::new(),
35            api_key,
36            model,
37            dimensions,
38            api_endpoint,
39        }
40    }
41}
42
43#[derive(Debug, Serialize)]
44struct OpenAIRequest {
45    model: String,
46    input: Vec<String>,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    dimensions: Option<usize>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    encoding_format: Option<String>,
51}
52
53#[derive(Debug, Deserialize)]
54struct OpenAIResponse {
55    data: Vec<OpenAIEmbedding>,
56    _model: String,
57    _usage: OpenAIUsage,
58}
59
60#[derive(Debug, Deserialize)]
61struct OpenAIEmbedding {
62    embedding: Vec<f32>,
63    index: usize,
64}
65
66#[derive(Debug, Deserialize)]
67struct OpenAIUsage {
68    _prompt_tokens: usize,
69    _total_tokens: usize,
70}
71
72#[derive(Debug, Deserialize)]
73struct OpenAIError {
74    error: OpenAIErrorDetail,
75}
76
77#[derive(Debug, Deserialize)]
78struct OpenAIErrorDetail {
79    message: String,
80    #[serde(rename = "type")]
81    error_type: String,
82    _code: Option<String>,
83}
84
85#[async_trait::async_trait]
86impl EmbeddingProvider for OpenAIProvider {
87    fn model_id(&self) -> String {
88        format!("openai:{}", self.model)
89    }
90
91    fn dimensions(&self) -> usize {
92        // Return configured dimensions or model defaults
93        self.dimensions.unwrap_or({
94            match self.model.as_str() {
95                "text-embedding-3-small" => 1536,
96                "text-embedding-3-large" => 3072,
97                "text-embedding-ada-002" => 1536,
98                _ => 1536,
99            }
100        })
101    }
102
103    async fn embed(&self, text: &str) -> Result<EmbeddingVector, EmbeddingError> {
104        self.embed_batch(&[text.to_string()])
105            .await
106            .map(|mut vecs| vecs.pop().unwrap_or_default())
107    }
108
109    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<EmbeddingVector>, EmbeddingError> {
110        if texts.is_empty() {
111            return Ok(vec![]);
112        }
113
114        // OpenAI has a limit of 2048 inputs per batch
115        const MAX_BATCH_SIZE: usize = 2048;
116        if texts.len() > MAX_BATCH_SIZE {
117            // Process in chunks
118            let mut all_embeddings = Vec::with_capacity(texts.len());
119            for chunk in texts.chunks(MAX_BATCH_SIZE) {
120                let chunk_embeddings = self.embed_batch_internal(chunk).await?;
121                all_embeddings.extend(chunk_embeddings);
122            }
123            return Ok(all_embeddings);
124        }
125
126        self.embed_batch_internal(texts).await
127    }
128
129    fn is_available(&self) -> bool {
130        !self.api_key.is_empty()
131    }
132}
133
134impl OpenAIProvider {
135    async fn embed_batch_internal(
136        &self,
137        texts: &[String],
138    ) -> Result<Vec<EmbeddingVector>, EmbeddingError> {
139        let endpoint = self
140            .api_endpoint
141            .as_deref()
142            .unwrap_or("https://api.openai.com/v1/embeddings");
143
144        let request = OpenAIRequest {
145            model: self.model.clone(),
146            input: texts.to_vec(),
147            dimensions: self.dimensions,
148            encoding_format: Some("float".to_string()),
149        };
150
151        let response = self
152            .client
153            .post(endpoint)
154            .header("Authorization", format!("Bearer {}", self.api_key))
155            .header("Content-Type", "application/json")
156            .json(&request)
157            .send()
158            .await
159            .map_err(|e| EmbeddingError::ApiError(format!("Request failed: {}", e)))?;
160
161        let status = response.status();
162        if !status.is_success() {
163            let error_text = response
164                .text()
165                .await
166                .unwrap_or_else(|_| "Unknown error".to_string());
167
168            // Try to parse OpenAI error format
169            if let Ok(error) = serde_json::from_str::<OpenAIError>(&error_text) {
170                return Err(EmbeddingError::ApiError(format!(
171                    "OpenAI API error ({}): {} - {}",
172                    status, error.error.error_type, error.error.message
173                )));
174            }
175
176            return Err(EmbeddingError::ApiError(format!(
177                "OpenAI API error ({}): {}",
178                status, error_text
179            )));
180        }
181
182        let openai_response: OpenAIResponse = response
183            .json()
184            .await
185            .map_err(|e| EmbeddingError::ApiError(format!("Failed to parse response: {}", e)))?;
186
187        // Sort by index to ensure correct order
188        let mut embeddings = openai_response.data;
189        embeddings.sort_by_key(|e| e.index);
190
191        // Validate dimensions
192        let expected_dims = self.dimensions();
193        for embedding in &embeddings {
194            if embedding.embedding.len() != expected_dims {
195                return Err(EmbeddingError::DimensionMismatch {
196                    expected: expected_dims,
197                    actual: embedding.embedding.len(),
198                });
199            }
200        }
201
202        Ok(embeddings.into_iter().map(|e| e.embedding).collect())
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn test_model_id() {
212        let provider = OpenAIProvider::new(
213            "test-key".to_string(),
214            "text-embedding-3-small".to_string(),
215            Some(256),
216            None,
217        );
218        assert_eq!(provider.model_id(), "openai:text-embedding-3-small");
219    }
220
221    #[test]
222    fn test_dimensions() {
223        let provider = OpenAIProvider::new(
224            "test-key".to_string(),
225            "text-embedding-3-small".to_string(),
226            Some(256),
227            None,
228        );
229        assert_eq!(provider.dimensions(), 256);
230
231        let provider_default = OpenAIProvider::new(
232            "test-key".to_string(),
233            "text-embedding-3-large".to_string(),
234            None,
235            None,
236        );
237        assert_eq!(provider_default.dimensions(), 3072);
238    }
239
240    #[test]
241    fn test_is_available() {
242        let provider = OpenAIProvider::new(
243            "test-key".to_string(),
244            "text-embedding-3-small".to_string(),
245            None,
246            None,
247        );
248        assert!(provider.is_available());
249
250        let provider_empty = OpenAIProvider::new(
251            String::new(),
252            "text-embedding-3-small".to_string(),
253            None,
254            None,
255        );
256        assert!(!provider_empty.is_available());
257    }
258}