agcodex_core/embeddings/providers/
openai.rs1use super::super::EmbeddingError;
10use super::super::EmbeddingProvider;
11use super::super::EmbeddingVector;
12use reqwest::Client;
13use serde::Deserialize;
14use serde::Serialize;
15
16pub struct OpenAIProvider {
18 client: Client,
19 api_key: String,
20 model: String,
21 dimensions: Option<usize>,
22 api_endpoint: Option<String>,
23}
24
25impl OpenAIProvider {
26 pub fn new(
28 api_key: String,
29 model: String,
30 dimensions: Option<usize>,
31 api_endpoint: Option<String>,
32 ) -> Self {
33 Self {
34 client: Client::new(),
35 api_key,
36 model,
37 dimensions,
38 api_endpoint,
39 }
40 }
41}
42
43#[derive(Debug, Serialize)]
44struct OpenAIRequest {
45 model: String,
46 input: Vec<String>,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 dimensions: Option<usize>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 encoding_format: Option<String>,
51}
52
53#[derive(Debug, Deserialize)]
54struct OpenAIResponse {
55 data: Vec<OpenAIEmbedding>,
56 _model: String,
57 _usage: OpenAIUsage,
58}
59
60#[derive(Debug, Deserialize)]
61struct OpenAIEmbedding {
62 embedding: Vec<f32>,
63 index: usize,
64}
65
66#[derive(Debug, Deserialize)]
67struct OpenAIUsage {
68 _prompt_tokens: usize,
69 _total_tokens: usize,
70}
71
72#[derive(Debug, Deserialize)]
73struct OpenAIError {
74 error: OpenAIErrorDetail,
75}
76
77#[derive(Debug, Deserialize)]
78struct OpenAIErrorDetail {
79 message: String,
80 #[serde(rename = "type")]
81 error_type: String,
82 _code: Option<String>,
83}
84
85#[async_trait::async_trait]
86impl EmbeddingProvider for OpenAIProvider {
87 fn model_id(&self) -> String {
88 format!("openai:{}", self.model)
89 }
90
91 fn dimensions(&self) -> usize {
92 self.dimensions.unwrap_or({
94 match self.model.as_str() {
95 "text-embedding-3-small" => 1536,
96 "text-embedding-3-large" => 3072,
97 "text-embedding-ada-002" => 1536,
98 _ => 1536,
99 }
100 })
101 }
102
103 async fn embed(&self, text: &str) -> Result<EmbeddingVector, EmbeddingError> {
104 self.embed_batch(&[text.to_string()])
105 .await
106 .map(|mut vecs| vecs.pop().unwrap_or_default())
107 }
108
109 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<EmbeddingVector>, EmbeddingError> {
110 if texts.is_empty() {
111 return Ok(vec![]);
112 }
113
114 const MAX_BATCH_SIZE: usize = 2048;
116 if texts.len() > MAX_BATCH_SIZE {
117 let mut all_embeddings = Vec::with_capacity(texts.len());
119 for chunk in texts.chunks(MAX_BATCH_SIZE) {
120 let chunk_embeddings = self.embed_batch_internal(chunk).await?;
121 all_embeddings.extend(chunk_embeddings);
122 }
123 return Ok(all_embeddings);
124 }
125
126 self.embed_batch_internal(texts).await
127 }
128
129 fn is_available(&self) -> bool {
130 !self.api_key.is_empty()
131 }
132}
133
134impl OpenAIProvider {
135 async fn embed_batch_internal(
136 &self,
137 texts: &[String],
138 ) -> Result<Vec<EmbeddingVector>, EmbeddingError> {
139 let endpoint = self
140 .api_endpoint
141 .as_deref()
142 .unwrap_or("https://api.openai.com/v1/embeddings");
143
144 let request = OpenAIRequest {
145 model: self.model.clone(),
146 input: texts.to_vec(),
147 dimensions: self.dimensions,
148 encoding_format: Some("float".to_string()),
149 };
150
151 let response = self
152 .client
153 .post(endpoint)
154 .header("Authorization", format!("Bearer {}", self.api_key))
155 .header("Content-Type", "application/json")
156 .json(&request)
157 .send()
158 .await
159 .map_err(|e| EmbeddingError::ApiError(format!("Request failed: {}", e)))?;
160
161 let status = response.status();
162 if !status.is_success() {
163 let error_text = response
164 .text()
165 .await
166 .unwrap_or_else(|_| "Unknown error".to_string());
167
168 if let Ok(error) = serde_json::from_str::<OpenAIError>(&error_text) {
170 return Err(EmbeddingError::ApiError(format!(
171 "OpenAI API error ({}): {} - {}",
172 status, error.error.error_type, error.error.message
173 )));
174 }
175
176 return Err(EmbeddingError::ApiError(format!(
177 "OpenAI API error ({}): {}",
178 status, error_text
179 )));
180 }
181
182 let openai_response: OpenAIResponse = response
183 .json()
184 .await
185 .map_err(|e| EmbeddingError::ApiError(format!("Failed to parse response: {}", e)))?;
186
187 let mut embeddings = openai_response.data;
189 embeddings.sort_by_key(|e| e.index);
190
191 let expected_dims = self.dimensions();
193 for embedding in &embeddings {
194 if embedding.embedding.len() != expected_dims {
195 return Err(EmbeddingError::DimensionMismatch {
196 expected: expected_dims,
197 actual: embedding.embedding.len(),
198 });
199 }
200 }
201
202 Ok(embeddings.into_iter().map(|e| e.embedding).collect())
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn test_model_id() {
212 let provider = OpenAIProvider::new(
213 "test-key".to_string(),
214 "text-embedding-3-small".to_string(),
215 Some(256),
216 None,
217 );
218 assert_eq!(provider.model_id(), "openai:text-embedding-3-small");
219 }
220
221 #[test]
222 fn test_dimensions() {
223 let provider = OpenAIProvider::new(
224 "test-key".to_string(),
225 "text-embedding-3-small".to_string(),
226 Some(256),
227 None,
228 );
229 assert_eq!(provider.dimensions(), 256);
230
231 let provider_default = OpenAIProvider::new(
232 "test-key".to_string(),
233 "text-embedding-3-large".to_string(),
234 None,
235 None,
236 );
237 assert_eq!(provider_default.dimensions(), 3072);
238 }
239
240 #[test]
241 fn test_is_available() {
242 let provider = OpenAIProvider::new(
243 "test-key".to_string(),
244 "text-embedding-3-small".to_string(),
245 None,
246 None,
247 );
248 assert!(provider.is_available());
249
250 let provider_empty = OpenAIProvider::new(
251 String::new(),
252 "text-embedding-3-small".to_string(),
253 None,
254 None,
255 );
256 assert!(!provider_empty.is_available());
257 }
258}