Skip to main content

directory_indexer/embedding/
openai.rs

1use async_trait::async_trait;
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4
5use super::provider::{EmbeddingProvider, EmbeddingResponse, EmbeddingUsage};
6use crate::error::{IndexerError, Result};
7
8pub struct OpenAIProvider {
9    client: Client,
10    endpoint: String,
11    model: String,
12    api_key: String,
13}
14
15#[derive(Serialize)]
16struct OpenAIEmbedRequest {
17    input: Vec<String>,
18    model: String,
19}
20
21#[derive(Deserialize)]
22struct OpenAIEmbedResponse {
23    data: Vec<OpenAIEmbedData>,
24    model: String,
25    usage: OpenAIUsage,
26}
27
28#[derive(Deserialize)]
29struct OpenAIEmbedData {
30    embedding: Vec<f32>,
31    #[allow(dead_code)]
32    index: usize,
33    #[allow(dead_code)]
34    object: String,
35}
36
37#[derive(Deserialize)]
38struct OpenAIUsage {
39    prompt_tokens: u32,
40    total_tokens: u32,
41}
42
43impl OpenAIProvider {
44    pub fn new(endpoint: String, model: String, api_key: String) -> Self {
45        let client = Client::builder()
46            .timeout(std::time::Duration::from_secs(60)) // Embedding can take longer
47            .build()
48            .unwrap_or_else(|_| Client::new());
49
50        Self {
51            client,
52            endpoint,
53            model,
54            api_key,
55        }
56    }
57}
58
59#[async_trait]
60impl EmbeddingProvider for OpenAIProvider {
61    fn model_name(&self) -> &str {
62        &self.model
63    }
64
65    fn embedding_dimension(&self) -> usize {
66        // text-embedding-ada-002 uses 1536 dimensions
67        // text-embedding-3-small uses 1536 dimensions
68        // text-embedding-3-large uses 3072 dimensions
69        match self.model.as_str() {
70            "text-embedding-3-large" => 3072,
71            _ => 1536,
72        }
73    }
74
75    async fn generate_embeddings(&self, texts: Vec<String>) -> Result<EmbeddingResponse> {
76        let request = OpenAIEmbedRequest {
77            input: texts,
78            model: self.model.clone(),
79        };
80
81        let response = self
82            .client
83            .post(format!("{}/v1/embeddings", self.endpoint))
84            .header("Authorization", format!("Bearer {}", self.api_key))
85            .header("Content-Type", "application/json")
86            .json(&request)
87            .send()
88            .await
89            .map_err(|e| IndexerError::embedding(format!("Failed to send OpenAI request: {e}")))?;
90
91        if !response.status().is_success() {
92            let status = response.status();
93            return Err(IndexerError::embedding(format!(
94                "OpenAI API returned error: {status}"
95            )));
96        }
97
98        let openai_response: OpenAIEmbedResponse = response.json().await.map_err(|e| {
99            IndexerError::embedding(format!("Failed to parse OpenAI response: {e}"))
100        })?;
101
102        let embeddings = openai_response
103            .data
104            .into_iter()
105            .map(|data| data.embedding)
106            .collect();
107
108        Ok(EmbeddingResponse {
109            embeddings,
110            model: openai_response.model,
111            usage: Some(EmbeddingUsage {
112                prompt_tokens: Some(openai_response.usage.prompt_tokens),
113                total_tokens: Some(openai_response.usage.total_tokens),
114            }),
115        })
116    }
117
118    async fn health_check(&self) -> Result<bool> {
119        // Try a simple request to check if the API key and endpoint work
120        let test_request = OpenAIEmbedRequest {
121            input: vec!["test".to_string()],
122            model: self.model.clone(),
123        };
124
125        let response = self
126            .client
127            .post(format!("{}/v1/embeddings", self.endpoint))
128            .header("Authorization", format!("Bearer {}", self.api_key))
129            .header("Content-Type", "application/json")
130            .json(&test_request)
131            .send()
132            .await;
133
134        match response {
135            Ok(resp) => Ok(resp.status().is_success()),
136            Err(_) => Ok(false),
137        }
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use wiremock::matchers::{header, method, path};
145    use wiremock::{Mock, MockServer, ResponseTemplate};
146
147    #[tokio::test]
148    async fn test_new_provider() {
149        let provider = OpenAIProvider::new(
150            "https://api.openai.com".to_string(),
151            "text-embedding-3-small".to_string(),
152            "test-key".to_string(),
153        );
154
155        assert_eq!(provider.model_name(), "text-embedding-3-small");
156        assert_eq!(provider.embedding_dimension(), 1536);
157    }
158
159    #[tokio::test]
160    async fn test_embedding_dimensions() {
161        let small_provider = OpenAIProvider::new(
162            "https://api.openai.com".to_string(),
163            "text-embedding-3-small".to_string(),
164            "test-key".to_string(),
165        );
166        assert_eq!(small_provider.embedding_dimension(), 1536);
167
168        let large_provider = OpenAIProvider::new(
169            "https://api.openai.com".to_string(),
170            "text-embedding-3-large".to_string(),
171            "test-key".to_string(),
172        );
173        assert_eq!(large_provider.embedding_dimension(), 3072);
174
175        let ada_provider = OpenAIProvider::new(
176            "https://api.openai.com".to_string(),
177            "text-embedding-ada-002".to_string(),
178            "test-key".to_string(),
179        );
180        assert_eq!(ada_provider.embedding_dimension(), 1536);
181    }
182
183    #[tokio::test]
184    async fn test_generate_embeddings_success() {
185        let mock_server = MockServer::start().await;
186
187        let response_body = r#"{
188            "object": "list",
189            "data": [
190                {
191                    "object": "embedding",
192                    "embedding": [0.1, 0.2, 0.3],
193                    "index": 0
194                },
195                {
196                    "object": "embedding", 
197                    "embedding": [0.4, 0.5, 0.6],
198                    "index": 1
199                }
200            ],
201            "model": "text-embedding-3-small",
202            "usage": {
203                "prompt_tokens": 10,
204                "total_tokens": 10
205            }
206        }"#;
207
208        Mock::given(method("POST"))
209            .and(path("/v1/embeddings"))
210            .and(header("authorization", "Bearer test-key"))
211            .and(header("content-type", "application/json"))
212            .respond_with(ResponseTemplate::new(200).set_body_string(response_body))
213            .mount(&mock_server)
214            .await;
215
216        let provider = OpenAIProvider::new(
217            mock_server.uri(),
218            "text-embedding-3-small".to_string(),
219            "test-key".to_string(),
220        );
221
222        let result = provider
223            .generate_embeddings(vec!["hello".to_string(), "world".to_string()])
224            .await;
225
226        assert!(result.is_ok());
227        let response = result.unwrap();
228        assert_eq!(response.embeddings.len(), 2);
229        assert_eq!(response.embeddings[0], vec![0.1, 0.2, 0.3]);
230        assert_eq!(response.embeddings[1], vec![0.4, 0.5, 0.6]);
231        assert_eq!(response.model, "text-embedding-3-small");
232        assert!(response.usage.is_some());
233        assert_eq!(response.usage.unwrap().total_tokens, Some(10));
234    }
235
236    #[tokio::test]
237    async fn test_generate_embeddings_single_text() {
238        let mock_server = MockServer::start().await;
239
240        let response_body = r#"{
241            "object": "list",
242            "data": [
243                {
244                    "object": "embedding",
245                    "embedding": [0.1, 0.2, 0.3],
246                    "index": 0
247                }
248            ],
249            "model": "text-embedding-3-small",
250            "usage": {
251                "prompt_tokens": 5,
252                "total_tokens": 5
253            }
254        }"#;
255
256        Mock::given(method("POST"))
257            .and(path("/v1/embeddings"))
258            .respond_with(ResponseTemplate::new(200).set_body_string(response_body))
259            .mount(&mock_server)
260            .await;
261
262        let provider = OpenAIProvider::new(
263            mock_server.uri(),
264            "text-embedding-3-small".to_string(),
265            "test-key".to_string(),
266        );
267
268        let result = provider
269            .generate_embeddings(vec!["hello world".to_string()])
270            .await;
271
272        assert!(result.is_ok());
273        let response = result.unwrap();
274        assert_eq!(response.embeddings.len(), 1);
275        assert_eq!(response.embeddings[0], vec![0.1, 0.2, 0.3]);
276    }
277
278    #[tokio::test]
279    async fn test_generate_embeddings_api_error() {
280        let mock_server = MockServer::start().await;
281
282        Mock::given(method("POST"))
283            .and(path("/v1/embeddings"))
284            .respond_with(
285                ResponseTemplate::new(401)
286                    .set_body_string(r#"{"error": {"message": "Invalid API key"}}"#),
287            )
288            .mount(&mock_server)
289            .await;
290
291        let provider = OpenAIProvider::new(
292            mock_server.uri(),
293            "text-embedding-3-small".to_string(),
294            "invalid-key".to_string(),
295        );
296
297        let result = provider.generate_embeddings(vec!["test".to_string()]).await;
298
299        assert!(result.is_err());
300        let error = result.unwrap_err();
301        assert!(error.to_string().contains("OpenAI API returned error"));
302    }
303
304    #[tokio::test]
305    async fn test_generate_embeddings_invalid_json() {
306        let mock_server = MockServer::start().await;
307
308        Mock::given(method("POST"))
309            .and(path("/v1/embeddings"))
310            .respond_with(ResponseTemplate::new(200).set_body_string("invalid json"))
311            .mount(&mock_server)
312            .await;
313
314        let provider = OpenAIProvider::new(
315            mock_server.uri(),
316            "text-embedding-3-small".to_string(),
317            "test-key".to_string(),
318        );
319
320        let result = provider.generate_embeddings(vec!["test".to_string()]).await;
321
322        assert!(result.is_err());
323        let error = result.unwrap_err();
324        assert!(error
325            .to_string()
326            .contains("Failed to parse OpenAI response"));
327    }
328
329    #[tokio::test]
330    async fn test_health_check_success() {
331        let mock_server = MockServer::start().await;
332
333        let response_body = r#"{
334            "object": "list",
335            "data": [
336                {
337                    "object": "embedding",
338                    "embedding": [0.1, 0.2, 0.3],
339                    "index": 0
340                }
341            ],
342            "model": "text-embedding-3-small",
343            "usage": {
344                "prompt_tokens": 1,
345                "total_tokens": 1
346            }
347        }"#;
348
349        Mock::given(method("POST"))
350            .and(path("/v1/embeddings"))
351            .respond_with(ResponseTemplate::new(200).set_body_string(response_body))
352            .mount(&mock_server)
353            .await;
354
355        let provider = OpenAIProvider::new(
356            mock_server.uri(),
357            "text-embedding-3-small".to_string(),
358            "test-key".to_string(),
359        );
360
361        let result = provider.health_check().await;
362        assert!(result.is_ok());
363        assert!(result.unwrap());
364    }
365
366    #[tokio::test]
367    async fn test_health_check_failure() {
368        let mock_server = MockServer::start().await;
369
370        Mock::given(method("POST"))
371            .and(path("/v1/embeddings"))
372            .respond_with(ResponseTemplate::new(401))
373            .mount(&mock_server)
374            .await;
375
376        let provider = OpenAIProvider::new(
377            mock_server.uri(),
378            "text-embedding-3-small".to_string(),
379            "invalid-key".to_string(),
380        );
381
382        let result = provider.health_check().await;
383        assert!(result.is_ok());
384        assert!(!result.unwrap());
385    }
386
387    #[tokio::test]
388    async fn test_health_check_network_error() {
389        // Use invalid URL to simulate network error
390        let provider = OpenAIProvider::new(
391            "http://invalid-url-that-does-not-exist:9999".to_string(),
392            "text-embedding-3-small".to_string(),
393            "test-key".to_string(),
394        );
395
396        let result = provider.health_check().await;
397        assert!(result.is_ok());
398        assert!(!result.unwrap());
399    }
400
401    #[tokio::test]
402    async fn test_request_headers_and_body() {
403        let mock_server = MockServer::start().await;
404
405        let response_body = r#"{
406            "object": "list",
407            "data": [
408                {
409                    "object": "embedding",
410                    "embedding": [0.1, 0.2, 0.3],
411                    "index": 0
412                }
413            ],
414            "model": "text-embedding-ada-002",
415            "usage": {
416                "prompt_tokens": 8,
417                "total_tokens": 8
418            }
419        }"#;
420
421        // Verify correct headers and request body
422        Mock::given(method("POST"))
423            .and(path("/v1/embeddings"))
424            .and(header("authorization", "Bearer secret-key"))
425            .and(header("content-type", "application/json"))
426            .respond_with(ResponseTemplate::new(200).set_body_string(response_body))
427            .mount(&mock_server)
428            .await;
429
430        let provider = OpenAIProvider::new(
431            mock_server.uri(),
432            "text-embedding-ada-002".to_string(),
433            "secret-key".to_string(),
434        );
435
436        let result = provider
437            .generate_embeddings(vec!["The food was delicious".to_string()])
438            .await;
439
440        assert!(result.is_ok());
441        let response = result.unwrap();
442        assert_eq!(response.model, "text-embedding-ada-002");
443    }
444
445    #[tokio::test]
446    async fn test_empty_embeddings_list() {
447        let mock_server = MockServer::start().await;
448
449        let response_body = r#"{
450            "object": "list",
451            "data": [],
452            "model": "text-embedding-3-small",
453            "usage": {
454                "prompt_tokens": 0,
455                "total_tokens": 0
456            }
457        }"#;
458
459        Mock::given(method("POST"))
460            .and(path("/v1/embeddings"))
461            .respond_with(ResponseTemplate::new(200).set_body_string(response_body))
462            .mount(&mock_server)
463            .await;
464
465        let provider = OpenAIProvider::new(
466            mock_server.uri(),
467            "text-embedding-3-small".to_string(),
468            "test-key".to_string(),
469        );
470
471        let result = provider.generate_embeddings(vec!["test".to_string()]).await;
472
473        assert!(result.is_ok());
474        let response = result.unwrap();
475        assert_eq!(response.embeddings.len(), 0);
476        assert_eq!(response.model, "text-embedding-3-small");
477    }
478}