1use std::time::Duration;
4
5use crate::azure::AzureConfig;
6use crate::config::ClientConfig;
7use crate::error::{ErrorResponse, OpenAIError};
8use crate::request_options::RequestOptions;
9#[cfg(feature = "audio")]
10use crate::resources::audio::Audio;
11#[cfg(feature = "batches")]
12use crate::resources::batches::Batches;
13#[cfg(feature = "beta")]
14use crate::resources::beta::assistants::Assistants;
15#[cfg(feature = "beta")]
16use crate::resources::beta::realtime::Realtime;
17#[cfg(feature = "beta")]
18use crate::resources::beta::runs::Runs;
19#[cfg(feature = "beta")]
20use crate::resources::beta::threads::Threads;
21#[cfg(feature = "beta")]
22use crate::resources::beta::vector_stores::VectorStores;
23#[cfg(feature = "chat")]
24use crate::resources::chat::Chat;
25#[cfg(feature = "embeddings")]
26use crate::resources::embeddings::Embeddings;
27#[cfg(feature = "files")]
28use crate::resources::files::Files;
29#[cfg(feature = "fine-tuning")]
30use crate::resources::fine_tuning::FineTuning;
31#[cfg(feature = "images")]
32use crate::resources::images::Images;
33#[cfg(feature = "models")]
34use crate::resources::models::Models;
35#[cfg(feature = "moderations")]
36use crate::resources::moderations::Moderations;
37#[cfg(feature = "responses")]
38use crate::resources::responses::Responses;
39#[cfg(feature = "uploads")]
40use crate::resources::uploads::Uploads;
41
42const RETRYABLE_STATUS_CODES: [u16; 4] = [429, 500, 502, 503];
44
45#[derive(Debug, Clone)]
63pub struct OpenAI {
64 pub(crate) http: reqwest::Client,
65 pub(crate) config: std::sync::Arc<dyn crate::config::Config>,
66 pub(crate) options: RequestOptions,
67}
68
69impl OpenAI {
70 pub fn new(api_key: impl Into<String>) -> Self {
72 Self::with_config(ClientConfig::new(api_key))
73 }
74
75 pub fn with_config<C: crate::config::Config + 'static>(config: C) -> Self {
77 let options = config.initial_options();
78
79 #[cfg(not(target_arch = "wasm32"))]
80 let http = {
81 crate::ensure_tls_provider();
82
83 reqwest::Client::builder()
84 .timeout(Duration::from_secs(config.timeout_secs()))
85 .tcp_nodelay(true)
86 .tcp_keepalive(Some(Duration::from_secs(30)))
87 .pool_idle_timeout(Some(Duration::from_secs(300)))
88 .pool_max_idle_per_host(4)
89 .http2_keep_alive_interval(Some(Duration::from_secs(20)))
90 .http2_keep_alive_timeout(Duration::from_secs(10))
91 .http2_keep_alive_while_idle(true)
92 .http2_adaptive_window(true)
93 .gzip(true)
94 .build()
95 .expect("failed to build HTTP client")
96 };
97
98 #[cfg(target_arch = "wasm32")]
99 let http = reqwest::Client::new();
100 Self {
101 http,
102 config: std::sync::Arc::new(config),
103 options,
104 }
105 }
106
107 #[must_use]
122 pub fn with_options(&self, options: RequestOptions) -> Self {
123 Self {
124 http: self.http.clone(),
125 config: self.config.clone(),
126 options: self.options.merge(&options),
127 }
128 }
129
130 pub fn from_env() -> Result<Self, OpenAIError> {
132 Ok(Self::with_config(ClientConfig::from_env()?))
133 }
134
135 pub fn azure(config: AzureConfig) -> Result<Self, OpenAIError> {
150 config.build()
151 }
152
153 #[cfg(feature = "batches")]
155 pub fn batches(&self) -> Batches<'_> {
156 Batches::new(self)
157 }
158
159 #[cfg(feature = "uploads")]
161 pub fn uploads(&self) -> Uploads<'_> {
162 Uploads::new(self)
163 }
164
165 #[cfg(feature = "beta")]
167 pub fn beta(&self) -> Beta<'_> {
168 Beta { client: self }
169 }
170
171 #[cfg(feature = "audio")]
173 pub fn audio(&self) -> Audio<'_> {
174 Audio::new(self)
175 }
176
177 #[cfg(feature = "chat")]
179 pub fn chat(&self) -> Chat<'_> {
180 Chat::new(self)
181 }
182
183 #[cfg(feature = "models")]
185 pub fn models(&self) -> Models<'_> {
186 Models::new(self)
187 }
188
189 #[cfg(feature = "fine-tuning")]
191 pub fn fine_tuning(&self) -> FineTuning<'_> {
192 FineTuning::new(self)
193 }
194
195 #[cfg(feature = "files")]
197 pub fn files(&self) -> Files<'_> {
198 Files::new(self)
199 }
200
201 #[cfg(feature = "images")]
203 pub fn images(&self) -> Images<'_> {
204 Images::new(self)
205 }
206
207 #[cfg(feature = "moderations")]
209 pub fn moderations(&self) -> Moderations<'_> {
210 Moderations::new(self)
211 }
212
213 #[cfg(feature = "responses")]
215 pub fn responses(&self) -> Responses<'_> {
216 Responses::new(self)
217 }
218
219 #[cfg(feature = "embeddings")]
221 pub fn embeddings(&self) -> Embeddings<'_> {
222 Embeddings::new(self)
223 }
224
225 pub fn conversations(&self) -> crate::resources::conversations::Conversations<'_> {
227 crate::resources::conversations::Conversations::new(self)
228 }
229
230 pub fn videos(&self) -> crate::resources::videos::Videos<'_> {
232 crate::resources::videos::Videos::new(self)
233 }
234
235 #[cfg(feature = "websocket")]
249 pub async fn ws_session(&self) -> Result<crate::websocket::WsSession, OpenAIError> {
250 crate::websocket::WsSession::connect(self.config.as_ref()).await
251 }
252
253 pub(crate) fn request(&self, method: reqwest::Method, path: &str) -> reqwest::RequestBuilder {
255 let url = format!("{}{}", self.config.base_url(), path);
256 let req = self.http.request(method, &url);
257 let mut req = self.config.build_request(req);
258
259 if let Some(ref headers) = self.options.headers {
261 for (key, value) in headers.iter() {
262 req = req.header(key.clone(), value.clone());
263 }
264 }
265 #[cfg(not(target_arch = "wasm32"))]
266 if let Some(ref query) = self.options.query {
267 req = req.query(query);
268 }
269 #[cfg(not(target_arch = "wasm32"))]
270 if let Some(timeout) = self.options.timeout {
271 req = req.timeout(timeout);
272 }
273
274 req
275 }
276
277 #[allow(dead_code)]
279 pub(crate) async fn get<T: serde::de::DeserializeOwned>(
280 &self,
281 path: &str,
282 ) -> Result<T, OpenAIError> {
283 self.send_with_retry(reqwest::Method::GET, path, None::<&()>)
284 .await
285 }
286
287 #[allow(dead_code)]
289 #[cfg(not(target_arch = "wasm32"))]
290 pub(crate) async fn get_with_query<T: serde::de::DeserializeOwned>(
291 &self,
292 path: &str,
293 query: &[(String, String)],
294 ) -> Result<T, OpenAIError> {
295 let mut req = self.request(reqwest::Method::GET, path);
296 if !query.is_empty() {
297 req = req.query(query);
298 }
299 let response = req.send().await?;
300 Self::handle_response(response).await
301 }
302
303 #[allow(dead_code)]
305 #[cfg(target_arch = "wasm32")]
306 pub(crate) async fn get_with_query<T: serde::de::DeserializeOwned>(
307 &self,
308 path: &str,
309 query: &[(String, String)],
310 ) -> Result<T, OpenAIError> {
311 let url = if query.is_empty() {
312 path.to_string()
313 } else {
314 let qs: Vec<String> = query.iter().map(|(k, v)| format!("{}={}", k, v)).collect();
315 format!("{}?{}", path, qs.join("&"))
316 };
317 self.get(&url).await
318 }
319
320 pub(crate) async fn post<B: serde::Serialize, T: serde::de::DeserializeOwned>(
322 &self,
323 path: &str,
324 body: &B,
325 ) -> Result<T, OpenAIError> {
326 self.send_with_retry(reqwest::Method::POST, path, Some(body))
327 .await
328 }
329
330 pub(crate) async fn post_json<B: serde::Serialize>(
336 &self,
337 path: &str,
338 body: &B,
339 ) -> Result<serde_json::Value, OpenAIError> {
340 self.post(path, body).await
341 }
342
343 pub(crate) async fn post_empty<T: serde::de::DeserializeOwned>(
345 &self,
346 path: &str,
347 ) -> Result<T, OpenAIError> {
348 self.send_with_retry(reqwest::Method::POST, path, None::<&()>)
349 .await
350 }
351
352 #[cfg(not(target_arch = "wasm32"))]
354 pub(crate) async fn post_multipart<T: serde::de::DeserializeOwned>(
355 &self,
356 path: &str,
357 form: reqwest::multipart::Form,
358 ) -> Result<T, OpenAIError> {
359 let response = self
360 .request(reqwest::Method::POST, path)
361 .multipart(form)
362 .send()
363 .await?;
364 Self::handle_response(response).await
365 }
366
367 pub(crate) async fn get_raw(&self, path: &str) -> Result<bytes::Bytes, OpenAIError> {
369 let response = self.request(reqwest::Method::GET, path).send().await?;
370
371 let status = response.status();
372 if status.is_success() {
373 Ok(response.bytes().await?)
374 } else {
375 Err(Self::extract_error(status.as_u16(), response).await)
376 }
377 }
378
379 pub(crate) async fn post_raw<B: serde::Serialize>(
381 &self,
382 path: &str,
383 body: &B,
384 ) -> Result<bytes::Bytes, OpenAIError> {
385 let mut req = self.request(reqwest::Method::POST, path);
386 if self.options.extra_body.is_some() {
387 req = req.json(&self.merge_body_json(body)?);
388 } else {
389 req = req.json(body);
390 }
391 let response = req.send().await?;
392
393 let status = response.status();
394 if status.is_success() {
395 Ok(response.bytes().await?)
396 } else {
397 Err(Self::extract_error(status.as_u16(), response).await)
398 }
399 }
400
401 #[allow(dead_code)]
403 pub(crate) async fn delete<T: serde::de::DeserializeOwned>(
404 &self,
405 path: &str,
406 ) -> Result<T, OpenAIError> {
407 self.send_with_retry(reqwest::Method::DELETE, path, None::<&()>)
408 .await
409 }
410
411 fn merge_body_json<B: serde::Serialize>(
413 &self,
414 body: &B,
415 ) -> Result<serde_json::Value, OpenAIError> {
416 let mut value = serde_json::to_value(body)?;
417 if let Some(ref extra) = self.options.extra_body
418 && let serde_json::Value::Object(map) = &mut value
419 && let serde_json::Value::Object(extra_map) = extra.clone()
420 {
421 for (k, v) in extra_map {
422 map.insert(k, v);
423 }
424 }
425 Ok(value)
426 }
427
428 fn prepare_body<B: serde::Serialize>(
430 &self,
431 body: Option<&B>,
432 ) -> Result<Option<serde_json::Value>, OpenAIError> {
433 match body {
434 Some(b) if self.options.extra_body.is_some() => Ok(Some(self.merge_body_json(b)?)),
435 Some(b) => Ok(Some(serde_json::to_value(b)?)),
436 None => Ok(None),
437 }
438 }
439
440 #[cfg(target_arch = "wasm32")]
442 async fn send_with_retry<B: serde::Serialize, T: serde::de::DeserializeOwned>(
443 &self,
444 method: reqwest::Method,
445 path: &str,
446 body: Option<&B>,
447 ) -> Result<T, OpenAIError> {
448 let body_value = self.prepare_body(body)?;
449
450 for attempt in 0..=self.config.max_retries() {
451 let mut req = self.request(method.clone(), path);
452 if let Some(ref val) = body_value {
453 req = req.json(val);
454 }
455
456 let response = match req.send().await {
457 Ok(resp) => resp,
458 Err(e) if attempt == self.config.max_retries() => {
459 return Err(OpenAIError::RequestError(e));
460 }
461 Err(_) => {
462 crate::runtime::sleep(crate::runtime::backoff_ms(attempt)).await;
463 continue;
464 }
465 };
466
467 let status = response.status().as_u16();
468 if !RETRYABLE_STATUS_CODES.contains(&status) || attempt == self.config.max_retries() {
469 return Self::handle_response(response).await;
470 }
471
472 crate::runtime::sleep(crate::runtime::backoff_ms(attempt)).await;
473 }
474
475 Err(OpenAIError::InvalidArgument("retry exhausted".into()))
476 }
477
478 #[cfg(not(target_arch = "wasm32"))]
483 async fn send_with_retry<B: serde::Serialize, T: serde::de::DeserializeOwned>(
484 &self,
485 method: reqwest::Method,
486 path: &str,
487 body: Option<&B>,
488 ) -> Result<T, OpenAIError> {
489 let body_value = self.prepare_body(body)?;
490
491 let mut req = self.request(method.clone(), path);
493 if let Some(ref val) = body_value {
494 req = req.json(val);
495 }
496
497 let response = match req.send().await {
498 Ok(resp) => resp,
499 Err(e) if self.config.max_retries() == 0 => return Err(OpenAIError::RequestError(e)),
500 Err(e) => {
501 return self.retry_loop(method, path, &body_value, e, 1).await;
503 }
504 };
505
506 let status = response.status().as_u16();
507 if !RETRYABLE_STATUS_CODES.contains(&status) {
508 return Self::handle_response(response).await;
509 }
510
511 if self.config.max_retries() == 0 {
512 return Self::handle_response(response).await;
513 }
514
515 let retry_after = response
517 .headers()
518 .get("retry-after")
519 .and_then(|v| v.to_str().ok())
520 .and_then(|v| v.parse::<f64>().ok());
521 let last_error = Self::extract_error(status, response).await;
522 tokio::time::sleep(Self::backoff_delay(0, retry_after)).await;
523 self.retry_loop(method, path, &body_value, last_error, 1)
524 .await
525 }
526
527 #[cfg(not(target_arch = "wasm32"))]
529 async fn retry_loop<T: serde::de::DeserializeOwned>(
530 &self,
531 method: reqwest::Method,
532 path: &str,
533 body_value: &Option<serde_json::Value>,
534 initial_error: impl Into<OpenAIError>,
535 start_attempt: u32,
536 ) -> Result<T, OpenAIError> {
537 let max_retries = self.config.max_retries();
538 let mut last_error: OpenAIError = initial_error.into();
539
540 for attempt in start_attempt..=max_retries {
541 let mut req = self.request(method.clone(), path);
542 if let Some(val) = body_value {
543 req = req.json(val);
544 }
545
546 let response = match req.send().await {
547 Ok(resp) => resp,
548 Err(e) => {
549 last_error = OpenAIError::RequestError(e);
550 if attempt < max_retries {
551 tokio::time::sleep(Self::backoff_delay(attempt, None)).await;
552 continue;
553 }
554 break;
555 }
556 };
557
558 let status = response.status().as_u16();
559 if !RETRYABLE_STATUS_CODES.contains(&status) || attempt == max_retries {
560 return Self::handle_response(response).await;
561 }
562
563 let retry_after = response
564 .headers()
565 .get("retry-after")
566 .and_then(|v| v.to_str().ok())
567 .and_then(|v| v.parse::<f64>().ok());
568 last_error = Self::extract_error(status, response).await;
569 tokio::time::sleep(Self::backoff_delay(attempt, retry_after)).await;
570 }
571
572 Err(last_error)
573 }
574
575 #[cfg(not(target_arch = "wasm32"))]
580 pub(crate) async fn send_raw_with_retry(
581 &self,
582 builder: reqwest::RequestBuilder,
583 ) -> Result<reqwest::Response, OpenAIError> {
584 let response = match builder.try_clone() {
586 Some(cloned) => match cloned.send().await {
587 Ok(resp) => resp,
588 Err(e) if self.config.max_retries() == 0 => {
589 return Err(OpenAIError::RequestError(e));
590 }
591 Err(e) => {
592 return self
593 .retry_loop_raw(builder, OpenAIError::RequestError(e), 1)
594 .await;
595 }
596 },
597 None => {
598 return Ok(builder.send().await?);
600 }
601 };
602
603 let status = response.status().as_u16();
604 if !RETRYABLE_STATUS_CODES.contains(&status) {
605 return Ok(response);
606 }
607 if self.config.max_retries() == 0 {
608 return Ok(response);
609 }
610
611 let retry_after = response
612 .headers()
613 .get("retry-after")
614 .and_then(|v| v.to_str().ok())
615 .and_then(|v| v.parse::<f64>().ok());
616 let last_error = Self::extract_error(status, response).await;
617 tokio::time::sleep(Self::backoff_delay(0, retry_after)).await;
618 self.retry_loop_raw(builder, last_error, 1).await
619 }
620
621 #[cfg(not(target_arch = "wasm32"))]
623 async fn retry_loop_raw(
624 &self,
625 builder: reqwest::RequestBuilder,
626 initial_error: OpenAIError,
627 start_attempt: u32,
628 ) -> Result<reqwest::Response, OpenAIError> {
629 let max_retries = self.config.max_retries();
630 let mut last_error = initial_error;
631
632 for attempt in start_attempt..=max_retries {
633 let req = match builder.try_clone() {
634 Some(cloned) => cloned,
635 None => return Err(last_error),
636 };
637
638 let response = match req.send().await {
639 Ok(resp) => resp,
640 Err(e) => {
641 last_error = OpenAIError::RequestError(e);
642 if attempt < max_retries {
643 tokio::time::sleep(Self::backoff_delay(attempt, None)).await;
644 continue;
645 }
646 break;
647 }
648 };
649
650 let status = response.status().as_u16();
651 if !RETRYABLE_STATUS_CODES.contains(&status) || attempt == max_retries {
652 return Ok(response);
653 }
654
655 let retry_after = response
656 .headers()
657 .get("retry-after")
658 .and_then(|v| v.to_str().ok())
659 .and_then(|v| v.parse::<f64>().ok());
660 last_error = Self::extract_error(status, response).await;
661 tokio::time::sleep(Self::backoff_delay(attempt, retry_after)).await;
662 }
663
664 Err(last_error)
665 }
666
667 #[cfg(target_arch = "wasm32")]
669 pub(crate) async fn send_raw_with_retry(
670 &self,
671 builder: reqwest::RequestBuilder,
672 ) -> Result<reqwest::Response, OpenAIError> {
673 Ok(builder.send().await?)
674 }
675
676 pub(crate) async fn check_stream_response(
678 response: reqwest::Response,
679 ) -> Result<reqwest::Response, OpenAIError> {
680 if response.status().is_success() {
681 Ok(response)
682 } else {
683 Err(Self::extract_error(response.status().as_u16(), response).await)
684 }
685 }
686
687 #[cfg(not(target_arch = "wasm32"))]
689 fn backoff_delay(attempt: u32, retry_after_secs: Option<f64>) -> Duration {
690 let base = crate::runtime::backoff_ms(attempt);
691 match retry_after_secs {
692 Some(ra) => Duration::from_secs_f64(ra.max(base.as_secs_f64())),
693 None => base,
694 }
695 }
696
697 pub(crate) async fn handle_response<T: serde::de::DeserializeOwned>(
704 response: reqwest::Response,
705 ) -> Result<T, OpenAIError> {
706 let status = response.status();
707 if status.is_success() {
708 let body = response.bytes().await?;
709 let result = Self::deserialize_body::<T>(&body);
710 match result {
711 Ok(value) => Ok(value),
712 Err(e) => {
713 tracing::error!(
714 error = %e,
715 body_len = body.len(),
716 body_preview = %String::from_utf8_lossy(&body[..body.len().min(500)]),
717 "failed to deserialize API response"
718 );
719 Err(e)
720 }
721 }
722 } else {
723 Err(Self::extract_error(status.as_u16(), response).await)
724 }
725 }
726
727 #[cfg(feature = "simd")]
729 fn deserialize_body<T: serde::de::DeserializeOwned>(body: &[u8]) -> Result<T, OpenAIError> {
730 let mut buf = body.to_vec();
731 simd_json::from_slice::<T>(&mut buf)
732 .map_err(|e| OpenAIError::StreamError(format!("simd-json: {e}")))
733 }
734
735 #[cfg(not(feature = "simd"))]
737 fn deserialize_body<T: serde::de::DeserializeOwned>(body: &[u8]) -> Result<T, OpenAIError> {
738 serde_json::from_slice::<T>(body).map_err(OpenAIError::from)
739 }
740
741 pub(crate) fn extract_request_id(response: &reqwest::Response) -> Option<String> {
743 response
744 .headers()
745 .get("x-request-id")
746 .and_then(|v| v.to_str().ok())
747 .map(String::from)
748 }
749
750 pub(crate) async fn extract_error(status: u16, response: reqwest::Response) -> OpenAIError {
752 let request_id = Self::extract_request_id(&response);
753 let body = response.text().await.unwrap_or_default();
754 if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&body) {
755 OpenAIError::ApiError {
756 status,
757 message: error_resp.error.message,
758 type_: error_resp.error.type_,
759 code: error_resp.error.code,
760 request_id,
761 }
762 } else {
763 OpenAIError::ApiError {
764 status,
765 message: body,
766 type_: None,
767 code: None,
768 request_id,
769 }
770 }
771 }
772}
773
774#[cfg(feature = "beta")]
776pub struct Beta<'a> {
777 client: &'a OpenAI,
778}
779
780#[cfg(feature = "beta")]
781impl<'a> Beta<'a> {
782 pub fn assistants(&self) -> Assistants<'_> {
784 Assistants::new(self.client)
785 }
786
787 pub fn threads(&self) -> Threads<'_> {
789 Threads::new(self.client)
790 }
791
792 pub fn runs(&self, thread_id: &str) -> Runs<'_> {
794 Runs::new(self.client, thread_id.to_string())
795 }
796
797 pub fn vector_stores(&self) -> VectorStores<'_> {
799 VectorStores::new(self.client)
800 }
801
802 pub fn realtime(&self) -> Realtime<'_> {
804 Realtime::new(self.client)
805 }
806}
807
808#[cfg(test)]
809mod tests {
810 use super::*;
811
812 #[test]
813 fn test_new_client() {
814 let client = OpenAI::new("sk-test-key");
815 assert_eq!(client.config.api_key(), "sk-test-key");
816 assert_eq!(client.config.base_url(), "https://api.openai.com/v1");
817 }
818
819 #[test]
820 fn test_with_config() {
821 let config = ClientConfig::new("sk-test")
822 .base_url("https://custom.api.com")
823 .organization("org-123")
824 .timeout_secs(30);
825 let client = OpenAI::with_config(config);
826 assert_eq!(client.config.base_url(), "https://custom.api.com");
827 assert_eq!(client.config.organization(), Some("org-123"));
828 assert_eq!(client.config.timeout_secs(), 30);
829 }
830
831 #[test]
832 fn test_backoff_delay() {
833 let d = OpenAI::backoff_delay(0, None);
835 assert_eq!(d, Duration::from_millis(500));
836
837 let d = OpenAI::backoff_delay(1, None);
839 assert_eq!(d, Duration::from_secs(1));
840
841 let d = OpenAI::backoff_delay(2, None);
843 assert_eq!(d, Duration::from_secs(2));
844
845 let d = OpenAI::backoff_delay(0, Some(5.0));
847 assert_eq!(d, Duration::from_secs(5));
848
849 let d = OpenAI::backoff_delay(3, Some(0.1));
851 assert_eq!(d, Duration::from_secs(4));
852
853 let d = OpenAI::backoff_delay(10, None);
855 assert_eq!(d, Duration::from_secs(60));
856 }
857
858 #[tokio::test]
859 async fn test_get_success() {
860 let mut server = mockito::Server::new_async().await;
861 let mock = server
862 .mock("GET", "/models/gpt-4")
863 .with_status(200)
864 .with_header("content-type", "application/json")
865 .with_body(
866 r#"{"id":"gpt-4","object":"model","created":1687882411,"owned_by":"openai"}"#,
867 )
868 .create_async()
869 .await;
870
871 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
872
873 #[derive(serde::Deserialize)]
874 struct Model {
875 id: String,
876 object: String,
877 }
878
879 let model: Model = client.get("/models/gpt-4").await.unwrap();
880 assert_eq!(model.id, "gpt-4");
881 assert_eq!(model.object, "model");
882 mock.assert_async().await;
883 }
884
885 #[tokio::test]
886 async fn test_post_success() {
887 let mut server = mockito::Server::new_async().await;
888 let mock = server
889 .mock("POST", "/chat/completions")
890 .match_header("authorization", "Bearer sk-test")
891 .match_header("content-type", "application/json")
892 .with_status(200)
893 .with_header("content-type", "application/json")
894 .with_body(r#"{"id":"chatcmpl-123","object":"chat.completion"}"#)
895 .create_async()
896 .await;
897
898 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
899
900 #[derive(serde::Serialize)]
901 struct Req {
902 model: String,
903 }
904 #[derive(serde::Deserialize)]
905 struct Resp {
906 id: String,
907 }
908
909 let resp: Resp = client
910 .post(
911 "/chat/completions",
912 &Req {
913 model: "gpt-4".into(),
914 },
915 )
916 .await
917 .unwrap();
918 assert_eq!(resp.id, "chatcmpl-123");
919 mock.assert_async().await;
920 }
921
922 #[tokio::test]
923 async fn test_delete_success() {
924 let mut server = mockito::Server::new_async().await;
925 let mock = server
926 .mock("DELETE", "/models/ft-abc")
927 .with_status(200)
928 .with_header("content-type", "application/json")
929 .with_body(r#"{"id":"ft-abc","deleted":true}"#)
930 .create_async()
931 .await;
932
933 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
934
935 #[derive(serde::Deserialize)]
936 struct DeleteResp {
937 id: String,
938 deleted: bool,
939 }
940
941 let resp: DeleteResp = client.delete("/models/ft-abc").await.unwrap();
942 assert_eq!(resp.id, "ft-abc");
943 assert!(resp.deleted);
944 mock.assert_async().await;
945 }
946
947 #[tokio::test]
948 async fn test_api_error_response() {
949 let mut server = mockito::Server::new_async().await;
950 let mock = server
951 .mock("GET", "/models/nonexistent")
952 .with_status(404)
953 .with_header("content-type", "application/json")
954 .with_body(
955 r#"{"error":{"message":"The model 'nonexistent' does not exist","type":"invalid_request_error","param":null,"code":"model_not_found"}}"#,
956 )
957 .create_async()
958 .await;
959
960 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
961
962 #[derive(Debug, serde::Deserialize)]
963 struct Model {
964 _id: String,
965 }
966
967 let err = client
968 .get::<Model>("/models/nonexistent")
969 .await
970 .unwrap_err();
971 match err {
972 OpenAIError::ApiError {
973 status,
974 message,
975 type_,
976 code,
977 ..
978 } => {
979 assert_eq!(status, 404);
980 assert!(message.contains("does not exist"));
981 assert_eq!(type_.as_deref(), Some("invalid_request_error"));
982 assert_eq!(code.as_deref(), Some("model_not_found"));
983 }
984 other => panic!("expected ApiError, got: {other:?}"),
985 }
986 mock.assert_async().await;
987 }
988
989 #[tokio::test]
990 async fn test_auth_headers() {
991 let mut server = mockito::Server::new_async().await;
992 let mock = server
993 .mock("GET", "/test")
994 .match_header("authorization", "Bearer sk-key")
995 .match_header("OpenAI-Organization", "org-abc")
996 .match_header("OpenAI-Project", "proj-xyz")
997 .with_status(200)
998 .with_body(r#"{"ok":true}"#)
999 .create_async()
1000 .await;
1001
1002 let client = OpenAI::with_config(
1003 ClientConfig::new("sk-key")
1004 .base_url(server.url())
1005 .organization("org-abc")
1006 .project("proj-xyz"),
1007 );
1008
1009 #[derive(serde::Deserialize)]
1010 struct Resp {
1011 ok: bool,
1012 }
1013
1014 let resp: Resp = client.get("/test").await.unwrap();
1015 assert!(resp.ok);
1016 mock.assert_async().await;
1017 }
1018
1019 #[tokio::test]
1020 async fn test_retry_on_429_then_success() {
1021 let mut server = mockito::Server::new_async().await;
1022
1023 let _mock_429 = server
1025 .mock("GET", "/test")
1026 .with_status(429)
1027 .with_header("retry-after", "0")
1028 .with_body(r#"{"error":{"message":"Rate limited","type":"rate_limit_error","param":null,"code":null}}"#)
1029 .create_async()
1030 .await;
1031
1032 let mock_200 = server
1033 .mock("GET", "/test")
1034 .with_status(200)
1035 .with_body(r#"{"ok":true}"#)
1036 .create_async()
1037 .await;
1038
1039 let client = OpenAI::with_config(
1040 ClientConfig::new("sk-test")
1041 .base_url(server.url())
1042 .max_retries(2),
1043 );
1044
1045 #[derive(serde::Deserialize)]
1046 struct Resp {
1047 ok: bool,
1048 }
1049
1050 let resp: Resp = client.get("/test").await.unwrap();
1051 assert!(resp.ok);
1052 mock_200.assert_async().await;
1053 }
1054
1055 #[tokio::test]
1056 async fn test_retry_exhausted_returns_last_error() {
1057 let mut server = mockito::Server::new_async().await;
1058
1059 let _mock = server
1061 .mock("GET", "/test")
1062 .with_status(500)
1063 .with_body(r#"{"error":{"message":"Internal server error","type":"server_error","param":null,"code":null}}"#)
1064 .expect_at_least(2)
1065 .create_async()
1066 .await;
1067
1068 let client = OpenAI::with_config(
1069 ClientConfig::new("sk-test")
1070 .base_url(server.url())
1071 .max_retries(1),
1072 );
1073
1074 #[derive(Debug, serde::Deserialize)]
1075 struct Resp {
1076 _ok: bool,
1077 }
1078
1079 let err = client.get::<Resp>("/test").await.unwrap_err();
1080 match err {
1081 OpenAIError::ApiError { status, .. } => assert_eq!(status, 500),
1082 other => panic!("expected ApiError, got: {other:?}"),
1083 }
1084 }
1085
1086 #[tokio::test]
1087 async fn test_no_retry_on_400() {
1088 let mut server = mockito::Server::new_async().await;
1089
1090 let mock = server
1092 .mock("GET", "/test")
1093 .with_status(400)
1094 .with_body(r#"{"error":{"message":"Bad request","type":"invalid_request_error","param":null,"code":null}}"#)
1095 .expect(1)
1096 .create_async()
1097 .await;
1098
1099 let client = OpenAI::with_config(
1100 ClientConfig::new("sk-test")
1101 .base_url(server.url())
1102 .max_retries(2),
1103 );
1104
1105 #[derive(Debug, serde::Deserialize)]
1106 struct Resp {
1107 _ok: bool,
1108 }
1109
1110 let err = client.get::<Resp>("/test").await.unwrap_err();
1111 match err {
1112 OpenAIError::ApiError { status, .. } => assert_eq!(status, 400),
1113 other => panic!("expected ApiError, got: {other:?}"),
1114 }
1115 mock.assert_async().await;
1116 }
1117
1118 #[tokio::test]
1119 async fn test_zero_retries_no_retry() {
1120 let mut server = mockito::Server::new_async().await;
1121
1122 let mock = server
1123 .mock("GET", "/test")
1124 .with_status(429)
1125 .with_body(r#"{"error":{"message":"Rate limited","type":"rate_limit_error","param":null,"code":null}}"#)
1126 .expect(1)
1127 .create_async()
1128 .await;
1129
1130 let client = OpenAI::with_config(
1131 ClientConfig::new("sk-test")
1132 .base_url(server.url())
1133 .max_retries(0),
1134 );
1135
1136 #[derive(Debug, serde::Deserialize)]
1137 struct Resp {
1138 _ok: bool,
1139 }
1140
1141 let err = client.get::<Resp>("/test").await.unwrap_err();
1142 match err {
1143 OpenAIError::ApiError { status, .. } => assert_eq!(status, 429),
1144 other => panic!("expected ApiError, got: {other:?}"),
1145 }
1146 mock.assert_async().await;
1147 }
1148
1149 #[tokio::test]
1152 async fn test_with_options_sends_extra_headers() {
1153 let mut server = mockito::Server::new_async().await;
1154 let mock = server
1155 .mock("GET", "/test")
1156 .match_header("X-Custom", "test-value")
1157 .with_status(200)
1158 .with_body(r#"{"ok":true}"#)
1159 .create_async()
1160 .await;
1161
1162 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
1163 let custom = client.with_options(RequestOptions::new().header("X-Custom", "test-value"));
1164
1165 #[derive(serde::Deserialize)]
1166 struct Resp {
1167 ok: bool,
1168 }
1169
1170 let resp: Resp = custom.get("/test").await.unwrap();
1171 assert!(resp.ok);
1172 mock.assert_async().await;
1173 }
1174
1175 #[tokio::test]
1176 async fn test_with_options_sends_query_params() {
1177 let mut server = mockito::Server::new_async().await;
1178 let mock = server
1179 .mock("GET", "/test")
1180 .match_query(mockito::Matcher::AllOf(vec![mockito::Matcher::UrlEncoded(
1181 "foo".into(),
1182 "bar".into(),
1183 )]))
1184 .with_status(200)
1185 .with_body(r#"{"ok":true}"#)
1186 .create_async()
1187 .await;
1188
1189 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
1190 let custom = client.with_options(RequestOptions::new().query_param("foo", "bar"));
1191
1192 #[derive(serde::Deserialize)]
1193 struct Resp {
1194 ok: bool,
1195 }
1196
1197 let resp: Resp = custom.get("/test").await.unwrap();
1198 assert!(resp.ok);
1199 mock.assert_async().await;
1200 }
1201
1202 #[tokio::test]
1203 async fn test_extra_body_merge() {
1204 let mut server = mockito::Server::new_async().await;
1205 let mock = server
1206 .mock("POST", "/test")
1207 .match_body(mockito::Matcher::Json(serde_json::json!({
1208 "model": "gpt-4",
1209 "extra_field": "injected"
1210 })))
1211 .with_status(200)
1212 .with_body(r#"{"id":"ok"}"#)
1213 .create_async()
1214 .await;
1215
1216 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
1217 let custom = client.with_options(
1218 RequestOptions::new().extra_body(serde_json::json!({"extra_field": "injected"})),
1219 );
1220
1221 #[derive(serde::Serialize)]
1222 struct Req {
1223 model: String,
1224 }
1225 #[derive(serde::Deserialize)]
1226 struct Resp {
1227 id: String,
1228 }
1229
1230 let resp: Resp = custom
1231 .post(
1232 "/test",
1233 &Req {
1234 model: "gpt-4".into(),
1235 },
1236 )
1237 .await
1238 .unwrap();
1239 assert_eq!(resp.id, "ok");
1240 mock.assert_async().await;
1241 }
1242
1243 #[tokio::test]
1244 async fn test_timeout_override() {
1245 let mut server = mockito::Server::new_async().await;
1246 let _mock = server
1248 .mock("GET", "/test")
1249 .with_status(200)
1250 .with_body(r#"{"ok":true}"#)
1251 .with_chunked_body(|_w| -> std::io::Result<()> {
1252 std::thread::sleep(std::time::Duration::from_secs(5));
1253 Ok(())
1254 })
1255 .create_async()
1256 .await;
1257
1258 let client = OpenAI::with_config(
1259 ClientConfig::new("sk-test")
1260 .base_url(server.url())
1261 .max_retries(0),
1262 );
1263 let custom = client.with_options(RequestOptions::new().timeout(Duration::from_millis(100)));
1264
1265 #[derive(Debug, serde::Deserialize)]
1266 struct Resp {
1267 _ok: bool,
1268 }
1269
1270 let err = custom.get::<Resp>("/test").await.unwrap_err();
1271 assert!(
1272 matches!(err, OpenAIError::RequestError(_)),
1273 "expected timeout error, got: {err:?}"
1274 );
1275 }
1276
1277 #[tokio::test]
1278 async fn test_options_merge_precedence() {
1279 let mut server = mockito::Server::new_async().await;
1280 let mock = server
1282 .mock("GET", "/test")
1283 .match_header("X-A", "2")
1284 .with_status(200)
1285 .with_body(r#"{"ok":true}"#)
1286 .create_async()
1287 .await;
1288
1289 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
1290 let base = client.with_options(RequestOptions::new().header("X-A", "1"));
1291 let custom = base.with_options(RequestOptions::new().header("X-A", "2"));
1292
1293 #[derive(serde::Deserialize)]
1294 struct Resp {
1295 ok: bool,
1296 }
1297
1298 let resp: Resp = custom.get("/test").await.unwrap();
1299 assert!(resp.ok);
1300 mock.assert_async().await;
1301 }
1302
1303 #[tokio::test]
1304 async fn test_default_headers_and_query_on_config() {
1305 let mut server = mockito::Server::new_async().await;
1306 let mock = server
1307 .mock("GET", "/test")
1308 .match_header("X-Default", "from-config")
1309 .match_query(mockito::Matcher::AllOf(vec![mockito::Matcher::UrlEncoded(
1310 "cfg_param".into(),
1311 "cfg_val".into(),
1312 )]))
1313 .with_status(200)
1314 .with_body(r#"{"ok":true}"#)
1315 .create_async()
1316 .await;
1317
1318 let mut default_headers = reqwest::header::HeaderMap::new();
1319 default_headers.insert("X-Default", "from-config".parse().unwrap());
1320
1321 let client = OpenAI::with_config(
1322 ClientConfig::new("sk-test")
1323 .base_url(server.url())
1324 .default_headers(default_headers)
1325 .default_query(vec![("cfg_param".into(), "cfg_val".into())]),
1326 );
1327
1328 #[derive(serde::Deserialize)]
1329 struct Resp {
1330 ok: bool,
1331 }
1332
1333 let resp: Resp = client.get("/test").await.unwrap();
1334 assert!(resp.ok);
1335 mock.assert_async().await;
1336 }
1337
1338 #[tokio::test]
1339 async fn test_chained_with_options_merges() {
1340 let mut server = mockito::Server::new_async().await;
1341 let mock = server
1342 .mock("GET", "/test")
1343 .match_header("X-A", "from-a")
1344 .match_header("X-B", "from-b")
1345 .with_status(200)
1346 .with_body(r#"{"ok":true}"#)
1347 .create_async()
1348 .await;
1349
1350 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
1351 let chained = client
1352 .with_options(RequestOptions::new().header("X-A", "from-a"))
1353 .with_options(RequestOptions::new().header("X-B", "from-b"));
1354
1355 #[derive(serde::Deserialize)]
1356 struct Resp {
1357 ok: bool,
1358 }
1359
1360 let resp: Resp = chained.get("/test").await.unwrap();
1361 assert!(resp.ok);
1362 mock.assert_async().await;
1363 }
1364}