1use std::time::Duration;
4
5use crate::azure::AzureConfig;
6use crate::config::ClientConfig;
7use crate::error::{ErrorResponse, OpenAIError};
8use crate::request_options::RequestOptions;
9#[cfg(feature = "audio")]
10use crate::resources::audio::Audio;
11#[cfg(feature = "batches")]
12use crate::resources::batches::Batches;
13#[cfg(feature = "beta")]
14use crate::resources::beta::assistants::Assistants;
15#[cfg(feature = "beta")]
16use crate::resources::beta::realtime::Realtime;
17#[cfg(feature = "beta")]
18use crate::resources::beta::runs::Runs;
19#[cfg(feature = "beta")]
20use crate::resources::beta::threads::Threads;
21#[cfg(feature = "beta")]
22use crate::resources::beta::vector_stores::VectorStores;
23#[cfg(feature = "chat")]
24use crate::resources::chat::Chat;
25#[cfg(feature = "embeddings")]
26use crate::resources::embeddings::Embeddings;
27#[cfg(feature = "files")]
28use crate::resources::files::Files;
29#[cfg(feature = "fine-tuning")]
30use crate::resources::fine_tuning::FineTuning;
31#[cfg(feature = "images")]
32use crate::resources::images::Images;
33#[cfg(feature = "models")]
34use crate::resources::models::Models;
35#[cfg(feature = "moderations")]
36use crate::resources::moderations::Moderations;
37#[cfg(feature = "responses")]
38use crate::resources::responses::Responses;
39#[cfg(feature = "uploads")]
40use crate::resources::uploads::Uploads;
41
42const RETRYABLE_STATUS_CODES: [u16; 4] = [429, 500, 502, 503];
44
45#[derive(Debug, Clone)]
63pub struct OpenAI {
64 pub(crate) http: reqwest::Client,
65 pub(crate) config: std::sync::Arc<dyn crate::config::Config>,
66 pub(crate) options: RequestOptions,
67}
68
69impl OpenAI {
70 pub fn new(api_key: impl Into<String>) -> Self {
72 Self::with_config(ClientConfig::new(api_key))
73 }
74
75 pub fn with_config<C: crate::config::Config + 'static>(config: C) -> Self {
77 let options = config.initial_options();
78
79 #[cfg(not(target_arch = "wasm32"))]
80 let http = {
81 crate::ensure_tls_provider();
82
83 reqwest::Client::builder()
84 .timeout(Duration::from_secs(config.timeout_secs()))
85 .tcp_nodelay(true)
86 .tcp_keepalive(Some(Duration::from_secs(30)))
87 .pool_idle_timeout(Some(Duration::from_secs(300)))
88 .pool_max_idle_per_host(4)
89 .http2_keep_alive_interval(Some(Duration::from_secs(20)))
90 .http2_keep_alive_timeout(Duration::from_secs(10))
91 .http2_keep_alive_while_idle(true)
92 .http2_adaptive_window(true)
93 .gzip(true)
94 .build()
95 .expect("failed to build HTTP client")
96 };
97
98 #[cfg(target_arch = "wasm32")]
99 let http = reqwest::Client::new();
100 Self {
101 http,
102 config: std::sync::Arc::new(config),
103 options,
104 }
105 }
106
107 #[must_use]
122 pub fn with_options(&self, options: RequestOptions) -> Self {
123 Self {
124 http: self.http.clone(),
125 config: self.config.clone(),
126 options: self.options.merge(&options),
127 }
128 }
129
130 pub fn from_env() -> Result<Self, OpenAIError> {
132 Ok(Self::with_config(ClientConfig::from_env()?))
133 }
134
135 pub fn azure(config: AzureConfig) -> Result<Self, OpenAIError> {
150 config.build()
151 }
152
153 #[cfg(feature = "batches")]
155 pub fn batches(&self) -> Batches<'_> {
156 Batches::new(self)
157 }
158
159 #[cfg(feature = "uploads")]
161 pub fn uploads(&self) -> Uploads<'_> {
162 Uploads::new(self)
163 }
164
165 #[cfg(feature = "beta")]
167 pub fn beta(&self) -> Beta<'_> {
168 Beta { client: self }
169 }
170
171 #[cfg(feature = "audio")]
173 pub fn audio(&self) -> Audio<'_> {
174 Audio::new(self)
175 }
176
177 #[cfg(feature = "chat")]
179 pub fn chat(&self) -> Chat<'_> {
180 Chat::new(self)
181 }
182
183 #[cfg(feature = "models")]
185 pub fn models(&self) -> Models<'_> {
186 Models::new(self)
187 }
188
189 #[cfg(feature = "fine-tuning")]
191 pub fn fine_tuning(&self) -> FineTuning<'_> {
192 FineTuning::new(self)
193 }
194
195 #[cfg(feature = "files")]
197 pub fn files(&self) -> Files<'_> {
198 Files::new(self)
199 }
200
201 #[cfg(feature = "images")]
203 pub fn images(&self) -> Images<'_> {
204 Images::new(self)
205 }
206
207 #[cfg(feature = "moderations")]
209 pub fn moderations(&self) -> Moderations<'_> {
210 Moderations::new(self)
211 }
212
213 #[cfg(feature = "responses")]
215 pub fn responses(&self) -> Responses<'_> {
216 Responses::new(self)
217 }
218
219 #[cfg(feature = "embeddings")]
221 pub fn embeddings(&self) -> Embeddings<'_> {
222 Embeddings::new(self)
223 }
224
225 pub fn conversations(&self) -> crate::resources::conversations::Conversations<'_> {
227 crate::resources::conversations::Conversations::new(self)
228 }
229
230 pub fn videos(&self) -> crate::resources::videos::Videos<'_> {
232 crate::resources::videos::Videos::new(self)
233 }
234
235 #[cfg(feature = "websocket")]
249 pub async fn ws_session(&self) -> Result<crate::websocket::WsSession, OpenAIError> {
250 crate::websocket::WsSession::connect(self.config.as_ref()).await
251 }
252
253 pub(crate) fn request(&self, method: reqwest::Method, path: &str) -> reqwest::RequestBuilder {
255 let url = format!("{}{}", self.config.base_url(), path);
256 let req = self.http.request(method, &url);
257 let mut req = self.config.build_request(req);
258
259 if let Some(ref headers) = self.options.headers {
261 for (key, value) in headers.iter() {
262 req = req.header(key.clone(), value.clone());
263 }
264 }
265 #[cfg(not(target_arch = "wasm32"))]
266 if let Some(ref query) = self.options.query {
267 req = req.query(query);
268 }
269 #[cfg(not(target_arch = "wasm32"))]
270 if let Some(timeout) = self.options.timeout {
271 req = req.timeout(timeout);
272 }
273
274 req
275 }
276
277 #[allow(dead_code)]
279 pub(crate) async fn get<T: serde::de::DeserializeOwned>(
280 &self,
281 path: &str,
282 ) -> Result<T, OpenAIError> {
283 self.send_with_retry(reqwest::Method::GET, path, None::<&()>)
284 .await
285 }
286
287 #[allow(dead_code)]
289 #[cfg(not(target_arch = "wasm32"))]
290 pub(crate) async fn get_with_query<T: serde::de::DeserializeOwned>(
291 &self,
292 path: &str,
293 query: &[(String, String)],
294 ) -> Result<T, OpenAIError> {
295 let mut req = self.request(reqwest::Method::GET, path);
296 if !query.is_empty() {
297 req = req.query(query);
298 }
299 let response = req.send().await?;
300 Self::handle_response(response).await
301 }
302
303 pub(crate) async fn post<B: serde::Serialize, T: serde::de::DeserializeOwned>(
305 &self,
306 path: &str,
307 body: &B,
308 ) -> Result<T, OpenAIError> {
309 self.send_with_retry(reqwest::Method::POST, path, Some(body))
310 .await
311 }
312
313 pub(crate) async fn post_json<B: serde::Serialize>(
319 &self,
320 path: &str,
321 body: &B,
322 ) -> Result<serde_json::Value, OpenAIError> {
323 self.post(path, body).await
324 }
325
326 pub(crate) async fn post_empty<T: serde::de::DeserializeOwned>(
328 &self,
329 path: &str,
330 ) -> Result<T, OpenAIError> {
331 self.send_with_retry(reqwest::Method::POST, path, None::<&()>)
332 .await
333 }
334
335 #[cfg(not(target_arch = "wasm32"))]
337 pub(crate) async fn post_multipart<T: serde::de::DeserializeOwned>(
338 &self,
339 path: &str,
340 form: reqwest::multipart::Form,
341 ) -> Result<T, OpenAIError> {
342 let response = self
343 .request(reqwest::Method::POST, path)
344 .multipart(form)
345 .send()
346 .await?;
347 Self::handle_response(response).await
348 }
349
350 pub(crate) async fn get_raw(&self, path: &str) -> Result<bytes::Bytes, OpenAIError> {
352 let response = self.request(reqwest::Method::GET, path).send().await?;
353
354 let status = response.status();
355 if status.is_success() {
356 Ok(response.bytes().await?)
357 } else {
358 Err(Self::extract_error(status.as_u16(), response).await)
359 }
360 }
361
362 pub(crate) async fn post_raw<B: serde::Serialize>(
364 &self,
365 path: &str,
366 body: &B,
367 ) -> Result<bytes::Bytes, OpenAIError> {
368 let mut req = self.request(reqwest::Method::POST, path);
369 if self.options.extra_body.is_some() {
370 req = req.json(&self.merge_body_json(body)?);
371 } else {
372 req = req.json(body);
373 }
374 let response = req.send().await?;
375
376 let status = response.status();
377 if status.is_success() {
378 Ok(response.bytes().await?)
379 } else {
380 Err(Self::extract_error(status.as_u16(), response).await)
381 }
382 }
383
384 #[allow(dead_code)]
386 pub(crate) async fn delete<T: serde::de::DeserializeOwned>(
387 &self,
388 path: &str,
389 ) -> Result<T, OpenAIError> {
390 self.send_with_retry(reqwest::Method::DELETE, path, None::<&()>)
391 .await
392 }
393
394 fn merge_body_json<B: serde::Serialize>(
396 &self,
397 body: &B,
398 ) -> Result<serde_json::Value, OpenAIError> {
399 let mut value = serde_json::to_value(body)?;
400 if let Some(ref extra) = self.options.extra_body
401 && let serde_json::Value::Object(map) = &mut value
402 && let serde_json::Value::Object(extra_map) = extra.clone()
403 {
404 for (k, v) in extra_map {
405 map.insert(k, v);
406 }
407 }
408 Ok(value)
409 }
410
411 fn prepare_body<B: serde::Serialize>(
413 &self,
414 body: Option<&B>,
415 ) -> Result<Option<serde_json::Value>, OpenAIError> {
416 match body {
417 Some(b) if self.options.extra_body.is_some() => Ok(Some(self.merge_body_json(b)?)),
418 Some(b) => Ok(Some(serde_json::to_value(b)?)),
419 None => Ok(None),
420 }
421 }
422
423 #[cfg(target_arch = "wasm32")]
425 async fn send_with_retry<B: serde::Serialize, T: serde::de::DeserializeOwned>(
426 &self,
427 method: reqwest::Method,
428 path: &str,
429 body: Option<&B>,
430 ) -> Result<T, OpenAIError> {
431 let body_value = self.prepare_body(body)?;
432
433 for attempt in 0..=self.config.max_retries {
434 let mut req = self.request(method.clone(), path);
435 if let Some(ref val) = body_value {
436 req = req.json(val);
437 }
438
439 let response = match req.send().await {
440 Ok(resp) => resp,
441 Err(e) if attempt == self.config.max_retries => {
442 return Err(OpenAIError::RequestError(e));
443 }
444 Err(_) => {
445 crate::runtime::sleep(crate::runtime::backoff_ms(attempt)).await;
446 continue;
447 }
448 };
449
450 let status = response.status().as_u16();
451 if !RETRYABLE_STATUS_CODES.contains(&status) || attempt == self.config.max_retries {
452 return Self::handle_response(response).await;
453 }
454
455 crate::runtime::sleep(crate::runtime::backoff_ms(attempt)).await;
456 }
457
458 Err(OpenAIError::InvalidArgument("retry exhausted".into()))
459 }
460
461 #[cfg(not(target_arch = "wasm32"))]
466 async fn send_with_retry<B: serde::Serialize, T: serde::de::DeserializeOwned>(
467 &self,
468 method: reqwest::Method,
469 path: &str,
470 body: Option<&B>,
471 ) -> Result<T, OpenAIError> {
472 let body_value = self.prepare_body(body)?;
473
474 let mut req = self.request(method.clone(), path);
476 if let Some(ref val) = body_value {
477 req = req.json(val);
478 }
479
480 let response = match req.send().await {
481 Ok(resp) => resp,
482 Err(e) if self.config.max_retries() == 0 => return Err(OpenAIError::RequestError(e)),
483 Err(e) => {
484 return self.retry_loop(method, path, &body_value, e, 1).await;
486 }
487 };
488
489 let status = response.status().as_u16();
490 if !RETRYABLE_STATUS_CODES.contains(&status) {
491 return Self::handle_response(response).await;
492 }
493
494 if self.config.max_retries() == 0 {
495 return Self::handle_response(response).await;
496 }
497
498 let retry_after = response
500 .headers()
501 .get("retry-after")
502 .and_then(|v| v.to_str().ok())
503 .and_then(|v| v.parse::<f64>().ok());
504 let last_error = Self::extract_error(status, response).await;
505 tokio::time::sleep(Self::backoff_delay(0, retry_after)).await;
506 self.retry_loop(method, path, &body_value, last_error, 1)
507 .await
508 }
509
510 #[cfg(not(target_arch = "wasm32"))]
512 async fn retry_loop<T: serde::de::DeserializeOwned>(
513 &self,
514 method: reqwest::Method,
515 path: &str,
516 body_value: &Option<serde_json::Value>,
517 initial_error: impl Into<OpenAIError>,
518 start_attempt: u32,
519 ) -> Result<T, OpenAIError> {
520 let max_retries = self.config.max_retries();
521 let mut last_error: OpenAIError = initial_error.into();
522
523 for attempt in start_attempt..=max_retries {
524 let mut req = self.request(method.clone(), path);
525 if let Some(val) = body_value {
526 req = req.json(val);
527 }
528
529 let response = match req.send().await {
530 Ok(resp) => resp,
531 Err(e) => {
532 last_error = OpenAIError::RequestError(e);
533 if attempt < max_retries {
534 tokio::time::sleep(Self::backoff_delay(attempt, None)).await;
535 continue;
536 }
537 break;
538 }
539 };
540
541 let status = response.status().as_u16();
542 if !RETRYABLE_STATUS_CODES.contains(&status) || attempt == max_retries {
543 return Self::handle_response(response).await;
544 }
545
546 let retry_after = response
547 .headers()
548 .get("retry-after")
549 .and_then(|v| v.to_str().ok())
550 .and_then(|v| v.parse::<f64>().ok());
551 last_error = Self::extract_error(status, response).await;
552 tokio::time::sleep(Self::backoff_delay(attempt, retry_after)).await;
553 }
554
555 Err(last_error)
556 }
557
558 #[cfg(not(target_arch = "wasm32"))]
563 pub(crate) async fn send_raw_with_retry(
564 &self,
565 builder: reqwest::RequestBuilder,
566 ) -> Result<reqwest::Response, OpenAIError> {
567 let response = match builder.try_clone() {
569 Some(cloned) => match cloned.send().await {
570 Ok(resp) => resp,
571 Err(e) if self.config.max_retries() == 0 => {
572 return Err(OpenAIError::RequestError(e));
573 }
574 Err(e) => {
575 return self
576 .retry_loop_raw(builder, OpenAIError::RequestError(e), 1)
577 .await;
578 }
579 },
580 None => {
581 return Ok(builder.send().await?);
583 }
584 };
585
586 let status = response.status().as_u16();
587 if !RETRYABLE_STATUS_CODES.contains(&status) {
588 return Ok(response);
589 }
590 if self.config.max_retries() == 0 {
591 return Ok(response);
592 }
593
594 let retry_after = response
595 .headers()
596 .get("retry-after")
597 .and_then(|v| v.to_str().ok())
598 .and_then(|v| v.parse::<f64>().ok());
599 let last_error = Self::extract_error(status, response).await;
600 tokio::time::sleep(Self::backoff_delay(0, retry_after)).await;
601 self.retry_loop_raw(builder, last_error, 1).await
602 }
603
604 #[cfg(not(target_arch = "wasm32"))]
606 async fn retry_loop_raw(
607 &self,
608 builder: reqwest::RequestBuilder,
609 initial_error: OpenAIError,
610 start_attempt: u32,
611 ) -> Result<reqwest::Response, OpenAIError> {
612 let max_retries = self.config.max_retries();
613 let mut last_error = initial_error;
614
615 for attempt in start_attempt..=max_retries {
616 let req = match builder.try_clone() {
617 Some(cloned) => cloned,
618 None => return Err(last_error),
619 };
620
621 let response = match req.send().await {
622 Ok(resp) => resp,
623 Err(e) => {
624 last_error = OpenAIError::RequestError(e);
625 if attempt < max_retries {
626 tokio::time::sleep(Self::backoff_delay(attempt, None)).await;
627 continue;
628 }
629 break;
630 }
631 };
632
633 let status = response.status().as_u16();
634 if !RETRYABLE_STATUS_CODES.contains(&status) || attempt == max_retries {
635 return Ok(response);
636 }
637
638 let retry_after = response
639 .headers()
640 .get("retry-after")
641 .and_then(|v| v.to_str().ok())
642 .and_then(|v| v.parse::<f64>().ok());
643 last_error = Self::extract_error(status, response).await;
644 tokio::time::sleep(Self::backoff_delay(attempt, retry_after)).await;
645 }
646
647 Err(last_error)
648 }
649
650 pub(crate) async fn check_stream_response(
652 response: reqwest::Response,
653 ) -> Result<reqwest::Response, OpenAIError> {
654 if response.status().is_success() {
655 Ok(response)
656 } else {
657 Err(Self::extract_error(response.status().as_u16(), response).await)
658 }
659 }
660
661 #[cfg(not(target_arch = "wasm32"))]
663 fn backoff_delay(attempt: u32, retry_after_secs: Option<f64>) -> Duration {
664 let base = crate::runtime::backoff_ms(attempt);
665 match retry_after_secs {
666 Some(ra) => Duration::from_secs_f64(ra.max(base.as_secs_f64())),
667 None => base,
668 }
669 }
670
671 pub(crate) async fn handle_response<T: serde::de::DeserializeOwned>(
678 response: reqwest::Response,
679 ) -> Result<T, OpenAIError> {
680 let status = response.status();
681 if status.is_success() {
682 let body = response.bytes().await?;
683 let result = Self::deserialize_body::<T>(&body);
684 match result {
685 Ok(value) => Ok(value),
686 Err(e) => {
687 tracing::error!(
688 error = %e,
689 body_len = body.len(),
690 body_preview = %String::from_utf8_lossy(&body[..body.len().min(500)]),
691 "failed to deserialize API response"
692 );
693 Err(e)
694 }
695 }
696 } else {
697 Err(Self::extract_error(status.as_u16(), response).await)
698 }
699 }
700
701 #[cfg(feature = "simd")]
703 fn deserialize_body<T: serde::de::DeserializeOwned>(body: &[u8]) -> Result<T, OpenAIError> {
704 let mut buf = body.to_vec();
705 simd_json::from_slice::<T>(&mut buf)
706 .map_err(|e| OpenAIError::StreamError(format!("simd-json: {e}")))
707 }
708
709 #[cfg(not(feature = "simd"))]
711 fn deserialize_body<T: serde::de::DeserializeOwned>(body: &[u8]) -> Result<T, OpenAIError> {
712 serde_json::from_slice::<T>(body).map_err(OpenAIError::from)
713 }
714
715 pub(crate) fn extract_request_id(response: &reqwest::Response) -> Option<String> {
717 response
718 .headers()
719 .get("x-request-id")
720 .and_then(|v| v.to_str().ok())
721 .map(String::from)
722 }
723
724 pub(crate) async fn extract_error(status: u16, response: reqwest::Response) -> OpenAIError {
726 let request_id = Self::extract_request_id(&response);
727 let body = response.text().await.unwrap_or_default();
728 if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&body) {
729 OpenAIError::ApiError {
730 status,
731 message: error_resp.error.message,
732 type_: error_resp.error.type_,
733 code: error_resp.error.code,
734 request_id,
735 }
736 } else {
737 OpenAIError::ApiError {
738 status,
739 message: body,
740 type_: None,
741 code: None,
742 request_id,
743 }
744 }
745 }
746}
747
748#[cfg(feature = "beta")]
750pub struct Beta<'a> {
751 client: &'a OpenAI,
752}
753
754#[cfg(feature = "beta")]
755impl<'a> Beta<'a> {
756 pub fn assistants(&self) -> Assistants<'_> {
758 Assistants::new(self.client)
759 }
760
761 pub fn threads(&self) -> Threads<'_> {
763 Threads::new(self.client)
764 }
765
766 pub fn runs(&self, thread_id: &str) -> Runs<'_> {
768 Runs::new(self.client, thread_id.to_string())
769 }
770
771 pub fn vector_stores(&self) -> VectorStores<'_> {
773 VectorStores::new(self.client)
774 }
775
776 pub fn realtime(&self) -> Realtime<'_> {
778 Realtime::new(self.client)
779 }
780}
781
782#[cfg(test)]
783mod tests {
784 use super::*;
785
786 #[test]
787 fn test_new_client() {
788 let client = OpenAI::new("sk-test-key");
789 assert_eq!(client.config.api_key(), "sk-test-key");
790 assert_eq!(client.config.base_url(), "https://api.openai.com/v1");
791 }
792
793 #[test]
794 fn test_with_config() {
795 let config = ClientConfig::new("sk-test")
796 .base_url("https://custom.api.com")
797 .organization("org-123")
798 .timeout_secs(30);
799 let client = OpenAI::with_config(config);
800 assert_eq!(client.config.base_url(), "https://custom.api.com");
801 assert_eq!(client.config.organization(), Some("org-123"));
802 assert_eq!(client.config.timeout_secs(), 30);
803 }
804
805 #[test]
806 fn test_backoff_delay() {
807 let d = OpenAI::backoff_delay(0, None);
809 assert_eq!(d, Duration::from_millis(500));
810
811 let d = OpenAI::backoff_delay(1, None);
813 assert_eq!(d, Duration::from_secs(1));
814
815 let d = OpenAI::backoff_delay(2, None);
817 assert_eq!(d, Duration::from_secs(2));
818
819 let d = OpenAI::backoff_delay(0, Some(5.0));
821 assert_eq!(d, Duration::from_secs(5));
822
823 let d = OpenAI::backoff_delay(3, Some(0.1));
825 assert_eq!(d, Duration::from_secs(4));
826
827 let d = OpenAI::backoff_delay(10, None);
829 assert_eq!(d, Duration::from_secs(60));
830 }
831
832 #[tokio::test]
833 async fn test_get_success() {
834 let mut server = mockito::Server::new_async().await;
835 let mock = server
836 .mock("GET", "/models/gpt-4")
837 .with_status(200)
838 .with_header("content-type", "application/json")
839 .with_body(
840 r#"{"id":"gpt-4","object":"model","created":1687882411,"owned_by":"openai"}"#,
841 )
842 .create_async()
843 .await;
844
845 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
846
847 #[derive(serde::Deserialize)]
848 struct Model {
849 id: String,
850 object: String,
851 }
852
853 let model: Model = client.get("/models/gpt-4").await.unwrap();
854 assert_eq!(model.id, "gpt-4");
855 assert_eq!(model.object, "model");
856 mock.assert_async().await;
857 }
858
859 #[tokio::test]
860 async fn test_post_success() {
861 let mut server = mockito::Server::new_async().await;
862 let mock = server
863 .mock("POST", "/chat/completions")
864 .match_header("authorization", "Bearer sk-test")
865 .match_header("content-type", "application/json")
866 .with_status(200)
867 .with_header("content-type", "application/json")
868 .with_body(r#"{"id":"chatcmpl-123","object":"chat.completion"}"#)
869 .create_async()
870 .await;
871
872 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
873
874 #[derive(serde::Serialize)]
875 struct Req {
876 model: String,
877 }
878 #[derive(serde::Deserialize)]
879 struct Resp {
880 id: String,
881 }
882
883 let resp: Resp = client
884 .post(
885 "/chat/completions",
886 &Req {
887 model: "gpt-4".into(),
888 },
889 )
890 .await
891 .unwrap();
892 assert_eq!(resp.id, "chatcmpl-123");
893 mock.assert_async().await;
894 }
895
896 #[tokio::test]
897 async fn test_delete_success() {
898 let mut server = mockito::Server::new_async().await;
899 let mock = server
900 .mock("DELETE", "/models/ft-abc")
901 .with_status(200)
902 .with_header("content-type", "application/json")
903 .with_body(r#"{"id":"ft-abc","deleted":true}"#)
904 .create_async()
905 .await;
906
907 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
908
909 #[derive(serde::Deserialize)]
910 struct DeleteResp {
911 id: String,
912 deleted: bool,
913 }
914
915 let resp: DeleteResp = client.delete("/models/ft-abc").await.unwrap();
916 assert_eq!(resp.id, "ft-abc");
917 assert!(resp.deleted);
918 mock.assert_async().await;
919 }
920
921 #[tokio::test]
922 async fn test_api_error_response() {
923 let mut server = mockito::Server::new_async().await;
924 let mock = server
925 .mock("GET", "/models/nonexistent")
926 .with_status(404)
927 .with_header("content-type", "application/json")
928 .with_body(
929 r#"{"error":{"message":"The model 'nonexistent' does not exist","type":"invalid_request_error","param":null,"code":"model_not_found"}}"#,
930 )
931 .create_async()
932 .await;
933
934 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
935
936 #[derive(Debug, serde::Deserialize)]
937 struct Model {
938 _id: String,
939 }
940
941 let err = client
942 .get::<Model>("/models/nonexistent")
943 .await
944 .unwrap_err();
945 match err {
946 OpenAIError::ApiError {
947 status,
948 message,
949 type_,
950 code,
951 ..
952 } => {
953 assert_eq!(status, 404);
954 assert!(message.contains("does not exist"));
955 assert_eq!(type_.as_deref(), Some("invalid_request_error"));
956 assert_eq!(code.as_deref(), Some("model_not_found"));
957 }
958 other => panic!("expected ApiError, got: {other:?}"),
959 }
960 mock.assert_async().await;
961 }
962
963 #[tokio::test]
964 async fn test_auth_headers() {
965 let mut server = mockito::Server::new_async().await;
966 let mock = server
967 .mock("GET", "/test")
968 .match_header("authorization", "Bearer sk-key")
969 .match_header("OpenAI-Organization", "org-abc")
970 .match_header("OpenAI-Project", "proj-xyz")
971 .with_status(200)
972 .with_body(r#"{"ok":true}"#)
973 .create_async()
974 .await;
975
976 let client = OpenAI::with_config(
977 ClientConfig::new("sk-key")
978 .base_url(server.url())
979 .organization("org-abc")
980 .project("proj-xyz"),
981 );
982
983 #[derive(serde::Deserialize)]
984 struct Resp {
985 ok: bool,
986 }
987
988 let resp: Resp = client.get("/test").await.unwrap();
989 assert!(resp.ok);
990 mock.assert_async().await;
991 }
992
993 #[tokio::test]
994 async fn test_retry_on_429_then_success() {
995 let mut server = mockito::Server::new_async().await;
996
997 let _mock_429 = server
999 .mock("GET", "/test")
1000 .with_status(429)
1001 .with_header("retry-after", "0")
1002 .with_body(r#"{"error":{"message":"Rate limited","type":"rate_limit_error","param":null,"code":null}}"#)
1003 .create_async()
1004 .await;
1005
1006 let mock_200 = server
1007 .mock("GET", "/test")
1008 .with_status(200)
1009 .with_body(r#"{"ok":true}"#)
1010 .create_async()
1011 .await;
1012
1013 let client = OpenAI::with_config(
1014 ClientConfig::new("sk-test")
1015 .base_url(server.url())
1016 .max_retries(2),
1017 );
1018
1019 #[derive(serde::Deserialize)]
1020 struct Resp {
1021 ok: bool,
1022 }
1023
1024 let resp: Resp = client.get("/test").await.unwrap();
1025 assert!(resp.ok);
1026 mock_200.assert_async().await;
1027 }
1028
1029 #[tokio::test]
1030 async fn test_retry_exhausted_returns_last_error() {
1031 let mut server = mockito::Server::new_async().await;
1032
1033 let _mock = server
1035 .mock("GET", "/test")
1036 .with_status(500)
1037 .with_body(r#"{"error":{"message":"Internal server error","type":"server_error","param":null,"code":null}}"#)
1038 .expect_at_least(2)
1039 .create_async()
1040 .await;
1041
1042 let client = OpenAI::with_config(
1043 ClientConfig::new("sk-test")
1044 .base_url(server.url())
1045 .max_retries(1),
1046 );
1047
1048 #[derive(Debug, serde::Deserialize)]
1049 struct Resp {
1050 _ok: bool,
1051 }
1052
1053 let err = client.get::<Resp>("/test").await.unwrap_err();
1054 match err {
1055 OpenAIError::ApiError { status, .. } => assert_eq!(status, 500),
1056 other => panic!("expected ApiError, got: {other:?}"),
1057 }
1058 }
1059
1060 #[tokio::test]
1061 async fn test_no_retry_on_400() {
1062 let mut server = mockito::Server::new_async().await;
1063
1064 let mock = server
1066 .mock("GET", "/test")
1067 .with_status(400)
1068 .with_body(r#"{"error":{"message":"Bad request","type":"invalid_request_error","param":null,"code":null}}"#)
1069 .expect(1)
1070 .create_async()
1071 .await;
1072
1073 let client = OpenAI::with_config(
1074 ClientConfig::new("sk-test")
1075 .base_url(server.url())
1076 .max_retries(2),
1077 );
1078
1079 #[derive(Debug, serde::Deserialize)]
1080 struct Resp {
1081 _ok: bool,
1082 }
1083
1084 let err = client.get::<Resp>("/test").await.unwrap_err();
1085 match err {
1086 OpenAIError::ApiError { status, .. } => assert_eq!(status, 400),
1087 other => panic!("expected ApiError, got: {other:?}"),
1088 }
1089 mock.assert_async().await;
1090 }
1091
1092 #[tokio::test]
1093 async fn test_zero_retries_no_retry() {
1094 let mut server = mockito::Server::new_async().await;
1095
1096 let mock = server
1097 .mock("GET", "/test")
1098 .with_status(429)
1099 .with_body(r#"{"error":{"message":"Rate limited","type":"rate_limit_error","param":null,"code":null}}"#)
1100 .expect(1)
1101 .create_async()
1102 .await;
1103
1104 let client = OpenAI::with_config(
1105 ClientConfig::new("sk-test")
1106 .base_url(server.url())
1107 .max_retries(0),
1108 );
1109
1110 #[derive(Debug, serde::Deserialize)]
1111 struct Resp {
1112 _ok: bool,
1113 }
1114
1115 let err = client.get::<Resp>("/test").await.unwrap_err();
1116 match err {
1117 OpenAIError::ApiError { status, .. } => assert_eq!(status, 429),
1118 other => panic!("expected ApiError, got: {other:?}"),
1119 }
1120 mock.assert_async().await;
1121 }
1122
1123 #[tokio::test]
1126 async fn test_with_options_sends_extra_headers() {
1127 let mut server = mockito::Server::new_async().await;
1128 let mock = server
1129 .mock("GET", "/test")
1130 .match_header("X-Custom", "test-value")
1131 .with_status(200)
1132 .with_body(r#"{"ok":true}"#)
1133 .create_async()
1134 .await;
1135
1136 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
1137 let custom = client.with_options(RequestOptions::new().header("X-Custom", "test-value"));
1138
1139 #[derive(serde::Deserialize)]
1140 struct Resp {
1141 ok: bool,
1142 }
1143
1144 let resp: Resp = custom.get("/test").await.unwrap();
1145 assert!(resp.ok);
1146 mock.assert_async().await;
1147 }
1148
1149 #[tokio::test]
1150 async fn test_with_options_sends_query_params() {
1151 let mut server = mockito::Server::new_async().await;
1152 let mock = server
1153 .mock("GET", "/test")
1154 .match_query(mockito::Matcher::AllOf(vec![mockito::Matcher::UrlEncoded(
1155 "foo".into(),
1156 "bar".into(),
1157 )]))
1158 .with_status(200)
1159 .with_body(r#"{"ok":true}"#)
1160 .create_async()
1161 .await;
1162
1163 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
1164 let custom = client.with_options(RequestOptions::new().query_param("foo", "bar"));
1165
1166 #[derive(serde::Deserialize)]
1167 struct Resp {
1168 ok: bool,
1169 }
1170
1171 let resp: Resp = custom.get("/test").await.unwrap();
1172 assert!(resp.ok);
1173 mock.assert_async().await;
1174 }
1175
1176 #[tokio::test]
1177 async fn test_extra_body_merge() {
1178 let mut server = mockito::Server::new_async().await;
1179 let mock = server
1180 .mock("POST", "/test")
1181 .match_body(mockito::Matcher::Json(serde_json::json!({
1182 "model": "gpt-4",
1183 "extra_field": "injected"
1184 })))
1185 .with_status(200)
1186 .with_body(r#"{"id":"ok"}"#)
1187 .create_async()
1188 .await;
1189
1190 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
1191 let custom = client.with_options(
1192 RequestOptions::new().extra_body(serde_json::json!({"extra_field": "injected"})),
1193 );
1194
1195 #[derive(serde::Serialize)]
1196 struct Req {
1197 model: String,
1198 }
1199 #[derive(serde::Deserialize)]
1200 struct Resp {
1201 id: String,
1202 }
1203
1204 let resp: Resp = custom
1205 .post(
1206 "/test",
1207 &Req {
1208 model: "gpt-4".into(),
1209 },
1210 )
1211 .await
1212 .unwrap();
1213 assert_eq!(resp.id, "ok");
1214 mock.assert_async().await;
1215 }
1216
1217 #[tokio::test]
1218 async fn test_timeout_override() {
1219 let mut server = mockito::Server::new_async().await;
1220 let _mock = server
1222 .mock("GET", "/test")
1223 .with_status(200)
1224 .with_body(r#"{"ok":true}"#)
1225 .with_chunked_body(|_w| -> std::io::Result<()> {
1226 std::thread::sleep(std::time::Duration::from_secs(5));
1227 Ok(())
1228 })
1229 .create_async()
1230 .await;
1231
1232 let client = OpenAI::with_config(
1233 ClientConfig::new("sk-test")
1234 .base_url(server.url())
1235 .max_retries(0),
1236 );
1237 let custom = client.with_options(RequestOptions::new().timeout(Duration::from_millis(100)));
1238
1239 #[derive(Debug, serde::Deserialize)]
1240 struct Resp {
1241 _ok: bool,
1242 }
1243
1244 let err = custom.get::<Resp>("/test").await.unwrap_err();
1245 assert!(
1246 matches!(err, OpenAIError::RequestError(_)),
1247 "expected timeout error, got: {err:?}"
1248 );
1249 }
1250
1251 #[tokio::test]
1252 async fn test_options_merge_precedence() {
1253 let mut server = mockito::Server::new_async().await;
1254 let mock = server
1256 .mock("GET", "/test")
1257 .match_header("X-A", "2")
1258 .with_status(200)
1259 .with_body(r#"{"ok":true}"#)
1260 .create_async()
1261 .await;
1262
1263 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
1264 let base = client.with_options(RequestOptions::new().header("X-A", "1"));
1265 let custom = base.with_options(RequestOptions::new().header("X-A", "2"));
1266
1267 #[derive(serde::Deserialize)]
1268 struct Resp {
1269 ok: bool,
1270 }
1271
1272 let resp: Resp = custom.get("/test").await.unwrap();
1273 assert!(resp.ok);
1274 mock.assert_async().await;
1275 }
1276
1277 #[tokio::test]
1278 async fn test_default_headers_and_query_on_config() {
1279 let mut server = mockito::Server::new_async().await;
1280 let mock = server
1281 .mock("GET", "/test")
1282 .match_header("X-Default", "from-config")
1283 .match_query(mockito::Matcher::AllOf(vec![mockito::Matcher::UrlEncoded(
1284 "cfg_param".into(),
1285 "cfg_val".into(),
1286 )]))
1287 .with_status(200)
1288 .with_body(r#"{"ok":true}"#)
1289 .create_async()
1290 .await;
1291
1292 let mut default_headers = reqwest::header::HeaderMap::new();
1293 default_headers.insert("X-Default", "from-config".parse().unwrap());
1294
1295 let client = OpenAI::with_config(
1296 ClientConfig::new("sk-test")
1297 .base_url(server.url())
1298 .default_headers(default_headers)
1299 .default_query(vec![("cfg_param".into(), "cfg_val".into())]),
1300 );
1301
1302 #[derive(serde::Deserialize)]
1303 struct Resp {
1304 ok: bool,
1305 }
1306
1307 let resp: Resp = client.get("/test").await.unwrap();
1308 assert!(resp.ok);
1309 mock.assert_async().await;
1310 }
1311
1312 #[tokio::test]
1313 async fn test_chained_with_options_merges() {
1314 let mut server = mockito::Server::new_async().await;
1315 let mock = server
1316 .mock("GET", "/test")
1317 .match_header("X-A", "from-a")
1318 .match_header("X-B", "from-b")
1319 .with_status(200)
1320 .with_body(r#"{"ok":true}"#)
1321 .create_async()
1322 .await;
1323
1324 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
1325 let chained = client
1326 .with_options(RequestOptions::new().header("X-A", "from-a"))
1327 .with_options(RequestOptions::new().header("X-B", "from-b"));
1328
1329 #[derive(serde::Deserialize)]
1330 struct Resp {
1331 ok: bool,
1332 }
1333
1334 let resp: Resp = chained.get("/test").await.unwrap();
1335 assert!(resp.ok);
1336 mock.assert_async().await;
1337 }
1338}