Skip to main content

openai_oxide/
client.rs

1// OpenAI client
2
3use std::time::Duration;
4
5use crate::azure::AzureConfig;
6use crate::config::ClientConfig;
7use crate::error::{ErrorResponse, OpenAIError};
8use crate::request_options::RequestOptions;
9#[cfg(feature = "audio")]
10use crate::resources::audio::Audio;
11#[cfg(feature = "batches")]
12use crate::resources::batches::Batches;
13#[cfg(feature = "beta")]
14use crate::resources::beta::assistants::Assistants;
15#[cfg(feature = "beta")]
16use crate::resources::beta::realtime::Realtime;
17#[cfg(feature = "beta")]
18use crate::resources::beta::runs::Runs;
19#[cfg(feature = "beta")]
20use crate::resources::beta::threads::Threads;
21#[cfg(feature = "beta")]
22use crate::resources::beta::vector_stores::VectorStores;
23#[cfg(feature = "chat")]
24use crate::resources::chat::Chat;
25#[cfg(feature = "embeddings")]
26use crate::resources::embeddings::Embeddings;
27#[cfg(feature = "files")]
28use crate::resources::files::Files;
29#[cfg(feature = "fine-tuning")]
30use crate::resources::fine_tuning::FineTuning;
31#[cfg(feature = "images")]
32use crate::resources::images::Images;
33#[cfg(feature = "models")]
34use crate::resources::models::Models;
35#[cfg(feature = "moderations")]
36use crate::resources::moderations::Moderations;
37#[cfg(feature = "responses")]
38use crate::resources::responses::Responses;
39#[cfg(feature = "uploads")]
40use crate::resources::uploads::Uploads;
41
42/// Status codes that trigger a retry.
43const RETRYABLE_STATUS_CODES: [u16; 4] = [429, 500, 502, 503];
44
45/// The main OpenAI client.
46///
47/// See [OpenAI API docs](https://platform.openai.com/docs/api-reference) for the full API reference.
48///
49/// Use [`with_options()`](Self::with_options) to create a cheap clone with
50/// per-request customization (extra headers, query params, timeout):
51///
52/// ```ignore
53/// use openai_oxide::RequestOptions;
54/// use std::time::Duration;
55///
56/// let custom = client.with_options(
57///     RequestOptions::new()
58///         .header("X-Custom", "value")
59///         .timeout(Duration::from_secs(30))
60/// );
61/// ```
62#[derive(Debug, Clone)]
63pub struct OpenAI {
64    pub(crate) http: reqwest::Client,
65    pub(crate) config: std::sync::Arc<dyn crate::config::Config>,
66    pub(crate) options: RequestOptions,
67}
68
69impl OpenAI {
70    /// Create a new client with the given API key.
71    pub fn new(api_key: impl Into<String>) -> Self {
72        Self::with_config(ClientConfig::new(api_key))
73    }
74
75    /// Create a client from a full config.
76    pub fn with_config<C: crate::config::Config + 'static>(config: C) -> Self {
77        let options = config.initial_options();
78
79        #[cfg(not(target_arch = "wasm32"))]
80        let http = {
81            crate::ensure_tls_provider();
82
83            reqwest::Client::builder()
84                .timeout(Duration::from_secs(config.timeout_secs()))
85                .tcp_nodelay(true)
86                .tcp_keepalive(Some(Duration::from_secs(30)))
87                .pool_idle_timeout(Some(Duration::from_secs(300)))
88                .pool_max_idle_per_host(4)
89                .http2_keep_alive_interval(Some(Duration::from_secs(20)))
90                .http2_keep_alive_timeout(Duration::from_secs(10))
91                .http2_keep_alive_while_idle(true)
92                .http2_adaptive_window(true)
93                .gzip(true)
94                .build()
95                .expect("failed to build HTTP client")
96        };
97
98        #[cfg(target_arch = "wasm32")]
99        let http = reqwest::Client::new();
100        Self {
101            http,
102            config: std::sync::Arc::new(config),
103            options,
104        }
105    }
106
107    /// Create a cheap clone of this client with additional request options.
108    ///
109    /// The returned client shares the same HTTP connection pool (`reqwest::Client`
110    /// uses `Arc` internally) but applies the merged options to every request.
111    ///
112    /// ```ignore
113    /// use openai_oxide::RequestOptions;
114    ///
115    /// let custom = client.with_options(
116    ///     RequestOptions::new().header("X-Custom", "value")
117    /// );
118    /// // All requests through `custom` will include the X-Custom header.
119    /// let resp = custom.chat().completions().create(req).await?;
120    /// ```
121    #[must_use]
122    pub fn with_options(&self, options: RequestOptions) -> Self {
123        Self {
124            http: self.http.clone(),
125            config: self.config.clone(),
126            options: self.options.merge(&options),
127        }
128    }
129
130    /// Create a client using the `OPENAI_API_KEY` environment variable.
131    pub fn from_env() -> Result<Self, OpenAIError> {
132        Ok(Self::with_config(ClientConfig::from_env()?))
133    }
134
135    /// Create a client configured for Azure OpenAI.
136    ///
137    /// # Examples
138    ///
139    /// ```ignore
140    /// use openai_oxide::{OpenAI, AzureConfig};
141    ///
142    /// let client = OpenAI::azure(
143    ///     AzureConfig::new()
144    ///         .azure_endpoint("https://my-resource.openai.azure.com")
145    ///         .azure_deployment("gpt-4")
146    ///         .api_key("my-azure-key")
147    /// )?;
148    /// ```
149    pub fn azure(config: AzureConfig) -> Result<Self, OpenAIError> {
150        config.build()
151    }
152
153    /// Access the Batches resource.
154    #[cfg(feature = "batches")]
155    pub fn batches(&self) -> Batches<'_> {
156        Batches::new(self)
157    }
158
159    /// Access the Uploads resource.
160    #[cfg(feature = "uploads")]
161    pub fn uploads(&self) -> Uploads<'_> {
162        Uploads::new(self)
163    }
164
165    /// Access the Beta resources (Assistants, Threads, Runs, Vector Stores).
166    #[cfg(feature = "beta")]
167    pub fn beta(&self) -> Beta<'_> {
168        Beta { client: self }
169    }
170
171    /// Access the Audio resource.
172    #[cfg(feature = "audio")]
173    pub fn audio(&self) -> Audio<'_> {
174        Audio::new(self)
175    }
176
177    /// Access the Chat resource.
178    #[cfg(feature = "chat")]
179    pub fn chat(&self) -> Chat<'_> {
180        Chat::new(self)
181    }
182
183    /// Access the Models resource.
184    #[cfg(feature = "models")]
185    pub fn models(&self) -> Models<'_> {
186        Models::new(self)
187    }
188
189    /// Access the Fine-tuning resource.
190    #[cfg(feature = "fine-tuning")]
191    pub fn fine_tuning(&self) -> FineTuning<'_> {
192        FineTuning::new(self)
193    }
194
195    /// Access the Files resource.
196    #[cfg(feature = "files")]
197    pub fn files(&self) -> Files<'_> {
198        Files::new(self)
199    }
200
201    /// Access the Images resource.
202    #[cfg(feature = "images")]
203    pub fn images(&self) -> Images<'_> {
204        Images::new(self)
205    }
206
207    /// Access the Moderations resource.
208    #[cfg(feature = "moderations")]
209    pub fn moderations(&self) -> Moderations<'_> {
210        Moderations::new(self)
211    }
212
213    /// Access the Responses resource.
214    #[cfg(feature = "responses")]
215    pub fn responses(&self) -> Responses<'_> {
216        Responses::new(self)
217    }
218
219    /// Access the Embeddings resource.
220    #[cfg(feature = "embeddings")]
221    pub fn embeddings(&self) -> Embeddings<'_> {
222        Embeddings::new(self)
223    }
224
225    /// Access conversation endpoints (multi-turn server-side state).
226    pub fn conversations(&self) -> crate::resources::conversations::Conversations<'_> {
227        crate::resources::conversations::Conversations::new(self)
228    }
229
230    /// Access video generation endpoints (Sora).
231    pub fn videos(&self) -> crate::resources::videos::Videos<'_> {
232        crate::resources::videos::Videos::new(self)
233    }
234
235    /// Create a persistent WebSocket session to the Responses API.
236    ///
237    /// Opens a connection to `wss://api.openai.com/v1/responses` and returns
238    /// a [`WsSession`](crate::websocket::WsSession) for low-latency,
239    /// multi-turn interactions.
240    ///
241    /// Requires the `websocket` feature.
242    ///
243    /// ```ignore
244    /// let mut session = client.ws_session().await?;
245    /// let response = session.send(request).await?;
246    /// session.close().await?;
247    /// ```
248    #[cfg(feature = "websocket")]
249    pub async fn ws_session(&self) -> Result<crate::websocket::WsSession, OpenAIError> {
250        crate::websocket::WsSession::connect(self.config.as_ref()).await
251    }
252
253    /// Build a request with auth headers and client-level options applied.
254    pub(crate) fn request(&self, method: reqwest::Method, path: &str) -> reqwest::RequestBuilder {
255        let url = format!("{}{}", self.config.base_url(), path);
256        let req = self.http.request(method, &url);
257        let mut req = self.config.build_request(req);
258
259        // Apply client-level options
260        if let Some(ref headers) = self.options.headers {
261            for (key, value) in headers.iter() {
262                req = req.header(key.clone(), value.clone());
263            }
264        }
265        #[cfg(not(target_arch = "wasm32"))]
266        if let Some(ref query) = self.options.query {
267            req = req.query(query);
268        }
269        #[cfg(not(target_arch = "wasm32"))]
270        if let Some(timeout) = self.options.timeout {
271            req = req.timeout(timeout);
272        }
273
274        req
275    }
276
277    /// Send a GET request and deserialize the response.
278    #[allow(dead_code)]
279    pub(crate) async fn get<T: serde::de::DeserializeOwned>(
280        &self,
281        path: &str,
282    ) -> Result<T, OpenAIError> {
283        self.send_with_retry(reqwest::Method::GET, path, None::<&()>)
284            .await
285    }
286
287    /// Send a GET request with query parameters and deserialize the response.
288    #[allow(dead_code)]
289    #[cfg(not(target_arch = "wasm32"))]
290    pub(crate) async fn get_with_query<T: serde::de::DeserializeOwned>(
291        &self,
292        path: &str,
293        query: &[(String, String)],
294    ) -> Result<T, OpenAIError> {
295        let mut req = self.request(reqwest::Method::GET, path);
296        if !query.is_empty() {
297            req = req.query(query);
298        }
299        let response = req.send().await?;
300        Self::handle_response(response).await
301    }
302
303    /// Send a POST request with a JSON body and deserialize the response.
304    pub(crate) async fn post<B: serde::Serialize, T: serde::de::DeserializeOwned>(
305        &self,
306        path: &str,
307        body: &B,
308    ) -> Result<T, OpenAIError> {
309        self.send_with_retry(reqwest::Method::POST, path, Some(body))
310            .await
311    }
312
313    /// Send a POST request with a JSON body and return the raw JSON value.
314    ///
315    /// This is the backbone for BYOT (bring your own types) `create_raw()` methods:
316    /// accepts any `Serialize` request and returns `serde_json::Value` instead of a
317    /// typed response, letting advanced users work with custom or untyped payloads.
318    pub(crate) async fn post_json<B: serde::Serialize>(
319        &self,
320        path: &str,
321        body: &B,
322    ) -> Result<serde_json::Value, OpenAIError> {
323        self.post(path, body).await
324    }
325
326    /// Send a POST request with no body and deserialize the response.
327    pub(crate) async fn post_empty<T: serde::de::DeserializeOwned>(
328        &self,
329        path: &str,
330    ) -> Result<T, OpenAIError> {
331        self.send_with_retry(reqwest::Method::POST, path, None::<&()>)
332            .await
333    }
334
335    /// Send a POST request with a multipart form body and deserialize the response.
336    #[cfg(not(target_arch = "wasm32"))]
337    pub(crate) async fn post_multipart<T: serde::de::DeserializeOwned>(
338        &self,
339        path: &str,
340        form: reqwest::multipart::Form,
341    ) -> Result<T, OpenAIError> {
342        let response = self
343            .request(reqwest::Method::POST, path)
344            .multipart(form)
345            .send()
346            .await?;
347        Self::handle_response(response).await
348    }
349
350    /// Send a GET request and return raw bytes.
351    pub(crate) async fn get_raw(&self, path: &str) -> Result<bytes::Bytes, OpenAIError> {
352        let response = self.request(reqwest::Method::GET, path).send().await?;
353
354        let status = response.status();
355        if status.is_success() {
356            Ok(response.bytes().await?)
357        } else {
358            Err(Self::extract_error(status.as_u16(), response).await)
359        }
360    }
361
362    /// Send a POST request with JSON body and return raw bytes (for binary responses like audio).
363    pub(crate) async fn post_raw<B: serde::Serialize>(
364        &self,
365        path: &str,
366        body: &B,
367    ) -> Result<bytes::Bytes, OpenAIError> {
368        let mut req = self.request(reqwest::Method::POST, path);
369        if self.options.extra_body.is_some() {
370            req = req.json(&self.merge_body_json(body)?);
371        } else {
372            req = req.json(body);
373        }
374        let response = req.send().await?;
375
376        let status = response.status();
377        if status.is_success() {
378            Ok(response.bytes().await?)
379        } else {
380            Err(Self::extract_error(status.as_u16(), response).await)
381        }
382    }
383
384    /// Send a DELETE request and deserialize the response.
385    #[allow(dead_code)]
386    pub(crate) async fn delete<T: serde::de::DeserializeOwned>(
387        &self,
388        path: &str,
389    ) -> Result<T, OpenAIError> {
390        self.send_with_retry(reqwest::Method::DELETE, path, None::<&()>)
391            .await
392    }
393
394    /// Serialize body to JSON and merge extra_body fields if set.
395    fn merge_body_json<B: serde::Serialize>(
396        &self,
397        body: &B,
398    ) -> Result<serde_json::Value, OpenAIError> {
399        let mut value = serde_json::to_value(body)?;
400        if let Some(ref extra) = self.options.extra_body
401            && let serde_json::Value::Object(map) = &mut value
402            && let serde_json::Value::Object(extra_map) = extra.clone()
403        {
404            for (k, v) in extra_map {
405                map.insert(k, v);
406            }
407        }
408        Ok(value)
409    }
410
411    /// Pre-serialize request body, merging extra_body if set.
412    fn prepare_body<B: serde::Serialize>(
413        &self,
414        body: Option<&B>,
415    ) -> Result<Option<serde_json::Value>, OpenAIError> {
416        match body {
417            Some(b) if self.options.extra_body.is_some() => Ok(Some(self.merge_body_json(b)?)),
418            Some(b) => Ok(Some(serde_json::to_value(b)?)),
419            None => Ok(None),
420        }
421    }
422
423    /// WASM: retry with cross-platform sleep.
424    #[cfg(target_arch = "wasm32")]
425    async fn send_with_retry<B: serde::Serialize, T: serde::de::DeserializeOwned>(
426        &self,
427        method: reqwest::Method,
428        path: &str,
429        body: Option<&B>,
430    ) -> Result<T, OpenAIError> {
431        let body_value = self.prepare_body(body)?;
432
433        for attempt in 0..=self.config.max_retries {
434            let mut req = self.request(method.clone(), path);
435            if let Some(ref val) = body_value {
436                req = req.json(val);
437            }
438
439            let response = match req.send().await {
440                Ok(resp) => resp,
441                Err(e) if attempt == self.config.max_retries => {
442                    return Err(OpenAIError::RequestError(e));
443                }
444                Err(_) => {
445                    crate::runtime::sleep(crate::runtime::backoff_ms(attempt)).await;
446                    continue;
447                }
448            };
449
450            let status = response.status().as_u16();
451            if !RETRYABLE_STATUS_CODES.contains(&status) || attempt == self.config.max_retries {
452                return Self::handle_response(response).await;
453            }
454
455            crate::runtime::sleep(crate::runtime::backoff_ms(attempt)).await;
456        }
457
458        Err(OpenAIError::InvalidArgument("retry exhausted".into()))
459    }
460
461    /// Send a request with retry logic for transient errors.
462    ///
463    /// Fast path: first attempt avoids loop overhead and method clone.
464    /// Only enters retry loop on transient errors (429, 5xx).
465    #[cfg(not(target_arch = "wasm32"))]
466    async fn send_with_retry<B: serde::Serialize, T: serde::de::DeserializeOwned>(
467        &self,
468        method: reqwest::Method,
469        path: &str,
470        body: Option<&B>,
471    ) -> Result<T, OpenAIError> {
472        let body_value = self.prepare_body(body)?;
473
474        // Fast path: first attempt — no clone, no loop
475        let mut req = self.request(method.clone(), path);
476        if let Some(ref val) = body_value {
477            req = req.json(val);
478        }
479
480        let response = match req.send().await {
481            Ok(resp) => resp,
482            Err(e) if self.config.max_retries() == 0 => return Err(OpenAIError::RequestError(e)),
483            Err(e) => {
484                // Enter retry path
485                return self.retry_loop(method, path, &body_value, e, 1).await;
486            }
487        };
488
489        let status = response.status().as_u16();
490        if !RETRYABLE_STATUS_CODES.contains(&status) {
491            return Self::handle_response(response).await;
492        }
493
494        if self.config.max_retries() == 0 {
495            return Self::handle_response(response).await;
496        }
497
498        // Retryable status on first attempt — enter retry loop
499        let retry_after = response
500            .headers()
501            .get("retry-after")
502            .and_then(|v| v.to_str().ok())
503            .and_then(|v| v.parse::<f64>().ok());
504        let last_error = Self::extract_error(status, response).await;
505        tokio::time::sleep(Self::backoff_delay(0, retry_after)).await;
506        self.retry_loop(method, path, &body_value, last_error, 1)
507            .await
508    }
509
510    /// Retry loop — only called when first attempt fails with a transient error.
511    #[cfg(not(target_arch = "wasm32"))]
512    async fn retry_loop<T: serde::de::DeserializeOwned>(
513        &self,
514        method: reqwest::Method,
515        path: &str,
516        body_value: &Option<serde_json::Value>,
517        initial_error: impl Into<OpenAIError>,
518        start_attempt: u32,
519    ) -> Result<T, OpenAIError> {
520        let max_retries = self.config.max_retries();
521        let mut last_error: OpenAIError = initial_error.into();
522
523        for attempt in start_attempt..=max_retries {
524            let mut req = self.request(method.clone(), path);
525            if let Some(val) = body_value {
526                req = req.json(val);
527            }
528
529            let response = match req.send().await {
530                Ok(resp) => resp,
531                Err(e) => {
532                    last_error = OpenAIError::RequestError(e);
533                    if attempt < max_retries {
534                        tokio::time::sleep(Self::backoff_delay(attempt, None)).await;
535                        continue;
536                    }
537                    break;
538                }
539            };
540
541            let status = response.status().as_u16();
542            if !RETRYABLE_STATUS_CODES.contains(&status) || attempt == max_retries {
543                return Self::handle_response(response).await;
544            }
545
546            let retry_after = response
547                .headers()
548                .get("retry-after")
549                .and_then(|v| v.to_str().ok())
550                .and_then(|v| v.parse::<f64>().ok());
551            last_error = Self::extract_error(status, response).await;
552            tokio::time::sleep(Self::backoff_delay(attempt, retry_after)).await;
553        }
554
555        Err(last_error)
556    }
557
558    /// Send a request with retry, returning the raw [`reqwest::Response`].
559    ///
560    /// Used by streaming and multipart endpoints that need retry but handle the
561    /// response body themselves. Retry happens BEFORE consuming the body.
562    #[cfg(not(target_arch = "wasm32"))]
563    pub(crate) async fn send_raw_with_retry(
564        &self,
565        builder: reqwest::RequestBuilder,
566    ) -> Result<reqwest::Response, OpenAIError> {
567        // Fast path: first attempt
568        let response = match builder.try_clone() {
569            Some(cloned) => match cloned.send().await {
570                Ok(resp) => resp,
571                Err(e) if self.config.max_retries() == 0 => {
572                    return Err(OpenAIError::RequestError(e));
573                }
574                Err(e) => {
575                    return self
576                        .retry_loop_raw(builder, OpenAIError::RequestError(e), 1)
577                        .await;
578                }
579            },
580            None => {
581                // Cannot clone (e.g. streaming body) — no retry possible
582                return Ok(builder.send().await?);
583            }
584        };
585
586        let status = response.status().as_u16();
587        if !RETRYABLE_STATUS_CODES.contains(&status) {
588            return Ok(response);
589        }
590        if self.config.max_retries() == 0 {
591            return Ok(response);
592        }
593
594        let retry_after = response
595            .headers()
596            .get("retry-after")
597            .and_then(|v| v.to_str().ok())
598            .and_then(|v| v.parse::<f64>().ok());
599        let last_error = Self::extract_error(status, response).await;
600        tokio::time::sleep(Self::backoff_delay(0, retry_after)).await;
601        self.retry_loop_raw(builder, last_error, 1).await
602    }
603
604    /// Retry loop for raw responses.
605    #[cfg(not(target_arch = "wasm32"))]
606    async fn retry_loop_raw(
607        &self,
608        builder: reqwest::RequestBuilder,
609        initial_error: OpenAIError,
610        start_attempt: u32,
611    ) -> Result<reqwest::Response, OpenAIError> {
612        let max_retries = self.config.max_retries();
613        let mut last_error = initial_error;
614
615        for attempt in start_attempt..=max_retries {
616            let req = match builder.try_clone() {
617                Some(cloned) => cloned,
618                None => return Err(last_error),
619            };
620
621            let response = match req.send().await {
622                Ok(resp) => resp,
623                Err(e) => {
624                    last_error = OpenAIError::RequestError(e);
625                    if attempt < max_retries {
626                        tokio::time::sleep(Self::backoff_delay(attempt, None)).await;
627                        continue;
628                    }
629                    break;
630                }
631            };
632
633            let status = response.status().as_u16();
634            if !RETRYABLE_STATUS_CODES.contains(&status) || attempt == max_retries {
635                return Ok(response);
636            }
637
638            let retry_after = response
639                .headers()
640                .get("retry-after")
641                .and_then(|v| v.to_str().ok())
642                .and_then(|v| v.parse::<f64>().ok());
643            last_error = Self::extract_error(status, response).await;
644            tokio::time::sleep(Self::backoff_delay(attempt, retry_after)).await;
645        }
646
647        Err(last_error)
648    }
649
650    /// Check a streaming response status and return error if non-2xx.
651    pub(crate) async fn check_stream_response(
652        response: reqwest::Response,
653    ) -> Result<reqwest::Response, OpenAIError> {
654        if response.status().is_success() {
655            Ok(response)
656        } else {
657            Err(Self::extract_error(response.status().as_u16(), response).await)
658        }
659    }
660
661    /// Calculate backoff delay: max(retry_after, 0.5 * 2^attempt) seconds.
662    #[cfg(not(target_arch = "wasm32"))]
663    fn backoff_delay(attempt: u32, retry_after_secs: Option<f64>) -> Duration {
664        let base = crate::runtime::backoff_ms(attempt);
665        match retry_after_secs {
666            Some(ra) => Duration::from_secs_f64(ra.max(base.as_secs_f64())),
667            None => base,
668        }
669    }
670
671    /// Handle API response: check status, parse errors or deserialize body.
672    ///
673    /// Uses `bytes()` + `from_slice()` instead of `text()` + `from_str()`
674    /// to avoid an intermediate String allocation.
675    ///
676    /// With `simd` feature: uses simd-json for SIMD-accelerated parsing.
677    pub(crate) async fn handle_response<T: serde::de::DeserializeOwned>(
678        response: reqwest::Response,
679    ) -> Result<T, OpenAIError> {
680        let status = response.status();
681        if status.is_success() {
682            let body = response.bytes().await?;
683            let result = Self::deserialize_body::<T>(&body);
684            match result {
685                Ok(value) => Ok(value),
686                Err(e) => {
687                    tracing::error!(
688                        error = %e,
689                        body_len = body.len(),
690                        body_preview = %String::from_utf8_lossy(&body[..body.len().min(500)]),
691                        "failed to deserialize API response"
692                    );
693                    Err(e)
694                }
695            }
696        } else {
697            Err(Self::extract_error(status.as_u16(), response).await)
698        }
699    }
700
701    /// Deserialize JSON body. Uses simd-json when `simd` feature is enabled.
702    #[cfg(feature = "simd")]
703    fn deserialize_body<T: serde::de::DeserializeOwned>(body: &[u8]) -> Result<T, OpenAIError> {
704        let mut buf = body.to_vec();
705        simd_json::from_slice::<T>(&mut buf)
706            .map_err(|e| OpenAIError::StreamError(format!("simd-json: {e}")))
707    }
708
709    /// Deserialize JSON body (standard serde_json).
710    #[cfg(not(feature = "simd"))]
711    fn deserialize_body<T: serde::de::DeserializeOwned>(body: &[u8]) -> Result<T, OpenAIError> {
712        serde_json::from_slice::<T>(body).map_err(OpenAIError::from)
713    }
714
715    /// Extract the `x-request-id` header from a response.
716    pub(crate) fn extract_request_id(response: &reqwest::Response) -> Option<String> {
717        response
718            .headers()
719            .get("x-request-id")
720            .and_then(|v| v.to_str().ok())
721            .map(String::from)
722    }
723
724    /// Extract an OpenAIError from a failed response.
725    pub(crate) async fn extract_error(status: u16, response: reqwest::Response) -> OpenAIError {
726        let request_id = Self::extract_request_id(&response);
727        let body = response.text().await.unwrap_or_default();
728        if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&body) {
729            OpenAIError::ApiError {
730                status,
731                message: error_resp.error.message,
732                type_: error_resp.error.type_,
733                code: error_resp.error.code,
734                request_id,
735            }
736        } else {
737            OpenAIError::ApiError {
738                status,
739                message: body,
740                type_: None,
741                code: None,
742                request_id,
743            }
744        }
745    }
746}
747
748/// Access beta endpoints (Assistants v2, Threads, Runs, Vector Stores).
749#[cfg(feature = "beta")]
750pub struct Beta<'a> {
751    client: &'a OpenAI,
752}
753
754#[cfg(feature = "beta")]
755impl<'a> Beta<'a> {
756    /// Access the Assistants resource.
757    pub fn assistants(&self) -> Assistants<'_> {
758        Assistants::new(self.client)
759    }
760
761    /// Access the Threads resource.
762    pub fn threads(&self) -> Threads<'_> {
763        Threads::new(self.client)
764    }
765
766    /// Access runs for a specific thread.
767    pub fn runs(&self, thread_id: &str) -> Runs<'_> {
768        Runs::new(self.client, thread_id.to_string())
769    }
770
771    /// Access the Vector Stores resource.
772    pub fn vector_stores(&self) -> VectorStores<'_> {
773        VectorStores::new(self.client)
774    }
775
776    /// Access the Realtime resource.
777    pub fn realtime(&self) -> Realtime<'_> {
778        Realtime::new(self.client)
779    }
780}
781
782#[cfg(test)]
783mod tests {
784    use super::*;
785
786    #[test]
787    fn test_new_client() {
788        let client = OpenAI::new("sk-test-key");
789        assert_eq!(client.config.api_key(), "sk-test-key");
790        assert_eq!(client.config.base_url(), "https://api.openai.com/v1");
791    }
792
793    #[test]
794    fn test_with_config() {
795        let config = ClientConfig::new("sk-test")
796            .base_url("https://custom.api.com")
797            .organization("org-123")
798            .timeout_secs(30);
799        let client = OpenAI::with_config(config);
800        assert_eq!(client.config.base_url(), "https://custom.api.com");
801        assert_eq!(client.config.organization(), Some("org-123"));
802        assert_eq!(client.config.timeout_secs(), 30);
803    }
804
805    #[test]
806    fn test_backoff_delay() {
807        // Attempt 0: 0.5s
808        let d = OpenAI::backoff_delay(0, None);
809        assert_eq!(d, Duration::from_millis(500));
810
811        // Attempt 1: 1.0s
812        let d = OpenAI::backoff_delay(1, None);
813        assert_eq!(d, Duration::from_secs(1));
814
815        // Attempt 2: 2.0s
816        let d = OpenAI::backoff_delay(2, None);
817        assert_eq!(d, Duration::from_secs(2));
818
819        // Retry-After takes precedence when larger
820        let d = OpenAI::backoff_delay(0, Some(5.0));
821        assert_eq!(d, Duration::from_secs(5));
822
823        // Exponential wins when larger than Retry-After
824        let d = OpenAI::backoff_delay(3, Some(0.1));
825        assert_eq!(d, Duration::from_secs(4));
826
827        // Capped at 60s
828        let d = OpenAI::backoff_delay(10, None);
829        assert_eq!(d, Duration::from_secs(60));
830    }
831
832    #[tokio::test]
833    async fn test_get_success() {
834        let mut server = mockito::Server::new_async().await;
835        let mock = server
836            .mock("GET", "/models/gpt-4")
837            .with_status(200)
838            .with_header("content-type", "application/json")
839            .with_body(
840                r#"{"id":"gpt-4","object":"model","created":1687882411,"owned_by":"openai"}"#,
841            )
842            .create_async()
843            .await;
844
845        let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
846
847        #[derive(serde::Deserialize)]
848        struct Model {
849            id: String,
850            object: String,
851        }
852
853        let model: Model = client.get("/models/gpt-4").await.unwrap();
854        assert_eq!(model.id, "gpt-4");
855        assert_eq!(model.object, "model");
856        mock.assert_async().await;
857    }
858
859    #[tokio::test]
860    async fn test_post_success() {
861        let mut server = mockito::Server::new_async().await;
862        let mock = server
863            .mock("POST", "/chat/completions")
864            .match_header("authorization", "Bearer sk-test")
865            .match_header("content-type", "application/json")
866            .with_status(200)
867            .with_header("content-type", "application/json")
868            .with_body(r#"{"id":"chatcmpl-123","object":"chat.completion"}"#)
869            .create_async()
870            .await;
871
872        let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
873
874        #[derive(serde::Serialize)]
875        struct Req {
876            model: String,
877        }
878        #[derive(serde::Deserialize)]
879        struct Resp {
880            id: String,
881        }
882
883        let resp: Resp = client
884            .post(
885                "/chat/completions",
886                &Req {
887                    model: "gpt-4".into(),
888                },
889            )
890            .await
891            .unwrap();
892        assert_eq!(resp.id, "chatcmpl-123");
893        mock.assert_async().await;
894    }
895
896    #[tokio::test]
897    async fn test_delete_success() {
898        let mut server = mockito::Server::new_async().await;
899        let mock = server
900            .mock("DELETE", "/models/ft-abc")
901            .with_status(200)
902            .with_header("content-type", "application/json")
903            .with_body(r#"{"id":"ft-abc","deleted":true}"#)
904            .create_async()
905            .await;
906
907        let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
908
909        #[derive(serde::Deserialize)]
910        struct DeleteResp {
911            id: String,
912            deleted: bool,
913        }
914
915        let resp: DeleteResp = client.delete("/models/ft-abc").await.unwrap();
916        assert_eq!(resp.id, "ft-abc");
917        assert!(resp.deleted);
918        mock.assert_async().await;
919    }
920
921    #[tokio::test]
922    async fn test_api_error_response() {
923        let mut server = mockito::Server::new_async().await;
924        let mock = server
925            .mock("GET", "/models/nonexistent")
926            .with_status(404)
927            .with_header("content-type", "application/json")
928            .with_body(
929                r#"{"error":{"message":"The model 'nonexistent' does not exist","type":"invalid_request_error","param":null,"code":"model_not_found"}}"#,
930            )
931            .create_async()
932            .await;
933
934        let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
935
936        #[derive(Debug, serde::Deserialize)]
937        struct Model {
938            _id: String,
939        }
940
941        let err = client
942            .get::<Model>("/models/nonexistent")
943            .await
944            .unwrap_err();
945        match err {
946            OpenAIError::ApiError {
947                status,
948                message,
949                type_,
950                code,
951                ..
952            } => {
953                assert_eq!(status, 404);
954                assert!(message.contains("does not exist"));
955                assert_eq!(type_.as_deref(), Some("invalid_request_error"));
956                assert_eq!(code.as_deref(), Some("model_not_found"));
957            }
958            other => panic!("expected ApiError, got: {other:?}"),
959        }
960        mock.assert_async().await;
961    }
962
963    #[tokio::test]
964    async fn test_auth_headers() {
965        let mut server = mockito::Server::new_async().await;
966        let mock = server
967            .mock("GET", "/test")
968            .match_header("authorization", "Bearer sk-key")
969            .match_header("OpenAI-Organization", "org-abc")
970            .match_header("OpenAI-Project", "proj-xyz")
971            .with_status(200)
972            .with_body(r#"{"ok":true}"#)
973            .create_async()
974            .await;
975
976        let client = OpenAI::with_config(
977            ClientConfig::new("sk-key")
978                .base_url(server.url())
979                .organization("org-abc")
980                .project("proj-xyz"),
981        );
982
983        #[derive(serde::Deserialize)]
984        struct Resp {
985            ok: bool,
986        }
987
988        let resp: Resp = client.get("/test").await.unwrap();
989        assert!(resp.ok);
990        mock.assert_async().await;
991    }
992
993    #[tokio::test]
994    async fn test_retry_on_429_then_success() {
995        let mut server = mockito::Server::new_async().await;
996
997        // First request returns 429, second returns 200
998        let _mock_429 = server
999            .mock("GET", "/test")
1000            .with_status(429)
1001            .with_header("retry-after", "0")
1002            .with_body(r#"{"error":{"message":"Rate limited","type":"rate_limit_error","param":null,"code":null}}"#)
1003            .create_async()
1004            .await;
1005
1006        let mock_200 = server
1007            .mock("GET", "/test")
1008            .with_status(200)
1009            .with_body(r#"{"ok":true}"#)
1010            .create_async()
1011            .await;
1012
1013        let client = OpenAI::with_config(
1014            ClientConfig::new("sk-test")
1015                .base_url(server.url())
1016                .max_retries(2),
1017        );
1018
1019        #[derive(serde::Deserialize)]
1020        struct Resp {
1021            ok: bool,
1022        }
1023
1024        let resp: Resp = client.get("/test").await.unwrap();
1025        assert!(resp.ok);
1026        mock_200.assert_async().await;
1027    }
1028
1029    #[tokio::test]
1030    async fn test_retry_exhausted_returns_last_error() {
1031        let mut server = mockito::Server::new_async().await;
1032
1033        // All requests return 500
1034        let _mock = server
1035            .mock("GET", "/test")
1036            .with_status(500)
1037            .with_body(r#"{"error":{"message":"Internal server error","type":"server_error","param":null,"code":null}}"#)
1038            .expect_at_least(2)
1039            .create_async()
1040            .await;
1041
1042        let client = OpenAI::with_config(
1043            ClientConfig::new("sk-test")
1044                .base_url(server.url())
1045                .max_retries(1),
1046        );
1047
1048        #[derive(Debug, serde::Deserialize)]
1049        struct Resp {
1050            _ok: bool,
1051        }
1052
1053        let err = client.get::<Resp>("/test").await.unwrap_err();
1054        match err {
1055            OpenAIError::ApiError { status, .. } => assert_eq!(status, 500),
1056            other => panic!("expected ApiError, got: {other:?}"),
1057        }
1058    }
1059
1060    #[tokio::test]
1061    async fn test_no_retry_on_400() {
1062        let mut server = mockito::Server::new_async().await;
1063
1064        // 400 should not be retried
1065        let mock = server
1066            .mock("GET", "/test")
1067            .with_status(400)
1068            .with_body(r#"{"error":{"message":"Bad request","type":"invalid_request_error","param":null,"code":null}}"#)
1069            .expect(1)
1070            .create_async()
1071            .await;
1072
1073        let client = OpenAI::with_config(
1074            ClientConfig::new("sk-test")
1075                .base_url(server.url())
1076                .max_retries(2),
1077        );
1078
1079        #[derive(Debug, serde::Deserialize)]
1080        struct Resp {
1081            _ok: bool,
1082        }
1083
1084        let err = client.get::<Resp>("/test").await.unwrap_err();
1085        match err {
1086            OpenAIError::ApiError { status, .. } => assert_eq!(status, 400),
1087            other => panic!("expected ApiError, got: {other:?}"),
1088        }
1089        mock.assert_async().await;
1090    }
1091
1092    #[tokio::test]
1093    async fn test_zero_retries_no_retry() {
1094        let mut server = mockito::Server::new_async().await;
1095
1096        let mock = server
1097            .mock("GET", "/test")
1098            .with_status(429)
1099            .with_body(r#"{"error":{"message":"Rate limited","type":"rate_limit_error","param":null,"code":null}}"#)
1100            .expect(1)
1101            .create_async()
1102            .await;
1103
1104        let client = OpenAI::with_config(
1105            ClientConfig::new("sk-test")
1106                .base_url(server.url())
1107                .max_retries(0),
1108        );
1109
1110        #[derive(Debug, serde::Deserialize)]
1111        struct Resp {
1112            _ok: bool,
1113        }
1114
1115        let err = client.get::<Resp>("/test").await.unwrap_err();
1116        match err {
1117            OpenAIError::ApiError { status, .. } => assert_eq!(status, 429),
1118            other => panic!("expected ApiError, got: {other:?}"),
1119        }
1120        mock.assert_async().await;
1121    }
1122
1123    // --- with_options() tests ---
1124
1125    #[tokio::test]
1126    async fn test_with_options_sends_extra_headers() {
1127        let mut server = mockito::Server::new_async().await;
1128        let mock = server
1129            .mock("GET", "/test")
1130            .match_header("X-Custom", "test-value")
1131            .with_status(200)
1132            .with_body(r#"{"ok":true}"#)
1133            .create_async()
1134            .await;
1135
1136        let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
1137        let custom = client.with_options(RequestOptions::new().header("X-Custom", "test-value"));
1138
1139        #[derive(serde::Deserialize)]
1140        struct Resp {
1141            ok: bool,
1142        }
1143
1144        let resp: Resp = custom.get("/test").await.unwrap();
1145        assert!(resp.ok);
1146        mock.assert_async().await;
1147    }
1148
1149    #[tokio::test]
1150    async fn test_with_options_sends_query_params() {
1151        let mut server = mockito::Server::new_async().await;
1152        let mock = server
1153            .mock("GET", "/test")
1154            .match_query(mockito::Matcher::AllOf(vec![mockito::Matcher::UrlEncoded(
1155                "foo".into(),
1156                "bar".into(),
1157            )]))
1158            .with_status(200)
1159            .with_body(r#"{"ok":true}"#)
1160            .create_async()
1161            .await;
1162
1163        let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
1164        let custom = client.with_options(RequestOptions::new().query_param("foo", "bar"));
1165
1166        #[derive(serde::Deserialize)]
1167        struct Resp {
1168            ok: bool,
1169        }
1170
1171        let resp: Resp = custom.get("/test").await.unwrap();
1172        assert!(resp.ok);
1173        mock.assert_async().await;
1174    }
1175
1176    #[tokio::test]
1177    async fn test_extra_body_merge() {
1178        let mut server = mockito::Server::new_async().await;
1179        let mock = server
1180            .mock("POST", "/test")
1181            .match_body(mockito::Matcher::Json(serde_json::json!({
1182                "model": "gpt-4",
1183                "extra_field": "injected"
1184            })))
1185            .with_status(200)
1186            .with_body(r#"{"id":"ok"}"#)
1187            .create_async()
1188            .await;
1189
1190        let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
1191        let custom = client.with_options(
1192            RequestOptions::new().extra_body(serde_json::json!({"extra_field": "injected"})),
1193        );
1194
1195        #[derive(serde::Serialize)]
1196        struct Req {
1197            model: String,
1198        }
1199        #[derive(serde::Deserialize)]
1200        struct Resp {
1201            id: String,
1202        }
1203
1204        let resp: Resp = custom
1205            .post(
1206                "/test",
1207                &Req {
1208                    model: "gpt-4".into(),
1209                },
1210            )
1211            .await
1212            .unwrap();
1213        assert_eq!(resp.id, "ok");
1214        mock.assert_async().await;
1215    }
1216
1217    #[tokio::test]
1218    async fn test_timeout_override() {
1219        let mut server = mockito::Server::new_async().await;
1220        // Mock with a 5s delay — our timeout is 100ms, so it should fail
1221        let _mock = server
1222            .mock("GET", "/test")
1223            .with_status(200)
1224            .with_body(r#"{"ok":true}"#)
1225            .with_chunked_body(|_w| -> std::io::Result<()> {
1226                std::thread::sleep(std::time::Duration::from_secs(5));
1227                Ok(())
1228            })
1229            .create_async()
1230            .await;
1231
1232        let client = OpenAI::with_config(
1233            ClientConfig::new("sk-test")
1234                .base_url(server.url())
1235                .max_retries(0),
1236        );
1237        let custom = client.with_options(RequestOptions::new().timeout(Duration::from_millis(100)));
1238
1239        #[derive(Debug, serde::Deserialize)]
1240        struct Resp {
1241            _ok: bool,
1242        }
1243
1244        let err = custom.get::<Resp>("/test").await.unwrap_err();
1245        assert!(
1246            matches!(err, OpenAIError::RequestError(_)),
1247            "expected timeout error, got: {err:?}"
1248        );
1249    }
1250
1251    #[tokio::test]
1252    async fn test_options_merge_precedence() {
1253        let mut server = mockito::Server::new_async().await;
1254        // with_options header should override the default
1255        let mock = server
1256            .mock("GET", "/test")
1257            .match_header("X-A", "2")
1258            .with_status(200)
1259            .with_body(r#"{"ok":true}"#)
1260            .create_async()
1261            .await;
1262
1263        let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
1264        let base = client.with_options(RequestOptions::new().header("X-A", "1"));
1265        let custom = base.with_options(RequestOptions::new().header("X-A", "2"));
1266
1267        #[derive(serde::Deserialize)]
1268        struct Resp {
1269            ok: bool,
1270        }
1271
1272        let resp: Resp = custom.get("/test").await.unwrap();
1273        assert!(resp.ok);
1274        mock.assert_async().await;
1275    }
1276
1277    #[tokio::test]
1278    async fn test_default_headers_and_query_on_config() {
1279        let mut server = mockito::Server::new_async().await;
1280        let mock = server
1281            .mock("GET", "/test")
1282            .match_header("X-Default", "from-config")
1283            .match_query(mockito::Matcher::AllOf(vec![mockito::Matcher::UrlEncoded(
1284                "cfg_param".into(),
1285                "cfg_val".into(),
1286            )]))
1287            .with_status(200)
1288            .with_body(r#"{"ok":true}"#)
1289            .create_async()
1290            .await;
1291
1292        let mut default_headers = reqwest::header::HeaderMap::new();
1293        default_headers.insert("X-Default", "from-config".parse().unwrap());
1294
1295        let client = OpenAI::with_config(
1296            ClientConfig::new("sk-test")
1297                .base_url(server.url())
1298                .default_headers(default_headers)
1299                .default_query(vec![("cfg_param".into(), "cfg_val".into())]),
1300        );
1301
1302        #[derive(serde::Deserialize)]
1303        struct Resp {
1304            ok: bool,
1305        }
1306
1307        let resp: Resp = client.get("/test").await.unwrap();
1308        assert!(resp.ok);
1309        mock.assert_async().await;
1310    }
1311
1312    #[tokio::test]
1313    async fn test_chained_with_options_merges() {
1314        let mut server = mockito::Server::new_async().await;
1315        let mock = server
1316            .mock("GET", "/test")
1317            .match_header("X-A", "from-a")
1318            .match_header("X-B", "from-b")
1319            .with_status(200)
1320            .with_body(r#"{"ok":true}"#)
1321            .create_async()
1322            .await;
1323
1324        let client = OpenAI::with_config(ClientConfig::new("sk-test").base_url(server.url()));
1325        let chained = client
1326            .with_options(RequestOptions::new().header("X-A", "from-a"))
1327            .with_options(RequestOptions::new().header("X-B", "from-b"));
1328
1329        #[derive(serde::Deserialize)]
1330        struct Resp {
1331            ok: bool,
1332        }
1333
1334        let resp: Resp = chained.get("/test").await.unwrap();
1335        assert!(resp.ok);
1336        mock.assert_async().await;
1337    }
1338}