1use std::time::Duration;
4
5use crate::azure::AzureConfig;
6use crate::config::ClientConfig;
7use crate::error::{ErrorResponse, OpenAIError};
8use crate::request_options::RequestOptions;
9#[cfg(feature = "audio")]
10use crate::resources::audio::Audio;
11#[cfg(feature = "batches")]
12use crate::resources::batches::Batches;
13#[cfg(feature = "beta")]
14use crate::resources::beta::assistants::Assistants;
15#[cfg(feature = "beta")]
16use crate::resources::beta::realtime::Realtime;
17#[cfg(feature = "beta")]
18use crate::resources::beta::runs::Runs;
19#[cfg(feature = "beta")]
20use crate::resources::beta::threads::Threads;
21#[cfg(feature = "beta")]
22use crate::resources::beta::vector_stores::VectorStores;
23#[cfg(feature = "chat")]
24use crate::resources::chat::Chat;
25#[cfg(feature = "embeddings")]
26use crate::resources::embeddings::Embeddings;
27#[cfg(feature = "files")]
28use crate::resources::files::Files;
29#[cfg(feature = "fine-tuning")]
30use crate::resources::fine_tuning::FineTuning;
31#[cfg(feature = "images")]
32use crate::resources::images::Images;
33#[cfg(feature = "models")]
34use crate::resources::models::Models;
35#[cfg(feature = "moderations")]
36use crate::resources::moderations::Moderations;
37#[cfg(feature = "responses")]
38use crate::resources::responses::Responses;
39#[cfg(feature = "uploads")]
40use crate::resources::uploads::Uploads;
41
42const RETRYABLE_STATUS_CODES: [u16; 4] = [429, 500, 502, 503];
44
45#[derive(Debug, Clone)]
63pub struct OpenAI {
64 pub(crate) http: reqwest::Client,
65 pub(crate) config: std::sync::Arc<dyn crate::config::Config>,
66 pub(crate) options: RequestOptions,
67}
68
69impl OpenAI {
70 pub fn new(api_key: impl Into<String>) -> Self {
72 Self::with_config(ClientConfig::new(api_key))
73 }
74
75 pub fn with_config<C: crate::config::Config + 'static>(config: C) -> Self {
77 let options = config.initial_options();
78
79 #[cfg(not(target_arch = "wasm32"))]
80 let http = {
81 crate::ensure_tls_provider();
82
83 reqwest::Client::builder()
84 .timeout(Duration::from_secs(config.timeout_secs()))
85 .tcp_nodelay(true)
86 .tcp_keepalive(Some(Duration::from_secs(30)))
87 .pool_idle_timeout(Some(Duration::from_secs(300)))
88 .pool_max_idle_per_host(4)
89 .http2_keep_alive_interval(Some(Duration::from_secs(20)))
90 .http2_keep_alive_timeout(Duration::from_secs(10))
91 .http2_keep_alive_while_idle(true)
92 .http2_adaptive_window(true)
93 .gzip(true)
94 .build()
95 .expect("failed to build HTTP client")
96 };
97
98 #[cfg(target_arch = "wasm32")]
99 let http = reqwest::Client::new();
100 Self {
101 http,
102 config: std::sync::Arc::new(config),
103 options,
104 }
105 }
106
107 #[must_use]
122 pub fn with_options(&self, options: RequestOptions) -> Self {
123 Self {
124 http: self.http.clone(),
125 config: self.config.clone(),
126 options: self.options.merge(&options),
127 }
128 }
129
130 pub fn client(&self) -> &Self {
134 self
135 }
136
137 pub fn from_env() -> Result<Self, OpenAIError> {
139 Ok(Self::with_config(ClientConfig::from_env()?))
140 }
141
142 pub fn azure(config: AzureConfig) -> Result<Self, OpenAIError> {
157 config.build()
158 }
159
160 #[cfg(feature = "batches")]
162 pub fn batches(&self) -> Batches<'_> {
163 Batches::new(self)
164 }
165
166 #[cfg(feature = "uploads")]
168 pub fn uploads(&self) -> Uploads<'_> {
169 Uploads::new(self)
170 }
171
172 #[cfg(feature = "beta")]
174 pub fn beta(&self) -> Beta<'_> {
175 Beta { client: self }
176 }
177
178 #[cfg(feature = "audio")]
180 pub fn audio(&self) -> Audio<'_> {
181 Audio::new(self)
182 }
183
184 #[cfg(feature = "chat")]
186 pub fn chat(&self) -> Chat<'_> {
187 Chat::new(self)
188 }
189
190 #[cfg(feature = "models")]
192 pub fn models(&self) -> Models<'_> {
193 Models::new(self)
194 }
195
196 #[cfg(feature = "fine-tuning")]
198 pub fn fine_tuning(&self) -> FineTuning<'_> {
199 FineTuning::new(self)
200 }
201
202 #[cfg(feature = "files")]
204 pub fn files(&self) -> Files<'_> {
205 Files::new(self)
206 }
207
208 #[cfg(feature = "images")]
210 pub fn images(&self) -> Images<'_> {
211 Images::new(self)
212 }
213
214 #[cfg(feature = "moderations")]
216 pub fn moderations(&self) -> Moderations<'_> {
217 Moderations::new(self)
218 }
219
220 #[cfg(feature = "responses")]
222 pub fn responses(&self) -> Responses<'_> {
223 Responses::new(self)
224 }
225
226 #[cfg(feature = "embeddings")]
228 pub fn embeddings(&self) -> Embeddings<'_> {
229 Embeddings::new(self)
230 }
231
232 pub fn conversations(&self) -> crate::resources::conversations::Conversations<'_> {
234 crate::resources::conversations::Conversations::new(self)
235 }
236
237 pub fn videos(&self) -> crate::resources::videos::Videos<'_> {
239 crate::resources::videos::Videos::new(self)
240 }
241
242 pub fn realtime(&self) -> crate::resources::realtime::Realtime<'_> {
248 crate::resources::realtime::Realtime::new(self)
249 }
250
251 #[cfg(feature = "websocket")]
265 pub async fn ws_session(&self) -> Result<crate::websocket::WsSession, OpenAIError> {
266 crate::websocket::WsSession::connect(self.config.as_ref()).await
267 }
268
269 pub(crate) fn request(&self, method: reqwest::Method, path: &str) -> reqwest::RequestBuilder {
271 let url = format!("{}{}", self.config.base_url(), path);
272 let req = self.http.request(method, &url);
273 let mut req = self.config.build_request(req);
274
275 if let Some(ref headers) = self.options.headers {
277 for (key, value) in headers.iter() {
278 req = req.header(key.clone(), value.clone());
279 }
280 }
281 #[cfg(not(target_arch = "wasm32"))]
282 if let Some(ref query) = self.options.query {
283 req = req.query(query);
284 }
285 #[cfg(not(target_arch = "wasm32"))]
286 if let Some(timeout) = self.options.timeout {
287 req = req.timeout(timeout);
288 }
289
290 req
291 }
292
293 #[allow(dead_code)]
295 pub(crate) async fn get<T: serde::de::DeserializeOwned>(
296 &self,
297 path: &str,
298 ) -> Result<T, OpenAIError> {
299 self.send_with_retry(reqwest::Method::GET, path, None::<&()>)
300 .await
301 }
302
303 #[allow(dead_code)]
305 #[cfg(not(target_arch = "wasm32"))]
306 pub(crate) async fn get_with_query<T: serde::de::DeserializeOwned>(
307 &self,
308 path: &str,
309 query: &[(String, String)],
310 ) -> Result<T, OpenAIError> {
311 let mut req = self.request(reqwest::Method::GET, path);
312 if !query.is_empty() {
313 req = req.query(query);
314 }
315 let response = req.send().await?;
316 Self::handle_response(response).await
317 }
318
319 #[allow(dead_code)]
321 #[cfg(target_arch = "wasm32")]
322 pub(crate) async fn get_with_query<T: serde::de::DeserializeOwned>(
323 &self,
324 path: &str,
325 query: &[(String, String)],
326 ) -> Result<T, OpenAIError> {
327 let url = if query.is_empty() {
328 path.to_string()
329 } else {
330 let qs: Vec<String> = query.iter().map(|(k, v)| format!("{}={}", k, v)).collect();
331 format!("{}?{}", path, qs.join("&"))
332 };
333 self.get(&url).await
334 }
335
336 pub(crate) async fn post<B: serde::Serialize, T: serde::de::DeserializeOwned>(
338 &self,
339 path: &str,
340 body: &B,
341 ) -> Result<T, OpenAIError> {
342 self.send_with_retry(reqwest::Method::POST, path, Some(body))
343 .await
344 }
345
346 pub(crate) async fn post_json<B: serde::Serialize>(
352 &self,
353 path: &str,
354 body: &B,
355 ) -> Result<serde_json::Value, OpenAIError> {
356 self.post(path, body).await
357 }
358
359 pub async fn post_json_bytes(
365 &self,
366 path: &str,
367 json_bytes: bytes::Bytes,
368 ) -> Result<serde_json::Value, OpenAIError> {
369 let req = self
370 .request(reqwest::Method::POST, path)
371 .header(reqwest::header::CONTENT_TYPE, "application/json")
372 .body(json_bytes);
373 let response = req.send().await?;
374 Self::handle_response(response).await
375 }
376
377 pub async fn post_stream_json_bytes(
379 &self,
380 path: &str,
381 json_bytes: bytes::Bytes,
382 ) -> Result<reqwest::Response, OpenAIError> {
383 let req = self
384 .request(reqwest::Method::POST, path)
385 .header(reqwest::header::CONTENT_TYPE, "application/json")
386 .header(reqwest::header::ACCEPT, "text/event-stream")
387 .header(reqwest::header::CACHE_CONTROL, "no-cache")
388 .body(json_bytes);
389 let response = req.send().await?;
390 Self::check_stream_response(response).await
391 }
392
393 pub(crate) async fn post_empty<T: serde::de::DeserializeOwned>(
395 &self,
396 path: &str,
397 ) -> Result<T, OpenAIError> {
398 self.send_with_retry(reqwest::Method::POST, path, None::<&()>)
399 .await
400 }
401
402 #[cfg(not(target_arch = "wasm32"))]
404 pub(crate) async fn post_multipart<T: serde::de::DeserializeOwned>(
405 &self,
406 path: &str,
407 form: reqwest::multipart::Form,
408 ) -> Result<T, OpenAIError> {
409 let response = self
410 .request(reqwest::Method::POST, path)
411 .multipart(form)
412 .send()
413 .await?;
414 Self::handle_response(response).await
415 }
416
417 pub(crate) async fn get_raw(&self, path: &str) -> Result<bytes::Bytes, OpenAIError> {
419 let response = self.request(reqwest::Method::GET, path).send().await?;
420
421 let status = response.status();
422 if status.is_success() {
423 Ok(response.bytes().await?)
424 } else {
425 Err(Self::extract_error(status.as_u16(), response).await)
426 }
427 }
428
429 pub(crate) async fn post_raw<B: serde::Serialize>(
431 &self,
432 path: &str,
433 body: &B,
434 ) -> Result<bytes::Bytes, OpenAIError> {
435 let mut req = self.request(reqwest::Method::POST, path);
436 if self.options.extra_body.is_some() {
437 req = req.json(&self.merge_body_json(body)?);
438 } else {
439 req = req.json(body);
440 }
441 let response = req.send().await?;
442
443 let status = response.status();
444 if status.is_success() {
445 Ok(response.bytes().await?)
446 } else {
447 Err(Self::extract_error(status.as_u16(), response).await)
448 }
449 }
450
451 #[allow(dead_code)]
453 pub(crate) async fn delete<T: serde::de::DeserializeOwned>(
454 &self,
455 path: &str,
456 ) -> Result<T, OpenAIError> {
457 self.send_with_retry(reqwest::Method::DELETE, path, None::<&()>)
458 .await
459 }
460
461 fn merge_body_json<B: serde::Serialize>(
463 &self,
464 body: &B,
465 ) -> Result<serde_json::Value, OpenAIError> {
466 let mut value = serde_json::to_value(body)?;
467 if let Some(ref extra) = self.options.extra_body
468 && let serde_json::Value::Object(map) = &mut value
469 && let serde_json::Value::Object(extra_map) = extra.clone()
470 {
471 for (k, v) in extra_map {
472 map.insert(k, v);
473 }
474 }
475 Ok(value)
476 }
477
478 fn prepare_body<B: serde::Serialize>(
480 &self,
481 body: Option<&B>,
482 ) -> Result<Option<serde_json::Value>, OpenAIError> {
483 match body {
484 Some(b) if self.options.extra_body.is_some() => Ok(Some(self.merge_body_json(b)?)),
485 Some(b) => Ok(Some(serde_json::to_value(b)?)),
486 None => Ok(None),
487 }
488 }
489
490 #[cfg(target_arch = "wasm32")]
492 async fn send_with_retry<B: serde::Serialize, T: serde::de::DeserializeOwned>(
493 &self,
494 method: reqwest::Method,
495 path: &str,
496 body: Option<&B>,
497 ) -> Result<T, OpenAIError> {
498 let body_value = self.prepare_body(body)?;
499
500 for attempt in 0..=self.config.max_retries() {
501 let mut req = self.request(method.clone(), path);
502 if let Some(ref val) = body_value {
503 req = req.json(val);
504 }
505
506 let response = match req.send().await {
507 Ok(resp) => resp,
508 Err(e) if attempt == self.config.max_retries() => {
509 return Err(OpenAIError::RequestError(e));
510 }
511 Err(_) => {
512 crate::runtime::sleep(crate::runtime::backoff_ms(attempt)).await;
513 continue;
514 }
515 };
516
517 let status = response.status().as_u16();
518 if !RETRYABLE_STATUS_CODES.contains(&status) || attempt == self.config.max_retries() {
519 return Self::handle_response(response).await;
520 }
521
522 crate::runtime::sleep(crate::runtime::backoff_ms(attempt)).await;
523 }
524
525 Err(OpenAIError::InvalidArgument("retry exhausted".into()))
526 }
527
528 #[cfg(not(target_arch = "wasm32"))]
533 async fn send_with_retry<B: serde::Serialize, T: serde::de::DeserializeOwned>(
534 &self,
535 method: reqwest::Method,
536 path: &str,
537 body: Option<&B>,
538 ) -> Result<T, OpenAIError> {
539 let body_value = self.prepare_body(body)?;
540
541 let mut req = self.request(method.clone(), path);
543 if let Some(ref val) = body_value {
544 req = req.json(val);
545 }
546
547 let response = match req.send().await {
548 Ok(resp) => resp,
549 Err(e) if self.config.max_retries() == 0 => return Err(OpenAIError::RequestError(e)),
550 Err(e) => {
551 return self.retry_loop(method, path, &body_value, e, 1).await;
553 }
554 };
555
556 let status = response.status().as_u16();
557 if !RETRYABLE_STATUS_CODES.contains(&status) {
558 return Self::handle_response(response).await;
559 }
560
561 if self.config.max_retries() == 0 {
562 return Self::handle_response(response).await;
563 }
564
565 let retry_after = response
567 .headers()
568 .get("retry-after")
569 .and_then(|v| v.to_str().ok())
570 .and_then(|v| v.parse::<f64>().ok());
571 let last_error = Self::extract_error(status, response).await;
572 tokio::time::sleep(Self::backoff_delay(0, retry_after)).await;
573 self.retry_loop(method, path, &body_value, last_error, 1)
574 .await
575 }
576
577 #[cfg(not(target_arch = "wasm32"))]
579 async fn retry_loop<T: serde::de::DeserializeOwned>(
580 &self,
581 method: reqwest::Method,
582 path: &str,
583 body_value: &Option<serde_json::Value>,
584 initial_error: impl Into<OpenAIError>,
585 start_attempt: u32,
586 ) -> Result<T, OpenAIError> {
587 let max_retries = self.config.max_retries();
588 let mut last_error: OpenAIError = initial_error.into();
589
590 for attempt in start_attempt..=max_retries {
591 let mut req = self.request(method.clone(), path);
592 if let Some(val) = body_value {
593 req = req.json(val);
594 }
595
596 let response = match req.send().await {
597 Ok(resp) => resp,
598 Err(e) => {
599 last_error = OpenAIError::RequestError(e);
600 if attempt < max_retries {
601 tokio::time::sleep(Self::backoff_delay(attempt, None)).await;
602 continue;
603 }
604 break;
605 }
606 };
607
608 let status = response.status().as_u16();
609 if !RETRYABLE_STATUS_CODES.contains(&status) || attempt == max_retries {
610 return Self::handle_response(response).await;
611 }
612
613 let retry_after = response
614 .headers()
615 .get("retry-after")
616 .and_then(|v| v.to_str().ok())
617 .and_then(|v| v.parse::<f64>().ok());
618 last_error = Self::extract_error(status, response).await;
619 tokio::time::sleep(Self::backoff_delay(attempt, retry_after)).await;
620 }
621
622 Err(last_error)
623 }
624
625 #[cfg(not(target_arch = "wasm32"))]
630 pub(crate) async fn send_raw_with_retry(
631 &self,
632 builder: reqwest::RequestBuilder,
633 ) -> Result<reqwest::Response, OpenAIError> {
634 let response = match builder.try_clone() {
636 Some(cloned) => match cloned.send().await {
637 Ok(resp) => resp,
638 Err(e) if self.config.max_retries() == 0 => {
639 return Err(OpenAIError::RequestError(e));
640 }
641 Err(e) => {
642 return self
643 .retry_loop_raw(builder, OpenAIError::RequestError(e), 1)
644 .await;
645 }
646 },
647 None => {
648 return Ok(builder.send().await?);
650 }
651 };
652
653 let status = response.status().as_u16();
654 if !RETRYABLE_STATUS_CODES.contains(&status) {
655 return Ok(response);
656 }
657 if self.config.max_retries() == 0 {
658 return Ok(response);
659 }
660
661 let retry_after = response
662 .headers()
663 .get("retry-after")
664 .and_then(|v| v.to_str().ok())
665 .and_then(|v| v.parse::<f64>().ok());
666 let last_error = Self::extract_error(status, response).await;
667 tokio::time::sleep(Self::backoff_delay(0, retry_after)).await;
668 self.retry_loop_raw(builder, last_error, 1).await
669 }
670
671 #[cfg(not(target_arch = "wasm32"))]
673 async fn retry_loop_raw(
674 &self,
675 builder: reqwest::RequestBuilder,
676 initial_error: OpenAIError,
677 start_attempt: u32,
678 ) -> Result<reqwest::Response, OpenAIError> {
679 let max_retries = self.config.max_retries();
680 let mut last_error = initial_error;
681
682 for attempt in start_attempt..=max_retries {
683 let req = match builder.try_clone() {
684 Some(cloned) => cloned,
685 None => return Err(last_error),
686 };
687
688 let response = match req.send().await {
689 Ok(resp) => resp,
690 Err(e) => {
691 last_error = OpenAIError::RequestError(e);
692 if attempt < max_retries {
693 tokio::time::sleep(Self::backoff_delay(attempt, None)).await;
694 continue;
695 }
696 break;
697 }
698 };
699
700 let status = response.status().as_u16();
701 if !RETRYABLE_STATUS_CODES.contains(&status) || attempt == max_retries {
702 return Ok(response);
703 }
704
705 let retry_after = response
706 .headers()
707 .get("retry-after")
708 .and_then(|v| v.to_str().ok())
709 .and_then(|v| v.parse::<f64>().ok());
710 last_error = Self::extract_error(status, response).await;
711 tokio::time::sleep(Self::backoff_delay(attempt, retry_after)).await;
712 }
713
714 Err(last_error)
715 }
716
717 #[cfg(target_arch = "wasm32")]
719 pub(crate) async fn send_raw_with_retry(
720 &self,
721 builder: reqwest::RequestBuilder,
722 ) -> Result<reqwest::Response, OpenAIError> {
723 Ok(builder.send().await?)
724 }
725
726 pub(crate) async fn check_stream_response(
728 response: reqwest::Response,
729 ) -> Result<reqwest::Response, OpenAIError> {
730 if response.status().is_success() {
731 Ok(response)
732 } else {
733 Err(Self::extract_error(response.status().as_u16(), response).await)
734 }
735 }
736
737 #[cfg(not(target_arch = "wasm32"))]
739 fn backoff_delay(attempt: u32, retry_after_secs: Option<f64>) -> Duration {
740 let base = crate::runtime::backoff_ms(attempt);
741 match retry_after_secs {
742 Some(ra) => Duration::from_secs_f64(ra.max(base.as_secs_f64())),
743 None => base,
744 }
745 }
746
747 pub(crate) async fn handle_response<T: serde::de::DeserializeOwned>(
754 response: reqwest::Response,
755 ) -> Result<T, OpenAIError> {
756 let status = response.status();
757 if status.is_success() {
758 let body = response.bytes().await?;
759 let result = Self::deserialize_body::<T>(&body);
760 match result {
761 Ok(value) => Ok(value),
762 Err(e) => {
763 tracing::error!(
764 error = %e,
765 body_len = body.len(),
766 body_preview = %String::from_utf8_lossy(&body[..body.len().min(500)]),
767 "failed to deserialize API response"
768 );
769 Err(e)
770 }
771 }
772 } else {
773 Err(Self::extract_error(status.as_u16(), response).await)
774 }
775 }
776
777 #[cfg(feature = "simd")]
779 fn deserialize_body<T: serde::de::DeserializeOwned>(body: &[u8]) -> Result<T, OpenAIError> {
780 let mut buf = body.to_vec();
781 simd_json::from_slice::<T>(&mut buf)
782 .map_err(|e| OpenAIError::StreamError(format!("simd-json: {e}")))
783 }
784
785 #[cfg(not(feature = "simd"))]
787 fn deserialize_body<T: serde::de::DeserializeOwned>(body: &[u8]) -> Result<T, OpenAIError> {
788 serde_json::from_slice::<T>(body).map_err(OpenAIError::from)
789 }
790
791 pub(crate) fn extract_request_id(response: &reqwest::Response) -> Option<String> {
793 response
794 .headers()
795 .get("x-request-id")
796 .and_then(|v| v.to_str().ok())
797 .map(String::from)
798 }
799
800 pub(crate) async fn extract_error(status: u16, response: reqwest::Response) -> OpenAIError {
802 let request_id = Self::extract_request_id(&response);
803 let body = response.text().await.unwrap_or_default();
804 if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&body) {
805 OpenAIError::ApiError {
806 status,
807 message: error_resp.error.message,
808 type_: error_resp.error.type_,
809 code: error_resp.error.code,
810 request_id,
811 }
812 } else {
813 OpenAIError::ApiError {
814 status,
815 message: body,
816 type_: None,
817 code: None,
818 request_id,
819 }
820 }
821 }
822}
823
824#[cfg(feature = "beta")]
826pub struct Beta<'a> {
827 client: &'a OpenAI,
828}
829
830#[cfg(feature = "beta")]
831impl<'a> Beta<'a> {
832 pub fn assistants(&self) -> Assistants<'_> {
834 Assistants::new(self.client)
835 }
836
837 pub fn threads(&self) -> Threads<'_> {
839 Threads::new(self.client)
840 }
841
842 pub fn runs(&self, thread_id: &str) -> Runs<'_> {
844 Runs::new(self.client, thread_id.to_string())
845 }
846
847 pub fn vector_stores(&self) -> VectorStores<'_> {
849 VectorStores::new(self.client)
850 }
851
852 pub fn realtime(&self) -> Realtime<'_> {
854 Realtime::new(self.client)
855 }
856}
857
858#[cfg(test)]
859mod tests {
860 use super::*;
861
862 #[test]
863 fn test_new_client() {
864 let client = OpenAI::new("sk-test-key");
865 assert_eq!(client.config.api_key(), "sk-test-key");
866 assert_eq!(client.config.base_url(), "https://api.openai.com/v1");
867 }
868
869 #[test]
870 fn test_with_config() {
871 let config = ClientConfig::new("sk-test")
872 .base_url("https://custom.api.com")
873 .organization("org-123")
874 .timeout_secs(30);
875 let client = OpenAI::with_config(config);
876 assert_eq!(client.config.base_url(), "https://custom.api.com");
877 assert_eq!(client.config.organization(), Some("org-123"));
878 assert_eq!(client.config.timeout_secs(), 30);
879 }
880
881 #[test]
882 fn test_backoff_delay() {
883 let d = OpenAI::backoff_delay(0, None);
885 assert_eq!(d, Duration::from_millis(500));
886
887 let d = OpenAI::backoff_delay(1, None);
889 assert_eq!(d, Duration::from_secs(1));
890
891 let d = OpenAI::backoff_delay(2, None);
893 assert_eq!(d, Duration::from_secs(2));
894
895 let d = OpenAI::backoff_delay(0, Some(5.0));
897 assert_eq!(d, Duration::from_secs(5));
898
899 let d = OpenAI::backoff_delay(3, Some(0.1));
901 assert_eq!(d, Duration::from_secs(4));
902
903 let d = OpenAI::backoff_delay(10, None);
905 assert_eq!(d, Duration::from_secs(60));
906 }
907
908 #[tokio::test]
909 async fn test_get_success() {
910 let mut server = mockito::Server::new_async().await;
911 let mock = server
912 .mock("GET", "/models/gpt-4")
913 .with_status(200)
914 .with_header("content-type", "application/json")
915 .with_body(
916 r#"{"id":"gpt-4","object":"model","created":1687882411,"owned_by":"openai"}"#,
917 )
918 .create_async()
919 .await;
920
921 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
922
923 #[derive(serde::Deserialize)]
924 struct Model {
925 id: String,
926 object: String,
927 }
928
929 let model: Model = client.get("/models/gpt-4").await.unwrap();
930 assert_eq!(model.id, "gpt-4");
931 assert_eq!(model.object, "model");
932 mock.assert_async().await;
933 }
934
935 #[tokio::test]
936 async fn test_post_success() {
937 let mut server = mockito::Server::new_async().await;
938 let mock = server
939 .mock("POST", "/chat/completions")
940 .match_header("authorization", "Bearer sk-test")
941 .match_header("content-type", "application/json")
942 .with_status(200)
943 .with_header("content-type", "application/json")
944 .with_body(r#"{"id":"chatcmpl-123","object":"chat.completion"}"#)
945 .create_async()
946 .await;
947
948 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
949
950 #[derive(serde::Serialize)]
951 struct Req {
952 model: String,
953 }
954 #[derive(serde::Deserialize)]
955 struct Resp {
956 id: String,
957 }
958
959 let resp: Resp = client
960 .post(
961 "/chat/completions",
962 &Req {
963 model: "gpt-4".into(),
964 },
965 )
966 .await
967 .unwrap();
968 assert_eq!(resp.id, "chatcmpl-123");
969 mock.assert_async().await;
970 }
971
972 #[tokio::test]
973 async fn test_delete_success() {
974 let mut server = mockito::Server::new_async().await;
975 let mock = server
976 .mock("DELETE", "/models/ft-abc")
977 .with_status(200)
978 .with_header("content-type", "application/json")
979 .with_body(r#"{"id":"ft-abc","deleted":true}"#)
980 .create_async()
981 .await;
982
983 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
984
985 #[derive(serde::Deserialize)]
986 struct DeleteResp {
987 id: String,
988 deleted: bool,
989 }
990
991 let resp: DeleteResp = client.delete("/models/ft-abc").await.unwrap();
992 assert_eq!(resp.id, "ft-abc");
993 assert!(resp.deleted);
994 mock.assert_async().await;
995 }
996
997 #[tokio::test]
998 async fn test_api_error_response() {
999 let mut server = mockito::Server::new_async().await;
1000 let mock = server
1001 .mock("GET", "/models/nonexistent")
1002 .with_status(404)
1003 .with_header("content-type", "application/json")
1004 .with_body(
1005 r#"{"error":{"message":"The model 'nonexistent' does not exist","type":"invalid_request_error","param":null,"code":"model_not_found"}}"#,
1006 )
1007 .create_async()
1008 .await;
1009
1010 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
1011
1012 #[derive(Debug, serde::Deserialize)]
1013 struct Model {
1014 _id: String,
1015 }
1016
1017 let err = client
1018 .get::<Model>("/models/nonexistent")
1019 .await
1020 .unwrap_err();
1021 match err {
1022 OpenAIError::ApiError {
1023 status,
1024 message,
1025 type_,
1026 code,
1027 ..
1028 } => {
1029 assert_eq!(status, 404);
1030 assert!(message.contains("does not exist"));
1031 assert_eq!(type_.as_deref(), Some("invalid_request_error"));
1032 assert_eq!(code.as_deref(), Some("model_not_found"));
1033 }
1034 other => panic!("expected ApiError, got: {other:?}"),
1035 }
1036 mock.assert_async().await;
1037 }
1038
1039 #[tokio::test]
1040 async fn test_auth_headers() {
1041 let mut server = mockito::Server::new_async().await;
1042 let mock = server
1043 .mock("GET", "/test")
1044 .match_header("authorization", "Bearer sk-key")
1045 .match_header("OpenAI-Organization", "org-abc")
1046 .match_header("OpenAI-Project", "proj-xyz")
1047 .with_status(200)
1048 .with_body(r#"{"ok":true}"#)
1049 .create_async()
1050 .await;
1051
1052 let client = OpenAI::with_config(
1053 ClientConfig::new("sk-key")
1054 .base_url(server.url())
1055 .organization("org-abc")
1056 .project("proj-xyz"),
1057 );
1058
1059 #[derive(serde::Deserialize)]
1060 struct Resp {
1061 ok: bool,
1062 }
1063
1064 let resp: Resp = client.get("/test").await.unwrap();
1065 assert!(resp.ok);
1066 mock.assert_async().await;
1067 }
1068
1069 #[tokio::test]
1070 async fn test_retry_on_429_then_success() {
1071 let mut server = mockito::Server::new_async().await;
1072
1073 let _mock_429 = server
1075 .mock("GET", "/test")
1076 .with_status(429)
1077 .with_header("retry-after", "0")
1078 .with_body(r#"{"error":{"message":"Rate limited","type":"rate_limit_error","param":null,"code":null}}"#)
1079 .create_async()
1080 .await;
1081
1082 let mock_200 = server
1083 .mock("GET", "/test")
1084 .with_status(200)
1085 .with_body(r#"{"ok":true}"#)
1086 .create_async()
1087 .await;
1088
1089 let client = OpenAI::with_config(
1090 ClientConfig::new("sk-test")
1091 .base_url(server.url())
1092 .max_retries(2),
1093 );
1094
1095 #[derive(serde::Deserialize)]
1096 struct Resp {
1097 ok: bool,
1098 }
1099
1100 let resp: Resp = client.get("/test").await.unwrap();
1101 assert!(resp.ok);
1102 mock_200.assert_async().await;
1103 }
1104
1105 #[tokio::test]
1106 async fn test_retry_exhausted_returns_last_error() {
1107 let mut server = mockito::Server::new_async().await;
1108
1109 let _mock = server
1111 .mock("GET", "/test")
1112 .with_status(500)
1113 .with_body(r#"{"error":{"message":"Internal server error","type":"server_error","param":null,"code":null}}"#)
1114 .expect_at_least(2)
1115 .create_async()
1116 .await;
1117
1118 let client = OpenAI::with_config(
1119 ClientConfig::new("sk-test")
1120 .base_url(server.url())
1121 .max_retries(1),
1122 );
1123
1124 #[derive(Debug, serde::Deserialize)]
1125 struct Resp {
1126 _ok: bool,
1127 }
1128
1129 let err = client.get::<Resp>("/test").await.unwrap_err();
1130 match err {
1131 OpenAIError::ApiError { status, .. } => assert_eq!(status, 500),
1132 other => panic!("expected ApiError, got: {other:?}"),
1133 }
1134 }
1135
1136 #[tokio::test]
1137 async fn test_no_retry_on_400() {
1138 let mut server = mockito::Server::new_async().await;
1139
1140 let mock = server
1142 .mock("GET", "/test")
1143 .with_status(400)
1144 .with_body(r#"{"error":{"message":"Bad request","type":"invalid_request_error","param":null,"code":null}}"#)
1145 .expect(1)
1146 .create_async()
1147 .await;
1148
1149 let client = OpenAI::with_config(
1150 ClientConfig::new("sk-test")
1151 .base_url(server.url())
1152 .max_retries(2),
1153 );
1154
1155 #[derive(Debug, serde::Deserialize)]
1156 struct Resp {
1157 _ok: bool,
1158 }
1159
1160 let err = client.get::<Resp>("/test").await.unwrap_err();
1161 match err {
1162 OpenAIError::ApiError { status, .. } => assert_eq!(status, 400),
1163 other => panic!("expected ApiError, got: {other:?}"),
1164 }
1165 mock.assert_async().await;
1166 }
1167
1168 #[tokio::test]
1169 async fn test_zero_retries_no_retry() {
1170 let mut server = mockito::Server::new_async().await;
1171
1172 let mock = server
1173 .mock("GET", "/test")
1174 .with_status(429)
1175 .with_body(r#"{"error":{"message":"Rate limited","type":"rate_limit_error","param":null,"code":null}}"#)
1176 .expect(1)
1177 .create_async()
1178 .await;
1179
1180 let client = OpenAI::with_config(
1181 ClientConfig::new("sk-test")
1182 .base_url(server.url())
1183 .max_retries(0),
1184 );
1185
1186 #[derive(Debug, serde::Deserialize)]
1187 struct Resp {
1188 _ok: bool,
1189 }
1190
1191 let err = client.get::<Resp>("/test").await.unwrap_err();
1192 match err {
1193 OpenAIError::ApiError { status, .. } => assert_eq!(status, 429),
1194 other => panic!("expected ApiError, got: {other:?}"),
1195 }
1196 mock.assert_async().await;
1197 }
1198
1199 #[tokio::test]
1202 async fn test_with_options_sends_extra_headers() {
1203 let mut server = mockito::Server::new_async().await;
1204 let mock = server
1205 .mock("GET", "/test")
1206 .match_header("X-Custom", "test-value")
1207 .with_status(200)
1208 .with_body(r#"{"ok":true}"#)
1209 .create_async()
1210 .await;
1211
1212 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
1213 let custom = client.with_options(RequestOptions::new().header("X-Custom", "test-value"));
1214
1215 #[derive(serde::Deserialize)]
1216 struct Resp {
1217 ok: bool,
1218 }
1219
1220 let resp: Resp = custom.get("/test").await.unwrap();
1221 assert!(resp.ok);
1222 mock.assert_async().await;
1223 }
1224
1225 #[tokio::test]
1226 async fn test_with_options_sends_query_params() {
1227 let mut server = mockito::Server::new_async().await;
1228 let mock = server
1229 .mock("GET", "/test")
1230 .match_query(mockito::Matcher::AllOf(vec![mockito::Matcher::UrlEncoded(
1231 "foo".into(),
1232 "bar".into(),
1233 )]))
1234 .with_status(200)
1235 .with_body(r#"{"ok":true}"#)
1236 .create_async()
1237 .await;
1238
1239 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
1240 let custom = client.with_options(RequestOptions::new().query_param("foo", "bar"));
1241
1242 #[derive(serde::Deserialize)]
1243 struct Resp {
1244 ok: bool,
1245 }
1246
1247 let resp: Resp = custom.get("/test").await.unwrap();
1248 assert!(resp.ok);
1249 mock.assert_async().await;
1250 }
1251
1252 #[tokio::test]
1253 async fn test_extra_body_merge() {
1254 let mut server = mockito::Server::new_async().await;
1255 let mock = server
1256 .mock("POST", "/test")
1257 .match_body(mockito::Matcher::Json(serde_json::json!({
1258 "model": "gpt-4",
1259 "extra_field": "injected"
1260 })))
1261 .with_status(200)
1262 .with_body(r#"{"id":"ok"}"#)
1263 .create_async()
1264 .await;
1265
1266 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
1267 let custom = client.with_options(
1268 RequestOptions::new().extra_body(serde_json::json!({"extra_field": "injected"})),
1269 );
1270
1271 #[derive(serde::Serialize)]
1272 struct Req {
1273 model: String,
1274 }
1275 #[derive(serde::Deserialize)]
1276 struct Resp {
1277 id: String,
1278 }
1279
1280 let resp: Resp = custom
1281 .post(
1282 "/test",
1283 &Req {
1284 model: "gpt-4".into(),
1285 },
1286 )
1287 .await
1288 .unwrap();
1289 assert_eq!(resp.id, "ok");
1290 mock.assert_async().await;
1291 }
1292
1293 #[tokio::test]
1294 async fn test_timeout_override() {
1295 let mut server = mockito::Server::new_async().await;
1296 let _mock = server
1298 .mock("GET", "/test")
1299 .with_status(200)
1300 .with_body(r#"{"ok":true}"#)
1301 .with_chunked_body(|_w| -> std::io::Result<()> {
1302 std::thread::sleep(std::time::Duration::from_secs(5));
1303 Ok(())
1304 })
1305 .create_async()
1306 .await;
1307
1308 let client = OpenAI::with_config(
1309 ClientConfig::new("sk-test")
1310 .base_url(server.url())
1311 .max_retries(0),
1312 );
1313 let custom = client.with_options(RequestOptions::new().timeout(Duration::from_millis(100)));
1314
1315 #[derive(Debug, serde::Deserialize)]
1316 struct Resp {
1317 _ok: bool,
1318 }
1319
1320 let err = custom.get::<Resp>("/test").await.unwrap_err();
1321 assert!(
1322 matches!(err, OpenAIError::RequestError(_)),
1323 "expected timeout error, got: {err:?}"
1324 );
1325 }
1326
1327 #[tokio::test]
1328 async fn test_options_merge_precedence() {
1329 let mut server = mockito::Server::new_async().await;
1330 let mock = server
1332 .mock("GET", "/test")
1333 .match_header("X-A", "2")
1334 .with_status(200)
1335 .with_body(r#"{"ok":true}"#)
1336 .create_async()
1337 .await;
1338
1339 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
1340 let base = client.with_options(RequestOptions::new().header("X-A", "1"));
1341 let custom = base.with_options(RequestOptions::new().header("X-A", "2"));
1342
1343 #[derive(serde::Deserialize)]
1344 struct Resp {
1345 ok: bool,
1346 }
1347
1348 let resp: Resp = custom.get("/test").await.unwrap();
1349 assert!(resp.ok);
1350 mock.assert_async().await;
1351 }
1352
1353 #[tokio::test]
1354 async fn test_default_headers_and_query_on_config() {
1355 let mut server = mockito::Server::new_async().await;
1356 let mock = server
1357 .mock("GET", "/test")
1358 .match_header("X-Default", "from-config")
1359 .match_query(mockito::Matcher::AllOf(vec![mockito::Matcher::UrlEncoded(
1360 "cfg_param".into(),
1361 "cfg_val".into(),
1362 )]))
1363 .with_status(200)
1364 .with_body(r#"{"ok":true}"#)
1365 .create_async()
1366 .await;
1367
1368 let mut default_headers = reqwest::header::HeaderMap::new();
1369 default_headers.insert("X-Default", "from-config".parse().unwrap());
1370
1371 let client = OpenAI::with_config(
1372 ClientConfig::new("sk-test")
1373 .base_url(server.url())
1374 .default_headers(default_headers)
1375 .default_query(vec![("cfg_param".into(), "cfg_val".into())]),
1376 );
1377
1378 #[derive(serde::Deserialize)]
1379 struct Resp {
1380 ok: bool,
1381 }
1382
1383 let resp: Resp = client.get("/test").await.unwrap();
1384 assert!(resp.ok);
1385 mock.assert_async().await;
1386 }
1387
1388 #[tokio::test]
1389 async fn test_chained_with_options_merges() {
1390 let mut server = mockito::Server::new_async().await;
1391 let mock = server
1392 .mock("GET", "/test")
1393 .match_header("X-A", "from-a")
1394 .match_header("X-B", "from-b")
1395 .with_status(200)
1396 .with_body(r#"{"ok":true}"#)
1397 .create_async()
1398 .await;
1399
1400 let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
1401 let chained = client
1402 .with_options(RequestOptions::new().header("X-A", "from-a"))
1403 .with_options(RequestOptions::new().header("X-B", "from-b"));
1404
1405 #[derive(serde::Deserialize)]
1406 struct Resp {
1407 ok: bool,
1408 }
1409
1410 let resp: Resp = chained.get("/test").await.unwrap();
1411 assert!(resp.ok);
1412 mock.assert_async().await;
1413 }
1414}