1use std::time::Duration;
4
5use crate::config::ClientConfig;
6use crate::error::{ErrorResponse, OpenAIError};
7use crate::resources::audio::Audio;
8use crate::resources::batches::Batches;
9use crate::resources::beta::assistants::Assistants;
10use crate::resources::beta::realtime::Realtime;
11use crate::resources::beta::runs::Runs;
12use crate::resources::beta::threads::Threads;
13use crate::resources::beta::vector_stores::VectorStores;
14use crate::resources::chat::Chat;
15use crate::resources::embeddings::Embeddings;
16use crate::resources::files::Files;
17use crate::resources::fine_tuning::FineTuning;
18use crate::resources::images::Images;
19use crate::resources::models::Models;
20use crate::resources::moderations::Moderations;
21use crate::resources::responses::Responses;
22use crate::resources::uploads::Uploads;
23
24const RETRYABLE_STATUS_CODES: [u16; 4] = [429, 500, 502, 503];
26
27#[derive(Debug, Clone)]
29pub struct OpenAI {
30 pub(crate) http: reqwest::Client,
31 pub(crate) config: ClientConfig,
32}
33
34impl OpenAI {
35 pub fn new(api_key: impl Into<String>) -> Self {
37 Self::with_config(ClientConfig::new(api_key))
38 }
39
40 pub fn with_config(config: ClientConfig) -> Self {
42 let http = reqwest::Client::builder()
43 .timeout(Duration::from_secs(config.timeout_secs))
44 .build()
45 .expect("failed to build HTTP client");
46 Self { http, config }
47 }
48
49 pub fn from_env() -> Result<Self, OpenAIError> {
51 Ok(Self::with_config(ClientConfig::from_env()?))
52 }
53
54 pub fn batches(&self) -> Batches<'_> {
56 Batches::new(self)
57 }
58
59 pub fn uploads(&self) -> Uploads<'_> {
61 Uploads::new(self)
62 }
63
64 pub fn beta(&self) -> Beta<'_> {
66 Beta { client: self }
67 }
68
69 pub fn audio(&self) -> Audio<'_> {
71 Audio::new(self)
72 }
73
74 pub fn chat(&self) -> Chat<'_> {
76 Chat::new(self)
77 }
78
79 pub fn models(&self) -> Models<'_> {
81 Models::new(self)
82 }
83
84 pub fn fine_tuning(&self) -> FineTuning<'_> {
86 FineTuning::new(self)
87 }
88
89 pub fn files(&self) -> Files<'_> {
91 Files::new(self)
92 }
93
94 pub fn images(&self) -> Images<'_> {
96 Images::new(self)
97 }
98
99 pub fn moderations(&self) -> Moderations<'_> {
101 Moderations::new(self)
102 }
103
104 pub fn responses(&self) -> Responses<'_> {
106 Responses::new(self)
107 }
108
109 pub fn embeddings(&self) -> Embeddings<'_> {
111 Embeddings::new(self)
112 }
113
114 pub(crate) fn request(&self, method: reqwest::Method, path: &str) -> reqwest::RequestBuilder {
116 let url = format!("{}{}", self.config.base_url, path);
117 let mut req = self
118 .http
119 .request(method, &url)
120 .bearer_auth(&self.config.api_key);
121
122 if let Some(ref org) = self.config.organization {
123 req = req.header("OpenAI-Organization", org);
124 }
125 if let Some(ref project) = self.config.project {
126 req = req.header("OpenAI-Project", project);
127 }
128
129 req
130 }
131
132 #[allow(dead_code)]
134 pub(crate) async fn get<T: serde::de::DeserializeOwned>(
135 &self,
136 path: &str,
137 ) -> Result<T, OpenAIError> {
138 self.send_with_retry(reqwest::Method::GET, path, None::<&()>)
139 .await
140 }
141
142 pub(crate) async fn post<B: serde::Serialize, T: serde::de::DeserializeOwned>(
144 &self,
145 path: &str,
146 body: &B,
147 ) -> Result<T, OpenAIError> {
148 self.send_with_retry(reqwest::Method::POST, path, Some(body))
149 .await
150 }
151
152 pub(crate) async fn post_multipart<T: serde::de::DeserializeOwned>(
154 &self,
155 path: &str,
156 form: reqwest::multipart::Form,
157 ) -> Result<T, OpenAIError> {
158 let response = self
159 .request(reqwest::Method::POST, path)
160 .multipart(form)
161 .send()
162 .await?;
163 Self::handle_response(response).await
164 }
165
166 pub(crate) async fn get_raw(&self, path: &str) -> Result<bytes::Bytes, OpenAIError> {
168 let response = self.request(reqwest::Method::GET, path).send().await?;
169
170 let status = response.status();
171 if status.is_success() {
172 Ok(response.bytes().await?)
173 } else {
174 Err(Self::extract_error(status.as_u16(), response).await)
175 }
176 }
177
178 pub(crate) async fn post_raw<B: serde::Serialize>(
180 &self,
181 path: &str,
182 body: &B,
183 ) -> Result<bytes::Bytes, OpenAIError> {
184 let response = self
185 .request(reqwest::Method::POST, path)
186 .json(body)
187 .send()
188 .await?;
189
190 let status = response.status();
191 if status.is_success() {
192 Ok(response.bytes().await?)
193 } else {
194 Err(Self::extract_error(status.as_u16(), response).await)
195 }
196 }
197
198 #[allow(dead_code)]
200 pub(crate) async fn delete<T: serde::de::DeserializeOwned>(
201 &self,
202 path: &str,
203 ) -> Result<T, OpenAIError> {
204 self.send_with_retry(reqwest::Method::DELETE, path, None::<&()>)
205 .await
206 }
207
208 async fn send_with_retry<B: serde::Serialize, T: serde::de::DeserializeOwned>(
210 &self,
211 method: reqwest::Method,
212 path: &str,
213 body: Option<&B>,
214 ) -> Result<T, OpenAIError> {
215 let max_retries = self.config.max_retries;
216 let mut last_error: Option<OpenAIError> = None;
217
218 for attempt in 0..=max_retries {
219 let mut req = self.request(method.clone(), path);
220 if let Some(b) = body {
221 req = req.json(b);
222 }
223
224 let response = match req.send().await {
225 Ok(resp) => resp,
226 Err(e) => {
227 last_error = Some(OpenAIError::RequestError(e));
228 if attempt < max_retries {
229 tokio::time::sleep(Self::backoff_delay(attempt, None)).await;
230 continue;
231 }
232 break;
233 }
234 };
235
236 let status = response.status().as_u16();
237
238 if !RETRYABLE_STATUS_CODES.contains(&status) || attempt == max_retries {
239 return Self::handle_response(response).await;
240 }
241
242 let retry_after = response
244 .headers()
245 .get("retry-after")
246 .and_then(|v| v.to_str().ok())
247 .and_then(|v| v.parse::<f64>().ok());
248
249 last_error = Some(Self::extract_error(status, response).await);
250 tokio::time::sleep(Self::backoff_delay(attempt, retry_after)).await;
251 }
252
253 Err(last_error.unwrap_or_else(|| {
254 OpenAIError::InvalidArgument("retry loop exhausted without error".to_string())
255 }))
256 }
257
258 fn backoff_delay(attempt: u32, retry_after_secs: Option<f64>) -> Duration {
260 let exponential = 0.5 * 2.0_f64.powi(attempt as i32);
261 let secs = match retry_after_secs {
262 Some(ra) => ra.max(exponential),
263 None => exponential,
264 };
265 Duration::from_secs_f64(secs.min(60.0))
266 }
267
268 pub(crate) async fn handle_response<T: serde::de::DeserializeOwned>(
270 response: reqwest::Response,
271 ) -> Result<T, OpenAIError> {
272 let status = response.status();
273 if status.is_success() {
274 let body = response.text().await?;
275 let value: T = serde_json::from_str(&body)?;
276 Ok(value)
277 } else {
278 Err(Self::extract_error(status.as_u16(), response).await)
279 }
280 }
281
282 async fn extract_error(status: u16, response: reqwest::Response) -> OpenAIError {
284 let body = response.text().await.unwrap_or_default();
285 if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&body) {
286 OpenAIError::ApiError {
287 status,
288 message: error_resp.error.message,
289 type_: error_resp.error.type_,
290 code: error_resp.error.code,
291 }
292 } else {
293 OpenAIError::ApiError {
294 status,
295 message: body,
296 type_: None,
297 code: None,
298 }
299 }
300 }
301}
302
303pub struct Beta<'a> {
305 client: &'a OpenAI,
306}
307
308impl<'a> Beta<'a> {
309 pub fn assistants(&self) -> Assistants<'_> {
311 Assistants::new(self.client)
312 }
313
314 pub fn threads(&self) -> Threads<'_> {
316 Threads::new(self.client)
317 }
318
319 pub fn runs(&self, thread_id: &str) -> Runs<'_> {
321 Runs::new(self.client, thread_id.to_string())
322 }
323
324 pub fn vector_stores(&self) -> VectorStores<'_> {
326 VectorStores::new(self.client)
327 }
328
329 pub fn realtime(&self) -> Realtime<'_> {
331 Realtime::new(self.client)
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[test]
340 fn test_new_client() {
341 let client = OpenAI::new("sk-test-key");
342 assert_eq!(client.config.api_key, "sk-test-key");
343 assert_eq!(client.config.base_url, "https://api.openai.com/v1");
344 }
345
346 #[test]
347 fn test_with_config() {
348 let config = ClientConfig::new("sk-test")
349 .base_url("https://custom.api.com")
350 .organization("org-123")
351 .timeout_secs(30);
352 let client = OpenAI::with_config(config);
353 assert_eq!(client.config.base_url, "https://custom.api.com");
354 assert_eq!(client.config.organization.as_deref(), Some("org-123"));
355 assert_eq!(client.config.timeout_secs, 30);
356 }
357
358 #[test]
359 fn test_backoff_delay() {
360 let d = OpenAI::backoff_delay(0, None);
362 assert_eq!(d, Duration::from_millis(500));
363
364 let d = OpenAI::backoff_delay(1, None);
366 assert_eq!(d, Duration::from_secs(1));
367
368 let d = OpenAI::backoff_delay(2, None);
370 assert_eq!(d, Duration::from_secs(2));
371
372 let d = OpenAI::backoff_delay(0, Some(5.0));
374 assert_eq!(d, Duration::from_secs(5));
375
376 let d = OpenAI::backoff_delay(3, Some(0.1));
378 assert_eq!(d, Duration::from_secs(4));
379
380 let d = OpenAI::backoff_delay(10, None);
382 assert_eq!(d, Duration::from_secs(60));
383 }
384
385 #[tokio::test]
386 async fn test_get_success() {
387 let mut server = mockito::Server::new_async().await;
388 let mock = server
389 .mock("GET", "/models/gpt-4")
390 .with_status(200)
391 .with_header("content-type", "application/json")
392 .with_body(
393 r#"{"id":"gpt-4","object":"model","created":1687882411,"owned_by":"openai"}"#,
394 )
395 .create_async()
396 .await;
397
398 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
399
400 #[derive(serde::Deserialize)]
401 struct Model {
402 id: String,
403 object: String,
404 }
405
406 let model: Model = client.get("/models/gpt-4").await.unwrap();
407 assert_eq!(model.id, "gpt-4");
408 assert_eq!(model.object, "model");
409 mock.assert_async().await;
410 }
411
412 #[tokio::test]
413 async fn test_post_success() {
414 let mut server = mockito::Server::new_async().await;
415 let mock = server
416 .mock("POST", "/chat/completions")
417 .match_header("authorization", "Bearer sk-test")
418 .match_header("content-type", "application/json")
419 .with_status(200)
420 .with_header("content-type", "application/json")
421 .with_body(r#"{"id":"chatcmpl-123","object":"chat.completion"}"#)
422 .create_async()
423 .await;
424
425 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
426
427 #[derive(serde::Serialize)]
428 struct Req {
429 model: String,
430 }
431 #[derive(serde::Deserialize)]
432 struct Resp {
433 id: String,
434 }
435
436 let resp: Resp = client
437 .post(
438 "/chat/completions",
439 &Req {
440 model: "gpt-4".into(),
441 },
442 )
443 .await
444 .unwrap();
445 assert_eq!(resp.id, "chatcmpl-123");
446 mock.assert_async().await;
447 }
448
449 #[tokio::test]
450 async fn test_delete_success() {
451 let mut server = mockito::Server::new_async().await;
452 let mock = server
453 .mock("DELETE", "/models/ft-abc")
454 .with_status(200)
455 .with_header("content-type", "application/json")
456 .with_body(r#"{"id":"ft-abc","deleted":true}"#)
457 .create_async()
458 .await;
459
460 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
461
462 #[derive(serde::Deserialize)]
463 struct DeleteResp {
464 id: String,
465 deleted: bool,
466 }
467
468 let resp: DeleteResp = client.delete("/models/ft-abc").await.unwrap();
469 assert_eq!(resp.id, "ft-abc");
470 assert!(resp.deleted);
471 mock.assert_async().await;
472 }
473
474 #[tokio::test]
475 async fn test_api_error_response() {
476 let mut server = mockito::Server::new_async().await;
477 let mock = server
478 .mock("GET", "/models/nonexistent")
479 .with_status(404)
480 .with_header("content-type", "application/json")
481 .with_body(
482 r#"{"error":{"message":"The model 'nonexistent' does not exist","type":"invalid_request_error","param":null,"code":"model_not_found"}}"#,
483 )
484 .create_async()
485 .await;
486
487 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
488
489 #[derive(Debug, serde::Deserialize)]
490 struct Model {
491 _id: String,
492 }
493
494 let err = client
495 .get::<Model>("/models/nonexistent")
496 .await
497 .unwrap_err();
498 match err {
499 OpenAIError::ApiError {
500 status,
501 message,
502 type_,
503 code,
504 } => {
505 assert_eq!(status, 404);
506 assert!(message.contains("does not exist"));
507 assert_eq!(type_.as_deref(), Some("invalid_request_error"));
508 assert_eq!(code.as_deref(), Some("model_not_found"));
509 }
510 other => panic!("expected ApiError, got: {other:?}"),
511 }
512 mock.assert_async().await;
513 }
514
515 #[tokio::test]
516 async fn test_auth_headers() {
517 let mut server = mockito::Server::new_async().await;
518 let mock = server
519 .mock("GET", "/test")
520 .match_header("authorization", "Bearer sk-key")
521 .match_header("OpenAI-Organization", "org-abc")
522 .match_header("OpenAI-Project", "proj-xyz")
523 .with_status(200)
524 .with_body(r#"{"ok":true}"#)
525 .create_async()
526 .await;
527
528 let client = OpenAI::with_config(
529 ClientConfig::new("sk-key")
530 .base_url(server.url())
531 .organization("org-abc")
532 .project("proj-xyz"),
533 );
534
535 #[derive(serde::Deserialize)]
536 struct Resp {
537 ok: bool,
538 }
539
540 let resp: Resp = client.get("/test").await.unwrap();
541 assert!(resp.ok);
542 mock.assert_async().await;
543 }
544
545 #[tokio::test]
546 async fn test_retry_on_429_then_success() {
547 let mut server = mockito::Server::new_async().await;
548
549 let _mock_429 = server
551 .mock("GET", "/test")
552 .with_status(429)
553 .with_header("retry-after", "0")
554 .with_body(r#"{"error":{"message":"Rate limited","type":"rate_limit_error","param":null,"code":null}}"#)
555 .create_async()
556 .await;
557
558 let mock_200 = server
559 .mock("GET", "/test")
560 .with_status(200)
561 .with_body(r#"{"ok":true}"#)
562 .create_async()
563 .await;
564
565 let client = OpenAI::with_config(
566 ClientConfig::new("sk-test")
567 .base_url(server.url())
568 .max_retries(2),
569 );
570
571 #[derive(serde::Deserialize)]
572 struct Resp {
573 ok: bool,
574 }
575
576 let resp: Resp = client.get("/test").await.unwrap();
577 assert!(resp.ok);
578 mock_200.assert_async().await;
579 }
580
581 #[tokio::test]
582 async fn test_retry_exhausted_returns_last_error() {
583 let mut server = mockito::Server::new_async().await;
584
585 let _mock = server
587 .mock("GET", "/test")
588 .with_status(500)
589 .with_body(r#"{"error":{"message":"Internal server error","type":"server_error","param":null,"code":null}}"#)
590 .expect_at_least(2)
591 .create_async()
592 .await;
593
594 let client = OpenAI::with_config(
595 ClientConfig::new("sk-test")
596 .base_url(server.url())
597 .max_retries(1),
598 );
599
600 #[derive(Debug, serde::Deserialize)]
601 struct Resp {
602 _ok: bool,
603 }
604
605 let err = client.get::<Resp>("/test").await.unwrap_err();
606 match err {
607 OpenAIError::ApiError { status, .. } => assert_eq!(status, 500),
608 other => panic!("expected ApiError, got: {other:?}"),
609 }
610 }
611
612 #[tokio::test]
613 async fn test_no_retry_on_400() {
614 let mut server = mockito::Server::new_async().await;
615
616 let mock = server
618 .mock("GET", "/test")
619 .with_status(400)
620 .with_body(r#"{"error":{"message":"Bad request","type":"invalid_request_error","param":null,"code":null}}"#)
621 .expect(1)
622 .create_async()
623 .await;
624
625 let client = OpenAI::with_config(
626 ClientConfig::new("sk-test")
627 .base_url(server.url())
628 .max_retries(2),
629 );
630
631 #[derive(Debug, serde::Deserialize)]
632 struct Resp {
633 _ok: bool,
634 }
635
636 let err = client.get::<Resp>("/test").await.unwrap_err();
637 match err {
638 OpenAIError::ApiError { status, .. } => assert_eq!(status, 400),
639 other => panic!("expected ApiError, got: {other:?}"),
640 }
641 mock.assert_async().await;
642 }
643
644 #[tokio::test]
645 async fn test_zero_retries_no_retry() {
646 let mut server = mockito::Server::new_async().await;
647
648 let mock = server
649 .mock("GET", "/test")
650 .with_status(429)
651 .with_body(r#"{"error":{"message":"Rate limited","type":"rate_limit_error","param":null,"code":null}}"#)
652 .expect(1)
653 .create_async()
654 .await;
655
656 let client = OpenAI::with_config(
657 ClientConfig::new("sk-test")
658 .base_url(server.url())
659 .max_retries(0),
660 );
661
662 #[derive(Debug, serde::Deserialize)]
663 struct Resp {
664 _ok: bool,
665 }
666
667 let err = client.get::<Resp>("/test").await.unwrap_err();
668 match err {
669 OpenAIError::ApiError { status, .. } => assert_eq!(status, 429),
670 other => panic!("expected ApiError, got: {other:?}"),
671 }
672 mock.assert_async().await;
673 }
674}