1use std::time::Duration;
4
5use crate::config::ClientConfig;
6use crate::error::{ErrorResponse, OpenAIError};
7use crate::resources::audio::Audio;
8use crate::resources::batches::Batches;
9use crate::resources::beta::assistants::Assistants;
10use crate::resources::beta::runs::Runs;
11use crate::resources::beta::threads::Threads;
12use crate::resources::beta::vector_stores::VectorStores;
13use crate::resources::chat::Chat;
14use crate::resources::embeddings::Embeddings;
15use crate::resources::files::Files;
16use crate::resources::fine_tuning::FineTuning;
17use crate::resources::images::Images;
18use crate::resources::models::Models;
19use crate::resources::moderations::Moderations;
20use crate::resources::responses::Responses;
21use crate::resources::uploads::Uploads;
22
23const RETRYABLE_STATUS_CODES: [u16; 4] = [429, 500, 502, 503];
25
26#[derive(Debug, Clone)]
28pub struct OpenAI {
29 pub(crate) http: reqwest::Client,
30 pub(crate) config: ClientConfig,
31}
32
33impl OpenAI {
34 pub fn new(api_key: impl Into<String>) -> Self {
36 Self::with_config(ClientConfig::new(api_key))
37 }
38
39 pub fn with_config(config: ClientConfig) -> Self {
41 let http = reqwest::Client::builder()
42 .timeout(Duration::from_secs(config.timeout_secs))
43 .build()
44 .expect("failed to build HTTP client");
45 Self { http, config }
46 }
47
48 pub fn from_env() -> Result<Self, OpenAIError> {
50 Ok(Self::with_config(ClientConfig::from_env()?))
51 }
52
53 pub fn batches(&self) -> Batches<'_> {
55 Batches::new(self)
56 }
57
58 pub fn uploads(&self) -> Uploads<'_> {
60 Uploads::new(self)
61 }
62
63 pub fn beta(&self) -> Beta<'_> {
65 Beta { client: self }
66 }
67
68 pub fn audio(&self) -> Audio<'_> {
70 Audio::new(self)
71 }
72
73 pub fn chat(&self) -> Chat<'_> {
75 Chat::new(self)
76 }
77
78 pub fn models(&self) -> Models<'_> {
80 Models::new(self)
81 }
82
83 pub fn fine_tuning(&self) -> FineTuning<'_> {
85 FineTuning::new(self)
86 }
87
88 pub fn files(&self) -> Files<'_> {
90 Files::new(self)
91 }
92
93 pub fn images(&self) -> Images<'_> {
95 Images::new(self)
96 }
97
98 pub fn moderations(&self) -> Moderations<'_> {
100 Moderations::new(self)
101 }
102
103 pub fn responses(&self) -> Responses<'_> {
105 Responses::new(self)
106 }
107
108 pub fn embeddings(&self) -> Embeddings<'_> {
110 Embeddings::new(self)
111 }
112
113 pub(crate) fn request(&self, method: reqwest::Method, path: &str) -> reqwest::RequestBuilder {
115 let url = format!("{}{}", self.config.base_url, path);
116 let mut req = self
117 .http
118 .request(method, &url)
119 .bearer_auth(&self.config.api_key);
120
121 if let Some(ref org) = self.config.organization {
122 req = req.header("OpenAI-Organization", org);
123 }
124 if let Some(ref project) = self.config.project {
125 req = req.header("OpenAI-Project", project);
126 }
127
128 req
129 }
130
131 #[allow(dead_code)]
133 pub(crate) async fn get<T: serde::de::DeserializeOwned>(
134 &self,
135 path: &str,
136 ) -> Result<T, OpenAIError> {
137 self.send_with_retry(reqwest::Method::GET, path, None::<&()>)
138 .await
139 }
140
141 pub(crate) async fn post<B: serde::Serialize, T: serde::de::DeserializeOwned>(
143 &self,
144 path: &str,
145 body: &B,
146 ) -> Result<T, OpenAIError> {
147 self.send_with_retry(reqwest::Method::POST, path, Some(body))
148 .await
149 }
150
151 pub(crate) async fn post_multipart<T: serde::de::DeserializeOwned>(
153 &self,
154 path: &str,
155 form: reqwest::multipart::Form,
156 ) -> Result<T, OpenAIError> {
157 let response = self
158 .request(reqwest::Method::POST, path)
159 .multipart(form)
160 .send()
161 .await?;
162 Self::handle_response(response).await
163 }
164
165 pub(crate) async fn get_raw(&self, path: &str) -> Result<bytes::Bytes, OpenAIError> {
167 let response = self.request(reqwest::Method::GET, path).send().await?;
168
169 let status = response.status();
170 if status.is_success() {
171 Ok(response.bytes().await?)
172 } else {
173 Err(Self::extract_error(status.as_u16(), response).await)
174 }
175 }
176
177 pub(crate) async fn post_raw<B: serde::Serialize>(
179 &self,
180 path: &str,
181 body: &B,
182 ) -> Result<bytes::Bytes, OpenAIError> {
183 let response = self
184 .request(reqwest::Method::POST, path)
185 .json(body)
186 .send()
187 .await?;
188
189 let status = response.status();
190 if status.is_success() {
191 Ok(response.bytes().await?)
192 } else {
193 Err(Self::extract_error(status.as_u16(), response).await)
194 }
195 }
196
197 #[allow(dead_code)]
199 pub(crate) async fn delete<T: serde::de::DeserializeOwned>(
200 &self,
201 path: &str,
202 ) -> Result<T, OpenAIError> {
203 self.send_with_retry(reqwest::Method::DELETE, path, None::<&()>)
204 .await
205 }
206
207 async fn send_with_retry<B: serde::Serialize, T: serde::de::DeserializeOwned>(
209 &self,
210 method: reqwest::Method,
211 path: &str,
212 body: Option<&B>,
213 ) -> Result<T, OpenAIError> {
214 let max_retries = self.config.max_retries;
215 let mut last_error: Option<OpenAIError> = None;
216
217 for attempt in 0..=max_retries {
218 let mut req = self.request(method.clone(), path);
219 if let Some(b) = body {
220 req = req.json(b);
221 }
222
223 let response = match req.send().await {
224 Ok(resp) => resp,
225 Err(e) => {
226 last_error = Some(OpenAIError::RequestError(e));
227 if attempt < max_retries {
228 tokio::time::sleep(Self::backoff_delay(attempt, None)).await;
229 continue;
230 }
231 break;
232 }
233 };
234
235 let status = response.status().as_u16();
236
237 if !RETRYABLE_STATUS_CODES.contains(&status) || attempt == max_retries {
238 return Self::handle_response(response).await;
239 }
240
241 let retry_after = response
243 .headers()
244 .get("retry-after")
245 .and_then(|v| v.to_str().ok())
246 .and_then(|v| v.parse::<f64>().ok());
247
248 last_error = Some(Self::extract_error(status, response).await);
249 tokio::time::sleep(Self::backoff_delay(attempt, retry_after)).await;
250 }
251
252 Err(last_error.unwrap_or_else(|| {
253 OpenAIError::InvalidArgument("retry loop exhausted without error".to_string())
254 }))
255 }
256
257 fn backoff_delay(attempt: u32, retry_after_secs: Option<f64>) -> Duration {
259 let exponential = 0.5 * 2.0_f64.powi(attempt as i32);
260 let secs = match retry_after_secs {
261 Some(ra) => ra.max(exponential),
262 None => exponential,
263 };
264 Duration::from_secs_f64(secs.min(60.0))
265 }
266
267 pub(crate) async fn handle_response<T: serde::de::DeserializeOwned>(
269 response: reqwest::Response,
270 ) -> Result<T, OpenAIError> {
271 let status = response.status();
272 if status.is_success() {
273 let body = response.text().await?;
274 let value: T = serde_json::from_str(&body)?;
275 Ok(value)
276 } else {
277 Err(Self::extract_error(status.as_u16(), response).await)
278 }
279 }
280
281 async fn extract_error(status: u16, response: reqwest::Response) -> OpenAIError {
283 let body = response.text().await.unwrap_or_default();
284 if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&body) {
285 OpenAIError::ApiError {
286 status,
287 message: error_resp.error.message,
288 type_: error_resp.error.type_,
289 code: error_resp.error.code,
290 }
291 } else {
292 OpenAIError::ApiError {
293 status,
294 message: body,
295 type_: None,
296 code: None,
297 }
298 }
299 }
300}
301
302pub struct Beta<'a> {
304 client: &'a OpenAI,
305}
306
307impl<'a> Beta<'a> {
308 pub fn assistants(&self) -> Assistants<'_> {
310 Assistants::new(self.client)
311 }
312
313 pub fn threads(&self) -> Threads<'_> {
315 Threads::new(self.client)
316 }
317
318 pub fn runs(&self, thread_id: &str) -> Runs<'_> {
320 Runs::new(self.client, thread_id.to_string())
321 }
322
323 pub fn vector_stores(&self) -> VectorStores<'_> {
325 VectorStores::new(self.client)
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[test]
334 fn test_new_client() {
335 let client = OpenAI::new("sk-test-key");
336 assert_eq!(client.config.api_key, "sk-test-key");
337 assert_eq!(client.config.base_url, "https://api.openai.com/v1");
338 }
339
340 #[test]
341 fn test_with_config() {
342 let config = ClientConfig::new("sk-test")
343 .base_url("https://custom.api.com")
344 .organization("org-123")
345 .timeout_secs(30);
346 let client = OpenAI::with_config(config);
347 assert_eq!(client.config.base_url, "https://custom.api.com");
348 assert_eq!(client.config.organization.as_deref(), Some("org-123"));
349 assert_eq!(client.config.timeout_secs, 30);
350 }
351
352 #[test]
353 fn test_backoff_delay() {
354 let d = OpenAI::backoff_delay(0, None);
356 assert_eq!(d, Duration::from_millis(500));
357
358 let d = OpenAI::backoff_delay(1, None);
360 assert_eq!(d, Duration::from_secs(1));
361
362 let d = OpenAI::backoff_delay(2, None);
364 assert_eq!(d, Duration::from_secs(2));
365
366 let d = OpenAI::backoff_delay(0, Some(5.0));
368 assert_eq!(d, Duration::from_secs(5));
369
370 let d = OpenAI::backoff_delay(3, Some(0.1));
372 assert_eq!(d, Duration::from_secs(4));
373
374 let d = OpenAI::backoff_delay(10, None);
376 assert_eq!(d, Duration::from_secs(60));
377 }
378
379 #[tokio::test]
380 async fn test_get_success() {
381 let mut server = mockito::Server::new_async().await;
382 let mock = server
383 .mock("GET", "/models/gpt-4")
384 .with_status(200)
385 .with_header("content-type", "application/json")
386 .with_body(
387 r#"{"id":"gpt-4","object":"model","created":1687882411,"owned_by":"openai"}"#,
388 )
389 .create_async()
390 .await;
391
392 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
393
394 #[derive(serde::Deserialize)]
395 struct Model {
396 id: String,
397 object: String,
398 }
399
400 let model: Model = client.get("/models/gpt-4").await.unwrap();
401 assert_eq!(model.id, "gpt-4");
402 assert_eq!(model.object, "model");
403 mock.assert_async().await;
404 }
405
406 #[tokio::test]
407 async fn test_post_success() {
408 let mut server = mockito::Server::new_async().await;
409 let mock = server
410 .mock("POST", "/chat/completions")
411 .match_header("authorization", "Bearer sk-test")
412 .match_header("content-type", "application/json")
413 .with_status(200)
414 .with_header("content-type", "application/json")
415 .with_body(r#"{"id":"chatcmpl-123","object":"chat.completion"}"#)
416 .create_async()
417 .await;
418
419 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
420
421 #[derive(serde::Serialize)]
422 struct Req {
423 model: String,
424 }
425 #[derive(serde::Deserialize)]
426 struct Resp {
427 id: String,
428 }
429
430 let resp: Resp = client
431 .post(
432 "/chat/completions",
433 &Req {
434 model: "gpt-4".into(),
435 },
436 )
437 .await
438 .unwrap();
439 assert_eq!(resp.id, "chatcmpl-123");
440 mock.assert_async().await;
441 }
442
443 #[tokio::test]
444 async fn test_delete_success() {
445 let mut server = mockito::Server::new_async().await;
446 let mock = server
447 .mock("DELETE", "/models/ft-abc")
448 .with_status(200)
449 .with_header("content-type", "application/json")
450 .with_body(r#"{"id":"ft-abc","deleted":true}"#)
451 .create_async()
452 .await;
453
454 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
455
456 #[derive(serde::Deserialize)]
457 struct DeleteResp {
458 id: String,
459 deleted: bool,
460 }
461
462 let resp: DeleteResp = client.delete("/models/ft-abc").await.unwrap();
463 assert_eq!(resp.id, "ft-abc");
464 assert!(resp.deleted);
465 mock.assert_async().await;
466 }
467
468 #[tokio::test]
469 async fn test_api_error_response() {
470 let mut server = mockito::Server::new_async().await;
471 let mock = server
472 .mock("GET", "/models/nonexistent")
473 .with_status(404)
474 .with_header("content-type", "application/json")
475 .with_body(
476 r#"{"error":{"message":"The model 'nonexistent' does not exist","type":"invalid_request_error","param":null,"code":"model_not_found"}}"#,
477 )
478 .create_async()
479 .await;
480
481 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
482
483 #[derive(Debug, serde::Deserialize)]
484 struct Model {
485 _id: String,
486 }
487
488 let err = client
489 .get::<Model>("/models/nonexistent")
490 .await
491 .unwrap_err();
492 match err {
493 OpenAIError::ApiError {
494 status,
495 message,
496 type_,
497 code,
498 } => {
499 assert_eq!(status, 404);
500 assert!(message.contains("does not exist"));
501 assert_eq!(type_.as_deref(), Some("invalid_request_error"));
502 assert_eq!(code.as_deref(), Some("model_not_found"));
503 }
504 other => panic!("expected ApiError, got: {other:?}"),
505 }
506 mock.assert_async().await;
507 }
508
509 #[tokio::test]
510 async fn test_auth_headers() {
511 let mut server = mockito::Server::new_async().await;
512 let mock = server
513 .mock("GET", "/test")
514 .match_header("authorization", "Bearer sk-key")
515 .match_header("OpenAI-Organization", "org-abc")
516 .match_header("OpenAI-Project", "proj-xyz")
517 .with_status(200)
518 .with_body(r#"{"ok":true}"#)
519 .create_async()
520 .await;
521
522 let client = OpenAI::with_config(
523 ClientConfig::new("sk-key")
524 .base_url(server.url())
525 .organization("org-abc")
526 .project("proj-xyz"),
527 );
528
529 #[derive(serde::Deserialize)]
530 struct Resp {
531 ok: bool,
532 }
533
534 let resp: Resp = client.get("/test").await.unwrap();
535 assert!(resp.ok);
536 mock.assert_async().await;
537 }
538
539 #[tokio::test]
540 async fn test_retry_on_429_then_success() {
541 let mut server = mockito::Server::new_async().await;
542
543 let _mock_429 = server
545 .mock("GET", "/test")
546 .with_status(429)
547 .with_header("retry-after", "0")
548 .with_body(r#"{"error":{"message":"Rate limited","type":"rate_limit_error","param":null,"code":null}}"#)
549 .create_async()
550 .await;
551
552 let mock_200 = server
553 .mock("GET", "/test")
554 .with_status(200)
555 .with_body(r#"{"ok":true}"#)
556 .create_async()
557 .await;
558
559 let client = OpenAI::with_config(
560 ClientConfig::new("sk-test")
561 .base_url(server.url())
562 .max_retries(2),
563 );
564
565 #[derive(serde::Deserialize)]
566 struct Resp {
567 ok: bool,
568 }
569
570 let resp: Resp = client.get("/test").await.unwrap();
571 assert!(resp.ok);
572 mock_200.assert_async().await;
573 }
574
575 #[tokio::test]
576 async fn test_retry_exhausted_returns_last_error() {
577 let mut server = mockito::Server::new_async().await;
578
579 let _mock = server
581 .mock("GET", "/test")
582 .with_status(500)
583 .with_body(r#"{"error":{"message":"Internal server error","type":"server_error","param":null,"code":null}}"#)
584 .expect_at_least(2)
585 .create_async()
586 .await;
587
588 let client = OpenAI::with_config(
589 ClientConfig::new("sk-test")
590 .base_url(server.url())
591 .max_retries(1),
592 );
593
594 #[derive(Debug, serde::Deserialize)]
595 struct Resp {
596 _ok: bool,
597 }
598
599 let err = client.get::<Resp>("/test").await.unwrap_err();
600 match err {
601 OpenAIError::ApiError { status, .. } => assert_eq!(status, 500),
602 other => panic!("expected ApiError, got: {other:?}"),
603 }
604 }
605
606 #[tokio::test]
607 async fn test_no_retry_on_400() {
608 let mut server = mockito::Server::new_async().await;
609
610 let mock = server
612 .mock("GET", "/test")
613 .with_status(400)
614 .with_body(r#"{"error":{"message":"Bad request","type":"invalid_request_error","param":null,"code":null}}"#)
615 .expect(1)
616 .create_async()
617 .await;
618
619 let client = OpenAI::with_config(
620 ClientConfig::new("sk-test")
621 .base_url(server.url())
622 .max_retries(2),
623 );
624
625 #[derive(Debug, serde::Deserialize)]
626 struct Resp {
627 _ok: bool,
628 }
629
630 let err = client.get::<Resp>("/test").await.unwrap_err();
631 match err {
632 OpenAIError::ApiError { status, .. } => assert_eq!(status, 400),
633 other => panic!("expected ApiError, got: {other:?}"),
634 }
635 mock.assert_async().await;
636 }
637
638 #[tokio::test]
639 async fn test_zero_retries_no_retry() {
640 let mut server = mockito::Server::new_async().await;
641
642 let mock = server
643 .mock("GET", "/test")
644 .with_status(429)
645 .with_body(r#"{"error":{"message":"Rate limited","type":"rate_limit_error","param":null,"code":null}}"#)
646 .expect(1)
647 .create_async()
648 .await;
649
650 let client = OpenAI::with_config(
651 ClientConfig::new("sk-test")
652 .base_url(server.url())
653 .max_retries(0),
654 );
655
656 #[derive(Debug, serde::Deserialize)]
657 struct Resp {
658 _ok: bool,
659 }
660
661 let err = client.get::<Resp>("/test").await.unwrap_err();
662 match err {
663 OpenAIError::ApiError { status, .. } => assert_eq!(status, 429),
664 other => panic!("expected ApiError, got: {other:?}"),
665 }
666 mock.assert_async().await;
667 }
668}