1use std::time::Duration;
2
3use serde::{Deserialize, Serialize};
4use tracing;
5
6use crate::api_key::ApiKey;
7use crate::config::EmbedConfig;
8use crate::error::{EmbedError, Result};
9
10const OPENAI_BASE_URL: &str = "https://api.openai.com";
11const MAX_TOKENS_PER_TEXT: usize = 8191;
12const MAX_BATCH_SIZE: usize = 2048;
13const DEFAULT_MODEL: &str = "text-embedding-3-small";
14
15#[derive(Serialize)]
16struct EmbedRequest<'a> {
17 model: &'a str,
18 input: &'a [String],
19 encoding_format: &'a str,
20}
21
22#[derive(Deserialize)]
23struct EmbedResponse {
24 data: Vec<EmbeddingData>,
25}
26
27#[derive(Deserialize)]
28struct EmbeddingData {
29 index: usize,
30 embedding: Vec<f32>,
31}
32
33pub struct OpenAiEmbedder {
34 api_key: ApiKey,
35 client: reqwest::Client,
36 config: EmbedConfig,
37 model: String,
38}
39
40impl OpenAiEmbedder {
41 pub fn new(config: EmbedConfig) -> Result<Self> {
42 let api_key = ApiKey::from_env("OPENAI_API_KEY")?;
43 Self::with_api_key(config, api_key)
44 }
45
46 pub fn with_api_key(config: EmbedConfig, api_key: ApiKey) -> Result<Self> {
47 let client = crate::http::build_client(&config)
48 .map_err(|e| EmbedError::Config(format!("failed to build HTTP client: {e}")))?;
49 Ok(Self {
50 api_key,
51 client,
52 config,
53 model: DEFAULT_MODEL.to_string(),
54 })
55 }
56
57 fn dimension_for_model(model: &str) -> usize {
58 match model {
59 "text-embedding-3-large" => 3072,
60 _ => 1536,
61 }
62 }
63
64 fn base_url(&self) -> &str {
65 self.config.base_url.as_deref().unwrap_or(OPENAI_BASE_URL)
66 }
67
68 fn truncate_text(text: &str) -> String {
69 let words: Vec<&str> = text.split_whitespace().collect();
70 if words.len() <= MAX_TOKENS_PER_TEXT {
71 text.to_string()
72 } else {
73 words[..MAX_TOKENS_PER_TEXT].join(" ")
74 }
75 }
76}
77
78#[async_trait::async_trait]
79impl crate::Embedder for OpenAiEmbedder {
80 fn dimension(&self) -> usize {
81 Self::dimension_for_model(&self.model)
82 }
83
84 fn model_id(&self) -> &str {
85 &self.model
86 }
87
88 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
89 if texts.is_empty() {
90 return Err(EmbedError::EmptyInput);
91 }
92
93 if self.config.batch_size > MAX_BATCH_SIZE {
94 return Err(EmbedError::BatchTooLarge {
95 batch_size: self.config.batch_size,
96 max_batch_size: MAX_BATCH_SIZE,
97 });
98 }
99
100 let truncated: Vec<String> = texts.iter().map(|t| Self::truncate_text(t)).collect();
101 let url = format!("{}/v1/embeddings", self.base_url());
102
103 let mut all_embeddings: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
104
105 for (batch_idx, chunk) in truncated.chunks(self.config.batch_size).enumerate() {
106 let batch: Vec<String> = chunk.to_vec();
107 let batch_start = batch_idx * self.config.batch_size;
108
109 tracing::debug!(
110 model = %self.model,
111 batch_index = batch_idx,
112 batch_size = batch.len(),
113 url = %url,
114 "sending embedding request"
115 );
116
117 let response_data = self.send_with_retry(&url, &batch).await?;
118
119 for data in response_data {
120 let global_idx = batch_start + data.index;
121 if global_idx < all_embeddings.len() {
122 all_embeddings[global_idx] = Some(data.embedding);
123 }
124 }
125
126 tracing::info!(
127 model = %self.model,
128 batch_index = batch_idx,
129 batch_size = batch.len(),
130 "batch embedding completed"
131 );
132 }
133
134 all_embeddings
135 .into_iter()
136 .collect::<Option<Vec<_>>>()
137 .ok_or_else(|| EmbedError::InvalidResponse("missing embeddings in response".into()))
138 }
139}
140
141impl OpenAiEmbedder {
142 async fn send_with_retry(&self, url: &str, batch: &[String]) -> Result<Vec<EmbeddingData>> {
143 let request_body = EmbedRequest {
144 model: &self.model,
145 input: batch,
146 encoding_format: "float",
147 };
148
149 let mut last_error: Option<EmbedError> = None;
150
151 for attempt in 0..=self.config.max_retries {
152 if attempt > 0 {
153 let delay = self.config.base_delay * 2u32.pow(attempt - 1);
154 tokio::time::sleep(delay).await;
155 }
156
157 let response = self
158 .client
159 .post(url)
160 .bearer_auth(&*self.api_key)
161 .json(&request_body)
162 .send()
163 .await;
164
165 match response {
166 Ok(resp) => {
167 let status = resp.status();
168
169 if status.is_success() {
170 match resp.json::<EmbedResponse>().await {
171 Ok(parsed) => return Ok(parsed.data),
172 Err(e) => {
173 last_error = Some(EmbedError::InvalidResponse(format!(
174 "failed to parse response: {e}"
175 )));
176 break;
177 }
178 }
179 }
180
181 if status.as_u16() == 429 {
182 let retry_after = resp
183 .headers()
184 .get("retry-after")
185 .and_then(|v| v.to_str().ok())
186 .and_then(|v| v.parse::<u64>().ok())
187 .map(Duration::from_secs);
188 return Err(EmbedError::RateLimited { retry_after });
189 }
190
191 if status.as_u16() == 401 || status.as_u16() == 403 {
192 let body = resp.text().await.unwrap_or_default();
193 return Err(EmbedError::Auth(body));
194 }
195
196 let body = resp.text().await.unwrap_or_default();
197 last_error = Some(EmbedError::Http(format!(
198 "HTTP {} {}",
199 status.as_u16(),
200 body
201 )));
202 }
203 Err(e) => {
204 last_error = Some(EmbedError::Http(e.to_string()));
205 }
206 }
207 }
208
209 Err(last_error.unwrap_or_else(|| EmbedError::Http("unknown error".into())))
210 }
211}
212
213#[cfg(test)]
214#[allow(clippy::unwrap_used)]
215mod tests {
216 use super::*;
217 use crate::api_key::ApiKey;
218 use crate::config::EmbedConfig;
219 use crate::Embedder;
220 use serde_json::json;
221 use wiremock::matchers::{method, path};
222 use wiremock::{Mock, MockServer, ResponseTemplate};
223
224 fn test_config(base_url: String) -> EmbedConfig {
225 EmbedConfig {
226 base_url: Some(base_url),
227 ..EmbedConfig::default()
228 }
229 }
230
231 fn test_api_key() -> ApiKey {
232 ApiKey::from("sk-test-key")
233 }
234
235 fn make_embed_response(embeddings: Vec<Vec<f32>>) -> serde_json::Value {
236 let data: Vec<_> = embeddings
237 .into_iter()
238 .enumerate()
239 .map(|(i, embedding)| {
240 json!({
241 "object": "embedding",
242 "index": i,
243 "embedding": embedding,
244 })
245 })
246 .collect();
247
248 json!({
249 "object": "list",
250 "data": data,
251 "model": "text-embedding-3-small",
252 })
253 }
254
255 #[tokio::test]
256 async fn happy_path_returns_correct_vectors() {
257 let mock_server = MockServer::start().await;
258 let expected = vec![vec![0.1_f32, 0.2, 0.3], vec![0.4, 0.5, 0.6]];
259
260 Mock::given(method("POST"))
261 .and(path("/v1/embeddings"))
262 .respond_with(
263 ResponseTemplate::new(200).set_body_json(make_embed_response(expected.clone())),
264 )
265 .expect(1)
266 .mount(&mock_server)
267 .await;
268
269 let config = test_config(mock_server.uri());
270 let embedder = OpenAiEmbedder::with_api_key(config, test_api_key()).unwrap();
271
272 let texts: Vec<String> = vec!["hello".into(), "world".into()];
273 let result = embedder.embed(&texts).await.unwrap();
274
275 assert_eq!(result.len(), 2);
276 assert_eq!(result[0], vec![0.1_f32, 0.2, 0.3]);
277 assert_eq!(result[1], vec![0.4, 0.5, 0.6]);
278 }
279
280 #[tokio::test]
281 async fn auth_failure_401_returns_auth_error() {
282 let mock_server = MockServer::start().await;
283
284 Mock::given(method("POST"))
285 .and(path("/v1/embeddings"))
286 .respond_with(ResponseTemplate::new(401).set_body_string("invalid api key"))
287 .expect(1)
288 .mount(&mock_server)
289 .await;
290
291 let config = test_config(mock_server.uri());
292 let embedder = OpenAiEmbedder::with_api_key(config, test_api_key()).unwrap();
293
294 let texts: Vec<String> = vec!["hello".into()];
295 let result = embedder.embed(&texts).await;
296
297 assert!(result.is_err());
298 match result.unwrap_err() {
299 EmbedError::Auth(_) => {}
300 other => panic!("expected Auth error, got: {other:?}"),
301 }
302 }
303
304 #[tokio::test]
305 async fn rate_limit_429_returns_rate_limited_error() {
306 let mock_server = MockServer::start().await;
307
308 Mock::given(method("POST"))
309 .and(path("/v1/embeddings"))
310 .respond_with(
311 ResponseTemplate::new(429)
312 .set_body_string("rate limited")
313 .insert_header("retry-after", "42"),
314 )
315 .expect(1)
316 .mount(&mock_server)
317 .await;
318
319 let config = test_config(mock_server.uri());
320 let embedder = OpenAiEmbedder::with_api_key(config, test_api_key()).unwrap();
321
322 let texts: Vec<String> = vec!["hello".into()];
323 let result = embedder.embed(&texts).await;
324
325 assert!(result.is_err());
326 match result.unwrap_err() {
327 EmbedError::RateLimited { retry_after } => {
328 assert_eq!(retry_after, Some(Duration::from_secs(42)));
329 }
330 other => panic!("expected RateLimited error, got: {other:?}"),
331 }
332 }
333
334 #[tokio::test]
335 async fn batching_splits_250_texts_into_3_chunks() {
336 let mock_server = MockServer::start().await;
337
338 let generate_response = |count: usize| -> serde_json::Value {
339 let embeddings: Vec<Vec<f32>> = (0..count).map(|_| vec![0.1, 0.2, 0.3]).collect();
340 make_embed_response(embeddings)
341 };
342
343 Mock::given(method("POST"))
344 .and(path("/v1/embeddings"))
345 .respond_with(move |req: &wiremock::Request| {
346 let body: serde_json::Value = serde_json::from_slice(&req.body).unwrap_or_default();
347 let input_len = body["input"].as_array().map(|a| a.len()).unwrap_or(0);
348 let resp = generate_response(input_len);
349 ResponseTemplate::new(200).set_body_json(resp)
350 })
351 .expect(3)
352 .mount(&mock_server)
353 .await;
354
355 let config = EmbedConfig {
356 base_url: Some(mock_server.uri()),
357 ..EmbedConfig::default()
358 };
359 let embedder = OpenAiEmbedder::with_api_key(config, test_api_key()).unwrap();
360
361 let texts: Vec<String> = (0..250).map(|i| format!("text {i}")).collect();
362 let result = embedder.embed(&texts).await.unwrap();
363
364 assert_eq!(result.len(), 250);
365 for embedding in &result {
366 assert_eq!(embedding, &vec![0.1_f32, 0.2, 0.3]);
367 }
368 }
369
370 #[tokio::test]
371 async fn empty_input_returns_empty_input_error() {
372 let mock_server = MockServer::start().await;
373 let config = test_config(mock_server.uri());
374 let embedder = OpenAiEmbedder::with_api_key(config, test_api_key()).unwrap();
375
376 let texts: Vec<String> = vec![];
377 let result = embedder.embed(&texts).await;
378
379 assert!(result.is_err());
380 match result.unwrap_err() {
381 EmbedError::EmptyInput => {}
382 other => panic!("expected EmptyInput error, got: {other:?}"),
383 }
384 }
385
386 #[tokio::test]
387 async fn embed_query_default_impl_calls_embed() {
388 let mock_server = MockServer::start().await;
389 let expected = vec![0.1_f32, 0.2, 0.3];
390
391 Mock::given(method("POST"))
392 .and(path("/v1/embeddings"))
393 .respond_with(
394 ResponseTemplate::new(200)
395 .set_body_json(make_embed_response(vec![expected.clone()])),
396 )
397 .expect(1)
398 .mount(&mock_server)
399 .await;
400
401 let config = test_config(mock_server.uri());
402 let embedder = OpenAiEmbedder::with_api_key(config, test_api_key()).unwrap();
403
404 let result = embedder.embed_query("hello").await.unwrap();
405 assert_eq!(result, expected);
406 }
407
408 #[cfg(feature = "live-providers")]
409 #[tokio::test]
410 async fn openai_live_smoke() {
411 if std::env::var("OPENAI_API_KEY").is_err() {
412 return;
413 }
414 let config = EmbedConfig::default();
415 let embedder = OpenAiEmbedder::new(config).unwrap();
416
417 assert_eq!(embedder.dimension(), 1536);
418 assert_eq!(embedder.model_id(), "text-embedding-3-small");
419
420 let texts: Vec<String> = vec!["hello world".into(), "goodbye world".into()];
421 let embeddings = embedder.embed(&texts).await.unwrap();
422
423 assert_eq!(embeddings.len(), 2);
424 for embedding in &embeddings {
425 assert_eq!(embedding.len(), 1536);
426 let sum: f32 = embedding.iter().sum();
427 assert!(sum != 0.0, "embedding should not be all zeros");
428 }
429 }
430}