Skip to main content

argyph_embed/
openai.rs

1use std::time::Duration;
2
3use serde::{Deserialize, Serialize};
4use tracing;
5
6use crate::api_key::ApiKey;
7use crate::config::EmbedConfig;
8use crate::error::{EmbedError, Result};
9
10const OPENAI_BASE_URL: &str = "https://api.openai.com";
11const MAX_TOKENS_PER_TEXT: usize = 8191;
12const MAX_BATCH_SIZE: usize = 2048;
13const DEFAULT_MODEL: &str = "text-embedding-3-small";
14
15#[derive(Serialize)]
16struct EmbedRequest<'a> {
17    model: &'a str,
18    input: &'a [String],
19    encoding_format: &'a str,
20}
21
22#[derive(Deserialize)]
23struct EmbedResponse {
24    data: Vec<EmbeddingData>,
25}
26
27#[derive(Deserialize)]
28struct EmbeddingData {
29    index: usize,
30    embedding: Vec<f32>,
31}
32
33pub struct OpenAiEmbedder {
34    api_key: ApiKey,
35    client: reqwest::Client,
36    config: EmbedConfig,
37    model: String,
38}
39
40impl OpenAiEmbedder {
41    pub fn new(config: EmbedConfig) -> Result<Self> {
42        let api_key = ApiKey::from_env("OPENAI_API_KEY")?;
43        Self::with_api_key(config, api_key)
44    }
45
46    pub fn with_api_key(config: EmbedConfig, api_key: ApiKey) -> Result<Self> {
47        let client = crate::http::build_client(&config)
48            .map_err(|e| EmbedError::Config(format!("failed to build HTTP client: {e}")))?;
49        Ok(Self {
50            api_key,
51            client,
52            config,
53            model: DEFAULT_MODEL.to_string(),
54        })
55    }
56
57    fn dimension_for_model(model: &str) -> usize {
58        match model {
59            "text-embedding-3-large" => 3072,
60            _ => 1536,
61        }
62    }
63
64    fn base_url(&self) -> &str {
65        self.config.base_url.as_deref().unwrap_or(OPENAI_BASE_URL)
66    }
67
68    fn truncate_text(text: &str) -> String {
69        let words: Vec<&str> = text.split_whitespace().collect();
70        if words.len() <= MAX_TOKENS_PER_TEXT {
71            text.to_string()
72        } else {
73            words[..MAX_TOKENS_PER_TEXT].join(" ")
74        }
75    }
76}
77
78#[async_trait::async_trait]
79impl crate::Embedder for OpenAiEmbedder {
80    fn dimension(&self) -> usize {
81        Self::dimension_for_model(&self.model)
82    }
83
84    fn model_id(&self) -> &str {
85        &self.model
86    }
87
88    async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
89        if texts.is_empty() {
90            return Err(EmbedError::EmptyInput);
91        }
92
93        if self.config.batch_size > MAX_BATCH_SIZE {
94            return Err(EmbedError::BatchTooLarge {
95                batch_size: self.config.batch_size,
96                max_batch_size: MAX_BATCH_SIZE,
97            });
98        }
99
100        let truncated: Vec<String> = texts.iter().map(|t| Self::truncate_text(t)).collect();
101        let url = format!("{}/v1/embeddings", self.base_url());
102
103        let mut all_embeddings: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
104
105        for (batch_idx, chunk) in truncated.chunks(self.config.batch_size).enumerate() {
106            let batch: Vec<String> = chunk.to_vec();
107            let batch_start = batch_idx * self.config.batch_size;
108
109            tracing::debug!(
110                model = %self.model,
111                batch_index = batch_idx,
112                batch_size = batch.len(),
113                url = %url,
114                "sending embedding request"
115            );
116
117            let response_data = self.send_with_retry(&url, &batch).await?;
118
119            for data in response_data {
120                let global_idx = batch_start + data.index;
121                if global_idx < all_embeddings.len() {
122                    all_embeddings[global_idx] = Some(data.embedding);
123                }
124            }
125
126            tracing::info!(
127                model = %self.model,
128                batch_index = batch_idx,
129                batch_size = batch.len(),
130                "batch embedding completed"
131            );
132        }
133
134        all_embeddings
135            .into_iter()
136            .collect::<Option<Vec<_>>>()
137            .ok_or_else(|| EmbedError::InvalidResponse("missing embeddings in response".into()))
138    }
139}
140
141impl OpenAiEmbedder {
142    async fn send_with_retry(&self, url: &str, batch: &[String]) -> Result<Vec<EmbeddingData>> {
143        let request_body = EmbedRequest {
144            model: &self.model,
145            input: batch,
146            encoding_format: "float",
147        };
148
149        let mut last_error: Option<EmbedError> = None;
150
151        for attempt in 0..=self.config.max_retries {
152            if attempt > 0 {
153                let delay = self.config.base_delay * 2u32.pow(attempt - 1);
154                tokio::time::sleep(delay).await;
155            }
156
157            let response = self
158                .client
159                .post(url)
160                .bearer_auth(&*self.api_key)
161                .json(&request_body)
162                .send()
163                .await;
164
165            match response {
166                Ok(resp) => {
167                    let status = resp.status();
168
169                    if status.is_success() {
170                        match resp.json::<EmbedResponse>().await {
171                            Ok(parsed) => return Ok(parsed.data),
172                            Err(e) => {
173                                last_error = Some(EmbedError::InvalidResponse(format!(
174                                    "failed to parse response: {e}"
175                                )));
176                                break;
177                            }
178                        }
179                    }
180
181                    if status.as_u16() == 429 {
182                        let retry_after = resp
183                            .headers()
184                            .get("retry-after")
185                            .and_then(|v| v.to_str().ok())
186                            .and_then(|v| v.parse::<u64>().ok())
187                            .map(Duration::from_secs);
188                        return Err(EmbedError::RateLimited { retry_after });
189                    }
190
191                    if status.as_u16() == 401 || status.as_u16() == 403 {
192                        let body = resp.text().await.unwrap_or_default();
193                        return Err(EmbedError::Auth(body));
194                    }
195
196                    let body = resp.text().await.unwrap_or_default();
197                    last_error = Some(EmbedError::Http(format!(
198                        "HTTP {} {}",
199                        status.as_u16(),
200                        body
201                    )));
202                }
203                Err(e) => {
204                    last_error = Some(EmbedError::Http(e.to_string()));
205                }
206            }
207        }
208
209        Err(last_error.unwrap_or_else(|| EmbedError::Http("unknown error".into())))
210    }
211}
212
213#[cfg(test)]
214#[allow(clippy::unwrap_used)]
215mod tests {
216    use super::*;
217    use crate::api_key::ApiKey;
218    use crate::config::EmbedConfig;
219    use crate::Embedder;
220    use serde_json::json;
221    use wiremock::matchers::{method, path};
222    use wiremock::{Mock, MockServer, ResponseTemplate};
223
224    fn test_config(base_url: String) -> EmbedConfig {
225        EmbedConfig {
226            base_url: Some(base_url),
227            ..EmbedConfig::default()
228        }
229    }
230
231    fn test_api_key() -> ApiKey {
232        ApiKey::from("sk-test-key")
233    }
234
235    fn make_embed_response(embeddings: Vec<Vec<f32>>) -> serde_json::Value {
236        let data: Vec<_> = embeddings
237            .into_iter()
238            .enumerate()
239            .map(|(i, embedding)| {
240                json!({
241                    "object": "embedding",
242                    "index": i,
243                    "embedding": embedding,
244                })
245            })
246            .collect();
247
248        json!({
249            "object": "list",
250            "data": data,
251            "model": "text-embedding-3-small",
252        })
253    }
254
255    #[tokio::test]
256    async fn happy_path_returns_correct_vectors() {
257        let mock_server = MockServer::start().await;
258        let expected = vec![vec![0.1_f32, 0.2, 0.3], vec![0.4, 0.5, 0.6]];
259
260        Mock::given(method("POST"))
261            .and(path("/v1/embeddings"))
262            .respond_with(
263                ResponseTemplate::new(200).set_body_json(make_embed_response(expected.clone())),
264            )
265            .expect(1)
266            .mount(&mock_server)
267            .await;
268
269        let config = test_config(mock_server.uri());
270        let embedder = OpenAiEmbedder::with_api_key(config, test_api_key()).unwrap();
271
272        let texts: Vec<String> = vec!["hello".into(), "world".into()];
273        let result = embedder.embed(&texts).await.unwrap();
274
275        assert_eq!(result.len(), 2);
276        assert_eq!(result[0], vec![0.1_f32, 0.2, 0.3]);
277        assert_eq!(result[1], vec![0.4, 0.5, 0.6]);
278    }
279
280    #[tokio::test]
281    async fn auth_failure_401_returns_auth_error() {
282        let mock_server = MockServer::start().await;
283
284        Mock::given(method("POST"))
285            .and(path("/v1/embeddings"))
286            .respond_with(ResponseTemplate::new(401).set_body_string("invalid api key"))
287            .expect(1)
288            .mount(&mock_server)
289            .await;
290
291        let config = test_config(mock_server.uri());
292        let embedder = OpenAiEmbedder::with_api_key(config, test_api_key()).unwrap();
293
294        let texts: Vec<String> = vec!["hello".into()];
295        let result = embedder.embed(&texts).await;
296
297        assert!(result.is_err());
298        match result.unwrap_err() {
299            EmbedError::Auth(_) => {}
300            other => panic!("expected Auth error, got: {other:?}"),
301        }
302    }
303
304    #[tokio::test]
305    async fn rate_limit_429_returns_rate_limited_error() {
306        let mock_server = MockServer::start().await;
307
308        Mock::given(method("POST"))
309            .and(path("/v1/embeddings"))
310            .respond_with(
311                ResponseTemplate::new(429)
312                    .set_body_string("rate limited")
313                    .insert_header("retry-after", "42"),
314            )
315            .expect(1)
316            .mount(&mock_server)
317            .await;
318
319        let config = test_config(mock_server.uri());
320        let embedder = OpenAiEmbedder::with_api_key(config, test_api_key()).unwrap();
321
322        let texts: Vec<String> = vec!["hello".into()];
323        let result = embedder.embed(&texts).await;
324
325        assert!(result.is_err());
326        match result.unwrap_err() {
327            EmbedError::RateLimited { retry_after } => {
328                assert_eq!(retry_after, Some(Duration::from_secs(42)));
329            }
330            other => panic!("expected RateLimited error, got: {other:?}"),
331        }
332    }
333
334    #[tokio::test]
335    async fn batching_splits_250_texts_into_3_chunks() {
336        let mock_server = MockServer::start().await;
337
338        let generate_response = |count: usize| -> serde_json::Value {
339            let embeddings: Vec<Vec<f32>> = (0..count).map(|_| vec![0.1, 0.2, 0.3]).collect();
340            make_embed_response(embeddings)
341        };
342
343        Mock::given(method("POST"))
344            .and(path("/v1/embeddings"))
345            .respond_with(move |req: &wiremock::Request| {
346                let body: serde_json::Value = serde_json::from_slice(&req.body).unwrap_or_default();
347                let input_len = body["input"].as_array().map(|a| a.len()).unwrap_or(0);
348                let resp = generate_response(input_len);
349                ResponseTemplate::new(200).set_body_json(resp)
350            })
351            .expect(3)
352            .mount(&mock_server)
353            .await;
354
355        let config = EmbedConfig {
356            base_url: Some(mock_server.uri()),
357            ..EmbedConfig::default()
358        };
359        let embedder = OpenAiEmbedder::with_api_key(config, test_api_key()).unwrap();
360
361        let texts: Vec<String> = (0..250).map(|i| format!("text {i}")).collect();
362        let result = embedder.embed(&texts).await.unwrap();
363
364        assert_eq!(result.len(), 250);
365        for embedding in &result {
366            assert_eq!(embedding, &vec![0.1_f32, 0.2, 0.3]);
367        }
368    }
369
370    #[tokio::test]
371    async fn empty_input_returns_empty_input_error() {
372        let mock_server = MockServer::start().await;
373        let config = test_config(mock_server.uri());
374        let embedder = OpenAiEmbedder::with_api_key(config, test_api_key()).unwrap();
375
376        let texts: Vec<String> = vec![];
377        let result = embedder.embed(&texts).await;
378
379        assert!(result.is_err());
380        match result.unwrap_err() {
381            EmbedError::EmptyInput => {}
382            other => panic!("expected EmptyInput error, got: {other:?}"),
383        }
384    }
385
386    #[tokio::test]
387    async fn embed_query_default_impl_calls_embed() {
388        let mock_server = MockServer::start().await;
389        let expected = vec![0.1_f32, 0.2, 0.3];
390
391        Mock::given(method("POST"))
392            .and(path("/v1/embeddings"))
393            .respond_with(
394                ResponseTemplate::new(200)
395                    .set_body_json(make_embed_response(vec![expected.clone()])),
396            )
397            .expect(1)
398            .mount(&mock_server)
399            .await;
400
401        let config = test_config(mock_server.uri());
402        let embedder = OpenAiEmbedder::with_api_key(config, test_api_key()).unwrap();
403
404        let result = embedder.embed_query("hello").await.unwrap();
405        assert_eq!(result, expected);
406    }
407
408    #[cfg(feature = "live-providers")]
409    #[tokio::test]
410    async fn openai_live_smoke() {
411        if std::env::var("OPENAI_API_KEY").is_err() {
412            return;
413        }
414        let config = EmbedConfig::default();
415        let embedder = OpenAiEmbedder::new(config).unwrap();
416
417        assert_eq!(embedder.dimension(), 1536);
418        assert_eq!(embedder.model_id(), "text-embedding-3-small");
419
420        let texts: Vec<String> = vec!["hello world".into(), "goodbye world".into()];
421        let embeddings = embedder.embed(&texts).await.unwrap();
422
423        assert_eq!(embeddings.len(), 2);
424        for embedding in &embeddings {
425            assert_eq!(embedding.len(), 1536);
426            let sum: f32 = embedding.iter().sum();
427            assert!(sum != 0.0, "embedding should not be all zeros");
428        }
429    }
430}