directory_indexer/embedding/
openai.rs1use async_trait::async_trait;
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4
5use super::provider::{EmbeddingProvider, EmbeddingResponse, EmbeddingUsage};
6use crate::error::{IndexerError, Result};
7
8pub struct OpenAIProvider {
9 client: Client,
10 endpoint: String,
11 model: String,
12 api_key: String,
13}
14
15#[derive(Serialize)]
16struct OpenAIEmbedRequest {
17 input: Vec<String>,
18 model: String,
19}
20
21#[derive(Deserialize)]
22struct OpenAIEmbedResponse {
23 data: Vec<OpenAIEmbedData>,
24 model: String,
25 usage: OpenAIUsage,
26}
27
28#[derive(Deserialize)]
29struct OpenAIEmbedData {
30 embedding: Vec<f32>,
31 #[allow(dead_code)]
32 index: usize,
33 #[allow(dead_code)]
34 object: String,
35}
36
37#[derive(Deserialize)]
38struct OpenAIUsage {
39 prompt_tokens: u32,
40 total_tokens: u32,
41}
42
43impl OpenAIProvider {
44 pub fn new(endpoint: String, model: String, api_key: String) -> Self {
45 let client = Client::builder()
46 .timeout(std::time::Duration::from_secs(60)) .build()
48 .unwrap_or_else(|_| Client::new());
49
50 Self {
51 client,
52 endpoint,
53 model,
54 api_key,
55 }
56 }
57}
58
59#[async_trait]
60impl EmbeddingProvider for OpenAIProvider {
61 fn model_name(&self) -> &str {
62 &self.model
63 }
64
65 fn embedding_dimension(&self) -> usize {
66 match self.model.as_str() {
70 "text-embedding-3-large" => 3072,
71 _ => 1536,
72 }
73 }
74
75 async fn generate_embeddings(&self, texts: Vec<String>) -> Result<EmbeddingResponse> {
76 let request = OpenAIEmbedRequest {
77 input: texts,
78 model: self.model.clone(),
79 };
80
81 let response = self
82 .client
83 .post(format!("{}/v1/embeddings", self.endpoint))
84 .header("Authorization", format!("Bearer {}", self.api_key))
85 .header("Content-Type", "application/json")
86 .json(&request)
87 .send()
88 .await
89 .map_err(|e| IndexerError::embedding(format!("Failed to send OpenAI request: {e}")))?;
90
91 if !response.status().is_success() {
92 let status = response.status();
93 return Err(IndexerError::embedding(format!(
94 "OpenAI API returned error: {status}"
95 )));
96 }
97
98 let openai_response: OpenAIEmbedResponse = response.json().await.map_err(|e| {
99 IndexerError::embedding(format!("Failed to parse OpenAI response: {e}"))
100 })?;
101
102 let embeddings = openai_response
103 .data
104 .into_iter()
105 .map(|data| data.embedding)
106 .collect();
107
108 Ok(EmbeddingResponse {
109 embeddings,
110 model: openai_response.model,
111 usage: Some(EmbeddingUsage {
112 prompt_tokens: Some(openai_response.usage.prompt_tokens),
113 total_tokens: Some(openai_response.usage.total_tokens),
114 }),
115 })
116 }
117
118 async fn health_check(&self) -> Result<bool> {
119 let test_request = OpenAIEmbedRequest {
121 input: vec!["test".to_string()],
122 model: self.model.clone(),
123 };
124
125 let response = self
126 .client
127 .post(format!("{}/v1/embeddings", self.endpoint))
128 .header("Authorization", format!("Bearer {}", self.api_key))
129 .header("Content-Type", "application/json")
130 .json(&test_request)
131 .send()
132 .await;
133
134 match response {
135 Ok(resp) => Ok(resp.status().is_success()),
136 Err(_) => Ok(false),
137 }
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use wiremock::matchers::{header, method, path};
145 use wiremock::{Mock, MockServer, ResponseTemplate};
146
147 #[tokio::test]
148 async fn test_new_provider() {
149 let provider = OpenAIProvider::new(
150 "https://api.openai.com".to_string(),
151 "text-embedding-3-small".to_string(),
152 "test-key".to_string(),
153 );
154
155 assert_eq!(provider.model_name(), "text-embedding-3-small");
156 assert_eq!(provider.embedding_dimension(), 1536);
157 }
158
159 #[tokio::test]
160 async fn test_embedding_dimensions() {
161 let small_provider = OpenAIProvider::new(
162 "https://api.openai.com".to_string(),
163 "text-embedding-3-small".to_string(),
164 "test-key".to_string(),
165 );
166 assert_eq!(small_provider.embedding_dimension(), 1536);
167
168 let large_provider = OpenAIProvider::new(
169 "https://api.openai.com".to_string(),
170 "text-embedding-3-large".to_string(),
171 "test-key".to_string(),
172 );
173 assert_eq!(large_provider.embedding_dimension(), 3072);
174
175 let ada_provider = OpenAIProvider::new(
176 "https://api.openai.com".to_string(),
177 "text-embedding-ada-002".to_string(),
178 "test-key".to_string(),
179 );
180 assert_eq!(ada_provider.embedding_dimension(), 1536);
181 }
182
183 #[tokio::test]
184 async fn test_generate_embeddings_success() {
185 let mock_server = MockServer::start().await;
186
187 let response_body = r#"{
188 "object": "list",
189 "data": [
190 {
191 "object": "embedding",
192 "embedding": [0.1, 0.2, 0.3],
193 "index": 0
194 },
195 {
196 "object": "embedding",
197 "embedding": [0.4, 0.5, 0.6],
198 "index": 1
199 }
200 ],
201 "model": "text-embedding-3-small",
202 "usage": {
203 "prompt_tokens": 10,
204 "total_tokens": 10
205 }
206 }"#;
207
208 Mock::given(method("POST"))
209 .and(path("/v1/embeddings"))
210 .and(header("authorization", "Bearer test-key"))
211 .and(header("content-type", "application/json"))
212 .respond_with(ResponseTemplate::new(200).set_body_string(response_body))
213 .mount(&mock_server)
214 .await;
215
216 let provider = OpenAIProvider::new(
217 mock_server.uri(),
218 "text-embedding-3-small".to_string(),
219 "test-key".to_string(),
220 );
221
222 let result = provider
223 .generate_embeddings(vec!["hello".to_string(), "world".to_string()])
224 .await;
225
226 assert!(result.is_ok());
227 let response = result.unwrap();
228 assert_eq!(response.embeddings.len(), 2);
229 assert_eq!(response.embeddings[0], vec![0.1, 0.2, 0.3]);
230 assert_eq!(response.embeddings[1], vec![0.4, 0.5, 0.6]);
231 assert_eq!(response.model, "text-embedding-3-small");
232 assert!(response.usage.is_some());
233 assert_eq!(response.usage.unwrap().total_tokens, Some(10));
234 }
235
236 #[tokio::test]
237 async fn test_generate_embeddings_single_text() {
238 let mock_server = MockServer::start().await;
239
240 let response_body = r#"{
241 "object": "list",
242 "data": [
243 {
244 "object": "embedding",
245 "embedding": [0.1, 0.2, 0.3],
246 "index": 0
247 }
248 ],
249 "model": "text-embedding-3-small",
250 "usage": {
251 "prompt_tokens": 5,
252 "total_tokens": 5
253 }
254 }"#;
255
256 Mock::given(method("POST"))
257 .and(path("/v1/embeddings"))
258 .respond_with(ResponseTemplate::new(200).set_body_string(response_body))
259 .mount(&mock_server)
260 .await;
261
262 let provider = OpenAIProvider::new(
263 mock_server.uri(),
264 "text-embedding-3-small".to_string(),
265 "test-key".to_string(),
266 );
267
268 let result = provider
269 .generate_embeddings(vec!["hello world".to_string()])
270 .await;
271
272 assert!(result.is_ok());
273 let response = result.unwrap();
274 assert_eq!(response.embeddings.len(), 1);
275 assert_eq!(response.embeddings[0], vec![0.1, 0.2, 0.3]);
276 }
277
278 #[tokio::test]
279 async fn test_generate_embeddings_api_error() {
280 let mock_server = MockServer::start().await;
281
282 Mock::given(method("POST"))
283 .and(path("/v1/embeddings"))
284 .respond_with(
285 ResponseTemplate::new(401)
286 .set_body_string(r#"{"error": {"message": "Invalid API key"}}"#),
287 )
288 .mount(&mock_server)
289 .await;
290
291 let provider = OpenAIProvider::new(
292 mock_server.uri(),
293 "text-embedding-3-small".to_string(),
294 "invalid-key".to_string(),
295 );
296
297 let result = provider.generate_embeddings(vec!["test".to_string()]).await;
298
299 assert!(result.is_err());
300 let error = result.unwrap_err();
301 assert!(error.to_string().contains("OpenAI API returned error"));
302 }
303
304 #[tokio::test]
305 async fn test_generate_embeddings_invalid_json() {
306 let mock_server = MockServer::start().await;
307
308 Mock::given(method("POST"))
309 .and(path("/v1/embeddings"))
310 .respond_with(ResponseTemplate::new(200).set_body_string("invalid json"))
311 .mount(&mock_server)
312 .await;
313
314 let provider = OpenAIProvider::new(
315 mock_server.uri(),
316 "text-embedding-3-small".to_string(),
317 "test-key".to_string(),
318 );
319
320 let result = provider.generate_embeddings(vec!["test".to_string()]).await;
321
322 assert!(result.is_err());
323 let error = result.unwrap_err();
324 assert!(error
325 .to_string()
326 .contains("Failed to parse OpenAI response"));
327 }
328
329 #[tokio::test]
330 async fn test_health_check_success() {
331 let mock_server = MockServer::start().await;
332
333 let response_body = r#"{
334 "object": "list",
335 "data": [
336 {
337 "object": "embedding",
338 "embedding": [0.1, 0.2, 0.3],
339 "index": 0
340 }
341 ],
342 "model": "text-embedding-3-small",
343 "usage": {
344 "prompt_tokens": 1,
345 "total_tokens": 1
346 }
347 }"#;
348
349 Mock::given(method("POST"))
350 .and(path("/v1/embeddings"))
351 .respond_with(ResponseTemplate::new(200).set_body_string(response_body))
352 .mount(&mock_server)
353 .await;
354
355 let provider = OpenAIProvider::new(
356 mock_server.uri(),
357 "text-embedding-3-small".to_string(),
358 "test-key".to_string(),
359 );
360
361 let result = provider.health_check().await;
362 assert!(result.is_ok());
363 assert!(result.unwrap());
364 }
365
366 #[tokio::test]
367 async fn test_health_check_failure() {
368 let mock_server = MockServer::start().await;
369
370 Mock::given(method("POST"))
371 .and(path("/v1/embeddings"))
372 .respond_with(ResponseTemplate::new(401))
373 .mount(&mock_server)
374 .await;
375
376 let provider = OpenAIProvider::new(
377 mock_server.uri(),
378 "text-embedding-3-small".to_string(),
379 "invalid-key".to_string(),
380 );
381
382 let result = provider.health_check().await;
383 assert!(result.is_ok());
384 assert!(!result.unwrap());
385 }
386
387 #[tokio::test]
388 async fn test_health_check_network_error() {
389 let provider = OpenAIProvider::new(
391 "http://invalid-url-that-does-not-exist:9999".to_string(),
392 "text-embedding-3-small".to_string(),
393 "test-key".to_string(),
394 );
395
396 let result = provider.health_check().await;
397 assert!(result.is_ok());
398 assert!(!result.unwrap());
399 }
400
401 #[tokio::test]
402 async fn test_request_headers_and_body() {
403 let mock_server = MockServer::start().await;
404
405 let response_body = r#"{
406 "object": "list",
407 "data": [
408 {
409 "object": "embedding",
410 "embedding": [0.1, 0.2, 0.3],
411 "index": 0
412 }
413 ],
414 "model": "text-embedding-ada-002",
415 "usage": {
416 "prompt_tokens": 8,
417 "total_tokens": 8
418 }
419 }"#;
420
421 Mock::given(method("POST"))
423 .and(path("/v1/embeddings"))
424 .and(header("authorization", "Bearer secret-key"))
425 .and(header("content-type", "application/json"))
426 .respond_with(ResponseTemplate::new(200).set_body_string(response_body))
427 .mount(&mock_server)
428 .await;
429
430 let provider = OpenAIProvider::new(
431 mock_server.uri(),
432 "text-embedding-ada-002".to_string(),
433 "secret-key".to_string(),
434 );
435
436 let result = provider
437 .generate_embeddings(vec!["The food was delicious".to_string()])
438 .await;
439
440 assert!(result.is_ok());
441 let response = result.unwrap();
442 assert_eq!(response.model, "text-embedding-ada-002");
443 }
444
445 #[tokio::test]
446 async fn test_empty_embeddings_list() {
447 let mock_server = MockServer::start().await;
448
449 let response_body = r#"{
450 "object": "list",
451 "data": [],
452 "model": "text-embedding-3-small",
453 "usage": {
454 "prompt_tokens": 0,
455 "total_tokens": 0
456 }
457 }"#;
458
459 Mock::given(method("POST"))
460 .and(path("/v1/embeddings"))
461 .respond_with(ResponseTemplate::new(200).set_body_string(response_body))
462 .mount(&mock_server)
463 .await;
464
465 let provider = OpenAIProvider::new(
466 mock_server.uri(),
467 "text-embedding-3-small".to_string(),
468 "test-key".to_string(),
469 );
470
471 let result = provider.generate_embeddings(vec!["test".to_string()]).await;
472
473 assert!(result.is_ok());
474 let response = result.unwrap();
475 assert_eq!(response.embeddings.len(), 0);
476 assert_eq!(response.model, "text-embedding-3-small");
477 }
478}